18
18
using namespace mlir ;
19
19
using namespace mlir ::linalg;
20
20
21
- namespace {
22
21
// / Pattern to replace
23
22
// /
24
23
// / linalg.matmul(a, b)
@@ -29,102 +28,124 @@ namespace {
29
28
// /
30
29
// / By default the LHS is transposed. Set `transposeLHS=false` to
31
30
// / transpose RHS instead.
31
+ FailureOr<Operation *> mlir::linalg::transposeMatmul (RewriterBase &rewriter,
32
+ linalg::MatmulOp matmulOp,
33
+ bool transposeLHS) {
34
+ if (!bufferization::hasTensorSemantics (matmulOp))
35
+ return rewriter.notifyMatchFailure (
36
+ matmulOp, " only matmul ops with tensors are supported" );
37
+
38
+ Location loc = matmulOp.getLoc ();
39
+ Value input = matmulOp.getInputs ()[transposeLHS ? 0 : 1 ];
40
+ auto type = cast<ShapedType>(input.getType ());
41
+
42
+ SmallVector<Value> dynamicDims;
43
+ if (type.isDynamicDim (1 ))
44
+ dynamicDims.push_back (rewriter.create <tensor::DimOp>(loc, input, 1 ));
45
+ if (type.isDynamicDim (0 ))
46
+ dynamicDims.push_back (rewriter.create <tensor::DimOp>(loc, input, 0 ));
47
+
48
+ ArrayRef<int64_t > shape = type.getShape ();
49
+ Value empty = rewriter.create <tensor::EmptyOp>(
50
+ loc, ArrayRef<int64_t >{shape[1 ], shape[0 ]}, type.getElementType (),
51
+ dynamicDims);
52
+ auto transposeOp = rewriter.create <linalg::TransposeOp>(
53
+ loc, input, empty, ArrayRef<int64_t >{1 , 0 });
54
+ Operation *newMatmulOp;
55
+ if (transposeLHS) {
56
+ newMatmulOp = rewriter.create <linalg::MatmulTransposeAOp>(
57
+ loc, matmulOp.getResultTypes (),
58
+ ValueRange{transposeOp->getResult (0 ), matmulOp.getInputs ()[1 ]},
59
+ matmulOp.getOutputs ());
60
+ } else {
61
+ newMatmulOp = rewriter.create <linalg::MatmulTransposeBOp>(
62
+ loc, matmulOp.getResultTypes (),
63
+ ValueRange{matmulOp.getInputs ()[0 ], transposeOp->getResult (0 )},
64
+ matmulOp.getOutputs ());
65
+ }
66
+ rewriter.replaceOp (matmulOp, newMatmulOp);
67
+ return newMatmulOp;
68
+ }
69
+
70
+ // / Pattern to replace
71
+ // /
72
+ // / linalg.batch_matmul(a, b)
73
+ // /
74
+ // / with
75
+ // /
76
+ // / linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
77
+ // /
78
+ // / Only the non-batch dimensions are transposed. By default the LHS is
79
+ // / transposed. Set `transposeLHS=false` to transpose RHS instead.
80
+ FailureOr<Operation *>
81
+ mlir::linalg::transposeBatchMatmul (RewriterBase &rewriter,
82
+ linalg::BatchMatmulOp batchMatmulOp,
83
+ bool transposeLHS) {
84
+ if (!bufferization::hasTensorSemantics (batchMatmulOp))
85
+ return rewriter.notifyMatchFailure (
86
+ batchMatmulOp, " only matmul ops with tensors are supported" );
87
+
88
+ Location loc = batchMatmulOp.getLoc ();
89
+ Value input = batchMatmulOp.getInputs ()[transposeLHS ? 0 : 1 ];
90
+ auto type = cast<ShapedType>(input.getType ());
91
+
92
+ SmallVector<Value> dynamicDims;
93
+ if (type.isDynamicDim (0 ))
94
+ dynamicDims.push_back (rewriter.create <tensor::DimOp>(loc, input, 0 ));
95
+ if (type.isDynamicDim (2 ))
96
+ dynamicDims.push_back (rewriter.create <tensor::DimOp>(loc, input, 2 ));
97
+ if (type.isDynamicDim (1 ))
98
+ dynamicDims.push_back (rewriter.create <tensor::DimOp>(loc, input, 1 ));
99
+
100
+ ArrayRef<int64_t > shape = type.getShape ();
101
+ Value empty = rewriter.create <tensor::EmptyOp>(
102
+ loc, ArrayRef<int64_t >{shape[0 ], shape[2 ], shape[1 ]},
103
+ type.getElementType (), dynamicDims);
104
+ auto transposeOp = rewriter.create <linalg::TransposeOp>(
105
+ loc, input, empty, ArrayRef<int64_t >{0 , 2 , 1 });
106
+ Operation *newMatmulOp;
107
+ if (transposeLHS) {
108
+ newMatmulOp = rewriter.create <linalg::BatchMatmulTransposeAOp>(
109
+ loc, batchMatmulOp.getResultTypes (),
110
+ ValueRange{transposeOp->getResult (0 ), batchMatmulOp.getInputs ()[1 ]},
111
+ batchMatmulOp.getOutputs ());
112
+ } else {
113
+ newMatmulOp = rewriter.create <linalg::BatchMatmulTransposeBOp>(
114
+ loc, batchMatmulOp.getResultTypes (),
115
+ ValueRange{batchMatmulOp.getInputs ()[0 ], transposeOp->getResult (0 )},
116
+ batchMatmulOp.getOutputs ());
117
+ }
118
+ rewriter.replaceOp (batchMatmulOp, newMatmulOp);
119
+ return newMatmulOp;
120
+ }
121
+
122
+ namespace {
32
123
struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
33
124
TransposeMatmul (MLIRContext *ctx, bool transposeLHS)
34
125
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
35
126
36
- LogicalResult matchAndRewrite (linalg::MatmulOp matmulOp ,
127
+ LogicalResult matchAndRewrite (linalg::MatmulOp op ,
37
128
PatternRewriter &rewriter) const override {
38
- if (!bufferization::hasTensorSemantics (matmulOp))
39
- return rewriter.notifyMatchFailure (
40
- matmulOp, " only matmul ops with tensors are supported" );
41
-
42
- Location loc = matmulOp.getLoc ();
43
- Value input = matmulOp.getInputs ()[transposeLHS ? 0 : 1 ];
44
- auto type = cast<ShapedType>(input.getType ());
45
-
46
- SmallVector<Value> dynamicDims;
47
- if (type.isDynamicDim (1 ))
48
- dynamicDims.push_back (rewriter.create <tensor::DimOp>(loc, input, 1 ));
49
- if (type.isDynamicDim (0 ))
50
- dynamicDims.push_back (rewriter.create <tensor::DimOp>(loc, input, 0 ));
51
-
52
- ArrayRef<int64_t > shape = type.getShape ();
53
- Value empty = rewriter.create <tensor::EmptyOp>(
54
- loc, ArrayRef<int64_t >{shape[1 ], shape[0 ]}, type.getElementType (),
55
- dynamicDims);
56
- auto transposeOp = rewriter.create <linalg::TransposeOp>(
57
- loc, input, empty, ArrayRef<int64_t >{1 , 0 });
58
- if (transposeLHS) {
59
- rewriter.replaceOpWithNewOp <linalg::MatmulTransposeAOp>(
60
- matmulOp, matmulOp.getResultTypes (),
61
- ValueRange{transposeOp->getResult (0 ), matmulOp.getInputs ()[1 ]},
62
- matmulOp.getOutputs ());
63
- } else {
64
- rewriter.replaceOpWithNewOp <linalg::MatmulTransposeBOp>(
65
- matmulOp, matmulOp.getResultTypes (),
66
- ValueRange{matmulOp.getInputs ()[0 ], transposeOp->getResult (0 )},
67
- matmulOp.getOutputs ());
129
+ if (failed (transposeMatmul (rewriter, op, transposeLHS))) {
130
+ return failure ();
68
131
}
69
-
70
132
return success ();
71
133
}
72
134
73
135
private:
74
136
bool transposeLHS;
75
137
};
76
138
77
- // / Pattern to replace
78
- // /
79
- // / linalg.batch_matmul(a, b)
80
- // /
81
- // / with
82
- // /
83
- // / linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
84
- // /
85
- // / Only the non-batch dimensions are transposed. By default the LHS is
86
- // / transposed. Set `transposeLHS=false` to transpose RHS instead.
87
139
struct TransposeBatchMatmul final
88
140
: public OpRewritePattern<linalg::BatchMatmulOp> {
89
141
TransposeBatchMatmul (MLIRContext *ctx, bool transposeLHS)
90
142
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
91
143
92
- LogicalResult matchAndRewrite (linalg::BatchMatmulOp batchMatmulOp ,
144
+ LogicalResult matchAndRewrite (linalg::BatchMatmulOp op ,
93
145
PatternRewriter &rewriter) const override {
94
- if (!bufferization::hasTensorSemantics (batchMatmulOp))
95
- return rewriter.notifyMatchFailure (
96
- batchMatmulOp, " only matmul ops with tensors are supported" );
97
-
98
- Location loc = batchMatmulOp.getLoc ();
99
- Value input = batchMatmulOp.getInputs ()[transposeLHS ? 0 : 1 ];
100
- auto type = cast<ShapedType>(input.getType ());
101
-
102
- SmallVector<Value> dynamicDims;
103
- if (type.isDynamicDim (0 ))
104
- dynamicDims.push_back (rewriter.create <tensor::DimOp>(loc, input, 0 ));
105
- if (type.isDynamicDim (2 ))
106
- dynamicDims.push_back (rewriter.create <tensor::DimOp>(loc, input, 2 ));
107
- if (type.isDynamicDim (1 ))
108
- dynamicDims.push_back (rewriter.create <tensor::DimOp>(loc, input, 1 ));
109
-
110
- ArrayRef<int64_t > shape = type.getShape ();
111
- Value empty = rewriter.create <tensor::EmptyOp>(
112
- loc, ArrayRef<int64_t >{shape[0 ], shape[2 ], shape[1 ]},
113
- type.getElementType (), dynamicDims);
114
- auto transposeOp = rewriter.create <linalg::TransposeOp>(
115
- loc, input, empty, ArrayRef<int64_t >{0 , 2 , 1 });
116
- if (transposeLHS) {
117
- rewriter.replaceOpWithNewOp <linalg::BatchMatmulTransposeAOp>(
118
- batchMatmulOp, batchMatmulOp.getResultTypes (),
119
- ValueRange{transposeOp->getResult (0 ), batchMatmulOp.getInputs ()[1 ]},
120
- batchMatmulOp.getOutputs ());
121
- } else {
122
- rewriter.replaceOpWithNewOp <linalg::BatchMatmulTransposeBOp>(
123
- batchMatmulOp, batchMatmulOp.getResultTypes (),
124
- ValueRange{batchMatmulOp.getInputs ()[0 ], transposeOp->getResult (0 )},
125
- batchMatmulOp.getOutputs ());
146
+ if (failed (transposeBatchMatmul (rewriter, op, transposeLHS))) {
147
+ return failure ();
126
148
}
127
-
128
149
return success ();
129
150
}
130
151
0 commit comments