11
11
# .. warning::
12
12
#
13
13
# ``torch.export`` and its related features are in prototype status and are subject to backwards compatibility
14
- # breaking changes. This tutorial provides a snapshot of ``torch.export`` usage as of PyTorch 2.2 .
14
+ # breaking changes. This tutorial provides a snapshot of ``torch.export`` usage as of PyTorch 2.3 .
15
15
#
16
16
# :func:`torch.export` is the PyTorch 2.X way to export PyTorch models into
17
17
# standardized model representations, intended
18
- # to be run on different (i.e. Python-less) environments.
18
+ # to be run on different (i.e. Python-less) environments. The official
19
+ # documentation can be found `here <https://pytorch.org/docs/main/export.html>`__.
19
20
#
20
21
# In this tutorial, you will learn how to use :func:`torch.export` to extract
21
22
# ``ExportedProgram``'s (i.e. single-graph representations) from PyTorch programs.
@@ -71,7 +72,7 @@ def forward(self, x, y):
71
72
mod = MyModule ()
72
73
exported_mod = export (mod , (torch .randn (8 , 100 ), torch .randn (8 , 100 )))
73
74
print (type (exported_mod ))
74
- print (exported_mod (torch .randn (8 , 100 ), torch .randn (8 , 100 )))
75
+ print (exported_mod . module () (torch .randn (8 , 100 ), torch .randn (8 , 100 )))
75
76
76
77
77
78
######################################################################
@@ -100,7 +101,7 @@ def forward(self, x, y):
100
101
# Other attributes of interest in ``ExportedProgram`` include:
101
102
#
102
103
# - ``graph_signature`` -- the inputs, outputs, parameters, buffers, etc. of the exported graph.
103
- # - ``range_constraints`` and ``equality_constraints`` -- constraints, covered later
104
+ # - ``range_constraints`` -- constraints, covered later
104
105
105
106
print (exported_mod .graph_signature )
106
107
@@ -123,54 +124,58 @@ def forward(self, x, y):
123
124
#
124
125
# - data-dependent control flow
125
126
126
- def bad1 (x ):
127
- if x .sum () > 0 :
128
- return torch .sin (x )
129
- return torch .cos (x )
127
+ class Bad1 (torch .nn .Module ):
128
+ def forward (self , x ):
129
+ if x .sum () > 0 :
130
+ return torch .sin (x )
131
+ return torch .cos (x )
130
132
131
133
import traceback as tb
132
134
try :
133
- export (bad1 , (torch .randn (3 , 3 ),))
135
+ export (Bad1 () , (torch .randn (3 , 3 ),))
134
136
except Exception :
135
137
tb .print_exc ()
136
138
137
139
######################################################################
138
140
# - accessing tensor data with ``.data``
139
141
140
- def bad2 (x ):
141
- x .data [0 , 0 ] = 3
142
- return x
142
+ class Bad2 (torch .nn .Module ):
143
+ def forward (self , x ):
144
+ x .data [0 , 0 ] = 3
145
+ return x
143
146
144
147
try :
145
- export (bad2 , (torch .randn (3 , 3 ),))
148
+ export (Bad2 () , (torch .randn (3 , 3 ),))
146
149
except Exception :
147
150
tb .print_exc ()
148
151
149
152
######################################################################
150
153
# - calling unsupported functions (such as many built-in functions)
151
154
152
- def bad3 (x ):
153
- x = x + 1
154
- return x + id (x )
155
+ class Bad3 (torch .nn .Module ):
156
+ def forward (self , x ):
157
+ x = x + 1
158
+ return x + id (x )
155
159
156
160
try :
157
- export (bad3 , (torch .randn (3 , 3 ),))
161
+ export (Bad3 () , (torch .randn (3 , 3 ),))
158
162
except Exception :
159
163
tb .print_exc ()
160
164
161
165
######################################################################
162
166
# - unsupported Python language features (e.g. throwing exceptions, match statements)
163
167
164
- def bad4 (x ):
165
- try :
166
- x = x + 1
167
- raise RuntimeError ("bad" )
168
- except :
169
- x = x + 2
170
- return x
168
+ class Bad4 (torch .nn .Module ):
169
+ def forward (self , x ):
170
+ try :
171
+ x = x + 1
172
+ raise RuntimeError ("bad" )
173
+ except :
174
+ x = x + 2
175
+ return x
171
176
172
177
try :
173
- export (bad4 , (torch .randn (3 , 3 ),))
178
+ export (Bad4 () , (torch .randn (3 , 3 ),))
174
179
except Exception :
175
180
tb .print_exc ()
176
181
@@ -188,16 +193,17 @@ def bad4(x):
188
193
189
194
from functorch .experimental .control_flow import cond
190
195
191
- def bad1_fixed (x ):
192
- def true_fn (x ):
193
- return torch .sin (x )
194
- def false_fn (x ):
195
- return torch .cos (x )
196
- return cond (x .sum () > 0 , true_fn , false_fn , [x ])
196
+ class Bad1Fixed (torch .nn .Module ):
197
+ def forward (self , x ):
198
+ def true_fn (x ):
199
+ return torch .sin (x )
200
+ def false_fn (x ):
201
+ return torch .cos (x )
202
+ return cond (x .sum () > 0 , true_fn , false_fn , [x ])
197
203
198
- exported_bad1_fixed = export (bad1_fixed , (torch .randn (3 , 3 ),))
199
- print (exported_bad1_fixed (torch .ones (3 , 3 )))
200
- print (exported_bad1_fixed (- torch .ones (3 , 3 )))
204
+ exported_bad1_fixed = export (Bad1Fixed () , (torch .randn (3 , 3 ),))
205
+ print (exported_bad1_fixed . module () (torch .ones (3 , 3 )))
206
+ print (exported_bad1_fixed . module () (- torch .ones (3 , 3 )))
201
207
202
208
######################################################################
203
209
# There are limitations to ``cond`` that one should be aware of:
@@ -255,7 +261,7 @@ def forward(self, x, y):
255
261
exported_mod2 = export (mod2 , (torch .randn (8 , 100 ), torch .randn (8 , 100 )))
256
262
257
263
try :
258
- exported_mod2 (torch .randn (10 , 100 ), torch .randn (10 , 100 ))
264
+ exported_mod2 . module () (torch .randn (10 , 100 ), torch .randn (10 , 100 ))
259
265
except Exception :
260
266
tb .print_exc ()
261
267
@@ -286,32 +292,33 @@ def forward(self, x, y):
286
292
287
293
inp1 = torch .randn (10 , 10 , 2 )
288
294
289
- def dynamic_shapes_example1 (x ):
290
- x = x [:, 2 :]
291
- return torch .relu (x )
295
+ class DynamicShapesExample1 (torch .nn .Module ):
296
+ def forward (self , x ):
297
+ x = x [:, 2 :]
298
+ return torch .relu (x )
292
299
293
300
inp1_dim0 = Dim ("inp1_dim0" )
294
301
inp1_dim1 = Dim ("inp1_dim1" , min = 4 , max = 18 )
295
302
dynamic_shapes1 = {
296
303
"x" : {0 : inp1_dim0 , 1 : inp1_dim1 },
297
304
}
298
305
299
- exported_dynamic_shapes_example1 = export (dynamic_shapes_example1 , (inp1 ,), dynamic_shapes = dynamic_shapes1 )
306
+ exported_dynamic_shapes_example1 = export (DynamicShapesExample1 () , (inp1 ,), dynamic_shapes = dynamic_shapes1 )
300
307
301
- print (exported_dynamic_shapes_example1 (torch .randn (5 , 5 , 2 )))
308
+ print (exported_dynamic_shapes_example1 . module () (torch .randn (5 , 5 , 2 )))
302
309
303
310
try :
304
- exported_dynamic_shapes_example1 (torch .randn (8 , 1 , 2 ))
311
+ exported_dynamic_shapes_example1 . module () (torch .randn (8 , 1 , 2 ))
305
312
except Exception :
306
313
tb .print_exc ()
307
314
308
315
try :
309
- exported_dynamic_shapes_example1 (torch .randn (8 , 20 , 2 ))
316
+ exported_dynamic_shapes_example1 . module () (torch .randn (8 , 20 , 2 ))
310
317
except Exception :
311
318
tb .print_exc ()
312
319
313
320
try :
314
- exported_dynamic_shapes_example1 (torch .randn (8 , 8 , 3 ))
321
+ exported_dynamic_shapes_example1 . module () (torch .randn (8 , 8 , 3 ))
315
322
except Exception :
316
323
tb .print_exc ()
317
324
@@ -325,7 +332,7 @@ def dynamic_shapes_example1(x):
325
332
}
326
333
327
334
try :
328
- export (dynamic_shapes_example1 , (inp1 ,), dynamic_shapes = dynamic_shapes1_bad )
335
+ export (DynamicShapesExample1 () , (inp1 ,), dynamic_shapes = dynamic_shapes1_bad )
329
336
except Exception :
330
337
tb .print_exc ()
331
338
@@ -336,8 +343,9 @@ def dynamic_shapes_example1(x):
336
343
inp2 = torch .randn (4 , 8 )
337
344
inp3 = torch .randn (8 , 2 )
338
345
339
- def dynamic_shapes_example2 (x , y ):
340
- return x @ y
346
+ class DynamicShapesExample2 (torch .nn .Module ):
347
+ def forward (self , x , y ):
348
+ return x @ y
341
349
342
350
inp2_dim0 = Dim ("inp2_dim0" )
343
351
inner_dim = Dim ("inner_dim" )
@@ -348,12 +356,12 @@ def dynamic_shapes_example2(x, y):
348
356
"y" : {0 : inner_dim , 1 : inp3_dim1 },
349
357
}
350
358
351
- exported_dynamic_shapes_example2 = export (dynamic_shapes_example2 , (inp2 , inp3 ), dynamic_shapes = dynamic_shapes2 )
359
+ exported_dynamic_shapes_example2 = export (DynamicShapesExample2 () , (inp2 , inp3 ), dynamic_shapes = dynamic_shapes2 )
352
360
353
- print (exported_dynamic_shapes_example2 (torch .randn (2 , 16 ), torch .randn (16 , 4 )))
361
+ print (exported_dynamic_shapes_example2 . module () (torch .randn (2 , 16 ), torch .randn (16 , 4 )))
354
362
355
363
try :
356
- exported_dynamic_shapes_example2 (torch .randn (4 , 8 ), torch .randn (4 , 2 ))
364
+ exported_dynamic_shapes_example2 . module () (torch .randn (4 , 8 ), torch .randn (4 , 2 ))
357
365
except Exception :
358
366
tb .print_exc ()
359
367
@@ -367,18 +375,19 @@ def dynamic_shapes_example2(x, y):
367
375
inp4 = torch .randn (8 , 16 )
368
376
inp5 = torch .randn (16 , 32 )
369
377
370
- def dynamic_shapes_example3 (x , y ):
371
- if x .shape [0 ] <= 16 :
372
- return x @ y [:, :16 ]
373
- return y
378
+ class DynamicShapesExample3 (torch .nn .Module ):
379
+ def forward (self , x , y ):
380
+ if x .shape [0 ] <= 16 :
381
+ return x @ y [:, :16 ]
382
+ return y
374
383
375
384
dynamic_shapes3 = {
376
385
"x" : {i : Dim (f"inp4_dim{ i } " ) for i in range (inp4 .dim ())},
377
386
"y" : {i : Dim (f"inp5_dim{ i } " ) for i in range (inp5 .dim ())},
378
387
}
379
388
380
389
try :
381
- export (dynamic_shapes_example3 , (inp4 , inp5 ), dynamic_shapes = dynamic_shapes3 )
390
+ export (DynamicShapesExample3 () , (inp4 , inp5 ), dynamic_shapes = dynamic_shapes3 )
382
391
except Exception :
383
392
tb .print_exc ()
384
393
@@ -400,8 +409,8 @@ def suggested_fixes():
400
409
}
401
410
402
411
dynamic_shapes3_fixed = suggested_fixes ()
403
- exported_dynamic_shapes_example3 = export (dynamic_shapes_example3 , (inp4 , inp5 ), dynamic_shapes = dynamic_shapes3_fixed )
404
- print (exported_dynamic_shapes_example3 (torch .randn (4 , 32 ), torch .randn (32 , 64 )))
412
+ exported_dynamic_shapes_example3 = export (DynamicShapesExample3 () , (inp4 , inp5 ), dynamic_shapes = dynamic_shapes3_fixed )
413
+ print (exported_dynamic_shapes_example3 . module () (torch .randn (4 , 32 ), torch .randn (32 , 64 )))
405
414
406
415
######################################################################
407
416
# Note that in the example above, because we constrained the value of ``x.shape[0]`` in
@@ -414,18 +423,16 @@ def suggested_fixes():
414
423
415
424
import logging
416
425
torch ._logging .set_logs (dynamic = logging .INFO , dynamo = logging .INFO )
417
- exported_dynamic_shapes_example3 = export (dynamic_shapes_example3 , (inp4 , inp5 ), dynamic_shapes = dynamic_shapes3_fixed )
426
+ exported_dynamic_shapes_example3 = export (DynamicShapesExample3 () , (inp4 , inp5 ), dynamic_shapes = dynamic_shapes3_fixed )
418
427
419
428
# reset to previous values
420
429
torch ._logging .set_logs (dynamic = logging .WARNING , dynamo = logging .WARNING )
421
430
422
431
######################################################################
423
- # We can view an ``ExportedProgram``'s constraints using the ``range_constraints`` and
424
- # ``equality_constraints`` attributes. The logging above reveals what the symbols ``s0, s1, ...``
425
- # represent.
432
+ # We can view an ``ExportedProgram``'s symbolic shape ranges using the
433
+ # ``range_constraints`` field.
426
434
427
435
print (exported_dynamic_shapes_example3 .range_constraints )
428
- print (exported_dynamic_shapes_example3 .equality_constraints )
429
436
430
437
######################################################################
431
438
# Custom Ops
@@ -438,7 +445,7 @@ def suggested_fixes():
438
445
# - Define the custom op using ``torch.library`` (`reference <https://pytorch.org/docs/main/library.html>`__)
439
446
# as with any other custom op
440
447
441
- from torch .library import Library , impl
448
+ from torch .library import Library , impl , impl_abstract
442
449
443
450
m = Library ("my_custom_library" , "DEF" )
444
451
@@ -453,25 +460,26 @@ def custom_op(x):
453
460
# - Define a ``"Meta"`` implementation of the custom op that returns an empty
454
461
# tensor with the same shape as the expected output
455
462
456
- @impl ( m , " custom_op" , "Meta " )
463
+ @impl_abstract ( "my_custom_library:: custom_op" )
457
464
def custom_op_meta (x ):
458
465
return torch .empty_like (x )
459
466
460
467
######################################################################
461
468
# - Call the custom op from the code you want to export using ``torch.ops``
462
469
463
- def custom_op_example (x ):
464
- x = torch .sin (x )
465
- x = torch .ops .my_custom_library .custom_op (x )
466
- x = torch .cos (x )
467
- return x
470
+ class CustomOpExample (torch .nn .Module ):
471
+ def forward (self , x ):
472
+ x = torch .sin (x )
473
+ x = torch .ops .my_custom_library .custom_op (x )
474
+ x = torch .cos (x )
475
+ return x
468
476
469
477
######################################################################
470
478
# - Export the code as before
471
479
472
- exported_custom_op_example = export (custom_op_example , (torch .randn (3 , 3 ),))
480
+ exported_custom_op_example = export (CustomOpExample () , (torch .randn (3 , 3 ),))
473
481
exported_custom_op_example .graph_module .print_readable ()
474
- print (exported_custom_op_example (torch .randn (3 , 3 )))
482
+ print (exported_custom_op_example . module () (torch .randn (3 , 3 )))
475
483
476
484
######################################################################
477
485
# Note in the above outputs that the custom op is included in the exported graph.
@@ -606,6 +614,45 @@ def cond_predicate(x):
606
614
# ExportDB is not exhaustive, but is intended to cover all use cases found in typical PyTorch code. Feel free to reach
607
615
# out if there is an important Python/PyTorch feature that should be added to ExportDB or supported by ``torch.export``.
608
616
617
+ ######################################################################
618
+ # Running the Exported Program
619
+ # ----------------------------
620
+ #
621
+ # As ``torch.export`` is only a graph capturing mechanism, calling the artifact
622
+ # produced by ``torch.export`` eagerly will be equivalent to running the eager
623
+ # module. To optimize the execution of the Exported Program, we can pass this
624
+ # exported artifact to backends such Inductor through ``torch.compile``,
625
+ # `AOTInductor <https://pytorch.org/docs/main/torch.compiler_aot_inductor.html>`__,
626
+ # or `TensorRT <https://pytorch.org/TensorRT/dynamo/dynamo_export.html>`__.
627
+
628
+ class M (torch .nn .Module ):
629
+ def __init__ (self ):
630
+ super ().__init__ ()
631
+ self .linear = torch .nn .Linear (3 , 3 )
632
+
633
+ def forward (self , x ):
634
+ x = self .linear (x )
635
+ return x
636
+
637
+ ep = torch .export .export (M ().to (device = "cuda" ), (torch .ones (2 , 3 , device = "cuda" ),))
638
+ inp = torch .randn (2 , 3 , device = "cuda" )
639
+
640
+ # Run it eagerly
641
+ res = ep .module ()(inp )
642
+ print (res )
643
+
644
+ # Run it with torch.compile
645
+ res = torch .compile (ep .module (), backend = "inductor" )(inp )
646
+ print (res )
647
+
648
+ # Compile the exported program to a .so using AOTInductor
649
+ so_path = torch ._export .aot_compile (ep .module (), (inp ,))
650
+ # Load and run the .so in python.
651
+ # To load and run it in a C++ environment, please take a look at
652
+ # https://pytorch.org/docs/main/torch.compiler_aot_inductor.html
653
+ res = torch ._export .aot_load (so_path , device = "cuda" )(inp )
654
+ print (res )
655
+
609
656
######################################################################
610
657
# Conclusion
611
658
# ----------
0 commit comments