inclusionAI / debug-distributed
Install for your project team
Run this command in your project directory to install the skill for your entire team:
mkdir -p .claude/skills/debug-distributed && curl -L -o skill.zip "https://fastmcp.me/Skills/Download/3466" && unzip -o skill.zip -d .claude/skills/debug-distributed && rm skill.zip
Project Skills
This skill will be saved in .claude/skills/debug-distributed/ and checked into git. All team members will have access to it automatically.
Important: Please verify the skill by reviewing its instructions before using it.
Guide for debugging distributed training issues in AReaL. Use when user encounters hangs, wrong results, OOM, or communication errors.
0 views
0 installs
Skill Content
---
name: debug-distributed
description: Guide for debugging distributed training issues in AReaL. Use when user encounters hangs, wrong results, OOM, or communication errors.
---
# Debug Distributed Training
Debugging guide for distributed training issues in AReaL (FSDP2, TP, CP, EP).
## When to Use
This skill is triggered when:
- Training hangs or deadlocks
- Results differ across ranks or are numerically wrong
- OOM errors in distributed settings
- NCCL/communication errors or device mesh issues
## Debugging Principles
### Minimal Reproduction
**Always follow the minimal demo principle**: Reproduce with the least amount of code to
narrow down the issue faster.
```python
# Bad: Debug in full training loop
# Good: Create minimal script
import torch
import torch.distributed as dist
dist.init_process_group("nccl")
rank = dist.get_rank()
# Reproduce the exact operation that fails
tensor = torch.ones(10).cuda()
dist.all_reduce(tensor) # <-- Isolate the failing op
print(f"Rank {rank}: {tensor}")
```
**Reduction strategy:**
1. Remove unrelated model components
1. Use small tensor sizes
1. Reduce world_size to minimum (e.g., 2 GPUs)
1. Remove torch.compile if possible
1. Disable activation checkpointing
## Step-by-Step Debugging Guide
### 1. Hang Debugging (Deadlocks, Synchronization)
**Environment Variables for Debugging**:
```bash
# Full debug logging
export TORCH_DISTRIBUTED_DEBUG=DETAIL
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL
# torch.compile debugging
export TORCH_LOGS="+dynamo,recompiles"
export TORCHDYNAMO_VERBOSE=1
```
**Dump Call Stack with py-spy** (for hung processes):
```bash
# Find process IDs
ps aux | grep python
# Dump call stack of specific rank
py-spy dump --pid <PID>
# Record flame graph for performance analysis
py-spy record -o profile.svg --pid <PID> --duration 30
```
**Common Causes**:
1. **Mismatched Collectives**: One rank calls `all_reduce`, another doesn't.
1. **Wrong Process Group**: Using wrong group for collective.
1. **Tensor Shape Mismatch**: Different shapes across ranks.
**Debug Steps**:
```python
# Verify group membership
mesh = parallel_dims.get_mesh("dp_shard_cp")
group = mesh.get_group()
print(f"Rank {dist.get_rank()}: group size = {dist.get_world_size(group)}")
# Print shapes on all ranks
print(f"Rank {dist.get_rank()}: tensor.shape = {tensor.shape}")
dist.barrier()
```
**Timeout Adjustment** (for debugging only):
```python
from areal.engine.core.distributed import patch_dist_group_timeout
from datetime import timedelta
patch_dist_group_timeout(timedelta(minutes=30))
```
### 2. Wrong Results (Gradient, Reduction Issues)
**Check DTensor Placements**:
```python
from torch.distributed.tensor import DTensor
if isinstance(param, DTensor):
print(f"Param {name}: placements={param.placements}, mesh={param.device_mesh}")
```
**Verify Gradient Reduction**:
```python
for name, param in model.named_parameters():
if param.grad is not None:
print(f"Rank {dist.get_rank()}: {name} grad_sum = {param.grad.sum().item()}")
```
### 3. OOM Issues (Memory, Sharding)
**Check Memory Usage**:
```python
print(f"Rank {dist.get_rank()}: "
f"allocated={torch.cuda.memory_allocated()/1e9:.2f}GB, "
f"reserved={torch.cuda.memory_reserved()/1e9:.2f}GB")
```
**Check FSDP Coverage**:
```python
for name, param in model.named_parameters():
is_dtensor = isinstance(param, DTensor)
print(f"{name}: is_dtensor={is_dtensor}, shape={param.shape}")
```
### 4. Communication Errors
| Error | Cause | Solution |
| ------------------------- | -------------------- | ---------------------------------- |
| `NCCL WARN Cuda failure` | GPU communication | Check NCCL version, GPU topology |
| `RuntimeError: Timed out` | Rank synchronization | Increase timeout, check code paths |
| `Invalid device mesh` | Mesh configuration | Verify world_size = dp * tp * cp |
## Debugging Tools
### Environment Variables Reference
| Variable | Purpose |
| --------------------------------- | -------------------------------------- |
| `TORCH_DISTRIBUTED_DEBUG=DETAIL` | Detailed distributed logging |
| `NCCL_DEBUG=INFO` | NCCL communication logging |
| `NCCL_DEBUG_SUBSYS=ALL` | All NCCL subsystems |
| `TORCH_LOGS="+dynamo,recompiles"` | torch.compile logging |
| `TORCHDYNAMO_VERBOSE=1` | Dynamo verbose output |
| `CUDA_LAUNCH_BLOCKING=1` | Synchronous CUDA (slow, for debugging) |
### py-spy for Call Stack Analysis
```bash
# Install
pip install py-spy
# Dump call stack of hung process
py-spy dump --pid <PID>
# Dump all Python processes
pgrep -f python | xargs -I {} py-spy dump --pid {}
# Record flame graph
py-spy record -o profile.svg --pid <PID> --duration 30
```
### Rank-Conditional Printing
```python
def print_all_ranks(msg):
for r in range(dist.get_world_size()):
if dist.get_rank() == r:
print(f"[Rank {r}] {msg}")
dist.barrier()
```
### Check Device Mesh
```python
def debug_mesh(parallel_dims):
mesh = parallel_dims.world_mesh
for dim_name in mesh.mesh_dim_names:
submesh = parallel_dims.get_mesh(dim_name)
if submesh:
print(f"Rank {dist.get_rank()}: {dim_name} size={submesh.size()}")
```
### Validate Tensor Consistency
```python
def check_tensor_consistency(tensor, name, group=None):
local_sum = tensor.sum().item()
tensor_sums = [None] * dist.get_world_size(group)
dist.all_gather_object(tensor_sums, local_sum, group=group)
if dist.get_rank() == 0 and len(set(tensor_sums)) > 1:
print(f"WARNING: {name} inconsistent: {tensor_sums}")
```
## Key Files Reference
| Component | File |
| --------------- | ------------------------------------------------------------- |
| Parallel Dims | `areal/experimental/models/archon/parallel_dims.py` |
| Expert Parallel | `areal/experimental/models/archon/expert_parallel.py` |
| Ulysses (CP) | `areal/experimental/models/archon/ulysses.py` |
| FSDP/TP Apply | `areal/experimental/models/archon/qwen2/infra/parallelize.py` |
______________________________________________________________________
<!--
================================================================================
MAINTAINER GUIDE
================================================================================
Location: .claude/skills/debug-distributed/SKILL.md
Invocation: /debug-distributed
## Purpose
Debugging guide for distributed training issues.
Covers FSDP2, Tensor Parallelism, Context Parallelism, and Expert Parallelism.
## How to Update
### When Adding New Parallelism Features
1. Add section for the parallelism type
2. Document common error patterns and debugging snippets
### When PyTorch Distributed APIs Change
1. Update DTensor/DeviceMesh examples
2. Update environment variable references
### When New Error Patterns Emerge
1. Add to "Common Errors and Solutions" table
2. Reference relevant source files
================================================================================
-->