Skip to content

Commit 148b3ba

Browse files
authored
Merge pull request #2198 from pytorch/dynamo_converter_where_type_mismatch
Type mismatch for dynamo aten::where converter
2 parents 91fcea4 + 8fdaaf5 commit 148b3ba

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,15 @@ def where(
6565
condition_val = condition_layer.get_output(0)
6666
else:
6767
assert condition.dtype == trt.bool, "mask dtype is not bool!"
68-
if condition_shape != condition_dim: # TODO: What is this checking?
68+
if len(condition_shape) != condition_dim:
6969
condition_val = expand(
7070
network, target, source_ir, f"{name}_expand", condition, output_shape
7171
)
7272
else:
7373
condition_val = condition
7474

7575
if type(input) != TRTTensor:
76-
if x_shape != input_dim: # TODO: What is this checking?
76+
if x_shape != output_shape:
7777
# special case where 1 element in input
7878
if len(input.shape) == 0:
7979
input = input.unsqueeze(0)
@@ -95,7 +95,7 @@ def where(
9595
y_val = get_trt_tensor(network, other, f"{name}_y")
9696
else:
9797
y_val = other
98-
if y_shape != other_dim: # TODO: What is this checking?
98+
if y_shape != output_shape:
9999
y_val = expand(
100100
network, target, source_ir, f"{name}_y_expand", y_val, output_shape
101101
)

tests/py/dynamo/converters/test_where_aten.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class TestWhereConverter(DispatchTestCase):
1212
("2d_broadcast_condition_xshape_yshape", (2, 2), (2, 1)),
1313
("3d_condition_xshape_yshape", (2, 2, 1), (2, 2, 1)),
1414
("2d_3d_condition_xshape_yshape", (2, 2), (1, 2, 2)),
15+
("3d_2d_condition_xshape_yshape", (1, 2, 2), (2, 2)),
1516
]
1617
)
1718
def test_(self, _, x_size, y_size):

0 commit comments

Comments
 (0)