@@ -207,6 +207,39 @@ def Vector_ContractionOp :
207
207
.template getAsValueRange<IteratorTypeAttr, IteratorType>();
208
208
return {range.begin(), range.end()};
209
209
}
210
+
211
+ //===------------------------------------------------------------------===//
212
+ // The code below is shared with LinalgStructuredInterface.
213
+ // vector.contract is really a linalg.generic on vectors without region.
214
+ // TODO: factor out in a common interface to inherit from ince identified.
215
+ //===------------------------------------------------------------------===//
216
+ ArrayRef<int64_t> getShape(OpOperand * opOperand) {
217
+ assert(opOperand->getOwner() == this->getOperation());
218
+ Type t = opOperand->get().getType();
219
+ return cast<VectorType>(t).getShape();
220
+ }
221
+
222
+ AffineMap getLoopsToShapesMap() {
223
+ auto maps = getIndexingMapsArray();
224
+ return concatAffineMaps(maps, getContext());
225
+ }
226
+
227
+ AffineMap getShapesToLoopsMap() {
228
+ return inversePermutation(getLoopsToShapesMap());
229
+ }
230
+
231
+ SmallVector<int64_t> getStaticShape(){
232
+ SmallVector<int64_t> res;
233
+ for (OpOperand &opOperand : this->getOperation()->getOpOperands())
234
+ llvm::append_range(res, getShape(&opOperand));
235
+ return res;
236
+ }
237
+
238
+ SmallVector<int64_t> getStaticLoopRanges() {
239
+ SmallVector<int64_t> viewSizes = getStaticShape();
240
+ AffineMap invertedMap = getShapesToLoopsMap();
241
+ return invertedMap.compose(viewSizes);
242
+ }
210
243
}];
211
244
212
245
let hasCanonicalizer = 1;
0 commit comments