|
47 | 47 |
|
48 | 48 | def compile(
|
49 | 49 | exported_program: ExportedProgram,
|
50 |
| - arg_inputs: Tuple[Any, ...], |
51 | 50 | *,
|
52 |
| - inputs: Optional[Tuple[Any, ...]] = None, |
| 51 | + arg_inputs: Optional[Sequence[Any]] = None, |
| 52 | + inputs: Optional[Sequence[Any]] = None, |
53 | 53 | kwarg_inputs: Optional[dict[Any, Any]] = None,
|
54 | 54 | device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE,
|
55 | 55 | disable_tf32: bool = _defaults.DISABLE_TF32,
|
@@ -169,22 +169,25 @@ def compile(
|
169 | 169 | f"Detected torch_executed_modules was non-empty: {torch_executed_modules}"
|
170 | 170 | "\nThis feature is unimplemented in Torch-TRT Dynamo currently."
|
171 | 171 | )
|
172 |
| - if inputs: |
| 172 | + if inputs is not None: |
173 | 173 | logger.warning(
|
174 | 174 | "'inputs' is deprecated. Please use 'args_inputs' in the future."
|
175 | 175 | )
|
176 |
| - if not arg_inputs: |
| 176 | + if arg_inputs is None: |
177 | 177 | arg_inputs = inputs
|
178 | 178 | else:
|
179 | 179 | logger.warning(
|
180 | 180 | "Both 'arg_inputs' and 'inputs' are received. 'inputs' will be ignored."
|
181 | 181 | )
|
| 182 | + else: |
| 183 | + if arg_inputs is None: |
| 184 | + raise AssertionError("'arg_inputs' cannot be empty") |
182 | 185 | if not isinstance(arg_inputs, collections.abc.Sequence):
|
183 | 186 | arg_inputs = [arg_inputs]
|
184 | 187 |
|
185 | 188 | # Prepare torch_trt inputs
|
186 |
| - arg_inputs = prepare_inputs(arg_inputs) |
187 |
| - kwarg_inputs = prepare_inputs(kwarg_inputs) |
| 189 | + trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs) |
| 190 | + trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs) |
188 | 191 | device = to_torch_tensorrt_device(device)
|
189 | 192 | enabled_precisions = {dtype._from(p) for p in enabled_precisions}
|
190 | 193 |
|
@@ -238,7 +241,7 @@ def compile(
|
238 | 241 |
|
239 | 242 | settings = CompilationSettings(**compilation_options)
|
240 | 243 | logger.info("Compilation Settings: %s\n", settings)
|
241 |
| - trt_gm = compile_module(gm, arg_inputs, kwarg_inputs, settings) |
| 244 | + trt_gm = compile_module(gm, trt_arg_inputs, trt_kwarg_inputs, settings) |
242 | 245 | return trt_gm
|
243 | 246 |
|
244 | 247 |
|
|
0 commit comments