Skip to content

Commit f422159

Browse files
authored
chore: fix ValueRanges computation in symbolic nodes (#2918)
1 parent a14752c commit f422159

File tree

3 files changed

+59
-29
lines changed

3 files changed

+59
-29
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1779,6 +1779,7 @@ def aten_ops_add(
17791779
)
17801780

17811781

1782+
@dynamo_tensorrt_converter(operator.mul, supports_dynamic_shapes=True)
17821783
@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor, supports_dynamic_shapes=True)
17831784
@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar, supports_dynamic_shapes=True)
17841785
def aten_ops_mul(

py/torch_tensorrt/dynamo/partitioning/common.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,16 @@ def construct_dynamic_input(
3535
node = dim.node
3636
expr = node.expr
3737
shape_env = node.shape_env
38-
var_range = shape_env.var_to_range.get(expr, None)
39-
var_val = shape_env.var_to_val.get(expr, None)
38+
# An expr can be a independent SymInt node (eg: s0 or s1) or a composition of them eg: (48*s0 or s0*s1).
39+
# In the case of expr which has symbolic computation, bound_sympy evaluates them.
40+
# https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.ShapeEnv.html#torch.fx.experimental.symbolic_shapes.ShapeEnv.bound_sympy
41+
# expr.xreplace replaces the symbolic variables with their current values and computes the expression.
42+
var_range = shape_env.var_to_range.get(expr, None) or shape_env.bound_sympy(
43+
expr
44+
)
45+
var_val = shape_env.var_to_val.get(expr, None) or expr.xreplace(
46+
shape_env.var_to_val
47+
)
4048
assert var_range, var_val
4149
# Torchdynamo 0/1 specialization outlier
4250
if var_range.lower == 2:

tests/py/dynamo/models/test_dyn_models.py

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,6 @@ def forward(self, x):
6464
cos_sim > COSINE_THRESHOLD,
6565
msg=f"test_dyn_full_compile model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
6666
)
67-
# Clean up model env
68-
torch._dynamo.reset()
69-
70-
with torch.no_grad():
71-
torch.cuda.empty_cache()
7267

7368

7469
@unittest.skip(
@@ -128,12 +123,6 @@ def forward(self, x):
128123
msg=f"test_base_dynamic_fallback model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
129124
)
130125

131-
# Clean up model env
132-
torch._dynamo.reset()
133-
134-
with torch.no_grad():
135-
torch.cuda.empty_cache()
136-
137126

138127
@pytest.mark.unit
139128
def test_view(ir):
@@ -185,12 +174,6 @@ def forward(self, x):
185174
msg=f"test_view model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
186175
)
187176

188-
# Clean up model env
189-
torch._dynamo.reset()
190-
191-
with torch.no_grad():
192-
torch.cuda.empty_cache()
193-
194177

195178
@pytest.mark.unit
196179
def test_resnet_dynamic(ir):
@@ -234,12 +217,6 @@ def test_resnet_dynamic(ir):
234217
msg=f"test_resnet_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
235218
)
236219

237-
# Clean up model env
238-
torch._dynamo.reset()
239-
240-
with torch.no_grad():
241-
torch.cuda.empty_cache()
242-
243220

244221
@pytest.mark.unit
245222
def test_view(ir):
@@ -284,8 +261,52 @@ def forward(self, x):
284261
msg=f"test_base_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
285262
)
286263

287-
# Clean up model env
288-
torch._dynamo.reset()
289264

290-
with torch.no_grad():
291-
torch.cuda.empty_cache()
265+
@pytest.mark.unit
266+
def test_linear(ir):
267+
"""
268+
Tests the model with linear op and operator.mul (added internally by PyTorch)
269+
with dynamic shapes
270+
"""
271+
272+
class MyModule(torch.nn.Module):
273+
def __init__(self):
274+
super().__init__()
275+
self.linear1 = torch.nn.Linear(10, 10)
276+
277+
def forward(self, x):
278+
return self.linear1(x)
279+
280+
model = MyModule().eval().cuda()
281+
282+
compile_spec = {
283+
"device": torchtrt.Device("cuda:0"),
284+
"enabled_precisions": {torch.float},
285+
"ir": ir,
286+
"min_block_size": 1,
287+
}
288+
inputs_bs2 = torch.randn(2, 2, 10).to("cuda")
289+
if ir == "torch_compile":
290+
torch._dynamo.mark_dynamic(inputs_bs2, 0, min=1, max=10)
291+
torch._dynamo.mark_dynamic(inputs_bs2, 1, min=1, max=10)
292+
# Compile the model
293+
trt_model = torch.compile(model, backend="tensorrt", options=compile_spec)
294+
trt_model(inputs_bs2)
295+
elif ir == "dynamo":
296+
dynamic_shapes = (
297+
{
298+
0: torch.export.Dim("batch_size", min=1, max=10),
299+
1: torch.export.Dim("seq_len", max=10),
300+
},
301+
)
302+
exp_program = torch.export.export(
303+
model, (inputs_bs2,), dynamic_shapes=dynamic_shapes
304+
)
305+
trt_model = torchtrt.dynamo.compile(exp_program, [inputs_bs2], **compile_spec)
306+
307+
input_bs6_s3 = torch.randn((6, 3, 10)).to("cuda")
308+
cos_sim = cosine_similarity(model(input_bs6_s3), trt_model(input_bs6_s3))
309+
assertions.assertTrue(
310+
cos_sim > COSINE_THRESHOLD,
311+
msg=f"test_linear model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
312+
)

0 commit comments

Comments
 (0)