Skip to content

Commit 9c029a5

Browse files
committed
Fixed a bug of dynamo.trace without dynamic shape
1 parent db40534 commit 9c029a5

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

py/torch_tensorrt/dynamo/_tracer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def trace(
7777
device = to_torch_device(kwargs.get("device", default_device()))
7878
torch_arg_inputs = get_torch_inputs(arg_inputs, device)
7979
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
80+
# Constructing dynamic shape list as a nested dict
8081
dynamic_shapes = get_dynamic_shapes_args(mod, arg_inputs)
8182
dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs))
8283
exp_program = export(
@@ -120,7 +121,8 @@ def get_dynamic_shapes_args(mod: torch.nn.Module, inputs: Any) -> dict[str, Any]
120121

121122
def get_dynamic_shapes(input: Input) -> dict[Any, Any]:
122123
if not isinstance(input, Input):
123-
raise TypeError(f"Expected type torch_trt.Input, but got {type(input)}")
124+
# If the input is torch.Tensor, no dynamic is needed. Return empty dict
125+
return {}
124126
else:
125127
dynamic_dims = {}
126128
if input.shape_mode == Input._ShapeMode.DYNAMIC:

0 commit comments

Comments
 (0)