Skip to content

Commit 2794049

Browse files
committed
[doc] add small example to flight recorder tutorial
Summary: Add a small example that demonstrated flight recorder end-to-end. Test Plan: Test on github to make sure that the tutorial renders correctly.
1 parent 45d33e1 commit 2794049

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed

prototype_source/flight_recorder_tutorial.rst

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,78 @@ Caveat: tabulate module is needed, so you might need pip install it first.
202202
python fr_trace.py <dump dir containing trace files> -j [--selected-ranks i j k ...] [--pg-filters tp dp]
203203
torchfrtrace <dump dir containing trace files> -j [--selected-ranks i j k ...] [--pg-filters 0 2]
204204
205+
206+
A Small example
207+
---------------
208+
To put this all togther, we demonstrate Flight Recorder using a small program where we induce mismatched collectives.
209+
`rank0` is programmed to do an additional collective.
210+
We write out Flight Recorder dump files to the `/tmp`` directory.
211+
For the purpose of this example, we named the small program `crash.py`.
212+
213+
.. code:: python
214+
215+
import torch
216+
import torch.distributed as dist
217+
import os
218+
from datetime import timedelta
219+
220+
local_rank = int(os.environ["LOCAL_RANK"])
221+
world_size = int(os.environ["WORLD_SIZE"])
222+
assert world_size <= 8, "world size must be less than or equal to 8"
223+
os.environ["TORCH_NCCL_DEBUG_INFO_TEMP_FILE"] = "/tmp/trace_"
224+
os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1"
225+
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "2000"
226+
device = torch.device(f"cuda:{local_rank}")
227+
print(f"{local_rank=} {world_size=} master addr: {os.environ['MASTER_ADDR']} master port: {os.environ['MASTER_PORT']} {device=}")
228+
229+
# Initialize the process group with a small timeout so that jobs fail quickly
230+
dist.init_process_group("nccl", world_size=world_size, rank=local_rank, timeout=timedelta(seconds=1))
231+
232+
a = torch.full((3, 4), float(local_rank), device=device)
233+
# Write some collectives to populate Flight Recorder data
234+
for i in range(2):
235+
print(f"calling allreduce on {local_rank=}")
236+
f = dist.all_reduce(a)
237+
238+
# rank0 is doing an additional collective
239+
if local_rank == 0:
240+
print("rank0 is doing an allreduce on tensor b, but other ranks forgot")
241+
b = torch.full((4,5), float(local_rank), device=device)
242+
f = dist.all_reduce(b)
243+
244+
for i in range(2):
245+
print(f"calling allreduce on {local_rank=}")
246+
f = dist.all_reduce(a)
247+
248+
torch.cuda.synchronize(device=device)
249+
print(f"{local_rank=} exiting")
250+
251+
252+
To run this program, we use `torchrun`.
253+
254+
255+
.. code:: python
256+
257+
torchrun --nnodes=1 --nproc_per_node=2 crash.py
258+
259+
You'll notice two files in the `/tmp` directory
260+
261+
.. code:: bash
262+
263+
ls /tmp/trace*
264+
# Expected output
265+
# /tmp/trace_0 /tmp/trace1
266+
267+
Finally, to analyze these two files, we use the `torchfrtrace` command.
268+
269+
.. code:: bash
270+
271+
torchfrtrace --prefix "trace_" /tmp/
272+
# Expected output
273+
# Collective 3 at entry 2 error
274+
# ...
275+
276+
205277
Conclusion
206278
----------
207279
In this tutorial, we have learned about a new PyTorch diagnostic tool called Flight Recorder.

0 commit comments

Comments
 (0)