Skip to content

Commit 6cd759d

Browse files
committed
Add kwarg example inputs to eager model base
1 parent 6c53356 commit 6cd759d

File tree

4 files changed

+44
-12
lines changed

4 files changed

+44
-12
lines changed

examples/models/llama2/model.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,32 +250,35 @@ def get_eager_model(self):
250250
# switch all to FP32
251251
return self.model_.to(torch.float32)
252252

253-
def get_example_inputs(self):
253+
def get_example_inputs(self) -> Tuple[Tuple, Dict]:
254254
if self.use_kv_cache:
255255
return self.get_example_inputs_kvcache_sdpa()
256256
else:
257-
return (
257+
positional_inputs = (
258258
torch.tensor(
259259
[[1, 2, 3]], dtype=torch.long
260260
), # tokens, with kv cache our input token length is always just 1 token.
261261
)
262+
return (positional_inputs, {})
262263

263264
# assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working
264-
def get_example_inputs_kvcache_sdpa(self):
265+
def get_example_inputs_kvcache_sdpa(self) -> Tuple[Tuple, Dict]:
265266
if self.enable_dynamic_shape:
266-
return (
267+
positional_inputs = (
267268
torch.tensor([[2, 3, 4]], dtype=torch.long),
268269
torch.tensor([0], dtype=torch.long),
269270
)
271+
return (positional_inputs, {})
270272
else:
271-
return (
273+
positional_inputs = (
272274
torch.tensor(
273275
[[1]], dtype=torch.long
274276
), # tokens, with kv cache our input token length is always just 1 token.
275277
torch.tensor(
276278
[0], dtype=torch.long
277279
), # start_pos, what token of output are we on.
278280
)
281+
return (positional_inputs, {})
279282

280283
def _transform_for_pre_quantization(self, checkpoint):
281284
assert hasattr(self.args, "preq_mode"), "preq_mode must be specified"

examples/models/model_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from abc import ABC, abstractmethod
8+
from typing import Dict, Tuple
89

910
import torch
1011

@@ -37,11 +38,11 @@ def get_eager_model(self) -> torch.nn.Module:
3738
raise NotImplementedError("get_eager_model")
3839

3940
@abstractmethod
40-
def get_example_inputs(self):
41+
def get_example_inputs(self) -> Tuple[Tuple, Dict]:
4142
"""
4243
Abstract method to provide example inputs for the model.
4344
4445
Returns:
45-
Any: Example inputs that can be used for testing and tracing.
46+
Tuple[Tuple, Dict]: The positional inputs (Tuple) and the kwarg inputs (Dict).
4647
"""
4748
raise NotImplementedError("get_example_inputs")

extension/export_util/utils.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
def _to_core_aten(
2727
model: Union[torch.fx.GraphModule, torch.nn.Module],
2828
example_inputs: Tuple[Value, ...],
29+
*,
30+
example_kwarg_inputs: Optional[Dict] = None,
2931
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
3032
strict=True,
3133
verbose=True,
@@ -38,7 +40,11 @@ def _to_core_aten(
3840
f"Expected passed in model to be an instance of fx.GraphModule, got {type(model)}"
3941
)
4042
core_aten_ep = export(
41-
model, example_inputs, dynamic_shapes=dynamic_shapes, strict=strict
43+
model,
44+
example_inputs,
45+
example_kwarg_inputs,
46+
dynamic_shapes=dynamic_shapes,
47+
strict=strict,
4248
)
4349
if verbose:
4450
logging.info(f"Core ATen graph:\n{core_aten_ep.graph}")
@@ -69,14 +75,21 @@ def _core_aten_to_edge(
6975
def export_to_edge(
7076
model: Union[torch.fx.GraphModule, torch.nn.Module],
7177
example_inputs: Tuple[Value, ...],
78+
*,
79+
example_kwarg_inputs: Optional[Dict] = None,
7280
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
7381
edge_constant_methods: Optional[Dict[str, Any]] = None,
7482
edge_compile_config=_EDGE_COMPILE_CONFIG,
7583
strict=True,
7684
verbose=True,
7785
) -> EdgeProgramManager:
7886
core_aten_ep = _to_core_aten(
79-
model, example_inputs, dynamic_shapes, strict=strict, verbose=verbose
87+
model,
88+
example_inputs,
89+
example_kwarg_inputs=example_kwarg_inputs,
90+
dynamic_shapes=dynamic_shapes,
91+
strict=strict,
92+
verbose=verbose,
8093
)
8194
return _core_aten_to_edge(
8295
core_aten_ep, edge_constant_methods, edge_compile_config, verbose=verbose
@@ -86,6 +99,8 @@ def export_to_edge(
8699
def export_to_exec_prog(
87100
model: Union[torch.fx.GraphModule, torch.nn.Module],
88101
example_inputs: Tuple[Value, ...],
102+
*,
103+
example_kwarg_inputs: Dict[str, Any] = None,
89104
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
90105
edge_constant_methods: Optional[Dict[str, Any]] = None,
91106
edge_compile_config=_EDGE_COMPILE_CONFIG,
@@ -96,7 +111,13 @@ def export_to_exec_prog(
96111
# pre-autograd export. eventually this will become torch.export
97112
m = export_for_training(m, example_inputs).module()
98113

99-
core_aten_ep = _to_core_aten(m, example_inputs, dynamic_shapes, strict=strict)
114+
core_aten_ep = _to_core_aten(
115+
m,
116+
example_inputs,
117+
example_kwarg_inputs=example_kwarg_inputs,
118+
dynamic_shapes=dynamic_shapes,
119+
strict=strict,
120+
)
100121

101122
edge_m = _core_aten_to_edge(
102123
core_aten_ep, edge_constant_methods, edge_compile_config

extension/llm/export/builder.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import logging
1212
from enum import Enum
13-
from typing import Any, Callable, List, Optional
13+
from typing import Any, Callable, Dict, List, Optional
1414

1515
import torch
1616
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
@@ -68,6 +68,7 @@ def __init__(
6868
dtype,
6969
use_kv_cache,
7070
example_inputs,
71+
example_kwarg_inputs: Optional[Dict] = None,
7172
args: Optional[Any] = None,
7273
enable_dynamic_shape: bool = False,
7374
generate_full_logits: bool = False,
@@ -87,6 +88,7 @@ def __init__(
8788
self.max_seq_len = max_seq_len
8889
self.dtype = dtype
8990
self.example_inputs = example_inputs
91+
self.example_kwarg_inputs = example_kwarg_inputs
9092
self.use_kv_cache = use_kv_cache
9193
self.generate_full_logits = generate_full_logits
9294
self.enable_dynamic_shape = enable_dynamic_shape
@@ -186,12 +188,16 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager":
186188
self.pre_autograd_graph_module = torch.export.export(
187189
self.model,
188190
self.example_inputs,
191+
self.example_kwarg_inputs,
189192
dynamic_shapes=dynamic_shape,
190193
strict=True,
191194
).module()
192195
else:
193196
self.pre_autograd_graph_module = capture_pre_autograd_graph(
194-
self.model, self.example_inputs, dynamic_shapes=dynamic_shape
197+
self.model,
198+
self.example_inputs,
199+
kwargs=self.example_kwarg_inputs,
200+
dynamic_shapes=dynamic_shape,
195201
)
196202

197203
return self
@@ -340,6 +346,7 @@ def export_to_edge(self) -> "LLMEdgeManager":
340346
self.edge_manager = export_to_edge(
341347
self.pre_autograd_graph_module, # pyre-fixme[6]
342348
self.example_inputs,
349+
example_kwarg_inputs=self.example_kwarg_inputs,
343350
dynamic_shapes=dynamic_shape,
344351
edge_constant_methods=self.metadata,
345352
edge_compile_config=edge_config,

0 commit comments

Comments
 (0)