Skip to content

Commit 846b766

Browse files
committed
[doc] add small example to flight recorder tutorial
Summary: Add a small example that demonstrated flight recorder end-to-end. Test Plan: Test on github to make sure that the tutorial renders correctly.
1 parent 45d33e1 commit 846b766

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed

prototype_source/flight_recorder_tutorial.rst

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,77 @@ Caveat: tabulate module is needed, so you might need pip install it first.
202202
python fr_trace.py <dump dir containing trace files> -j [--selected-ranks i j k ...] [--pg-filters tp dp]
203203
torchfrtrace <dump dir containing trace files> -j [--selected-ranks i j k ...] [--pg-filters 0 2]
204204
205+
A Small example
206+
---------------
207+
To put this all togther, we demonstrate Flight Recorder using a small program where we induce mismatched collectives.
208+
`rank0` is programmed to do an additional collective.
209+
We write out Flight Recorder dump files to the `/tmp`` directory.
210+
For the purpose of this example, we named the small program `crash.py`.
211+
212+
.. code:: python
213+
214+
import torch
215+
import torch.distributed as dist
216+
import os
217+
from datetime import timedelta
218+
219+
local_rank = int(os.environ["LOCAL_RANK"])
220+
world_size = int(os.environ["WORLD_SIZE"])
221+
assert world_size <= 8, "world size must be less than or equal to 8"
222+
os.environ["TORCH_NCCL_DEBUG_INFO_TEMP_FILE"] = "/tmp/trace_"
223+
os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1"
224+
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "2000"
225+
device = torch.device(f"cuda:{local_rank}")
226+
print(f"{local_rank=} {world_size=} master addr: {os.environ['MASTER_ADDR']} master port: {os.environ['MASTER_PORT']} {device=}")
227+
228+
# Initialize the process group with a small timeout so that jobs fail quickly
229+
dist.init_process_group("nccl", world_size=world_size, rank=local_rank, timeout=timedelta(seconds=1))
230+
231+
a = torch.full((3, 4), float(local_rank), device=device)
232+
# Write some collectives to populate Flight Recorder data
233+
for i in range(2):
234+
print(f"calling allreduce on {local_rank=}")
235+
f = dist.all_reduce(a)
236+
237+
# rank0 is doing an additional collective
238+
if local_rank == 0:
239+
print("rank0 is doing an allreduce on tensor b, but other ranks forgot")
240+
b = torch.full((4,5), float(local_rank), device=device)
241+
f = dist.all_reduce(b)
242+
243+
for i in range(2):
244+
print(f"calling allreduce on {local_rank=}")
245+
f = dist.all_reduce(a)
246+
247+
torch.cuda.synchronize(device=device)
248+
print(f"{local_rank=} exiting")
249+
250+
251+
To run this program, we use `torchrun`.
252+
253+
254+
.. code:: python
255+
256+
torchrun --nnodes=1 --nproc_per_node=2 crash.py
257+
258+
You'll notice two files in the `/tmp` directory
259+
260+
.. code:: bash
261+
262+
ls /tmp/trace*
263+
# Expected output
264+
# /tmp/trace_0 /tmp/trace1
265+
266+
Finally, to analyze these two files, we use the `torchfrtrace` command.
267+
268+
.. code:: bash
269+
270+
torchfrtrace --prefix "trace_" /tmp/
271+
# Expected output
272+
# Collective 3 at entry 2 error
273+
# ...
274+
275+
205276
Conclusion
206277
----------
207278
In this tutorial, we have learned about a new PyTorch diagnostic tool called Flight Recorder.

0 commit comments

Comments
 (0)