@@ -267,9 +267,9 @@ struct UnitExtentReplacementInfo {
267
267
// / - modified index map that can be used to access the replaced result/operand
268
268
// / - the reassociation that converts from the original tensor type to the
269
269
// / modified tensor type.
270
- static UnitExtentReplacementInfo replaceUnitExtents (GenericOp genericOp,
271
- OpOperand *opOperand,
272
- MLIRContext *context) {
270
+ static llvm::Optional< UnitExtentReplacementInfo>
271
+ replaceUnitExtents (GenericOp genericOp, OpOperand *opOperand,
272
+ MLIRContext *context) {
273
273
AffineMap indexingMap = genericOp.getTiedIndexingMap (opOperand);
274
274
ArrayRef<int64_t > shape = genericOp.getShape (opOperand);
275
275
ArrayRef<AffineExpr> exprs = indexingMap.getResults ();
@@ -284,6 +284,14 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
284
284
return shape[dim] == 1 && exprs[dim] == zeroExpr;
285
285
};
286
286
287
+ // Early return for memrefs with affine maps to represent that we will always
288
+ // leave them unchanged.
289
+ Type actualType = opOperand->get ().getType ();
290
+ if (auto memref = actualType.dyn_cast <MemRefType>()) {
291
+ if (!memref.getAffineMaps ().empty ())
292
+ return llvm::None;
293
+ }
294
+
287
295
int64_t dim = 0 ;
288
296
// Fold dimensions that are unit-extent at the beginning of the tensor.
289
297
while (dim < origRank && isUnitExtent (dim))
@@ -302,17 +310,15 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
302
310
reassociations.clear ();
303
311
++dim;
304
312
}
313
+
305
314
// Compute the tensor or scalar replacement type.
306
- Type actualType = opOperand->get ().getType ();
307
315
Type elementType = getElementTypeOrSelf (opOperand->get ());
308
316
Type replacementType;
309
317
if (elementType == opOperand->get ().getType ()) {
310
318
replacementType = elementType;
311
319
} else if (actualType.isa <RankedTensorType>()) {
312
320
replacementType = RankedTensorType::get (newShape, elementType);
313
321
} else if (actualType.isa <MemRefType>()) {
314
- assert (actualType.cast <MemRefType>().getAffineMaps ().empty () &&
315
- " unsupported strided memrefs" );
316
322
replacementType = MemRefType::get (newShape, elementType);
317
323
}
318
324
assert (replacementType && " unsupported shaped type" );
@@ -390,12 +396,28 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
390
396
SmallVector<Type> newInputOutputTypes;
391
397
bool doCanonicalization = false ;
392
398
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands ()) {
393
- UnitExtentReplacementInfo replacementInfo =
394
- replaceUnitExtents (genericOp, opOperand, context);
395
- reassociationMaps.push_back (replacementInfo.reassociation );
396
- newIndexingMaps.push_back (replacementInfo.indexMap );
397
- newInputOutputTypes.push_back (replacementInfo.type );
398
- doCanonicalization |= replacementInfo.type != opOperand->get ().getType ();
399
+ auto replacementInfo = replaceUnitExtents (genericOp, opOperand, context);
400
+ if (replacementInfo) {
401
+ reassociationMaps.push_back (replacementInfo->reassociation );
402
+ newIndexingMaps.push_back (replacementInfo->indexMap );
403
+ newInputOutputTypes.push_back (replacementInfo->type );
404
+ doCanonicalization |=
405
+ replacementInfo->type != opOperand->get ().getType ();
406
+ } else {
407
+ // If replaceUnitExtents cannot handle this case, maintain the same
408
+ // type, indexing map, and create a set of mappings representing an
409
+ // identity matrix.
410
+ newInputOutputTypes.push_back (opOperand->get ().getType ());
411
+ newIndexingMaps.push_back (genericOp.getTiedIndexingMap (opOperand));
412
+ int64_t origRank = genericOp.getRank (opOperand);
413
+ auto maps = llvm::to_vector<8 >(llvm::map_range (
414
+ llvm::seq<int64_t >(0 , origRank), [&](int64_t dim) -> Attribute {
415
+ return AffineMapAttr::get (
416
+ AffineMap::get (origRank, /* symbolCount = */ 0 ,
417
+ getAffineDimExpr (dim, context), context));
418
+ }));
419
+ reassociationMaps.push_back (ArrayAttr::get (context, maps));
420
+ }
399
421
}
400
422
401
423
// If the indexing maps of the result operation are not invertible (i.e. not
0 commit comments