Skip to content

Update docs in executorch, remove capture_pre_autograd_graph references #6613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/apple/coreml/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ To quantize a Program in a Core ML favored way, the client may utilize **CoreMLQ
import torch
import executorch.exir

from torch._export import capture_pre_autograd_graph
from torch.export import export_for_training
from torch.ao.quantization.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
Expand Down Expand Up @@ -93,7 +93,7 @@ class Model(torch.nn.Module):
source_model = Model()
example_inputs = (torch.randn((1, 3, 256, 256)), )

pre_autograd_aten_dialect = capture_pre_autograd_graph(model, example_inputs)
pre_autograd_aten_dialect = export_for_training(model, example_inputs).module()

quantization_config = LinearQuantizerConfig.from_dict(
{
Expand Down
11 changes: 5 additions & 6 deletions docs/source/llm/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,7 @@ import torch

from executorch.exir import EdgeCompileConfig, to_edge
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch._export import capture_pre_autograd_graph
from torch.export import export
from torch.export import export, export_for_training

from model import GPT

Expand All @@ -170,7 +169,7 @@ dynamic_shape = (
# Trace the model, converting it to a portable intermediate representation.
# The torch.no_grad() call tells PyTorch to exclude training-specific logic.
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
m = capture_pre_autograd_graph(model, example_inputs, dynamic_shapes=dynamic_shape)
m = export_for_training(model, example_inputs, dynamic_shapes=dynamic_shape).module()
traced_model = export(m, example_inputs, dynamic_shapes=dynamic_shape)

# Convert the model into a runnable ExecuTorch program.
Expand Down Expand Up @@ -462,7 +461,7 @@ from executorch.exir import EdgeCompileConfig, to_edge
import torch
from torch.export import export
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch._export import capture_pre_autograd_graph
from torch.export import export_for_training

from model import GPT

Expand All @@ -489,7 +488,7 @@ dynamic_shape = (
# Trace the model, converting it to a portable intermediate representation.
# The torch.no_grad() call tells PyTorch to exclude training-specific logic.
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
m = capture_pre_autograd_graph(model, example_inputs, dynamic_shapes=dynamic_shape)
m = export_for_training(model, example_inputs, dynamic_shapes=dynamic_shape).module()
traced_model = export(m, example_inputs, dynamic_shapes=dynamic_shape)

# Convert the model into a runnable ExecuTorch program.
Expand Down Expand Up @@ -635,7 +634,7 @@ xnnpack_quant_config = get_symmetric_quantization_config(
xnnpack_quantizer = XNNPACKQuantizer()
xnnpack_quantizer.set_global(xnnpack_quant_config)

m = capture_pre_autograd_graph(model, example_inputs)
m = export_for_training(model, example_inputs).module()

# Annotate the model for quantization. This prepares the model for calibration.
m = prepare_pt2e(m, xnnpack_quantizer)
Expand Down
Loading