@@ -224,32 +224,35 @@ def get_eager_model(self) -> torch.nn.Module:
224
224
# switch all to FP32
225
225
return self .model_ .to (torch .float32 )
226
226
227
- def get_example_inputs (self ):
227
+ def get_example_inputs (self ) -> Tuple [ Tuple , Dict ] :
228
228
if self .use_kv_cache :
229
229
return self .get_example_inputs_kvcache_sdpa ()
230
230
else :
231
- return (
231
+ positional_inputs = (
232
232
torch .tensor (
233
233
[[1 , 2 , 3 ]], dtype = torch .long
234
234
), # tokens, with kv cache our input token length is always just 1 token.
235
235
)
236
+ return (positional_inputs , {})
236
237
237
238
# assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working
238
- def get_example_inputs_kvcache_sdpa (self ):
239
+ def get_example_inputs_kvcache_sdpa (self ) -> Tuple [ Tuple , Dict ] :
239
240
if self .enable_dynamic_shape :
240
- return (
241
+ positional_inputs = (
241
242
torch .tensor ([[2 , 3 , 4 ]], dtype = torch .long ),
242
243
torch .tensor ([0 ], dtype = torch .long ),
243
244
)
245
+ return (positional_inputs , {})
244
246
else :
245
- return (
247
+ positional_inputs = (
246
248
torch .tensor (
247
249
[[1 ]], dtype = torch .long
248
250
), # tokens, with kv cache our input token length is always just 1 token.
249
251
torch .tensor (
250
252
[0 ], dtype = torch .long
251
253
), # start_pos, what token of output are we on.
252
254
)
255
+ return (positional_inputs , {})
253
256
254
257
def _transform_for_pre_quantization (self , checkpoint ):
255
258
assert hasattr (self .args , "preq_mode" ), "preq_mode must be specified"
0 commit comments