Skip to content

Commit e77e445

Browse files
authored
fix: Repair broadcasting utility for aten.where (#2228)
1 parent 6a69c6a commit e77e445

File tree

2 files changed

+29
-17
lines changed

2 files changed

+29
-17
lines changed

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -137,25 +137,23 @@ def broadcastable(
137137
"Check if two tensors are broadcastable according to torch rules"
138138
a_shape = tuple(a.shape)
139139
b_shape = tuple(b.shape)
140+
140141
# check from the trailing
141142
diff = len(a_shape) - len(b_shape)
142-
if diff == 0:
143+
144+
# Validate tensors have same rank and shape
145+
if diff == 0 and all(a_shape[i] == b_shape[i] for i in range(len(a_shape))):
143146
return True
147+
148+
# Left-pad the shorter dimension with ones
144149
if diff > 0:
145-
max = len(a_shape)
146-
min = len(b_shape)
147-
greater_tensor = a_shape
148-
lesser_tensor = b_shape
149-
elif diff < 0:
150-
max = len(b_shape)
151-
min = len(a_shape)
152-
greater_tensor = b_shape
153-
lesser_tensor = a_shape
154-
j = min - 1
155-
for i in range(max - 1, diff - 1, -1):
156-
if not (
157-
greater_tensor[i] != lesser_tensor[j]
158-
and (greater_tensor[i] == 1 or lesser_tensor[i] == 1)
159-
):
150+
b_shape = (1,) * abs(diff) + b_shape
151+
else:
152+
a_shape = (1,) * abs(diff) + a_shape
153+
154+
# Validate one of the following conditions for broadcastability per-dimension
155+
# 1. Equal number of dimensions or 2. Dimension has shape 1
156+
for i in range(len(a_shape)):
157+
if not (a_shape[i] == b_shape[i] or a_shape[i] == 1 or b_shape[i] == 1):
160158
return False
161159
return True

tests/py/dynamo/converters/test_where_aten.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import torch
22
import torch.nn as nn
3+
from harness import DispatchTestCase
34
from parameterized import parameterized
45
from torch.testing._internal.common_utils import run_tests
5-
from harness import DispatchTestCase
66

77

88
class TestWhereConverter(DispatchTestCase):
@@ -28,6 +28,20 @@ def forward(self, condition, x, y):
2828
expected_ops={torch.ops.aten.where.self},
2929
)
3030

31+
def test_0D_input(self):
32+
class Where(nn.Module):
33+
def forward(self, condition, x, y):
34+
return torch.where(condition, x, y)
35+
36+
inputX = torch.randn((5, 6, 7, 1, 3))
37+
inputOther = torch.tensor(8.0, dtype=torch.float)
38+
condition = inputX < 0
39+
self.run_test(
40+
Where(),
41+
(condition, inputX, inputOther),
42+
expected_ops={torch.ops.aten.where.self},
43+
)
44+
3145

3246
if __name__ == "__main__":
3347
run_tests()

0 commit comments

Comments
 (0)