@@ -234,8 +234,8 @@ def compile(
234
234
235
235
def compile_module (
236
236
gm : torch .fx .GraphModule ,
237
- sample_inputs : Sequence [Input ],
238
- sample_kwarg_inputs : Any = None ,
237
+ sample_arg_inputs : Sequence [Input ],
238
+ sample_kwarg_inputs : Optional [ dict [ Any , Any ]] = None ,
239
239
settings : CompilationSettings = CompilationSettings (),
240
240
) -> torch .fx .GraphModule :
241
241
"""Compile a traced FX module
@@ -266,10 +266,12 @@ def compile_module(
266
266
dryrun_tracker .total_ops_in_graph = total_ops
267
267
dryrun_tracker .supported_ops_in_graph = num_supported_ops
268
268
dryrun_tracker .graph_input_shapes = parse_complex_tensor_structs (
269
- sample_inputs , "shape" , lambda x : dict (x ) if isinstance (x , dict ) else tuple (x )
269
+ sample_arg_inputs ,
270
+ "shape" ,
271
+ lambda x : dict (x ) if isinstance (x , dict ) else tuple (x ),
270
272
)
271
273
dryrun_tracker .graph_input_dtypes = parse_complex_tensor_structs (
272
- sample_inputs , "dtype" , lambda t : t .to (torch .dtype , use_default = True )
274
+ sample_arg_inputs , "dtype" , lambda t : t .to (torch .dtype , use_default = True )
273
275
)
274
276
dryrun_tracker .compilation_settings = settings
275
277
@@ -428,7 +430,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
428
430
trt_modules [name ] = trt_module
429
431
430
432
torch_sample_inputs = get_torch_inputs (
431
- sample_inputs , to_torch_device (settings .device )
433
+ sample_arg_inputs , to_torch_device (settings .device )
432
434
)
433
435
torch_sample_kwarg_inputs = get_torch_inputs (
434
436
sample_kwarg_inputs , to_torch_device (settings .device )
0 commit comments