Skip to content

Commit e36680a

Browse files
Normalize reinterpret_cast op
Normalize reinterpret_cast op for statically shaped input and output memrefs.
1 parent 2df25a4 commit e36680a

File tree

4 files changed

+154
-171
lines changed

4 files changed

+154
-171
lines changed

mlir/include/mlir/Dialect/Affine/Utils.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class FuncOp;
3232
namespace memref {
3333
class AllocOp;
3434
class AllocaOp;
35+
class ReinterpretCastOp;
3536
} // namespace memref
3637

3738
namespace affine {
@@ -243,16 +244,18 @@ LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
243244
ArrayRef<Value> symbolOperands = {},
244245
bool allowNonDereferencingOps = false);
245246

246-
/// Rewrites the memref defined by this alloc op to have an identity layout map
247-
/// and updates all its indexing uses. Returns failure if any of its uses
248-
/// escape (while leaving the IR in a valid state).
247+
/// Rewrites the memref defined by alloc or reinterpret_cast op to have an
248+
/// identity layout map and updates all its indexing uses. Returns failure if
249+
/// any of its uses escape (while leaving the IR in a valid state).
249250
template <typename AllocLikeOp>
250251
LogicalResult normalizeMemRef(AllocLikeOp *op);
251252
extern template LogicalResult
252253
normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
253254
extern template LogicalResult
254255
normalizeMemRef<memref::AllocOp>(memref::AllocOp *op);
255256

257+
LogicalResult normalizeMemRef(memref::ReinterpretCastOp *op);
258+
256259
/// Normalizes `memrefType` so that the affine layout map of the memref is
257260
/// transformed to an identity map with a new shape being computed for the
258261
/// normalized memref type and returns it. The old memref type is simplify

mlir/lib/Dialect/Affine/Utils/Utils.cpp

Lines changed: 139 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,7 +1101,7 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
11011101
// Private helper function to transform memref.load with reduced rank.
11021102
// This function will modify the indices of the memref.load to match the
11031103
// newMemRef.
1104-
LogicalResult transformMemRefLoadWithReducedRank(
1104+
LogicalResult transformMemRefLoadOrStoreWithReducedRank(
11051105
Operation *op, Value oldMemRef, Value newMemRef, unsigned memRefOperandPos,
11061106
ArrayRef<Value> extraIndices, ArrayRef<Value> extraOperands,
11071107
ArrayRef<Value> symbolOperands, AffineMap indexRemap) {
@@ -1182,6 +1182,14 @@ LogicalResult transformMemRefLoadWithReducedRank(
11821182

11831183
return success();
11841184
}
1185+
1186+
// Checks if `op` is non dereferencing.
1187+
// TODO: This hardcoded check will be removed once the right interface is added.
1188+
static bool isDereferencingOp(Operation *op) {
1189+
return isa<AffineMapAccessInterface, memref::LoadOp, memref::StoreOp>(
1190+
op);
1191+
}
1192+
11851193
// Perform the replacement in `op`.
11861194
LogicalResult mlir::affine::replaceAllMemRefUsesWith(
11871195
Value oldMemRef, Value newMemRef, Operation *op,
@@ -1228,41 +1236,44 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
12281236
// The following checks if op is dereferencing memref and performs the access
12291237
// index rewrites.
12301238
auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
1231-
if (!affMapAccInterface) {
1239+
if (!isDereferencingOp(op)) {
12321240
if (!allowNonDereferencingOps) {
12331241
// Failure: memref used in a non-dereferencing context (potentially
12341242
// escapes); no replacement in these cases unless allowNonDereferencingOps
12351243
// is set.
12361244
return failure();
12371245
}
1246+
op->setOperand(memRefOperandPos, newMemRef);
1247+
return success();
1248+
}
12381249

1239-
// Check if it is a memref.load
1240-
auto memrefLoad = dyn_cast<memref::LoadOp>(op);
1241-
bool isReductionLike =
1242-
indexRemap.getNumResults() < indexRemap.getNumInputs();
1243-
if (!memrefLoad || !isReductionLike) {
1244-
op->setOperand(memRefOperandPos, newMemRef);
1245-
return success();
1246-
}
1247-
1248-
return transformMemRefLoadWithReducedRank(
1249-
op, oldMemRef, newMemRef, memRefOperandPos, extraIndices, extraOperands,
1250-
symbolOperands, indexRemap);
1250+
// Perform index rewrites for the dereferencing op and then replace the op.
1251+
SmallVector<Value, 4> oldMapOperands;
1252+
AffineMap oldMap;
1253+
unsigned oldMemRefNumIndices = oldMemRefRank;
1254+
if (affMapAccInterface) {
1255+
// If `op` implements AffineMapAccessInterface, we can get the indices by
1256+
// quering the number of map operands from the operand list from a certain
1257+
// offset (`memRefOperandPos` in this case).
1258+
NamedAttribute oldMapAttrPair =
1259+
affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
1260+
oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue()).getValue();
1261+
oldMemRefNumIndices = oldMap.getNumInputs();
1262+
oldMapOperands.assign(op->operand_begin() + memRefOperandPos + 1,
1263+
op->operand_begin() + memRefOperandPos + 1 +
1264+
oldMemRefNumIndices);
1265+
} else {
1266+
oldMapOperands.assign(op->operand_begin() + memRefOperandPos + 1,
1267+
op->operand_begin() + memRefOperandPos + 1 +
1268+
oldMemRefRank);
12511269
}
1252-
// Perform index rewrites for the dereferencing op and then replace the op
1253-
NamedAttribute oldMapAttrPair =
1254-
affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
1255-
AffineMap oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue()).getValue();
1256-
unsigned oldMapNumInputs = oldMap.getNumInputs();
1257-
SmallVector<Value, 4> oldMapOperands(
1258-
op->operand_begin() + memRefOperandPos + 1,
1259-
op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
12601270

12611271
// Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
12621272
SmallVector<Value, 4> oldMemRefOperands;
12631273
SmallVector<Value, 4> affineApplyOps;
12641274
oldMemRefOperands.reserve(oldMemRefRank);
1265-
if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
1275+
if (affMapAccInterface &&
1276+
oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
12661277
for (auto resultExpr : oldMap.getResults()) {
12671278
auto singleResMap = AffineMap::get(oldMap.getNumDims(),
12681279
oldMap.getNumSymbols(), resultExpr);
@@ -1287,7 +1298,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
12871298

12881299
SmallVector<Value, 4> remapOutputs;
12891300
remapOutputs.reserve(oldMemRefRank);
1290-
12911301
if (indexRemap &&
12921302
indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
12931303
// Remapped indices.
@@ -1303,7 +1313,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
13031313
// No remapping specified.
13041314
remapOutputs.assign(remapOperands.begin(), remapOperands.end());
13051315
}
1306-
13071316
SmallVector<Value, 4> newMapOperands;
13081317
newMapOperands.reserve(newMemRefRank);
13091318

@@ -1338,13 +1347,26 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
13381347
state.operands.push_back(newMemRef);
13391348

13401349
// Insert the new memref map operands.
1341-
state.operands.append(newMapOperands.begin(), newMapOperands.end());
1350+
if (affMapAccInterface) {
1351+
state.operands.append(newMapOperands.begin(), newMapOperands.end());
1352+
} else {
1353+
// In the case of dereferencing ops not implementing
1354+
// AffineMapAccessInterface, we need to apply the values of `newMapOperands`
1355+
// to the `newMap` to get the correct indices.
1356+
for (unsigned i = 0; i < newMemRefRank; i++)
1357+
state.operands.push_back(builder.create<AffineApplyOp>(
1358+
op->getLoc(),
1359+
AffineMap::get(newMap.getNumDims(), newMap.getNumSymbols(),
1360+
newMap.getResult(i)),
1361+
newMapOperands));
1362+
}
13421363

13431364
// Insert the remaining operands unmodified.
1365+
unsigned oldMapNumInputs = oldMapOperands.size();
1366+
13441367
state.operands.append(op->operand_begin() + memRefOperandPos + 1 +
13451368
oldMapNumInputs,
13461369
op->operand_end());
1347-
13481370
// Result types don't change. Both memref's are of the same elemental type.
13491371
state.types.reserve(op->getNumResults());
13501372
for (auto result : op->getResults())
@@ -1353,7 +1375,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
13531375
// Add attribute for 'newMap', other Attributes do not change.
13541376
auto newMapAttr = AffineMapAttr::get(newMap);
13551377
for (auto namedAttr : op->getAttrs()) {
1356-
if (namedAttr.getName() == oldMapAttrPair.getName())
1378+
if (affMapAccInterface &&
1379+
namedAttr.getName() ==
1380+
affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef).getName())
13571381
state.attributes.push_back({namedAttr.getName(), newMapAttr});
13581382
else
13591383
state.attributes.push_back(namedAttr);
@@ -1846,6 +1870,93 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
18461870
return success();
18471871
}
18481872

1873+
LogicalResult
1874+
mlir::affine::normalizeMemRef(memref::ReinterpretCastOp *reinterpretCastOp) {
1875+
MemRefType memrefType = reinterpretCastOp->getType();
1876+
AffineMap oldLayoutMap = memrefType.getLayout().getAffineMap();
1877+
Value oldMemRef = reinterpretCastOp->getResult();
1878+
1879+
// If `oldLayoutMap` is identity, `memrefType` is already normalized.
1880+
if (oldLayoutMap.isIdentity())
1881+
return success();
1882+
1883+
// Fetch a new memref type after normalizing the old memref to have an
1884+
// identity map layout.
1885+
MemRefType newMemRefType =
1886+
normalizeMemRefType(memrefType);
1887+
if (newMemRefType == memrefType)
1888+
// `oldLayoutMap` couldn't be transformed to an identity map.
1889+
return failure();
1890+
1891+
uint64_t newRank = newMemRefType.getRank();
1892+
SmallVector<Value> mapOperands(oldLayoutMap.getNumDims() +
1893+
oldLayoutMap.getNumSymbols());
1894+
SmallVector<Value> oldStrides = reinterpretCastOp->getStrides();
1895+
Location loc = reinterpretCastOp->getLoc();
1896+
// As `newMemRefType` is normalized, it is unit strided.
1897+
SmallVector<int64_t> newStaticStrides(newRank, 1);
1898+
ArrayRef<int64_t> oldShape = memrefType.getShape();
1899+
mlir::ValueRange oldSizes = reinterpretCastOp->getSizes();
1900+
unsigned idx = 0;
1901+
SmallVector<int64_t> newStaticSizes;
1902+
OpBuilder b(*reinterpretCastOp);
1903+
// Collectthe map operands which will be used to compute the new normalized
1904+
// memref shape.
1905+
for (unsigned i = 0, e = memrefType.getRank(); i < e; i++) {
1906+
if (oldShape[i] == ShapedType::kDynamic)
1907+
mapOperands[i] =
1908+
b.create<arith::SubIOp>(loc, oldSizes[0].getType(), oldSizes[idx++],
1909+
b.create<arith::ConstantIndexOp>(loc, 1));
1910+
else
1911+
mapOperands[i] = b.create<arith::ConstantIndexOp>(loc, oldShape[i] - 1);
1912+
}
1913+
for (unsigned i = 0, e = oldStrides.size(); i < e; i++)
1914+
mapOperands[memrefType.getRank() + i] = oldStrides[i];
1915+
SmallVector<Value> newSizes;
1916+
ArrayRef<int64_t> newShape = newMemRefType.getShape();
1917+
// Compute size along all the dimensions of the new normalized memref.
1918+
for (unsigned i = 0; i < newRank; i++) {
1919+
if (newShape[i] != ShapedType::kDynamic)
1920+
continue;
1921+
newSizes.push_back(b.create<AffineApplyOp>(
1922+
loc,
1923+
AffineMap::get(oldLayoutMap.getNumDims(), oldLayoutMap.getNumSymbols(),
1924+
oldLayoutMap.getResult(i)),
1925+
mapOperands));
1926+
}
1927+
for (unsigned i = 0, e = newSizes.size(); i < e; i++)
1928+
newSizes[i] =
1929+
b.create<arith::AddIOp>(loc, newSizes[i].getType(), newSizes[i],
1930+
b.create<arith::ConstantIndexOp>(loc, 1));
1931+
// Create the new reinterpret_cast op.
1932+
memref::ReinterpretCastOp newReinterpretCast =
1933+
b.create<memref::ReinterpretCastOp>(
1934+
loc, newMemRefType, reinterpretCastOp->getSource(),
1935+
reinterpretCastOp->getOffsets(), newSizes, mlir::ValueRange(),
1936+
/*static_offsets=*/reinterpretCastOp->getStaticOffsets(),
1937+
/*static_sizes=*/newShape,
1938+
/*static_strides=*/newStaticStrides);
1939+
1940+
// Replace all uses of the old memref.
1941+
if (failed(replaceAllMemRefUsesWith(oldMemRef,
1942+
/*newMemRef=*/newReinterpretCast,
1943+
/*extraIndices=*/{},
1944+
/*indexRemap=*/oldLayoutMap,
1945+
/*extraOperands=*/{},
1946+
/*symbolOperands=*/oldStrides,
1947+
/*domOpFilter=*/nullptr,
1948+
/*postDomOpFilter=*/nullptr,
1949+
/*allowNonDereferencingOps=*/true))) {
1950+
// If it failed (due to escapes for example), bail out.
1951+
newReinterpretCast->erase();
1952+
return failure();
1953+
}
1954+
1955+
oldMemRef.replaceAllUsesWith(newReinterpretCast);
1956+
reinterpretCastOp->erase();
1957+
return success();
1958+
}
1959+
18491960
template LogicalResult
18501961
mlir::affine::normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
18511962
template LogicalResult

mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,15 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
363363
for (memref::AllocaOp allocaOp : allocaOps)
364364
(void)normalizeMemRef(&allocaOp);
365365

366+
// Turn memrefs' non-identity layouts maps into ones with identity. Collect
367+
// reinterpret_cast ops first and then process since normalizeMemRef
368+
// replaces/erases ops during memref rewriting.
369+
SmallVector<memref::ReinterpretCastOp> reinterpretCastOps;
370+
funcOp.walk(
371+
[&](memref::ReinterpretCastOp op) { reinterpretCastOps.push_back(op); });
372+
for (memref::ReinterpretCastOp reinterpretCastOp : reinterpretCastOps)
373+
(void)normalizeMemRef(&reinterpretCastOp);
374+
366375
// We use this OpBuilder to create new memref layout later.
367376
OpBuilder b(funcOp);
368377

0 commit comments

Comments
 (0)