Skip to content

Commit aabddc9

Browse files
[MLIR][memref] Fix normalization issue in memref.load (#107771)
This change will fix the normalization issue with memref.load when the associated affine map is reducing the dimension. This PR fixes #82675 Co-authored-by: Kai Sasaki <[email protected]>
1 parent 04a8bff commit aabddc9

File tree

2 files changed

+136
-2
lines changed

2 files changed

+136
-2
lines changed

mlir/lib/Dialect/Affine/Utils/Utils.cpp

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir/IR/ImplicitLocOpBuilder.h"
2828
#include "mlir/IR/IntegerSet.h"
2929
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30+
#include "llvm/Support/LogicalResult.h"
3031
#include <optional>
3132

3233
#define DEBUG_TYPE "affine-utils"
@@ -1093,6 +1094,90 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
10931094
op->erase();
10941095
}
10951096

1097+
// Private helper function to transform memref.load with reduced rank.
1098+
// This function will modify the indices of the memref.load to match the
1099+
// newMemRef.
1100+
LogicalResult transformMemRefLoadWithReducedRank(
1101+
Operation *op, Value oldMemRef, Value newMemRef, unsigned memRefOperandPos,
1102+
ArrayRef<Value> extraIndices, ArrayRef<Value> extraOperands,
1103+
ArrayRef<Value> symbolOperands, AffineMap indexRemap) {
1104+
unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank();
1105+
unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank();
1106+
unsigned oldMapNumInputs = oldMemRefRank;
1107+
SmallVector<Value, 4> oldMapOperands(
1108+
op->operand_begin() + memRefOperandPos + 1,
1109+
op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
1110+
SmallVector<Value, 4> oldMemRefOperands;
1111+
oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
1112+
SmallVector<Value, 4> remapOperands;
1113+
remapOperands.reserve(extraOperands.size() + oldMemRefRank +
1114+
symbolOperands.size());
1115+
remapOperands.append(extraOperands.begin(), extraOperands.end());
1116+
remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
1117+
remapOperands.append(symbolOperands.begin(), symbolOperands.end());
1118+
1119+
SmallVector<Value, 4> remapOutputs;
1120+
remapOutputs.reserve(oldMemRefRank);
1121+
SmallVector<Value, 4> affineApplyOps;
1122+
1123+
OpBuilder builder(op);
1124+
1125+
if (indexRemap &&
1126+
indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
1127+
// Remapped indices.
1128+
for (auto resultExpr : indexRemap.getResults()) {
1129+
auto singleResMap = AffineMap::get(
1130+
indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
1131+
auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
1132+
remapOperands);
1133+
remapOutputs.push_back(afOp);
1134+
affineApplyOps.push_back(afOp);
1135+
}
1136+
} else {
1137+
// No remapping specified.
1138+
remapOutputs.assign(remapOperands.begin(), remapOperands.end());
1139+
}
1140+
1141+
SmallVector<Value, 4> newMapOperands;
1142+
newMapOperands.reserve(newMemRefRank);
1143+
1144+
// Prepend 'extraIndices' in 'newMapOperands'.
1145+
for (Value extraIndex : extraIndices) {
1146+
assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
1147+
"invalid memory op index");
1148+
newMapOperands.push_back(extraIndex);
1149+
}
1150+
1151+
// Append 'remapOutputs' to 'newMapOperands'.
1152+
newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
1153+
1154+
// Create new fully composed AffineMap for new op to be created.
1155+
assert(newMapOperands.size() == newMemRefRank);
1156+
1157+
OperationState state(op->getLoc(), op->getName());
1158+
// Construct the new operation using this memref.
1159+
state.operands.reserve(newMapOperands.size() + extraIndices.size());
1160+
state.operands.push_back(newMemRef);
1161+
1162+
// Insert the new memref map operands.
1163+
state.operands.append(newMapOperands.begin(), newMapOperands.end());
1164+
1165+
state.types.reserve(op->getNumResults());
1166+
for (auto result : op->getResults())
1167+
state.types.push_back(result.getType());
1168+
1169+
// Copy over the attributes from the old operation to the new operation.
1170+
for (auto namedAttr : op->getAttrs()) {
1171+
state.attributes.push_back(namedAttr);
1172+
}
1173+
1174+
// Create the new operation.
1175+
auto *repOp = builder.create(state);
1176+
op->replaceAllUsesWith(repOp);
1177+
op->erase();
1178+
1179+
return success();
1180+
}
10961181
// Perform the replacement in `op`.
10971182
LogicalResult mlir::affine::replaceAllMemRefUsesWith(
10981183
Value oldMemRef, Value newMemRef, Operation *op,
@@ -1146,8 +1231,19 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
11461231
// is set.
11471232
return failure();
11481233
}
1149-
op->setOperand(memRefOperandPos, newMemRef);
1150-
return success();
1234+
1235+
// Check if it is a memref.load
1236+
auto memrefLoad = dyn_cast<memref::LoadOp>(op);
1237+
bool isReductionLike =
1238+
indexRemap.getNumResults() < indexRemap.getNumInputs();
1239+
if (!memrefLoad || !isReductionLike) {
1240+
op->setOperand(memRefOperandPos, newMemRef);
1241+
return success();
1242+
}
1243+
1244+
return transformMemRefLoadWithReducedRank(
1245+
op, oldMemRef, newMemRef, memRefOperandPos, extraIndices, extraOperands,
1246+
symbolOperands, indexRemap);
11511247
}
11521248
// Perform index rewrites for the dereferencing op and then replace the op
11531249
NamedAttribute oldMapAttrPair =

mlir/test/Dialect/MemRef/normalize-memrefs.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
// This file tests whether the memref type having non-trivial map layouts
44
// are normalized to trivial (identity) layouts.
55

6+
// CHECK-DAG: #[[$REDUCE_MAP1:.*]] = affine_map<(d0, d1) -> ((d0 mod 2) * 2 + d1 mod 2 + (d0 floordiv 2) * 4 + (d1 floordiv 2) * 8)>
7+
// CHECK-DAG: #[[$REDUCE_MAP2:.*]] = affine_map<(d0, d1) -> (d0 mod 2 + (d1 mod 2) * 2 + (d0 floordiv 2) * 8 + (d1 floordiv 2) * 4)>
8+
// CHECK-DAG: #[[$REDUCE_MAP3:.*]] = affine_map<(d0, d1) -> (d0 * 4 + d1)>
9+
610
// CHECK-LABEL: func @permute()
711
func.func @permute() {
812
%A = memref.alloc() : memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>>
@@ -363,3 +367,37 @@ func.func @memref_with_strided_offset(%arg0: tensor<128x512xf32>, %arg1: index,
363367
%1 = bufferization.to_tensor %cast : memref<16x512xf32, strided<[?, ?], offset: ?>>
364368
return %1 : tensor<16x512xf32>
365369
}
370+
371+
#map0 = affine_map<(i,k) -> (2 * (i mod 2) + (k mod 2) + 4 * (i floordiv 2) + 8 * (k floordiv 2))>
372+
#map1 = affine_map<(k,j) -> ((k mod 2) + 2 * (j mod 2) + 8 * (k floordiv 2) + 4 * (j floordiv 2))>
373+
#map2 = affine_map<(i,j) -> (4 * i + j)>
374+
// CHECK-LABEL: func @memref_load_with_reduction_map
375+
func.func @memref_load_with_reduction_map(%arg0 : memref<4x4xf32,#map2>) -> () {
376+
%0 = memref.alloc() : memref<4x8xf32,#map0>
377+
%1 = memref.alloc() : memref<8x4xf32,#map1>
378+
%2 = memref.alloc() : memref<4x4xf32,#map2>
379+
// CHECK-NOT: memref<4x8xf32>
380+
// CHECK-NOT: memref<8x4xf32>
381+
// CHECK-NOT: memref<4x4xf32>
382+
%cst = arith.constant 3.0 : f32
383+
%cst0 = arith.constant 0 : index
384+
affine.for %i = 0 to 4 {
385+
affine.for %j = 0 to 8 {
386+
affine.for %k = 0 to 8 {
387+
// CHECK: %[[INDEX0:.*]] = affine.apply #[[$REDUCE_MAP1]](%{{.*}}, %{{.*}})
388+
// CHECK: memref.load %alloc[%[[INDEX0]]] : memref<32xf32>
389+
%a = memref.load %0[%i, %k] : memref<4x8xf32,#map0>
390+
// CHECK: %[[INDEX1:.*]] = affine.apply #[[$REDUCE_MAP2]](%{{.*}}, %{{.*}})
391+
// CHECK: memref.load %alloc_0[%[[INDEX1]]] : memref<32xf32>
392+
%b = memref.load %1[%k, %j] :memref<8x4xf32,#map1>
393+
// CHECK: %[[INDEX2:.*]] = affine.apply #[[$REDUCE_MAP3]](%{{.*}}, %{{.*}})
394+
// CHECK: memref.load %alloc_1[%[[INDEX2]]] : memref<16xf32>
395+
%c = memref.load %2[%i, %j] : memref<4x4xf32,#map2>
396+
%3 = arith.mulf %a, %b : f32
397+
%4 = arith.addf %3, %c : f32
398+
affine.store %4, %arg0[%i, %j] : memref<4x4xf32,#map2>
399+
}
400+
}
401+
}
402+
return
403+
}

0 commit comments

Comments
 (0)