@@ -202,6 +202,77 @@ Caveat: tabulate module is needed, so you might need pip install it first.
202
202
python fr_trace.py <dump dir containing trace files> -j [--selected-ranks i j k ...] [--pg-filters tp dp]
203
203
torchfrtrace <dump dir containing trace files> -j [--selected-ranks i j k ...] [--pg-filters 0 2]
204
204
205
+ A Small example
206
+ ---------------
207
+ To put this all togther, we demonstrate Flight Recorder using a small program where we induce mismatched collectives.
208
+ `rank0` is programmed to do an additional collective.
209
+ We write out Flight Recorder dump files to the `/tmp `` directory.
210
+ For the purpose of this example, we named the small program `crash.py `.
211
+
212
+ .. code :: python
213
+
214
+ import torch
215
+ import torch.distributed as dist
216
+ import os
217
+ from datetime import timedelta
218
+
219
+ local_rank = int (os.environ[" LOCAL_RANK" ])
220
+ world_size = int (os.environ[" WORLD_SIZE" ])
221
+ assert world_size <= 8 , " world size must be less than or equal to 8"
222
+ os.environ[" TORCH_NCCL_DEBUG_INFO_TEMP_FILE" ] = " /tmp/trace_"
223
+ os.environ[" TORCH_NCCL_DUMP_ON_TIMEOUT" ] = " 1"
224
+ os.environ[" TORCH_NCCL_TRACE_BUFFER_SIZE" ] = " 2000"
225
+ device = torch.device(f " cuda: { local_rank} " )
226
+ print (f " { local_rank= } { world_size= } master addr: { os.environ[' MASTER_ADDR' ]} master port: { os.environ[' MASTER_PORT' ]} { device= } " )
227
+
228
+ # Initialize the process group with a small timeout so that jobs fail quickly
229
+ dist.init_process_group(" nccl" , world_size = world_size, rank = local_rank, timeout = timedelta(seconds = 1 ))
230
+
231
+ a = torch.full((3 , 4 ), float (local_rank), device = device)
232
+ # Write some collectives to populate Flight Recorder data
233
+ for i in range (2 ):
234
+ print (f " calling allreduce on { local_rank= } " )
235
+ f = dist.all_reduce(a)
236
+
237
+ # rank0 is doing an additional collective
238
+ if local_rank == 0 :
239
+ print (" rank0 is doing an allreduce on tensor b, but other ranks forgot" )
240
+ b = torch.full((4 ,5 ), float (local_rank), device = device)
241
+ f = dist.all_reduce(b)
242
+
243
+ for i in range (2 ):
244
+ print (f " calling allreduce on { local_rank= } " )
245
+ f = dist.all_reduce(a)
246
+
247
+ torch.cuda.synchronize(device = device)
248
+ print (f " { local_rank= } exiting " )
249
+
250
+
251
+ To run this program, we use `torchrun `.
252
+
253
+
254
+ .. code :: python
255
+
256
+ torchrun -- nnodes= 1 -- nproc_per_node= 2 crash.py
257
+
258
+ You'll notice two files in the `/tmp ` directory
259
+
260
+ .. code :: bash
261
+
262
+ ls /tmp/trace*
263
+ # Expected output
264
+ # /tmp/trace_0 /tmp/trace1
265
+
266
+ Finally, to analyze these two files, we use the `torchfrtrace ` command.
267
+
268
+ .. code :: bash
269
+
270
+ torchfrtrace --prefix " trace_" /tmp/
271
+ # Expected output
272
+ # Collective 3 at entry 2 error
273
+ # ...
274
+
275
+
205
276
Conclusion
206
277
----------
207
278
In this tutorial, we have learned about a new PyTorch diagnostic tool called Flight Recorder.
0 commit comments