Skip to content

Commit db40534

Browse files
committed
Fixed comments
1 parent 164b934 commit db40534

File tree

8 files changed

+303
-305
lines changed

8 files changed

+303
-305
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -243,11 +243,11 @@ def compile(
243243
elif target_ir == _IRType.dynamo:
244244
# Prepare torch and torchtrt inputs
245245
if not arg_inputs and not inputs:
246-
raise AssertionError("'arg_input' and 'input' should not both be None.")
246+
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")
247247

248248
elif arg_inputs and inputs:
249249
raise AssertionError(
250-
"'arg_input' and 'input' should not be used at the same time."
250+
"'arg_inputs' and 'inputs' should not be used at the same time."
251251
)
252252
arg_inputs = inputs or arg_inputs
253253

@@ -367,11 +367,11 @@ def convert_method_to_trt_engine(
367367
elif target_ir == _IRType.dynamo:
368368
# Prepare torch and torchtrt inputs
369369
if not arg_inputs and not inputs:
370-
raise AssertionError("'arg_input' and 'input' should not both be None.")
370+
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")
371371

372372
elif arg_inputs and inputs:
373373
raise AssertionError(
374-
"'arg_input' and 'input' should not be used at the same time."
374+
"'arg_inputs' and 'inputs' should not be used at the same time."
375375
)
376376
arg_inputs = arg_inputs or inputs
377377

@@ -450,7 +450,7 @@ def save(
450450
output_format: str = "exported_program",
451451
inputs: Optional[Sequence[torch.Tensor]] = None,
452452
arg_inputs: Optional[Sequence[torch.Tensor]] = None,
453-
kwargs_inputs: Optional[dict[str, Any]] = None,
453+
kwarg_inputs: Optional[dict[str, Any]] = None,
454454
retrace: bool = False,
455455
) -> None:
456456
"""
@@ -475,15 +475,15 @@ def save(
475475
)
476476
if arg_inputs and inputs:
477477
raise AssertionError(
478-
"'arg_input' and 'input' should not be used at the same time."
478+
"'arg_inputs' and 'inputs' should not be used at the same time."
479479
)
480480

481481
arg_inputs = inputs or arg_inputs
482482

483-
if kwargs_inputs is None:
484-
kwargs_inputs = {}
483+
if kwarg_inputs is None:
484+
kwarg_inputs = {}
485485

486-
if kwargs_inputs and any(value is None for value in kwargs_inputs.values()):
486+
if kwarg_inputs and any(value is None for value in kwarg_inputs.values()):
487487
raise ValueError("kwargs should not include None.")
488488
if output_format not in accepted_formats:
489489
raise ValueError(
@@ -518,20 +518,20 @@ def save(
518518
# The module type is torch.fx.GraphModule
519519
if output_format == "torchscript":
520520
module_ts = torch.jit.trace(
521-
module, arg_inputs, example_kwarg_inputs=kwargs_inputs
521+
module, arg_inputs, example_kwarg_inputs=kwarg_inputs
522522
)
523523
torch.jit.save(module_ts, file_path)
524524
else:
525525
if not retrace:
526526
from torch_tensorrt.dynamo._exporter import export
527527

528-
exp_program = export(module, arg_inputs, kwargs_inputs)
528+
exp_program = export(module, arg_inputs, kwarg_inputs)
529529
torch.export.save(exp_program, file_path)
530530
else:
531531
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
532532

533533
with enable_torchbind_tracing():
534534
exp_program = torch.export.export(
535-
module, tuple(arg_inputs), kwargs=kwargs_inputs, strict=False
535+
module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
536536
)
537537
torch.export.save(exp_program, file_path)

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,11 @@ def compile(
187187

188188
# Aliasing inputs to arg_inputs for better understanding
189189
if not arg_inputs and not inputs:
190-
raise AssertionError("'arg_input' and 'input' should not both be None.")
190+
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")
191191

192192
elif arg_inputs and inputs:
193193
raise AssertionError(
194-
"'arg_input' and 'input' should not be used at the same time."
194+
"'arg_inputs' and 'inputs' should not be used at the same time."
195195
)
196196

197197
arg_inputs = inputs or arg_inputs
@@ -460,13 +460,13 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
460460

461461
trt_modules[name] = trt_module
462462

463-
torch_sample_inputs = get_torch_inputs(
463+
torch_sample_arg_inputs = get_torch_inputs(
464464
sample_arg_inputs, to_torch_device(settings.device)
465465
)
466466
torch_sample_kwarg_inputs = get_torch_inputs(
467467
sample_kwarg_inputs, to_torch_device(settings.device)
468468
)
469-
sample_outputs = gm(*torch_sample_inputs, **torch_sample_kwarg_inputs)
469+
sample_outputs = gm(*torch_sample_arg_inputs, **torch_sample_kwarg_inputs)
470470

471471
if not isinstance(sample_outputs, (list, tuple)):
472472
sample_outputs = [sample_outputs]
@@ -609,23 +609,22 @@ def convert_module_to_trt_engine(
609609
stacklevel=2,
610610
)
611611
if not arg_inputs and not inputs:
612-
raise AssertionError("'arg_input' and 'input' should not both be None.")
612+
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")
613613

614614
elif arg_inputs and inputs:
615615
raise AssertionError(
616-
"'arg_input' and 'input' should not be used at the same time."
616+
"'arg_inputs' and 'inputs' should not be used at the same time."
617617
)
618618

619619
arg_inputs = inputs or arg_inputs
620620

621-
arg_input_list = list(arg_inputs) if arg_inputs is not None else []
622621
torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set()
623622
if kwarg_inputs is None:
624623
kwarg_inputs = {}
625624
# Prepare torch_trt inputs
626-
arg_input_list = prepare_inputs(arg_input_list)
625+
arg_inputs = prepare_inputs(arg_inputs)
627626
kwarg_input_list = prepare_inputs(kwarg_inputs)
628-
flattened_input_list = arg_input_list + flatten_dict_value(kwarg_input_list)
627+
flattened_input_list = arg_inputs + flatten_dict_value(kwarg_input_list)
629628
device = to_torch_tensorrt_device(device)
630629
enabled_precisions = {dtype._from(e) for e in enabled_precisions}
631630

@@ -679,7 +678,7 @@ def convert_module_to_trt_engine(
679678
interpreter_result = interpret_module_to_result(
680679
gm,
681680
inputs=flattened_input_list,
682-
arg_inputs=arg_input_list,
681+
arg_inputs=arg_inputs,
683682
kwarg_inputs=kwarg_input_list,
684683
settings=settings,
685684
)

py/torch_tensorrt/dynamo/_exporter.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,25 @@
2222
def export(
2323
gm: torch.fx.GraphModule,
2424
inputs: Sequence[torch.Tensor],
25-
kwargs_inputs: Optional[dict[str, Any]] = None,
25+
kwarg_inputs: Optional[dict[str, Any]] = None,
2626
) -> ExportedProgram:
2727
"""Export the result of TensorRT compilation into the desired output format.
2828
2929
Arguments:
3030
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
3131
inputs (torch.Tensor): Torch input tensors
3232
"""
33-
if kwargs_inputs is None:
34-
kwargs_inputs = {}
35-
patched_module = transform(gm, inputs, kwargs_inputs)
33+
if kwarg_inputs is None:
34+
kwarg_inputs = {}
35+
patched_module = transform(gm, inputs, kwarg_inputs)
3636
exp_program = create_trt_exp_program(patched_module)
3737
return exp_program
3838

3939

4040
def transform(
4141
gm: torch.fx.GraphModule,
4242
inputs: Sequence[torch.Tensor],
43-
kwargs_inputs: Optional[dict[str, Any]] = None,
43+
kwarg_inputs: Optional[dict[str, Any]] = None,
4444
) -> torch.fx.GraphModule:
4545
"""
4646
Transforms the graphmodule by inlining Pytorch and TensorRT submodules.
@@ -55,11 +55,11 @@ def transform(
5555
"""
5656
# Make a copy the graph since this function transforms the input graph and changes it's attributes.
5757
# This transformed graph is meant to be consumed by `create_trt_exp_program`
58-
if kwargs_inputs is None:
59-
kwargs_inputs = {}
58+
if kwarg_inputs is None:
59+
kwarg_inputs = {}
6060
gm = copy.deepcopy(gm)
6161
# Run shape analysis
62-
_, outputs_map = partitioning.run_shape_analysis(gm, inputs, kwargs_inputs)
62+
_, outputs_map = partitioning.run_shape_analysis(gm, inputs, kwarg_inputs)
6363

6464
# Inline TensorRT submodules
6565
inline_trt_modules(gm, outputs_map)

py/torch_tensorrt/dynamo/_tracer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@ def trace(
5959

6060
# Set log level at the top of compilation (torch_tensorrt.dynamo)
6161
if not arg_inputs and not inputs:
62-
raise AssertionError("'arg_input' and 'input' should not both be None.")
62+
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")
6363

6464
elif arg_inputs and inputs:
6565
raise AssertionError(
66-
"'arg_input' and 'input' should not be used at the same time."
66+
"'arg_inputs' and 'inputs' should not be used at the same time."
6767
)
6868
arg_inputs = inputs or arg_inputs
6969

@@ -79,7 +79,6 @@ def trace(
7979
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
8080
dynamic_shapes = get_dynamic_shapes_args(mod, arg_inputs)
8181
dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs))
82-
# breakpoint()
8382
exp_program = export(
8483
mod,
8584
tuple(torch_arg_inputs),

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _pretraced_backend(
8484
]
8585

8686
# Remove detach nodes
87-
remove_detach(gm, torch_inputs)
87+
remove_detach(gm)
8888

8989
# Invoke AOTAutograd to translate operators to aten
9090
gm = aot_export_joint_simple(

py/torch_tensorrt/dynamo/partitioning/common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
131131
def run_shape_analysis(
132132
parent_module: torch.fx.GraphModule,
133133
inputs: Sequence[Input],
134-
kwargs_inputs: Optional[dict[str, Any]] = None,
134+
kwarg_inputs: Optional[dict[str, Any]] = None,
135135
) -> Tuple[Dict[Any, Sequence[Any]], Dict[Any, Sequence[Any]]]:
136136
submod_inputs_shape_map: Dict[Any, Sequence[Any]] = {}
137137
submod_outputs_shape_map: Dict[Any, Sequence[Any]] = {}
@@ -147,13 +147,13 @@ def get_submodule_io(
147147
sub_outputs = outputs
148148
return
149149

150-
if kwargs_inputs is None:
151-
kwargs_inputs = {}
150+
if kwarg_inputs is None:
151+
kwarg_inputs = {}
152152
# Iterate through submodules (both Torch and TRT) and store IO shapes
153153
for name, _ in parent_module.named_children():
154154
submodule = getattr(parent_module, name)
155155
handle = submodule.register_forward_hook(get_submodule_io)
156-
parent_module(*inputs, **kwargs_inputs)
156+
parent_module(*inputs, **kwarg_inputs)
157157
handle.remove()
158158
submod_inputs_shape_map[name] = (
159159
[input.shape for input in sub_inputs]

tests/py/dynamo/models/test_export_kwargs_serde.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def forward(self, x, b=5, c=None, d=None):
7474

7575
# Save the module
7676
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
77-
torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwargs_inputs=kwargs)
77+
torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwarg_inputs=kwargs)
7878
# Clean up model env
7979
torch._dynamo.reset()
8080

@@ -133,7 +133,7 @@ def forward(self, x, b=5, c=None, d=None):
133133

134134
# Save the module
135135
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
136-
torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwargs_inputs=kwargs)
136+
torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwarg_inputs=kwargs)
137137
# Clean up model env
138138
torch._dynamo.reset()
139139

@@ -201,7 +201,7 @@ def forward(self, x, b=5, c=None, d=None):
201201

202202
# Save the module
203203
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
204-
torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwargs_inputs=kwargs)
204+
torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwarg_inputs=kwargs)
205205
# Clean up model env
206206
torch._dynamo.reset()
207207

@@ -288,7 +288,7 @@ def forward(self, x, b=None, c=None, d=None, e=[]):
288288
)
289289
# Save the module
290290
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
291-
torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwargs_inputs=kwargs)
291+
torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwarg_inputs=kwargs)
292292
# Clean up model env
293293
torch._dynamo.reset()
294294

@@ -375,7 +375,7 @@ def forward(self, x, b=None, c=None, d=None, e=[]):
375375
)
376376
# Save the module
377377
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
378-
torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwargs_inputs=kwargs)
378+
torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwarg_inputs=kwargs)
379379
# Clean up model env
380380
torch._dynamo.reset()
381381

0 commit comments

Comments
 (0)