@@ -152,7 +152,8 @@ class Inp:
152
152
RETRACEABILITY_SUFFIX = "_retraceability"
153
153
SERDES_SUFFIX = "_serdes"
154
154
PREDISPATCH_SUFFIX = "_pre_dispatch"
155
- TRAINING_IR_DECOMP_SUFFIX = "_training_ir_to_decomp"
155
+ TRAINING_IR_DECOMP_STRICT_SUFFIX = "_training_ir_to_decomp"
156
+ TRAINING_IR_DECOMP_NON_STRICT_SUFFIX = "_training_ir_to_decomp_non_strict"
156
157
157
158
158
159
def is_non_strict_test (test_name ):
@@ -167,6 +168,12 @@ def is_serdes_test(test_name):
167
168
return test_name .endswith (SERDES_SUFFIX )
168
169
169
170
171
+ def is_training_ir_test (test_name ):
172
+ return test_name .endswith (TRAINING_IR_DECOMP_STRICT_SUFFIX ) or test_name .endswith (
173
+ TRAINING_IR_DECOMP_NON_STRICT_SUFFIX
174
+ )
175
+
176
+
170
177
@unittest .skipIf (not torchdynamo .is_dynamo_supported (), "dynamo isn't support" )
171
178
class TestDynamismExpression (TestCase ):
172
179
def test_export_inline_constraints (self ):
@@ -309,6 +316,7 @@ def forward(self, x, y):
309
316
)
310
317
311
318
# Errors because fake mode is not detected from non-tensor inputs
319
+ @testing .expectedFailureTrainingIRToRunDecompNonStrict
312
320
@testing .expectedFailureTrainingIRToRunDecomp
313
321
def test_no_tensor_computation_3 (self ):
314
322
class Module (torch .nn .Module ):
@@ -346,8 +354,6 @@ def forward(self, x, y):
346
354
return (x_0,)""" ,
347
355
)
348
356
349
- # Errors because non-strict is not supported in training IR (T193692164)
350
- @testing .expectedFailureTrainingIRToRunDecomp
351
357
def test_external_call_non_strict_real_tensor (self ):
352
358
class ExternalMethod :
353
359
def add (self , x ):
@@ -418,8 +424,6 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
418
424
args = (torch .randn (15 , 3 , 256 , 256 ), torch .ones (15 , 32 , 256 , 256 ))
419
425
self .assertEqual (gm (* args ), m (* args ))
420
426
421
- # Errors because non-strict is not supported in training IR (T193692164)
422
- @testing .expectedFailureTrainingIRToRunDecomp
423
427
def test_basic_non_strict_real_tensor (self ):
424
428
class Basic (torch .nn .Module ):
425
429
def __init__ (self ):
@@ -434,8 +438,6 @@ def forward(self, x, y):
434
438
ep = export (f , args , strict = False )
435
439
self .assertEqual (ep .module ()(* args ), f (* args ))
436
440
437
- # Errors because non-strict is not supported in training IR (T193692164)
438
- @testing .expectedFailureTrainingIRToRunDecomp
439
441
def test_basic_non_strict_fake_tensor (self ):
440
442
class Basic (torch .nn .Module ):
441
443
def __init__ (self ):
@@ -690,8 +692,6 @@ def forward(self, x):
690
692
torch .allclose (ep .module ()(torch .zeros (2 , 3 )), torch .ones (2 , 3 ) * 21 )
691
693
)
692
694
693
- # Predispatch has different expected results
694
- @testing .expectedFailureTrainingIRToRunDecomp # T193700910
695
695
def test_torch_fn (self ):
696
696
class M1 (torch .nn .Module ):
697
697
def __init__ (self ):
@@ -823,6 +823,7 @@ def forward(self, p_linear_weight, p_linear_bias, x):
823
823
@testing .expectedFailurePreDispatchRunDecomp
824
824
@testing .expectedFailureRetraceability
825
825
@testing .expectedFailureTrainingIRToRunDecomp # T193700910
826
+ @testing .expectedFailureTrainingIRToRunDecompNonStrict
826
827
def test_export_cond_preserve_torch_fn_for_subgraphs (self ):
827
828
class MySubModule (torch .nn .Module ):
828
829
def foo (self , x ):
@@ -2178,6 +2179,7 @@ def forward(self, arg1, arg2, *args, kw1, kw2, **kwargs):
2178
2179
@testing .expectedFailureSerDer # we don't save placeholder metadata
2179
2180
@testing .expectedFailureNonStrict
2180
2181
@testing .expectedFailureTrainingIRToRunDecomp # T193692674
2182
+ @testing .expectedFailureTrainingIRToRunDecompNonStrict
2181
2183
def test_linear_conv (self ):
2182
2184
class MyLinear (torch .nn .Module ):
2183
2185
def __init__ (self ):
@@ -2853,7 +2855,6 @@ def test_buffer_util(self):
2853
2855
self .assertEqual (buffer [1 ].shape , torch .Size ([100 ])) # running_var
2854
2856
self .assertEqual (buffer [2 ].shape , torch .Size ([])) # num_batches_tracked
2855
2857
2856
- @testing .expectedFailureTrainingIRToRunDecomp # T193701564
2857
2858
def test_export_dynamo_config (self ):
2858
2859
class MyModule (torch .nn .Module ):
2859
2860
def __init__ (self ):
@@ -2889,6 +2890,7 @@ def _patch_config(kwargs):
2889
2890
_ = export (mod , inp , strict = True )
2890
2891
2891
2892
@testing .expectedFailureTrainingIRToRunDecomp # T193700396
2893
+ @testing .expectedFailureTrainingIRToRunDecompNonStrict
2892
2894
def test_device_to_static (self ):
2893
2895
class Module (torch .nn .Module ):
2894
2896
def forward (self , x ):
@@ -2904,6 +2906,7 @@ def forward(self, x):
2904
2906
self .assertIn (op , (torch .ops .aten ._to_copy .default ,))
2905
2907
2906
2908
@testing .expectedFailureTrainingIRToRunDecomp # T193700396
2909
+ @testing .expectedFailureTrainingIRToRunDecompNonStrict
2907
2910
def test_device_to_dynamic (self ):
2908
2911
class Module (torch .nn .Module ):
2909
2912
def forward (self , x ):
@@ -2923,6 +2926,7 @@ def forward(self, x):
2923
2926
self .assertIn (op , (torch .ops .aten ._to_copy .default ,))
2924
2927
2925
2928
@testing .expectedFailureTrainingIRToRunDecomp # T193700396
2929
+ @testing .expectedFailureTrainingIRToRunDecompNonStrict
2926
2930
def test_device_to_mutation (self ):
2927
2931
class Module (torch .nn .Module ):
2928
2932
def forward (self , x ):
@@ -2936,6 +2940,7 @@ def forward(self, x):
2936
2940
export (Module (), (torch .tensor (1 , device = "cpu" ),))
2937
2941
2938
2942
@testing .expectedFailureTrainingIRToRunDecomp # T193700396
2943
+ @testing .expectedFailureTrainingIRToRunDecompNonStrict
2939
2944
def test_float_conversion (self ):
2940
2945
class Module (torch .nn .Module ):
2941
2946
def forward (self , x ):
@@ -2951,6 +2956,7 @@ def forward(self, x):
2951
2956
self .assertIn (op , (torch .ops .aten ._to_copy .default ,))
2952
2957
2953
2958
@testing .expectedFailureTrainingIRToRunDecomp # T193700396
2959
+ @testing .expectedFailureTrainingIRToRunDecompNonStrict
2954
2960
def test_device_to_mutation_float (self ):
2955
2961
class Module (torch .nn .Module ):
2956
2962
def forward (self , x ):
@@ -2964,6 +2970,7 @@ def forward(self, x):
2964
2970
export (Module (), (torch .tensor (1 , dtype = torch .float ),))
2965
2971
2966
2972
@testing .expectedFailureTrainingIRToRunDecomp # T193692674
2973
+ @testing .expectedFailureTrainingIRToRunDecompNonStrict
2967
2974
def test_module (self ):
2968
2975
class MyLinear (torch .nn .Module ):
2969
2976
def __init__ (self ):
@@ -3010,6 +3017,7 @@ def forward(self, x):
3010
3017
)
3011
3018
3012
3019
@testing .expectedFailureTrainingIRToRunDecomp # T193701564
3020
+ @testing .expectedFailureTrainingIRToRunDecompNonStrict
3013
3021
def test_module_with_dict_container_inp_out (self ):
3014
3022
class MyLinear (torch .nn .Module ):
3015
3023
def __init__ (self ):
@@ -3773,6 +3781,7 @@ def forward(self, xs, y):
3773
3781
3774
3782
@testing .expectedFailureSerDer # We don't preserve metadata on graph module
3775
3783
@testing .expectedFailureNonStrict
3784
+ @testing .expectedFailureTrainingIRToRunDecompNonStrict
3776
3785
def test_retrace_graph_level_meta_preservation (self ):
3777
3786
class Foo (torch .nn .Module ):
3778
3787
def __init__ (self ):
@@ -3854,6 +3863,7 @@ def forward(self, x):
3854
3863
3855
3864
# TODO Retracing a module with constant attrs don't work.(T193692674)
3856
3865
@testing .expectedFailureTrainingIRToRunDecomp
3866
+ @testing .expectedFailureTrainingIRToRunDecompNonStrict
3857
3867
@testing .expectedFailureRetraceability # T183144788
3858
3868
def test_lifted_constants (self ) -> None :
3859
3869
class Module (torch .nn .Module ):
@@ -3890,6 +3900,7 @@ def forward(self, x):
3890
3900
3891
3901
@testing .expectedFailureRetraceability # T183144788
3892
3902
@testing .expectedFailureTrainingIRToRunDecomp # T193701164
3903
+ @testing .expectedFailureTrainingIRToRunDecompNonStrict
3893
3904
def test_tensor_attribute_zero_args (self ):
3894
3905
class Foo (torch .nn .Module ):
3895
3906
def __init__ (self , value ):
@@ -4237,6 +4248,7 @@ def forward(self, x):
4237
4248
4238
4249
@testing .expectedFailureRetraceability # Retracing tensor constants results in buffers
4239
4250
@testing .expectedFailureTrainingIRToRunDecomp # T193692674
4251
+ @testing .expectedFailureTrainingIRToRunDecompNonStrict
4240
4252
def test_nested_module_with_constant_buffer (self ):
4241
4253
class M1 (torch .nn .Module ):
4242
4254
def __init__ (self ):
@@ -4386,6 +4398,8 @@ def forward(self, x, y):
4386
4398
self .assertTrue (torch .allclose (ep .module ()(* inp ), M ()(* inp )))
4387
4399
4388
4400
# TODO Retracing a module with constant attrs don't work.(T193692674)
4401
+ @testing .expectedFailureTrainingIRToRunDecomp
4402
+ @testing .expectedFailureTrainingIRToRunDecompNonStrict
4389
4403
@unittest .skip ("Test is only supposed to work with non-strict mode" )
4390
4404
def test_issue_113041 (self ):
4391
4405
class TestModule (torch .nn .Module ):
@@ -5252,6 +5266,7 @@ def forward(self, x):
5252
5266
self .assertEqual (ep .state_dict , m .state_dict ())
5253
5267
5254
5268
@testing .expectedFailureTrainingIRToRunDecomp # T193692674
5269
+ @testing .expectedFailureTrainingIRToRunDecompNonStrict
5255
5270
def test_non_persistent_buffer (self ):
5256
5271
class MyModule (torch .nn .Module ):
5257
5272
def __init__ (self ):
@@ -5319,6 +5334,7 @@ def forward(self, x):
5319
5334
5320
5335
# TODO Retracing a module with constant attrs don't work.(T193692674)
5321
5336
@testing .expectedFailureTrainingIRToRunDecomp
5337
+ @testing .expectedFailureTrainingIRToRunDecompNonStrict
5322
5338
def test_fake_weights (self ):
5323
5339
class MyModule (torch .nn .Module ):
5324
5340
def __init__ (self ):
@@ -5377,8 +5393,6 @@ def forward(self, x):
5377
5393
# under a new FakeTensorMode.
5378
5394
ep = torch .export .export (m , (inp ,))
5379
5395
5380
- # Errors because non-strict is not supported in training IR (T193692164)
5381
- @testing .expectedFailureTrainingIRToRunDecomp
5382
5396
def test_compiling_state (self ):
5383
5397
class TestModule1 (torch .nn .Module ):
5384
5398
def forward (self , x ):
@@ -5428,7 +5442,6 @@ def forward(self, x):
5428
5442
self .assertEqual (mod .foo , ep .module ().foo )
5429
5443
self .assertEqual (mod (torch .ones (4 , 4 )), ep .module ()(torch .ones (4 , 4 )))
5430
5444
5431
- @testing .expectedFailureTrainingIRToRunDecomp # T193702033
5432
5445
def test_symint_tensor_return (self ):
5433
5446
class Module (torch .nn .Module ):
5434
5447
def forward (self , x ):
@@ -5534,6 +5547,7 @@ def forward(self, x):
5534
5547
# TODO Retracing a module with constant attrs don't work.(T193692674)
5535
5548
@testing .expectedFailureRetraceability
5536
5549
@testing .expectedFailureTrainingIRToRunDecomp
5550
+ @testing .expectedFailureTrainingIRToRunDecompNonStrict
5537
5551
def test_placeholder_naming_collisions (self ):
5538
5552
# test collisions between nested user inputs
5539
5553
class Foo (torch .nn .Module ):
@@ -6150,7 +6164,6 @@ def forward(self, x):
6150
6164
for param in ["alpha" , "beta" , "gamma" ]:
6151
6165
self .assertTrue (param in unep .state_dict ())
6152
6166
6153
- @testing .expectedFailureTrainingIRToRunDecomp # nn_module_stack replacement when we do sympy_interp()
6154
6167
def test_intermediate_shape_comp (self ):
6155
6168
class Foo (torch .nn .Module ):
6156
6169
def forward (self , x , y ):
@@ -6182,14 +6195,18 @@ def forward(self, x, y):
6182
6195
all (node .args [0 ].op == "placeholder" for node in sym_size_nodes )
6183
6196
)
6184
6197
# dynamo will DCE the repeat node, AOTAutograd will leave it
6198
+ # training IR will also DCE due to retracing
6185
6199
repeat_nodes = [
6186
6200
node
6187
6201
for node in ep .graph .nodes
6188
6202
if node .target == torch .ops .aten .repeat .default
6189
6203
]
6190
6204
self .assertEqual (
6191
6205
len (repeat_nodes ),
6192
- 1 if is_non_strict_test (self ._testMethodName ) else 0 ,
6206
+ 1
6207
+ if is_non_strict_test (self ._testMethodName )
6208
+ and not is_training_ir_test (self ._testMethodName )
6209
+ else 0 ,
6193
6210
)
6194
6211
6195
6212
def test_checks_to_constrain_range (self ):
0 commit comments