@@ -250,35 +250,32 @@ def get_eager_model(self):
250
250
# switch all to FP32
251
251
return self .model_ .to (torch .float32 )
252
252
253
- def get_example_inputs (self ) -> Tuple [ Tuple , Dict ] :
253
+ def get_example_inputs (self ):
254
254
if self .use_kv_cache :
255
255
return self .get_example_inputs_kvcache_sdpa ()
256
256
else :
257
- positional_inputs = (
257
+ return (
258
258
torch .tensor (
259
259
[[1 , 2 , 3 ]], dtype = torch .long
260
260
), # tokens, with kv cache our input token length is always just 1 token.
261
261
)
262
- return (positional_inputs , {})
263
262
264
263
# assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working
265
- def get_example_inputs_kvcache_sdpa (self ) -> Tuple [ Tuple , Dict ] :
264
+ def get_example_inputs_kvcache_sdpa (self ):
266
265
if self .enable_dynamic_shape :
267
- positional_inputs = (
266
+ return (
268
267
torch .tensor ([[2 , 3 , 4 ]], dtype = torch .long ),
269
268
torch .tensor ([0 ], dtype = torch .long ),
270
269
)
271
- return (positional_inputs , {})
272
270
else :
273
- positional_inputs = (
271
+ return (
274
272
torch .tensor (
275
273
[[1 ]], dtype = torch .long
276
274
), # tokens, with kv cache our input token length is always just 1 token.
277
275
torch .tensor (
278
276
[0 ], dtype = torch .long
279
277
), # start_pos, what token of output are we on.
280
278
)
281
- return (positional_inputs , {})
282
279
283
280
def _transform_for_pre_quantization (self , checkpoint ):
284
281
assert hasattr (self .args , "preq_mode" ), "preq_mode must be specified"
0 commit comments