Skip to content

Commit 91c4988

Browse files
committed
test cases for boolean input
1 parent 5a4daf6 commit 91c4988

File tree

2 files changed

+65
-1
lines changed

2 files changed

+65
-1
lines changed

py/torch_tensorrt/dynamo/conversion/impl/reduce.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,12 +220,13 @@ def any(
220220
dim: Sequence[int] = [],
221221
keepdim: bool = False,
222222
) -> TRTTensor:
223+
num_out = cast_trt_tensor(ctx, input_val, trt.float32, f"{name}_cast")
223224
abs_out = impl.unary.abs(
224225
ctx,
225226
target,
226227
source_ir,
227228
f"{name}_abs",
228-
input_val,
229+
num_out,
229230
)
230231
max_out = amax(ctx, target, source_ir, f"{name}_amax", abs_out, dim, keepdim)
231232

tests/py/dynamo/conversion/test_any.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,69 @@ def forward(self, x):
126126
output_dtypes=[torch.bool],
127127
)
128128

129+
@parameterized.expand(
130+
[
131+
((2, 3, 4), torch.int, -5, 0),
132+
((6, 7, 5, 4, 5), torch.int, -5, 5),
133+
((1, 5, 2, 1), torch.int, -5, 5),
134+
]
135+
)
136+
def test_any_default_bool_dtype(self, input_shape, dtype, low, high):
137+
class Any(nn.Module):
138+
def forward(self, x):
139+
return torch.ops.aten.any.default(x)
140+
141+
inputs = [torch.randint(low, high, input_shape, dtype=dtype).bool()]
142+
self.run_test(
143+
Any(),
144+
inputs,
145+
output_dtypes=[torch.bool],
146+
)
147+
148+
@parameterized.expand(
149+
[
150+
((3, 2, 4), 1, True, torch.int, 0, 5),
151+
((2, 3, 4, 5), 3, True, torch.int, -10, 10),
152+
((2, 3, 4, 5), 2, False, torch.int32, -5, 0),
153+
((6, 7, 5, 4, 5), 4, False, torch.int32, -5, 5),
154+
((1, 5, 2, 1), -4, False, torch.int32, -5, 5),
155+
]
156+
)
157+
def test_any_dim_bool_dtype(self, input_shape, dim, keep_dims, dtype, low, high):
158+
class AnyDim(nn.Module):
159+
def forward(self, x):
160+
return torch.ops.aten.any.dim(x, dim, keep_dims)
161+
162+
inputs = [torch.randint(low, high, input_shape, dtype=dtype).bool()]
163+
self.run_test(
164+
AnyDim(),
165+
inputs,
166+
output_dtypes=[torch.bool],
167+
)
168+
169+
@parameterized.expand(
170+
[
171+
((3, 2, 4), [1], True, torch.int, 0, 5),
172+
((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10),
173+
((2, 3, 4, 5), [0, 1, 2, 3], False, torch.int32, -5, 0),
174+
((6, 7, 5, 4, 5), [1, 3, 4], False, torch.int32, -5, 5),
175+
((1, 5, 2, 1), [-3, -1], False, torch.int32, -5, 5),
176+
]
177+
)
178+
def test_any_dims_tuple_bool_dtype(
179+
self, input_shape, dims, keep_dims, dtype, low, high
180+
):
181+
class AnyDims(nn.Module):
182+
def forward(self, x):
183+
return torch.ops.aten.any.dims(x, dims, keep_dims)
184+
185+
inputs = [torch.randint(low, high, input_shape, dtype=dtype).bool()]
186+
self.run_test(
187+
AnyDims(),
188+
inputs,
189+
output_dtypes=[torch.bool],
190+
)
191+
129192

130193
if __name__ == "__main__":
131194
run_tests()

0 commit comments

Comments
 (0)