Skip to content

Commit f99b484

Browse files
committed
Added kwarg to refit
1 parent 7527995 commit f99b484

File tree

7 files changed

+35
-29
lines changed

7 files changed

+35
-29
lines changed

docs/_downloads/7e3a125a2d4ba8274a41b46f5e0723fa/refit_engine_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
new_trt_gm = refit_module_weights(
7979
compiled_module=compiled_trt_ep,
8080
new_weight_module=exp_program2,
81-
inputs=inputs,
81+
arg_inputs=inputs,
8282
)
8383

8484
# Check the output

examples/dynamo/refit_engine_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
new_trt_gm = refit_module_weights(
7979
compiled_module=compiled_trt_ep,
8080
new_weight_module=exp_program2,
81-
inputs=inputs,
81+
arg_inputs=inputs,
8282
)
8383

8484
# Check the output

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import collections.abc
44
import copy
55
import logging
6-
from typing import Any, Sequence, Tuple
6+
from typing import Any, Optional, Sequence, Tuple
77

88
import numpy as np
99
import tensorrt as trt
@@ -36,7 +36,6 @@
3636
from torch_tensorrt.dynamo.utils import (
3737
check_output,
3838
get_torch_inputs,
39-
prepare_inputs,
4039
set_log_level,
4140
to_torch_device,
4241
to_torch_tensorrt_device,
@@ -146,7 +145,8 @@ def _refit_single_trt_engine_with_gm(
146145
def refit_module_weights(
147146
compiled_module: torch.fx.GraphModule | ExportedProgram,
148147
new_weight_module: ExportedProgram,
149-
inputs: Tuple[Any, ...],
148+
arg_inputs: Optional[Tuple[Any, ...]] = None,
149+
kwarg_inputs: Optional[dict[str, Any]] = None,
150150
verify_output: bool = False,
151151
) -> torch.fx.GraphModule:
152152
"""
@@ -208,27 +208,29 @@ def refit_module_weights(
208208
if settings.debug:
209209
set_log_level(logger.parent, logging.DEBUG)
210210

211-
if not isinstance(inputs, collections.abc.Sequence):
212-
inputs = [inputs]
213-
214-
# Prepare torch_trt inputs
215-
inputs = prepare_inputs(inputs)
216211
device = to_torch_tensorrt_device(settings.device)
217-
torch_inputs = get_torch_inputs(inputs, device)
212+
if arg_inputs:
213+
if not isinstance(arg_inputs, collections.abc.Sequence):
214+
# Prepare torch_trt inputs
215+
arg_inputs = [arg_inputs]
216+
torch_inputs = get_torch_inputs(arg_inputs, device)
217+
218+
if kwarg_inputs:
219+
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
218220
runtime = trt.Runtime(TRT_LOGGER)
219221
if not isinstance(new_weight_module, ExportedProgram):
220222
raise AssertionError(
221223
f"Input graph should be an ExportedProgram but got type {type(new_weight_module)}"
222224
)
223-
new_weight_module = pre_export_lowering(new_weight_module, torch_inputs)
225+
new_weight_module = pre_export_lowering(new_weight_module)
224226
new_weight_module = new_weight_module.run_decompositions(
225227
get_decompositions(settings.enable_experimental_decompositions)
226228
)
227229
new_gm = new_weight_module.module()
228230
logger.debug("Input graph: " + str(new_gm.graph))
229231
# Apply lowering on the graph module
230232

231-
new_gm = post_lowering(new_gm, torch_inputs)
233+
new_gm = post_lowering(new_gm)
232234

233235
logger.info("Compilation Settings: %s\n", settings)
234236

@@ -354,11 +356,12 @@ def refit_module_weights(
354356
refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info))
355357
setattr(compiled_module, f"{name}_engine", refitted_engine)
356358

357-
if verify_output:
359+
if verify_output and arg_inputs is not None:
358360
if check_output(
359361
new_module=new_gm,
360362
refitted_module=compiled_module,
361-
inputs=torch_inputs,
363+
arg_inputs=torch_inputs,
364+
kwarg_inputs=torch_kwarg_inputs,
362365
):
363366
logger.info("Refitting Succeed!")
364367
else:

py/torch_tensorrt/dynamo/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,9 +397,12 @@ def function_wrapper(*args: Any, **kwargs: Any) -> Any:
397397
def check_output(
398398
new_module: torch.fx.GraphModule,
399399
refitted_module: torch.fx.GraphModule,
400-
inputs: tuple[Any, ...],
400+
arg_inputs: Any,
401+
kwarg_inputs: Any = None,
401402
) -> bool:
402-
old_outputs, new_outputs = refitted_module(*inputs), new_module(*inputs)
403+
old_outputs, new_outputs = refitted_module(*arg_inputs), new_module(
404+
*arg_inputs, **kwarg_inputs
405+
)
403406
for old_output, new_output in zip(old_outputs, new_outputs):
404407
if isinstance(old_output, torch.Tensor) and isinstance(
405408
new_outputs, torch.Tensor

tests/py/dynamo/models/test_dtype_support.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def forward(self, x):
3636
exp_mod = torch.export.export(mod, (in_tensor,))
3737
trt_mod = torch_tensorrt.dynamo.compile(
3838
exp_mod,
39-
arg_inputs=[in_tensor],
39+
inputs=[in_tensor],
4040
pass_through_build_failures=True,
4141
truncate_double=True,
4242
min_block_size=1,
@@ -74,7 +74,7 @@ def forward(self, x):
7474
exp_mod = torch.export.export(mod, (in_tensor,))
7575
trt_mod = torch_tensorrt.dynamo.compile(
7676
exp_mod,
77-
arg_inputs=[in_tensor],
77+
inputs=[in_tensor],
7878
pass_through_build_failures=True,
7979
truncate_double=True,
8080
min_block_size=1,
@@ -118,7 +118,7 @@ def forward(self, x):
118118
exp_mod = torch.export.export(mod, (in_tensor,))
119119
trt_mod = torch_tensorrt.dynamo.compile(
120120
exp_mod,
121-
arg_inputs=[in_tensor],
121+
inputs=[in_tensor],
122122
pass_through_build_failures=True,
123123
truncate_double=False,
124124
min_block_size=1,
@@ -157,7 +157,7 @@ def forward(self, x):
157157
exp_mod = torch.export.export(mod, (in_tensor,))
158158
trt_mod = torch_tensorrt.dynamo.compile(
159159
exp_mod,
160-
arg_inputs=[in_tensor],
160+
inputs=[in_tensor],
161161
pass_through_build_failures=True,
162162
truncate_double=False,
163163
min_block_size=1,
@@ -201,7 +201,7 @@ def forward(self, x):
201201
exp_mod = torch.export.export(mod, (in_tensor,))
202202
trt_mod = torch_tensorrt.dynamo.compile(
203203
exp_mod,
204-
arg_inputs=[in_tensor],
204+
inputs=[in_tensor],
205205
pass_through_build_failures=True,
206206
enabled_precisions={torch.float, torch.bfloat16, torch.half},
207207
min_block_size=1,
@@ -239,7 +239,7 @@ def forward(self, x):
239239
exp_mod = torch.export.export(mod, (in_tensor,))
240240
trt_mod = torch_tensorrt.dynamo.compile(
241241
exp_mod,
242-
arg_inputs=[in_tensor],
242+
inputs=[in_tensor],
243243
pass_through_build_failures=True,
244244
enabled_precisions={torch.float, torch.bfloat16, torch.half},
245245
min_block_size=1,

tests/py/dynamo/models/test_model_refit.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_refit_one_engine():
108108
new_trt_gm = refit_module_weights(
109109
compiled_module=trt_gm,
110110
new_weight_module=exp_program2,
111-
inputs=inputs,
111+
arg_inputs=inputs,
112112
)
113113

114114
# Check the output
@@ -154,7 +154,7 @@ def test_refit_one_engine_bert():
154154
new_trt_gm = refit_module_weights(
155155
compiled_module=trt_gm,
156156
new_weight_module=exp_program2,
157-
inputs=inputs,
157+
arg_inputs=inputs,
158158
)
159159

160160
# Check the output
@@ -203,7 +203,7 @@ def test_refit_one_engine_inline_runtime():
203203
new_trt_gm = refit_module_weights(
204204
compiled_module=trt_gm,
205205
new_weight_module=exp_program2,
206-
inputs=inputs,
206+
arg_inputs=inputs,
207207
)
208208

209209
# Check the output
@@ -247,7 +247,7 @@ def test_refit_one_engine_python_runtime():
247247
new_trt_gm = refit_module_weights(
248248
compiled_module=trt_gm,
249249
new_weight_module=exp_program2,
250-
inputs=inputs,
250+
arg_inputs=inputs,
251251
)
252252

253253
# Check the output
@@ -313,7 +313,7 @@ def forward(self, x):
313313
new_trt_gm = refit_module_weights(
314314
compiled_module=trt_gm,
315315
new_weight_module=exp_program2,
316-
inputs=inputs,
316+
arg_inputs=inputs,
317317
)
318318

319319
# Check the output

tests/py/dynamo/models/test_models_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def calibrate_loop(model):
222222
exp_program = torch.export.export(model, (input_tensor,))
223223
trt_model = torchtrt.dynamo.compile(
224224
exp_program,
225-
arg_inputs=[input_tensor],
225+
inputs=[input_tensor],
226226
enabled_precisions={torch.float8_e4m3fn},
227227
min_block_size=1,
228228
debug=True,

0 commit comments

Comments
 (0)