Skip to content

Commit 0b5a9a7

Browse files
committed
Add kwarg example inputs to eager model base
1 parent 517fddb commit 0b5a9a7

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

examples/models/llama2/model.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -224,32 +224,35 @@ def get_eager_model(self) -> torch.nn.Module:
224224
# switch all to FP32
225225
return self.model_.to(torch.float32)
226226

227-
def get_example_inputs(self):
227+
def get_example_inputs(self) -> Tuple[Tuple, Dict]:
228228
if self.use_kv_cache:
229229
return self.get_example_inputs_kvcache_sdpa()
230230
else:
231-
return (
231+
positional_inputs = (
232232
torch.tensor(
233233
[[1, 2, 3]], dtype=torch.long
234234
), # tokens, with kv cache our input token length is always just 1 token.
235235
)
236+
return (positional_inputs, {})
236237

237238
# assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working
238-
def get_example_inputs_kvcache_sdpa(self):
239+
def get_example_inputs_kvcache_sdpa(self) -> Tuple[Tuple, Dict]:
239240
if self.enable_dynamic_shape:
240-
return (
241+
positional_inputs = (
241242
torch.tensor([[2, 3, 4]], dtype=torch.long),
242243
torch.tensor([0], dtype=torch.long),
243244
)
245+
return (positional_inputs, {})
244246
else:
245-
return (
247+
positional_inputs = (
246248
torch.tensor(
247249
[[1]], dtype=torch.long
248250
), # tokens, with kv cache our input token length is always just 1 token.
249251
torch.tensor(
250252
[0], dtype=torch.long
251253
), # start_pos, what token of output are we on.
252254
)
255+
return (positional_inputs, {})
253256

254257
def _transform_for_pre_quantization(self, checkpoint):
255258
assert hasattr(self.args, "preq_mode"), "preq_mode must be specified"

examples/models/model_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from abc import ABC, abstractmethod
8+
from typing import Dict, Tuple
89

910
import torch
1011

@@ -37,11 +38,11 @@ def get_eager_model(self) -> torch.nn.Module:
3738
raise NotImplementedError("get_eager_model")
3839

3940
@abstractmethod
40-
def get_example_inputs(self):
41+
def get_example_inputs(self) -> Tuple[Tuple, Dict]:
4142
"""
4243
Abstract method to provide example inputs for the model.
4344
4445
Returns:
45-
Any: Example inputs that can be used for testing and tracing.
46+
Tuple[Tuple, Dict]: The positional inputs (Tuple) and the kwarg inputs (Dict).
4647
"""
4748
raise NotImplementedError("get_example_inputs")

0 commit comments

Comments
 (0)