Skip to content

Commit bfbb09b

Browse files
committed
backward compatibility
1 parent 0e4be6b commit bfbb09b

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@
4747

4848
def compile(
4949
exported_program: ExportedProgram,
50-
arg_inputs: Tuple[Any, ...],
5150
*,
52-
inputs: Optional[Tuple[Any, ...]] = None,
51+
arg_inputs: Optional[Sequence[Any]] = None,
52+
inputs: Optional[Sequence[Any]] = None,
5353
kwarg_inputs: Optional[dict[Any, Any]] = None,
5454
device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE,
5555
disable_tf32: bool = _defaults.DISABLE_TF32,
@@ -169,22 +169,25 @@ def compile(
169169
f"Detected torch_executed_modules was non-empty: {torch_executed_modules}"
170170
"\nThis feature is unimplemented in Torch-TRT Dynamo currently."
171171
)
172-
if inputs:
172+
if inputs is not None:
173173
logger.warning(
174174
"'inputs' is deprecated. Please use 'args_inputs' in the future."
175175
)
176-
if not arg_inputs:
176+
if arg_inputs is None:
177177
arg_inputs = inputs
178178
else:
179179
logger.warning(
180180
"Both 'arg_inputs' and 'inputs' are received. 'inputs' will be ignored."
181181
)
182+
else:
183+
if arg_inputs is None:
184+
raise AssertionError("'arg_inputs' cannot be empty")
182185
if not isinstance(arg_inputs, collections.abc.Sequence):
183186
arg_inputs = [arg_inputs]
184187

185188
# Prepare torch_trt inputs
186-
arg_inputs = prepare_inputs(arg_inputs)
187-
kwarg_inputs = prepare_inputs(kwarg_inputs)
189+
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
190+
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
188191
device = to_torch_tensorrt_device(device)
189192
enabled_precisions = {dtype._from(p) for p in enabled_precisions}
190193

@@ -238,7 +241,7 @@ def compile(
238241

239242
settings = CompilationSettings(**compilation_options)
240243
logger.info("Compilation Settings: %s\n", settings)
241-
trt_gm = compile_module(gm, arg_inputs, kwarg_inputs, settings)
244+
trt_gm = compile_module(gm, trt_arg_inputs, trt_kwarg_inputs, settings)
242245
return trt_gm
243246

244247

0 commit comments

Comments
 (0)