Skip to content

Commit 2d19a0d

Browse files
committed
Changed the typing
1 parent c5a1e46 commit 2d19a0d

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,8 @@ def compile(
234234

235235
def compile_module(
236236
gm: torch.fx.GraphModule,
237-
sample_inputs: Sequence[Input],
238-
sample_kwarg_inputs: Any = None,
237+
sample_arg_inputs: Sequence[Input],
238+
sample_kwarg_inputs: Optional[dict[Any, Any]] = None,
239239
settings: CompilationSettings = CompilationSettings(),
240240
) -> torch.fx.GraphModule:
241241
"""Compile a traced FX module
@@ -266,10 +266,12 @@ def compile_module(
266266
dryrun_tracker.total_ops_in_graph = total_ops
267267
dryrun_tracker.supported_ops_in_graph = num_supported_ops
268268
dryrun_tracker.graph_input_shapes = parse_complex_tensor_structs(
269-
sample_inputs, "shape", lambda x: dict(x) if isinstance(x, dict) else tuple(x)
269+
sample_arg_inputs,
270+
"shape",
271+
lambda x: dict(x) if isinstance(x, dict) else tuple(x),
270272
)
271273
dryrun_tracker.graph_input_dtypes = parse_complex_tensor_structs(
272-
sample_inputs, "dtype", lambda t: t.to(torch.dtype, use_default=True)
274+
sample_arg_inputs, "dtype", lambda t: t.to(torch.dtype, use_default=True)
273275
)
274276
dryrun_tracker.compilation_settings = settings
275277

@@ -428,7 +430,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
428430
trt_modules[name] = trt_module
429431

430432
torch_sample_inputs = get_torch_inputs(
431-
sample_inputs, to_torch_device(settings.device)
433+
sample_arg_inputs, to_torch_device(settings.device)
432434
)
433435
torch_sample_kwarg_inputs = get_torch_inputs(
434436
sample_kwarg_inputs, to_torch_device(settings.device)

0 commit comments

Comments
 (0)