Skip to content

Commit be1c72d

Browse files
authored
[mlir][linalg] Move transpose_matmul to targeted transform op (#89717)
More targeted than a blanket "apply everywhere" pattern. Follow up to #89075 to address @ftynse's feedback.
1 parent 719112c commit be1c72d

File tree

6 files changed

+182
-106
lines changed

6 files changed

+182
-106
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -73,23 +73,6 @@ def ApplyTilingCanonicalizationPatternsOp : Op<Transform_Dialect,
7373
let assemblyFormat = "attr-dict";
7474
}
7575

76-
def ApplyTransposeMatmulPatternsOp : Op<Transform_Dialect,
77-
"apply_patterns.linalg.transpose_matmul",
78-
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
79-
let description = [{
80-
Collects patterns to convert Linalg matmul ops to transposed variants.
81-
82-
By default the LHS matrix is transposed. Set `inputToTranspose=<rhs>` to
83-
instead transpose RHS matrix.
84-
}];
85-
86-
let arguments = (ins
87-
DefaultValuedAttr<TransposeMatmulInput,
88-
"TransposeMatmulInput::lhs">:$inputToTranspose);
89-
90-
let assemblyFormat = "(`<` $inputToTranspose^ `>`)? attr-dict";
91-
}
92-
9376
//===----------------------------------------------------------------------===//
9477
// BufferizeToAllocationOp
9578
//===----------------------------------------------------------------------===//
@@ -2429,6 +2412,52 @@ def TransposeConv2DOp : Op<Transform_Dialect,
24292412
}];
24302413
}
24312414

2415+
//===----------------------------------------------------------------------===//
2416+
// TransposeMatmulOp
2417+
//===----------------------------------------------------------------------===//
2418+
2419+
def TransposeMatmulOp : Op<Transform_Dialect,
2420+
"structured.transpose_matmul",
2421+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
2422+
TransformOpInterface, TransformEachOpTrait,
2423+
ReportTrackingListenerFailuresOpTrait]> {
2424+
let description = [{
2425+
Convert Linalg matmul ops to transposed variants.
2426+
2427+
By default the LHS matrix is transposed. Specify `<rhs>` to instead
2428+
transpose RHS matrix.
2429+
2430+
#### Return modes:
2431+
2432+
This operation fails if `target` is unsupported, i.e., not a
2433+
`linalg.matmul` or `linalg.batch_matmul`. Otherwise, the operation succeeds
2434+
and returns a handle to the transposed matmul op.
2435+
}];
2436+
2437+
let arguments = (ins
2438+
TransformHandleTypeInterface:$target,
2439+
DefaultValuedAttr<TransposeMatmulInput,
2440+
"TransposeMatmulInput::lhs">:$inputToTranspose);
2441+
let results = (outs TransformHandleTypeInterface:$transformed);
2442+
2443+
let assemblyFormat = [{
2444+
$target (`<` $inputToTranspose^ `>`)?
2445+
attr-dict `:` functional-type($target, results)
2446+
}];
2447+
2448+
let builders = [
2449+
OpBuilder<(ins "Value":$target)>
2450+
];
2451+
2452+
let extraClassDeclaration = [{
2453+
::mlir::DiagnosedSilenceableFailure applyToOne(
2454+
::mlir::transform::TransformRewriter &rewriter,
2455+
::mlir::linalg::LinalgOp target,
2456+
::mlir::transform::ApplyToEachResultList &results,
2457+
::mlir::transform::TransformState &state);
2458+
}];
2459+
}
2460+
24322461
//===----------------------------------------------------------------------===//
24332462
// InsertSliceToCopyOp
24342463
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,6 +1244,14 @@ FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
12441244
FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
12451245
linalg::Conv2DNhwcFhwcQOp op);
12461246

1247+
/// Convert Linalg matmul ops to transposed variants.
1248+
FailureOr<Operation *> transposeMatmul(RewriterBase &rewriter,
1249+
linalg::MatmulOp op,
1250+
bool transposeLHS = true);
1251+
FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
1252+
linalg::BatchMatmulOp op,
1253+
bool transposeLHS = true);
1254+
12471255
//===----------------------------------------------------------------------===//
12481256
// Rewrite patterns wrapping transformations.
12491257
// TODO: every single such pattern should be a close to noop wrapper around a

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,6 @@ void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
199199
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
200200
}
201201

202-
void transform::ApplyTransposeMatmulPatternsOp::populatePatterns(
203-
RewritePatternSet &patterns) {
204-
bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
205-
linalg::populateTransposeMatmulPatterns(patterns, transposeLHS);
206-
}
207-
208202
//===----------------------------------------------------------------------===//
209203
// BufferizeToAllocationOp
210204
//===----------------------------------------------------------------------===//
@@ -3422,6 +3416,32 @@ DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
34223416
return DiagnosedSilenceableFailure::success();
34233417
}
34243418

3419+
//===----------------------------------------------------------------------===//
3420+
// TransposeMatmulOp
3421+
//===----------------------------------------------------------------------===//
3422+
3423+
DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
3424+
transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3425+
transform::ApplyToEachResultList &results,
3426+
transform::TransformState &state) {
3427+
rewriter.setInsertionPoint(target);
3428+
bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
3429+
auto maybeTransformed =
3430+
TypeSwitch<Operation *, FailureOr<Operation *>>(target)
3431+
.Case([&](linalg::MatmulOp op) {
3432+
return transposeMatmul(rewriter, op, transposeLHS);
3433+
})
3434+
.Case([&](linalg::BatchMatmulOp op) {
3435+
return transposeBatchMatmul(rewriter, op, transposeLHS);
3436+
})
3437+
.Default([&](Operation *op) { return failure(); });
3438+
if (failed(maybeTransformed))
3439+
return emitSilenceableFailure(target->getLoc()) << "not supported";
3440+
// Handle to the new Matmul operation with transposed filters
3441+
results.push_back(*maybeTransformed);
3442+
return DiagnosedSilenceableFailure::success();
3443+
}
3444+
34253445
//===----------------------------------------------------------------------===//
34263446
// InsertSliceToCopyOp
34273447
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp

Lines changed: 98 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
using namespace mlir;
1919
using namespace mlir::linalg;
2020

21-
namespace {
2221
/// Pattern to replace
2322
///
2423
/// linalg.matmul(a, b)
@@ -29,102 +28,124 @@ namespace {
2928
///
3029
/// By default the LHS is transposed. Set `transposeLHS=false` to
3130
/// transpose RHS instead.
31+
FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
32+
linalg::MatmulOp matmulOp,
33+
bool transposeLHS) {
34+
if (!bufferization::hasTensorSemantics(matmulOp))
35+
return rewriter.notifyMatchFailure(
36+
matmulOp, "only matmul ops with tensors are supported");
37+
38+
Location loc = matmulOp.getLoc();
39+
Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
40+
auto type = cast<ShapedType>(input.getType());
41+
42+
SmallVector<Value> dynamicDims;
43+
if (type.isDynamicDim(1))
44+
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
45+
if (type.isDynamicDim(0))
46+
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
47+
48+
ArrayRef<int64_t> shape = type.getShape();
49+
Value empty = rewriter.create<tensor::EmptyOp>(
50+
loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
51+
dynamicDims);
52+
auto transposeOp = rewriter.create<linalg::TransposeOp>(
53+
loc, input, empty, ArrayRef<int64_t>{1, 0});
54+
Operation *newMatmulOp;
55+
if (transposeLHS) {
56+
newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
57+
loc, matmulOp.getResultTypes(),
58+
ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
59+
matmulOp.getOutputs());
60+
} else {
61+
newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
62+
loc, matmulOp.getResultTypes(),
63+
ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
64+
matmulOp.getOutputs());
65+
}
66+
rewriter.replaceOp(matmulOp, newMatmulOp);
67+
return newMatmulOp;
68+
}
69+
70+
/// Pattern to replace
71+
///
72+
/// linalg.batch_matmul(a, b)
73+
///
74+
/// with
75+
///
76+
/// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
77+
///
78+
/// Only the non-batch dimensions are transposed. By default the LHS is
79+
/// transposed. Set `transposeLHS=false` to transpose RHS instead.
80+
FailureOr<Operation *>
81+
mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
82+
linalg::BatchMatmulOp batchMatmulOp,
83+
bool transposeLHS) {
84+
if (!bufferization::hasTensorSemantics(batchMatmulOp))
85+
return rewriter.notifyMatchFailure(
86+
batchMatmulOp, "only matmul ops with tensors are supported");
87+
88+
Location loc = batchMatmulOp.getLoc();
89+
Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
90+
auto type = cast<ShapedType>(input.getType());
91+
92+
SmallVector<Value> dynamicDims;
93+
if (type.isDynamicDim(0))
94+
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
95+
if (type.isDynamicDim(2))
96+
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
97+
if (type.isDynamicDim(1))
98+
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
99+
100+
ArrayRef<int64_t> shape = type.getShape();
101+
Value empty = rewriter.create<tensor::EmptyOp>(
102+
loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
103+
type.getElementType(), dynamicDims);
104+
auto transposeOp = rewriter.create<linalg::TransposeOp>(
105+
loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
106+
Operation *newMatmulOp;
107+
if (transposeLHS) {
108+
newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
109+
loc, batchMatmulOp.getResultTypes(),
110+
ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
111+
batchMatmulOp.getOutputs());
112+
} else {
113+
newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
114+
loc, batchMatmulOp.getResultTypes(),
115+
ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
116+
batchMatmulOp.getOutputs());
117+
}
118+
rewriter.replaceOp(batchMatmulOp, newMatmulOp);
119+
return newMatmulOp;
120+
}
121+
122+
namespace {
32123
struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
33124
TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
34125
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
35126

36-
LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
127+
LogicalResult matchAndRewrite(linalg::MatmulOp op,
37128
PatternRewriter &rewriter) const override {
38-
if (!bufferization::hasTensorSemantics(matmulOp))
39-
return rewriter.notifyMatchFailure(
40-
matmulOp, "only matmul ops with tensors are supported");
41-
42-
Location loc = matmulOp.getLoc();
43-
Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
44-
auto type = cast<ShapedType>(input.getType());
45-
46-
SmallVector<Value> dynamicDims;
47-
if (type.isDynamicDim(1))
48-
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
49-
if (type.isDynamicDim(0))
50-
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
51-
52-
ArrayRef<int64_t> shape = type.getShape();
53-
Value empty = rewriter.create<tensor::EmptyOp>(
54-
loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
55-
dynamicDims);
56-
auto transposeOp = rewriter.create<linalg::TransposeOp>(
57-
loc, input, empty, ArrayRef<int64_t>{1, 0});
58-
if (transposeLHS) {
59-
rewriter.replaceOpWithNewOp<linalg::MatmulTransposeAOp>(
60-
matmulOp, matmulOp.getResultTypes(),
61-
ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
62-
matmulOp.getOutputs());
63-
} else {
64-
rewriter.replaceOpWithNewOp<linalg::MatmulTransposeBOp>(
65-
matmulOp, matmulOp.getResultTypes(),
66-
ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
67-
matmulOp.getOutputs());
129+
if (failed(transposeMatmul(rewriter, op, transposeLHS))) {
130+
return failure();
68131
}
69-
70132
return success();
71133
}
72134

73135
private:
74136
bool transposeLHS;
75137
};
76138

77-
/// Pattern to replace
78-
///
79-
/// linalg.batch_matmul(a, b)
80-
///
81-
/// with
82-
///
83-
/// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
84-
///
85-
/// Only the non-batch dimensions are transposed. By default the LHS is
86-
/// transposed. Set `transposeLHS=false` to transpose RHS instead.
87139
struct TransposeBatchMatmul final
88140
: public OpRewritePattern<linalg::BatchMatmulOp> {
89141
TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS)
90142
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
91143

92-
LogicalResult matchAndRewrite(linalg::BatchMatmulOp batchMatmulOp,
144+
LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
93145
PatternRewriter &rewriter) const override {
94-
if (!bufferization::hasTensorSemantics(batchMatmulOp))
95-
return rewriter.notifyMatchFailure(
96-
batchMatmulOp, "only matmul ops with tensors are supported");
97-
98-
Location loc = batchMatmulOp.getLoc();
99-
Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
100-
auto type = cast<ShapedType>(input.getType());
101-
102-
SmallVector<Value> dynamicDims;
103-
if (type.isDynamicDim(0))
104-
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
105-
if (type.isDynamicDim(2))
106-
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
107-
if (type.isDynamicDim(1))
108-
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
109-
110-
ArrayRef<int64_t> shape = type.getShape();
111-
Value empty = rewriter.create<tensor::EmptyOp>(
112-
loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
113-
type.getElementType(), dynamicDims);
114-
auto transposeOp = rewriter.create<linalg::TransposeOp>(
115-
loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
116-
if (transposeLHS) {
117-
rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeAOp>(
118-
batchMatmulOp, batchMatmulOp.getResultTypes(),
119-
ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
120-
batchMatmulOp.getOutputs());
121-
} else {
122-
rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeBOp>(
123-
batchMatmulOp, batchMatmulOp.getResultTypes(),
124-
ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
125-
batchMatmulOp.getOutputs());
146+
if (failed(transposeBatchMatmul(rewriter, op, transposeLHS))) {
147+
return failure();
126148
}
127-
128149
return success();
129150
}
130151

mlir/test/Dialect/Linalg/transpose-matmul-a.mlir

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
module attributes {transform.with_named_sequence} {
44
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
5+
%matmul = transform.structured.match ops{["linalg.matmul", "linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
6+
transform.structured.transpose_matmul %matmul : (!transform.any_op) -> (!transform.any_op)
57
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
6-
transform.apply_patterns to %0 {
7-
transform.apply_patterns.linalg.transpose_matmul
8-
} : !transform.any_op
98
transform.apply_cse to %0 : !transform.any_op
109
transform.apply_patterns to %0 {
1110
transform.apply_patterns.canonicalization

mlir/test/Dialect/Linalg/transpose-matmul-b.mlir

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
module attributes {transform.with_named_sequence} {
44
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
5+
%matmul = transform.structured.match ops{["linalg.matmul", "linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
6+
transform.structured.transpose_matmul %matmul <rhs> : (!transform.any_op) -> (!transform.any_op)
57
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
6-
transform.apply_patterns to %0 {
7-
transform.apply_patterns.linalg.transpose_matmul <rhs>
8-
} : !transform.any_op
98
transform.apply_cse to %0 : !transform.any_op
109
transform.apply_patterns to %0 {
1110
transform.apply_patterns.canonicalization

0 commit comments

Comments
 (0)