@@ -183,9 +183,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
183
183
AffineMap invertedMap = inversePermutation (concatAffineMaps (indexingMaps));
184
184
if (!invertedMap)
185
185
return failure ();
186
- SmallVector<int64_t , 4 > dims;
187
- for (ShapedType shapedType : genericOp.getShapedOperandTypes ())
188
- dims.append (shapedType.getShape ().begin (), shapedType.getShape ().end ());
186
+ SmallVector<int64_t > dims = genericOp.getStaticShape ();
189
187
190
188
// Find all the reduction iterators. Those need some special consideration
191
189
// (see below).
@@ -267,17 +265,18 @@ struct UnitExtentReplacementInfo {
267
265
// / - modified index map that can be used to access the replaced result/operand
268
266
// / - the reassociation that converts from the original tensor type to the
269
267
// / modified tensor type.
270
- static UnitExtentReplacementInfo replaceUnitExtents (AffineMap indexMap ,
271
- RankedTensorType type ,
268
+ static UnitExtentReplacementInfo replaceUnitExtents (GenericOp genericOp ,
269
+ OpOperand *opOperand ,
272
270
MLIRContext *context) {
273
- ArrayRef<int64_t > shape = type.getShape ();
274
- ArrayRef<AffineExpr> exprs = indexMap.getResults ();
271
+ AffineMap indexingMap = genericOp.getTiedIndexingMap (opOperand);
272
+ ArrayRef<int64_t > shape = genericOp.getShape (opOperand);
273
+ ArrayRef<AffineExpr> exprs = indexingMap.getResults ();
275
274
SmallVector<AffineExpr, 2 > reassociations;
276
275
SmallVector<Attribute, 4 > reassociationMaps;
277
276
SmallVector<AffineExpr, 4 > newIndexExprs;
278
277
SmallVector<int64_t , 4 > newShape;
279
278
280
- int64_t origRank = type .getRank ();
279
+ int64_t origRank = genericOp .getRank (opOperand );
281
280
AffineExpr zeroExpr = getAffineConstantExpr (0 , context);
282
281
auto isUnitExtent = [&](int64_t dim) -> bool {
283
282
return shape[dim] == 1 && exprs[dim] == zeroExpr;
@@ -302,8 +301,9 @@ static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
302
301
++dim;
303
302
}
304
303
UnitExtentReplacementInfo info = {
305
- RankedTensorType::get (newShape, type.getElementType ()),
306
- AffineMap::get (indexMap.getNumDims (), indexMap.getNumSymbols (),
304
+ RankedTensorType::get (newShape,
305
+ getElementTypeOrSelf (opOperand->get ().getType ())),
306
+ AffineMap::get (indexingMap.getNumDims (), indexingMap.getNumSymbols (),
307
307
newIndexExprs, context),
308
308
ArrayAttr::get (context, reassociationMaps)};
309
309
return info;
@@ -335,15 +335,13 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
335
335
SmallVector<ArrayAttr, 4 > reassociationMaps;
336
336
SmallVector<ShapedType, 4 > newInputOutputTypes;
337
337
bool doCanonicalization = false ;
338
- for (auto it : llvm::zip (genericOp.getIndexingMaps (),
339
- genericOp.getShapedOperandTypes ())) {
340
- auto replacementInfo = replaceUnitExtents (
341
- std::get<0 >(it), std::get<1 >(it).template cast <RankedTensorType>(),
342
- context);
338
+
339
+ for (OpOperand *opOperand : genericOp.getInputAndOutputOperands ()) {
340
+ auto replacementInfo = replaceUnitExtents (genericOp, opOperand, context);
343
341
reassociationMaps.push_back (replacementInfo.reassociation );
344
342
newIndexingMaps.push_back (replacementInfo.indexMap );
345
343
newInputOutputTypes.push_back (replacementInfo.type );
346
- doCanonicalization |= replacementInfo.type != std::get< 1 >(it );
344
+ doCanonicalization |= replacementInfo.type != opOperand-> get (). getType ( );
347
345
}
348
346
349
347
// If the indexing maps of the result operation are not invertible (i.e. not
0 commit comments