Skip to content

Commit d951743

Browse files
committed
Added dynamic shape support for kwargs and dynamo.trace
1 parent 08f2cbb commit d951743

File tree

3 files changed

+536
-253
lines changed

3 files changed

+536
-253
lines changed

py/torch_tensorrt/dynamo/_tracer.py

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

33
import logging
4-
from typing import Any, Optional, Tuple
4+
from inspect import signature
5+
from typing import Any, Optional, Tuple, Union
56

67
import torch
78
from torch.export import Dim, export
@@ -76,14 +77,58 @@ def trace(
7677
device = to_torch_device(kwargs.get("device", default_device()))
7778
torch_arg_inputs = get_torch_inputs(arg_inputs, device)
7879
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
79-
dynamic_shapes = []
80-
for input in arg_inputs: # type: ignore
81-
if isinstance(input, Input) and input.shape_mode == Input._ShapeMode.DYNAMIC:
80+
dynamic_shapes = get_dynamic_shapes_args(mod, arg_inputs)
81+
dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs))
82+
# breakpoint()
83+
exp_program = export(
84+
mod,
85+
tuple(torch_arg_inputs),
86+
kwargs=torch_kwarg_inputs,
87+
dynamic_shapes=dynamic_shapes,
88+
)
89+
90+
return exp_program
91+
92+
93+
def get_dynamic_shapes_kwargs(inputs: Any) -> Union[dict[str, Any], list[Any]]:
94+
if isinstance(inputs, dict):
95+
dynamic_shapes_kwarg = {}
96+
for k, v in inputs.items():
97+
dynamic_shapes_kwarg[k] = get_dynamic_shapes_kwargs(v)
98+
return dynamic_shapes_kwarg
99+
100+
elif isinstance(inputs, Input):
101+
return get_dynamic_shapes(inputs)
102+
103+
elif isinstance(inputs, (list, tuple)):
104+
dynamic_shapes = []
105+
for input in inputs:
106+
dynamic_shapes.append(get_dynamic_shapes(input))
107+
return dynamic_shapes
108+
109+
raise TypeError(f"Unknown type {type(inputs)}.")
110+
111+
112+
def get_dynamic_shapes_args(mod: torch.nn.Module, inputs: Any) -> dict[str, Any]:
113+
# dynamic_shape is a dict and cannot work without keys. Here we use position argument name
114+
# in forward function as the name
115+
args = list(signature(mod.forward).parameters.keys())
116+
dynamic_shapes = {}
117+
for input, input_name in zip(inputs, args[: len(inputs)]):
118+
dynamic_shapes[input_name] = get_dynamic_shapes(input)
119+
return dynamic_shapes
120+
121+
122+
def get_dynamic_shapes(input: Input) -> dict[Any, Any]:
123+
if not isinstance(input, Input):
124+
raise TypeError(f"Expected type torch_trt.Input, but got {type(input)}")
125+
else:
126+
dynamic_dims = {}
127+
if input.shape_mode == Input._ShapeMode.DYNAMIC:
82128
min_shape = input.shape["min_shape"]
83129
opt_shape = input.shape["opt_shape"]
84130
max_shape = input.shape["max_shape"]
85131
assert len(min_shape) == len(opt_shape) == len(max_shape)
86-
dynamic_dims = {}
87132
for dim in range(len(min_shape)):
88133
if min_shape[dim] == opt_shape[dim] == max_shape[dim]:
89134
continue
@@ -93,14 +138,4 @@ def trace(
93138
min=min_shape[dim],
94139
max=max_shape[dim],
95140
)
96-
97-
dynamic_shapes.append(dynamic_dims)
98-
99-
exp_program = export(
100-
mod,
101-
tuple(torch_arg_inputs),
102-
kwargs=torch_kwarg_inputs,
103-
dynamic_shapes=tuple(dynamic_shapes),
104-
)
105-
106-
return exp_program
141+
return dynamic_dims

tests/py/dynamo/models/test_export_kwargs_serde.py

Lines changed: 217 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,135 @@
1111
import torchvision.models as models
1212
from torch import nn
1313
from torch_tensorrt.dynamo._compiler import convert_module_to_trt_engine
14-
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
14+
from torch_tensorrt.dynamo.utils import (
15+
COSINE_THRESHOLD,
16+
cosine_similarity,
17+
prepare_inputs,
18+
)
1519

1620
assertions = unittest.TestCase()
1721

1822

23+
# @pytest.mark.unit
24+
# def test_custom_model():
25+
# class net(nn.Module):
26+
# def __init__(self):
27+
# super().__init__()
28+
# self.conv1 = nn.Conv2d(3, 12, 3, padding=1)
29+
# self.bn = nn.BatchNorm2d(12)
30+
# self.conv2 = nn.Conv2d(12, 12, 3, padding=1)
31+
# self.fc1 = nn.Linear(12 * 56 * 56, 10)
32+
33+
# def forward(self, x, b=5, c=None, d=None):
34+
# x = self.conv1(x)
35+
# x = F.relu(x)
36+
# x = self.bn(x)
37+
# x = F.max_pool2d(x, (2, 2))
38+
# x = self.conv2(x)
39+
# x = F.relu(x)
40+
# x = F.max_pool2d(x, (2, 2))
41+
# x = torch.flatten(x, 1)
42+
# x = x + b
43+
# if c is not None:
44+
# x = x * c
45+
# if d is not None:
46+
# x = x - d["value"]
47+
# return self.fc1(x)
48+
49+
# model = net().eval().to("cuda")
50+
# args = [torch.rand((1, 3, 224, 224)).to("cuda")]
51+
# kwargs = {
52+
# "b": torch.tensor(6).to("cuda"),
53+
# "d": {"value": torch.tensor(8).to("cuda")},
54+
# }
55+
56+
# compile_spec = {
57+
# "inputs": args,
58+
# "kwarg_inputs": kwargs,
59+
# "device": torchtrt.Device("cuda:0"),
60+
# "enabled_precisions": {torch.float},
61+
# "pass_through_build_failures": True,
62+
# "optimization_level": 1,
63+
# "min_block_size": 1,
64+
# "ir": "dynamo",
65+
# }
66+
67+
# exp_program = torch.export.export(model, args=tuple(args), kwargs=kwargs)
68+
# trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
69+
# cos_sim = cosine_similarity(model(*args, **kwargs), trt_gm(*args, **kwargs)[0])
70+
# assertions.assertTrue(
71+
# cos_sim > COSINE_THRESHOLD,
72+
# msg=f"CustomKwargs Module TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
73+
# )
74+
75+
# # Save the module
76+
# trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
77+
# torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwargs_inputs=kwargs)
78+
# # Clean up model env
79+
# torch._dynamo.reset()
80+
81+
82+
# @pytest.mark.unit
83+
# def test_custom_model_with_dynamo_trace():
84+
# class net(nn.Module):
85+
# def __init__(self):
86+
# super().__init__()
87+
# self.conv1 = nn.Conv2d(3, 12, 3, padding=1)
88+
# self.bn = nn.BatchNorm2d(12)
89+
# self.conv2 = nn.Conv2d(12, 12, 3, padding=1)
90+
# self.fc1 = nn.Linear(12 * 56 * 56, 10)
91+
92+
# def forward(self, x, b=5, c=None, d=None):
93+
# x = self.conv1(x)
94+
# x = F.relu(x)
95+
# x = self.bn(x)
96+
# x = F.max_pool2d(x, (2, 2))
97+
# x = self.conv2(x)
98+
# x = F.relu(x)
99+
# x = F.max_pool2d(x, (2, 2))
100+
# x = torch.flatten(x, 1)
101+
# x = x + b
102+
# if c is not None:
103+
# x = x * c
104+
# if d is not None:
105+
# x = x - d["value"]
106+
# return self.fc1(x)
107+
108+
# model = net().eval().to("cuda")
109+
# args = [torch.rand((1, 3, 224, 224)).to("cuda")]
110+
# kwargs = {
111+
# "b": torch.tensor(6).to("cuda"),
112+
# "d": {"value": torch.tensor(8).to("cuda")},
113+
# }
114+
115+
# compile_spec = {
116+
# "inputs": prepare_inputs(args),
117+
# "kwarg_inputs": prepare_inputs(kwargs),
118+
# "device": torchtrt.Device("cuda:0"),
119+
# "enabled_precisions": {torch.float},
120+
# "pass_through_build_failures": True,
121+
# "optimization_level": 1,
122+
# "min_block_size": 1,
123+
# "ir": "dynamo",
124+
# }
125+
126+
# exp_program = torchtrt.dynamo.trace(model, **compile_spec)
127+
# trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
128+
# cos_sim = cosine_similarity(model(*args, **kwargs), trt_gm(*args, **kwargs)[0])
129+
# assertions.assertTrue(
130+
# cos_sim > COSINE_THRESHOLD,
131+
# msg=f"CustomKwargs Module TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
132+
# )
133+
134+
# # Save the module
135+
# trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
136+
# torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwargs_inputs=kwargs)
137+
# # Clean up model env
138+
# torch._dynamo.reset()
139+
140+
19141
@pytest.mark.unit
20-
def test_custom_model():
142+
def test_custom_model_with_dynamo_trace_dynamic():
21143
class net(nn.Module):
22144
def __init__(self):
23145
super().__init__()
@@ -50,8 +172,17 @@ def forward(self, x, b=5, c=None, d=None):
50172
}
51173

52174
compile_spec = {
53-
"inputs": args,
54-
"kwarg_inputs": kwargs,
175+
# "arg_inputs": prepare_inputs(args),
176+
"inputs": [
177+
torchtrt.Input(
178+
min_shape=(1, 3, 224, 224),
179+
opt_shape=(4, 3, 224, 224),
180+
max_shape=(8, 3, 224, 224),
181+
dtype=torch.float32,
182+
name="x",
183+
)
184+
],
185+
"kwarg_inputs": prepare_inputs(kwargs),
55186
"device": torchtrt.Device("cuda:0"),
56187
"enabled_precisions": {torch.float},
57188
"pass_through_build_failures": True,
@@ -60,7 +191,88 @@ def forward(self, x, b=5, c=None, d=None):
60191
"ir": "dynamo",
61192
}
62193

63-
exp_program = torch.export.export(model, args=tuple(args), kwargs=kwargs)
194+
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
195+
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
196+
cos_sim = cosine_similarity(model(*args, **kwargs), trt_gm(*args, **kwargs)[0])
197+
assertions.assertTrue(
198+
cos_sim > COSINE_THRESHOLD,
199+
msg=f"CustomKwargs Module TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
200+
)
201+
202+
# Save the module
203+
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
204+
torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwargs_inputs=kwargs)
205+
# Clean up model env
206+
torch._dynamo.reset()
207+
208+
209+
@pytest.mark.unit
210+
def test_custom_model_with_dynamo_trace_dynamic_complex():
211+
ir = "dynamo"
212+
213+
class net(nn.Module):
214+
def __init__(self):
215+
super().__init__()
216+
self.conv1 = nn.Conv2d(3, 12, 3, padding=1)
217+
self.bn = nn.BatchNorm2d(12)
218+
self.conv2 = nn.Conv2d(12, 12, 3, padding=1)
219+
self.fc1 = nn.Linear(12 * 56 * 56, 10)
220+
221+
def forward(self, x, b=None, c=None, d=None, e=[]):
222+
x = self.conv1(x)
223+
x = F.relu(x)
224+
x = self.bn(x)
225+
x = F.max_pool2d(x, (2, 2))
226+
x = self.conv2(x)
227+
x = F.relu(x)
228+
x = F.max_pool2d(x, (2, 2))
229+
x = torch.flatten(x, 1)
230+
x = x @ b
231+
if c is not None:
232+
x = x * c
233+
if d is not None:
234+
x = x - d["value"]
235+
for n in e:
236+
x += n
237+
return x
238+
239+
model = net().eval().to("cuda")
240+
args = [torch.rand((1, 3, 224, 224)).to("cuda")]
241+
kwargs = {
242+
"b": torch.rand((37632, 10)).to("cuda"),
243+
"d": {"value": torch.tensor(8).to("cuda")},
244+
"e": [torch.tensor(8).to("cuda"), torch.tensor(10).to("cuda")],
245+
}
246+
model(*args, **kwargs)
247+
kwarg_torchtrt_input = prepare_inputs(kwargs)
248+
kwarg_torchtrt_input["b"] = torchtrt.Input(
249+
min_shape=(37632, 1),
250+
opt_shape=(37632, 5),
251+
max_shape=(37632, 10),
252+
dtype=torch.float32,
253+
name="b",
254+
)
255+
compile_spec = {
256+
# "arg_inputs": prepare_inputs(args),
257+
"inputs": [
258+
torchtrt.Input(
259+
min_shape=(1, 3, 224, 224),
260+
opt_shape=(4, 3, 224, 224),
261+
max_shape=(8, 3, 224, 224),
262+
dtype=torch.float32,
263+
name="x",
264+
),
265+
],
266+
"kwarg_inputs": kwarg_torchtrt_input,
267+
"device": torchtrt.Device("cuda:0"),
268+
"enabled_precisions": {torch.float},
269+
"pass_through_build_failures": True,
270+
"optimization_level": 1,
271+
"min_block_size": 1,
272+
"ir": "dynamo",
273+
}
274+
275+
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
64276
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
65277
cos_sim = cosine_similarity(model(*args, **kwargs), trt_gm(*args, **kwargs)[0])
66278
assertions.assertTrue(

0 commit comments

Comments
 (0)