@@ -68,12 +68,19 @@ class LowerContractionToSMMLAPattern
68
68
return failure ();
69
69
}
70
70
71
- // Check iterator types for contract.
71
+ // Check iterator types for contract. All iterators except inner-most
72
+ // dimension must be parallel.
72
73
auto iteratorTypes = op.getIteratorTypesArray ();
73
74
if (iteratorTypes.size () > 3 || iteratorTypes[iteratorTypes.size () - 1 ] !=
74
75
vector::IteratorType::reduction) {
75
76
return failure ();
76
77
}
78
+ if (llvm::any_of (ArrayRef<vector::IteratorType>(iteratorTypes).drop_back (1 ),
79
+ [](vector::IteratorType iteratorType) {
80
+ return iteratorType != vector::IteratorType::parallel;
81
+ })) {
82
+ return failure ();
83
+ }
77
84
78
85
// Check two extsi inputs Rhs Lhs for contract.
79
86
arith::ExtSIOp origLhsExtOp =
@@ -117,11 +124,11 @@ class LowerContractionToSMMLAPattern
117
124
loc, op.getResultType (), rewriter.getZeroAttr (op.getResultType ()));
118
125
119
126
SmallVector<int64_t > unrolledSize = *op.getShapeForUnroll ();
120
- SmallVector<int64_t > smmlaShape{isVecmat ? 1 : 2 , 2 , 8 };
121
- SmallVector<int64_t > loopOrder{0 , 1 , 2 };
122
- if (unrolledSize.size () == 2 ) {
123
- smmlaShape = { 2 , 8 } ;
124
- loopOrder = { 0 , 1 } ;
127
+ SmallVector<int64_t > smmlaShape{2 , 8 };
128
+ SmallVector<int64_t > loopOrder{0 , 1 };
129
+ if (unrolledSize.size () == 3 ) {
130
+ smmlaShape. insert (smmlaShape. begin (), isVecmat ? 1 : 2 ) ;
131
+ loopOrder. push_back ( 2 ) ;
125
132
}
126
133
for (SmallVector<int64_t > offsets :
127
134
StaticTileOffsetRange (unrolledSize, smmlaShape, loopOrder)) {
@@ -150,30 +157,40 @@ class LowerContractionToSMMLAPattern
150
157
Value tiledAcc =
151
158
extractOperand (op.getAcc (), accPermutationMap, accOffsets);
152
159
160
+ auto inputElementType =
161
+ tiledLhs.getType ().cast <ShapedType>().getElementType ();
162
+ auto accElementType =
163
+ tiledAcc.getType ().cast <ShapedType>().getElementType ();
164
+ auto inputExpandedType = VectorType::get ({2 , 8 }, inputElementType);
165
+ auto outputExpandedType = VectorType::get ({2 , 2 }, accElementType);
166
+
153
167
// With vecmat, tiled LHS and ACC will contain only one of 2 necessary
154
- // rows along dimM. Broadcast both to the full width
168
+ // rows along dimM. Expand their shapes to match the smmla op.
155
169
if (isVecmat) {
156
- auto lhsBroadcastType = VectorType::get (
157
- {2 , 8 }, tiledLhs.getType ().cast <ShapedType>().getElementType ());
158
- tiledLhs = rewriter.create <vector::BroadcastOp>(loc, lhsBroadcastType,
159
- tiledLhs);
160
- auto accBroadcastType = VectorType::get (
161
- {2 , 2 }, tiledAcc.getType ().cast <ShapedType>().getElementType ());
162
- tiledAcc = rewriter.create <vector::BroadcastOp>(loc, accBroadcastType,
163
- tiledAcc);
170
+ auto expandForSMMLA = [&](Value tiledOperand,
171
+ VectorType expandedTypeType) {
172
+ auto emptyOperand = rewriter.create <arith::ConstantOp>(
173
+ loc, expandedTypeType, rewriter.getZeroAttr (expandedTypeType));
174
+ SmallVector<int64_t > offsets (
175
+ emptyOperand.getType ().cast <ShapedType>().getRank (), 0 );
176
+ SmallVector<int64_t > strides (
177
+ tiledOperand.getType ().cast <ShapedType>().getRank (), 1 );
178
+ return rewriter.createOrFold <vector::InsertStridedSliceOp>(
179
+ loc, tiledOperand, emptyOperand, offsets, strides);
180
+ };
181
+ tiledLhs = expandForSMMLA (tiledLhs, inputExpandedType);
182
+ tiledAcc = expandForSMMLA (tiledAcc, outputExpandedType);
164
183
}
165
184
166
185
// Collapse tiled operands to 1D vectors required by smmla intrinsic
167
- auto collapsedInputType = VectorType::get (
168
- tiledLhs.getType ().cast <ShapedType>().getNumElements (),
169
- tiledLhs.getType ().cast <ShapedType>().getElementType ());
186
+ auto collapsedInputType =
187
+ VectorType::get (inputExpandedType.getNumElements (), inputElementType);
170
188
auto collapsedLhs = rewriter.createOrFold <vector::ShapeCastOp>(
171
189
tiledLhs.getLoc (), collapsedInputType, tiledLhs);
172
190
auto collapsedRhs = rewriter.createOrFold <vector::ShapeCastOp>(
173
191
tiledRhs.getLoc (), collapsedInputType, tiledRhs);
174
- auto collapsedOutputType = VectorType::get (
175
- tiledAcc.getType ().cast <ShapedType>().getNumElements (),
176
- tiledAcc.getType ().cast <ShapedType>().getElementType ());
192
+ auto collapsedOutputType =
193
+ VectorType::get (outputExpandedType.getNumElements (), accElementType);
177
194
auto collapsedRes = rewriter.createOrFold <vector::ShapeCastOp>(
178
195
tiledAcc.getLoc (), collapsedOutputType, tiledAcc);
179
196
0 commit comments