Skip to content

Commit 4629f3a

Browse files
yushangdifacebook-github-bot
authored andcommitted
Migrate to training IR in executorch, part 2 (#5950)
Summary: Pull Request resolved: #5950 as title Reviewed By: mergennachin, tugsbayasgalan Differential Revision: D64004508 fbshipit-source-id: f605312c57c92b7f9e81e1283748a25466cbbc84
1 parent 085193e commit 4629f3a

File tree

9 files changed

+21
-21
lines changed

9 files changed

+21
-21
lines changed

backends/apple/coreml/test/test_coreml_quantizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
)
1616

1717
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
18-
from torch._export import capture_pre_autograd_graph
1918
from torch.ao.quantization.quantize_pt2e import (
2019
convert_pt2e,
2120
prepare_pt2e,
2221
prepare_qat_pt2e,
2322
)
23+
from torch.export import export_for_training
2424

2525

2626
class TestCoreMLQuantizer:
@@ -32,7 +32,7 @@ def quantize_and_compare(
3232
) -> None:
3333
assert quantization_type in {"PTQ", "QAT"}
3434

35-
pre_autograd_aten_dialect = capture_pre_autograd_graph(model, example_inputs)
35+
pre_autograd_aten_dialect = export_for_training(model, example_inputs).module()
3636

3737
quantization_config = LinearQuantizerConfig.from_dict(
3838
{

backends/apple/mps/test/test_mps_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,9 @@ def lower_module_and_test_output(
209209

210210
expected_output = model(*sample_inputs)
211211

212-
model = torch._export.capture_pre_autograd_graph(
212+
model = torch.export.export_for_training(
213213
model, sample_inputs, dynamic_shapes=dynamic_shapes
214-
)
214+
).module()
215215

216216
edge_program = export_to_edge(
217217
model,

backends/mediatek/quantizer/annotator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from typing import Callable, List
88

99
import torch
10-
11-
from torch._export import capture_pre_autograd_graph
1210
from torch._ops import OpOverload
1311
from torch._subclasses import FakeTensor
1412

@@ -17,6 +15,8 @@
1715
_annotate_input_qspec_map,
1816
_annotate_output_qspec,
1917
)
18+
19+
from torch.export import export_for_training
2020
from torch.fx import Graph, Node
2121
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
2222
SubgraphMatcherWithNameNodeMap,
@@ -159,7 +159,7 @@ def forward(self, x):
159159
return norm, {}
160160

161161
for pattern_cls in (ExecuTorchPattern, MTKPattern):
162-
pattern_gm = capture_pre_autograd_graph(pattern_cls(), (torch.randn(3, 3),))
162+
pattern_gm = export_for_training(pattern_cls(), (torch.randn(3, 3),)).module()
163163
matcher = SubgraphMatcherWithNameNodeMap(
164164
pattern_gm, ignore_literals=True, remove_overlapping_matches=False
165165
)

backends/transforms/test/test_duplicate_dynamic_quant_chain.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import unittest
99

1010
import torch
11-
import torch._export as export
1211
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
1312
DuplicateDynamicQuantChainPass,
1413
)
@@ -59,10 +58,10 @@ def _test_duplicate_chain(
5958

6059
# program capture
6160
m = copy.deepcopy(m_eager)
62-
m = export.capture_pre_autograd_graph(
61+
m = torch.export.export_for_training(
6362
m,
6463
example_inputs,
65-
)
64+
).module()
6665

6766
m = prepare_pt2e(m, quantizer)
6867
# Calibrate

examples/llm_manual/export_nanogpt.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
from executorch.exir import to_edge
1616

1717
from model import GPT
18-
from torch._export import capture_pre_autograd_graph
19-
from torch.export import export
18+
from torch.export import export, export_for_training
2019
from torch.nn.attention import sdpa_kernel, SDPBackend
2120

2221
model = GPT.from_pretrained("gpt2") # use gpt2 weight as pretrained weight
@@ -28,7 +27,9 @@
2827
# Trace the model, converting it to a portable intermediate representation.
2928
# The torch.no_grad() call tells PyTorch to exclude training-specific logic.
3029
with sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
31-
m = capture_pre_autograd_graph(model, example_inputs, dynamic_shapes=dynamic_shape)
30+
m = export_for_training(
31+
model, example_inputs, dynamic_shapes=dynamic_shape
32+
).module()
3233
traced_model = export(m, example_inputs, dynamic_shapes=dynamic_shape)
3334

3435
# Convert the model into a runnable ExecuTorch program.

examples/mediatek/aot_utils/oss_utils/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def build_executorch_binary(
3030
if quant_dtype not in Precision:
3131
raise AssertionError(f"No support for Precision {quant_dtype}.")
3232

33-
captured_model = torch._export.capture_pre_autograd_graph(model, inputs)
33+
captured_model = torch.export.export_for_training(model, inputs).module()
3434
annotated_model = prepare_pt2e(captured_model, quantizer)
3535
print("Quantizing the model...")
3636
# calibration

examples/mediatek/model_export_scripts/llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,9 @@ def export_to_et_ir(
318318
max_num_token, max_cache_size, True
319319
)
320320
print("Getting pre autograd ATen Dialect Graph")
321-
pre_autograd_aten_dialect = torch._export.capture_pre_autograd_graph(
321+
pre_autograd_aten_dialect = torch.export.export_for_training(
322322
model, example_inputs, dynamic_shapes=dynamic_shapes
323-
) # NOTE: Will be replaced with export
323+
).module() # NOTE: Will be replaced with export
324324
quantizer = NeuropilotQuantizer()
325325
quantizer.setup_precision(getattr(Precision, precision))
326326
prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer)

exir/tests/test_quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ def test_resnet(self) -> None:
5151
m = torchvision.models.resnet18().eval()
5252
m_copy = copy.deepcopy(m)
5353
# program capture
54-
m = torch._export.capture_pre_autograd_graph(
54+
m = torch.export.export_for_training(
5555
m, copy.deepcopy(example_inputs)
56-
)
56+
).module()
5757

5858
quantizer = XNNPACKQuantizer()
5959
operator_config = get_symmetric_quantization_config(is_per_channel=True)

extension/llm/export/builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@
2929

3030
from executorch.extension.export_util.utils import export_to_edge, save_pte_program
3131
from executorch.extension.llm.tokenizer.utils import get_tokenizer
32-
from torch._export import capture_pre_autograd_graph
3332
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
3433
from torch.ao.quantization.quantizer import Quantizer
3534
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
35+
from torch.export import export_for_training
3636
from torch.nn.attention import SDPBackend
3737

3838
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -190,9 +190,9 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager":
190190
strict=True,
191191
).module()
192192
else:
193-
self.pre_autograd_graph_module = capture_pre_autograd_graph(
193+
self.pre_autograd_graph_module = export_for_training(
194194
self.model, self.example_inputs, dynamic_shapes=dynamic_shape
195-
)
195+
).module()
196196

197197
return self
198198

0 commit comments

Comments
 (0)