Skip to content

Commit 108fa3b

Browse files
authored
executorch
Differential Revision: D71520535 Pull Request resolved: #9661
1 parent 5fcb6b8 commit 108fa3b

File tree

16 files changed

+55
-60
lines changed

16 files changed

+55
-60
lines changed

backends/apple/coreml/test/test_coreml_quantizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def quantize_and_compare(
3232
) -> None:
3333
assert quantization_type in {"PTQ", "QAT"}
3434

35-
pre_autograd_aten_dialect = export_for_training(model, example_inputs).module()
35+
pre_autograd_aten_dialect = export_for_training(
36+
model, example_inputs, strict=True
37+
).module()
3638

3739
quantization_config = LinearQuantizerConfig.from_dict(
3840
{

backends/apple/mps/test/test_mps_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def lower_module_and_test_output(
207207
expected_output = model(*sample_inputs)
208208

209209
model = torch.export.export_for_training(
210-
model, sample_inputs, dynamic_shapes=dynamic_shapes
210+
model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True
211211
).module()
212212

213213
edge_program = export_to_edge(

backends/cadence/aot/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def convert_pt2(
8686
remove_decompositions(decomp_table, ops_to_keep)
8787
# Export with dynamo
8888
model_gm = (
89-
torch.export.export_for_training(model, inputs)
89+
torch.export.export_for_training(model, inputs, strict=True)
9090
.run_decompositions(decomp_table)
9191
.module()
9292
)

backends/cadence/aot/tests/test_remove_ops_passes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def forward(self, x):
474474
# Run the standard quant/convert steps, but without fusing
475475
# this leaves two redundant quant/dequant pairs to test with
476476
quantizer = CadenceDefaultQuantizer()
477-
model_exp = export_for_training(M(), (inp,)).module()
477+
model_exp = export_for_training(M(), (inp,), strict=True).module()
478478
prepared_model = prepare_pt2e(model_exp, quantizer)
479479
prepared_model(inp)
480480
converted_model = convert_pt2e(prepared_model)

backends/example/test_example_delegate.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def get_example_inputs():
4646
)
4747

4848
m = model.eval()
49-
m = torch.export.export_for_training(m, copy.deepcopy(example_inputs)).module()
49+
m = torch.export.export_for_training(
50+
m, copy.deepcopy(example_inputs), strict=True
51+
).module()
5052
# print("original model:", m)
5153
quantizer = ExampleQuantizer()
5254
# quantizer = XNNPACKQuantizer()
@@ -82,7 +84,9 @@ def test_delegate_mobilenet_v2(self):
8284
)
8385

8486
m = model.eval()
85-
m = torch.export.export_for_training(m, copy.deepcopy(example_inputs)).module()
87+
m = torch.export.export_for_training(
88+
m, copy.deepcopy(example_inputs), strict=True
89+
).module()
8690
quantizer = ExampleQuantizer()
8791

8892
m = prepare_pt2e(m, quantizer)

backends/mediatek/quantizer/annotator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def annotate(graph: Graph, quant_config: QuantizationConfig) -> None:
4444

4545

4646
def register_annotator(ops: List[OpOverload]):
47-
4847
def decorator(annotator_fn: Callable):
4948
for op in ops:
5049
OP_TO_ANNOTATOR[op] = annotator_fn
@@ -147,7 +146,6 @@ def _annotate_fused_activation_pattern(
147146

148147

149148
def _annotate_rmsnorm_pattern(graph: Graph, quant_config: QuantizationConfig) -> None:
150-
151149
class ExecuTorchPattern(torch.nn.Module):
152150
def forward(self, x):
153151
norm = x * torch.rsqrt((x * x).mean(-1, keepdim=True) + 1e-6)
@@ -159,7 +157,9 @@ def forward(self, x):
159157
return norm, {}
160158

161159
for pattern_cls in (ExecuTorchPattern, MTKPattern):
162-
pattern_gm = export_for_training(pattern_cls(), (torch.randn(3, 3),)).module()
160+
pattern_gm = export_for_training(
161+
pattern_cls(), (torch.randn(3, 3),), strict=True
162+
).module()
163163
matcher = SubgraphMatcherWithNameNodeMap(
164164
pattern_gm, ignore_literals=True, remove_overlapping_matches=False
165165
)

backends/qualcomm/tests/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ def get_prepared_qat_module(
567567
custom_quant_annotations: Tuple[Callable] = (),
568568
quant_dtype: QuantDtype = QuantDtype.use_8a8w,
569569
) -> torch.fx.GraphModule:
570-
m = torch.export.export_for_training(module, inputs).module()
570+
m = torch.export.export_for_training(module, inputs, strict=True).module()
571571

572572
quantizer = make_quantizer(
573573
quant_dtype=quant_dtype,

backends/transforms/test/test_duplicate_dynamic_quant_chain.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,7 @@ def _test_duplicate_chain(
5858

5959
# program capture
6060
m = copy.deepcopy(m_eager)
61-
m = torch.export.export_for_training(
62-
m,
63-
example_inputs,
64-
).module()
61+
m = torch.export.export_for_training(m, example_inputs, strict=True).module()
6562

6663
m = prepare_pt2e(m, quantizer)
6764
# Calibrate

backends/xnnpack/test/quantizer/test_pt2e_quantization.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def test_disallow_eval_train(self) -> None:
326326
m.train()
327327

328328
# After export: this is not OK
329-
m = export_for_training(m, example_inputs).module()
329+
m = export_for_training(m, example_inputs, strict=True).module()
330330
with self.assertRaises(NotImplementedError):
331331
m.eval()
332332
with self.assertRaises(NotImplementedError):
@@ -380,7 +380,7 @@ def forward(self, x):
380380
m = M().train()
381381
example_inputs = (torch.randn(1, 3, 3, 3),)
382382
bn_train_op, bn_eval_op = self._get_bn_train_eval_ops() # pyre-ignore[23]
383-
m = export_for_training(m, example_inputs).module()
383+
m = export_for_training(m, example_inputs, strict=True).module()
384384

385385
def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None:
386386
bn_op = bn_train_op if train else bn_eval_op
@@ -449,10 +449,7 @@ def forward(self, x):
449449
quantizer.set_global(operator_config)
450450
example_inputs = (torch.randn(2, 2),)
451451
m = M().eval()
452-
m = export_for_training(
453-
m,
454-
example_inputs,
455-
).module()
452+
m = export_for_training(m, example_inputs, strict=True).module()
456453
weight_meta = None
457454
for n in m.graph.nodes: # pyre-ignore[16]
458455
if (
@@ -481,7 +478,7 @@ def test_reentrant(self) -> None:
481478
get_symmetric_quantization_config(is_per_channel=True, is_qat=True)
482479
)
483480
m.conv_bn_relu = export_for_training( # pyre-ignore[8]
484-
m.conv_bn_relu, example_inputs
481+
m.conv_bn_relu, example_inputs, strict=True
485482
).module()
486483
m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer) # pyre-ignore[6,8]
487484
m(*example_inputs)
@@ -490,7 +487,7 @@ def test_reentrant(self) -> None:
490487
quantizer = XNNPACKQuantizer().set_module_type(
491488
torch.nn.Linear, get_symmetric_quantization_config(is_per_channel=False)
492489
)
493-
m = export_for_training(m, example_inputs).module()
490+
m = export_for_training(m, example_inputs, strict=True).module()
494491
m = prepare_pt2e(m, quantizer) # pyre-ignore[6]
495492
m = convert_pt2e(m)
496493

@@ -553,7 +550,7 @@ def check_nn_module(node: torch.fx.Node) -> None:
553550
)
554551

555552
m.conv_bn_relu = export_for_training( # pyre-ignore[8]
556-
m.conv_bn_relu, example_inputs
553+
m.conv_bn_relu, example_inputs, strict=True
557554
).module()
558555
for node in m.conv_bn_relu.graph.nodes: # pyre-ignore[16]
559556
if node.op not in ["placeholder", "output", "get_attr"]:
@@ -568,7 +565,7 @@ def test_speed(self) -> None:
568565

569566
def dynamic_quantize_pt2e(model, example_inputs) -> torch.fx.GraphModule:
570567
torch._dynamo.reset()
571-
model = export_for_training(model, example_inputs).module()
568+
model = export_for_training(model, example_inputs, strict=True).module()
572569
# Per channel quantization for weight
573570
# Dynamic quantization for activation
574571
# Please read a detail: https://fburl.com/code/30zds51q
@@ -625,7 +622,7 @@ def forward(self, x):
625622

626623
example_inputs = (torch.randn(1, 3, 5, 5),)
627624
m = M()
628-
m = export_for_training(m, example_inputs).module()
625+
m = export_for_training(m, example_inputs, strict=True).module()
629626
quantizer = XNNPACKQuantizer().set_global(
630627
get_symmetric_quantization_config(),
631628
)
@@ -701,7 +698,6 @@ def test_save_load(self) -> None:
701698

702699

703700
class TestNumericDebugger(TestCase):
704-
705701
def _extract_debug_handles(self, model) -> Dict[str, int]:
706702
debug_handle_map: Dict[str, int] = {}
707703

@@ -731,7 +727,7 @@ def _assert_node_has_debug_handle(node: torch.fx.Node) -> None:
731727
def test_quantize_pt2e_preserve_handle(self) -> None:
732728
m = TestHelperModules.Conv2dThenConv1d()
733729
example_inputs = m.example_inputs()
734-
ep = export_for_training(m, example_inputs)
730+
ep = export_for_training(m, example_inputs, strict=True)
735731
generate_numeric_debug_handle(ep)
736732
m = ep.module()
737733

@@ -761,7 +757,7 @@ def test_quantize_pt2e_preserve_handle(self) -> None:
761757
def test_extract_results_from_loggers(self) -> None:
762758
m = TestHelperModules.Conv2dThenConv1d()
763759
example_inputs = m.example_inputs()
764-
ep = export_for_training(m, example_inputs)
760+
ep = export_for_training(m, example_inputs, strict=True)
765761
generate_numeric_debug_handle(ep)
766762
m = ep.module()
767763
m_ref_logger = prepare_for_propagation_comparison(m) # pyre-ignore[6]
@@ -779,18 +775,20 @@ def test_extract_results_from_loggers(self) -> None:
779775
ref_results = extract_results_from_loggers(m_ref_logger)
780776
quant_results = extract_results_from_loggers(m_quant_logger)
781777
comparison_results = compare_results(
782-
ref_results, quant_results # pyre-ignore[6]
778+
ref_results,
779+
quant_results, # pyre-ignore[6]
783780
)
784781
for node_summary in comparison_results.values():
785782
if len(node_summary.results) > 0:
786783
self.assertGreaterEqual(
787-
node_summary.results[0].sqnr, 35 # pyre-ignore[6]
784+
node_summary.results[0].sqnr,
785+
35, # pyre-ignore[6]
788786
)
789787

790788
def test_extract_results_from_loggers_list_output(self) -> None:
791789
m = TestHelperModules.Conv2dWithSplit()
792790
example_inputs = m.example_inputs()
793-
ep = export_for_training(m, example_inputs)
791+
ep = export_for_training(m, example_inputs, strict=True)
794792
generate_numeric_debug_handle(ep)
795793
m = ep.module()
796794
m_ref_logger = prepare_for_propagation_comparison(m) # pyre-ignore[6]
@@ -808,7 +806,8 @@ def test_extract_results_from_loggers_list_output(self) -> None:
808806
ref_results = extract_results_from_loggers(m_ref_logger)
809807
quant_results = extract_results_from_loggers(m_quant_logger)
810808
comparison_results = compare_results(
811-
ref_results, quant_results # pyre-ignore[6]
809+
ref_results,
810+
quant_results, # pyre-ignore[6]
812811
)
813812
for node_summary in comparison_results.values():
814813
if len(node_summary.results) > 0:

backends/xnnpack/test/quantizer/test_representation.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,7 @@ def _test_representation(
3333
) -> None:
3434
# resetting dynamo cache
3535
torch._dynamo.reset()
36-
model = export_for_training(
37-
model,
38-
example_inputs,
39-
).module()
36+
model = export_for_training(model, example_inputs, strict=True).module()
4037
model_copy = copy.deepcopy(model)
4138

4239
model = prepare_pt2e(model, quantizer) # pyre-ignore[6]

backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def forward(self, x):
361361
)
362362
example_inputs = (torch.randn(2, 2),)
363363
m = M().eval()
364-
m = export_for_training(m, example_inputs).module()
364+
m = export_for_training(m, example_inputs, strict=True).module()
365365
m = prepare_pt2e(m, quantizer) # pyre-ignore[6]
366366
# Use a linear count instead of names because the names might change, but
367367
# the order should be the same.
@@ -497,10 +497,7 @@ def test_propagate_annotation(self):
497497
example_inputs = (torch.randn(1, 3, 5, 5),)
498498

499499
# program capture
500-
m = export_for_training(
501-
m,
502-
example_inputs,
503-
).module()
500+
m = export_for_training(m, example_inputs, strict=True).module()
504501

505502
m = prepare_pt2e(m, quantizer)
506503
m(*example_inputs)
@@ -766,8 +763,7 @@ def forward(self, input_tensor, hidden_tensor):
766763

767764
with torchdynamo.config.patch(allow_rnn=True):
768765
model_graph = export_for_training(
769-
model_graph,
770-
example_inputs,
766+
model_graph, example_inputs, strict=True
771767
).module()
772768
quantizer = XNNPACKQuantizer()
773769
quantization_config = get_symmetric_quantization_config(
@@ -829,8 +825,7 @@ def forward(self, input_tensor, hidden_tensor):
829825

830826
with torchdynamo.config.patch(allow_rnn=True):
831827
model_graph = export_for_training(
832-
model_graph,
833-
example_inputs,
828+
model_graph, example_inputs, strict=True
834829
).module()
835830
quantizer = XNNPACKQuantizer()
836831
quantization_config = get_symmetric_quantization_config(
@@ -1039,10 +1034,7 @@ def test_resnet18(self):
10391034
m = torchvision.models.resnet18().eval()
10401035
m_copy = copy.deepcopy(m)
10411036
# program capture
1042-
m = export_for_training(
1043-
m,
1044-
example_inputs,
1045-
).module()
1037+
m = export_for_training(m, example_inputs, strict=True).module()
10461038

10471039
quantizer = XNNPACKQuantizer()
10481040
quantization_config = get_symmetric_quantization_config(is_per_channel=True)

backends/xnnpack/test/test_xnnpack_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -317,10 +317,7 @@ def quantize_and_test_model_with_quantizer(
317317
module.eval()
318318
# program capture
319319

320-
m = export_for_training(
321-
module,
322-
example_inputs,
323-
).module()
320+
m = export_for_training(module, example_inputs, strict=True).module()
324321

325322
quantizer = XNNPACKQuantizer()
326323
quantization_config = get_symmetric_quantization_config()

backends/xnnpack/test/tester/tester.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def run(
166166
self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]]
167167
) -> None:
168168
assert inputs is not None
169-
captured_graph = export_for_training(artifact, inputs).module()
169+
captured_graph = export_for_training(artifact, inputs, strict=True).module()
170170

171171
assert isinstance(captured_graph, torch.fx.GraphModule)
172172
prepared = prepare_pt2e(captured_graph, self.quantizer)

docs/source/tutorials_source/export-to-executorch-tutorial.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
190190
from torch.export import export_for_training
191191

192192
example_args = (torch.randn(1, 3, 256, 256),)
193-
pre_autograd_aten_dialect = export_for_training(SimpleConv(), example_args).module()
193+
pre_autograd_aten_dialect = export_for_training(
194+
SimpleConv(), example_args, strict=True
195+
).module()
194196
print("Pre-Autograd ATen Dialect Graph")
195197
print(pre_autograd_aten_dialect)
196198

@@ -555,7 +557,7 @@ def forward(self, x):
555557

556558

557559
example_args = (torch.randn(3, 4),)
558-
pre_autograd_aten_dialect = export_for_training(M(), example_args).module()
560+
pre_autograd_aten_dialect = export_for_training(M(), example_args, strict=True).module()
559561
# Optionally do quantization:
560562
# pre_autograd_aten_dialect = convert_pt2e(prepare_pt2e(pre_autograd_aten_dialect, CustomBackendQuantizer))
561563
aten_dialect = export(pre_autograd_aten_dialect, example_args, strict=True)

examples/apple/mps/scripts/mps_example.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,9 @@ def get_model_config(args):
166166

167167
# pre-autograd export. eventually this will become torch.export
168168
with torch.no_grad():
169-
model = torch.export.export_for_training(model, example_inputs).module()
169+
model = torch.export.export_for_training(
170+
model, example_inputs, strict=True
171+
).module()
170172
edge: EdgeProgramManager = export_to_edge(
171173
model,
172174
example_inputs,

examples/arm/aot_arm_compiler.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,6 @@ def forward(self, x):
224224

225225

226226
class MultipleOutputsModule(torch.nn.Module):
227-
228227
def forward(self, x: torch.Tensor, y: torch.Tensor):
229228
return (x * y, x.sum(dim=-1, keepdim=True))
230229

@@ -648,7 +647,9 @@ def to_edge_TOSA_delegate(
648647
)
649648
model_int8 = model
650649
# Wrap quantized model back into an exported_program
651-
exported_program = torch.export.export_for_training(model, example_inputs)
650+
exported_program = torch.export.export_for_training(
651+
model, example_inputs, strict=True
652+
)
652653

653654
if args.intermediates:
654655
os.makedirs(args.intermediates, exist_ok=True)
@@ -681,7 +682,9 @@ def to_edge_TOSA_delegate(
681682

682683
# export_for_training under the assumption we quantize, the exported form also works
683684
# in to_edge if we don't quantize
684-
exported_program = torch.export.export_for_training(model, example_inputs)
685+
exported_program = torch.export.export_for_training(
686+
model, example_inputs, strict=True
687+
)
685688
model = exported_program.module()
686689
model_fp32 = model
687690

0 commit comments

Comments
 (0)