Skip to content

Commit 75dc485

Browse files
committed
Added kwarg support for dynamo.compile
1 parent abed8f0 commit 75dc485

File tree

3 files changed

+118
-16
lines changed

3 files changed

+118
-16
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def compile(
4949
exported_program: ExportedProgram,
5050
inputs: Tuple[Any, ...],
5151
*,
52+
kwarg_inputs: Any = None,
5253
device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE,
5354
disable_tf32: bool = _defaults.DISABLE_TF32,
5455
assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT,
@@ -148,7 +149,6 @@ def compile(
148149

149150
if debug:
150151
set_log_level(logger.parent, logging.DEBUG)
151-
152152
if "truncate_long_and_double" in kwargs.keys():
153153
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
154154
raise ValueError(
@@ -173,6 +173,8 @@ def compile(
173173
else:
174174
make_refitable = kwargs["refit"]
175175

176+
if kwarg_inputs is None:
177+
kwarg_inputs = {}
176178
engine_capability = EngineCapability._from(engine_capability)
177179

178180
if torch_executed_modules is not None and torch_executed_modules:
@@ -186,22 +188,22 @@ def compile(
186188

187189
# Prepare torch_trt inputs
188190
inputs = prepare_inputs(inputs)
189-
torch_inputs = get_torch_inputs(inputs, device)
191+
kwarg_inputs = prepare_inputs(kwarg_inputs)
190192
device = to_torch_tensorrt_device(device)
191193
enabled_precisions = {dtype._from(p) for p in enabled_precisions}
192194

193195
if not isinstance(exported_program, ExportedProgram):
194196
raise AssertionError(
195197
f"Input graph should be an ExportedProgram but got type {type(exported_program)}"
196198
)
197-
exported_program = pre_export_lowering(exported_program, torch_inputs)
199+
exported_program = pre_export_lowering(exported_program, None)
198200
exported_program = exported_program.run_decompositions(
199201
get_decompositions(enable_experimental_decompositions)
200202
)
201203
gm = exported_program.module()
202204
logger.debug("Input graph: " + str(gm.graph))
203205
# Apply lowering on the graph module
204-
gm = post_lowering(gm, torch_inputs)
206+
gm = post_lowering(gm, None)
205207
logger.debug("Lowered Input graph: " + str(gm.graph))
206208

207209
compilation_options = {
@@ -240,13 +242,14 @@ def compile(
240242

241243
settings = CompilationSettings(**compilation_options)
242244
logger.info("Compilation Settings: %s\n", settings)
243-
trt_gm = compile_module(gm, inputs, settings)
245+
trt_gm = compile_module(gm, inputs, kwarg_inputs, settings)
244246
return trt_gm
245247

246248

247249
def compile_module(
248250
gm: torch.fx.GraphModule,
249251
sample_inputs: Sequence[Input],
252+
sample_kwarg_inputs: Any = None,
250253
settings: CompilationSettings = CompilationSettings(),
251254
) -> torch.fx.GraphModule:
252255
"""Compile a traced FX module
@@ -261,7 +264,8 @@ def compile_module(
261264
Compiled FX GraphModule
262265
"""
263266
dryrun_tracker = DryRunTracker()
264-
267+
if sample_kwarg_inputs is None:
268+
sample_kwarg_inputs = {}
265269
# Assume converters support dynamic shapes and disable validation
266270
CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support)
267271

@@ -437,9 +441,13 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
437441

438442
trt_modules[name] = trt_module
439443

440-
sample_outputs = gm(
441-
*get_torch_inputs(sample_inputs, to_torch_device(settings.device))
444+
torch_sample_inputs = get_torch_inputs(
445+
sample_inputs, to_torch_device(settings.device)
446+
)
447+
torch_sample_kwarg_inputs = get_torch_inputs(
448+
sample_kwarg_inputs, to_torch_device(settings.device)
442449
)
450+
sample_outputs = gm(*torch_sample_inputs, **torch_sample_kwarg_inputs)
443451

444452
if not isinstance(sample_outputs, (list, tuple)):
445453
sample_outputs = [sample_outputs]

py/torch_tensorrt/dynamo/utils.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,24 +128,45 @@ def input_is_dynamic(inputs: Sequence[Union[Input, torch.Tensor]]) -> bool:
128128

129129

130130
def get_torch_inputs(
131-
inputs: Sequence[Input], device: Union[Device, torch.device, str], mode: str = ""
132-
) -> Sequence[torch.tensor]:
131+
inputs: Sequence[Input] | Dict[Any, Any],
132+
device: Union[Device, torch.device, str],
133+
mode: str = "",
134+
) -> Sequence[torch.tensor] | Dict[Any, Any]:
133135
"""
134136
Return the torch_tensor from the Input object. If mode is set, this implies
135137
user is using dynamic shaped inputs and return the corresponding input based
136138
on the mode requested.
137139
"""
138140
device = to_torch_device(device)
139141
if mode:
142+
if isinstance(inputs, dict):
143+
result = {}
144+
for k, v in inputs.items():
145+
if isinstance(v, (list, tuple, dict)):
146+
result[k] = get_torch_inputs(v, device)
147+
else:
148+
result[k] = v.example_tensor(mode).to(device)
149+
return result
150+
else:
151+
return [
152+
input.example_tensor(mode).to(device)
153+
for input in inputs
154+
if isinstance(input, Input)
155+
]
156+
157+
if isinstance(inputs, dict):
158+
result = {}
159+
for k, v in inputs.items():
160+
if isinstance(v, (list, tuple, dict)):
161+
result[k] = get_torch_inputs(v, device)
162+
else:
163+
result[k] = v.torch_tensor.to(device)
164+
return result
165+
else:
140166
return [
141-
input.example_tensor(mode).to(device)
167+
input.torch_tensor.to(device) if isinstance(input, Input) else input
142168
for input in inputs
143-
if isinstance(input, Input)
144169
]
145-
return [
146-
input.torch_tensor.to(device) if isinstance(input, Input) else input
147-
for input in inputs
148-
]
149170

150171

151172
def set_log_level(parent_logger: Any, level: Any) -> None:
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# type: ignore
2+
import unittest
3+
4+
import pytest
5+
import timm
6+
import torch
7+
import torch.nn.functional as F
8+
import torch_tensorrt as torchtrt
9+
import torchvision.models as models
10+
from torch import nn
11+
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
12+
from transformers import BertModel
13+
from transformers.utils.fx import symbolic_trace as transformers_trace
14+
15+
assertions = unittest.TestCase()
16+
17+
18+
@pytest.mark.unit
19+
def test_custom_model():
20+
class net(nn.Module):
21+
def __init__(self):
22+
super().__init__()
23+
self.conv1 = nn.Conv2d(3, 12, 3, padding=1)
24+
self.bn = nn.BatchNorm2d(12)
25+
self.conv2 = nn.Conv2d(12, 12, 3, padding=1)
26+
self.fc1 = nn.Linear(12 * 56 * 56, 10)
27+
28+
def forward(self, x, b=5, c=None, d=None):
29+
x = self.conv1(x)
30+
x = F.relu(x)
31+
x = self.bn(x)
32+
x = F.max_pool2d(x, (2, 2))
33+
x = self.conv2(x)
34+
x = F.relu(x)
35+
x = F.max_pool2d(x, (2, 2))
36+
x = torch.flatten(x, 1)
37+
x = x + b
38+
if c is not None:
39+
x = x * c
40+
if d is not None:
41+
x = x - d["value"]
42+
return self.fc1(x)
43+
44+
model = net().eval().to("cuda")
45+
args = [torch.rand((1, 3, 224, 224)).to("cuda")]
46+
kwargs = {
47+
"b": torch.tensor(6).to("cuda"),
48+
"d": {"value": torch.tensor(8).to("cuda")},
49+
}
50+
51+
compile_spec = {
52+
"inputs": args,
53+
"kwarg_inputs": kwargs,
54+
"device": torchtrt.Device("cuda:0"),
55+
"enabled_precisions": {torch.float},
56+
"pass_through_build_failures": True,
57+
"optimization_level": 1,
58+
"min_block_size": 1,
59+
"ir": "dynamo",
60+
}
61+
# TODO: Support torchtrt.compile
62+
# trt_mod = torchtrt.compile(model, **compile_spec)
63+
64+
exp_program = torch.export.export(model, args=tuple(args), kwargs=kwargs)
65+
trt_mod = torchtrt.dynamo.compile(exp_program, **compile_spec)
66+
cos_sim = cosine_similarity(model(*args, **kwargs), trt_mod(*args, **kwargs)[0])
67+
assertions.assertTrue(
68+
cos_sim > COSINE_THRESHOLD,
69+
msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
70+
)
71+
72+
# Clean up model env
73+
torch._dynamo.reset()

0 commit comments

Comments
 (0)