Skip to content

Commit 821202f

Browse files
Normalize reinterpret_cast op
Rewrites the memref defined by reinterpet_cast op to have an identity layout map and updates all its indexing uses. Also extend `replaceAllMemRefUsesWith` utility to work when there are multiple occurences of `oldMemRef` in `op`'s operand list when op is non-dereferencing.
1 parent 2df25a4 commit 821202f

File tree

5 files changed

+191
-120
lines changed

5 files changed

+191
-120
lines changed

mlir/include/mlir/Dialect/Affine/Utils.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class FuncOp;
3232
namespace memref {
3333
class AllocOp;
3434
class AllocaOp;
35+
class ReinterpretCastOp;
3536
} // namespace memref
3637

3738
namespace affine {
@@ -243,15 +244,16 @@ LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
243244
ArrayRef<Value> symbolOperands = {},
244245
bool allowNonDereferencingOps = false);
245246

246-
/// Rewrites the memref defined by this alloc op to have an identity layout map
247-
/// and updates all its indexing uses. Returns failure if any of its uses
248-
/// escape (while leaving the IR in a valid state).
247+
/// Rewrites the memref defined by alloc or reinterpret_cast op to have an
248+
/// identity layout map and updates all its indexing uses. Returns failure if
249+
/// any of its uses escape (while leaving the IR in a valid state).
249250
template <typename AllocLikeOp>
250251
LogicalResult normalizeMemRef(AllocLikeOp *op);
251252
extern template LogicalResult
252253
normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
253254
extern template LogicalResult
254255
normalizeMemRef<memref::AllocOp>(memref::AllocOp *op);
256+
LogicalResult normalizeMemRef(memref::ReinterpretCastOp *op);
255257

256258
/// Normalizes `memrefType` so that the affine layout map of the memref is
257259
/// transformed to an identity map with a new shape being computed for the

mlir/lib/Dialect/Affine/Utils/Utils.cpp

Lines changed: 143 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,90 +1098,12 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
10981098
op->erase();
10991099
}
11001100

1101-
// Private helper function to transform memref.load with reduced rank.
1102-
// This function will modify the indices of the memref.load to match the
1103-
// newMemRef.
1104-
LogicalResult transformMemRefLoadWithReducedRank(
1105-
Operation *op, Value oldMemRef, Value newMemRef, unsigned memRefOperandPos,
1106-
ArrayRef<Value> extraIndices, ArrayRef<Value> extraOperands,
1107-
ArrayRef<Value> symbolOperands, AffineMap indexRemap) {
1108-
unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank();
1109-
unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank();
1110-
unsigned oldMapNumInputs = oldMemRefRank;
1111-
SmallVector<Value, 4> oldMapOperands(
1112-
op->operand_begin() + memRefOperandPos + 1,
1113-
op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
1114-
SmallVector<Value, 4> oldMemRefOperands;
1115-
oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
1116-
SmallVector<Value, 4> remapOperands;
1117-
remapOperands.reserve(extraOperands.size() + oldMemRefRank +
1118-
symbolOperands.size());
1119-
remapOperands.append(extraOperands.begin(), extraOperands.end());
1120-
remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
1121-
remapOperands.append(symbolOperands.begin(), symbolOperands.end());
1122-
1123-
SmallVector<Value, 4> remapOutputs;
1124-
remapOutputs.reserve(oldMemRefRank);
1125-
SmallVector<Value, 4> affineApplyOps;
1126-
1127-
OpBuilder builder(op);
1128-
1129-
if (indexRemap &&
1130-
indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
1131-
// Remapped indices.
1132-
for (auto resultExpr : indexRemap.getResults()) {
1133-
auto singleResMap = AffineMap::get(
1134-
indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
1135-
auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
1136-
remapOperands);
1137-
remapOutputs.push_back(afOp);
1138-
affineApplyOps.push_back(afOp);
1139-
}
1140-
} else {
1141-
// No remapping specified.
1142-
remapOutputs.assign(remapOperands.begin(), remapOperands.end());
1143-
}
1144-
1145-
SmallVector<Value, 4> newMapOperands;
1146-
newMapOperands.reserve(newMemRefRank);
1147-
1148-
// Prepend 'extraIndices' in 'newMapOperands'.
1149-
for (Value extraIndex : extraIndices) {
1150-
assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
1151-
"invalid memory op index");
1152-
newMapOperands.push_back(extraIndex);
1153-
}
1154-
1155-
// Append 'remapOutputs' to 'newMapOperands'.
1156-
newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
1157-
1158-
// Create new fully composed AffineMap for new op to be created.
1159-
assert(newMapOperands.size() == newMemRefRank);
1160-
1161-
OperationState state(op->getLoc(), op->getName());
1162-
// Construct the new operation using this memref.
1163-
state.operands.reserve(newMapOperands.size() + extraIndices.size());
1164-
state.operands.push_back(newMemRef);
1165-
1166-
// Insert the new memref map operands.
1167-
state.operands.append(newMapOperands.begin(), newMapOperands.end());
1168-
1169-
state.types.reserve(op->getNumResults());
1170-
for (auto result : op->getResults())
1171-
state.types.push_back(result.getType());
1172-
1173-
// Copy over the attributes from the old operation to the new operation.
1174-
for (auto namedAttr : op->getAttrs()) {
1175-
state.attributes.push_back(namedAttr);
1176-
}
1177-
1178-
// Create the new operation.
1179-
auto *repOp = builder.create(state);
1180-
op->replaceAllUsesWith(repOp);
1181-
op->erase();
1182-
1183-
return success();
1101+
// Checks if `op` is non dereferencing.
1102+
// TODO: This hardcoded check will be removed once the right interface is added.
1103+
static bool isDereferencingOp(Operation *op) {
1104+
return isa<AffineMapAccessInterface, memref::LoadOp, memref::StoreOp>(op);
11841105
}
1106+
11851107
// Perform the replacement in `op`.
11861108
LogicalResult mlir::affine::replaceAllMemRefUsesWith(
11871109
Value oldMemRef, Value newMemRef, Operation *op,
@@ -1216,53 +1138,57 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
12161138
if (usePositions.empty())
12171139
return success();
12181140

1219-
if (usePositions.size() > 1) {
1220-
// TODO: extend it for this case when needed (rare).
1221-
assert(false && "multiple dereferencing uses in a single op not supported");
1222-
return failure();
1223-
}
1224-
12251141
unsigned memRefOperandPos = usePositions.front();
12261142

12271143
OpBuilder builder(op);
12281144
// The following checks if op is dereferencing memref and performs the access
12291145
// index rewrites.
12301146
auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
1231-
if (!affMapAccInterface) {
1147+
if (!isDereferencingOp(op)) {
12321148
if (!allowNonDereferencingOps) {
12331149
// Failure: memref used in a non-dereferencing context (potentially
12341150
// escapes); no replacement in these cases unless allowNonDereferencingOps
12351151
// is set.
12361152
return failure();
12371153
}
1154+
for (unsigned pos : usePositions)
1155+
op->setOperand(pos, newMemRef);
1156+
return success();
1157+
}
12381158

1239-
// Check if it is a memref.load
1240-
auto memrefLoad = dyn_cast<memref::LoadOp>(op);
1241-
bool isReductionLike =
1242-
indexRemap.getNumResults() < indexRemap.getNumInputs();
1243-
if (!memrefLoad || !isReductionLike) {
1244-
op->setOperand(memRefOperandPos, newMemRef);
1245-
return success();
1246-
}
1159+
if (usePositions.size() > 1) {
1160+
// TODO: extend it for this case when needed (rare).
1161+
assert(false && "multiple dereferencing uses in a single op not supported");
1162+
return failure();
1163+
}
12471164

1248-
return transformMemRefLoadWithReducedRank(
1249-
op, oldMemRef, newMemRef, memRefOperandPos, extraIndices, extraOperands,
1250-
symbolOperands, indexRemap);
1165+
// Perform index rewrites for the dereferencing op and then replace the op.
1166+
SmallVector<Value, 4> oldMapOperands;
1167+
AffineMap oldMap;
1168+
unsigned oldMemRefNumIndices = oldMemRefRank;
1169+
if (affMapAccInterface) {
1170+
// If `op` implements AffineMapAccessInterface, we can get the indices by
1171+
// quering the number of map operands from the operand list from a certain
1172+
// offset (`memRefOperandPos` in this case).
1173+
NamedAttribute oldMapAttrPair =
1174+
affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
1175+
oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue()).getValue();
1176+
oldMemRefNumIndices = oldMap.getNumInputs();
1177+
oldMapOperands.assign(op->operand_begin() + memRefOperandPos + 1,
1178+
op->operand_begin() + memRefOperandPos + 1 +
1179+
oldMemRefNumIndices);
1180+
} else {
1181+
oldMapOperands.assign(op->operand_begin() + memRefOperandPos + 1,
1182+
op->operand_begin() + memRefOperandPos + 1 +
1183+
oldMemRefRank);
12511184
}
1252-
// Perform index rewrites for the dereferencing op and then replace the op
1253-
NamedAttribute oldMapAttrPair =
1254-
affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
1255-
AffineMap oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue()).getValue();
1256-
unsigned oldMapNumInputs = oldMap.getNumInputs();
1257-
SmallVector<Value, 4> oldMapOperands(
1258-
op->operand_begin() + memRefOperandPos + 1,
1259-
op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
12601185

12611186
// Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
12621187
SmallVector<Value, 4> oldMemRefOperands;
12631188
SmallVector<Value, 4> affineApplyOps;
12641189
oldMemRefOperands.reserve(oldMemRefRank);
1265-
if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
1190+
if (affMapAccInterface &&
1191+
oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
12661192
for (auto resultExpr : oldMap.getResults()) {
12671193
auto singleResMap = AffineMap::get(oldMap.getNumDims(),
12681194
oldMap.getNumSymbols(), resultExpr);
@@ -1287,7 +1213,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
12871213

12881214
SmallVector<Value, 4> remapOutputs;
12891215
remapOutputs.reserve(oldMemRefRank);
1290-
12911216
if (indexRemap &&
12921217
indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
12931218
// Remapped indices.
@@ -1303,7 +1228,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
13031228
// No remapping specified.
13041229
remapOutputs.assign(remapOperands.begin(), remapOperands.end());
13051230
}
1306-
13071231
SmallVector<Value, 4> newMapOperands;
13081232
newMapOperands.reserve(newMemRefRank);
13091233

@@ -1338,13 +1262,26 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
13381262
state.operands.push_back(newMemRef);
13391263

13401264
// Insert the new memref map operands.
1341-
state.operands.append(newMapOperands.begin(), newMapOperands.end());
1265+
if (affMapAccInterface) {
1266+
state.operands.append(newMapOperands.begin(), newMapOperands.end());
1267+
} else {
1268+
// In the case of dereferencing ops not implementing
1269+
// AffineMapAccessInterface, we need to apply the values of `newMapOperands`
1270+
// to the `newMap` to get the correct indices.
1271+
for (unsigned i = 0; i < newMemRefRank; i++)
1272+
state.operands.push_back(builder.create<AffineApplyOp>(
1273+
op->getLoc(),
1274+
AffineMap::get(newMap.getNumDims(), newMap.getNumSymbols(),
1275+
newMap.getResult(i)),
1276+
newMapOperands));
1277+
}
13421278

13431279
// Insert the remaining operands unmodified.
1280+
unsigned oldMapNumInputs = oldMapOperands.size();
1281+
13441282
state.operands.append(op->operand_begin() + memRefOperandPos + 1 +
13451283
oldMapNumInputs,
13461284
op->operand_end());
1347-
13481285
// Result types don't change. Both memref's are of the same elemental type.
13491286
state.types.reserve(op->getNumResults());
13501287
for (auto result : op->getResults())
@@ -1353,7 +1290,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
13531290
// Add attribute for 'newMap', other Attributes do not change.
13541291
auto newMapAttr = AffineMapAttr::get(newMap);
13551292
for (auto namedAttr : op->getAttrs()) {
1356-
if (namedAttr.getName() == oldMapAttrPair.getName())
1293+
if (affMapAccInterface &&
1294+
namedAttr.getName() ==
1295+
affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef).getName())
13571296
state.attributes.push_back({namedAttr.getName(), newMapAttr});
13581297
else
13591298
state.attributes.push_back(namedAttr);
@@ -1846,6 +1785,95 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
18461785
return success();
18471786
}
18481787

1788+
LogicalResult
1789+
mlir::affine::normalizeMemRef(memref::ReinterpretCastOp *reinterpretCastOp) {
1790+
MemRefType memrefType = reinterpretCastOp->getType();
1791+
AffineMap oldLayoutMap = memrefType.getLayout().getAffineMap();
1792+
Value oldMemRef = reinterpretCastOp->getResult();
1793+
1794+
// If `oldLayoutMap` is identity, `memrefType` is already normalized.
1795+
if (oldLayoutMap.isIdentity())
1796+
return success();
1797+
1798+
// Fetch a new memref type after normalizing the old memref to have an
1799+
// identity map layout.
1800+
MemRefType newMemRefType = normalizeMemRefType(memrefType);
1801+
newMemRefType.dump();
1802+
if (newMemRefType == memrefType)
1803+
// `oldLayoutMap` couldn't be transformed to an identity map.
1804+
return failure();
1805+
1806+
uint64_t newRank = newMemRefType.getRank();
1807+
SmallVector<Value> mapOperands(oldLayoutMap.getNumDims() +
1808+
oldLayoutMap.getNumSymbols());
1809+
SmallVector<Value> oldStrides = reinterpretCastOp->getStrides();
1810+
Location loc = reinterpretCastOp->getLoc();
1811+
// As `newMemRefType` is normalized, it is unit strided.
1812+
SmallVector<int64_t> newStaticStrides(newRank, 1);
1813+
SmallVector<int64_t> newStaticOffsets(newRank, 0);
1814+
ArrayRef<int64_t> oldShape = memrefType.getShape();
1815+
mlir::ValueRange oldSizes = reinterpretCastOp->getSizes();
1816+
unsigned idx = 0;
1817+
SmallVector<int64_t> newStaticSizes;
1818+
OpBuilder b(*reinterpretCastOp);
1819+
// Collectthe map operands which will be used to compute the new normalized
1820+
// memref shape.
1821+
for (unsigned i = 0, e = memrefType.getRank(); i < e; i++) {
1822+
if (oldShape[i] == ShapedType::kDynamic)
1823+
mapOperands[i] =
1824+
b.create<arith::SubIOp>(loc, oldSizes[0].getType(), oldSizes[idx++],
1825+
b.create<arith::ConstantIndexOp>(loc, 1));
1826+
else
1827+
mapOperands[i] = b.create<arith::ConstantIndexOp>(loc, oldShape[i] - 1);
1828+
}
1829+
for (unsigned i = 0, e = oldStrides.size(); i < e; i++)
1830+
mapOperands[memrefType.getRank() + i] = oldStrides[i];
1831+
SmallVector<Value> newSizes;
1832+
ArrayRef<int64_t> newShape = newMemRefType.getShape();
1833+
// Compute size along all the dimensions of the new normalized memref.
1834+
for (unsigned i = 0; i < newRank; i++) {
1835+
if (newMemRefType.isDynamicDim(i))
1836+
continue;
1837+
newSizes.push_back(b.create<AffineApplyOp>(
1838+
loc,
1839+
AffineMap::get(oldLayoutMap.getNumDims(), oldLayoutMap.getNumSymbols(),
1840+
oldLayoutMap.getResult(i)),
1841+
mapOperands));
1842+
}
1843+
for (unsigned i = 0, e = newSizes.size(); i < e; i++)
1844+
newSizes[i] =
1845+
b.create<arith::AddIOp>(loc, newSizes[i].getType(), newSizes[i],
1846+
b.create<arith::ConstantIndexOp>(loc, 1));
1847+
// Create the new reinterpret_cast op.
1848+
memref::ReinterpretCastOp newReinterpretCast =
1849+
b.create<memref::ReinterpretCastOp>(
1850+
loc, newMemRefType, reinterpretCastOp->getSource(),
1851+
/*offsets=*/mlir::ValueRange(), newSizes,
1852+
/*strides=*/mlir::ValueRange(),
1853+
/*static_offsets=*/newStaticOffsets,
1854+
/*static_sizes=*/newShape,
1855+
/*static_strides=*/newStaticStrides);
1856+
1857+
// Replace all uses of the old memref.
1858+
if (failed(replaceAllMemRefUsesWith(oldMemRef,
1859+
/*newMemRef=*/newReinterpretCast,
1860+
/*extraIndices=*/{},
1861+
/*indexRemap=*/oldLayoutMap,
1862+
/*extraOperands=*/{},
1863+
/*symbolOperands=*/oldStrides,
1864+
/*domOpFilter=*/nullptr,
1865+
/*postDomOpFilter=*/nullptr,
1866+
/*allowNonDereferencingOps=*/true))) {
1867+
// If it failed (due to escapes for example), bail out.
1868+
newReinterpretCast->erase();
1869+
return failure();
1870+
}
1871+
1872+
oldMemRef.replaceAllUsesWith(newReinterpretCast);
1873+
reinterpretCastOp->erase();
1874+
return success();
1875+
}
1876+
18491877
template LogicalResult
18501878
mlir::affine::normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
18511879
template LogicalResult

mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,15 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
363363
for (memref::AllocaOp allocaOp : allocaOps)
364364
(void)normalizeMemRef(&allocaOp);
365365

366+
// Turn memrefs' non-identity layouts maps into ones with identity. Collect
367+
// reinterpret_cast ops first and then process since normalizeMemRef
368+
// replaces/erases ops during memref rewriting.
369+
SmallVector<memref::ReinterpretCastOp> reinterpretCastOps;
370+
funcOp.walk(
371+
[&](memref::ReinterpretCastOp op) { reinterpretCastOps.push_back(op); });
372+
for (memref::ReinterpretCastOp reinterpretCastOp : reinterpretCastOps)
373+
(void)normalizeMemRef(&reinterpretCastOp);
374+
366375
// We use this OpBuilder to create new memref layout later.
367376
OpBuilder b(funcOp);
368377

0 commit comments

Comments
 (0)