Skip to content

Commit c698505

Browse files
author
Tobias Gysi
committed
[mlir][linalg] Cleanup LinalgOp usage in drop unit dims.
Replace the uses of deprecated Structured Op Interface methods in DropUnitDims.cpp. This patch is based on https://reviews.llvm.org/D103394. Differential Revision: https://reviews.llvm.org/D103448
1 parent 728cc00 commit c698505

File tree

1 file changed

+14
-16
lines changed

1 file changed

+14
-16
lines changed

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
183183
AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
184184
if (!invertedMap)
185185
return failure();
186-
SmallVector<int64_t, 4> dims;
187-
for (ShapedType shapedType : genericOp.getShapedOperandTypes())
188-
dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
186+
SmallVector<int64_t> dims = genericOp.getStaticShape();
189187

190188
// Find all the reduction iterators. Those need some special consideration
191189
// (see below).
@@ -267,17 +265,18 @@ struct UnitExtentReplacementInfo {
267265
/// - modified index map that can be used to access the replaced result/operand
268266
/// - the reassociation that converts from the original tensor type to the
269267
/// modified tensor type.
270-
static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
271-
RankedTensorType type,
268+
static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
269+
OpOperand *opOperand,
272270
MLIRContext *context) {
273-
ArrayRef<int64_t> shape = type.getShape();
274-
ArrayRef<AffineExpr> exprs = indexMap.getResults();
271+
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
272+
ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
273+
ArrayRef<AffineExpr> exprs = indexingMap.getResults();
275274
SmallVector<AffineExpr, 2> reassociations;
276275
SmallVector<Attribute, 4> reassociationMaps;
277276
SmallVector<AffineExpr, 4> newIndexExprs;
278277
SmallVector<int64_t, 4> newShape;
279278

280-
int64_t origRank = type.getRank();
279+
int64_t origRank = genericOp.getRank(opOperand);
281280
AffineExpr zeroExpr = getAffineConstantExpr(0, context);
282281
auto isUnitExtent = [&](int64_t dim) -> bool {
283282
return shape[dim] == 1 && exprs[dim] == zeroExpr;
@@ -302,8 +301,9 @@ static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
302301
++dim;
303302
}
304303
UnitExtentReplacementInfo info = {
305-
RankedTensorType::get(newShape, type.getElementType()),
306-
AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(),
304+
RankedTensorType::get(newShape,
305+
getElementTypeOrSelf(opOperand->get().getType())),
306+
AffineMap::get(indexingMap.getNumDims(), indexingMap.getNumSymbols(),
307307
newIndexExprs, context),
308308
ArrayAttr::get(context, reassociationMaps)};
309309
return info;
@@ -335,15 +335,13 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
335335
SmallVector<ArrayAttr, 4> reassociationMaps;
336336
SmallVector<ShapedType, 4> newInputOutputTypes;
337337
bool doCanonicalization = false;
338-
for (auto it : llvm::zip(genericOp.getIndexingMaps(),
339-
genericOp.getShapedOperandTypes())) {
340-
auto replacementInfo = replaceUnitExtents(
341-
std::get<0>(it), std::get<1>(it).template cast<RankedTensorType>(),
342-
context);
338+
339+
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
340+
auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
343341
reassociationMaps.push_back(replacementInfo.reassociation);
344342
newIndexingMaps.push_back(replacementInfo.indexMap);
345343
newInputOutputTypes.push_back(replacementInfo.type);
346-
doCanonicalization |= replacementInfo.type != std::get<1>(it);
344+
doCanonicalization |= replacementInfo.type != opOperand->get().getType();
347345
}
348346

349347
// If the indexing maps of the result operation are not invertible (i.e. not

0 commit comments

Comments
 (0)