|
27 | 27 | #include "mlir/IR/ImplicitLocOpBuilder.h"
|
28 | 28 | #include "mlir/IR/IntegerSet.h"
|
29 | 29 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
| 30 | +#include "llvm/Support/LogicalResult.h" |
30 | 31 | #include <optional>
|
31 | 32 |
|
32 | 33 | #define DEBUG_TYPE "affine-utils"
|
@@ -1093,6 +1094,90 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
|
1093 | 1094 | op->erase();
|
1094 | 1095 | }
|
1095 | 1096 |
|
| 1097 | +// Private helper function to transform memref.load with reduced rank. |
| 1098 | +// This function will modify the indices of the memref.load to match the |
| 1099 | +// newMemRef. |
| 1100 | +LogicalResult transformMemRefLoadWithReducedRank( |
| 1101 | + Operation *op, Value oldMemRef, Value newMemRef, unsigned memRefOperandPos, |
| 1102 | + ArrayRef<Value> extraIndices, ArrayRef<Value> extraOperands, |
| 1103 | + ArrayRef<Value> symbolOperands, AffineMap indexRemap) { |
| 1104 | + unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank(); |
| 1105 | + unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank(); |
| 1106 | + unsigned oldMapNumInputs = oldMemRefRank; |
| 1107 | + SmallVector<Value, 4> oldMapOperands( |
| 1108 | + op->operand_begin() + memRefOperandPos + 1, |
| 1109 | + op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs); |
| 1110 | + SmallVector<Value, 4> oldMemRefOperands; |
| 1111 | + oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end()); |
| 1112 | + SmallVector<Value, 4> remapOperands; |
| 1113 | + remapOperands.reserve(extraOperands.size() + oldMemRefRank + |
| 1114 | + symbolOperands.size()); |
| 1115 | + remapOperands.append(extraOperands.begin(), extraOperands.end()); |
| 1116 | + remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end()); |
| 1117 | + remapOperands.append(symbolOperands.begin(), symbolOperands.end()); |
| 1118 | + |
| 1119 | + SmallVector<Value, 4> remapOutputs; |
| 1120 | + remapOutputs.reserve(oldMemRefRank); |
| 1121 | + SmallVector<Value, 4> affineApplyOps; |
| 1122 | + |
| 1123 | + OpBuilder builder(op); |
| 1124 | + |
| 1125 | + if (indexRemap && |
| 1126 | + indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) { |
| 1127 | + // Remapped indices. |
| 1128 | + for (auto resultExpr : indexRemap.getResults()) { |
| 1129 | + auto singleResMap = AffineMap::get( |
| 1130 | + indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr); |
| 1131 | + auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap, |
| 1132 | + remapOperands); |
| 1133 | + remapOutputs.push_back(afOp); |
| 1134 | + affineApplyOps.push_back(afOp); |
| 1135 | + } |
| 1136 | + } else { |
| 1137 | + // No remapping specified. |
| 1138 | + remapOutputs.assign(remapOperands.begin(), remapOperands.end()); |
| 1139 | + } |
| 1140 | + |
| 1141 | + SmallVector<Value, 4> newMapOperands; |
| 1142 | + newMapOperands.reserve(newMemRefRank); |
| 1143 | + |
| 1144 | + // Prepend 'extraIndices' in 'newMapOperands'. |
| 1145 | + for (Value extraIndex : extraIndices) { |
| 1146 | + assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) && |
| 1147 | + "invalid memory op index"); |
| 1148 | + newMapOperands.push_back(extraIndex); |
| 1149 | + } |
| 1150 | + |
| 1151 | + // Append 'remapOutputs' to 'newMapOperands'. |
| 1152 | + newMapOperands.append(remapOutputs.begin(), remapOutputs.end()); |
| 1153 | + |
| 1154 | + // Create new fully composed AffineMap for new op to be created. |
| 1155 | + assert(newMapOperands.size() == newMemRefRank); |
| 1156 | + |
| 1157 | + OperationState state(op->getLoc(), op->getName()); |
| 1158 | + // Construct the new operation using this memref. |
| 1159 | + state.operands.reserve(newMapOperands.size() + extraIndices.size()); |
| 1160 | + state.operands.push_back(newMemRef); |
| 1161 | + |
| 1162 | + // Insert the new memref map operands. |
| 1163 | + state.operands.append(newMapOperands.begin(), newMapOperands.end()); |
| 1164 | + |
| 1165 | + state.types.reserve(op->getNumResults()); |
| 1166 | + for (auto result : op->getResults()) |
| 1167 | + state.types.push_back(result.getType()); |
| 1168 | + |
| 1169 | + // Copy over the attributes from the old operation to the new operation. |
| 1170 | + for (auto namedAttr : op->getAttrs()) { |
| 1171 | + state.attributes.push_back(namedAttr); |
| 1172 | + } |
| 1173 | + |
| 1174 | + // Create the new operation. |
| 1175 | + auto *repOp = builder.create(state); |
| 1176 | + op->replaceAllUsesWith(repOp); |
| 1177 | + op->erase(); |
| 1178 | + |
| 1179 | + return success(); |
| 1180 | +} |
1096 | 1181 | // Perform the replacement in `op`.
|
1097 | 1182 | LogicalResult mlir::affine::replaceAllMemRefUsesWith(
|
1098 | 1183 | Value oldMemRef, Value newMemRef, Operation *op,
|
@@ -1146,8 +1231,19 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
|
1146 | 1231 | // is set.
|
1147 | 1232 | return failure();
|
1148 | 1233 | }
|
1149 |
| - op->setOperand(memRefOperandPos, newMemRef); |
1150 |
| - return success(); |
| 1234 | + |
| 1235 | + // Check if it is a memref.load |
| 1236 | + auto memrefLoad = dyn_cast<memref::LoadOp>(op); |
| 1237 | + bool isReductionLike = |
| 1238 | + indexRemap.getNumResults() < indexRemap.getNumInputs(); |
| 1239 | + if (!memrefLoad || !isReductionLike) { |
| 1240 | + op->setOperand(memRefOperandPos, newMemRef); |
| 1241 | + return success(); |
| 1242 | + } |
| 1243 | + |
| 1244 | + return transformMemRefLoadWithReducedRank( |
| 1245 | + op, oldMemRef, newMemRef, memRefOperandPos, extraIndices, extraOperands, |
| 1246 | + symbolOperands, indexRemap); |
1151 | 1247 | }
|
1152 | 1248 | // Perform index rewrites for the dereferencing op and then replace the op
|
1153 | 1249 | NamedAttribute oldMapAttrPair =
|
|
0 commit comments