Skip to content

Commit 88edc37

Browse files
committed
fix: Handle dynamic shapes in where ops
1 parent dfc31c7 commit 88edc37

File tree

2 files changed

+98
-62
lines changed

2 files changed

+98
-62
lines changed

py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py

Lines changed: 17 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Optional, Union
22

33
import numpy as np
4-
import tensorrt as trt
54
import torch
65
from torch.fx.node import Target
76
from torch_tensorrt.dynamo._SourceIR import SourceIR
@@ -10,8 +9,7 @@
109
broadcastable,
1110
get_trt_tensor,
1211
)
13-
from torch_tensorrt.dynamo.conversion.impl.slice import expand
14-
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
12+
from torch_tensorrt.fx.converters.converter_utils import prepend_ones, set_layer_name
1513
from torch_tensorrt.fx.types import TRTTensor
1614

1715

@@ -30,73 +28,30 @@ def where(
3028
x_shape = list(input.shape)
3129
y_shape = list(other.shape)
3230
condition_shape = list(condition.shape)
31+
max_shape_len = max(len(x_shape), len(y_shape), len(condition_shape))
3332

34-
output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape))
35-
36-
# expand shape
3733
if not isinstance(condition, TRTTensor):
3834
assert condition.dtype in (torch.bool, np.bool_), "condition dtype is not bool"
39-
if condition_shape != output_shape:
40-
condition = (
41-
condition.expand(output_shape)
42-
if isinstance(condition, torch.Tensor)
43-
else np.broadcast_to(condition, output_shape)
44-
)
45-
condition_val = get_trt_tensor(ctx, condition, f"{name}_condition")
46-
else:
47-
assert condition.dtype == trt.bool, "mask dtype is not bool!"
48-
if condition_shape != output_shape:
49-
condition_val = expand(
50-
ctx, target, source_ir, f"{name}_expand", condition, output_shape
51-
)
52-
else:
53-
condition_val = condition
35+
condition = get_trt_tensor(ctx, condition, f"{name}_condition")
36+
diff = max_shape_len - len(condition_shape)
37+
if diff > 0:
38+
condition = prepend_ones(
39+
ctx.net, condition, f"{name}_condition_broadcast", diff
40+
)
5441

5542
if not isinstance(input, TRTTensor):
56-
if x_shape != output_shape:
57-
# special case where 1 element in input
58-
if len(input.shape) == 0:
59-
input = (
60-
input.unsqueeze(0)
61-
if isinstance(input, torch.Tensor)
62-
else np.expand_dims(input, axis=0)
63-
)
64-
input = (
65-
input.expand(output_shape)
66-
if isinstance(input, torch.Tensor)
67-
else np.broadcast_to(input, output_shape)
68-
)
69-
x_val = get_trt_tensor(ctx, input, f"{name}_x")
70-
else:
71-
x_val = input
72-
if x_shape != output_shape:
73-
x_val = expand(
74-
ctx, target, source_ir, f"{name}_x_expand", input, output_shape
75-
)
43+
input = get_trt_tensor(ctx, input, f"{name}_x")
44+
diff = max_shape_len - len(x_shape)
45+
if diff > 0:
46+
input = prepend_ones(ctx.net, input, f"{name}_input_broadcast", diff)
7647

7748
if not isinstance(other, TRTTensor):
78-
if y_shape != output_shape:
79-
# special case where 1 element in other
80-
if len(other.shape) == 0:
81-
other = (
82-
other.unsqueeze(0)
83-
if isinstance(other, torch.Tensor)
84-
else np.expand_dims(other, axis=0)
85-
)
86-
other = (
87-
other.expand(output_shape)
88-
if isinstance(other, torch.Tensor)
89-
else np.broadcast_to(other, output_shape)
90-
)
91-
y_val = get_trt_tensor(ctx, other, f"{name}_y")
92-
else:
93-
y_val = other
94-
if y_shape != output_shape:
95-
y_val = expand(
96-
ctx, target, source_ir, f"{name}_y_expand", y_val, output_shape
97-
)
49+
other = get_trt_tensor(ctx, other, f"{name}_y")
50+
diff = max_shape_len - len(y_shape)
51+
if diff > 0:
52+
other = prepend_ones(ctx.net, other, f"{name}_other_broadcast", diff)
9853

99-
return select(ctx, target, source_ir, name, x_val, y_val, condition_val)
54+
return select(ctx, target, source_ir, name, input, other, condition)
10055

10156

10257
def select(

tests/py/dynamo/conversion/test_where_aten.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -76,6 +77,86 @@ def forward(self, condition):
7677
(condition,),
7778
)
7879

80+
# shape, min shape range, opt shape range, max shape range
81+
@parameterized.expand(
82+
[
83+
(
84+
"3d_condition_3d_xshape_3d_yshape",
85+
(-1, -1, -1),
86+
(1, 1, 1),
87+
(1, 2, 3),
88+
(3, 3, 3),
89+
(-1, -1, -1),
90+
(1, 1, 1),
91+
(1, 2, 3),
92+
(3, 3, 3),
93+
(-1, -1, -1),
94+
(1, 1, 1),
95+
(1, 2, 3),
96+
(3, 3, 3),
97+
),
98+
(
99+
"1d_condition_3d_xshape_2d_yshape",
100+
(-1),
101+
(1,),
102+
(2,),
103+
(4,),
104+
(-1, -1, -1),
105+
(1, 1, 1),
106+
(3, 2, 2),
107+
(3, 2, 4),
108+
(
109+
-1,
110+
-1,
111+
),
112+
(1, 1),
113+
(2, 2),
114+
(2, 4),
115+
),
116+
(
117+
"2d_condition_3d_xshape_2d_yshape",
118+
(-1, -1),
119+
(4, 1),
120+
(4, 2),
121+
(5, 4),
122+
(-1, -1, -1),
123+
(1, 1, 1),
124+
(3, 1, 2),
125+
(3, 1, 4),
126+
(
127+
-1,
128+
-1,
129+
),
130+
(4, 1),
131+
(4, 2),
132+
(5, 4),
133+
),
134+
]
135+
)
136+
def test_with_dynamic_shape(self, *args):
137+
class Where(nn.Module):
138+
def forward(self, condition, x, y):
139+
return torch.ops.aten.where.self(condition, x, y)
140+
141+
input_specs = [
142+
Input(
143+
shape=args[1],
144+
dtype=torch.bool,
145+
shape_ranges=[(args[2], args[3], args[4])],
146+
),
147+
Input(
148+
shape=args[5],
149+
dtype=torch.float32,
150+
shape_ranges=[(args[6], args[7], args[8])],
151+
),
152+
Input(
153+
shape=args[9],
154+
dtype=torch.float32,
155+
shape_ranges=[(args[10], args[11], args[12])],
156+
),
157+
]
158+
self.run_test_with_dynamic_shape(Where(), input_specs)
159+
79160

80161
if __name__ == "__main__":
81162
run_tests()

0 commit comments

Comments
 (0)