Skip to content

Commit 5da146d

Browse files
tissue3Wei Wei
authored andcommitted
[fx2trt][bootcamp] Add support for 2 ops: torch.logical_or, torch.logical_xor (#41)
Summary: Pull Request resolved: pytorch/fx2trt#41 fx2trt is a tool we use to create a TensorRT engine from a PyTorch model. The lowering is composed of 1) start the Pytorch model 2) frace model with acc tracer in acc_ops 3) Use TRT Interpreter to create a TensorRT engine Here I: 1. Add corresponding acc ops 2. Add a converter for the acc op to acc_ops_converters.py. 3. Add a unit test for the converter in fbcode/deeplearning/trt/fx2trt_oss/test/converters/acc_op/test_logical_or/xor.py Reviewed By: frank-wei Differential Revision: D35237918 fbshipit-source-id: 82720b764f0c886749aafea84584cdcb5172d206
1 parent 4510e1a commit 5da146d

File tree

6 files changed

+236
-1
lines changed

6 files changed

+236
-1
lines changed

fx/converters/acc_ops_converters.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,6 +1313,73 @@ def acc_ops_lt(
13131313
network, input_t, other_t, trt.ElementWiseOperation.LESS, target, name
13141314
)
13151315

1316+
@tensorrt_converter(acc_ops.logical_or, no_implicit_batch_dim=True)
1317+
def acc_ops_logical_or(
1318+
network: TRTNetwork,
1319+
target: Target,
1320+
args: Tuple[Argument, ...],
1321+
kwargs: Dict[str, Argument],
1322+
name: str,
1323+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1324+
if network.has_implicit_batch_dimension:
1325+
raise RuntimeError("The `logical_or` function should be called with explicit batch dimension.")
1326+
1327+
input_t = kwargs["input"]
1328+
other_t = kwargs["other"]
1329+
if isinstance(other_t, (torch.Tensor, bool)):
1330+
if isinstance(other_t, bool):
1331+
other_t = int(other_t)
1332+
elif other_t.dtype == torch.bool:
1333+
other_t = other_t.to(torch.int32)
1334+
other_t = get_trt_tensor(network, other_t, f"{name}_other_t")
1335+
if input_t.dtype != trt.bool:
1336+
layer_i = network.add_identity(input_t)
1337+
layer_i.set_output_type(0, trt.bool)
1338+
set_layer_name(layer_i, target, f"{name}_input_dtype_change")
1339+
input_t = layer_i.get_output(0)
1340+
if other_t.dtype != trt.bool:
1341+
layer_o = network.add_identity(other_t)
1342+
layer_o.set_output_type(0, trt.bool)
1343+
set_layer_name(layer_o, target, f"{name}_other_dtype_change")
1344+
other_t = layer_o.get_output(0)
1345+
1346+
return add_binary_elementwise_layer(
1347+
network, input_t, other_t, trt.ElementWiseOperation.OR, target, name
1348+
)
1349+
1350+
@tensorrt_converter(acc_ops.logical_xor, no_implicit_batch_dim=True)
1351+
def acc_ops_logical_xor(
1352+
network: TRTNetwork,
1353+
target: Target,
1354+
args: Tuple[Argument, ...],
1355+
kwargs: Dict[str, Argument],
1356+
name: str,
1357+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1358+
if network.has_implicit_batch_dimension:
1359+
raise RuntimeError("The `logical_xor` function should be called with explicit batch dimension.")
1360+
1361+
input_t = kwargs["input"]
1362+
other_t = kwargs["other"]
1363+
if isinstance(other_t, (torch.Tensor, bool)):
1364+
if isinstance(other_t, bool):
1365+
other_t = int(other_t)
1366+
elif other_t.dtype == torch.bool:
1367+
other_t = other_t.to(torch.int32)
1368+
other_t = get_trt_tensor(network, other_t, f"{name}_other_t")
1369+
if input_t.dtype != trt.bool:
1370+
layer_i = network.add_identity(input_t)
1371+
layer_i.set_output_type(0, trt.bool)
1372+
set_layer_name(layer_i, target, f"{name}_input_dtype_change")
1373+
input_t = layer_i.get_output(0)
1374+
if other_t.dtype != trt.bool:
1375+
layer_o = network.add_identity(other_t)
1376+
layer_o.set_output_type(0, trt.bool)
1377+
set_layer_name(layer_o, target, f"{name}_other_dtype_change")
1378+
other_t = layer_o.get_output(0)
1379+
1380+
return add_binary_elementwise_layer(
1381+
network, input_t, other_t, trt.ElementWiseOperation.XOR, target, name
1382+
)
13161383

13171384
@tensorrt_converter(acc_ops.fmod)
13181385
def acc_ops_fmod(

fx/fx2trt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def output(self, target, args, kwargs):
297297
raise RuntimeError("TensorRT requires all outputs to be Tensor!")
298298

299299
for i, output in enumerate(outputs):
300-
if any(op_name in output.name.split("_") for op_name in ("eq", "gt", "lt")):
300+
if any(op_name in output.name.split("_") for op_name in ("eq", "gt", "lt", "or", "xor")):
301301
output_bool = True
302302
else:
303303
output_bool = False
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import torch
2+
import fx2trt_oss.tracer.acc_tracer.acc_ops as acc_ops
3+
from torch.testing._internal.common_fx2trt import AccTestCase
4+
from torch.testing._internal.common_utils import run_tests
5+
from parameterized import parameterized
6+
7+
class TestLogicalOrMethodSimpleConverter(AccTestCase):
8+
@parameterized.expand(
9+
[
10+
("rand_2d_bool_bool", torch.randn(3,4) > 0, torch.randn(3,4) > 0),
11+
("rand_3d_bool_bool", torch.randn(3,4,5) > 0, torch.randn(3,4,5) > 0),
12+
("rand_4d_bool_bool", torch.randn(3,4,5,6) > 0, torch.randn(3,4,5,6) > 0),
13+
("rand_2d_bool_single_bool", torch.randn(3,4) > 0, torch.tensor(0) > 0),
14+
("rand_2d_int_bool", torch.randn(3,4).to(torch.int), torch.randn(3,4) > 0),
15+
("rand_2d_int_single_bool", torch.randn(3,4).to(torch.int), torch.tensor(0) > 0),
16+
]
17+
)
18+
def test_logical_or(self, _, input, other):
19+
class LogicalOr(torch.nn.Module):
20+
def forward(self, x, y):
21+
return x.logical_or(y)
22+
23+
inputs = [
24+
input,
25+
other,
26+
]
27+
self.run_test(LogicalOr(), inputs, expected_ops={acc_ops.logical_or}, test_implicit_batch_dim = False)
28+
29+
class TestLogicalOrFunctionSimpleConverter(AccTestCase):
30+
@parameterized.expand(
31+
[
32+
("rand_2d_bool_bool", torch.randn(3,4) > 0, torch.randn(3,4) > 0),
33+
("rand_3d_bool_bool", torch.randn(3,4,5) > 0, torch.randn(3,4,5) > 0),
34+
("rand_4d_bool_bool", torch.randn(3,4,5,6) > 0, torch.randn(3,4,5,6) > 0),
35+
("rand_2d_bool_single_bool", torch.randn(3,4) > 0, torch.tensor(0) > 0),
36+
("rand_2d_int_bool", torch.randn(3,4).to(torch.int), torch.randn(3,4) > 0),
37+
("rand_2d_int_single_bool", torch.randn(3,4).to(torch.int), torch.tensor(0) > 0),
38+
]
39+
)
40+
def test_logical_or(self, _, input, other):
41+
class LogicalOr(torch.nn.Module):
42+
def forward(self, x, y):
43+
return torch.logical_or(x, y)
44+
45+
inputs = [
46+
input,
47+
other,
48+
]
49+
self.run_test(LogicalOr(), inputs, expected_ops={acc_ops.logical_or}, test_implicit_batch_dim = False)
50+
51+
class TestLogicalOrOperatorSimpleConverter(AccTestCase):
52+
@parameterized.expand(
53+
[
54+
("rand_2d_bool_bool", torch.randn(3,4) > 0, torch.randn(3,4) > 0),
55+
("rand_3d_bool_bool", torch.randn(3,4,5) > 0, torch.randn(3,4,5) > 0),
56+
("rand_4d_bool_bool", torch.randn(3,4,5,6) > 0, torch.randn(3,4,5,6) > 0),
57+
("rand_2d_bool_single_bool", torch.randn(3,4) > 0, torch.tensor(0) > 0),
58+
("rand_2d_int_bool", torch.randn(3,4).to(torch.int), torch.randn(3,4) > 0),
59+
("rand_2d_int_single_bool", torch.randn(3,4).to(torch.int), torch.tensor(0) > 0),
60+
]
61+
)
62+
def test_logical_or(self, _, input, other):
63+
class LogicalOr(torch.nn.Module):
64+
def forward(self, x, y):
65+
return x | y
66+
67+
inputs = [
68+
input,
69+
other,
70+
]
71+
self.run_test(LogicalOr(), inputs, expected_ops={acc_ops.logical_or}, test_implicit_batch_dim = False)
72+
73+
74+
if __name__ == '__main__':
75+
run_tests()
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import torch
2+
import fx2trt_oss.tracer.acc_tracer.acc_ops as acc_ops
3+
from torch.testing._internal.common_fx2trt import AccTestCase
4+
from torch.testing._internal.common_utils import run_tests
5+
from parameterized import parameterized
6+
7+
class TestLogicalXorMethodSimpleConverter(AccTestCase):
8+
@parameterized.expand(
9+
[
10+
("rand_2d_bool_bool", torch.randn(3,4) > 0, torch.randn(3,4) > 0),
11+
("rand_3d_bool_bool", torch.randn(3,4,5) > 0, torch.randn(3,4,5) > 0),
12+
("rand_4d_bool_bool", torch.randn(3,4,5,6) > 0, torch.randn(3,4,5,6) > 0),
13+
("rand_2d_bool_single_bool", torch.randn(3,4) > 0, torch.tensor(0) > 0),
14+
("rand_2d_int_bool", torch.randn(3,4).to(torch.int), torch.randn(3,4) > 0),
15+
("rand_2d_int_single_bool", torch.randn(3,4).to(torch.int), torch.tensor(0) > 0),
16+
]
17+
)
18+
def test_logical_xor(self, _, input, other):
19+
class LogicalXor(torch.nn.Module):
20+
def forward(self, x, y):
21+
return x.logical_xor(y)
22+
23+
inputs = [
24+
input,
25+
other,
26+
]
27+
self.run_test(LogicalXor(), inputs, expected_ops={acc_ops.logical_xor}, test_implicit_batch_dim = False)
28+
29+
class TestLogicalXorFunctionSimpleConverter(AccTestCase):
30+
@parameterized.expand(
31+
[
32+
("rand_2d_bool_bool", torch.randn(3,4) > 0, torch.randn(3,4) > 0),
33+
("rand_3d_bool_bool", torch.randn(3,4,5) > 0, torch.randn(3,4,5) > 0),
34+
("rand_4d_bool_bool", torch.randn(3,4,5,6) > 0, torch.randn(3,4,5,6) > 0),
35+
("rand_2d_bool_single_bool", torch.randn(3,4) > 0, torch.tensor(0) > 0),
36+
("rand_2d_int_bool", torch.randn(3,4).to(torch.int), torch.randn(3,4) > 0),
37+
("rand_2d_int_single_bool", torch.randn(3,4).to(torch.int), torch.tensor(0) > 0),
38+
]
39+
)
40+
def test_logical_xor(self, _, input, other):
41+
class LogicalXor(torch.nn.Module):
42+
def forward(self, x, y):
43+
return torch.logical_xor(x, y)
44+
45+
inputs = [
46+
input,
47+
other,
48+
]
49+
self.run_test(LogicalXor(), inputs, expected_ops={acc_ops.logical_xor}, test_implicit_batch_dim = False)
50+
51+
class TestLogicalXorOperatorSimpleConverter(AccTestCase):
52+
@parameterized.expand(
53+
[
54+
("rand_2d_bool_bool", torch.randn(3,4) > 0, torch.randn(3,4) > 0),
55+
("rand_3d_bool_bool", torch.randn(3,4,5) > 0, torch.randn(3,4,5) > 0),
56+
("rand_4d_bool_bool", torch.randn(3,4,5,6) > 0, torch.randn(3,4,5,6) > 0),
57+
("rand_2d_bool_single_bool", torch.randn(3,4) > 0, torch.tensor(0) > 0),
58+
("rand_2d_int_bool", torch.randn(3,4).to(torch.int), torch.randn(3,4) > 0),
59+
("rand_2d_int_single_bool", torch.randn(3,4).to(torch.int), torch.tensor(0) > 0),
60+
]
61+
)
62+
def test_logical_xor(self, _, input, other):
63+
class LogicalXor(torch.nn.Module):
64+
def forward(self, x, y):
65+
return x ^ y
66+
67+
inputs = [
68+
input,
69+
other,
70+
]
71+
self.run_test(LogicalXor(), inputs, expected_ops={acc_ops.logical_xor}, test_implicit_batch_dim = False)
72+
73+
74+
if __name__ == '__main__':
75+
run_tests()

test/tracer/test_acc_tracer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2364,6 +2364,8 @@ def test_all_acc_ops_registered(self):
23642364
acc_ops.eq,
23652365
acc_ops.gt,
23662366
acc_ops.lt,
2367+
acc_ops.logical_or,
2368+
acc_ops.logical_xor,
23672369
acc_ops.gather,
23682370
acc_ops.index_select,
23692371
acc_ops.interpolate,

tracer/acc_tracer/acc_ops.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,6 +1237,22 @@ def gt(*, input, other):
12371237
def lt(*, input, other):
12381238
return torch.lt(input=input, other=other)
12391239

1240+
@register_acc_op_properties(AccOpProperty.pointwise)
1241+
@register_acc_op_mapping(op_and_target=("call_function", operator.or_))
1242+
@register_acc_op_mapping(op_and_target=("call_function", torch.logical_or))
1243+
@register_acc_op_mapping(op_and_target=("call_method", "logical_or"))
1244+
@register_acc_op
1245+
def logical_or(*, input, other):
1246+
return torch.logical_or(input=input, other=other)
1247+
1248+
@register_acc_op_properties(AccOpProperty.pointwise)
1249+
@register_acc_op_mapping(op_and_target=("call_function", operator.xor))
1250+
@register_acc_op_mapping(op_and_target=("call_function", torch.logical_xor))
1251+
@register_acc_op_mapping(op_and_target=("call_method", "logical_xor"))
1252+
@register_acc_op
1253+
def logical_xor(*, input, other):
1254+
return torch.logical_xor(input=input, other=other)
1255+
12401256

12411257
@register_acc_op_properties(AccOpProperty.pointwise)
12421258
@register_acc_op_mapping(op_and_target=("call_function", torch.fmod))

0 commit comments

Comments
 (0)