@@ -222,35 +222,32 @@ def get_eager_model(self) -> torch.nn.Module:
222
222
# switch all to FP32
223
223
return self .model_ .to (torch .float32 )
224
224
225
- def get_example_inputs (self ) -> Tuple [ Tuple , Dict ] :
225
+ def get_example_inputs (self ):
226
226
if self .use_kv_cache :
227
227
return self .get_example_inputs_kvcache_sdpa ()
228
228
else :
229
- positional_inputs = (
229
+ return (
230
230
torch .tensor (
231
231
[[1 , 2 , 3 ]], dtype = torch .long
232
232
), # tokens, with kv cache our input token length is always just 1 token.
233
233
)
234
- return (positional_inputs , {})
235
234
236
235
# assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working
237
- def get_example_inputs_kvcache_sdpa (self ) -> Tuple [ Tuple , Dict ] :
236
+ def get_example_inputs_kvcache_sdpa (self ):
238
237
if self .enable_dynamic_shape :
239
- positional_inputs = (
238
+ return (
240
239
torch .tensor ([[2 , 3 , 4 ]], dtype = torch .long ),
241
240
torch .tensor ([0 ], dtype = torch .long ),
242
241
)
243
- return (positional_inputs , {})
244
242
else :
245
- positional_inputs = (
243
+ return (
246
244
torch .tensor (
247
245
[[1 ]], dtype = torch .long
248
246
), # tokens, with kv cache our input token length is always just 1 token.
249
247
torch .tensor (
250
248
[0 ], dtype = torch .long
251
249
), # start_pos, what token of output are we on.
252
250
)
253
- return (positional_inputs , {})
254
251
255
252
def _transform_for_pre_quantization (self , checkpoint ):
256
253
assert hasattr (self .args , "preq_mode" ), "preq_mode must be specified"
0 commit comments