@@ -202,6 +202,78 @@ Caveat: tabulate module is needed, so you might need pip install it first.
202
202
python fr_trace.py <dump dir containing trace files> -j [--selected-ranks i j k ...] [--pg-filters tp dp]
203
203
torchfrtrace <dump dir containing trace files> -j [--selected-ranks i j k ...] [--pg-filters 0 2]
204
204
205
+
206
+ A Small example
207
+ ---------------
208
+ To put this all togther, we demonstrate Flight Recorder using a small program where we induce mismatched collectives.
209
+ `rank0` is programmed to do an additional collective.
210
+ We write out Flight Recorder dump files to the `/tmp `` directory.
211
+ For the purpose of this example, we named the small program `crash.py `.
212
+
213
+ .. code :: python
214
+
215
+ import torch
216
+ import torch.distributed as dist
217
+ import os
218
+ from datetime import timedelta
219
+
220
+ local_rank = int (os.environ[" LOCAL_RANK" ])
221
+ world_size = int (os.environ[" WORLD_SIZE" ])
222
+ assert world_size <= 8 , " world size must be less than or equal to 8"
223
+ os.environ[" TORCH_NCCL_DEBUG_INFO_TEMP_FILE" ] = " /tmp/trace_"
224
+ os.environ[" TORCH_NCCL_DUMP_ON_TIMEOUT" ] = " 1"
225
+ os.environ[" TORCH_NCCL_TRACE_BUFFER_SIZE" ] = " 2000"
226
+ device = torch.device(f " cuda: { local_rank} " )
227
+ print (f " { local_rank= } { world_size= } master addr: { os.environ[' MASTER_ADDR' ]} master port: { os.environ[' MASTER_PORT' ]} { device= } " )
228
+
229
+ # Initialize the process group with a small timeout so that jobs fail quickly
230
+ dist.init_process_group(" nccl" , world_size = world_size, rank = local_rank, timeout = timedelta(seconds = 1 ))
231
+
232
+ a = torch.full((3 , 4 ), float (local_rank), device = device)
233
+ # Write some collectives to populate Flight Recorder data
234
+ for i in range (2 ):
235
+ print (f " calling allreduce on { local_rank= } " )
236
+ f = dist.all_reduce(a)
237
+
238
+ # rank0 is doing an additional collective
239
+ if local_rank == 0 :
240
+ print (" rank0 is doing an allreduce on tensor b, but other ranks forgot" )
241
+ b = torch.full((4 ,5 ), float (local_rank), device = device)
242
+ f = dist.all_reduce(b)
243
+
244
+ for i in range (2 ):
245
+ print (f " calling allreduce on { local_rank= } " )
246
+ f = dist.all_reduce(a)
247
+
248
+ torch.cuda.synchronize(device = device)
249
+ print (f " { local_rank= } exiting " )
250
+
251
+
252
+ To run this program, we use `torchrun `.
253
+
254
+
255
+ .. code :: python
256
+
257
+ torchrun -- nnodes= 1 -- nproc_per_node= 2 crash.py
258
+
259
+ You'll notice two files in the `/tmp ` directory
260
+
261
+ .. code :: bash
262
+
263
+ ls /tmp/trace*
264
+ # Expected output
265
+ # /tmp/trace_0 /tmp/trace1
266
+
267
+ Finally, to analyze these two files, we use the `torchfrtrace ` command.
268
+
269
+ .. code :: bash
270
+
271
+ torchfrtrace --prefix " trace_" /tmp/
272
+ # Expected output
273
+ # Collective 3 at entry 2 error
274
+ # ...
275
+
276
+
205
277
Conclusion
206
278
----------
207
279
In this tutorial, we have learned about a new PyTorch diagnostic tool called Flight Recorder.
0 commit comments