Skip to content

Commit 91f7ea2

Browse files
[mlir][vector] NFC - Add more structured interface support to vector.contract
1 parent 7e77aae commit 91f7ea2

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,39 @@ def Vector_ContractionOp :
207207
.template getAsValueRange<IteratorTypeAttr, IteratorType>();
208208
return {range.begin(), range.end()};
209209
}
210+
211+
//===------------------------------------------------------------------===//
212+
// The code below is shared with LinalgStructuredInterface.
213+
// vector.contract is really a linalg.generic on vectors without region.
214+
// TODO: factor out in a common interface to inherit from ince identified.
215+
//===------------------------------------------------------------------===//
216+
ArrayRef<int64_t> getShape(OpOperand * opOperand) {
217+
assert(opOperand->getOwner() == this->getOperation());
218+
Type t = opOperand->get().getType();
219+
return cast<VectorType>(t).getShape();
220+
}
221+
222+
AffineMap getLoopsToShapesMap() {
223+
auto maps = getIndexingMapsArray();
224+
return concatAffineMaps(maps, getContext());
225+
}
226+
227+
AffineMap getShapesToLoopsMap() {
228+
return inversePermutation(getLoopsToShapesMap());
229+
}
230+
231+
SmallVector<int64_t> getStaticShape(){
232+
SmallVector<int64_t> res;
233+
for (OpOperand &opOperand : this->getOperation()->getOpOperands())
234+
llvm::append_range(res, getShape(&opOperand));
235+
return res;
236+
}
237+
238+
SmallVector<int64_t> getStaticLoopRanges() {
239+
SmallVector<int64_t> viewSizes = getStaticShape();
240+
AffineMap invertedMap = getShapesToLoopsMap();
241+
return invertedMap.compose(viewSizes);
242+
}
210243
}];
211244

212245
let hasCanonicalizer = 1;

0 commit comments

Comments
 (0)