|
7 | 7 | # pyre-unsafe
|
8 | 8 |
|
9 | 9 | import json
|
10 |
| -from typing import Any, Dict, Tuple |
| 10 | +from typing import Any, Dict |
11 | 11 |
|
12 | 12 | import torch
|
13 | 13 |
|
@@ -122,17 +122,20 @@ def get_eager_model(self) -> torch.nn.Module:
|
122 | 122 | else:
|
123 | 123 | return self.model_.to(torch.float16)
|
124 | 124 |
|
125 |
| - def get_example_inputs(self) -> Tuple[Tuple, Dict]: |
| 125 | + def get_example_inputs(self): |
126 | 126 | return (
|
127 |
| - (torch.ones(1, 64, dtype=torch.long),), # positional inputs |
128 |
| - { |
129 |
| - # "mask": None, |
130 |
| - # "encoder_input": None, |
131 |
| - # "encoder_mask": None, |
132 |
| - # "input_pos": torch.ones(64, dtype=torch.long), |
133 |
| - } # kwarg inputs |
| 127 | + torch.ones(1, 64, dtype=torch.long), # positional inputs |
134 | 128 | )
|
135 | 129 |
|
| 130 | + def get_example_kwarg_inputs(self): |
| 131 | + # TODO: add input_pos and mask when after making cache work. |
| 132 | + return { |
| 133 | + # "mask": None, |
| 134 | + # "encoder_input": None, |
| 135 | + # "encoder_mask": None, |
| 136 | + # "input_pos": torch.ones(64, dtype=torch.long), |
| 137 | + } |
| 138 | + |
136 | 139 | def get_dynamic_shapes(self):
|
137 | 140 | dim = torch.export.Dim("token_dim", min=1,max=self.max_seq_len)
|
138 | 141 | dynamic_shapes = {
|
|
0 commit comments