Skip to content

Commit 9be5f57

Browse files
committed
Create create new method for example kwarg inputs instead
1 parent 08aacb2 commit 9be5f57

File tree

17 files changed

+43
-42
lines changed

17 files changed

+43
-42
lines changed

examples/apple/coreml/scripts/debugger_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def main() -> None:
149149
root_dir_path=get_root_dir_path(), conda_env_name=args.conda_environment_name
150150
)
151151

152-
model, example_inputs, _ = EagerModelFactory.create_model(
152+
model, example_inputs, _, _ = EagerModelFactory.create_model(
153153
*MODEL_NAME_TO_MODEL[args.model_name]
154154
)
155155

examples/apple/coreml/scripts/export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def main():
158158
f"Valid compute units are {valid_compute_units}."
159159
)
160160

161-
model, example_inputs, _ = EagerModelFactory.create_model(
161+
model, example_inputs, _, _ = EagerModelFactory.create_model(
162162
*MODEL_NAME_TO_MODEL[args.model_name]
163163
)
164164

examples/apple/mps/scripts/mps_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def get_model_config(args):
152152
raise RuntimeError(f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}.")
153153

154154
model_config = get_model_config(args)
155-
model, example_inputs, _ = EagerModelFactory.create_model(**model_config)
155+
model, example_inputs, _, _ = EagerModelFactory.create_model(**model_config)
156156

157157
model = model.eval()
158158

examples/arm/aot_arm_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def get_model_and_inputs_from_name(model_name: str):
4646
logging.warning(
4747
"Using a model from examples/models not all of these are currently supported"
4848
)
49-
model, example_inputs, _ = EagerModelFactory.create_model(
49+
model, example_inputs, _, _ = EagerModelFactory.create_model(
5050
*MODEL_NAME_TO_MODEL[model_name]
5151
)
5252
# Case 3: Model is in an external python file loaded as a module.

examples/devtools/scripts/export_bundled_program.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def main() -> None:
139139
f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}."
140140
)
141141

142-
model, example_inputs, _ = EagerModelFactory.create_model(
142+
model, example_inputs, _, _ = EagerModelFactory.create_model(
143143
*MODEL_NAME_TO_MODEL[args.model_name]
144144
)
145145

examples/devtools/scripts/gen_sample_etrecord.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def main() -> None:
7474
f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}."
7575
)
7676

77-
model, example_inputs, _ = EagerModelFactory.create_model(
77+
model, example_inputs, _, _ = EagerModelFactory.create_model(
7878
*MODEL_NAME_TO_MODEL[args.model_name]
7979
)
8080

examples/models/llama2/export_llama_lib.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ def _load_llama_model(
780780
logging.info(
781781
f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}"
782782
)
783-
model, example_inputs, _ = EagerModelFactory.create_model(
783+
model, example_inputs, example_kwarg_inputs, _ = EagerModelFactory.create_model(
784784
"llama2",
785785
"Llama2Model",
786786
checkpoint=checkpoint,
@@ -830,6 +830,7 @@ def _load_llama_model(
830830
use_kv_cache=use_kv_cache,
831831
generate_full_logits=generate_full_logits,
832832
example_inputs=example_inputs,
833+
example_kwarg_inputs=example_kwarg_inputs,
833834
enable_dynamic_shape=enable_dynamic_shape,
834835
calibration_tasks=calibration_tasks,
835836
calibration_limit=calibration_limit,

examples/models/llama2/model.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -250,35 +250,32 @@ def get_eager_model(self):
250250
# switch all to FP32
251251
return self.model_.to(torch.float32)
252252

253-
def get_example_inputs(self) -> Tuple[Tuple, Dict]:
253+
def get_example_inputs(self):
254254
if self.use_kv_cache:
255255
return self.get_example_inputs_kvcache_sdpa()
256256
else:
257-
positional_inputs = (
257+
return (
258258
torch.tensor(
259259
[[1, 2, 3]], dtype=torch.long
260260
), # tokens, with kv cache our input token length is always just 1 token.
261261
)
262-
return (positional_inputs, {})
263262

264263
# assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working
265-
def get_example_inputs_kvcache_sdpa(self) -> Tuple[Tuple, Dict]:
264+
def get_example_inputs_kvcache_sdpa(self):
266265
if self.enable_dynamic_shape:
267-
positional_inputs = (
266+
return (
268267
torch.tensor([[2, 3, 4]], dtype=torch.long),
269268
torch.tensor([0], dtype=torch.long),
270269
)
271-
return (positional_inputs, {})
272270
else:
273-
positional_inputs = (
271+
return (
274272
torch.tensor(
275273
[[1]], dtype=torch.long
276274
), # tokens, with kv cache our input token length is always just 1 token.
277275
torch.tensor(
278276
[0], dtype=torch.long
279277
), # start_pos, what token of output are we on.
280278
)
281-
return (positional_inputs, {})
282279

283280
def _transform_for_pre_quantization(self, checkpoint):
284281
assert hasattr(self.args, "preq_mode"), "preq_mode must be specified"

examples/models/llama2/runner/eager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(self, args):
3131
**params,
3232
)
3333
super().__init__(tokenizer_path=args.tokenizer, model_args=model_args)
34-
self.model, _, _ = EagerModelFactory.create_model(
34+
self.model, _, _, _ = EagerModelFactory.create_model(
3535
"llama2",
3636
"Llama2Model",
3737
checkpoint=args.checkpoint,

examples/models/model_base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from abc import ABC, abstractmethod
8-
from typing import Dict, Tuple
98

109
import torch
1110

@@ -38,11 +37,11 @@ def get_eager_model(self) -> torch.nn.Module:
3837
raise NotImplementedError("get_eager_model")
3938

4039
@abstractmethod
41-
def get_example_inputs(self) -> Tuple[Tuple, Dict]:
40+
def get_example_inputs(self):
4241
"""
4342
Abstract method to provide example inputs for the model.
4443
4544
Returns:
46-
Tuple[Tuple, Dict]: The positional inputs (Tuple) and the kwarg inputs (Dict).
45+
Any: Example inputs that can be used for testing and tracing.
4746
"""
4847
raise NotImplementedError("get_example_inputs")

examples/models/model_factory.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import importlib
88
import os
9-
from typing import Any, Tuple
9+
from typing import Any, Dict, Tuple
1010

1111
import torch
1212

@@ -19,7 +19,7 @@ class EagerModelFactory:
1919
@staticmethod
2020
def create_model(
2121
module_name, model_class_name, **kwargs
22-
) -> Tuple[torch.nn.Module, Any, Any]:
22+
) -> Tuple[torch.nn.Module, Tuple[Any], Dict[str, Any], Any]:
2323
"""
2424
Create an instance of a model class that implements EagerModelBase and retrieve related data.
2525
@@ -42,14 +42,18 @@ def create_model(
4242
if hasattr(module, model_class_name):
4343
model_class = getattr(module, model_class_name)
4444
model = model_class(**kwargs)
45+
example_kwarg_inputs = None
46+
dynamic_shapes = None
47+
if hasattr(model, "get_example_kwarg_inputs()"):
48+
example_kwarg_inputs = model.get_example_kwarg_inputs()
4549
if hasattr(model, "get_dynamic_shapes"):
46-
return (
47-
model.get_eager_model(),
48-
model.get_example_inputs(),
49-
model.get_dynamic_shapes(),
50-
)
51-
else:
52-
return model.get_eager_model(), model.get_example_inputs(), None
50+
dynamic_shapes = model.get_dynamic_shapes()
51+
return (
52+
model.get_eager_model(),
53+
model.get_example_inputs(),
54+
example_kwarg_inputs,
55+
dynamic_shapes,
56+
)
5357

5458
raise ValueError(
5559
f"Model class '{model_class_name}' not found in module '{module_name}'."

examples/models/test/test_export.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def validate_tensor_allclose(
6969
return self.assertTrue(result)
7070

7171
def test_mv3_export_to_executorch(self):
72-
eager_model, example_inputs, _ = EagerModelFactory.create_model(
72+
eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
7373
*MODEL_NAME_TO_MODEL["mv3"]
7474
)
7575
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
@@ -81,7 +81,7 @@ def test_mv3_export_to_executorch(self):
8181
)
8282

8383
def test_mv2_export_to_executorch(self):
84-
eager_model, example_inputs, _ = EagerModelFactory.create_model(
84+
eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
8585
*MODEL_NAME_TO_MODEL["mv2"]
8686
)
8787
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
@@ -90,7 +90,7 @@ def test_mv2_export_to_executorch(self):
9090
self.validate_tensor_allclose(eager_output, executorch_output[0])
9191

9292
def test_vit_export_to_executorch(self):
93-
eager_model, example_inputs, _ = EagerModelFactory.create_model(
93+
eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
9494
*MODEL_NAME_TO_MODEL["vit"]
9595
)
9696
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
@@ -102,7 +102,7 @@ def test_vit_export_to_executorch(self):
102102
)
103103

104104
def test_w2l_export_to_executorch(self):
105-
eager_model, example_inputs, _ = EagerModelFactory.create_model(
105+
eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
106106
*MODEL_NAME_TO_MODEL["w2l"]
107107
)
108108
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
@@ -111,7 +111,7 @@ def test_w2l_export_to_executorch(self):
111111
self.validate_tensor_allclose(eager_output, executorch_output[0])
112112

113113
def test_ic3_export_to_executorch(self):
114-
eager_model, example_inputs, _ = EagerModelFactory.create_model(
114+
eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
115115
*MODEL_NAME_TO_MODEL["ic3"]
116116
)
117117
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
@@ -123,7 +123,7 @@ def test_ic3_export_to_executorch(self):
123123
)
124124

125125
def test_resnet18_export_to_executorch(self):
126-
eager_model, example_inputs, _ = EagerModelFactory.create_model(
126+
eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
127127
*MODEL_NAME_TO_MODEL["resnet18"]
128128
)
129129
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
@@ -132,7 +132,7 @@ def test_resnet18_export_to_executorch(self):
132132
self.validate_tensor_allclose(eager_output, executorch_output[0])
133133

134134
def test_resnet50_export_to_executorch(self):
135-
eager_model, example_inputs, _ = EagerModelFactory.create_model(
135+
eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
136136
*MODEL_NAME_TO_MODEL["resnet50"]
137137
)
138138
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
@@ -141,7 +141,7 @@ def test_resnet50_export_to_executorch(self):
141141
self.validate_tensor_allclose(eager_output, executorch_output[0])
142142

143143
def test_dl3_export_to_executorch(self):
144-
eager_model, example_inputs, _ = EagerModelFactory.create_model(
144+
eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
145145
*MODEL_NAME_TO_MODEL["dl3"]
146146
)
147147
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(

examples/portable/scripts/export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def main() -> None:
5858
f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}."
5959
)
6060

61-
model, example_inputs, dynamic_shapes = EagerModelFactory.create_model(
61+
model, example_inputs, _, dynamic_shapes = EagerModelFactory.create_model(
6262
*MODEL_NAME_TO_MODEL[args.model_name]
6363
)
6464

examples/portable/scripts/export_and_delegate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def export_composite_module_with_lower_graph():
5757
"Running the example to export a composite module with lowered graph..."
5858
)
5959

60-
m, m_inputs, _ = EagerModelFactory.create_model(*MODEL_NAME_TO_MODEL["add_mul"])
60+
m, m_inputs, _, _ = EagerModelFactory.create_model(*MODEL_NAME_TO_MODEL["add_mul"])
6161
m_compile_spec = m.get_compile_spec()
6262

6363
# pre-autograd export. eventually this will become torch.export
@@ -166,7 +166,7 @@ def export_and_lower_the_whole_graph():
166166
"""
167167
logging.info("Running the example to export and lower the whole graph...")
168168

169-
m, m_inputs, _ = EagerModelFactory.create_model(*MODEL_NAME_TO_MODEL["add_mul"])
169+
m, m_inputs, _, _ = EagerModelFactory.create_model(*MODEL_NAME_TO_MODEL["add_mul"])
170170
m_compile_spec = m.get_compile_spec()
171171

172172
m_inputs = m.get_example_inputs()

examples/qualcomm/scripts/export_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def main() -> None:
5858
f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}."
5959
)
6060

61-
model, example_inputs, _ = EagerModelFactory.create_model(
61+
model, example_inputs, _, _ = EagerModelFactory.create_model(
6262
*MODEL_NAME_TO_MODEL[args.model_name]
6363
)
6464

examples/xnnpack/aot_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
f"Available models are {list(MODEL_NAME_TO_OPTIONS.keys())}."
7676
)
7777

78-
model, example_inputs, _ = EagerModelFactory.create_model(
78+
model, example_inputs, _, _ = EagerModelFactory.create_model(
7979
*MODEL_NAME_TO_MODEL[args.model_name]
8080
)
8181

examples/xnnpack/quantization/example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def main() -> None:
162162
)
163163

164164
start = time.perf_counter()
165-
model, example_inputs, _ = EagerModelFactory.create_model(
165+
model, example_inputs, _, _ = EagerModelFactory.create_model(
166166
*MODEL_NAME_TO_MODEL[args.model_name]
167167
)
168168
end = time.perf_counter()

0 commit comments

Comments
 (0)