@@ -1101,7 +1101,7 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
1101
1101
// Private helper function to transform memref.load with reduced rank.
1102
1102
// This function will modify the indices of the memref.load to match the
1103
1103
// newMemRef.
1104
- LogicalResult transformMemRefLoadWithReducedRank (
1104
+ LogicalResult transformMemRefLoadOrStoreWithReducedRank (
1105
1105
Operation *op, Value oldMemRef, Value newMemRef, unsigned memRefOperandPos,
1106
1106
ArrayRef<Value> extraIndices, ArrayRef<Value> extraOperands,
1107
1107
ArrayRef<Value> symbolOperands, AffineMap indexRemap) {
@@ -1182,6 +1182,14 @@ LogicalResult transformMemRefLoadWithReducedRank(
1182
1182
1183
1183
return success ();
1184
1184
}
1185
+
1186
+ // Checks if `op` is non dereferencing.
1187
+ // TODO: This hardcoded check will be removed once the right interface is added.
1188
+ static bool isDereferencingOp (Operation *op) {
1189
+ return isa<AffineMapAccessInterface, memref::LoadOp, memref::StoreOp>(
1190
+ op);
1191
+ }
1192
+
1185
1193
// Perform the replacement in `op`.
1186
1194
LogicalResult mlir::affine::replaceAllMemRefUsesWith (
1187
1195
Value oldMemRef, Value newMemRef, Operation *op,
@@ -1228,41 +1236,44 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1228
1236
// The following checks if op is dereferencing memref and performs the access
1229
1237
// index rewrites.
1230
1238
auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
1231
- if (!affMapAccInterface ) {
1239
+ if (!isDereferencingOp (op) ) {
1232
1240
if (!allowNonDereferencingOps) {
1233
1241
// Failure: memref used in a non-dereferencing context (potentially
1234
1242
// escapes); no replacement in these cases unless allowNonDereferencingOps
1235
1243
// is set.
1236
1244
return failure ();
1237
1245
}
1246
+ op->setOperand (memRefOperandPos, newMemRef);
1247
+ return success ();
1248
+ }
1238
1249
1239
- // Check if it is a memref.load
1240
- auto memrefLoad = dyn_cast<memref::LoadOp>(op);
1241
- bool isReductionLike =
1242
- indexRemap.getNumResults () < indexRemap.getNumInputs ();
1243
- if (!memrefLoad || !isReductionLike) {
1244
- op->setOperand (memRefOperandPos, newMemRef);
1245
- return success ();
1246
- }
1247
-
1248
- return transformMemRefLoadWithReducedRank (
1249
- op, oldMemRef, newMemRef, memRefOperandPos, extraIndices, extraOperands,
1250
- symbolOperands, indexRemap);
1250
+ // Perform index rewrites for the dereferencing op and then replace the op.
1251
+ SmallVector<Value, 4 > oldMapOperands;
1252
+ AffineMap oldMap;
1253
+ unsigned oldMemRefNumIndices = oldMemRefRank;
1254
+ if (affMapAccInterface) {
1255
+ // If `op` implements AffineMapAccessInterface, we can get the indices by
1256
+ // quering the number of map operands from the operand list from a certain
1257
+ // offset (`memRefOperandPos` in this case).
1258
+ NamedAttribute oldMapAttrPair =
1259
+ affMapAccInterface.getAffineMapAttrForMemRef (oldMemRef);
1260
+ oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue ()).getValue ();
1261
+ oldMemRefNumIndices = oldMap.getNumInputs ();
1262
+ oldMapOperands.assign (op->operand_begin () + memRefOperandPos + 1 ,
1263
+ op->operand_begin () + memRefOperandPos + 1 +
1264
+ oldMemRefNumIndices);
1265
+ } else {
1266
+ oldMapOperands.assign (op->operand_begin () + memRefOperandPos + 1 ,
1267
+ op->operand_begin () + memRefOperandPos + 1 +
1268
+ oldMemRefRank);
1251
1269
}
1252
- // Perform index rewrites for the dereferencing op and then replace the op
1253
- NamedAttribute oldMapAttrPair =
1254
- affMapAccInterface.getAffineMapAttrForMemRef (oldMemRef);
1255
- AffineMap oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue ()).getValue ();
1256
- unsigned oldMapNumInputs = oldMap.getNumInputs ();
1257
- SmallVector<Value, 4 > oldMapOperands (
1258
- op->operand_begin () + memRefOperandPos + 1 ,
1259
- op->operand_begin () + memRefOperandPos + 1 + oldMapNumInputs);
1260
1270
1261
1271
// Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
1262
1272
SmallVector<Value, 4 > oldMemRefOperands;
1263
1273
SmallVector<Value, 4 > affineApplyOps;
1264
1274
oldMemRefOperands.reserve (oldMemRefRank);
1265
- if (oldMap != builder.getMultiDimIdentityMap (oldMap.getNumDims ())) {
1275
+ if (affMapAccInterface &&
1276
+ oldMap != builder.getMultiDimIdentityMap (oldMap.getNumDims ())) {
1266
1277
for (auto resultExpr : oldMap.getResults ()) {
1267
1278
auto singleResMap = AffineMap::get (oldMap.getNumDims (),
1268
1279
oldMap.getNumSymbols (), resultExpr);
@@ -1287,7 +1298,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1287
1298
1288
1299
SmallVector<Value, 4 > remapOutputs;
1289
1300
remapOutputs.reserve (oldMemRefRank);
1290
-
1291
1301
if (indexRemap &&
1292
1302
indexRemap != builder.getMultiDimIdentityMap (indexRemap.getNumDims ())) {
1293
1303
// Remapped indices.
@@ -1303,7 +1313,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1303
1313
// No remapping specified.
1304
1314
remapOutputs.assign (remapOperands.begin (), remapOperands.end ());
1305
1315
}
1306
-
1307
1316
SmallVector<Value, 4 > newMapOperands;
1308
1317
newMapOperands.reserve (newMemRefRank);
1309
1318
@@ -1338,13 +1347,26 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1338
1347
state.operands .push_back (newMemRef);
1339
1348
1340
1349
// Insert the new memref map operands.
1341
- state.operands .append (newMapOperands.begin (), newMapOperands.end ());
1350
+ if (affMapAccInterface) {
1351
+ state.operands .append (newMapOperands.begin (), newMapOperands.end ());
1352
+ } else {
1353
+ // In the case of dereferencing ops not implementing
1354
+ // AffineMapAccessInterface, we need to apply the values of `newMapOperands`
1355
+ // to the `newMap` to get the correct indices.
1356
+ for (unsigned i = 0 ; i < newMemRefRank; i++)
1357
+ state.operands .push_back (builder.create <AffineApplyOp>(
1358
+ op->getLoc (),
1359
+ AffineMap::get (newMap.getNumDims (), newMap.getNumSymbols (),
1360
+ newMap.getResult (i)),
1361
+ newMapOperands));
1362
+ }
1342
1363
1343
1364
// Insert the remaining operands unmodified.
1365
+ unsigned oldMapNumInputs = oldMapOperands.size ();
1366
+
1344
1367
state.operands .append (op->operand_begin () + memRefOperandPos + 1 +
1345
1368
oldMapNumInputs,
1346
1369
op->operand_end ());
1347
-
1348
1370
// Result types don't change. Both memref's are of the same elemental type.
1349
1371
state.types .reserve (op->getNumResults ());
1350
1372
for (auto result : op->getResults ())
@@ -1353,7 +1375,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1353
1375
// Add attribute for 'newMap', other Attributes do not change.
1354
1376
auto newMapAttr = AffineMapAttr::get (newMap);
1355
1377
for (auto namedAttr : op->getAttrs ()) {
1356
- if (namedAttr.getName () == oldMapAttrPair.getName ())
1378
+ if (affMapAccInterface &&
1379
+ namedAttr.getName () ==
1380
+ affMapAccInterface.getAffineMapAttrForMemRef (oldMemRef).getName ())
1357
1381
state.attributes .push_back ({namedAttr.getName (), newMapAttr});
1358
1382
else
1359
1383
state.attributes .push_back (namedAttr);
@@ -1846,6 +1870,93 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
1846
1870
return success ();
1847
1871
}
1848
1872
1873
+ LogicalResult
1874
+ mlir::affine::normalizeMemRef (memref::ReinterpretCastOp *reinterpretCastOp) {
1875
+ MemRefType memrefType = reinterpretCastOp->getType ();
1876
+ AffineMap oldLayoutMap = memrefType.getLayout ().getAffineMap ();
1877
+ Value oldMemRef = reinterpretCastOp->getResult ();
1878
+
1879
+ // If `oldLayoutMap` is identity, `memrefType` is already normalized.
1880
+ if (oldLayoutMap.isIdentity ())
1881
+ return success ();
1882
+
1883
+ // Fetch a new memref type after normalizing the old memref to have an
1884
+ // identity map layout.
1885
+ MemRefType newMemRefType =
1886
+ normalizeMemRefType (memrefType);
1887
+ if (newMemRefType == memrefType)
1888
+ // `oldLayoutMap` couldn't be transformed to an identity map.
1889
+ return failure ();
1890
+
1891
+ uint64_t newRank = newMemRefType.getRank ();
1892
+ SmallVector<Value> mapOperands (oldLayoutMap.getNumDims () +
1893
+ oldLayoutMap.getNumSymbols ());
1894
+ SmallVector<Value> oldStrides = reinterpretCastOp->getStrides ();
1895
+ Location loc = reinterpretCastOp->getLoc ();
1896
+ // As `newMemRefType` is normalized, it is unit strided.
1897
+ SmallVector<int64_t > newStaticStrides (newRank, 1 );
1898
+ ArrayRef<int64_t > oldShape = memrefType.getShape ();
1899
+ mlir::ValueRange oldSizes = reinterpretCastOp->getSizes ();
1900
+ unsigned idx = 0 ;
1901
+ SmallVector<int64_t > newStaticSizes;
1902
+ OpBuilder b (*reinterpretCastOp);
1903
+ // Collectthe map operands which will be used to compute the new normalized
1904
+ // memref shape.
1905
+ for (unsigned i = 0 , e = memrefType.getRank (); i < e; i++) {
1906
+ if (oldShape[i] == ShapedType::kDynamic )
1907
+ mapOperands[i] =
1908
+ b.create <arith::SubIOp>(loc, oldSizes[0 ].getType (), oldSizes[idx++],
1909
+ b.create <arith::ConstantIndexOp>(loc, 1 ));
1910
+ else
1911
+ mapOperands[i] = b.create <arith::ConstantIndexOp>(loc, oldShape[i] - 1 );
1912
+ }
1913
+ for (unsigned i = 0 , e = oldStrides.size (); i < e; i++)
1914
+ mapOperands[memrefType.getRank () + i] = oldStrides[i];
1915
+ SmallVector<Value> newSizes;
1916
+ ArrayRef<int64_t > newShape = newMemRefType.getShape ();
1917
+ // Compute size along all the dimensions of the new normalized memref.
1918
+ for (unsigned i = 0 ; i < newRank; i++) {
1919
+ if (newShape[i] != ShapedType::kDynamic )
1920
+ continue ;
1921
+ newSizes.push_back (b.create <AffineApplyOp>(
1922
+ loc,
1923
+ AffineMap::get (oldLayoutMap.getNumDims (), oldLayoutMap.getNumSymbols (),
1924
+ oldLayoutMap.getResult (i)),
1925
+ mapOperands));
1926
+ }
1927
+ for (unsigned i = 0 , e = newSizes.size (); i < e; i++)
1928
+ newSizes[i] =
1929
+ b.create <arith::AddIOp>(loc, newSizes[i].getType (), newSizes[i],
1930
+ b.create <arith::ConstantIndexOp>(loc, 1 ));
1931
+ // Create the new reinterpret_cast op.
1932
+ memref::ReinterpretCastOp newReinterpretCast =
1933
+ b.create <memref::ReinterpretCastOp>(
1934
+ loc, newMemRefType, reinterpretCastOp->getSource (),
1935
+ reinterpretCastOp->getOffsets (), newSizes, mlir::ValueRange (),
1936
+ /* static_offsets=*/ reinterpretCastOp->getStaticOffsets (),
1937
+ /* static_sizes=*/ newShape,
1938
+ /* static_strides=*/ newStaticStrides);
1939
+
1940
+ // Replace all uses of the old memref.
1941
+ if (failed (replaceAllMemRefUsesWith (oldMemRef,
1942
+ /* newMemRef=*/ newReinterpretCast,
1943
+ /* extraIndices=*/ {},
1944
+ /* indexRemap=*/ oldLayoutMap,
1945
+ /* extraOperands=*/ {},
1946
+ /* symbolOperands=*/ oldStrides,
1947
+ /* domOpFilter=*/ nullptr ,
1948
+ /* postDomOpFilter=*/ nullptr ,
1949
+ /* allowNonDereferencingOps=*/ true ))) {
1950
+ // If it failed (due to escapes for example), bail out.
1951
+ newReinterpretCast->erase ();
1952
+ return failure ();
1953
+ }
1954
+
1955
+ oldMemRef.replaceAllUsesWith (newReinterpretCast);
1956
+ reinterpretCastOp->erase ();
1957
+ return success ();
1958
+ }
1959
+
1849
1960
template LogicalResult
1850
1961
mlir::affine::normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
1851
1962
template LogicalResult
0 commit comments