Skip to content

Commit 36108bb

Browse files
[fixup] Commenting and some minor refactoring
1 parent d184b22 commit 36108bb

File tree

2 files changed

+64
-54
lines changed

2 files changed

+64
-54
lines changed

mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ class LowerContractionToSMMLAPattern
5656
// Avoid 0-D vectors and 1-D rhs:
5757
if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
5858
return failure();
59-
// Avoid scalable vectors.
59+
// This codegen does not work for scalable vectors. Return failure so this
60+
// pattern not accidentally chosen over patterns that lower to ArmSVE.
6061
if (lhsType.isScalable() || rhsType.isScalable())
6162
return failure();
6263
auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);

mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp

Lines changed: 62 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -30,23 +30,26 @@ using namespace mlir;
3030
using namespace mlir::arm_sve;
3131

3232
namespace {
33-
// Get the LHS or RHS side operand of a vector contract. Handle two cases
34-
// * if the operand is a sign- or zero- extend operation of type `T` from i8
35-
// to i32, return the value before the extension, otherwise
36-
// * if the operand is of i8 type and the operation is sign-extend, return the
37-
// operand itself.
33+
// Get the operand of a `vector.contract`. This function is intended to abstract
34+
// away from the particular way a value is extended before feeding it into the
35+
// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
36+
// (for implicit sign-extension see `vector.contract` documentation).
3837
//
39-
// This way we handle both explicit sign- or zero- extension or implicit
40-
// sign-extension.
41-
template <typename T>
38+
// The template parameter `Op` indicates the extension operation (explicir or
39+
// implicit) for which we are checking.
40+
//
41+
// Return success only for extensions from `i8` to `i32`.
42+
template <typename Op>
4243
std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
4344

44-
static_assert(llvm::is_one_of<T, arith::ExtSIOp, arith::ExtUIOp>::value,
45+
static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
4546
"Must be instantiated with either sign- or zero- extension op");
4647

47-
auto extOp = dyn_cast_or_null<T>(v.getDefiningOp());
48+
// If the operand is not defined by an explicit extend operation of the
49+
// accepted operation type allow for an implicit sign-extension.
50+
auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
4851
if (!extOp) {
49-
if constexpr (std::is_same<T, arith::ExtSIOp>::value) {
52+
if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
5053
auto vTy = cast<VectorType>(v.getType());
5154
if (vTy.getElementType() != i8Ty)
5255
return {};
@@ -55,6 +58,8 @@ std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
5558
return {};
5659
}
5760

61+
// If the operand is defined by an explicit extend operation of the accepted
62+
// operation type, check it's extented from `i8` to `i32`.
5863
auto inOp = extOp.getIn();
5964
auto inTy = dyn_cast<VectorType>(inOp.getType());
6065
if (!inTy || inTy.getElementType() != i8Ty)
@@ -93,37 +98,38 @@ Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
9398
}
9499
}
95100

96-
// Lower a contraction operation that performs a matrix multiplication
97-
// of two 8-bit integer matrix tiles with logical dimensions <Mx8> and <8x[N]>
98-
// for the left-hand side and the right-hand side, respectively,
99-
// yielding a <Mx[N]> 32-bit integer result.
100-
//
101-
// The operands shapes are such that the operands can be evenly split into
102-
// sub-tiles with dimensions as expected by the targeted FEAT_I8MM instructions.
103-
// The intent is that M and N are chosen (by higher level transforms) in such a
104-
// way as to maximise register usage. The main use case we envision as of now is
105-
// MMT4D, thus the RHS operand is expected pre-transposed.
106-
//
107-
// The matrix multiplication is performed by unrolling the usual tiled matrix
108-
// multiplication algorithm using sub-tiles with dimensions <2x8> for the LHS,
109-
// <8x[2]> for the RHS, and <2x[2]> for the result and the input accumulator.
110-
//
111-
// One way to illustrate the operation is as follows:
112-
//
113-
// RHS<8x[N]>: <8x[2]> <8x[2]> ... <8x[2]>
114-
// +-----------------------------
115-
// LHS<Mx8>: <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
116-
// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
117-
// ... | ... ... ... ...
118-
// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
119-
//
120-
// The RHS operand is unpacked into N/2 values, each representing a sequence of
121-
// VSCALE number of sub-tiles with dimensions <8x2>.
122-
// The LHS operand is initially unpacked into M/2 values, each representing a
123-
// sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
124-
// VSCALE times.
125-
// Multiplying thus replicated LHS sub-tile by the corresposponing RHS sub-tile
126-
// correctly computes an entire result sub-tile.
101+
/// Lower a contraction operation that performs a matrix multiplication
102+
/// of two 8-bit integer matrix tiles with logical dimensions <Mx8> and <8x[N]>
103+
/// for the left-hand side and the right-hand side, respectively,
104+
/// yielding a <Mx[N]> 32-bit integer result.
105+
///
106+
/// The operands' shapes are such that the operands can be evenly split into
107+
/// sub-tiles with dimensions as expected by the targeted FEAT_I8MM
108+
/// instructions. The intent is that M and N are chosen (by higher level
109+
/// transforms) in such a way as to maximise register usage. The main use case
110+
/// we envision as of now is MMT4D, thus the RHS operand is expected
111+
/// pre-transposed.
112+
///
113+
/// The matrix multiplication is performed by unrolling the usual tiled matrix
114+
/// multiplication algorithm using sub-tiles with dimensions <2x8> for the LHS,
115+
/// <8x[2]> for the RHS, and <2x[2]> for the result and the input accumulator.
116+
///
117+
/// One way to illustrate the operation is as follows:
118+
///
119+
/// RHS<8x[N]>: <8x[2]> <8x[2]> ... <8x[2]>
120+
/// +-----------------------------
121+
/// LHS<Mx8>: <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
122+
/// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
123+
/// ... | ... ... ... ...
124+
/// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
125+
///
126+
/// The RHS operand is unpacked into N/2 values, each representing a sequence of
127+
/// VSCALE number of sub-tiles with dimensions <8x2>.
128+
/// The LHS operand is initially unpacked into M/2 values, each representing a
129+
/// sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
130+
/// VSCALE times.
131+
/// Multiplying thus replicated LHS sub-tile by the corresposponing RHS sub-tile
132+
/// correctly computes an entire result sub-tile.
127133
class LowerContractionToSVEI8MMPattern
128134
: public OpRewritePattern<vector::ContractionOp> {
129135
public:
@@ -135,27 +141,30 @@ class LowerContractionToSVEI8MMPattern
135141
mlir::VectorType lhsType = op.getLhsType();
136142
mlir::VectorType rhsType = op.getRhsType();
137143

138-
// Check the operands have the expected shape. M and N dimensions must be
139-
// even and at least 2.
140-
if (lhsType.getRank() != 2 || rhsType.getRank() != 2 ||
141-
lhsType.isScalable() || !rhsType.isScalable())
144+
// Check the rank the types so we can safely examine their dimensions.
145+
if (lhsType.getRank() != 2 || rhsType.getRank() != 2)
142146
return rewriter.notifyMatchFailure(op, "non-matching operand shape");
143147

144-
// M, N, and K are the conventional names for matrix dimensions in the
145-
// context of matrix multiplication.
146148
auto M = lhsType.getDimSize(0);
147149
auto N = rhsType.getDimSize(0);
148150
auto K = rhsType.getDimSize(1);
149151

150-
if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
151-
N % 2 != 0 || !rhsType.getScalableDims()[0])
152+
// Check the operands have the expected shape:
153+
// * for LHS: fixed vector MxK
154+
// * for RHS: scalable vector [N]xK
155+
// * K == 8
156+
// * M and N even and at least 2
157+
if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
158+
rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 8 ||
159+
M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 ||
160+
!rhsType.getScalableDims()[0])
152161
return rewriter.notifyMatchFailure(op, "non-matching operand shape");
153162

154163
// Check permutation maps. For now only accept
155164
// lhs: (d0, d1, d2) -> (d0, d2)
156165
// rhs: (d0, d1, d2) -> (d1, d2)
157166
// acc: (d0, d1, d2) -> (d0, d1)
158-
// Note: RHS is transposed.
167+
// This corresponds to matrix multiplication with transposed RHS.
159168
if (op.getIndexingMapsArray()[0] !=
160169
AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
161170
op.getContext()) ||
@@ -245,7 +254,7 @@ class LowerContractionToSVEI8MMPattern
245254
}
246255

247256
// "Flatten" the RHS tile from <[N]x8> to <[8*N]>.
248-
auto RHS = rewriter.create<vector::ShapeCastOp>(
257+
auto rhs = rewriter.create<vector::ShapeCastOp>(
249258
maybeRhs->getLoc(),
250259
VectorType::get(/*shape=*/8 * N, rewriter.getI8Type(),
251260
/*scalableDims=*/{true}),
@@ -255,7 +264,7 @@ class LowerContractionToSVEI8MMPattern
255264
SmallVector<Value> rhsTile;
256265
for (int64_t j = 0; j < N; j += 2)
257266
rhsTile.push_back(
258-
rewriter.create<vector::ScalableExtractOp>(loc, nxv16i8, RHS, j * 8));
267+
rewriter.create<vector::ScalableExtractOp>(loc, nxv16i8, rhs, j * 8));
259268

260269
// Handy types for packing/unpacking of the accumulator tile.
261270
auto accRowTy = VectorType::get(/*shape=*/N, rewriter.getI32Type(),

0 commit comments

Comments
 (0)