Skip to content

Commit 8dfb58e

Browse files
committed
add args input_matrix_op and other_matrix_op, support aten.mv.default
1 parent 9234306 commit 8dfb58e

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def aten_ops_gelu(
171171

172172
@dynamo_tensorrt_converter(torch.ops.aten.matmul) # type: ignore[misc]
173173
@dynamo_tensorrt_converter(torch.ops.aten.mm.default) # type: ignore[misc]
174+
@dynamo_tensorrt_converter(torch.ops.aten.mv.default) # type: ignore[misc]
174175
def aten_ops_matmul(
175176
network: TRTNetwork,
176177
target: Target,
@@ -179,7 +180,14 @@ def aten_ops_matmul(
179180
name: str,
180181
) -> Union[TRTTensor, Sequence[TRTTensor]]:
181182
return impl.matmul.matrix_multiply(
182-
network, target, SourceIR.ATEN, name, args[0], args[1]
183+
network,
184+
target,
185+
SourceIR.ATEN,
186+
name,
187+
args[0],
188+
args[1],
189+
args_bounds_check(args, 2, trt.MatrixOperation.NONE),
190+
args_bounds_check(args, 3, trt.MatrixOperation.NONE),
183191
)
184192

185193

py/torch_tensorrt/dynamo/conversion/impl/matmul.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ def matrix_multiply(
1919
name: str,
2020
input: TRTTensor,
2121
other: TRTTensor,
22+
input_matrix_op: trt.MatrixOperation = trt.MatrixOperation.NONE,
23+
other_matrix_op: trt.MatrixOperation = trt.MatrixOperation.NONE,
2224
) -> TRTTensor:
2325
if not isinstance(input, trt.tensorrt.ITensor):
2426
input = get_trt_tensor(network, input, f"{name}_input")
@@ -30,7 +32,6 @@ def matrix_multiply(
3032
dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH),
3133
)
3234

33-
input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE
3435
preset_diff = 0
3536

3637
if len(input.shape) == 1:

0 commit comments

Comments
 (0)