Skip to content

Commit 553eb7e

Browse files
mergennachinfacebook-github-bot
authored andcommitted
Rename capture_pre_autograd_graph private method (#6214)
Summary: These are mainly no-op changes. The underlying function is already changed. git grep `capture_pre_autograd_graph` finds these helper functions but in fact we already done migrating. So let's change the name of the method This is not going to be cherry-picked, but fixed in main. Reviewed By: helunwencser Differential Revision: D64370367
1 parent e342a92 commit 553eb7e

File tree

6 files changed

+13
-15
lines changed

6 files changed

+13
-15
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def gen_eval_wrapper(
194194
manager: LLMEdgeManager = _prepare_for_llama_export(model_name, args)
195195

196196
if len(quantizers) != 0:
197-
manager = manager.capture_pre_autograd_graph().pt2e_quantize(quantizers)
197+
manager = manager.export().pt2e_quantize(quantizers)
198198
model = (
199199
manager.pre_autograd_graph_module.to(device="cuda") # pyre-ignore
200200
if torch.cuda.is_available()
@@ -209,7 +209,7 @@ def gen_eval_wrapper(
209209
)
210210
else:
211211
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
212-
# for quantizers. Currently capture_pre_autograd_graph only works with --kv_cache, but
212+
# for quantizers. Currently export_for_training only works with --kv_cache, but
213213
# fails without the kv_cache mode
214214
model = (
215215
manager.model.eval().to(device="cuda")

examples/models/llama2/export_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
581581
# export_to_edge
582582
builder_exported_to_edge = (
583583
_prepare_for_llama_export(modelname, args)
584-
.capture_pre_autograd_graph()
584+
.export()
585585
.pt2e_quantize(quantizers)
586586
.export_to_edge()
587587
)

examples/models/llava/export_llava.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353

5454

5555
class LlavaEdgeManager(LLMEdgeManager):
56-
def capture_pre_autograd_graph(self) -> "LlavaEdgeManager":
56+
def export(self) -> "LlavaEdgeManager":
5757
dynamic_shape = self._get_dynamic_shape()
5858
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
5959
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
@@ -107,7 +107,7 @@ def forward(self, input_pos, embeddings):
107107
text_model_em.set_output_dir("./")
108108
.to_dtype(dtype_override)
109109
.source_transform(source_transforms)
110-
.capture_pre_autograd_graph()
110+
.export()
111111
.pt2e_quantize(quantizers)
112112
)
113113

@@ -148,7 +148,7 @@ def forward(self, images):
148148
dynamic_shapes=dynamic_shapes,
149149
args=None,
150150
)
151-
.capture_pre_autograd_graph()
151+
.export()
152152
.pt2e_quantize([quantizer])
153153
)
154154

examples/portable/scripts/export.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,7 @@ def main() -> None:
6565
backend_config = ExecutorchBackendConfig()
6666
if args.segment_alignment is not None:
6767
backend_config.segment_alignment = int(args.segment_alignment, 16)
68-
if (
69-
dynamic_shapes is not None
70-
): # capture_pre_autograd_graph does not work with dynamic shapes
68+
if dynamic_shapes is not None:
7169
edge_manager = export_to_edge(
7270
model,
7371
example_inputs,

extension/llm/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Commonly used methods in this class include:
1010
- _source_transform_: execute a series of source transform passes. Some transform passes include
1111
- weight only quantization, which can be done at source (eager mode) level.
1212
- replace some torch operators to a custom operator. For example, _replace_sdpa_with_custom_op_.
13-
- _capture_pre_autograd_graph_: get a graph that is ready for pt2 graph-based quantization.
13+
- _torch.export_for_training_: get a graph that is ready for pt2 graph-based quantization.
1414
- _pt2e_quantize_ with passed in quantizers.
1515
- util functions in _quantizer_lib.py_ can help to get different quantizers based on the needs.
1616
- _export_to_edge_: export to edge dialect

extension/llm/export/builder.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282
dynamic_shapes: Optional[Any] = None,
8383
):
8484
self.model = model
85-
# graph module returned from capture_pre_autograd_graph
85+
# graph module returned from export()
8686
self.pre_autograd_graph_module: Optional[torch.fx.GraphModule] = None
8787
self.modelname = modelname
8888
self.max_seq_len = max_seq_len
@@ -176,7 +176,7 @@ def _get_edge_config(self) -> EdgeCompileConfig:
176176
)
177177
return edge_config
178178

179-
def capture_pre_autograd_graph(self) -> "LLMEdgeManager":
179+
def export(self) -> "LLMEdgeManager":
180180
dynamic_shape = self._get_dynamic_shape()
181181
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
182182
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
@@ -296,7 +296,7 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
296296
composed_quantizer = ComposableQuantizer(quantizers)
297297
assert (
298298
self.pre_autograd_graph_module is not None
299-
), "Please run capture_pre_autograd_graph first"
299+
), "Please run export() first"
300300
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
301301
logging.info(
302302
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
@@ -344,8 +344,8 @@ def export_to_edge(self) -> "LLMEdgeManager":
344344
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
345345
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
346346
if self.pre_autograd_graph_module is None:
347-
# Run capture_pre_autograd_graph if it didn't run
348-
self.capture_pre_autograd_graph()
347+
# Run export() if it didn't run
348+
self.export()
349349
self.edge_manager = export_to_edge(
350350
self.pre_autograd_graph_module, # pyre-fixme[6]
351351
self.example_inputs,

0 commit comments

Comments
 (0)