@@ -250,32 +250,35 @@ def get_eager_model(self):
250
250
# switch all to FP32
251
251
return self .model_ .to (torch .float32 )
252
252
253
- def get_example_inputs (self ):
253
+ def get_example_inputs (self ) -> Tuple [ Tuple , Dict ] :
254
254
if self .use_kv_cache :
255
255
return self .get_example_inputs_kvcache_sdpa ()
256
256
else :
257
- return (
257
+ positional_inputs = (
258
258
torch .tensor (
259
259
[[1 , 2 , 3 ]], dtype = torch .long
260
260
), # tokens, with kv cache our input token length is always just 1 token.
261
261
)
262
+ return (positional_inputs , {})
262
263
263
264
# assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working
264
- def get_example_inputs_kvcache_sdpa (self ):
265
+ def get_example_inputs_kvcache_sdpa (self ) -> Tuple [ Tuple , Dict ] :
265
266
if self .enable_dynamic_shape :
266
- return (
267
+ positional_inputs = (
267
268
torch .tensor ([[2 , 3 , 4 ]], dtype = torch .long ),
268
269
torch .tensor ([0 ], dtype = torch .long ),
269
270
)
271
+ return (positional_inputs , {})
270
272
else :
271
- return (
273
+ positional_inputs = (
272
274
torch .tensor (
273
275
[[1 ]], dtype = torch .long
274
276
), # tokens, with kv cache our input token length is always just 1 token.
275
277
torch .tensor (
276
278
[0 ], dtype = torch .long
277
279
), # start_pos, what token of output are we on.
278
280
)
281
+ return (positional_inputs , {})
279
282
280
283
def _transform_for_pre_quantization (self , checkpoint ):
281
284
assert hasattr (self .args , "preq_mode" ), "preq_mode must be specified"
0 commit comments