Skip to content

Commit 695e86b

Browse files
committed
Add kwarg example inputs to eager model base
1 parent 6ff6615 commit 695e86b

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

examples/models/llama2/model.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,32 +250,35 @@ def get_eager_model(self):
250250
# switch all to FP32
251251
return self.model_.to(torch.float32)
252252

253-
def get_example_inputs(self):
253+
def get_example_inputs(self) -> Tuple[Tuple, Dict]:
254254
if self.use_kv_cache:
255255
return self.get_example_inputs_kvcache_sdpa()
256256
else:
257-
return (
257+
positional_inputs = (
258258
torch.tensor(
259259
[[1, 2, 3]], dtype=torch.long
260260
), # tokens, with kv cache our input token length is always just 1 token.
261261
)
262+
return (positional_inputs, {})
262263

263264
# assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working
264-
def get_example_inputs_kvcache_sdpa(self):
265+
def get_example_inputs_kvcache_sdpa(self) -> Tuple[Tuple, Dict]:
265266
if self.enable_dynamic_shape:
266-
return (
267+
positional_inputs = (
267268
torch.tensor([[2, 3, 4]], dtype=torch.long),
268269
torch.tensor([0], dtype=torch.long),
269270
)
271+
return (positional_inputs, {})
270272
else:
271-
return (
273+
positional_inputs = (
272274
torch.tensor(
273275
[[1]], dtype=torch.long
274276
), # tokens, with kv cache our input token length is always just 1 token.
275277
torch.tensor(
276278
[0], dtype=torch.long
277279
), # start_pos, what token of output are we on.
278280
)
281+
return (positional_inputs, {})
279282

280283
def _transform_for_pre_quantization(self, checkpoint):
281284
assert hasattr(self.args, "preq_mode"), "preq_mode must be specified"

examples/models/model_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from abc import ABC, abstractmethod
8+
from typing import Dict, Tuple
89

910
import torch
1011

@@ -37,11 +38,11 @@ def get_eager_model(self) -> torch.nn.Module:
3738
raise NotImplementedError("get_eager_model")
3839

3940
@abstractmethod
40-
def get_example_inputs(self):
41+
def get_example_inputs(self) -> Tuple[Tuple, Dict]:
4142
"""
4243
Abstract method to provide example inputs for the model.
4344
4445
Returns:
45-
Any: Example inputs that can be used for testing and tracing.
46+
Tuple[Tuple, Dict]: The positional inputs (Tuple) and the kwarg inputs (Dict).
4647
"""
4748
raise NotImplementedError("get_example_inputs")

0 commit comments

Comments
 (0)