@@ -1098,90 +1098,12 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
1098
1098
op->erase ();
1099
1099
}
1100
1100
1101
- // Private helper function to transform memref.load with reduced rank.
1102
- // This function will modify the indices of the memref.load to match the
1103
- // newMemRef.
1104
- LogicalResult transformMemRefLoadWithReducedRank (
1105
- Operation *op, Value oldMemRef, Value newMemRef, unsigned memRefOperandPos,
1106
- ArrayRef<Value> extraIndices, ArrayRef<Value> extraOperands,
1107
- ArrayRef<Value> symbolOperands, AffineMap indexRemap) {
1108
- unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType ()).getRank ();
1109
- unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType ()).getRank ();
1110
- unsigned oldMapNumInputs = oldMemRefRank;
1111
- SmallVector<Value, 4 > oldMapOperands (
1112
- op->operand_begin () + memRefOperandPos + 1 ,
1113
- op->operand_begin () + memRefOperandPos + 1 + oldMapNumInputs);
1114
- SmallVector<Value, 4 > oldMemRefOperands;
1115
- oldMemRefOperands.assign (oldMapOperands.begin (), oldMapOperands.end ());
1116
- SmallVector<Value, 4 > remapOperands;
1117
- remapOperands.reserve (extraOperands.size () + oldMemRefRank +
1118
- symbolOperands.size ());
1119
- remapOperands.append (extraOperands.begin (), extraOperands.end ());
1120
- remapOperands.append (oldMemRefOperands.begin (), oldMemRefOperands.end ());
1121
- remapOperands.append (symbolOperands.begin (), symbolOperands.end ());
1122
-
1123
- SmallVector<Value, 4 > remapOutputs;
1124
- remapOutputs.reserve (oldMemRefRank);
1125
- SmallVector<Value, 4 > affineApplyOps;
1126
-
1127
- OpBuilder builder (op);
1128
-
1129
- if (indexRemap &&
1130
- indexRemap != builder.getMultiDimIdentityMap (indexRemap.getNumDims ())) {
1131
- // Remapped indices.
1132
- for (auto resultExpr : indexRemap.getResults ()) {
1133
- auto singleResMap = AffineMap::get (
1134
- indexRemap.getNumDims (), indexRemap.getNumSymbols (), resultExpr);
1135
- auto afOp = builder.create <AffineApplyOp>(op->getLoc (), singleResMap,
1136
- remapOperands);
1137
- remapOutputs.push_back (afOp);
1138
- affineApplyOps.push_back (afOp);
1139
- }
1140
- } else {
1141
- // No remapping specified.
1142
- remapOutputs.assign (remapOperands.begin (), remapOperands.end ());
1143
- }
1144
-
1145
- SmallVector<Value, 4 > newMapOperands;
1146
- newMapOperands.reserve (newMemRefRank);
1147
-
1148
- // Prepend 'extraIndices' in 'newMapOperands'.
1149
- for (Value extraIndex : extraIndices) {
1150
- assert ((isValidDim (extraIndex) || isValidSymbol (extraIndex)) &&
1151
- " invalid memory op index" );
1152
- newMapOperands.push_back (extraIndex);
1153
- }
1154
-
1155
- // Append 'remapOutputs' to 'newMapOperands'.
1156
- newMapOperands.append (remapOutputs.begin (), remapOutputs.end ());
1157
-
1158
- // Create new fully composed AffineMap for new op to be created.
1159
- assert (newMapOperands.size () == newMemRefRank);
1160
-
1161
- OperationState state (op->getLoc (), op->getName ());
1162
- // Construct the new operation using this memref.
1163
- state.operands .reserve (newMapOperands.size () + extraIndices.size ());
1164
- state.operands .push_back (newMemRef);
1165
-
1166
- // Insert the new memref map operands.
1167
- state.operands .append (newMapOperands.begin (), newMapOperands.end ());
1168
-
1169
- state.types .reserve (op->getNumResults ());
1170
- for (auto result : op->getResults ())
1171
- state.types .push_back (result.getType ());
1172
-
1173
- // Copy over the attributes from the old operation to the new operation.
1174
- for (auto namedAttr : op->getAttrs ()) {
1175
- state.attributes .push_back (namedAttr);
1176
- }
1177
-
1178
- // Create the new operation.
1179
- auto *repOp = builder.create (state);
1180
- op->replaceAllUsesWith (repOp);
1181
- op->erase ();
1182
-
1183
- return success ();
1101
+ // Checks if `op` is non dereferencing.
1102
+ // TODO: This hardcoded check will be removed once the right interface is added.
1103
+ static bool isDereferencingOp (Operation *op) {
1104
+ return isa<AffineMapAccessInterface, memref::LoadOp, memref::StoreOp>(op);
1184
1105
}
1106
+
1185
1107
// Perform the replacement in `op`.
1186
1108
LogicalResult mlir::affine::replaceAllMemRefUsesWith (
1187
1109
Value oldMemRef, Value newMemRef, Operation *op,
@@ -1216,53 +1138,55 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1216
1138
if (usePositions.empty ())
1217
1139
return success ();
1218
1140
1219
- if (usePositions.size () > 1 ) {
1220
- // TODO: extend it for this case when needed (rare).
1221
- assert (false && " multiple dereferencing uses in a single op not supported" );
1222
- return failure ();
1223
- }
1224
-
1225
1141
unsigned memRefOperandPos = usePositions.front ();
1226
1142
1227
1143
OpBuilder builder (op);
1228
1144
// The following checks if op is dereferencing memref and performs the access
1229
1145
// index rewrites.
1230
1146
auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
1231
- if (!affMapAccInterface ) {
1147
+ if (!isDereferencingOp (op) ) {
1232
1148
if (!allowNonDereferencingOps) {
1233
1149
// Failure: memref used in a non-dereferencing context (potentially
1234
1150
// escapes); no replacement in these cases unless allowNonDereferencingOps
1235
1151
// is set.
1236
1152
return failure ();
1237
1153
}
1154
+ for (unsigned pos : usePositions)
1155
+ op->setOperand (pos, newMemRef);
1156
+ return success ();
1157
+ }
1238
1158
1239
- // Check if it is a memref.load
1240
- auto memrefLoad = dyn_cast<memref::LoadOp>(op);
1241
- bool isReductionLike =
1242
- indexRemap.getNumResults () < indexRemap.getNumInputs ();
1243
- if (!memrefLoad || !isReductionLike) {
1244
- op->setOperand (memRefOperandPos, newMemRef);
1245
- return success ();
1246
- }
1159
+ if (usePositions.size () > 1 ) {
1160
+ // TODO: extend it for this case when needed (rare).
1161
+ LLVM_DEBUG (llvm::dbgs ()
1162
+ << " multiple dereferencing uses in a single op not supported" );
1163
+ return failure ();
1164
+ }
1247
1165
1248
- return transformMemRefLoadWithReducedRank (
1249
- op, oldMemRef, newMemRef, memRefOperandPos, extraIndices, extraOperands,
1250
- symbolOperands, indexRemap);
1166
+ // Perform index rewrites for the dereferencing op and then replace the op.
1167
+ SmallVector<Value, 4 > oldMapOperands;
1168
+ AffineMap oldMap;
1169
+ unsigned oldMemRefNumIndices = oldMemRefRank;
1170
+ auto startIdx = op->operand_begin () + memRefOperandPos + 1 ;
1171
+ if (affMapAccInterface) {
1172
+ // If `op` implements AffineMapAccessInterface, we can get the indices by
1173
+ // quering the number of map operands from the operand list from a certain
1174
+ // offset (`memRefOperandPos` in this case).
1175
+ NamedAttribute oldMapAttrPair =
1176
+ affMapAccInterface.getAffineMapAttrForMemRef (oldMemRef);
1177
+ oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue ()).getValue ();
1178
+ oldMemRefNumIndices = oldMap.getNumInputs ();
1179
+ oldMapOperands.assign (startIdx, startIdx + oldMemRefNumIndices);
1180
+ } else {
1181
+ oldMapOperands.assign (startIdx, startIdx + oldMemRefRank);
1251
1182
}
1252
- // Perform index rewrites for the dereferencing op and then replace the op
1253
- NamedAttribute oldMapAttrPair =
1254
- affMapAccInterface.getAffineMapAttrForMemRef (oldMemRef);
1255
- AffineMap oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue ()).getValue ();
1256
- unsigned oldMapNumInputs = oldMap.getNumInputs ();
1257
- SmallVector<Value, 4 > oldMapOperands (
1258
- op->operand_begin () + memRefOperandPos + 1 ,
1259
- op->operand_begin () + memRefOperandPos + 1 + oldMapNumInputs);
1260
1183
1261
1184
// Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
1262
1185
SmallVector<Value, 4 > oldMemRefOperands;
1263
1186
SmallVector<Value, 4 > affineApplyOps;
1264
1187
oldMemRefOperands.reserve (oldMemRefRank);
1265
- if (oldMap != builder.getMultiDimIdentityMap (oldMap.getNumDims ())) {
1188
+ if (affMapAccInterface &&
1189
+ oldMap != builder.getMultiDimIdentityMap (oldMap.getNumDims ())) {
1266
1190
for (auto resultExpr : oldMap.getResults ()) {
1267
1191
auto singleResMap = AffineMap::get (oldMap.getNumDims (),
1268
1192
oldMap.getNumSymbols (), resultExpr);
@@ -1287,7 +1211,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1287
1211
1288
1212
SmallVector<Value, 4 > remapOutputs;
1289
1213
remapOutputs.reserve (oldMemRefRank);
1290
-
1291
1214
if (indexRemap &&
1292
1215
indexRemap != builder.getMultiDimIdentityMap (indexRemap.getNumDims ())) {
1293
1216
// Remapped indices.
@@ -1303,7 +1226,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1303
1226
// No remapping specified.
1304
1227
remapOutputs.assign (remapOperands.begin (), remapOperands.end ());
1305
1228
}
1306
-
1307
1229
SmallVector<Value, 4 > newMapOperands;
1308
1230
newMapOperands.reserve (newMemRefRank);
1309
1231
@@ -1338,13 +1260,25 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1338
1260
state.operands .push_back (newMemRef);
1339
1261
1340
1262
// Insert the new memref map operands.
1341
- state.operands .append (newMapOperands.begin (), newMapOperands.end ());
1263
+ if (affMapAccInterface) {
1264
+ state.operands .append (newMapOperands.begin (), newMapOperands.end ());
1265
+ } else {
1266
+ // In the case of dereferencing ops not implementing
1267
+ // AffineMapAccessInterface, we need to apply the values of `newMapOperands`
1268
+ // to the `newMap` to get the correct indices.
1269
+ for (unsigned i = 0 ; i < newMemRefRank; i++)
1270
+ state.operands .push_back (builder.create <AffineApplyOp>(
1271
+ op->getLoc (),
1272
+ AffineMap::get (newMap.getNumDims (), newMap.getNumSymbols (),
1273
+ newMap.getResult (i)),
1274
+ newMapOperands));
1275
+ }
1342
1276
1343
1277
// Insert the remaining operands unmodified.
1278
+ unsigned oldMapNumInputs = oldMapOperands.size ();
1344
1279
state.operands .append (op->operand_begin () + memRefOperandPos + 1 +
1345
1280
oldMapNumInputs,
1346
1281
op->operand_end ());
1347
-
1348
1282
// Result types don't change. Both memref's are of the same elemental type.
1349
1283
state.types .reserve (op->getNumResults ());
1350
1284
for (auto result : op->getResults ())
@@ -1353,7 +1287,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1353
1287
// Add attribute for 'newMap', other Attributes do not change.
1354
1288
auto newMapAttr = AffineMapAttr::get (newMap);
1355
1289
for (auto namedAttr : op->getAttrs ()) {
1356
- if (namedAttr.getName () == oldMapAttrPair.getName ())
1290
+ if (affMapAccInterface &&
1291
+ namedAttr.getName () ==
1292
+ affMapAccInterface.getAffineMapAttrForMemRef (oldMemRef).getName ())
1357
1293
state.attributes .push_back ({namedAttr.getName (), newMapAttr});
1358
1294
else
1359
1295
state.attributes .push_back (namedAttr);
@@ -1846,6 +1782,95 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
1846
1782
return success ();
1847
1783
}
1848
1784
1785
+ LogicalResult
1786
+ mlir::affine::normalizeMemRef (memref::ReinterpretCastOp *reinterpretCastOp) {
1787
+ MemRefType memrefType = reinterpretCastOp->getType ();
1788
+ AffineMap oldLayoutMap = memrefType.getLayout ().getAffineMap ();
1789
+ Value oldMemRef = reinterpretCastOp->getResult ();
1790
+
1791
+ // If `oldLayoutMap` is identity, `memrefType` is already normalized.
1792
+ if (oldLayoutMap.isIdentity ())
1793
+ return success ();
1794
+
1795
+ // Fetch a new memref type after normalizing the old memref to have an
1796
+ // identity map layout.
1797
+ MemRefType newMemRefType = normalizeMemRefType (memrefType);
1798
+ newMemRefType.dump ();
1799
+ if (newMemRefType == memrefType)
1800
+ // `oldLayoutMap` couldn't be transformed to an identity map.
1801
+ return failure ();
1802
+
1803
+ uint64_t newRank = newMemRefType.getRank ();
1804
+ SmallVector<Value> mapOperands (oldLayoutMap.getNumDims () +
1805
+ oldLayoutMap.getNumSymbols ());
1806
+ SmallVector<Value> oldStrides = reinterpretCastOp->getStrides ();
1807
+ Location loc = reinterpretCastOp->getLoc ();
1808
+ // As `newMemRefType` is normalized, it is unit strided.
1809
+ SmallVector<int64_t > newStaticStrides (newRank, 1 );
1810
+ SmallVector<int64_t > newStaticOffsets (newRank, 0 );
1811
+ ArrayRef<int64_t > oldShape = memrefType.getShape ();
1812
+ mlir::ValueRange oldSizes = reinterpretCastOp->getSizes ();
1813
+ unsigned idx = 0 ;
1814
+ SmallVector<int64_t > newStaticSizes;
1815
+ OpBuilder b (*reinterpretCastOp);
1816
+ // Collectthe map operands which will be used to compute the new normalized
1817
+ // memref shape.
1818
+ for (unsigned i = 0 , e = memrefType.getRank (); i < e; i++) {
1819
+ if (memrefType.isDynamicDim (i))
1820
+ mapOperands[i] =
1821
+ b.create <arith::SubIOp>(loc, oldSizes[0 ].getType (), oldSizes[idx++],
1822
+ b.create <arith::ConstantIndexOp>(loc, 1 ));
1823
+ else
1824
+ mapOperands[i] = b.create <arith::ConstantIndexOp>(loc, oldShape[i] - 1 );
1825
+ }
1826
+ for (unsigned i = 0 , e = oldStrides.size (); i < e; i++)
1827
+ mapOperands[memrefType.getRank () + i] = oldStrides[i];
1828
+ SmallVector<Value> newSizes;
1829
+ ArrayRef<int64_t > newShape = newMemRefType.getShape ();
1830
+ // Compute size along all the dimensions of the new normalized memref.
1831
+ for (unsigned i = 0 ; i < newRank; i++) {
1832
+ if (!newMemRefType.isDynamicDim (i))
1833
+ continue ;
1834
+ newSizes.push_back (b.create <AffineApplyOp>(
1835
+ loc,
1836
+ AffineMap::get (oldLayoutMap.getNumDims (), oldLayoutMap.getNumSymbols (),
1837
+ oldLayoutMap.getResult (i)),
1838
+ mapOperands));
1839
+ }
1840
+ for (unsigned i = 0 , e = newSizes.size (); i < e; i++)
1841
+ newSizes[i] =
1842
+ b.create <arith::AddIOp>(loc, newSizes[i].getType (), newSizes[i],
1843
+ b.create <arith::ConstantIndexOp>(loc, 1 ));
1844
+ // Create the new reinterpret_cast op.
1845
+ memref::ReinterpretCastOp newReinterpretCast =
1846
+ b.create <memref::ReinterpretCastOp>(
1847
+ loc, newMemRefType, reinterpretCastOp->getSource (),
1848
+ /* offsets=*/ mlir::ValueRange (), newSizes,
1849
+ /* strides=*/ mlir::ValueRange (),
1850
+ /* static_offsets=*/ newStaticOffsets,
1851
+ /* static_sizes=*/ newShape,
1852
+ /* static_strides=*/ newStaticStrides);
1853
+
1854
+ // Replace all uses of the old memref.
1855
+ if (failed (replaceAllMemRefUsesWith (oldMemRef,
1856
+ /* newMemRef=*/ newReinterpretCast,
1857
+ /* extraIndices=*/ {},
1858
+ /* indexRemap=*/ oldLayoutMap,
1859
+ /* extraOperands=*/ {},
1860
+ /* symbolOperands=*/ oldStrides,
1861
+ /* domOpFilter=*/ nullptr ,
1862
+ /* postDomOpFilter=*/ nullptr ,
1863
+ /* allowNonDereferencingOps=*/ true ))) {
1864
+ // If it failed (due to escapes for example), bail out.
1865
+ newReinterpretCast->erase ();
1866
+ return failure ();
1867
+ }
1868
+
1869
+ oldMemRef.replaceAllUsesWith (newReinterpretCast);
1870
+ reinterpretCastOp->erase ();
1871
+ return success ();
1872
+ }
1873
+
1849
1874
template LogicalResult
1850
1875
mlir::affine::normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
1851
1876
template LogicalResult
0 commit comments