Skip to content

Commit 301f762

Browse files
committed
Add kwarg example inputs to eager model base
1 parent f9c001a commit 301f762

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

examples/models/llama2/model.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,32 +222,35 @@ def get_eager_model(self) -> torch.nn.Module:
222222
# switch all to FP32
223223
return self.model_.to(torch.float32)
224224

225-
def get_example_inputs(self):
225+
def get_example_inputs(self) -> Tuple[Tuple, Dict]:
226226
if self.use_kv_cache:
227227
return self.get_example_inputs_kvcache_sdpa()
228228
else:
229-
return (
229+
positional_inputs = (
230230
torch.tensor(
231231
[[1, 2, 3]], dtype=torch.long
232232
), # tokens, with kv cache our input token length is always just 1 token.
233233
)
234+
return (positional_inputs, {})
234235

235236
# assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working
236-
def get_example_inputs_kvcache_sdpa(self):
237+
def get_example_inputs_kvcache_sdpa(self) -> Tuple[Tuple, Dict]:
237238
if self.enable_dynamic_shape:
238-
return (
239+
positional_inputs = (
239240
torch.tensor([[2, 3, 4]], dtype=torch.long),
240241
torch.tensor([0], dtype=torch.long),
241242
)
243+
return (positional_inputs, {})
242244
else:
243-
return (
245+
positional_inputs = (
244246
torch.tensor(
245247
[[1]], dtype=torch.long
246248
), # tokens, with kv cache our input token length is always just 1 token.
247249
torch.tensor(
248250
[0], dtype=torch.long
249251
), # start_pos, what token of output are we on.
250252
)
253+
return (positional_inputs, {})
251254

252255
def _transform_for_pre_quantization(self, checkpoint):
253256
assert hasattr(self.args, "preq_mode"), "preq_mode must be specified"

examples/models/model_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from abc import ABC, abstractmethod
8+
from typing import Dict, Tuple
89

910
import torch
1011

@@ -37,11 +38,11 @@ def get_eager_model(self) -> torch.nn.Module:
3738
raise NotImplementedError("get_eager_model")
3839

3940
@abstractmethod
40-
def get_example_inputs(self):
41+
def get_example_inputs(self) -> Tuple[Tuple, Dict]:
4142
"""
4243
Abstract method to provide example inputs for the model.
4344
4445
Returns:
45-
Any: Example inputs that can be used for testing and tracing.
46+
Tuple[Tuple, Dict]: The positional inputs (Tuple) and the kwarg inputs (Dict).
4647
"""
4748
raise NotImplementedError("get_example_inputs")

0 commit comments

Comments
 (0)