Skip to content

Commit cf9a9bb

Browse files
committed
backward compatibility
1 parent 89b6216 commit cf9a9bb

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@
4747

4848
def compile(
4949
exported_program: ExportedProgram,
50-
arg_inputs: Tuple[Any, ...],
5150
*,
52-
inputs: Optional[Tuple[Any, ...]] = None,
51+
arg_inputs: Optional[Sequence[Any]] = None,
52+
inputs: Optional[Sequence[Any]] = None,
5353
kwarg_inputs: Optional[dict[Any, Any]] = None,
5454
device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE,
5555
disable_tf32: bool = _defaults.DISABLE_TF32,
@@ -183,22 +183,25 @@ def compile(
183183
f"Detected torch_executed_modules was non-empty: {torch_executed_modules}"
184184
"\nThis feature is unimplemented in Torch-TRT Dynamo currently."
185185
)
186-
if inputs:
186+
if inputs is not None:
187187
logger.warning(
188188
"'inputs' is deprecated. Please use 'args_inputs' in the future."
189189
)
190-
if not arg_inputs:
190+
if arg_inputs is None:
191191
arg_inputs = inputs
192192
else:
193193
logger.warning(
194194
"Both 'arg_inputs' and 'inputs' are received. 'inputs' will be ignored."
195195
)
196+
else:
197+
if arg_inputs is None:
198+
raise AssertionError("'arg_inputs' cannot be empty")
196199
if not isinstance(arg_inputs, collections.abc.Sequence):
197200
arg_inputs = [arg_inputs]
198201

199202
# Prepare torch_trt inputs
200-
arg_inputs = prepare_inputs(arg_inputs)
201-
kwarg_inputs = prepare_inputs(kwarg_inputs)
203+
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
204+
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
202205
device = to_torch_tensorrt_device(device)
203206
enabled_precisions = {dtype._from(p) for p in enabled_precisions}
204207

@@ -252,7 +255,7 @@ def compile(
252255

253256
settings = CompilationSettings(**compilation_options)
254257
logger.info("Compilation Settings: %s\n", settings)
255-
trt_gm = compile_module(gm, arg_inputs, kwarg_inputs, settings)
258+
trt_gm = compile_module(gm, trt_arg_inputs, trt_kwarg_inputs, settings)
256259
return trt_gm
257260

258261

0 commit comments

Comments
 (0)