Skip to content

Commit 73218d5

Browse files
committed
Create create new method for example kwarg inputs instead
1 parent 695e86b commit 73218d5

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

examples/models/llama2/model.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -250,35 +250,32 @@ def get_eager_model(self):
250250
# switch all to FP32
251251
return self.model_.to(torch.float32)
252252

253-
def get_example_inputs(self) -> Tuple[Tuple, Dict]:
253+
def get_example_inputs(self):
254254
if self.use_kv_cache:
255255
return self.get_example_inputs_kvcache_sdpa()
256256
else:
257-
positional_inputs = (
257+
return (
258258
torch.tensor(
259259
[[1, 2, 3]], dtype=torch.long
260260
), # tokens, with kv cache our input token length is always just 1 token.
261261
)
262-
return (positional_inputs, {})
263262

264263
# assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working
265-
def get_example_inputs_kvcache_sdpa(self) -> Tuple[Tuple, Dict]:
264+
def get_example_inputs_kvcache_sdpa(self):
266265
if self.enable_dynamic_shape:
267-
positional_inputs = (
266+
return (
268267
torch.tensor([[2, 3, 4]], dtype=torch.long),
269268
torch.tensor([0], dtype=torch.long),
270269
)
271-
return (positional_inputs, {})
272270
else:
273-
positional_inputs = (
271+
return (
274272
torch.tensor(
275273
[[1]], dtype=torch.long
276274
), # tokens, with kv cache our input token length is always just 1 token.
277275
torch.tensor(
278276
[0], dtype=torch.long
279277
), # start_pos, what token of output are we on.
280278
)
281-
return (positional_inputs, {})
282279

283280
def _transform_for_pre_quantization(self, checkpoint):
284281
assert hasattr(self.args, "preq_mode"), "preq_mode must be specified"

examples/models/model_base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from abc import ABC, abstractmethod
8-
from typing import Dict, Tuple
98

109
import torch
1110

@@ -38,11 +37,11 @@ def get_eager_model(self) -> torch.nn.Module:
3837
raise NotImplementedError("get_eager_model")
3938

4039
@abstractmethod
41-
def get_example_inputs(self) -> Tuple[Tuple, Dict]:
40+
def get_example_inputs(self):
4241
"""
4342
Abstract method to provide example inputs for the model.
4443
4544
Returns:
46-
Tuple[Tuple, Dict]: The positional inputs (Tuple) and the kwarg inputs (Dict).
45+
Any: Example inputs that can be used for testing and tracing.
4746
"""
4847
raise NotImplementedError("get_example_inputs")

0 commit comments

Comments
 (0)