Skip to content

Commit f8540ed

Browse files
committed
Changed the typing
1 parent f9df16e commit f8540ed

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,8 @@ def compile(
248248

249249
def compile_module(
250250
gm: torch.fx.GraphModule,
251-
sample_inputs: Sequence[Input],
252-
sample_kwarg_inputs: Any = None,
251+
sample_arg_inputs: Sequence[Input],
252+
sample_kwarg_inputs: Optional[dict[Any, Any]] = None,
253253
settings: CompilationSettings = CompilationSettings(),
254254
) -> torch.fx.GraphModule:
255255
"""Compile a traced FX module
@@ -280,10 +280,12 @@ def compile_module(
280280
dryrun_tracker.total_ops_in_graph = total_ops
281281
dryrun_tracker.supported_ops_in_graph = num_supported_ops
282282
dryrun_tracker.graph_input_shapes = parse_complex_tensor_structs(
283-
sample_inputs, "shape", lambda x: dict(x) if isinstance(x, dict) else tuple(x)
283+
sample_arg_inputs,
284+
"shape",
285+
lambda x: dict(x) if isinstance(x, dict) else tuple(x),
284286
)
285287
dryrun_tracker.graph_input_dtypes = parse_complex_tensor_structs(
286-
sample_inputs, "dtype", lambda t: t.to(torch.dtype, use_default=True)
288+
sample_arg_inputs, "dtype", lambda t: t.to(torch.dtype, use_default=True)
287289
)
288290
dryrun_tracker.compilation_settings = settings
289291

@@ -442,7 +444,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
442444
trt_modules[name] = trt_module
443445

444446
torch_sample_inputs = get_torch_inputs(
445-
sample_inputs, to_torch_device(settings.device)
447+
sample_arg_inputs, to_torch_device(settings.device)
446448
)
447449
torch_sample_kwarg_inputs = get_torch_inputs(
448450
sample_kwarg_inputs, to_torch_device(settings.device)

0 commit comments

Comments
 (0)