@@ -30,23 +30,26 @@ using namespace mlir;
30
30
using namespace mlir ::arm_sve;
31
31
32
32
namespace {
33
- // Get the LHS or RHS side operand of a vector contract. Handle two cases
34
- // * if the operand is a sign- or zero- extend operation of type `T` from i8
35
- // to i32, return the value before the extension, otherwise
36
- // * if the operand is of i8 type and the operation is sign-extend, return the
37
- // operand itself.
33
+ // Get the operand of a `vector.contract`. This function is intended to abstract
34
+ // away from the particular way a value is extended before feeding it into the
35
+ // `vector.contract` - via zero-extend or an explicit or implicit sign-extend
36
+ // (for implicit sign-extension see `vector.contract` documentation).
38
37
//
39
- // This way we handle both explicit sign- or zero- extension or implicit
40
- // sign-extension.
41
- template <typename T>
38
+ // The template parameter `Op` indicates the extension operation (explicir or
39
+ // implicit) for which we are checking.
40
+ //
41
+ // Return success only for extensions from `i8` to `i32`.
42
+ template <typename Op>
42
43
std::optional<Value> getExtOperand (Value v, Type i8Ty, Type i32Ty) {
43
44
44
- static_assert (llvm::is_one_of<T , arith::ExtSIOp, arith::ExtUIOp>::value,
45
+ static_assert (llvm::is_one_of<Op , arith::ExtSIOp, arith::ExtUIOp>::value,
45
46
" Must be instantiated with either sign- or zero- extension op" );
46
47
47
- auto extOp = dyn_cast_or_null<T>(v.getDefiningOp ());
48
+ // If the operand is not defined by an explicit extend operation of the
49
+ // accepted operation type allow for an implicit sign-extension.
50
+ auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp ());
48
51
if (!extOp) {
49
- if constexpr (std::is_same<T , arith::ExtSIOp>::value) {
52
+ if constexpr (std::is_same<Op , arith::ExtSIOp>::value) {
50
53
auto vTy = cast<VectorType>(v.getType ());
51
54
if (vTy.getElementType () != i8Ty)
52
55
return {};
@@ -55,6 +58,8 @@ std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
55
58
return {};
56
59
}
57
60
61
+ // If the operand is defined by an explicit extend operation of the accepted
62
+ // operation type, check it's extented from `i8` to `i32`.
58
63
auto inOp = extOp.getIn ();
59
64
auto inTy = dyn_cast<VectorType>(inOp.getType ());
60
65
if (!inTy || inTy.getElementType () != i8Ty)
@@ -93,37 +98,38 @@ Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
93
98
}
94
99
}
95
100
96
- // Lower a contraction operation that performs a matrix multiplication
97
- // of two 8-bit integer matrix tiles with logical dimensions <Mx8> and <8x[N]>
98
- // for the left-hand side and the right-hand side, respectively,
99
- // yielding a <Mx[N]> 32-bit integer result.
100
- //
101
- // The operands shapes are such that the operands can be evenly split into
102
- // sub-tiles with dimensions as expected by the targeted FEAT_I8MM instructions.
103
- // The intent is that M and N are chosen (by higher level transforms) in such a
104
- // way as to maximise register usage. The main use case we envision as of now is
105
- // MMT4D, thus the RHS operand is expected pre-transposed.
106
- //
107
- // The matrix multiplication is performed by unrolling the usual tiled matrix
108
- // multiplication algorithm using sub-tiles with dimensions <2x8> for the LHS,
109
- // <8x[2]> for the RHS, and <2x[2]> for the result and the input accumulator.
110
- //
111
- // One way to illustrate the operation is as follows:
112
- //
113
- // RHS<8x[N]>: <8x[2]> <8x[2]> ... <8x[2]>
114
- // +-----------------------------
115
- // LHS<Mx8>: <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
116
- // <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
117
- // ... | ... ... ... ...
118
- // <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
119
- //
120
- // The RHS operand is unpacked into N/2 values, each representing a sequence of
121
- // VSCALE number of sub-tiles with dimensions <8x2>.
122
- // The LHS operand is initially unpacked into M/2 values, each representing a
123
- // sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
124
- // VSCALE times.
125
- // Multiplying thus replicated LHS sub-tile by the corresposponing RHS sub-tile
126
- // correctly computes an entire result sub-tile.
101
+ // / Lower a contraction operation that performs a matrix multiplication
102
+ // / of two 8-bit integer matrix tiles with logical dimensions <Mx8> and <8x[N]>
103
+ // / for the left-hand side and the right-hand side, respectively,
104
+ // / yielding a <Mx[N]> 32-bit integer result.
105
+ // /
106
+ // / The operands' shapes are such that the operands can be evenly split into
107
+ // / sub-tiles with dimensions as expected by the targeted FEAT_I8MM
108
+ // / instructions. The intent is that M and N are chosen (by higher level
109
+ // / transforms) in such a way as to maximise register usage. The main use case
110
+ // / we envision as of now is MMT4D, thus the RHS operand is expected
111
+ // / pre-transposed.
112
+ // /
113
+ // / The matrix multiplication is performed by unrolling the usual tiled matrix
114
+ // / multiplication algorithm using sub-tiles with dimensions <2x8> for the LHS,
115
+ // / <8x[2]> for the RHS, and <2x[2]> for the result and the input accumulator.
116
+ // /
117
+ // / One way to illustrate the operation is as follows:
118
+ // /
119
+ // / RHS<8x[N]>: <8x[2]> <8x[2]> ... <8x[2]>
120
+ // / +-----------------------------
121
+ // / LHS<Mx8>: <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
122
+ // / <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
123
+ // / ... | ... ... ... ...
124
+ // / <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
125
+ // /
126
+ // / The RHS operand is unpacked into N/2 values, each representing a sequence of
127
+ // / VSCALE number of sub-tiles with dimensions <8x2>.
128
+ // / The LHS operand is initially unpacked into M/2 values, each representing a
129
+ // / sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
130
+ // / VSCALE times.
131
+ // / Multiplying thus replicated LHS sub-tile by the corresposponing RHS sub-tile
132
+ // / correctly computes an entire result sub-tile.
127
133
class LowerContractionToSVEI8MMPattern
128
134
: public OpRewritePattern<vector::ContractionOp> {
129
135
public:
@@ -135,27 +141,30 @@ class LowerContractionToSVEI8MMPattern
135
141
mlir::VectorType lhsType = op.getLhsType ();
136
142
mlir::VectorType rhsType = op.getRhsType ();
137
143
138
- // Check the operands have the expected shape. M and N dimensions must be
139
- // even and at least 2.
140
- if (lhsType.getRank () != 2 || rhsType.getRank () != 2 ||
141
- lhsType.isScalable () || !rhsType.isScalable ())
144
+ // Check the rank the types so we can safely examine their dimensions.
145
+ if (lhsType.getRank () != 2 || rhsType.getRank () != 2 )
142
146
return rewriter.notifyMatchFailure (op, " non-matching operand shape" );
143
147
144
- // M, N, and K are the conventional names for matrix dimensions in the
145
- // context of matrix multiplication.
146
148
auto M = lhsType.getDimSize (0 );
147
149
auto N = rhsType.getDimSize (0 );
148
150
auto K = rhsType.getDimSize (1 );
149
151
150
- if (lhsType.getDimSize (1 ) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
151
- N % 2 != 0 || !rhsType.getScalableDims ()[0 ])
152
+ // Check the operands have the expected shape:
153
+ // * for LHS: fixed vector MxK
154
+ // * for RHS: scalable vector [N]xK
155
+ // * K == 8
156
+ // * M and N even and at least 2
157
+ if (lhsType.isScalable () || !rhsType.getScalableDims ()[0 ] ||
158
+ rhsType.getScalableDims ()[1 ] || lhsType.getDimSize (1 ) != K || K != 8 ||
159
+ M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 ||
160
+ !rhsType.getScalableDims ()[0 ])
152
161
return rewriter.notifyMatchFailure (op, " non-matching operand shape" );
153
162
154
163
// Check permutation maps. For now only accept
155
164
// lhs: (d0, d1, d2) -> (d0, d2)
156
165
// rhs: (d0, d1, d2) -> (d1, d2)
157
166
// acc: (d0, d1, d2) -> (d0, d1)
158
- // Note: RHS is transposed.
167
+ // This corresponds to matrix multiplication with transposed RHS .
159
168
if (op.getIndexingMapsArray ()[0 ] !=
160
169
AffineMap::getMultiDimMapWithTargets (3 , ArrayRef{0u , 2u },
161
170
op.getContext ()) ||
@@ -245,7 +254,7 @@ class LowerContractionToSVEI8MMPattern
245
254
}
246
255
247
256
// "Flatten" the RHS tile from <[N]x8> to <[8*N]>.
248
- auto RHS = rewriter.create <vector::ShapeCastOp>(
257
+ auto rhs = rewriter.create <vector::ShapeCastOp>(
249
258
maybeRhs->getLoc (),
250
259
VectorType::get (/* shape=*/ 8 * N, rewriter.getI8Type (),
251
260
/* scalableDims=*/ {true }),
@@ -255,7 +264,7 @@ class LowerContractionToSVEI8MMPattern
255
264
SmallVector<Value> rhsTile;
256
265
for (int64_t j = 0 ; j < N; j += 2 )
257
266
rhsTile.push_back (
258
- rewriter.create <vector::ScalableExtractOp>(loc, nxv16i8, RHS , j * 8 ));
267
+ rewriter.create <vector::ScalableExtractOp>(loc, nxv16i8, rhs , j * 8 ));
259
268
260
269
// Handy types for packing/unpacking of the accumulator tile.
261
270
auto accRowTy = VectorType::get (/* shape=*/ N, rewriter.getI32Type (),
0 commit comments