Skip to content

Commit 126eebf

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Add kwarg example inputs to eager model base (#5765)
Summary: For situations where the forward has non-position arguments, such as https://github.com/pytorch/torchtune/blob/3c450ef5f1fbe8237f899e942fd5222491a47ca7/torchtune/modules/transformer.py#L519 PR chain: - **YOU ARE HERE ~>** [Add kwarg example inputs to eager model base](#5765) - [Llama2 model cleanup](#5859) - [Accept model type parameter in export_llama](#5910) - [Export TorchTune llama3_2_vision in ET](#5911) - [Add et version of TorchTune MHA for swapping with custom op](#5912) Test Plan: Exported Stories110M model. ``` wget "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt" echo '{"dim": 768, "multiple_of": 32, "n_heads": 12, "n_layers": 12, "norm_eps": 1e-05, "vocab_size": 32000}' > params.json python -m examples.models.llama2.export_llama -c stories110M.pt -p params.json -X -kv ``` Reviewed By: tarun292 Differential Revision: D64027696 Pulled By: dvorjackz
1 parent ed9f50f commit 126eebf

File tree

17 files changed

+69
-36
lines changed

17 files changed

+69
-36
lines changed

examples/apple/coreml/scripts/debugger_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def main() -> None:
149149
root_dir_path=get_root_dir_path(), conda_env_name=args.conda_environment_name
150150
)
151151

152-
model, example_inputs, _ = EagerModelFactory.create_model(
152+
model, example_inputs, _, _ = EagerModelFactory.create_model(
153153
*MODEL_NAME_TO_MODEL[args.model_name]
154154
)
155155

examples/apple/coreml/scripts/export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def main():
158158
f"Valid compute units are {valid_compute_units}."
159159
)
160160

161-
model, example_inputs, _ = EagerModelFactory.create_model(
161+
model, example_inputs, _, _ = EagerModelFactory.create_model(
162162
*MODEL_NAME_TO_MODEL[args.model_name]
163163
)
164164

examples/apple/mps/scripts/mps_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def get_model_config(args):
152152
raise RuntimeError(f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}.")
153153

154154
model_config = get_model_config(args)
155-
model, example_inputs, _ = EagerModelFactory.create_model(**model_config)
155+
model, example_inputs, _, _ = EagerModelFactory.create_model(**model_config)
156156

157157
model = model.eval()
158158

examples/arm/aot_arm_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def get_model_and_inputs_from_name(model_name: str):
5050
logging.warning(
5151
"Using a model from examples/models not all of these are currently supported"
5252
)
53-
model, example_inputs, _ = EagerModelFactory.create_model(
53+
model, example_inputs, _, _ = EagerModelFactory.create_model(
5454
*MODEL_NAME_TO_MODEL[model_name]
5555
)
5656
# Case 3: Model is in an external python file loaded as a module.

examples/devtools/scripts/export_bundled_program.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def main() -> None:
139139
f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}."
140140
)
141141

142-
model, example_inputs, _ = EagerModelFactory.create_model(
142+
model, example_inputs, _, _ = EagerModelFactory.create_model(
143143
*MODEL_NAME_TO_MODEL[args.model_name]
144144
)
145145

examples/devtools/scripts/gen_sample_etrecord.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def main() -> None:
7474
f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}."
7575
)
7676

77-
model, example_inputs, _ = EagerModelFactory.create_model(
77+
model, example_inputs, _, _ = EagerModelFactory.create_model(
7878
*MODEL_NAME_TO_MODEL[args.model_name]
7979
)
8080

examples/models/llama2/export_llama_lib.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ def _load_llama_model(
780780
logging.info(
781781
f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}"
782782
)
783-
model, example_inputs, _ = EagerModelFactory.create_model(
783+
model, example_inputs, example_kwarg_inputs, _ = EagerModelFactory.create_model(
784784
"llama2",
785785
"Llama2Model",
786786
checkpoint=checkpoint,
@@ -830,6 +830,7 @@ def _load_llama_model(
830830
use_kv_cache=use_kv_cache,
831831
generate_full_logits=generate_full_logits,
832832
example_inputs=example_inputs,
833+
example_kwarg_inputs=example_kwarg_inputs,
833834
enable_dynamic_shape=enable_dynamic_shape,
834835
calibration_tasks=calibration_tasks,
835836
calibration_limit=calibration_limit,

examples/models/llama2/runner/eager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(self, args):
3131
**params,
3232
)
3333
super().__init__(tokenizer_path=args.tokenizer, model_args=model_args)
34-
self.model, _, _ = EagerModelFactory.create_model(
34+
self.model, _, _, _ = EagerModelFactory.create_model(
3535
"llama2",
3636
"Llama2Model",
3737
checkpoint=args.checkpoint,

examples/models/model_factory.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import importlib
88
import os
9-
from typing import Any, Tuple
9+
from typing import Any, Dict, Tuple
1010

1111
import torch
1212

@@ -19,7 +19,7 @@ class EagerModelFactory:
1919
@staticmethod
2020
def create_model(
2121
module_name, model_class_name, **kwargs
22-
) -> Tuple[torch.nn.Module, Any, Any]:
22+
) -> Tuple[torch.nn.Module, Tuple[Any], Dict[str, Any], Any]:
2323
"""
2424
Create an instance of a model class that implements EagerModelBase and retrieve related data.
2525
@@ -42,14 +42,18 @@ def create_model(
4242
if hasattr(module, model_class_name):
4343
model_class = getattr(module, model_class_name)
4444
model = model_class(**kwargs)
45+
example_kwarg_inputs = None
46+
dynamic_shapes = None
47+
if hasattr(model, "get_example_kwarg_inputs()"):
48+
example_kwarg_inputs = model.get_example_kwarg_inputs()
4549
if hasattr(model, "get_dynamic_shapes"):
46-
return (
47-
model.get_eager_model(),
48-
model.get_example_inputs(),
49-
model.get_dynamic_shapes(),
50-
)
51-
else:
52-
return model.get_eager_model(), model.get_example_inputs(), None
50+
dynamic_shapes = model.get_dynamic_shapes()
51+
return (
52+
model.get_eager_model(),
53+
model.get_example_inputs(),
54+
example_kwarg_inputs,
55+
dynamic_shapes,
56+
)
5357

5458
raise ValueError(
5559
f"Model class '{model_class_name}' not found in module '{module_name}'."

examples/models/test/test_export.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def validate_tensor_allclose(
6969
return self.assertTrue(result)
7070

7171
def test_mv3_export_to_executorch(self):
72-
eager_model, example_inputs, _ = EagerModelFactory.create_model(
72+
eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
7373
*MODEL_NAME_TO_MODEL["mv3"]
7474
)
7575
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
@@ -81,7 +81,7 @@ def test_mv3_export_to_executorch(self):
8181
)
8282

8383
def test_mv2_export_to_executorch(self):
84-
eager_model, example_inputs, _ = EagerModelFactory.create_model(
84+
eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
8585
*MODEL_NAME_TO_MODEL["mv2"]
8686
)
8787
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
@@ -90,7 +90,7 @@ def test_mv2_export_to_executorch(self):
9090
self.validate_tensor_allclose(eager_output, executorch_output[0])
9191

9292
def test_vit_export_to_executorch(self):
93-
eager_model, example_inputs, _ = EagerModelFactory.create_model(
93+
eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
9494
*MODEL_NAME_TO_MODEL["vit"]
9595
)
9696
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
@@ -102,7 +102,7 @@ def test_vit_export_to_executorch(self):
102102
)
103103

104104
def test_w2l_export_to_executorch(self):
105-
eager_model, example_inputs, _ = EagerModelFactory.create_model(
105+
eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
106106
*MODEL_NAME_TO_MODEL["w2l"]
107107
)
108108
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
@@ -111,7 +111,7 @@ def test_w2l_export_to_executorch(self):
111111
self.validate_tensor_allclose(eager_output, executorch_output[0])
112112

113113
def test_ic3_export_to_executorch(self):
114-
eager_model, example_inputs, _ = EagerModelFactory.create_model(
114+
eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
115115
*MODEL_NAME_TO_MODEL["ic3"]
116116
)
117117
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
@@ -123,7 +123,7 @@ def test_ic3_export_to_executorch(self):
123123
)
124124

125125
def test_resnet18_export_to_executorch(self):
126-
eager_model, example_inputs, _ = EagerModelFactory.create_model(
126+
eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
127127
*MODEL_NAME_TO_MODEL["resnet18"]
128128
)
129129
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
@@ -132,7 +132,7 @@ def test_resnet18_export_to_executorch(self):
132132
self.validate_tensor_allclose(eager_output, executorch_output[0])
133133

134134
def test_resnet50_export_to_executorch(self):
135-
eager_model, example_inputs, _ = EagerModelFactory.create_model(
135+
eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
136136
*MODEL_NAME_TO_MODEL["resnet50"]
137137
)
138138
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
@@ -141,7 +141,7 @@ def test_resnet50_export_to_executorch(self):
141141
self.validate_tensor_allclose(eager_output, executorch_output[0])
142142

143143
def test_dl3_export_to_executorch(self):
144-
eager_model, example_inputs, _ = EagerModelFactory.create_model(
144+
eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
145145
*MODEL_NAME_TO_MODEL["dl3"]
146146
)
147147
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(

examples/portable/scripts/export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def main() -> None:
5858
f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}."
5959
)
6060

61-
model, example_inputs, dynamic_shapes = EagerModelFactory.create_model(
61+
model, example_inputs, _, dynamic_shapes = EagerModelFactory.create_model(
6262
*MODEL_NAME_TO_MODEL[args.model_name]
6363
)
6464

examples/portable/scripts/export_and_delegate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def export_composite_module_with_lower_graph():
5757
"Running the example to export a composite module with lowered graph..."
5858
)
5959

60-
m, m_inputs, _ = EagerModelFactory.create_model(*MODEL_NAME_TO_MODEL["add_mul"])
60+
m, m_inputs, _, _ = EagerModelFactory.create_model(*MODEL_NAME_TO_MODEL["add_mul"])
6161
m_compile_spec = m.get_compile_spec()
6262

6363
# pre-autograd export. eventually this will become torch.export
@@ -166,7 +166,7 @@ def export_and_lower_the_whole_graph():
166166
"""
167167
logging.info("Running the example to export and lower the whole graph...")
168168

169-
m, m_inputs, _ = EagerModelFactory.create_model(*MODEL_NAME_TO_MODEL["add_mul"])
169+
m, m_inputs, _, _ = EagerModelFactory.create_model(*MODEL_NAME_TO_MODEL["add_mul"])
170170
m_compile_spec = m.get_compile_spec()
171171

172172
m_inputs = m.get_example_inputs()

examples/qualcomm/scripts/export_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def main() -> None:
5858
f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}."
5959
)
6060

61-
model, example_inputs, _ = EagerModelFactory.create_model(
61+
model, example_inputs, _, _ = EagerModelFactory.create_model(
6262
*MODEL_NAME_TO_MODEL[args.model_name]
6363
)
6464

examples/xnnpack/aot_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
f"Available models are {list(MODEL_NAME_TO_OPTIONS.keys())}."
7676
)
7777

78-
model, example_inputs, _ = EagerModelFactory.create_model(
78+
model, example_inputs, _, _ = EagerModelFactory.create_model(
7979
*MODEL_NAME_TO_MODEL[args.model_name]
8080
)
8181

examples/xnnpack/quantization/example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def main() -> None:
162162
)
163163

164164
start = time.perf_counter()
165-
model, example_inputs, _ = EagerModelFactory.create_model(
165+
model, example_inputs, _, _ = EagerModelFactory.create_model(
166166
*MODEL_NAME_TO_MODEL[args.model_name]
167167
)
168168
end = time.perf_counter()

extension/export_util/utils.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
def _to_core_aten(
2727
model: Union[torch.fx.GraphModule, torch.nn.Module],
2828
example_inputs: Tuple[Value, ...],
29+
*,
30+
example_kwarg_inputs: Optional[Dict] = None,
2931
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
3032
strict=True,
3133
verbose=True,
@@ -38,7 +40,11 @@ def _to_core_aten(
3840
f"Expected passed in model to be an instance of fx.GraphModule, got {type(model)}"
3941
)
4042
core_aten_ep = export(
41-
model, example_inputs, dynamic_shapes=dynamic_shapes, strict=strict
43+
model,
44+
example_inputs,
45+
example_kwarg_inputs,
46+
dynamic_shapes=dynamic_shapes,
47+
strict=strict,
4248
)
4349
if verbose:
4450
logging.info(f"Core ATen graph:\n{core_aten_ep.graph}")
@@ -69,14 +75,21 @@ def _core_aten_to_edge(
6975
def export_to_edge(
7076
model: Union[torch.fx.GraphModule, torch.nn.Module],
7177
example_inputs: Tuple[Value, ...],
78+
*,
79+
example_kwarg_inputs: Optional[Dict] = None,
7280
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
7381
edge_constant_methods: Optional[Dict[str, Any]] = None,
7482
edge_compile_config=_EDGE_COMPILE_CONFIG,
7583
strict=True,
7684
verbose=True,
7785
) -> EdgeProgramManager:
7886
core_aten_ep = _to_core_aten(
79-
model, example_inputs, dynamic_shapes, strict=strict, verbose=verbose
87+
model,
88+
example_inputs,
89+
example_kwarg_inputs=example_kwarg_inputs,
90+
dynamic_shapes=dynamic_shapes,
91+
strict=strict,
92+
verbose=verbose,
8093
)
8194
return _core_aten_to_edge(
8295
core_aten_ep, edge_constant_methods, edge_compile_config, verbose=verbose
@@ -86,6 +99,8 @@ def export_to_edge(
8699
def export_to_exec_prog(
87100
model: Union[torch.fx.GraphModule, torch.nn.Module],
88101
example_inputs: Tuple[Value, ...],
102+
*,
103+
example_kwarg_inputs: Optional[Dict[str, Any]] = None,
89104
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
90105
edge_constant_methods: Optional[Dict[str, Any]] = None,
91106
edge_compile_config=_EDGE_COMPILE_CONFIG,
@@ -96,7 +111,13 @@ def export_to_exec_prog(
96111
# pre-autograd export. eventually this will become torch.export
97112
m = export_for_training(m, example_inputs).module()
98113

99-
core_aten_ep = _to_core_aten(m, example_inputs, dynamic_shapes, strict=strict)
114+
core_aten_ep = _to_core_aten(
115+
m,
116+
example_inputs,
117+
example_kwarg_inputs=example_kwarg_inputs,
118+
dynamic_shapes=dynamic_shapes,
119+
strict=strict,
120+
)
100121

101122
edge_m = _core_aten_to_edge(
102123
core_aten_ep, edge_constant_methods, edge_compile_config

extension/llm/export/builder.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import logging
1212
from enum import Enum
13-
from typing import Any, Callable, List, Optional
13+
from typing import Any, Callable, Dict, List, Optional
1414

1515
import torch
1616
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
@@ -68,6 +68,7 @@ def __init__(
6868
dtype,
6969
use_kv_cache,
7070
example_inputs,
71+
example_kwarg_inputs: Optional[Dict] = None,
7172
args: Optional[Any] = None,
7273
enable_dynamic_shape: bool = False,
7374
generate_full_logits: bool = False,
@@ -87,6 +88,7 @@ def __init__(
8788
self.max_seq_len = max_seq_len
8889
self.dtype = dtype
8990
self.example_inputs = example_inputs
91+
self.example_kwarg_inputs = example_kwarg_inputs
9092
self.use_kv_cache = use_kv_cache
9193
self.generate_full_logits = generate_full_logits
9294
self.enable_dynamic_shape = enable_dynamic_shape
@@ -186,12 +188,16 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager":
186188
self.pre_autograd_graph_module = torch.export.export(
187189
self.model,
188190
self.example_inputs,
191+
self.example_kwarg_inputs,
189192
dynamic_shapes=dynamic_shape,
190193
strict=True,
191194
).module()
192195
else:
193196
self.pre_autograd_graph_module = export_for_training(
194-
self.model, self.example_inputs, dynamic_shapes=dynamic_shape
197+
self.model,
198+
self.example_inputs,
199+
kwargs=self.example_kwarg_inputs,
200+
dynamic_shapes=dynamic_shape,
195201
).module()
196202

197203
return self
@@ -340,6 +346,7 @@ def export_to_edge(self) -> "LLMEdgeManager":
340346
self.edge_manager = export_to_edge(
341347
self.pre_autograd_graph_module, # pyre-fixme[6]
342348
self.example_inputs,
349+
example_kwarg_inputs=self.example_kwarg_inputs,
343350
dynamic_shapes=dynamic_shape,
344351
edge_constant_methods=self.metadata,
345352
edge_compile_config=edge_config,

0 commit comments

Comments
 (0)