Skip to content

Commit c511c90

Browse files
authored
[mlir][ArmNeon] Updates LowerContractionToSMMLAPattern with vecmat unroll patterns (#86005)
Updates smmla unrolling patterns to handle vecmat contracts where `dimM=1`. This includes explicit vecmats in the form: `<1x8xi8> x <8x8xi8> --> <1x8xi32>` or implied with the leading dim folded: `<8xi8> x <8x8xi8> --> <8xi32>` Since the smmla operates on two `<2x8xi8>` input vectors to produce `<2x2xi8>` accumulators, half of each 2x2 accumulator tile is dummy data not pertinent to the computation, resulting in half throughput.
1 parent be57c90 commit c511c90

File tree

2 files changed

+191
-31
lines changed

2 files changed

+191
-31
lines changed

mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp

Lines changed: 67 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -40,41 +40,45 @@ static Type matchContainerType(Type element, Type container) {
4040

4141
/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
4242
/// any vector.contract into multiple smmla instructions with unrolling so long
43-
/// as [2,2,8] is a divisor of its shape. If no unrolling is necessary, a single
44-
/// smmla instruction is emitted.
43+
/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
44+
/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
45+
/// necessary, a single smmla instruction is emitted.
4546
class LowerContractionToSMMLAPattern
4647
: public OpRewritePattern<vector::ContractionOp> {
4748
public:
4849
using OpRewritePattern::OpRewritePattern;
4950
LogicalResult matchAndRewrite(vector::ContractionOp op,
5051
PatternRewriter &rewriter) const override {
5152
Location loc = op.getLoc();
52-
// Check index maps that represent M N K in contract.
53-
auto indexingMaps = op.getIndexingMapsArray();
54-
if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
55-
return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
56-
affineMap.getNumResults() != 2;
57-
})) {
58-
return failure();
59-
}
60-
// Check iterator types for contract.
61-
auto iteratorTypes = op.getIteratorTypesArray();
62-
if (iteratorTypes.size() != 3 ||
63-
iteratorTypes[0] != vector::IteratorType::parallel ||
64-
iteratorTypes[1] != vector::IteratorType::parallel ||
65-
iteratorTypes[2] != vector::IteratorType::reduction) {
66-
return failure();
67-
}
68-
// Infer tile sizes from operands; Note: RHS is not transposed.
53+
// Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
54+
// Note: RHS is not transposed.
6955
mlir::VectorType lhsType = op.getLhsType();
7056
mlir::VectorType rhsType = op.getRhsType();
71-
auto dimM = lhsType.getDimSize(0);
57+
auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
7258
auto dimN = rhsType.getDimSize(0);
73-
auto dimK = lhsType.getDimSize(1);
74-
59+
auto dimK = rhsType.getDimSize(1);
60+
bool isVecmat = dimM == 1 ? true : false;
61+
if (lhsType.getDimSize(lhsType.getRank() - 1) !=
62+
rhsType.getDimSize(rhsType.getRank() - 1)) {
63+
return failure(); // dimK mismatch
64+
}
7565
// Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
7666
// tiling.
77-
if (dimM % 2 != 0 || dimN % 2 != 0 || dimK % 8 != 0) {
67+
if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
68+
return failure();
69+
}
70+
71+
// Check iterator types for contract. All iterators except inner-most
72+
// dimension must be parallel.
73+
auto iteratorTypes = op.getIteratorTypesArray();
74+
if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
75+
vector::IteratorType::reduction) {
76+
return failure();
77+
}
78+
if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
79+
[](vector::IteratorType iteratorType) {
80+
return iteratorType != vector::IteratorType::parallel;
81+
})) {
7882
return failure();
7983
}
8084

@@ -120,11 +124,14 @@ class LowerContractionToSMMLAPattern
120124
loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
121125

122126
SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
123-
SmallVector<int64_t> smmlaShape{2, 2, 8};
124-
SmallVector<int64_t> loopOrder{0, 1, 2};
127+
SmallVector<int64_t> smmlaShape{2, 8};
128+
SmallVector<int64_t> loopOrder{0, 1};
129+
if (unrolledSize.size() == 3) {
130+
smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
131+
loopOrder.push_back(2);
132+
}
125133
for (SmallVector<int64_t> offsets :
126134
StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
127-
128135
// Helper to compute the new shape of each operand and extract the slice.
129136
auto extractOperand = [&](Value operand, AffineMap permutationMap,
130137
ArrayRef<int64_t> operandOffsets) {
@@ -150,16 +157,40 @@ class LowerContractionToSMMLAPattern
150157
Value tiledAcc =
151158
extractOperand(op.getAcc(), accPermutationMap, accOffsets);
152159

160+
auto inputElementType =
161+
tiledLhs.getType().cast<ShapedType>().getElementType();
162+
auto accElementType =
163+
tiledAcc.getType().cast<ShapedType>().getElementType();
164+
auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
165+
auto outputExpandedType = VectorType::get({2, 2}, accElementType);
166+
167+
// With vecmat, tiled LHS and ACC will contain only one of 2 necessary
168+
// rows along dimM. Expand their shapes to match the smmla op.
169+
if (isVecmat) {
170+
auto expandForSMMLA = [&](Value tiledOperand,
171+
VectorType expandedTypeType) {
172+
auto emptyOperand = rewriter.create<arith::ConstantOp>(
173+
loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
174+
SmallVector<int64_t> offsets(
175+
emptyOperand.getType().cast<ShapedType>().getRank(), 0);
176+
SmallVector<int64_t> strides(
177+
tiledOperand.getType().cast<ShapedType>().getRank(), 1);
178+
return rewriter.createOrFold<vector::InsertStridedSliceOp>(
179+
loc, tiledOperand, emptyOperand, offsets, strides);
180+
};
181+
tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
182+
tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
183+
}
184+
153185
// Collapse tiled operands to 1D vectors required by smmla intrinsic
154-
auto collapsedInputType = VectorType::get(
155-
tiledLhs.getType().cast<ShapedType>().getNumElements(),
156-
tiledLhs.getType().cast<ShapedType>().getElementType());
157-
auto collapsedOutputType = VectorType::get(
158-
{4}, tiledAcc.getType().cast<ShapedType>().getElementType());
186+
auto collapsedInputType =
187+
VectorType::get(inputExpandedType.getNumElements(), inputElementType);
159188
auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
160189
tiledLhs.getLoc(), collapsedInputType, tiledLhs);
161190
auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
162191
tiledRhs.getLoc(), collapsedInputType, tiledRhs);
192+
auto collapsedOutputType =
193+
VectorType::get(outputExpandedType.getNumElements(), accElementType);
163194
auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
164195
tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
165196

@@ -172,6 +203,11 @@ class LowerContractionToSMMLAPattern
172203
Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
173204
smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
174205

206+
// With vecmat, only one row of tiled ACC can be inserted inot file result
207+
if (isVecmat) {
208+
tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
209+
}
210+
175211
// Insert the tiled result back into the non tiled result of the
176212
// contract op.
177213
SmallVector<int64_t> strides(

0 commit comments

Comments
 (0)