@@ -248,8 +248,8 @@ def compile(
248
248
249
249
def compile_module (
250
250
gm : torch .fx .GraphModule ,
251
- sample_inputs : Sequence [Input ],
252
- sample_kwarg_inputs : Any = None ,
251
+ sample_arg_inputs : Sequence [Input ],
252
+ sample_kwarg_inputs : Optional [ dict [ Any , Any ]] = None ,
253
253
settings : CompilationSettings = CompilationSettings (),
254
254
) -> torch .fx .GraphModule :
255
255
"""Compile a traced FX module
@@ -280,10 +280,12 @@ def compile_module(
280
280
dryrun_tracker .total_ops_in_graph = total_ops
281
281
dryrun_tracker .supported_ops_in_graph = num_supported_ops
282
282
dryrun_tracker .graph_input_shapes = parse_complex_tensor_structs (
283
- sample_inputs , "shape" , lambda x : dict (x ) if isinstance (x , dict ) else tuple (x )
283
+ sample_arg_inputs ,
284
+ "shape" ,
285
+ lambda x : dict (x ) if isinstance (x , dict ) else tuple (x ),
284
286
)
285
287
dryrun_tracker .graph_input_dtypes = parse_complex_tensor_structs (
286
- sample_inputs , "dtype" , lambda t : t .to (torch .dtype , use_default = True )
288
+ sample_arg_inputs , "dtype" , lambda t : t .to (torch .dtype , use_default = True )
287
289
)
288
290
dryrun_tracker .compilation_settings = settings
289
291
@@ -442,7 +444,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
442
444
trt_modules [name ] = trt_module
443
445
444
446
torch_sample_inputs = get_torch_inputs (
445
- sample_inputs , to_torch_device (settings .device )
447
+ sample_arg_inputs , to_torch_device (settings .device )
446
448
)
447
449
torch_sample_kwarg_inputs = get_torch_inputs (
448
450
sample_kwarg_inputs , to_torch_device (settings .device )
0 commit comments