@@ -1098,90 +1098,12 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
1098
1098
op->erase ();
1099
1099
}
1100
1100
1101
- // Private helper function to transform memref.load with reduced rank.
1102
- // This function will modify the indices of the memref.load to match the
1103
- // newMemRef.
1104
- LogicalResult transformMemRefLoadWithReducedRank (
1105
- Operation *op, Value oldMemRef, Value newMemRef, unsigned memRefOperandPos,
1106
- ArrayRef<Value> extraIndices, ArrayRef<Value> extraOperands,
1107
- ArrayRef<Value> symbolOperands, AffineMap indexRemap) {
1108
- unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType ()).getRank ();
1109
- unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType ()).getRank ();
1110
- unsigned oldMapNumInputs = oldMemRefRank;
1111
- SmallVector<Value, 4 > oldMapOperands (
1112
- op->operand_begin () + memRefOperandPos + 1 ,
1113
- op->operand_begin () + memRefOperandPos + 1 + oldMapNumInputs);
1114
- SmallVector<Value, 4 > oldMemRefOperands;
1115
- oldMemRefOperands.assign (oldMapOperands.begin (), oldMapOperands.end ());
1116
- SmallVector<Value, 4 > remapOperands;
1117
- remapOperands.reserve (extraOperands.size () + oldMemRefRank +
1118
- symbolOperands.size ());
1119
- remapOperands.append (extraOperands.begin (), extraOperands.end ());
1120
- remapOperands.append (oldMemRefOperands.begin (), oldMemRefOperands.end ());
1121
- remapOperands.append (symbolOperands.begin (), symbolOperands.end ());
1122
-
1123
- SmallVector<Value, 4 > remapOutputs;
1124
- remapOutputs.reserve (oldMemRefRank);
1125
- SmallVector<Value, 4 > affineApplyOps;
1126
-
1127
- OpBuilder builder (op);
1128
-
1129
- if (indexRemap &&
1130
- indexRemap != builder.getMultiDimIdentityMap (indexRemap.getNumDims ())) {
1131
- // Remapped indices.
1132
- for (auto resultExpr : indexRemap.getResults ()) {
1133
- auto singleResMap = AffineMap::get (
1134
- indexRemap.getNumDims (), indexRemap.getNumSymbols (), resultExpr);
1135
- auto afOp = builder.create <AffineApplyOp>(op->getLoc (), singleResMap,
1136
- remapOperands);
1137
- remapOutputs.push_back (afOp);
1138
- affineApplyOps.push_back (afOp);
1139
- }
1140
- } else {
1141
- // No remapping specified.
1142
- remapOutputs.assign (remapOperands.begin (), remapOperands.end ());
1143
- }
1144
-
1145
- SmallVector<Value, 4 > newMapOperands;
1146
- newMapOperands.reserve (newMemRefRank);
1147
-
1148
- // Prepend 'extraIndices' in 'newMapOperands'.
1149
- for (Value extraIndex : extraIndices) {
1150
- assert ((isValidDim (extraIndex) || isValidSymbol (extraIndex)) &&
1151
- " invalid memory op index" );
1152
- newMapOperands.push_back (extraIndex);
1153
- }
1154
-
1155
- // Append 'remapOutputs' to 'newMapOperands'.
1156
- newMapOperands.append (remapOutputs.begin (), remapOutputs.end ());
1157
-
1158
- // Create new fully composed AffineMap for new op to be created.
1159
- assert (newMapOperands.size () == newMemRefRank);
1160
-
1161
- OperationState state (op->getLoc (), op->getName ());
1162
- // Construct the new operation using this memref.
1163
- state.operands .reserve (newMapOperands.size () + extraIndices.size ());
1164
- state.operands .push_back (newMemRef);
1165
-
1166
- // Insert the new memref map operands.
1167
- state.operands .append (newMapOperands.begin (), newMapOperands.end ());
1168
-
1169
- state.types .reserve (op->getNumResults ());
1170
- for (auto result : op->getResults ())
1171
- state.types .push_back (result.getType ());
1172
-
1173
- // Copy over the attributes from the old operation to the new operation.
1174
- for (auto namedAttr : op->getAttrs ()) {
1175
- state.attributes .push_back (namedAttr);
1176
- }
1177
-
1178
- // Create the new operation.
1179
- auto *repOp = builder.create (state);
1180
- op->replaceAllUsesWith (repOp);
1181
- op->erase ();
1182
-
1183
- return success ();
1101
+ // Checks if `op` is non dereferencing.
1102
+ // TODO: This hardcoded check will be removed once the right interface is added.
1103
+ static bool isDereferencingOp (Operation *op) {
1104
+ return isa<AffineMapAccessInterface, memref::LoadOp, memref::StoreOp>(op);
1184
1105
}
1106
+
1185
1107
// Perform the replacement in `op`.
1186
1108
LogicalResult mlir::affine::replaceAllMemRefUsesWith (
1187
1109
Value oldMemRef, Value newMemRef, Operation *op,
@@ -1216,53 +1138,53 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1216
1138
if (usePositions.empty ())
1217
1139
return success ();
1218
1140
1219
- if (usePositions.size () > 1 ) {
1220
- // TODO: extend it for this case when needed (rare).
1221
- assert (false && " multiple dereferencing uses in a single op not supported" );
1222
- return failure ();
1223
- }
1224
-
1225
1141
unsigned memRefOperandPos = usePositions.front ();
1226
1142
1227
1143
OpBuilder builder (op);
1228
1144
// The following checks if op is dereferencing memref and performs the access
1229
1145
// index rewrites.
1230
- auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
1231
- if (!affMapAccInterface) {
1146
+ if (!isDereferencingOp (op)) {
1232
1147
if (!allowNonDereferencingOps) {
1233
1148
// Failure: memref used in a non-dereferencing context (potentially
1234
1149
// escapes); no replacement in these cases unless allowNonDereferencingOps
1235
1150
// is set.
1236
1151
return failure ();
1237
1152
}
1153
+ for (unsigned pos : usePositions)
1154
+ op->setOperand (pos, newMemRef);
1155
+ return success ();
1156
+ }
1238
1157
1239
- // Check if it is a memref.load
1240
- auto memrefLoad = dyn_cast<memref::LoadOp>(op);
1241
- bool isReductionLike =
1242
- indexRemap.getNumResults () < indexRemap.getNumInputs ();
1243
- if (!memrefLoad || !isReductionLike) {
1244
- op->setOperand (memRefOperandPos, newMemRef);
1245
- return success ();
1246
- }
1158
+ if (usePositions.size () > 1 ) {
1159
+ // TODO: extend it for this case when needed (rare).
1160
+ LLVM_DEBUG (llvm::dbgs ()
1161
+ << " multiple dereferencing uses in a single op not supported" );
1162
+ return failure ();
1163
+ }
1247
1164
1248
- return transformMemRefLoadWithReducedRank (
1249
- op, oldMemRef, newMemRef, memRefOperandPos, extraIndices, extraOperands,
1250
- symbolOperands, indexRemap);
1165
+ // Perform index rewrites for the dereferencing op and then replace the op.
1166
+ SmallVector<Value, 4 > oldMapOperands;
1167
+ AffineMap oldMap;
1168
+ unsigned oldMemRefNumIndices = oldMemRefRank;
1169
+ auto startIdx = op->operand_begin () + memRefOperandPos + 1 ;
1170
+ auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
1171
+ if (affMapAccInterface) {
1172
+ // If `op` implements AffineMapAccessInterface, we can get the indices by
1173
+ // quering the number of map operands from the operand list from a certain
1174
+ // offset (`memRefOperandPos` in this case).
1175
+ NamedAttribute oldMapAttrPair =
1176
+ affMapAccInterface.getAffineMapAttrForMemRef (oldMemRef);
1177
+ oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue ()).getValue ();
1178
+ oldMemRefNumIndices = oldMap.getNumInputs ();
1251
1179
}
1252
- // Perform index rewrites for the dereferencing op and then replace the op
1253
- NamedAttribute oldMapAttrPair =
1254
- affMapAccInterface.getAffineMapAttrForMemRef (oldMemRef);
1255
- AffineMap oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue ()).getValue ();
1256
- unsigned oldMapNumInputs = oldMap.getNumInputs ();
1257
- SmallVector<Value, 4 > oldMapOperands (
1258
- op->operand_begin () + memRefOperandPos + 1 ,
1259
- op->operand_begin () + memRefOperandPos + 1 + oldMapNumInputs);
1180
+ oldMapOperands.assign (startIdx, startIdx + oldMemRefNumIndices);
1260
1181
1261
1182
// Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
1262
1183
SmallVector<Value, 4 > oldMemRefOperands;
1263
1184
SmallVector<Value, 4 > affineApplyOps;
1264
1185
oldMemRefOperands.reserve (oldMemRefRank);
1265
- if (oldMap != builder.getMultiDimIdentityMap (oldMap.getNumDims ())) {
1186
+ if (affMapAccInterface &&
1187
+ oldMap != builder.getMultiDimIdentityMap (oldMap.getNumDims ())) {
1266
1188
for (auto resultExpr : oldMap.getResults ()) {
1267
1189
auto singleResMap = AffineMap::get (oldMap.getNumDims (),
1268
1190
oldMap.getNumSymbols (), resultExpr);
@@ -1287,7 +1209,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1287
1209
1288
1210
SmallVector<Value, 4 > remapOutputs;
1289
1211
remapOutputs.reserve (oldMemRefRank);
1290
-
1291
1212
if (indexRemap &&
1292
1213
indexRemap != builder.getMultiDimIdentityMap (indexRemap.getNumDims ())) {
1293
1214
// Remapped indices.
@@ -1303,7 +1224,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1303
1224
// No remapping specified.
1304
1225
remapOutputs.assign (remapOperands.begin (), remapOperands.end ());
1305
1226
}
1306
-
1307
1227
SmallVector<Value, 4 > newMapOperands;
1308
1228
newMapOperands.reserve (newMemRefRank);
1309
1229
@@ -1338,13 +1258,26 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1338
1258
state.operands .push_back (newMemRef);
1339
1259
1340
1260
// Insert the new memref map operands.
1341
- state.operands .append (newMapOperands.begin (), newMapOperands.end ());
1261
+ if (affMapAccInterface) {
1262
+ state.operands .append (newMapOperands.begin (), newMapOperands.end ());
1263
+ } else {
1264
+ // In the case of dereferencing ops not implementing
1265
+ // AffineMapAccessInterface, we need to apply the values of `newMapOperands`
1266
+ // to the `newMap` to get the correct indices.
1267
+ for (unsigned i = 0 ; i < newMemRefRank; i++) {
1268
+ state.operands .push_back (builder.create <AffineApplyOp>(
1269
+ op->getLoc (),
1270
+ AffineMap::get (newMap.getNumDims (), newMap.getNumSymbols (),
1271
+ newMap.getResult (i)),
1272
+ newMapOperands));
1273
+ }
1274
+ }
1342
1275
1343
1276
// Insert the remaining operands unmodified.
1277
+ unsigned oldMapNumInputs = oldMapOperands.size ();
1344
1278
state.operands .append (op->operand_begin () + memRefOperandPos + 1 +
1345
1279
oldMapNumInputs,
1346
1280
op->operand_end ());
1347
-
1348
1281
// Result types don't change. Both memref's are of the same elemental type.
1349
1282
state.types .reserve (op->getNumResults ());
1350
1283
for (auto result : op->getResults ())
@@ -1353,7 +1286,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1353
1286
// Add attribute for 'newMap', other Attributes do not change.
1354
1287
auto newMapAttr = AffineMapAttr::get (newMap);
1355
1288
for (auto namedAttr : op->getAttrs ()) {
1356
- if (namedAttr.getName () == oldMapAttrPair.getName ())
1289
+ if (affMapAccInterface &&
1290
+ namedAttr.getName () ==
1291
+ affMapAccInterface.getAffineMapAttrForMemRef (oldMemRef).getName ())
1357
1292
state.attributes .push_back ({namedAttr.getName (), newMapAttr});
1358
1293
else
1359
1294
state.attributes .push_back (namedAttr);
@@ -1845,6 +1780,94 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp allocOp) {
1845
1780
return success ();
1846
1781
}
1847
1782
1783
+ LogicalResult
1784
+ mlir::affine::normalizeMemRef (memref::ReinterpretCastOp reinterpretCastOp) {
1785
+ MemRefType memrefType = reinterpretCastOp.getType ();
1786
+ AffineMap oldLayoutMap = memrefType.getLayout ().getAffineMap ();
1787
+ Value oldMemRef = reinterpretCastOp.getResult ();
1788
+
1789
+ // If `oldLayoutMap` is identity, `memrefType` is already normalized.
1790
+ if (oldLayoutMap.isIdentity ())
1791
+ return success ();
1792
+
1793
+ // Fetch a new memref type after normalizing the old memref to have an
1794
+ // identity map layout.
1795
+ MemRefType newMemRefType = normalizeMemRefType (memrefType);
1796
+ if (newMemRefType == memrefType)
1797
+ // `oldLayoutMap` couldn't be transformed to an identity map.
1798
+ return failure ();
1799
+
1800
+ uint64_t newRank = newMemRefType.getRank ();
1801
+ SmallVector<Value> mapOperands (oldLayoutMap.getNumDims () +
1802
+ oldLayoutMap.getNumSymbols ());
1803
+ SmallVector<Value> oldStrides = reinterpretCastOp.getStrides ();
1804
+ Location loc = reinterpretCastOp.getLoc ();
1805
+ // As `newMemRefType` is normalized, it is unit strided.
1806
+ SmallVector<int64_t > newStaticStrides (newRank, 1 );
1807
+ SmallVector<int64_t > newStaticOffsets (newRank, 0 );
1808
+ ArrayRef<int64_t > oldShape = memrefType.getShape ();
1809
+ ValueRange oldSizes = reinterpretCastOp.getSizes ();
1810
+ unsigned idx = 0 ;
1811
+ SmallVector<int64_t > newStaticSizes;
1812
+ OpBuilder b (reinterpretCastOp);
1813
+ // Collect the map operands which will be used to compute the new normalized
1814
+ // memref shape.
1815
+ for (unsigned i = 0 , e = memrefType.getRank (); i < e; i++) {
1816
+ if (memrefType.isDynamicDim (i))
1817
+ mapOperands[i] =
1818
+ b.create <arith::SubIOp>(loc, oldSizes[0 ].getType (), oldSizes[idx++],
1819
+ b.create <arith::ConstantIndexOp>(loc, 1 ));
1820
+ else
1821
+ mapOperands[i] = b.create <arith::ConstantIndexOp>(loc, oldShape[i] - 1 );
1822
+ }
1823
+ for (unsigned i = 0 , e = oldStrides.size (); i < e; i++)
1824
+ mapOperands[memrefType.getRank () + i] = oldStrides[i];
1825
+ SmallVector<Value> newSizes;
1826
+ ArrayRef<int64_t > newShape = newMemRefType.getShape ();
1827
+ // Compute size along all the dimensions of the new normalized memref.
1828
+ for (unsigned i = 0 ; i < newRank; i++) {
1829
+ if (!newMemRefType.isDynamicDim (i))
1830
+ continue ;
1831
+ newSizes.push_back (b.create <AffineApplyOp>(
1832
+ loc,
1833
+ AffineMap::get (oldLayoutMap.getNumDims (), oldLayoutMap.getNumSymbols (),
1834
+ oldLayoutMap.getResult (i)),
1835
+ mapOperands));
1836
+ }
1837
+ for (unsigned i = 0 , e = newSizes.size (); i < e; i++) {
1838
+ newSizes[i] =
1839
+ b.create <arith::AddIOp>(loc, newSizes[i].getType (), newSizes[i],
1840
+ b.create <arith::ConstantIndexOp>(loc, 1 ));
1841
+ }
1842
+ // Create the new reinterpret_cast op.
1843
+ auto newReinterpretCast = b.create <memref::ReinterpretCastOp>(
1844
+ loc, newMemRefType, reinterpretCastOp.getSource (),
1845
+ /* offsets=*/ ValueRange (), newSizes,
1846
+ /* strides=*/ ValueRange (),
1847
+ /* static_offsets=*/ newStaticOffsets,
1848
+ /* static_sizes=*/ newShape,
1849
+ /* static_strides=*/ newStaticStrides);
1850
+
1851
+ // Replace all uses of the old memref.
1852
+ if (failed (replaceAllMemRefUsesWith (oldMemRef,
1853
+ /* newMemRef=*/ newReinterpretCast,
1854
+ /* extraIndices=*/ {},
1855
+ /* indexRemap=*/ oldLayoutMap,
1856
+ /* extraOperands=*/ {},
1857
+ /* symbolOperands=*/ oldStrides,
1858
+ /* domOpFilter=*/ nullptr ,
1859
+ /* postDomOpFilter=*/ nullptr ,
1860
+ /* allowNonDereferencingOps=*/ true ))) {
1861
+ // If it failed (due to escapes for example), bail out.
1862
+ newReinterpretCast.erase ();
1863
+ return failure ();
1864
+ }
1865
+
1866
+ oldMemRef.replaceAllUsesWith (newReinterpretCast);
1867
+ reinterpretCastOp.erase ();
1868
+ return success ();
1869
+ }
1870
+
1848
1871
template LogicalResult
1849
1872
mlir::affine::normalizeMemRef<memref::AllocaOp>(memref::AllocaOp op);
1850
1873
template LogicalResult
0 commit comments