Skip to content

Commit 7922534

Browse files
authored
[mlir][linalg] Add patterns to convert matmul to transposed variants (#89075)
This adds patterns to convert from the Linalg matmul and batch_matmul ops to the transposed variants. By default the LHS matrix is transposed. Our work enabling a lowering path from linalg.matmul to ArmSME has revealed the current lowering results in non-contiguous memory accesses for the A matrix and very poor performance. These patterns provide a simple option to fix this.
1 parent 9ba6961 commit 7922534

File tree

9 files changed

+386
-0
lines changed

9 files changed

+386
-0
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,12 @@ def MatchInterfaceEnum : I32EnumAttr<"MatchInterfaceEnum", "An interface to matc
88
]>{
99
let cppNamespace = "mlir::transform";
1010
}
11+
12+
def TransposeMatmulInput : I32EnumAttr<"TransposeMatmulInput",
13+
"Input to transpose when converting matmul ops to transposed variants",
14+
[
15+
I32EnumAttrCase<"lhs", 0>,
16+
I32EnumAttrCase<"rhs", 1>,
17+
]>{
18+
let cppNamespace = "mlir::transform";
19+
}

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,23 @@ def ApplyTilingCanonicalizationPatternsOp : Op<Transform_Dialect,
7373
let assemblyFormat = "attr-dict";
7474
}
7575

76+
def ApplyTransposeMatmulPatternsOp : Op<Transform_Dialect,
77+
"apply_patterns.linalg.transpose_matmul",
78+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
79+
let description = [{
80+
Collects patterns to convert Linalg matmul ops to transposed variants.
81+
82+
By default the LHS matrix is transposed. Set `inputToTranspose=<rhs>` to
83+
instead transpose RHS matrix.
84+
}];
85+
86+
let arguments = (ins
87+
DefaultValuedAttr<TransposeMatmulInput,
88+
"TransposeMatmulInput::lhs">:$inputToTranspose);
89+
90+
let assemblyFormat = "(`<` $inputToTranspose^ `>`)? attr-dict";
91+
}
92+
7693
//===----------------------------------------------------------------------===//
7794
// BufferizeToAllocationOp
7895
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1616,6 +1616,10 @@ void populateSplitReductionPattern(
16161616
const ControlSplitReductionFn &controlSplitReductionFn,
16171617
bool useAlloc = false);
16181618

1619+
/// Patterns to convert Linalg matmul ops to transposed variants.
1620+
void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
1621+
bool transposeLHS = true);
1622+
16191623
} // namespace linalg
16201624
} // namespace mlir
16211625

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,12 @@ void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
199199
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
200200
}
201201

202+
void transform::ApplyTransposeMatmulPatternsOp::populatePatterns(
203+
RewritePatternSet &patterns) {
204+
bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
205+
linalg::populateTransposeMatmulPatterns(patterns, transposeLHS);
206+
}
207+
202208
//===----------------------------------------------------------------------===//
203209
// BufferizeToAllocationOp
204210
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
2222
InlineScalarOperands.cpp
2323
Interchange.cpp
2424
Loops.cpp
25+
TransposeMatmul.cpp
2526
MeshShardingInterfaceImpl.cpp
2627
NamedOpConversions.cpp
2728
Padding.cpp
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
//===- TransposeMatmul.cpp - Convert Linalg matmul to transposed variants -===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// This is intended to be a simple high-level (target-agnostic) matmul
9+
// transposition transformation.
10+
//===----------------------------------------------------------------------===//
11+
12+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
13+
#include "mlir/IR/PatternMatch.h"
14+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15+
16+
#define DEBUG_TYPE "linalg-transpose-matmul"
17+
18+
using namespace mlir;
19+
using namespace mlir::linalg;
20+
21+
namespace {
22+
/// Pattern to replace
23+
///
24+
/// linalg.matmul(a, b)
25+
///
26+
/// with
27+
///
28+
/// linalg.matmul_transpose_a(linalg.transpose(a), b)
29+
///
30+
/// By default the LHS is transposed. Set `transposeLHS=false` to
31+
/// transpose RHS instead.
32+
struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
33+
TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
34+
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
35+
36+
LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
37+
PatternRewriter &rewriter) const override {
38+
if (!bufferization::hasTensorSemantics(matmulOp))
39+
return rewriter.notifyMatchFailure(
40+
matmulOp, "only matmul ops with tensors are supported");
41+
42+
Location loc = matmulOp.getLoc();
43+
Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
44+
auto type = cast<ShapedType>(input.getType());
45+
46+
SmallVector<Value> dynamicDims;
47+
if (type.isDynamicDim(1))
48+
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
49+
if (type.isDynamicDim(0))
50+
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
51+
52+
ArrayRef<int64_t> shape = type.getShape();
53+
Value empty = rewriter.create<tensor::EmptyOp>(
54+
loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
55+
dynamicDims);
56+
auto transposeOp = rewriter.create<linalg::TransposeOp>(
57+
loc, input, empty, ArrayRef<int64_t>{1, 0});
58+
if (transposeLHS) {
59+
rewriter.replaceOpWithNewOp<linalg::MatmulTransposeAOp>(
60+
matmulOp, matmulOp.getResultTypes(),
61+
ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
62+
matmulOp.getOutputs());
63+
} else {
64+
rewriter.replaceOpWithNewOp<linalg::MatmulTransposeBOp>(
65+
matmulOp, matmulOp.getResultTypes(),
66+
ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
67+
matmulOp.getOutputs());
68+
}
69+
70+
return success();
71+
}
72+
73+
private:
74+
bool transposeLHS;
75+
};
76+
77+
/// Pattern to replace
78+
///
79+
/// linalg.batch_matmul(a, b)
80+
///
81+
/// with
82+
///
83+
/// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
84+
///
85+
/// Only the non-batch dimensions are transposed. By default the LHS is
86+
/// transposed. Set `transposeLHS=false` to transpose RHS instead.
87+
struct TransposeBatchMatmul final
88+
: public OpRewritePattern<linalg::BatchMatmulOp> {
89+
TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS)
90+
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
91+
92+
LogicalResult matchAndRewrite(linalg::BatchMatmulOp batchMatmulOp,
93+
PatternRewriter &rewriter) const override {
94+
if (!bufferization::hasTensorSemantics(batchMatmulOp))
95+
return rewriter.notifyMatchFailure(
96+
batchMatmulOp, "only matmul ops with tensors are supported");
97+
98+
Location loc = batchMatmulOp.getLoc();
99+
Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
100+
auto type = cast<ShapedType>(input.getType());
101+
102+
SmallVector<Value> dynamicDims;
103+
if (type.isDynamicDim(0))
104+
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
105+
if (type.isDynamicDim(2))
106+
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
107+
if (type.isDynamicDim(1))
108+
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
109+
110+
ArrayRef<int64_t> shape = type.getShape();
111+
Value empty = rewriter.create<tensor::EmptyOp>(
112+
loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
113+
type.getElementType(), dynamicDims);
114+
auto transposeOp = rewriter.create<linalg::TransposeOp>(
115+
loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
116+
if (transposeLHS) {
117+
rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeAOp>(
118+
batchMatmulOp, batchMatmulOp.getResultTypes(),
119+
ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
120+
batchMatmulOp.getOutputs());
121+
} else {
122+
rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeBOp>(
123+
batchMatmulOp, batchMatmulOp.getResultTypes(),
124+
ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
125+
batchMatmulOp.getOutputs());
126+
}
127+
128+
return success();
129+
}
130+
131+
private:
132+
bool transposeLHS;
133+
};
134+
} // namespace
135+
136+
void mlir::linalg::populateTransposeMatmulPatterns(RewritePatternSet &patterns,
137+
bool transposeLHS) {
138+
patterns.add<TransposeMatmul, TransposeBatchMatmul>(patterns.getContext(),
139+
transposeLHS);
140+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: mlir-opt %s
2+
3+
module attributes {transform.with_named_sequence} {
4+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
5+
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
6+
transform.apply_patterns to %0 {
7+
transform.apply_patterns.linalg.transpose_matmul
8+
} : !transform.any_op
9+
transform.apply_cse to %0 : !transform.any_op
10+
transform.apply_patterns to %0 {
11+
transform.apply_patterns.canonicalization
12+
} : !transform.any_op
13+
transform.yield
14+
}
15+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: mlir-opt %s
2+
3+
module attributes {transform.with_named_sequence} {
4+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
5+
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
6+
transform.apply_patterns to %0 {
7+
transform.apply_patterns.linalg.transpose_matmul <rhs>
8+
} : !transform.any_op
9+
transform.apply_cse to %0 : !transform.any_op
10+
transform.apply_patterns to %0 {
11+
transform.apply_patterns.canonicalization
12+
} : !transform.any_op
13+
transform.yield
14+
}
15+
}

0 commit comments

Comments
 (0)