Skip to content

Commit 346d45f

Browse files
committed
Renamed inputs to arg_inputs
1 parent 2d19a0d commit 346d45f

File tree

5 files changed

+15
-15
lines changed

5 files changed

+15
-15
lines changed

docs/_downloads/6dc7949e3e7f3cb855e4a7b28eadc851/vgg16_fp8_ptq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def calibrate_loop(model):
229229
exp_program = torch.export.export(model, (input_tensor,))
230230
trt_model = torchtrt.dynamo.compile(
231231
exp_program,
232-
inputs=[input_tensor],
232+
arg_inputs=[input_tensor],
233233
enabled_precisions={torch.float8_e4m3fn},
234234
min_block_size=1,
235235
debug=False,

examples/dynamo/vgg16_fp8_ptq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def calibrate_loop(model):
229229
exp_program = torch.export.export(model, (input_tensor,))
230230
trt_model = torchtrt.dynamo.compile(
231231
exp_program,
232-
inputs=[input_tensor],
232+
arg_inputs=[input_tensor],
233233
enabled_precisions={torch.float8_e4m3fn},
234234
min_block_size=1,
235235
debug=False,

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@
4747

4848
def compile(
4949
exported_program: ExportedProgram,
50-
inputs: Tuple[Any, ...],
50+
arg_inputs: Tuple[Any, ...],
5151
*,
52-
kwarg_inputs: Any = None,
52+
kwarg_inputs: Optional[dict[Any, Any]] = None,
5353
device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE,
5454
disable_tf32: bool = _defaults.DISABLE_TF32,
5555
assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT,
@@ -169,11 +169,11 @@ def compile(
169169
"\nThis feature is unimplemented in Torch-TRT Dynamo currently."
170170
)
171171

172-
if not isinstance(inputs, collections.abc.Sequence):
173-
inputs = [inputs]
172+
if not isinstance(arg_inputs, collections.abc.Sequence):
173+
arg_inputs = [arg_inputs]
174174

175175
# Prepare torch_trt inputs
176-
inputs = prepare_inputs(inputs)
176+
arg_inputs = prepare_inputs(arg_inputs)
177177
kwarg_inputs = prepare_inputs(kwarg_inputs)
178178
device = to_torch_tensorrt_device(device)
179179
enabled_precisions = {dtype._from(p) for p in enabled_precisions}
@@ -228,7 +228,7 @@ def compile(
228228

229229
settings = CompilationSettings(**compilation_options)
230230
logger.info("Compilation Settings: %s\n", settings)
231-
trt_gm = compile_module(gm, inputs, kwarg_inputs, settings)
231+
trt_gm = compile_module(gm, arg_inputs, kwarg_inputs, settings)
232232
return trt_gm
233233

234234

tests/py/dynamo/models/test_dtype_support.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def forward(self, x):
3636
exp_mod = torch.export.export(mod, (in_tensor,))
3737
trt_mod = torch_tensorrt.dynamo.compile(
3838
exp_mod,
39-
inputs=[in_tensor],
39+
arg_inputs=[in_tensor],
4040
pass_through_build_failures=True,
4141
truncate_double=True,
4242
min_block_size=1,
@@ -74,7 +74,7 @@ def forward(self, x):
7474
exp_mod = torch.export.export(mod, (in_tensor,))
7575
trt_mod = torch_tensorrt.dynamo.compile(
7676
exp_mod,
77-
inputs=[in_tensor],
77+
arg_inputs=[in_tensor],
7878
pass_through_build_failures=True,
7979
truncate_double=True,
8080
min_block_size=1,
@@ -118,7 +118,7 @@ def forward(self, x):
118118
exp_mod = torch.export.export(mod, (in_tensor,))
119119
trt_mod = torch_tensorrt.dynamo.compile(
120120
exp_mod,
121-
inputs=[in_tensor],
121+
arg_inputs=[in_tensor],
122122
pass_through_build_failures=True,
123123
truncate_double=False,
124124
min_block_size=1,
@@ -157,7 +157,7 @@ def forward(self, x):
157157
exp_mod = torch.export.export(mod, (in_tensor,))
158158
trt_mod = torch_tensorrt.dynamo.compile(
159159
exp_mod,
160-
inputs=[in_tensor],
160+
arg_inputs=[in_tensor],
161161
pass_through_build_failures=True,
162162
truncate_double=False,
163163
min_block_size=1,
@@ -201,7 +201,7 @@ def forward(self, x):
201201
exp_mod = torch.export.export(mod, (in_tensor,))
202202
trt_mod = torch_tensorrt.dynamo.compile(
203203
exp_mod,
204-
inputs=[in_tensor],
204+
arg_inputs=[in_tensor],
205205
pass_through_build_failures=True,
206206
enabled_precisions={torch.float, torch.bfloat16, torch.half},
207207
min_block_size=1,
@@ -239,7 +239,7 @@ def forward(self, x):
239239
exp_mod = torch.export.export(mod, (in_tensor,))
240240
trt_mod = torch_tensorrt.dynamo.compile(
241241
exp_mod,
242-
inputs=[in_tensor],
242+
arg_inputs=[in_tensor],
243243
pass_through_build_failures=True,
244244
enabled_precisions={torch.float, torch.bfloat16, torch.half},
245245
min_block_size=1,

tests/py/dynamo/models/test_models_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def calibrate_loop(model):
222222
exp_program = torch.export.export(model, (input_tensor,))
223223
trt_model = torchtrt.dynamo.compile(
224224
exp_program,
225-
inputs=[input_tensor],
225+
arg_inputs=[input_tensor],
226226
enabled_precisions={torch.float8_e4m3fn},
227227
min_block_size=1,
228228
debug=True,

0 commit comments

Comments
 (0)