Skip to content

Commit 159e77c

Browse files
Shirong WuWei Wei
authored andcommitted
Add baddbmm operator support (#14)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/14 Add support for torch.baddbmm. Add unit test to cover bmm and addmm, both are mapping to other converter insteand of having new acc_op converter. Reviewed By: yinghai Differential Revision: D34743279 fbshipit-source-id: efd417b2b494635c63f2ec58daa3fc568f72111a
1 parent e488ad1 commit 159e77c

File tree

3 files changed

+171
-1
lines changed

3 files changed

+171
-1
lines changed

test/converters/acc_op/test_mm_ops.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from torch.testing._internal.common_fx2trt import AccTestCase
2+
from parameterized import parameterized
3+
from torch import nn
4+
import torch
5+
from fx2trt_oss.tracer.acc_tracer import acc_ops
6+
7+
8+
9+
class TestBmmConverters(AccTestCase):
10+
@parameterized.expand(
11+
[
12+
("1", (10, 2, 13), (10, 13, 2)),
13+
]
14+
)
15+
def test_bmm(self, _, input_shape, other_shape, alpha=1, beta=1):
16+
class Bmm(nn.Module):
17+
def forward(self, input, other):
18+
return torch.bmm(input, other)
19+
20+
inputs = [torch.randn(*input_shape), torch.randn(*other_shape)]
21+
22+
self.run_test(
23+
Bmm(),
24+
inputs,
25+
expected_ops={acc_ops.matmul}
26+
)
27+
28+
@parameterized.expand(
29+
[
30+
("default", (2, 3), (2, 3), (3, 3)),
31+
("broadcast", (1, 1), (10, 7), (7, 13)),
32+
]
33+
)
34+
def test_addmm(self, _, input_shape, m1_shape, m2_shape, alpha=1, beta=1):
35+
class Addmm(nn.Module):
36+
def forward(self, input, m1, m2):
37+
return torch.addmm(input, m1, m2, alpha=alpha, beta=beta)
38+
39+
inputs = [torch.randn(input_shape), torch.randn(*m1_shape), torch.randn(*m2_shape)]
40+
test_implicit_batch_dim = len(input_shape) > 2
41+
42+
self.run_test(
43+
Addmm(),
44+
inputs,
45+
expected_ops={acc_ops.matmul, acc_ops.add},
46+
test_implicit_batch_dim = test_implicit_batch_dim,
47+
)
48+
49+
@parameterized.expand(
50+
[
51+
("default", (10, 2, 2), (10, 2, 3), (10, 3, 2)),
52+
("broadcast", (10, 2, 1), (10, 2, 3), (10, 3, 2)),
53+
]
54+
)
55+
def test_baddbmm(self, _, input_shape, m1_shape, m2_shape, alpha=1, beta=1):
56+
class Baddbmm(nn.Module):
57+
def forward(self, input, m1, m2):
58+
return torch.baddbmm(input, m1, m2, alpha=alpha, beta=beta)
59+
60+
inputs = [torch.randn(*input_shape), torch.randn(*m1_shape), torch.randn(*m2_shape)]
61+
62+
self.run_test(
63+
Baddbmm(),
64+
inputs,
65+
expected_ops={acc_ops.matmul, acc_ops.add}
66+
)

test/tracer/test_acc_tracer.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1314,6 +1314,49 @@ def test_bmm(self):
13141314
acc_ops.matmul, lambda x: torch.bmm(x, x), input_shape=(2, 4, 4)
13151315
)
13161316

1317+
def test_baddbmm_with_alpha_beta(self):
1318+
class TestModule(torch.nn.Module):
1319+
def forward(self, input: torch.Tensor, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
1320+
return torch.baddbmm(input, a, b, alpha = 1.2, beta = 1.1)
1321+
1322+
m = TestModule()
1323+
input, a, b = torch.randn(10, 2, 3), torch.randn(10, 2, 4), torch.randn(10, 4, 3)
1324+
traced = acc_tracer.trace(m, [input, a, b])
1325+
1326+
ph_in = ph_a = ph_b = mm = add = mm_mul = add_mul = None
1327+
for node in traced.graph.nodes:
1328+
if node.op == "placeholder":
1329+
if str(node.target) == "a":
1330+
ph_a = node
1331+
elif str(node.target) == "b":
1332+
ph_b = node
1333+
else:
1334+
self.assertTrue(str(node.target) == "input")
1335+
ph_in = node
1336+
elif node.op == "call_function":
1337+
if node.target == acc_ops.matmul:
1338+
self.assertEqual(node.kwargs["input"], ph_a)
1339+
self.assertEqual(node.kwargs["other"], ph_b)
1340+
mm = node
1341+
elif node.target == acc_ops.add:
1342+
self.assertEqual(node.kwargs["input"], mm_mul)
1343+
self.assertEqual(node.kwargs["other"], add_mul)
1344+
add = node
1345+
elif mm_mul:
1346+
self.assertEqual(node.kwargs["input"], ph_in)
1347+
self.assertEqual(node.kwargs["other"], 1.1)
1348+
add_mul = node
1349+
1350+
else:
1351+
self.assertEqual(node.kwargs["input"], mm)
1352+
self.assertEqual(node.kwargs["other"], 1.2)
1353+
mm_mul = node
1354+
elif node.op == "output":
1355+
self.assertEqual(add, node.args[0])
1356+
else:
1357+
self.fail(f"Unexpected node: {node.format_node()}")
1358+
torch.testing.assert_allclose(m(input, a, b), traced(input, a, b))
1359+
13171360
def test_tile(self):
13181361
return self._make_acc_op_function_test(
13191362
acc_ops.tile, lambda x: torch.tile(x, (2, 1, 2)), input_shape=(1, 2)
@@ -1716,7 +1759,7 @@ def forward(
17161759
input, a, b = torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2)
17171760
traced = acc_tracer.trace(m, [input, a, b])
17181761

1719-
ph_in = ph_a = ph_b = mm = add = mm_mul = add_mul = None
1762+
ph_in = ph_a = ph_b = mm = add = mm_mul = add_mul = reshape = None
17201763
for node in traced.graph.nodes:
17211764
if node.op == "placeholder":
17221765
if str(node.target) == "a":

tracer/acc_tracer/acc_ops.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,60 @@ def softmax(*, input, dim, dtype):
361361
return torch.nn.functional.softmax(input=input, dim=dim, dtype=dtype)
362362

363363

364+
@register_custom_acc_mapper_fn(
365+
op_and_target=("call_function", torch.baddbmm),
366+
arg_replacement_tuples=[
367+
("input", "input"),
368+
("batch1", "mat1"),
369+
("batch2", "mat2"),
370+
("beta", "beta"),
371+
("alpha", "alpha"),
372+
],
373+
)
374+
def baddbmm_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
375+
"""
376+
Mapping from torch.addmm to acc_ops.mm -> acc_ops.add, if alpha or beta is not 1
377+
then we also insert acc_ops.mul to the right place.
378+
"""
379+
with node.graph.inserting_before(node):
380+
mm_kwargs = {"input": node.kwargs["mat1"], "other": node.kwargs["mat2"]}
381+
mm_node = node.graph.create_node(
382+
"call_function", matmul, kwargs=mm_kwargs, name=f"{node.name}_mm"
383+
)
384+
mm_node.meta = node.meta.copy()
385+
386+
if node.kwargs["alpha"] != 1:
387+
mul_kwargs = {"input": mm_node, "other": node.kwargs["alpha"]}
388+
mm_node = node.graph.create_node(
389+
"call_function", mul, kwargs=mul_kwargs, name=f"{mm_node.name}_mul"
390+
)
391+
mm_node.meta = node.meta.copy()
392+
393+
input_node = node.kwargs["input"]
394+
if node.kwargs["beta"] != 1:
395+
mul_kwargs = {"input": input_node, "other": node.kwargs["beta"]}
396+
new_input_node = node.graph.create_node(
397+
"call_function", mul, kwargs=mul_kwargs, name=f"{node.name}_input_mul"
398+
)
399+
assert isinstance(input_node, torch.fx.Node)
400+
new_input_node.meta = input_node.meta.copy()
401+
input_node = new_input_node
402+
403+
# broadcast input to target shape
404+
if input_node.meta["tensor_rank"] < mm_node.meta["tensor_rank"]:
405+
rank = input_node.meta["tensor_rank"]
406+
raise RuntimeError(
407+
f"Unable to broadcast input with dimension {rank} on batch size dimension. "
408+
)
409+
410+
add_kwargs = {"input": mm_node, "other": input_node}
411+
add_node = node.graph.create_node(
412+
"call_function", add, kwargs=add_kwargs, name=f"{node.name}_add"
413+
)
414+
add_node.meta = node.meta.copy()
415+
return add_node
416+
417+
364418
@register_custom_acc_mapper_fn(
365419
op_and_target=("call_function", torch.addmm),
366420
arg_replacement_tuples=[
@@ -400,6 +454,13 @@ def addmm_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
400454
new_input_node.meta = input_node.meta.copy()
401455
input_node = new_input_node
402456

457+
# broadcast input to target shape
458+
if input_node.meta["tensor_rank"] < mm_node.meta["tensor_rank"]:
459+
rank = input_node.meta["tensor_rank"]
460+
raise RuntimeError(
461+
f"Unable to broadcast input with dimension {rank} on batch size dimension. "
462+
)
463+
403464
add_kwargs = {"input": mm_node, "other": input_node}
404465
add_node = node.graph.create_node(
405466
"call_function", add, kwargs=add_kwargs, name=f"{node.name}_add"

0 commit comments

Comments
 (0)