Skip to content

Commit cd63859

Browse files
committed
diego mehdi comments
1 parent 5b239a8 commit cd63859

File tree

2 files changed

+134
-115
lines changed

2 files changed

+134
-115
lines changed

mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,19 @@ class LowerContractionToSMMLAPattern
6868
return failure();
6969
}
7070

71-
// Check iterator types for contract.
71+
// Check iterator types for contract. All iterators except inner-most
72+
// dimension must be parallel.
7273
auto iteratorTypes = op.getIteratorTypesArray();
7374
if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
7475
vector::IteratorType::reduction) {
7576
return failure();
7677
}
78+
if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
79+
[](vector::IteratorType iteratorType) {
80+
return iteratorType != vector::IteratorType::parallel;
81+
})) {
82+
return failure();
83+
}
7784

7885
// Check two extsi inputs Rhs Lhs for contract.
7986
arith::ExtSIOp origLhsExtOp =
@@ -117,11 +124,11 @@ class LowerContractionToSMMLAPattern
117124
loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
118125

119126
SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
120-
SmallVector<int64_t> smmlaShape{isVecmat ? 1 : 2, 2, 8};
121-
SmallVector<int64_t> loopOrder{0, 1, 2};
122-
if (unrolledSize.size() == 2) {
123-
smmlaShape = {2, 8};
124-
loopOrder = {0, 1};
127+
SmallVector<int64_t> smmlaShape{2, 8};
128+
SmallVector<int64_t> loopOrder{0, 1};
129+
if (unrolledSize.size() == 3) {
130+
smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
131+
loopOrder.push_back(2);
125132
}
126133
for (SmallVector<int64_t> offsets :
127134
StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
@@ -150,30 +157,40 @@ class LowerContractionToSMMLAPattern
150157
Value tiledAcc =
151158
extractOperand(op.getAcc(), accPermutationMap, accOffsets);
152159

160+
auto inputElementType =
161+
tiledLhs.getType().cast<ShapedType>().getElementType();
162+
auto accElementType =
163+
tiledAcc.getType().cast<ShapedType>().getElementType();
164+
auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
165+
auto outputExpandedType = VectorType::get({2, 2}, accElementType);
166+
153167
// With vecmat, tiled LHS and ACC will contain only one of 2 necessary
154-
// rows along dimM. Broadcast both to the full width
168+
// rows along dimM. Expand their shapes to match the smmla op.
155169
if (isVecmat) {
156-
auto lhsBroadcastType = VectorType::get(
157-
{2, 8}, tiledLhs.getType().cast<ShapedType>().getElementType());
158-
tiledLhs = rewriter.create<vector::BroadcastOp>(loc, lhsBroadcastType,
159-
tiledLhs);
160-
auto accBroadcastType = VectorType::get(
161-
{2, 2}, tiledAcc.getType().cast<ShapedType>().getElementType());
162-
tiledAcc = rewriter.create<vector::BroadcastOp>(loc, accBroadcastType,
163-
tiledAcc);
170+
auto expandForSMMLA = [&](Value tiledOperand,
171+
VectorType expandedTypeType) {
172+
auto emptyOperand = rewriter.create<arith::ConstantOp>(
173+
loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
174+
SmallVector<int64_t> offsets(
175+
emptyOperand.getType().cast<ShapedType>().getRank(), 0);
176+
SmallVector<int64_t> strides(
177+
tiledOperand.getType().cast<ShapedType>().getRank(), 1);
178+
return rewriter.createOrFold<vector::InsertStridedSliceOp>(
179+
loc, tiledOperand, emptyOperand, offsets, strides);
180+
};
181+
tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
182+
tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
164183
}
165184

166185
// Collapse tiled operands to 1D vectors required by smmla intrinsic
167-
auto collapsedInputType = VectorType::get(
168-
tiledLhs.getType().cast<ShapedType>().getNumElements(),
169-
tiledLhs.getType().cast<ShapedType>().getElementType());
186+
auto collapsedInputType =
187+
VectorType::get(inputExpandedType.getNumElements(), inputElementType);
170188
auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
171189
tiledLhs.getLoc(), collapsedInputType, tiledLhs);
172190
auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
173191
tiledRhs.getLoc(), collapsedInputType, tiledRhs);
174-
auto collapsedOutputType = VectorType::get(
175-
tiledAcc.getType().cast<ShapedType>().getNumElements(),
176-
tiledAcc.getType().cast<ShapedType>().getElementType());
192+
auto collapsedOutputType =
193+
VectorType::get(outputExpandedType.getNumElements(), accElementType);
177194
auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
178195
tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
179196

0 commit comments

Comments
 (0)