Skip to content

Commit 8f1233e

Browse files
Shirong WuWei Wei
authored andcommitted
Back out "Add baddbmm operator support" (#26)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/26 Original commit changeset: efd417b2b494 Original Phabricator Diff: D34743279 (https://github.com/pytorch/fx2trt/commit/4f75c6f42213c841880de5358d32a9cee09beb4f) Reviewed By: frank-wei Differential Revision: D35002287 fbshipit-source-id: 7929aae49ef1daa0e116ba930f25a40c9b1388aa
1 parent 159e77c commit 8f1233e

File tree

3 files changed

+1
-171
lines changed

3 files changed

+1
-171
lines changed

test/converters/acc_op/test_mm_ops.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

test/tracer/test_acc_tracer.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,49 +1314,6 @@ def test_bmm(self):
13141314
acc_ops.matmul, lambda x: torch.bmm(x, x), input_shape=(2, 4, 4)
13151315
)
13161316

1317-
def test_baddbmm_with_alpha_beta(self):
1318-
class TestModule(torch.nn.Module):
1319-
def forward(self, input: torch.Tensor, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
1320-
return torch.baddbmm(input, a, b, alpha = 1.2, beta = 1.1)
1321-
1322-
m = TestModule()
1323-
input, a, b = torch.randn(10, 2, 3), torch.randn(10, 2, 4), torch.randn(10, 4, 3)
1324-
traced = acc_tracer.trace(m, [input, a, b])
1325-
1326-
ph_in = ph_a = ph_b = mm = add = mm_mul = add_mul = None
1327-
for node in traced.graph.nodes:
1328-
if node.op == "placeholder":
1329-
if str(node.target) == "a":
1330-
ph_a = node
1331-
elif str(node.target) == "b":
1332-
ph_b = node
1333-
else:
1334-
self.assertTrue(str(node.target) == "input")
1335-
ph_in = node
1336-
elif node.op == "call_function":
1337-
if node.target == acc_ops.matmul:
1338-
self.assertEqual(node.kwargs["input"], ph_a)
1339-
self.assertEqual(node.kwargs["other"], ph_b)
1340-
mm = node
1341-
elif node.target == acc_ops.add:
1342-
self.assertEqual(node.kwargs["input"], mm_mul)
1343-
self.assertEqual(node.kwargs["other"], add_mul)
1344-
add = node
1345-
elif mm_mul:
1346-
self.assertEqual(node.kwargs["input"], ph_in)
1347-
self.assertEqual(node.kwargs["other"], 1.1)
1348-
add_mul = node
1349-
1350-
else:
1351-
self.assertEqual(node.kwargs["input"], mm)
1352-
self.assertEqual(node.kwargs["other"], 1.2)
1353-
mm_mul = node
1354-
elif node.op == "output":
1355-
self.assertEqual(add, node.args[0])
1356-
else:
1357-
self.fail(f"Unexpected node: {node.format_node()}")
1358-
torch.testing.assert_allclose(m(input, a, b), traced(input, a, b))
1359-
13601317
def test_tile(self):
13611318
return self._make_acc_op_function_test(
13621319
acc_ops.tile, lambda x: torch.tile(x, (2, 1, 2)), input_shape=(1, 2)
@@ -1759,7 +1716,7 @@ def forward(
17591716
input, a, b = torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2)
17601717
traced = acc_tracer.trace(m, [input, a, b])
17611718

1762-
ph_in = ph_a = ph_b = mm = add = mm_mul = add_mul = reshape = None
1719+
ph_in = ph_a = ph_b = mm = add = mm_mul = add_mul = None
17631720
for node in traced.graph.nodes:
17641721
if node.op == "placeholder":
17651722
if str(node.target) == "a":

tracer/acc_tracer/acc_ops.py

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -361,60 +361,6 @@ def softmax(*, input, dim, dtype):
361361
return torch.nn.functional.softmax(input=input, dim=dim, dtype=dtype)
362362

363363

364-
@register_custom_acc_mapper_fn(
365-
op_and_target=("call_function", torch.baddbmm),
366-
arg_replacement_tuples=[
367-
("input", "input"),
368-
("batch1", "mat1"),
369-
("batch2", "mat2"),
370-
("beta", "beta"),
371-
("alpha", "alpha"),
372-
],
373-
)
374-
def baddbmm_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
375-
"""
376-
Mapping from torch.addmm to acc_ops.mm -> acc_ops.add, if alpha or beta is not 1
377-
then we also insert acc_ops.mul to the right place.
378-
"""
379-
with node.graph.inserting_before(node):
380-
mm_kwargs = {"input": node.kwargs["mat1"], "other": node.kwargs["mat2"]}
381-
mm_node = node.graph.create_node(
382-
"call_function", matmul, kwargs=mm_kwargs, name=f"{node.name}_mm"
383-
)
384-
mm_node.meta = node.meta.copy()
385-
386-
if node.kwargs["alpha"] != 1:
387-
mul_kwargs = {"input": mm_node, "other": node.kwargs["alpha"]}
388-
mm_node = node.graph.create_node(
389-
"call_function", mul, kwargs=mul_kwargs, name=f"{mm_node.name}_mul"
390-
)
391-
mm_node.meta = node.meta.copy()
392-
393-
input_node = node.kwargs["input"]
394-
if node.kwargs["beta"] != 1:
395-
mul_kwargs = {"input": input_node, "other": node.kwargs["beta"]}
396-
new_input_node = node.graph.create_node(
397-
"call_function", mul, kwargs=mul_kwargs, name=f"{node.name}_input_mul"
398-
)
399-
assert isinstance(input_node, torch.fx.Node)
400-
new_input_node.meta = input_node.meta.copy()
401-
input_node = new_input_node
402-
403-
# broadcast input to target shape
404-
if input_node.meta["tensor_rank"] < mm_node.meta["tensor_rank"]:
405-
rank = input_node.meta["tensor_rank"]
406-
raise RuntimeError(
407-
f"Unable to broadcast input with dimension {rank} on batch size dimension. "
408-
)
409-
410-
add_kwargs = {"input": mm_node, "other": input_node}
411-
add_node = node.graph.create_node(
412-
"call_function", add, kwargs=add_kwargs, name=f"{node.name}_add"
413-
)
414-
add_node.meta = node.meta.copy()
415-
return add_node
416-
417-
418364
@register_custom_acc_mapper_fn(
419365
op_and_target=("call_function", torch.addmm),
420366
arg_replacement_tuples=[
@@ -454,13 +400,6 @@ def addmm_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
454400
new_input_node.meta = input_node.meta.copy()
455401
input_node = new_input_node
456402

457-
# broadcast input to target shape
458-
if input_node.meta["tensor_rank"] < mm_node.meta["tensor_rank"]:
459-
rank = input_node.meta["tensor_rank"]
460-
raise RuntimeError(
461-
f"Unable to broadcast input with dimension {rank} on batch size dimension. "
462-
)
463-
464403
add_kwargs = {"input": mm_node, "other": input_node}
465404
add_node = node.graph.create_node(
466405
"call_function", add, kwargs=add_kwargs, name=f"{node.name}_add"

0 commit comments

Comments
 (0)