@@ -144,8 +144,7 @@ import torch
144
144
145
145
from executorch.exir import EdgeCompileConfig, to_edge
146
146
from torch.nn.attention import sdpa_kernel, SDPBackend
147
- from torch._export import capture_pre_autograd_graph
148
- from torch.export import export
147
+ from torch.export import export, export_for_training
149
148
150
149
from model import GPT
151
150
@@ -170,7 +169,7 @@ dynamic_shape = (
170
169
# Trace the model, converting it to a portable intermediate representation.
171
170
# The torch.no_grad() call tells PyTorch to exclude training-specific logic.
172
171
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH ]), torch.no_grad():
173
- m = capture_pre_autograd_graph (model, example_inputs, dynamic_shapes = dynamic_shape)
172
+ m = export_for_training (model, example_inputs, dynamic_shapes = dynamic_shape).module( )
174
173
traced_model = export(m, example_inputs, dynamic_shapes = dynamic_shape)
175
174
176
175
# Convert the model into a runnable ExecuTorch program.
@@ -462,7 +461,7 @@ from executorch.exir import EdgeCompileConfig, to_edge
462
461
import torch
463
462
from torch.export import export
464
463
from torch.nn.attention import sdpa_kernel, SDPBackend
465
- from torch._export import capture_pre_autograd_graph
464
+ from torch.export import export_for_training
466
465
467
466
from model import GPT
468
467
@@ -489,7 +488,7 @@ dynamic_shape = (
489
488
# Trace the model, converting it to a portable intermediate representation.
490
489
# The torch.no_grad() call tells PyTorch to exclude training-specific logic.
491
490
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH ]), torch.no_grad():
492
- m = capture_pre_autograd_graph (model, example_inputs, dynamic_shapes = dynamic_shape)
491
+ m = export_for_training (model, example_inputs, dynamic_shapes = dynamic_shape).module( )
493
492
traced_model = export(m, example_inputs, dynamic_shapes = dynamic_shape)
494
493
495
494
# Convert the model into a runnable ExecuTorch program.
@@ -635,7 +634,7 @@ xnnpack_quant_config = get_symmetric_quantization_config(
635
634
xnnpack_quantizer = XNNPACKQuantizer()
636
635
xnnpack_quantizer.set_global(xnnpack_quant_config)
637
636
638
- m = capture_pre_autograd_graph (model, example_inputs)
637
+ m = export_for_training (model, example_inputs).module( )
639
638
640
639
# Annotate the model for quantization. This prepares the model for calibration.
641
640
m = prepare_pt2e(m, xnnpack_quantizer)
0 commit comments