Skip to content

Commit 5f83e2c

Browse files
author
Wei Wei
committed
[fx2trt] support for ne, logical_not, logical_and (#75444)
Summary: X-link: pytorch/pytorch#75444 as titled 1. support logical_and, logical_not 2. replace eq,gt,lt with python operator in acc_ops due to the fact that torch op needs input to be torch.Tensor but python op does not 3. add more test cases 4. add individual ne op without using combination of existing ops since there are limitations. For ex, in lowering it to equal+logical_not. It will fail the last test case in test_ne.py. The failure reason is that logical_not needs the input to be a tensor. Also we can not use equal+operator.not since not is not tracable in FX with the error "symbolically traced variables cannot be used as inputs to control flow" We also can not use equal+operator.invert since operator.invert(True)=-2 (Note: this ignores all push blocking failures!) Reviewed By: 842974287 Differential Revision: D35232917 fbshipit-source-id: d4601a6883c977caa263f67b9db86cbc862d4780
1 parent 6ebec00 commit 5f83e2c

File tree

11 files changed

+454
-23
lines changed

11 files changed

+454
-23
lines changed

fx/converters/acc_ops_converters.py

Lines changed: 90 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,6 +1237,87 @@ def acc_ops_minimum(
12371237
)
12381238

12391239

1240+
1241+
@tensorrt_converter(acc_ops.logical_not)
1242+
def acc_ops_logical_not(
1243+
network: TRTNetwork,
1244+
target: Target,
1245+
args: Tuple[Argument, ...],
1246+
kwargs: Dict[str, Argument],
1247+
name: str,
1248+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1249+
input_val = kwargs["input"]
1250+
operation_type = trt.UnaryOperation.NOT
1251+
# cast to bool type
1252+
if input_val.dtype in (trt.float32, trt.float16, trt.int32):
1253+
input_val = type_cast(network, target, f"{name}_input", input_val, trt.bool)
1254+
return add_unary_layer(network, input_val, operation_type, target, name)
1255+
1256+
1257+
@tensorrt_converter(acc_ops.logical_and, no_implicit_batch_dim=True)
1258+
@tensorrt_converter(acc_ops.bitwise_and, no_implicit_batch_dim=True)
1259+
def acc_ops_logical_and(
1260+
network: TRTNetwork,
1261+
target: Target,
1262+
args: Tuple[Argument, ...],
1263+
kwargs: Dict[str, Argument],
1264+
name: str,
1265+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1266+
if network.has_implicit_batch_dimension:
1267+
raise RuntimeError("The `logical_and` function should be called with explicit batch dimension.")
1268+
1269+
input_t = kwargs["input"]
1270+
other_t = kwargs["other"]
1271+
# we only support both inputs are bool type
1272+
if target == acc_ops.bitwise_and:
1273+
def check_is_bool(input_t):
1274+
if isinstance(input_t, TRTTensor):
1275+
assert input_t.dtype == trt.bool, "We currently do not support input is non-bool"
1276+
elif isinstance(input_t, torch.Tensor):
1277+
assert input_t.dtype == torch.bool, "We currently do not support input is non-bool"
1278+
else:
1279+
assert isinstance(input_t. bool), "We currently do not support input is non-bool"
1280+
1281+
check_is_bool(input_t)
1282+
check_is_bool(other_t)
1283+
1284+
input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
1285+
other_t = get_trt_tensor(network, other_t, f"{name}_other_t")
1286+
1287+
if input_t.dtype != trt.bool:
1288+
input_t = type_cast(network, target, f"{name}_input", input_t, trt.bool)
1289+
if other_t.dtype != trt.bool:
1290+
other_t = type_cast(network, target, f"{name}_other", other_t, trt.bool)
1291+
return add_binary_elementwise_layer(
1292+
network, input_t, other_t, trt.ElementWiseOperation.AND, target, name
1293+
)
1294+
1295+
1296+
@tensorrt_converter(acc_ops.ne, no_implicit_batch_dim=True)
1297+
def acc_ops_ne(
1298+
network: TRTNetwork,
1299+
target: Target,
1300+
args: Tuple[Argument, ...],
1301+
kwargs: Dict[str, Argument],
1302+
name: str,
1303+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1304+
if network.has_implicit_batch_dimension:
1305+
raise RuntimeError("The `ne` function should be called with explicit batch dimension.")
1306+
1307+
input_t = kwargs["input"]
1308+
other_t = kwargs["other"]
1309+
1310+
input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
1311+
other_t = get_trt_tensor(network, other_t, f"{name}_other_t")
1312+
1313+
input_t, other_t = dtype_uniform(network, target, name, input_t, other_t)
1314+
eq_t = add_binary_elementwise_layer(
1315+
network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name
1316+
)
1317+
1318+
return add_unary_layer(network, eq_t, trt.UnaryOperation.NOT, target, name)
1319+
1320+
12401321
@tensorrt_converter(acc_ops.eq, no_implicit_batch_dim=True)
12411322
def acc_ops_eq(
12421323
network: TRTNetwork,
@@ -1250,14 +1331,11 @@ def acc_ops_eq(
12501331

12511332
input_t = kwargs["input"]
12521333
other_t = kwargs["other"]
1253-
if isinstance(other_t, (torch.Tensor, bool)):
1254-
if isinstance(other_t, bool):
1255-
other_t = int(other_t)
1256-
elif other_t.dtype == torch.bool:
1257-
other_t = other_t.to(torch.int32)
1334+
1335+
input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
12581336
other_t = get_trt_tensor(network, other_t, f"{name}_other_t")
1259-
input_t, other_t = dtype_uniform(network, target, name, input_t, other_t)
12601337

1338+
input_t, other_t = dtype_uniform(network, target, name, input_t, other_t)
12611339
return add_binary_elementwise_layer(
12621340
network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name
12631341
)
@@ -1276,12 +1354,10 @@ def acc_ops_gt(
12761354

12771355
input_t = kwargs["input"]
12781356
other_t = kwargs["other"]
1279-
if isinstance(other_t, (torch.Tensor, bool)):
1280-
if isinstance(other_t, bool):
1281-
other_t = int(other_t)
1282-
elif other_t.dtype == torch.bool:
1283-
other_t = other_t.to(torch.int32)
1357+
1358+
input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
12841359
other_t = get_trt_tensor(network, other_t, f"{name}_other_t")
1360+
12851361
input_t, other_t = dtype_uniform(network, target, name, input_t, other_t)
12861362
return add_binary_elementwise_layer(
12871363
network, input_t, other_t, trt.ElementWiseOperation.GREATER, target, name
@@ -1301,14 +1377,11 @@ def acc_ops_lt(
13011377

13021378
input_t = kwargs["input"]
13031379
other_t = kwargs["other"]
1304-
if isinstance(other_t, (torch.Tensor, bool)):
1305-
if isinstance(other_t, bool):
1306-
other_t = int(other_t)
1307-
elif other_t.dtype == torch.bool:
1308-
other_t = other_t.to(torch.int32)
1380+
1381+
input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
13091382
other_t = get_trt_tensor(network, other_t, f"{name}_other_t")
1310-
input_t, other_t = dtype_uniform(network, target, name, input_t, other_t)
13111383

1384+
input_t, other_t = dtype_uniform(network, target, name, input_t, other_t)
13121385
return add_binary_elementwise_layer(
13131386
network, input_t, other_t, trt.ElementWiseOperation.LESS, target, name
13141387
)

fx/converters/converter_utils.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,16 @@
88
import tensorrt as trt
99
import torch
1010
from torch.fx.node import Target, Argument
11-
from fx2trt_oss.fx.types import * # noqa: F403
11+
from fx2trt_oss.fx.types import (
12+
TRTNetwork,
13+
TRTTensor,
14+
TRTLayer,
15+
TRTPluginFieldCollection,
16+
TRTPlugin,
17+
TRTDataType,
18+
TRTElementWiseOp,
19+
Shape
20+
)
1221
from fx2trt_oss.fx.utils import torch_dtype_from_trt
1322

1423

@@ -223,6 +232,15 @@ def get_trt_tensor(
223232
Returns:
224233
A TensorRT ITensor that represents the given value.
225234
"""
235+
# TRT can not add constant for bool type. We do a work around to 1) cast it to int and 2)cast to bool later
236+
# This is useful for logical operations which require input to be bool type
237+
if isinstance(input_val, bool):
238+
input_val = int(input_val)
239+
if isinstance(input_val, torch.Tensor) and input_val.dtype == torch.bool:
240+
input_val = input_val.to(torch.int32)
241+
if isinstance(input_val, torch.Tensor) and input_val.dtype == torch.int64:
242+
input_val = input_val.to(torch.int32)
243+
226244
if isinstance(input_val, (torch.Tensor, int, float)):
227245
return create_constant(network, input_val, name, dtype)
228246
elif not isinstance(input_val, TRTTensor):
@@ -439,6 +457,8 @@ def add_unary_layer(
439457
)
440458
layer = network.add_unary(input_val, operation_type)
441459
set_layer_name(layer, target, name)
460+
output = layer.get_output(0)
461+
output.name = output.name + "_" + target.__name__
442462
return layer.get_output(0)
443463

444464

@@ -672,7 +692,7 @@ def get_python_op_from_trt_elementwise_op(trt_op: TRTElementWiseOp) -> Callable[
672692
else:
673693
raise RuntimeError(f"{trt_op} is not supported yet!")
674694

675-
def dtype_uniform(network, target, name, input, other):
695+
def dtype_uniform(network: TRTNetwork, target: Target, name: str, input: TRTTensor, other: TRTTensor):
676696
table = {trt.bool:0, trt.int32:1, trt.float16:2, trt.float32:3}
677697
input_dtype = input.dtype
678698
other_dtype = other.dtype
@@ -697,3 +717,12 @@ def dtype_uniform(network, target, name, input, other):
697717
set_layer_name(layer_o, target, f"{name}_other_dtype_change")
698718
other = layer_o.get_output(0)
699719
return input, other
720+
721+
def type_cast(network: TRTNetwork, target: Target, name: str, input: TRTTensor, cast_type: TRTDataType):
722+
"""
723+
This function helps to cast the input type to cast_type
724+
"""
725+
layer_i = network.add_identity(input)
726+
layer_i.set_output_type(0, cast_type)
727+
set_layer_name(layer_i, target, f"{name}_dtype_change")
728+
return layer_i.get_output(0)

fx/fx2trt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def output(self, target, args, kwargs):
297297
raise RuntimeError("TensorRT requires all outputs to be Tensor!")
298298

299299
for i, output in enumerate(outputs):
300-
if any(op_name in output.name.split("_") for op_name in ("eq", "gt", "lt", "or", "xor")):
300+
if any(op_name in output.name.split("_") for op_name in ("eq", "gt", "lt", "or", "xor", "and", "not", "ne")):
301301
output_bool = True
302302
else:
303303
output_bool = False

test/converters/acc_op/test_eq.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,21 @@ def forward(self, x):
127127
]
128128
self.run_test(Eq(), inputs, expected_ops={acc_ops.eq}, test_implicit_batch_dim = False)
129129

130+
class TestConstInputConverter(AccTestCase):
131+
def test_eq(self):
132+
class Eq(torch.nn.Module):
133+
def __init__(self):
134+
super().__init__()
135+
136+
def forward(self, x):
137+
return x.shape[0] == 4
138+
139+
input = torch.randn(3,4)
140+
inputs = [
141+
input,
142+
]
143+
self.run_test(Eq(), inputs, expected_ops={acc_ops.eq}, test_implicit_batch_dim = False)
144+
130145

131146
if __name__ == '__main__':
132147
run_tests()

test/converters/acc_op/test_gt.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,5 +124,22 @@ def forward(self, x):
124124
]
125125
self.run_test(Eq(), inputs, expected_ops={acc_ops.gt}, test_implicit_batch_dim = False)
126126

127+
128+
class TestConstInputConverter(AccTestCase):
129+
def test_gt(self):
130+
class Gt(torch.nn.Module):
131+
def __init__(self):
132+
super().__init__()
133+
134+
def forward(self, x):
135+
return x.shape[0] > 4
136+
137+
input = torch.randn(3,4)
138+
inputs = [
139+
input,
140+
]
141+
self.run_test(Gt(), inputs, expected_ops={acc_ops.gt}, test_implicit_batch_dim = False)
142+
143+
127144
if __name__ == '__main__':
128145
run_tests()
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import torch
2+
import fx2trt_oss.tracer.acc_tracer.acc_ops as acc_ops
3+
from torch.testing._internal.common_fx2trt import AccTestCase
4+
from torch.testing._internal.common_utils import run_tests
5+
from parameterized import parameterized
6+
7+
class TestAndMethodSimpleConverter(AccTestCase):
8+
@parameterized.expand(
9+
[
10+
("rand_2d_float_bool", torch.randn(3,4), torch.randn(3,4).to(torch.bool)),
11+
("rand_2d_int_bool", torch.randn(3,4).to(torch.int), torch.randn(3,4).to(torch.bool)),
12+
("rand_2d_bool_bool", torch.randn(3,4).to(torch.bool), torch.randn(3,4).to(torch.bool)),
13+
("rand_2d_float_int", torch.randn(3,4).to(torch.float), torch.randn(3,4).to(torch.int)),
14+
("rand_2d_float_single_bool", torch.randn(3,4), torch.tensor(0).to(torch.bool)),
15+
("rand_2d_int_single_bool", torch.randn(3,4).to(torch.int), torch.tensor(0).to(torch.bool)),
16+
("rand_2d_bool_single_bool", torch.randn(3,4).to(torch.bool), torch.tensor(0).to(torch.bool)),
17+
]
18+
)
19+
def test_and(self, _, input, other):
20+
class And(torch.nn.Module):
21+
def forward(self, x, y):
22+
return x.logical_and(y)
23+
24+
inputs = [
25+
input,
26+
other,
27+
]
28+
self.run_test(And(), inputs, expected_ops={acc_ops.logical_and}, test_implicit_batch_dim = False)
29+
30+
class TestAndFunctionSimpleConverter(AccTestCase):
31+
@parameterized.expand(
32+
[
33+
("rand_2d_float_bool", torch.randn(3,4), torch.randn(3,4).to(torch.bool)),
34+
("rand_2d_int_bool", torch.randn(3,4).to(torch.int), torch.randn(3,4).to(torch.bool)),
35+
("rand_2d_bool_bool", torch.randn(3,4).to(torch.bool), torch.randn(3,4).to(torch.bool)),
36+
("rand_2d_float_int", torch.randn(3,4).to(torch.float), torch.randn(3,4).to(torch.int)),
37+
("rand_2d_float_single_bool", torch.randn(3,4), torch.tensor(0).to(torch.bool)),
38+
("rand_2d_int_single_bool", torch.randn(3,4).to(torch.int), torch.tensor(0).to(torch.bool)),
39+
("rand_2d_bool_single_bool", torch.randn(3,4).to(torch.bool), torch.tensor(0).to(torch.bool)),
40+
]
41+
)
42+
def test_and(self, _, input, other):
43+
class And(torch.nn.Module):
44+
def forward(self, x, y):
45+
return torch.logical_and(x, y)
46+
47+
inputs = [
48+
input,
49+
other,
50+
]
51+
self.run_test(And(), inputs, expected_ops={acc_ops.logical_and}, test_implicit_batch_dim = False)
52+
53+
class TestAndOperatorSimpleConverter(AccTestCase):
54+
@parameterized.expand(
55+
[
56+
("rand_2d_bool_bool", torch.randn(3,4).to(torch.bool), torch.randn(3,4).to(torch.bool)),
57+
("rand_2d_bool_single_bool", torch.randn(3,4).to(torch.bool), torch.tensor(0).to(torch.bool)),
58+
]
59+
)
60+
def test_and(self, _, input, other):
61+
class And(torch.nn.Module):
62+
def forward(self, x, y):
63+
return x & y
64+
65+
inputs = [
66+
input,
67+
other,
68+
]
69+
self.run_test(And(), inputs, expected_ops={acc_ops.bitwise_and}, test_implicit_batch_dim = False)
70+
71+
72+
class TestAndOperatorConstantConverter(AccTestCase):
73+
@parameterized.expand(
74+
[
75+
("rand_2d_bool_bool", torch.randn(3,4).to(torch.bool), torch.randn(3,4).to(torch.bool)),
76+
("rand_2d_bool_single_bool", torch.randn(3,4).to(torch.bool), torch.tensor(0).to(torch.bool)),
77+
]
78+
)
79+
def test_and(self, _, input, other):
80+
class And(torch.nn.Module):
81+
def __init__(self):
82+
super().__init__()
83+
self.other = other
84+
85+
def forward(self, x):
86+
return x & self.other
87+
88+
inputs = [
89+
input,
90+
]
91+
self.run_test(And(), inputs, expected_ops={acc_ops.bitwise_and}, test_implicit_batch_dim = False)
92+
93+
94+
if __name__ == '__main__':
95+
run_tests()

test/converters/acc_op/test_lt.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,5 +124,22 @@ def forward(self, x):
124124
]
125125
self.run_test(Eq(), inputs, expected_ops={acc_ops.lt}, test_implicit_batch_dim = False)
126126

127+
128+
class TestConstInputConverter(AccTestCase):
129+
def test_lt(self):
130+
class Lt(torch.nn.Module):
131+
def __init__(self):
132+
super().__init__()
133+
134+
def forward(self, x):
135+
return x.shape[0] < 4
136+
137+
input = torch.randn(3,4)
138+
inputs = [
139+
input,
140+
]
141+
self.run_test(Lt(), inputs, expected_ops={acc_ops.lt}, test_implicit_batch_dim = False)
142+
143+
127144
if __name__ == '__main__':
128145
run_tests()

0 commit comments

Comments
 (0)