Skip to content

Commit e8998d6

Browse files
pianpwkpytorchmergebot
authored andcommitted
[export] add non-strict training IR (pytorch#130062)
Summary: Adds non-strict implementation of training IR export. Any expected non-strict training IR failures are also either existing strict training IR or non-strict failures (no new failures added). 4 strict training IR failures also resolved. Refraining from unifying export/export_for_training, per @ydwu4's feedback :) Test Plan: added test_export_training_ir_to_run_decomp_non_strict.py for non-strict training IR Differential Revision: D59349454 Pull Request resolved: pytorch#130062 Approved by: https://github.com/ydwu4, https://github.com/zhxchen17
1 parent d2f44ea commit e8998d6

File tree

4 files changed

+160
-69
lines changed

4 files changed

+160
-69
lines changed

test/export/test_export.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ class Inp:
152152
RETRACEABILITY_SUFFIX = "_retraceability"
153153
SERDES_SUFFIX = "_serdes"
154154
PREDISPATCH_SUFFIX = "_pre_dispatch"
155-
TRAINING_IR_DECOMP_SUFFIX = "_training_ir_to_decomp"
155+
TRAINING_IR_DECOMP_STRICT_SUFFIX = "_training_ir_to_decomp"
156+
TRAINING_IR_DECOMP_NON_STRICT_SUFFIX = "_training_ir_to_decomp_non_strict"
156157

157158

158159
def is_non_strict_test(test_name):
@@ -167,6 +168,12 @@ def is_serdes_test(test_name):
167168
return test_name.endswith(SERDES_SUFFIX)
168169

169170

171+
def is_training_ir_test(test_name):
172+
return test_name.endswith(TRAINING_IR_DECOMP_STRICT_SUFFIX) or test_name.endswith(
173+
TRAINING_IR_DECOMP_NON_STRICT_SUFFIX
174+
)
175+
176+
170177
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
171178
class TestDynamismExpression(TestCase):
172179
def test_export_inline_constraints(self):
@@ -309,6 +316,7 @@ def forward(self, x, y):
309316
)
310317

311318
# Errors because fake mode is not detected from non-tensor inputs
319+
@testing.expectedFailureTrainingIRToRunDecompNonStrict
312320
@testing.expectedFailureTrainingIRToRunDecomp
313321
def test_no_tensor_computation_3(self):
314322
class Module(torch.nn.Module):
@@ -346,8 +354,6 @@ def forward(self, x, y):
346354
return (x_0,)""",
347355
)
348356

349-
# Errors because non-strict is not supported in training IR (T193692164)
350-
@testing.expectedFailureTrainingIRToRunDecomp
351357
def test_external_call_non_strict_real_tensor(self):
352358
class ExternalMethod:
353359
def add(self, x):
@@ -418,8 +424,6 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
418424
args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256))
419425
self.assertEqual(gm(*args), m(*args))
420426

421-
# Errors because non-strict is not supported in training IR (T193692164)
422-
@testing.expectedFailureTrainingIRToRunDecomp
423427
def test_basic_non_strict_real_tensor(self):
424428
class Basic(torch.nn.Module):
425429
def __init__(self):
@@ -434,8 +438,6 @@ def forward(self, x, y):
434438
ep = export(f, args, strict=False)
435439
self.assertEqual(ep.module()(*args), f(*args))
436440

437-
# Errors because non-strict is not supported in training IR (T193692164)
438-
@testing.expectedFailureTrainingIRToRunDecomp
439441
def test_basic_non_strict_fake_tensor(self):
440442
class Basic(torch.nn.Module):
441443
def __init__(self):
@@ -690,8 +692,6 @@ def forward(self, x):
690692
torch.allclose(ep.module()(torch.zeros(2, 3)), torch.ones(2, 3) * 21)
691693
)
692694

693-
# Predispatch has different expected results
694-
@testing.expectedFailureTrainingIRToRunDecomp # T193700910
695695
def test_torch_fn(self):
696696
class M1(torch.nn.Module):
697697
def __init__(self):
@@ -823,6 +823,7 @@ def forward(self, p_linear_weight, p_linear_bias, x):
823823
@testing.expectedFailurePreDispatchRunDecomp
824824
@testing.expectedFailureRetraceability
825825
@testing.expectedFailureTrainingIRToRunDecomp # T193700910
826+
@testing.expectedFailureTrainingIRToRunDecompNonStrict
826827
def test_export_cond_preserve_torch_fn_for_subgraphs(self):
827828
class MySubModule(torch.nn.Module):
828829
def foo(self, x):
@@ -2178,6 +2179,7 @@ def forward(self, arg1, arg2, *args, kw1, kw2, **kwargs):
21782179
@testing.expectedFailureSerDer # we don't save placeholder metadata
21792180
@testing.expectedFailureNonStrict
21802181
@testing.expectedFailureTrainingIRToRunDecomp # T193692674
2182+
@testing.expectedFailureTrainingIRToRunDecompNonStrict
21812183
def test_linear_conv(self):
21822184
class MyLinear(torch.nn.Module):
21832185
def __init__(self):
@@ -2853,7 +2855,6 @@ def test_buffer_util(self):
28532855
self.assertEqual(buffer[1].shape, torch.Size([100])) # running_var
28542856
self.assertEqual(buffer[2].shape, torch.Size([])) # num_batches_tracked
28552857

2856-
@testing.expectedFailureTrainingIRToRunDecomp # T193701564
28572858
def test_export_dynamo_config(self):
28582859
class MyModule(torch.nn.Module):
28592860
def __init__(self):
@@ -2889,6 +2890,7 @@ def _patch_config(kwargs):
28892890
_ = export(mod, inp, strict=True)
28902891

28912892
@testing.expectedFailureTrainingIRToRunDecomp # T193700396
2893+
@testing.expectedFailureTrainingIRToRunDecompNonStrict
28922894
def test_device_to_static(self):
28932895
class Module(torch.nn.Module):
28942896
def forward(self, x):
@@ -2904,6 +2906,7 @@ def forward(self, x):
29042906
self.assertIn(op, (torch.ops.aten._to_copy.default,))
29052907

29062908
@testing.expectedFailureTrainingIRToRunDecomp # T193700396
2909+
@testing.expectedFailureTrainingIRToRunDecompNonStrict
29072910
def test_device_to_dynamic(self):
29082911
class Module(torch.nn.Module):
29092912
def forward(self, x):
@@ -2923,6 +2926,7 @@ def forward(self, x):
29232926
self.assertIn(op, (torch.ops.aten._to_copy.default,))
29242927

29252928
@testing.expectedFailureTrainingIRToRunDecomp # T193700396
2929+
@testing.expectedFailureTrainingIRToRunDecompNonStrict
29262930
def test_device_to_mutation(self):
29272931
class Module(torch.nn.Module):
29282932
def forward(self, x):
@@ -2936,6 +2940,7 @@ def forward(self, x):
29362940
export(Module(), (torch.tensor(1, device="cpu"),))
29372941

29382942
@testing.expectedFailureTrainingIRToRunDecomp # T193700396
2943+
@testing.expectedFailureTrainingIRToRunDecompNonStrict
29392944
def test_float_conversion(self):
29402945
class Module(torch.nn.Module):
29412946
def forward(self, x):
@@ -2951,6 +2956,7 @@ def forward(self, x):
29512956
self.assertIn(op, (torch.ops.aten._to_copy.default,))
29522957

29532958
@testing.expectedFailureTrainingIRToRunDecomp # T193700396
2959+
@testing.expectedFailureTrainingIRToRunDecompNonStrict
29542960
def test_device_to_mutation_float(self):
29552961
class Module(torch.nn.Module):
29562962
def forward(self, x):
@@ -2964,6 +2970,7 @@ def forward(self, x):
29642970
export(Module(), (torch.tensor(1, dtype=torch.float),))
29652971

29662972
@testing.expectedFailureTrainingIRToRunDecomp # T193692674
2973+
@testing.expectedFailureTrainingIRToRunDecompNonStrict
29672974
def test_module(self):
29682975
class MyLinear(torch.nn.Module):
29692976
def __init__(self):
@@ -3010,6 +3017,7 @@ def forward(self, x):
30103017
)
30113018

30123019
@testing.expectedFailureTrainingIRToRunDecomp # T193701564
3020+
@testing.expectedFailureTrainingIRToRunDecompNonStrict
30133021
def test_module_with_dict_container_inp_out(self):
30143022
class MyLinear(torch.nn.Module):
30153023
def __init__(self):
@@ -3773,6 +3781,7 @@ def forward(self, xs, y):
37733781

37743782
@testing.expectedFailureSerDer # We don't preserve metadata on graph module
37753783
@testing.expectedFailureNonStrict
3784+
@testing.expectedFailureTrainingIRToRunDecompNonStrict
37763785
def test_retrace_graph_level_meta_preservation(self):
37773786
class Foo(torch.nn.Module):
37783787
def __init__(self):
@@ -3854,6 +3863,7 @@ def forward(self, x):
38543863

38553864
# TODO Retracing a module with constant attrs don't work.(T193692674)
38563865
@testing.expectedFailureTrainingIRToRunDecomp
3866+
@testing.expectedFailureTrainingIRToRunDecompNonStrict
38573867
@testing.expectedFailureRetraceability # T183144788
38583868
def test_lifted_constants(self) -> None:
38593869
class Module(torch.nn.Module):
@@ -3890,6 +3900,7 @@ def forward(self, x):
38903900

38913901
@testing.expectedFailureRetraceability # T183144788
38923902
@testing.expectedFailureTrainingIRToRunDecomp # T193701164
3903+
@testing.expectedFailureTrainingIRToRunDecompNonStrict
38933904
def test_tensor_attribute_zero_args(self):
38943905
class Foo(torch.nn.Module):
38953906
def __init__(self, value):
@@ -4237,6 +4248,7 @@ def forward(self, x):
42374248

42384249
@testing.expectedFailureRetraceability # Retracing tensor constants results in buffers
42394250
@testing.expectedFailureTrainingIRToRunDecomp # T193692674
4251+
@testing.expectedFailureTrainingIRToRunDecompNonStrict
42404252
def test_nested_module_with_constant_buffer(self):
42414253
class M1(torch.nn.Module):
42424254
def __init__(self):
@@ -4386,6 +4398,8 @@ def forward(self, x, y):
43864398
self.assertTrue(torch.allclose(ep.module()(*inp), M()(*inp)))
43874399

43884400
# TODO Retracing a module with constant attrs don't work.(T193692674)
4401+
@testing.expectedFailureTrainingIRToRunDecomp
4402+
@testing.expectedFailureTrainingIRToRunDecompNonStrict
43894403
@unittest.skip("Test is only supposed to work with non-strict mode")
43904404
def test_issue_113041(self):
43914405
class TestModule(torch.nn.Module):
@@ -5252,6 +5266,7 @@ def forward(self, x):
52525266
self.assertEqual(ep.state_dict, m.state_dict())
52535267

52545268
@testing.expectedFailureTrainingIRToRunDecomp # T193692674
5269+
@testing.expectedFailureTrainingIRToRunDecompNonStrict
52555270
def test_non_persistent_buffer(self):
52565271
class MyModule(torch.nn.Module):
52575272
def __init__(self):
@@ -5319,6 +5334,7 @@ def forward(self, x):
53195334

53205335
# TODO Retracing a module with constant attrs don't work.(T193692674)
53215336
@testing.expectedFailureTrainingIRToRunDecomp
5337+
@testing.expectedFailureTrainingIRToRunDecompNonStrict
53225338
def test_fake_weights(self):
53235339
class MyModule(torch.nn.Module):
53245340
def __init__(self):
@@ -5377,8 +5393,6 @@ def forward(self, x):
53775393
# under a new FakeTensorMode.
53785394
ep = torch.export.export(m, (inp,))
53795395

5380-
# Errors because non-strict is not supported in training IR (T193692164)
5381-
@testing.expectedFailureTrainingIRToRunDecomp
53825396
def test_compiling_state(self):
53835397
class TestModule1(torch.nn.Module):
53845398
def forward(self, x):
@@ -5428,7 +5442,6 @@ def forward(self, x):
54285442
self.assertEqual(mod.foo, ep.module().foo)
54295443
self.assertEqual(mod(torch.ones(4, 4)), ep.module()(torch.ones(4, 4)))
54305444

5431-
@testing.expectedFailureTrainingIRToRunDecomp # T193702033
54325445
def test_symint_tensor_return(self):
54335446
class Module(torch.nn.Module):
54345447
def forward(self, x):
@@ -5534,6 +5547,7 @@ def forward(self, x):
55345547
# TODO Retracing a module with constant attrs don't work.(T193692674)
55355548
@testing.expectedFailureRetraceability
55365549
@testing.expectedFailureTrainingIRToRunDecomp
5550+
@testing.expectedFailureTrainingIRToRunDecompNonStrict
55375551
def test_placeholder_naming_collisions(self):
55385552
# test collisions between nested user inputs
55395553
class Foo(torch.nn.Module):
@@ -6150,7 +6164,6 @@ def forward(self, x):
61506164
for param in ["alpha", "beta", "gamma"]:
61516165
self.assertTrue(param in unep.state_dict())
61526166

6153-
@testing.expectedFailureTrainingIRToRunDecomp # nn_module_stack replacement when we do sympy_interp()
61546167
def test_intermediate_shape_comp(self):
61556168
class Foo(torch.nn.Module):
61566169
def forward(self, x, y):
@@ -6182,14 +6195,18 @@ def forward(self, x, y):
61826195
all(node.args[0].op == "placeholder" for node in sym_size_nodes)
61836196
)
61846197
# dynamo will DCE the repeat node, AOTAutograd will leave it
6198+
# training IR will also DCE due to retracing
61856199
repeat_nodes = [
61866200
node
61876201
for node in ep.graph.nodes
61886202
if node.target == torch.ops.aten.repeat.default
61896203
]
61906204
self.assertEqual(
61916205
len(repeat_nodes),
6192-
1 if is_non_strict_test(self._testMethodName) else 0,
6206+
1
6207+
if is_non_strict_test(self._testMethodName)
6208+
and not is_training_ir_test(self._testMethodName)
6209+
else 0,
61936210
)
61946211

61956212
def test_checks_to_constrain_range(self):

test/export/test_export_training_ir_to_run_decomp.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,41 @@
1010
test_classes = {}
1111

1212

13-
def mocked_training_ir_to_run_decomp_export(*args, **kwargs):
13+
def mocked_training_ir_to_run_decomp_export_strict(*args, **kwargs):
1414
ep = _export_for_training(*args, **kwargs)
1515
return ep.run_decompositions(
1616
{}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY
1717
)
1818

1919

20-
def make_dynamic_cls(cls):
21-
cls_prefix = "TrainingIRToRunDecompExport"
22-
23-
test_class = testing.make_test_cls_with_mocked_export(
24-
cls,
25-
cls_prefix,
26-
test_export.TRAINING_IR_DECOMP_SUFFIX,
27-
mocked_training_ir_to_run_decomp_export,
28-
xfail_prop="_expected_failure_training_ir_to_run_decomp",
20+
def mocked_training_ir_to_run_decomp_export_non_strict(*args, **kwargs):
21+
if "strict" in kwargs:
22+
ep = _export_for_training(*args, **kwargs)
23+
else:
24+
ep = _export_for_training(*args, **kwargs, strict=False)
25+
return ep.run_decompositions(
26+
{}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY
2927
)
3028

29+
30+
def make_dynamic_cls(cls, strict):
31+
if strict:
32+
test_class = testing.make_test_cls_with_mocked_export(
33+
cls,
34+
"TrainingIRToRunDecompExport",
35+
test_export.TRAINING_IR_DECOMP_STRICT_SUFFIX,
36+
mocked_training_ir_to_run_decomp_export_strict,
37+
xfail_prop="_expected_failure_training_ir_to_run_decomp",
38+
)
39+
else:
40+
test_class = testing.make_test_cls_with_mocked_export(
41+
cls,
42+
"TrainingIRToRunDecompExportNonStrict",
43+
test_export.TRAINING_IR_DECOMP_NON_STRICT_SUFFIX,
44+
mocked_training_ir_to_run_decomp_export_non_strict,
45+
xfail_prop="_expected_failure_training_ir_to_run_decomp_non_strict",
46+
)
47+
3148
test_classes[test_class.__name__] = test_class
3249
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
3350
globals()[test_class.__name__] = test_class
@@ -40,7 +57,8 @@ def make_dynamic_cls(cls):
4057
test_export.TestExport,
4158
]
4259
for test in tests:
43-
make_dynamic_cls(test)
60+
make_dynamic_cls(test, True)
61+
make_dynamic_cls(test, False)
4462
del test
4563

4664
if __name__ == "__main__":

test/export/testing.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,12 @@ def expectedFailureTrainingIRToRunDecomp(fn):
239239
return fn
240240

241241

242+
# Controls tests generated in test/export/test_export_training_ir_to_run_decomp.py
243+
def expectedFailureTrainingIRToRunDecompNonStrict(fn):
244+
fn._expected_failure_training_ir_to_run_decomp_non_strict = True
245+
return fn
246+
247+
242248
# Controls tests generated in test/export/test_export_nonstrict.py
243249
def expectedFailureNonStrict(fn):
244250
fn._expected_failure_non_strict = True

0 commit comments

Comments
 (0)