@@ -224,35 +224,32 @@ def get_eager_model(self) -> torch.nn.Module:
224
224
# switch all to FP32
225
225
return self .model_ .to (torch .float32 )
226
226
227
- def get_example_inputs (self ) -> Tuple [ Tuple , Dict ] :
227
+ def get_example_inputs (self ):
228
228
if self .use_kv_cache :
229
229
return self .get_example_inputs_kvcache_sdpa ()
230
230
else :
231
- positional_inputs = (
231
+ return (
232
232
torch .tensor (
233
233
[[1 , 2 , 3 ]], dtype = torch .long
234
234
), # tokens, with kv cache our input token length is always just 1 token.
235
235
)
236
- return (positional_inputs , {})
237
236
238
237
# assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working
239
- def get_example_inputs_kvcache_sdpa (self ) -> Tuple [ Tuple , Dict ] :
238
+ def get_example_inputs_kvcache_sdpa (self ):
240
239
if self .enable_dynamic_shape :
241
- positional_inputs = (
240
+ return (
242
241
torch .tensor ([[2 , 3 , 4 ]], dtype = torch .long ),
243
242
torch .tensor ([0 ], dtype = torch .long ),
244
243
)
245
- return (positional_inputs , {})
246
244
else :
247
- positional_inputs = (
245
+ return (
248
246
torch .tensor (
249
247
[[1 ]], dtype = torch .long
250
248
), # tokens, with kv cache our input token length is always just 1 token.
251
249
torch .tensor (
252
250
[0 ], dtype = torch .long
253
251
), # start_pos, what token of output are we on.
254
252
)
255
- return (positional_inputs , {})
256
253
257
254
def _transform_for_pre_quantization (self , checkpoint ):
258
255
assert hasattr (self .args , "preq_mode" ), "preq_mode must be specified"
0 commit comments