|
47 | 47 |
|
48 | 48 | def compile(
|
49 | 49 | exported_program: ExportedProgram,
|
50 |
| - arg_inputs: Tuple[Any, ...], |
51 | 50 | *,
|
52 |
| - inputs: Optional[Tuple[Any, ...]] = None, |
| 51 | + arg_inputs: Optional[Sequence[Any]] = None, |
| 52 | + inputs: Optional[Sequence[Any]] = None, |
53 | 53 | kwarg_inputs: Optional[dict[Any, Any]] = None,
|
54 | 54 | device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE,
|
55 | 55 | disable_tf32: bool = _defaults.DISABLE_TF32,
|
@@ -183,22 +183,25 @@ def compile(
|
183 | 183 | f"Detected torch_executed_modules was non-empty: {torch_executed_modules}"
|
184 | 184 | "\nThis feature is unimplemented in Torch-TRT Dynamo currently."
|
185 | 185 | )
|
186 |
| - if inputs: |
| 186 | + if inputs is not None: |
187 | 187 | logger.warning(
|
188 | 188 | "'inputs' is deprecated. Please use 'args_inputs' in the future."
|
189 | 189 | )
|
190 |
| - if not arg_inputs: |
| 190 | + if arg_inputs is None: |
191 | 191 | arg_inputs = inputs
|
192 | 192 | else:
|
193 | 193 | logger.warning(
|
194 | 194 | "Both 'arg_inputs' and 'inputs' are received. 'inputs' will be ignored."
|
195 | 195 | )
|
| 196 | + else: |
| 197 | + if arg_inputs is None: |
| 198 | + raise AssertionError("'arg_inputs' cannot be empty") |
196 | 199 | if not isinstance(arg_inputs, collections.abc.Sequence):
|
197 | 200 | arg_inputs = [arg_inputs]
|
198 | 201 |
|
199 | 202 | # Prepare torch_trt inputs
|
200 |
| - arg_inputs = prepare_inputs(arg_inputs) |
201 |
| - kwarg_inputs = prepare_inputs(kwarg_inputs) |
| 203 | + trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs) |
| 204 | + trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs) |
202 | 205 | device = to_torch_tensorrt_device(device)
|
203 | 206 | enabled_precisions = {dtype._from(p) for p in enabled_precisions}
|
204 | 207 |
|
@@ -252,7 +255,7 @@ def compile(
|
252 | 255 |
|
253 | 256 | settings = CompilationSettings(**compilation_options)
|
254 | 257 | logger.info("Compilation Settings: %s\n", settings)
|
255 |
| - trt_gm = compile_module(gm, arg_inputs, kwarg_inputs, settings) |
| 258 | + trt_gm = compile_module(gm, trt_arg_inputs, trt_kwarg_inputs, settings) |
256 | 259 | return trt_gm
|
257 | 260 |
|
258 | 261 |
|
|
0 commit comments