@@ -1098,90 +1098,12 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
1098
1098
op->erase ();
1099
1099
}
1100
1100
1101
- // Private helper function to transform memref.load with reduced rank.
1102
- // This function will modify the indices of the memref.load to match the
1103
- // newMemRef.
1104
- LogicalResult transformMemRefLoadWithReducedRank (
1105
- Operation *op, Value oldMemRef, Value newMemRef, unsigned memRefOperandPos,
1106
- ArrayRef<Value> extraIndices, ArrayRef<Value> extraOperands,
1107
- ArrayRef<Value> symbolOperands, AffineMap indexRemap) {
1108
- unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType ()).getRank ();
1109
- unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType ()).getRank ();
1110
- unsigned oldMapNumInputs = oldMemRefRank;
1111
- SmallVector<Value, 4 > oldMapOperands (
1112
- op->operand_begin () + memRefOperandPos + 1 ,
1113
- op->operand_begin () + memRefOperandPos + 1 + oldMapNumInputs);
1114
- SmallVector<Value, 4 > oldMemRefOperands;
1115
- oldMemRefOperands.assign (oldMapOperands.begin (), oldMapOperands.end ());
1116
- SmallVector<Value, 4 > remapOperands;
1117
- remapOperands.reserve (extraOperands.size () + oldMemRefRank +
1118
- symbolOperands.size ());
1119
- remapOperands.append (extraOperands.begin (), extraOperands.end ());
1120
- remapOperands.append (oldMemRefOperands.begin (), oldMemRefOperands.end ());
1121
- remapOperands.append (symbolOperands.begin (), symbolOperands.end ());
1122
-
1123
- SmallVector<Value, 4 > remapOutputs;
1124
- remapOutputs.reserve (oldMemRefRank);
1125
- SmallVector<Value, 4 > affineApplyOps;
1126
-
1127
- OpBuilder builder (op);
1128
-
1129
- if (indexRemap &&
1130
- indexRemap != builder.getMultiDimIdentityMap (indexRemap.getNumDims ())) {
1131
- // Remapped indices.
1132
- for (auto resultExpr : indexRemap.getResults ()) {
1133
- auto singleResMap = AffineMap::get (
1134
- indexRemap.getNumDims (), indexRemap.getNumSymbols (), resultExpr);
1135
- auto afOp = builder.create <AffineApplyOp>(op->getLoc (), singleResMap,
1136
- remapOperands);
1137
- remapOutputs.push_back (afOp);
1138
- affineApplyOps.push_back (afOp);
1139
- }
1140
- } else {
1141
- // No remapping specified.
1142
- remapOutputs.assign (remapOperands.begin (), remapOperands.end ());
1143
- }
1144
-
1145
- SmallVector<Value, 4 > newMapOperands;
1146
- newMapOperands.reserve (newMemRefRank);
1147
-
1148
- // Prepend 'extraIndices' in 'newMapOperands'.
1149
- for (Value extraIndex : extraIndices) {
1150
- assert ((isValidDim (extraIndex) || isValidSymbol (extraIndex)) &&
1151
- " invalid memory op index" );
1152
- newMapOperands.push_back (extraIndex);
1153
- }
1154
-
1155
- // Append 'remapOutputs' to 'newMapOperands'.
1156
- newMapOperands.append (remapOutputs.begin (), remapOutputs.end ());
1157
-
1158
- // Create new fully composed AffineMap for new op to be created.
1159
- assert (newMapOperands.size () == newMemRefRank);
1160
-
1161
- OperationState state (op->getLoc (), op->getName ());
1162
- // Construct the new operation using this memref.
1163
- state.operands .reserve (newMapOperands.size () + extraIndices.size ());
1164
- state.operands .push_back (newMemRef);
1165
-
1166
- // Insert the new memref map operands.
1167
- state.operands .append (newMapOperands.begin (), newMapOperands.end ());
1168
-
1169
- state.types .reserve (op->getNumResults ());
1170
- for (auto result : op->getResults ())
1171
- state.types .push_back (result.getType ());
1172
-
1173
- // Copy over the attributes from the old operation to the new operation.
1174
- for (auto namedAttr : op->getAttrs ()) {
1175
- state.attributes .push_back (namedAttr);
1176
- }
1177
-
1178
- // Create the new operation.
1179
- auto *repOp = builder.create (state);
1180
- op->replaceAllUsesWith (repOp);
1181
- op->erase ();
1182
-
1183
- return success ();
1101
+ // Checks if `op` is non dereferencing.
1102
+ // TODO: This hardcoded check will be removed once the right interface is added.
1103
+ static bool isDereferencingOp (Operation *op) {
1104
+ return isa<AffineMapAccessInterface, memref::LoadOp, memref::StoreOp>(op);
1184
1105
}
1106
+
1185
1107
// Perform the replacement in `op`.
1186
1108
LogicalResult mlir::affine::replaceAllMemRefUsesWith (
1187
1109
Value oldMemRef, Value newMemRef, Operation *op,
@@ -1216,53 +1138,57 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1216
1138
if (usePositions.empty ())
1217
1139
return success ();
1218
1140
1219
- if (usePositions.size () > 1 ) {
1220
- // TODO: extend it for this case when needed (rare).
1221
- assert (false && " multiple dereferencing uses in a single op not supported" );
1222
- return failure ();
1223
- }
1224
-
1225
1141
unsigned memRefOperandPos = usePositions.front ();
1226
1142
1227
1143
OpBuilder builder (op);
1228
1144
// The following checks if op is dereferencing memref and performs the access
1229
1145
// index rewrites.
1230
1146
auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
1231
- if (!affMapAccInterface ) {
1147
+ if (!isDereferencingOp (op) ) {
1232
1148
if (!allowNonDereferencingOps) {
1233
1149
// Failure: memref used in a non-dereferencing context (potentially
1234
1150
// escapes); no replacement in these cases unless allowNonDereferencingOps
1235
1151
// is set.
1236
1152
return failure ();
1237
1153
}
1154
+ for (unsigned pos : usePositions)
1155
+ op->setOperand (pos, newMemRef);
1156
+ return success ();
1157
+ }
1238
1158
1239
- // Check if it is a memref.load
1240
- auto memrefLoad = dyn_cast<memref::LoadOp>(op);
1241
- bool isReductionLike =
1242
- indexRemap.getNumResults () < indexRemap.getNumInputs ();
1243
- if (!memrefLoad || !isReductionLike) {
1244
- op->setOperand (memRefOperandPos, newMemRef);
1245
- return success ();
1246
- }
1159
+ if (usePositions.size () > 1 ) {
1160
+ // TODO: extend it for this case when needed (rare).
1161
+ assert (false && " multiple dereferencing uses in a single op not supported" );
1162
+ return failure ();
1163
+ }
1247
1164
1248
- return transformMemRefLoadWithReducedRank (
1249
- op, oldMemRef, newMemRef, memRefOperandPos, extraIndices, extraOperands,
1250
- symbolOperands, indexRemap);
1165
+ // Perform index rewrites for the dereferencing op and then replace the op.
1166
+ SmallVector<Value, 4 > oldMapOperands;
1167
+ AffineMap oldMap;
1168
+ unsigned oldMemRefNumIndices = oldMemRefRank;
1169
+ if (affMapAccInterface) {
1170
+ // If `op` implements AffineMapAccessInterface, we can get the indices by
1171
+ // quering the number of map operands from the operand list from a certain
1172
+ // offset (`memRefOperandPos` in this case).
1173
+ NamedAttribute oldMapAttrPair =
1174
+ affMapAccInterface.getAffineMapAttrForMemRef (oldMemRef);
1175
+ oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue ()).getValue ();
1176
+ oldMemRefNumIndices = oldMap.getNumInputs ();
1177
+ oldMapOperands.assign (op->operand_begin () + memRefOperandPos + 1 ,
1178
+ op->operand_begin () + memRefOperandPos + 1 +
1179
+ oldMemRefNumIndices);
1180
+ } else {
1181
+ oldMapOperands.assign (op->operand_begin () + memRefOperandPos + 1 ,
1182
+ op->operand_begin () + memRefOperandPos + 1 +
1183
+ oldMemRefRank);
1251
1184
}
1252
- // Perform index rewrites for the dereferencing op and then replace the op
1253
- NamedAttribute oldMapAttrPair =
1254
- affMapAccInterface.getAffineMapAttrForMemRef (oldMemRef);
1255
- AffineMap oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue ()).getValue ();
1256
- unsigned oldMapNumInputs = oldMap.getNumInputs ();
1257
- SmallVector<Value, 4 > oldMapOperands (
1258
- op->operand_begin () + memRefOperandPos + 1 ,
1259
- op->operand_begin () + memRefOperandPos + 1 + oldMapNumInputs);
1260
1185
1261
1186
// Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
1262
1187
SmallVector<Value, 4 > oldMemRefOperands;
1263
1188
SmallVector<Value, 4 > affineApplyOps;
1264
1189
oldMemRefOperands.reserve (oldMemRefRank);
1265
- if (oldMap != builder.getMultiDimIdentityMap (oldMap.getNumDims ())) {
1190
+ if (affMapAccInterface &&
1191
+ oldMap != builder.getMultiDimIdentityMap (oldMap.getNumDims ())) {
1266
1192
for (auto resultExpr : oldMap.getResults ()) {
1267
1193
auto singleResMap = AffineMap::get (oldMap.getNumDims (),
1268
1194
oldMap.getNumSymbols (), resultExpr);
@@ -1287,7 +1213,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1287
1213
1288
1214
SmallVector<Value, 4 > remapOutputs;
1289
1215
remapOutputs.reserve (oldMemRefRank);
1290
-
1291
1216
if (indexRemap &&
1292
1217
indexRemap != builder.getMultiDimIdentityMap (indexRemap.getNumDims ())) {
1293
1218
// Remapped indices.
@@ -1303,7 +1228,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1303
1228
// No remapping specified.
1304
1229
remapOutputs.assign (remapOperands.begin (), remapOperands.end ());
1305
1230
}
1306
-
1307
1231
SmallVector<Value, 4 > newMapOperands;
1308
1232
newMapOperands.reserve (newMemRefRank);
1309
1233
@@ -1338,13 +1262,26 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1338
1262
state.operands .push_back (newMemRef);
1339
1263
1340
1264
// Insert the new memref map operands.
1341
- state.operands .append (newMapOperands.begin (), newMapOperands.end ());
1265
+ if (affMapAccInterface) {
1266
+ state.operands .append (newMapOperands.begin (), newMapOperands.end ());
1267
+ } else {
1268
+ // In the case of dereferencing ops not implementing
1269
+ // AffineMapAccessInterface, we need to apply the values of `newMapOperands`
1270
+ // to the `newMap` to get the correct indices.
1271
+ for (unsigned i = 0 ; i < newMemRefRank; i++)
1272
+ state.operands .push_back (builder.create <AffineApplyOp>(
1273
+ op->getLoc (),
1274
+ AffineMap::get (newMap.getNumDims (), newMap.getNumSymbols (),
1275
+ newMap.getResult (i)),
1276
+ newMapOperands));
1277
+ }
1342
1278
1343
1279
// Insert the remaining operands unmodified.
1280
+ unsigned oldMapNumInputs = oldMapOperands.size ();
1281
+
1344
1282
state.operands .append (op->operand_begin () + memRefOperandPos + 1 +
1345
1283
oldMapNumInputs,
1346
1284
op->operand_end ());
1347
-
1348
1285
// Result types don't change. Both memref's are of the same elemental type.
1349
1286
state.types .reserve (op->getNumResults ());
1350
1287
for (auto result : op->getResults ())
@@ -1353,7 +1290,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1353
1290
// Add attribute for 'newMap', other Attributes do not change.
1354
1291
auto newMapAttr = AffineMapAttr::get (newMap);
1355
1292
for (auto namedAttr : op->getAttrs ()) {
1356
- if (namedAttr.getName () == oldMapAttrPair.getName ())
1293
+ if (affMapAccInterface &&
1294
+ namedAttr.getName () ==
1295
+ affMapAccInterface.getAffineMapAttrForMemRef (oldMemRef).getName ())
1357
1296
state.attributes .push_back ({namedAttr.getName (), newMapAttr});
1358
1297
else
1359
1298
state.attributes .push_back (namedAttr);
@@ -1846,6 +1785,95 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
1846
1785
return success ();
1847
1786
}
1848
1787
1788
+ LogicalResult
1789
+ mlir::affine::normalizeMemRef (memref::ReinterpretCastOp *reinterpretCastOp) {
1790
+ MemRefType memrefType = reinterpretCastOp->getType ();
1791
+ AffineMap oldLayoutMap = memrefType.getLayout ().getAffineMap ();
1792
+ Value oldMemRef = reinterpretCastOp->getResult ();
1793
+
1794
+ // If `oldLayoutMap` is identity, `memrefType` is already normalized.
1795
+ if (oldLayoutMap.isIdentity ())
1796
+ return success ();
1797
+
1798
+ // Fetch a new memref type after normalizing the old memref to have an
1799
+ // identity map layout.
1800
+ MemRefType newMemRefType = normalizeMemRefType (memrefType);
1801
+ newMemRefType.dump ();
1802
+ if (newMemRefType == memrefType)
1803
+ // `oldLayoutMap` couldn't be transformed to an identity map.
1804
+ return failure ();
1805
+
1806
+ uint64_t newRank = newMemRefType.getRank ();
1807
+ SmallVector<Value> mapOperands (oldLayoutMap.getNumDims () +
1808
+ oldLayoutMap.getNumSymbols ());
1809
+ SmallVector<Value> oldStrides = reinterpretCastOp->getStrides ();
1810
+ Location loc = reinterpretCastOp->getLoc ();
1811
+ // As `newMemRefType` is normalized, it is unit strided.
1812
+ SmallVector<int64_t > newStaticStrides (newRank, 1 );
1813
+ SmallVector<int64_t > newStaticOffsets (newRank, 0 );
1814
+ ArrayRef<int64_t > oldShape = memrefType.getShape ();
1815
+ mlir::ValueRange oldSizes = reinterpretCastOp->getSizes ();
1816
+ unsigned idx = 0 ;
1817
+ SmallVector<int64_t > newStaticSizes;
1818
+ OpBuilder b (*reinterpretCastOp);
1819
+ // Collectthe map operands which will be used to compute the new normalized
1820
+ // memref shape.
1821
+ for (unsigned i = 0 , e = memrefType.getRank (); i < e; i++) {
1822
+ if (memrefType.isDynamicDim (i))
1823
+ mapOperands[i] =
1824
+ b.create <arith::SubIOp>(loc, oldSizes[0 ].getType (), oldSizes[idx++],
1825
+ b.create <arith::ConstantIndexOp>(loc, 1 ));
1826
+ else
1827
+ mapOperands[i] = b.create <arith::ConstantIndexOp>(loc, oldShape[i] - 1 );
1828
+ }
1829
+ for (unsigned i = 0 , e = oldStrides.size (); i < e; i++)
1830
+ mapOperands[memrefType.getRank () + i] = oldStrides[i];
1831
+ SmallVector<Value> newSizes;
1832
+ ArrayRef<int64_t > newShape = newMemRefType.getShape ();
1833
+ // Compute size along all the dimensions of the new normalized memref.
1834
+ for (unsigned i = 0 ; i < newRank; i++) {
1835
+ if (!newMemRefType.isDynamicDim (i))
1836
+ continue ;
1837
+ newSizes.push_back (b.create <AffineApplyOp>(
1838
+ loc,
1839
+ AffineMap::get (oldLayoutMap.getNumDims (), oldLayoutMap.getNumSymbols (),
1840
+ oldLayoutMap.getResult (i)),
1841
+ mapOperands));
1842
+ }
1843
+ for (unsigned i = 0 , e = newSizes.size (); i < e; i++)
1844
+ newSizes[i] =
1845
+ b.create <arith::AddIOp>(loc, newSizes[i].getType (), newSizes[i],
1846
+ b.create <arith::ConstantIndexOp>(loc, 1 ));
1847
+ // Create the new reinterpret_cast op.
1848
+ memref::ReinterpretCastOp newReinterpretCast =
1849
+ b.create <memref::ReinterpretCastOp>(
1850
+ loc, newMemRefType, reinterpretCastOp->getSource (),
1851
+ /* offsets=*/ mlir::ValueRange (), newSizes,
1852
+ /* strides=*/ mlir::ValueRange (),
1853
+ /* static_offsets=*/ newStaticOffsets,
1854
+ /* static_sizes=*/ newShape,
1855
+ /* static_strides=*/ newStaticStrides);
1856
+
1857
+ // Replace all uses of the old memref.
1858
+ if (failed (replaceAllMemRefUsesWith (oldMemRef,
1859
+ /* newMemRef=*/ newReinterpretCast,
1860
+ /* extraIndices=*/ {},
1861
+ /* indexRemap=*/ oldLayoutMap,
1862
+ /* extraOperands=*/ {},
1863
+ /* symbolOperands=*/ oldStrides,
1864
+ /* domOpFilter=*/ nullptr ,
1865
+ /* postDomOpFilter=*/ nullptr ,
1866
+ /* allowNonDereferencingOps=*/ true ))) {
1867
+ // If it failed (due to escapes for example), bail out.
1868
+ newReinterpretCast->erase ();
1869
+ return failure ();
1870
+ }
1871
+
1872
+ oldMemRef.replaceAllUsesWith (newReinterpretCast);
1873
+ reinterpretCastOp->erase ();
1874
+ return success ();
1875
+ }
1876
+
1849
1877
template LogicalResult
1850
1878
mlir::affine::normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
1851
1879
template LogicalResult
0 commit comments