@@ -40,41 +40,45 @@ static Type matchContainerType(Type element, Type container) {
40
40
41
41
// / Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
42
42
// / any vector.contract into multiple smmla instructions with unrolling so long
43
- // / as [2,2,8] is a divisor of its shape. If no unrolling is necessary, a single
44
- // / smmla instruction is emitted.
43
+ // / as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
44
+ // / = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
45
+ // / necessary, a single smmla instruction is emitted.
45
46
class LowerContractionToSMMLAPattern
46
47
: public OpRewritePattern<vector::ContractionOp> {
47
48
public:
48
49
using OpRewritePattern::OpRewritePattern;
49
50
LogicalResult matchAndRewrite (vector::ContractionOp op,
50
51
PatternRewriter &rewriter) const override {
51
52
Location loc = op.getLoc ();
52
- // Check index maps that represent M N K in contract.
53
- auto indexingMaps = op.getIndexingMapsArray ();
54
- if (llvm::any_of (indexingMaps, [](mlir::AffineMap affineMap) {
55
- return affineMap.isPermutation () || affineMap.getNumDims () != 3 ||
56
- affineMap.getNumResults () != 2 ;
57
- })) {
58
- return failure ();
59
- }
60
- // Check iterator types for contract.
61
- auto iteratorTypes = op.getIteratorTypesArray ();
62
- if (iteratorTypes.size () != 3 ||
63
- iteratorTypes[0 ] != vector::IteratorType::parallel ||
64
- iteratorTypes[1 ] != vector::IteratorType::parallel ||
65
- iteratorTypes[2 ] != vector::IteratorType::reduction) {
66
- return failure ();
67
- }
68
- // Infer tile sizes from operands; Note: RHS is not transposed.
53
+ // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
54
+ // Note: RHS is not transposed.
69
55
mlir::VectorType lhsType = op.getLhsType ();
70
56
mlir::VectorType rhsType = op.getRhsType ();
71
- auto dimM = lhsType.getDimSize (0 );
57
+ auto dimM = lhsType.getRank () == 1 ? 1 : lhsType. getDimSize (0 );
72
58
auto dimN = rhsType.getDimSize (0 );
73
- auto dimK = lhsType.getDimSize (1 );
74
-
59
+ auto dimK = rhsType.getDimSize (1 );
60
+ bool isVecmat = dimM == 1 ? true : false ;
61
+ if (lhsType.getDimSize (lhsType.getRank () - 1 ) !=
62
+ rhsType.getDimSize (rhsType.getRank () - 1 )) {
63
+ return failure (); // dimK mismatch
64
+ }
75
65
// Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
76
66
// tiling.
77
- if (dimM % 2 != 0 || dimN % 2 != 0 || dimK % 8 != 0 ) {
67
+ if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0 ) {
68
+ return failure ();
69
+ }
70
+
71
+ // Check iterator types for contract. All iterators except inner-most
72
+ // dimension must be parallel.
73
+ auto iteratorTypes = op.getIteratorTypesArray ();
74
+ if (iteratorTypes.size () > 3 || iteratorTypes[iteratorTypes.size () - 1 ] !=
75
+ vector::IteratorType::reduction) {
76
+ return failure ();
77
+ }
78
+ if (llvm::any_of (ArrayRef<vector::IteratorType>(iteratorTypes).drop_back (1 ),
79
+ [](vector::IteratorType iteratorType) {
80
+ return iteratorType != vector::IteratorType::parallel;
81
+ })) {
78
82
return failure ();
79
83
}
80
84
@@ -120,11 +124,14 @@ class LowerContractionToSMMLAPattern
120
124
loc, op.getResultType (), rewriter.getZeroAttr (op.getResultType ()));
121
125
122
126
SmallVector<int64_t > unrolledSize = *op.getShapeForUnroll ();
123
- SmallVector<int64_t > smmlaShape{2 , 2 , 8 };
124
- SmallVector<int64_t > loopOrder{0 , 1 , 2 };
127
+ SmallVector<int64_t > smmlaShape{2 , 8 };
128
+ SmallVector<int64_t > loopOrder{0 , 1 };
129
+ if (unrolledSize.size () == 3 ) {
130
+ smmlaShape.insert (smmlaShape.begin (), isVecmat ? 1 : 2 );
131
+ loopOrder.push_back (2 );
132
+ }
125
133
for (SmallVector<int64_t > offsets :
126
134
StaticTileOffsetRange (unrolledSize, smmlaShape, loopOrder)) {
127
-
128
135
// Helper to compute the new shape of each operand and extract the slice.
129
136
auto extractOperand = [&](Value operand, AffineMap permutationMap,
130
137
ArrayRef<int64_t > operandOffsets) {
@@ -150,16 +157,40 @@ class LowerContractionToSMMLAPattern
150
157
Value tiledAcc =
151
158
extractOperand (op.getAcc (), accPermutationMap, accOffsets);
152
159
160
+ auto inputElementType =
161
+ tiledLhs.getType ().cast <ShapedType>().getElementType ();
162
+ auto accElementType =
163
+ tiledAcc.getType ().cast <ShapedType>().getElementType ();
164
+ auto inputExpandedType = VectorType::get ({2 , 8 }, inputElementType);
165
+ auto outputExpandedType = VectorType::get ({2 , 2 }, accElementType);
166
+
167
+ // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
168
+ // rows along dimM. Expand their shapes to match the smmla op.
169
+ if (isVecmat) {
170
+ auto expandForSMMLA = [&](Value tiledOperand,
171
+ VectorType expandedTypeType) {
172
+ auto emptyOperand = rewriter.create <arith::ConstantOp>(
173
+ loc, expandedTypeType, rewriter.getZeroAttr (expandedTypeType));
174
+ SmallVector<int64_t > offsets (
175
+ emptyOperand.getType ().cast <ShapedType>().getRank (), 0 );
176
+ SmallVector<int64_t > strides (
177
+ tiledOperand.getType ().cast <ShapedType>().getRank (), 1 );
178
+ return rewriter.createOrFold <vector::InsertStridedSliceOp>(
179
+ loc, tiledOperand, emptyOperand, offsets, strides);
180
+ };
181
+ tiledLhs = expandForSMMLA (tiledLhs, inputExpandedType);
182
+ tiledAcc = expandForSMMLA (tiledAcc, outputExpandedType);
183
+ }
184
+
153
185
// Collapse tiled operands to 1D vectors required by smmla intrinsic
154
- auto collapsedInputType = VectorType::get (
155
- tiledLhs.getType ().cast <ShapedType>().getNumElements (),
156
- tiledLhs.getType ().cast <ShapedType>().getElementType ());
157
- auto collapsedOutputType = VectorType::get (
158
- {4 }, tiledAcc.getType ().cast <ShapedType>().getElementType ());
186
+ auto collapsedInputType =
187
+ VectorType::get (inputExpandedType.getNumElements (), inputElementType);
159
188
auto collapsedLhs = rewriter.createOrFold <vector::ShapeCastOp>(
160
189
tiledLhs.getLoc (), collapsedInputType, tiledLhs);
161
190
auto collapsedRhs = rewriter.createOrFold <vector::ShapeCastOp>(
162
191
tiledRhs.getLoc (), collapsedInputType, tiledRhs);
192
+ auto collapsedOutputType =
193
+ VectorType::get (outputExpandedType.getNumElements (), accElementType);
163
194
auto collapsedRes = rewriter.createOrFold <vector::ShapeCastOp>(
164
195
tiledAcc.getLoc (), collapsedOutputType, tiledAcc);
165
196
@@ -172,6 +203,11 @@ class LowerContractionToSMMLAPattern
172
203
Value tiledRes = rewriter.createOrFold <vector::ShapeCastOp>(
173
204
smmlaOp.getLoc (), tiledAcc.getType (), smmlaOp);
174
205
206
+ // With vecmat, only one row of tiled ACC can be inserted inot file result
207
+ if (isVecmat) {
208
+ tiledRes = rewriter.createOrFold <vector::ExtractOp>(loc, tiledRes, 0 );
209
+ }
210
+
175
211
// Insert the tiled result back into the non tiled result of the
176
212
// contract op.
177
213
SmallVector<int64_t > strides (
0 commit comments