Skip to content

Commit a9647d2

Browse files
committed
Create create new method for example kwarg inputs instead
1 parent 0b5a9a7 commit a9647d2

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

examples/models/llama2/model.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -224,35 +224,32 @@ def get_eager_model(self) -> torch.nn.Module:
224224
# switch all to FP32
225225
return self.model_.to(torch.float32)
226226

227-
def get_example_inputs(self) -> Tuple[Tuple, Dict]:
227+
def get_example_inputs(self):
228228
if self.use_kv_cache:
229229
return self.get_example_inputs_kvcache_sdpa()
230230
else:
231-
positional_inputs = (
231+
return (
232232
torch.tensor(
233233
[[1, 2, 3]], dtype=torch.long
234234
), # tokens, with kv cache our input token length is always just 1 token.
235235
)
236-
return (positional_inputs, {})
237236

238237
# assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working
239-
def get_example_inputs_kvcache_sdpa(self) -> Tuple[Tuple, Dict]:
238+
def get_example_inputs_kvcache_sdpa(self):
240239
if self.enable_dynamic_shape:
241-
positional_inputs = (
240+
return (
242241
torch.tensor([[2, 3, 4]], dtype=torch.long),
243242
torch.tensor([0], dtype=torch.long),
244243
)
245-
return (positional_inputs, {})
246244
else:
247-
positional_inputs = (
245+
return (
248246
torch.tensor(
249247
[[1]], dtype=torch.long
250248
), # tokens, with kv cache our input token length is always just 1 token.
251249
torch.tensor(
252250
[0], dtype=torch.long
253251
), # start_pos, what token of output are we on.
254252
)
255-
return (positional_inputs, {})
256253

257254
def _transform_for_pre_quantization(self, checkpoint):
258255
assert hasattr(self.args, "preq_mode"), "preq_mode must be specified"

examples/models/model_base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from abc import ABC, abstractmethod
8-
from typing import Dict, Tuple
98

109
import torch
1110

@@ -38,11 +37,11 @@ def get_eager_model(self) -> torch.nn.Module:
3837
raise NotImplementedError("get_eager_model")
3938

4039
@abstractmethod
41-
def get_example_inputs(self) -> Tuple[Tuple, Dict]:
40+
def get_example_inputs(self):
4241
"""
4342
Abstract method to provide example inputs for the model.
4443
4544
Returns:
46-
Tuple[Tuple, Dict]: The positional inputs (Tuple) and the kwarg inputs (Dict).
45+
Any: Example inputs that can be used for testing and tracing.
4746
"""
4847
raise NotImplementedError("get_example_inputs")

0 commit comments

Comments
 (0)