Skip to content

fix: add an arg in matmul #2279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def aten_ops_gelu(

@dynamo_tensorrt_converter(torch.ops.aten.matmul) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.mm.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.mv.default) # type: ignore[misc]
def aten_ops_matmul(
network: TRTNetwork,
target: Target,
Expand All @@ -179,7 +180,12 @@ def aten_ops_matmul(
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.matmul.matrix_multiply(
network, target, SourceIR.ATEN, name, args[0], args[1]
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


Expand Down
8 changes: 4 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/impl/matmul.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional

import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.fx.converters.converter_utils import (
Expand All @@ -10,8 +11,6 @@
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter

import tensorrt as trt


def matrix_multiply(
network: TRTNetwork,
Expand All @@ -20,6 +19,8 @@ def matrix_multiply(
name: str,
input: TRTTensor,
other: TRTTensor,
input_matrix_op: trt.MatrixOperation = trt.MatrixOperation.NONE,
other_matrix_op: trt.MatrixOperation = trt.MatrixOperation.NONE,
) -> TRTTensor:
if not isinstance(input, trt.tensorrt.ITensor):
input = get_trt_tensor(network, input, f"{name}_input")
Expand All @@ -31,7 +32,6 @@ def matrix_multiply(
dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH),
)

input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE
preset_diff = 0

if len(input.shape) == 1:
Expand All @@ -46,5 +46,5 @@ def matrix_multiply(
network, input, other, f"{name}_input", f"{name}_other", preset_diff
)
layer = network.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op)
set_layer_name(layer, target, name)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)
79 changes: 58 additions & 21 deletions tests/py/dynamo/conversion/test_matmul_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,44 @@
class TestMatMulConverter(DispatchTestCase):
@parameterized.expand(
[
("2_2", (2, 3), (3, 2)),
("2_2", (2, 3), (3, 1)),
# FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm?
# (2,3), (3,) torch.ops.aten.mv.default
# Following cases use torch.ops.aten.bmm.defauly
(
"2_2",
(2, 3),
(3, 2),
),
(
"4_6",
(4, 5),
(5, 6),
),
(
"2_1",
(2, 3),
(3, 1),
),
(
"4_1",
(4, 1),
(1, 1),
),
(
"1_2",
(1, 3),
(3, 2),
),
(
"1_3",
(1, 2),
(2, 3),
),
# Following cases use torch.ops.aten.bmm.default
# ("4_3", (3,1,3,2), (2,2,3)),
# ("3_4", (3,1,3,2), (2,2,3)),
# ("3_4", (2, 2, 3), (3, 1, 3, 3)),
# ("4_2", (1, 2, 2, 3), (3, 2)),
]
)
def test_matmul_other_constant(self, _, input_shape, other_shape):
def test_matmul_mm(self, _, input_shape, other_shape):
class MatMul(nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -39,32 +65,43 @@ def forward(self, input):

@parameterized.expand(
[
("2_2", (2, 3), (3, 2)),
("1_2", (1, 3), (3, 2)),
# FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm?
# (2,3), (3,) torch.ops.aten.mv.default
# Following cases use torch.ops.aten.bmm.defauly
# ("4_3", (3,1,3,2), (2,2,3)),
# ("3_4", (3,1,3,2), (2,2,3)),
# ("3_4", (2, 2, 3), (3, 1, 3, 3)),
# ("4_2", (1, 2, 2, 3), (3, 2)),
(
"1_1",
(1, 1),
(1,),
),
(
"1_1",
(1, 2),
(2,),
),
(
"2_1",
(2, 1),
(1,),
),
(
"3_1",
(3, 4),
(4,),
),
]
)
def test_matmul_input_constant(self, _, input_shape, other_shape):
def test_matmul_mv(self, _, input_shape, other_shape):
class MatMul(nn.Module):
def __init__(self):
super().__init__()
self.input = nn.Parameter(torch.randn(*input_shape))
self.other = nn.Parameter(torch.randn(*other_shape))

def forward(self, other):
return torch.matmul(self.input, other)
def forward(self, input):
return torch.matmul(input, self.other)

inputs = [torch.randn(*other_shape)]
inputs = [torch.randn(*input_shape)]

self.run_test(
MatMul(),
inputs,
expected_ops={torch.ops.aten.mm.default},
expected_ops={torch.ops.aten.mv.default},
)

@parameterized.expand(
Expand Down