Skip to content

Commit 809820e

Browse files
committed
Add kwarg example inputs to eager model base
1 parent 1751fdc commit 809820e

File tree

4 files changed

+44
-12
lines changed

4 files changed

+44
-12
lines changed

examples/models/llama2/model.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -292,29 +292,32 @@ def get_eager_model(self):
292292
# switch all to FP32
293293
return self.model_.to(torch.float32)
294294

295-
def get_example_inputs(self):
295+
def get_example_inputs(self) -> Tuple[Tuple, Dict]:
296296
if self.use_kv_cache:
297297
return self.get_example_inputs_kvcache_sdpa()
298298
else:
299-
return (
299+
positional_inputs = (
300300
torch.tensor(
301301
[[1, 2, 3]], dtype=torch.long
302302
), # tokens, with kv cache our input token length is always just 1 token.
303303
)
304+
return (positional_inputs, {})
304305

305306
# assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working
306-
def get_example_inputs_kvcache_sdpa(self):
307+
def get_example_inputs_kvcache_sdpa(self) -> Tuple[Tuple, Dict]:
307308
if self.enable_dynamic_shape:
308-
return (
309+
positional_inputs = (
309310
torch.tensor([[2, 3, 4]], dtype=torch.long),
310311
torch.tensor([0], dtype=torch.long),
311312
)
313+
return (positional_inputs, {})
312314
else:
313-
return (
315+
positional_inputs = (
314316
torch.tensor(
315317
[[1]], dtype=torch.long
316318
), # tokens, with kv cache our input token length is always just 1 token.
317319
torch.tensor(
318320
[0], dtype=torch.long
319321
), # start_pos, what token of output are we on.
320322
)
323+
return (positional_inputs, {})

examples/models/model_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from abc import ABC, abstractmethod
8+
from typing import Dict, Tuple
89

910
import torch
1011

@@ -37,11 +38,11 @@ def get_eager_model(self) -> torch.nn.Module:
3738
raise NotImplementedError("get_eager_model")
3839

3940
@abstractmethod
40-
def get_example_inputs(self):
41+
def get_example_inputs(self) -> Tuple[Tuple, Dict]:
4142
"""
4243
Abstract method to provide example inputs for the model.
4344
4445
Returns:
45-
Any: Example inputs that can be used for testing and tracing.
46+
Tuple[Tuple, Dict]: The positional inputs (Tuple) and the kwarg inputs (Dict).
4647
"""
4748
raise NotImplementedError("get_example_inputs")

extension/export_util/utils.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
def _to_core_aten(
2727
model: Union[torch.fx.GraphModule, torch.nn.Module],
2828
example_inputs: Tuple[Value, ...],
29+
*,
30+
example_kwarg_inputs: Optional[Dict] = None,
2931
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
3032
strict=True,
3133
verbose=True,
@@ -38,7 +40,11 @@ def _to_core_aten(
3840
f"Expected passed in model to be an instance of fx.GraphModule, got {type(model)}"
3941
)
4042
core_aten_ep = export(
41-
model, example_inputs, dynamic_shapes=dynamic_shapes, strict=strict
43+
model,
44+
example_inputs,
45+
example_kwarg_inputs,
46+
dynamic_shapes=dynamic_shapes,
47+
strict=strict,
4248
)
4349
if verbose:
4450
logging.info(f"Core ATen graph:\n{core_aten_ep.graph}")
@@ -69,14 +75,21 @@ def _core_aten_to_edge(
6975
def export_to_edge(
7076
model: Union[torch.fx.GraphModule, torch.nn.Module],
7177
example_inputs: Tuple[Value, ...],
78+
*,
79+
example_kwarg_inputs: Optional[Dict] = None,
7280
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
7381
edge_constant_methods: Optional[Dict[str, Any]] = None,
7482
edge_compile_config=_EDGE_COMPILE_CONFIG,
7583
strict=True,
7684
verbose=True,
7785
) -> EdgeProgramManager:
7886
core_aten_ep = _to_core_aten(
79-
model, example_inputs, dynamic_shapes, strict=strict, verbose=verbose
87+
model,
88+
example_inputs,
89+
example_kwarg_inputs=example_kwarg_inputs,
90+
dynamic_shapes=dynamic_shapes,
91+
strict=strict,
92+
verbose=verbose,
8093
)
8194
return _core_aten_to_edge(
8295
core_aten_ep, edge_constant_methods, edge_compile_config, verbose=verbose
@@ -86,6 +99,8 @@ def export_to_edge(
8699
def export_to_exec_prog(
87100
model: Union[torch.fx.GraphModule, torch.nn.Module],
88101
example_inputs: Tuple[Value, ...],
102+
*,
103+
example_kwarg_inputs: Dict[str, Any] = None,
89104
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
90105
edge_constant_methods: Optional[Dict[str, Any]] = None,
91106
edge_compile_config=_EDGE_COMPILE_CONFIG,
@@ -96,7 +111,13 @@ def export_to_exec_prog(
96111
# pre-autograd export. eventually this will become torch.export
97112
m = export_for_training(m, example_inputs).module()
98113

99-
core_aten_ep = _to_core_aten(m, example_inputs, dynamic_shapes, strict=strict)
114+
core_aten_ep = _to_core_aten(
115+
m,
116+
example_inputs,
117+
example_kwarg_inputs=example_kwarg_inputs,
118+
dynamic_shapes=dynamic_shapes,
119+
strict=strict,
120+
)
100121

101122
edge_m = _core_aten_to_edge(
102123
core_aten_ep, edge_constant_methods, edge_compile_config

extension/llm/export/builder.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import logging
1212
from enum import Enum
13-
from typing import Any, Callable, List, Optional
13+
from typing import Any, Callable, Dict, List, Optional
1414

1515
import torch
1616
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
@@ -68,6 +68,7 @@ def __init__(
6868
dtype,
6969
use_kv_cache,
7070
example_inputs,
71+
example_kwarg_inputs: Optional[Dict] = None,
7172
args: Optional[Any] = None,
7273
enable_dynamic_shape: bool = False,
7374
generate_full_logits: bool = False,
@@ -87,6 +88,7 @@ def __init__(
8788
self.max_seq_len = max_seq_len
8889
self.dtype = dtype
8990
self.example_inputs = example_inputs
91+
self.example_kwarg_inputs = example_kwarg_inputs
9092
self.use_kv_cache = use_kv_cache
9193
self.generate_full_logits = generate_full_logits
9294
self.enable_dynamic_shape = enable_dynamic_shape
@@ -185,12 +187,16 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager":
185187
self.pre_autograd_graph_module = torch.export.export(
186188
self.model,
187189
self.example_inputs,
190+
self.example_kwarg_inputs,
188191
dynamic_shapes=dynamic_shape,
189192
strict=True,
190193
).module()
191194
else:
192195
self.pre_autograd_graph_module = capture_pre_autograd_graph(
193-
self.model, self.example_inputs, dynamic_shapes=dynamic_shape
196+
self.model,
197+
self.example_inputs,
198+
kwargs=self.example_kwarg_inputs,
199+
dynamic_shapes=dynamic_shape,
194200
)
195201

196202
return self
@@ -337,6 +343,7 @@ def export_to_edge(self) -> "LLMEdgeManager":
337343
self.edge_manager = export_to_edge(
338344
self.pre_autograd_graph_module, # pyre-fixme[6]
339345
self.example_inputs,
346+
example_kwarg_inputs=self.example_kwarg_inputs,
340347
dynamic_shapes=dynamic_shape,
341348
edge_constant_methods=self.metadata,
342349
edge_compile_config=edge_config,

0 commit comments

Comments
 (0)