@@ -222,32 +222,35 @@ def get_eager_model(self) -> torch.nn.Module:
222
222
# switch all to FP32
223
223
return self .model_ .to (torch .float32 )
224
224
225
- def get_example_inputs (self ):
225
+ def get_example_inputs (self ) -> Tuple [ Tuple , Dict ] :
226
226
if self .use_kv_cache :
227
227
return self .get_example_inputs_kvcache_sdpa ()
228
228
else :
229
- return (
229
+ positional_inputs = (
230
230
torch .tensor (
231
231
[[1 , 2 , 3 ]], dtype = torch .long
232
232
), # tokens, with kv cache our input token length is always just 1 token.
233
233
)
234
+ return (positional_inputs , {})
234
235
235
236
# assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working
236
- def get_example_inputs_kvcache_sdpa (self ):
237
+ def get_example_inputs_kvcache_sdpa (self ) -> Tuple [ Tuple , Dict ] :
237
238
if self .enable_dynamic_shape :
238
- return (
239
+ positional_inputs = (
239
240
torch .tensor ([[2 , 3 , 4 ]], dtype = torch .long ),
240
241
torch .tensor ([0 ], dtype = torch .long ),
241
242
)
243
+ return (positional_inputs , {})
242
244
else :
243
- return (
245
+ positional_inputs = (
244
246
torch .tensor (
245
247
[[1 ]], dtype = torch .long
246
248
), # tokens, with kv cache our input token length is always just 1 token.
247
249
torch .tensor (
248
250
[0 ], dtype = torch .long
249
251
), # start_pos, what token of output are we on.
250
252
)
253
+ return (positional_inputs , {})
251
254
252
255
def _transform_for_pre_quantization (self , checkpoint ):
253
256
assert hasattr (self .args , "preq_mode" ), "preq_mode must be specified"
0 commit comments