-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][MemRef] Use specialized index ops to fold expand/collapse_shape #138930
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][MemRef] Use specialized index ops to fold expand/collapse_shape #138930
Conversation
PRs this depends on: |
@llvm/pr-subscribers-mlir-memref Author: Krzysztof Drewniak (krzysz00) ChangesThis PR updates the FoldMemRefAliasOps to use This also loosens some limitations of the pass:
Patch is 31.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138930.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index d6d8161d3117b..f34b5b46cab50 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1342,14 +1342,14 @@ def MemRef_ReinterpretCastOp
according to specified offsets, sizes, and strides.
```mlir
- %result1 = memref.reinterpret_cast %arg0 to
+ %result1 = memref.reinterpret_cast %arg0 to
offset: [9],
sizes: [4, 4],
strides: [16, 2]
: memref<8x8xf32, strided<[8, 1], offset: 0>> to
memref<4x4xf32, strided<[16, 2], offset: 9>>
- %result2 = memref.reinterpret_cast %result1 to
+ %result2 = memref.reinterpret_cast %result1 to
offset: [0],
sizes: [2, 2],
strides: [4, 2]
@@ -1755,6 +1755,12 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
OpBuilder &b, Location loc, MemRefType expandedType,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<OpFoldResult> inputShape);
+
+ // Return a vector with all the static and dynamic values in the output shape.
+ SmallVector<OpFoldResult> getMixedOutputShape() {
+ OpBuilder builder(getContext());
+ return ::mlir::getMixedValues(getStaticOutputShape(), getOutputShape(), builder);
+ }
}];
let hasVerifier = 1;
@@ -1873,7 +1879,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
let summary = "store operation";
let description = [{
The `store` op stores an element into a memref at the specified indices.
-
+
The number of indices must match the rank of the memref. The indices must
be in-bounds: `0 <= idx < dim_size`
@@ -2025,7 +2031,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
Unlike the `reinterpret_cast`, the values are relative to the strided
memref of the input (`%result1` in this case) and not its
underlying memory.
-
+
Example 2:
```mlir
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index e4fb3f9bb87ed..2acb90613e5d1 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -59,92 +59,28 @@ using namespace mlir;
///
/// %2 = load %0[6 * i1 + i2, %i3] :
/// memref<12x42xf32>
-static LogicalResult
-resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
- memref::ExpandShapeOp expandShapeOp,
- ValueRange indices,
- SmallVectorImpl<Value> &sourceIndices) {
- // Record the rewriter context for constructing ops later.
- MLIRContext *ctx = rewriter.getContext();
-
- // Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
- // This is done for the purpose of inferring the output shape via
- // `inferExpandOutputShape` which will in turn be used for suffix product
- // calculation later.
- SmallVector<OpFoldResult> srcShape;
- MemRefType srcType = expandShapeOp.getSrcType();
-
- for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) {
- if (srcType.isDynamicDim(i)) {
- srcShape.push_back(
- rewriter.create<memref::DimOp>(loc, expandShapeOp.getSrc(), i)
- .getResult());
- } else {
- srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i]));
- }
- }
-
- auto outputShape = inferExpandShapeOutputShape(
- rewriter, loc, expandShapeOp.getResultType(),
- expandShapeOp.getReassociationIndices(), srcShape);
- if (!outputShape.has_value())
- return failure();
+static LogicalResult resolveSourceIndicesExpandShape(
+ Location loc, PatternRewriter &rewriter,
+ memref::ExpandShapeOp expandShapeOp, ValueRange indices,
+ SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
+ SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
// Traverse all reassociation groups to determine the appropriate indices
// corresponding to each one of them post op folding.
- for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
- assert(!groups.empty() && "association indices groups cannot be empty");
- // Flag to indicate the presence of dynamic dimensions in current
- // reassociation group.
- int64_t groupSize = groups.size();
-
- // Group output dimensions utilized in this reassociation group for suffix
- // product calculation.
- SmallVector<OpFoldResult> sizesVal(groupSize);
- for (int64_t i = 0; i < groupSize; ++i) {
- sizesVal[i] = (*outputShape)[groups[i]];
+ for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
+ assert(!group.empty() && "association indices groups cannot be empty");
+ int64_t groupSize = group.size();
+ if (groupSize == 1) {
+ sourceIndices.push_back(indices[group[0]]);
+ continue;
}
-
- // Calculate suffix product of relevant output dimension sizes.
- SmallVector<OpFoldResult> suffixProduct =
- memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal);
-
- // Create affine expression variables for dimensions and symbols in the
- // newly constructed affine map.
- SmallVector<AffineExpr> dims(groupSize), symbols(groupSize);
- bindDimsList<AffineExpr>(ctx, dims);
- bindSymbolsList<AffineExpr>(ctx, symbols);
-
- // Linearize binded dimensions and symbols to construct the resultant
- // affine expression for this indice.
- AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
-
- // Record the load index corresponding to each dimension in the
- // reassociation group. These are later supplied as operands to the affine
- // map used for calulating relevant index post op folding.
- SmallVector<OpFoldResult> dynamicIndices(groupSize);
- for (int64_t i = 0; i < groupSize; i++)
- dynamicIndices[i] = indices[groups[i]];
-
- // Supply suffix product results followed by load op indices as operands
- // to the map.
- SmallVector<OpFoldResult> mapOperands;
- llvm::append_range(mapOperands, suffixProduct);
- llvm::append_range(mapOperands, dynamicIndices);
-
- // Creating maximally folded and composed affine.apply composes better
- // with other transformations without interleaving canonicalization
- // passes.
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
- rewriter, loc,
- AffineMap::get(/*numDims=*/groupSize,
- /*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
- mapOperands);
-
- // Push index value in the op post folding corresponding to this
- // reassociation group.
- sourceIndices.push_back(
- getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+ SmallVector<OpFoldResult> groupBasis =
+ llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
+ SmallVector<Value> groupIndices =
+ llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
+ Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
+ loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
+ sourceIndices.push_back(collapsedIndex);
}
return success();
}
@@ -167,49 +103,34 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
memref::CollapseShapeOp collapseShapeOp,
ValueRange indices,
SmallVectorImpl<Value> &sourceIndices) {
- int64_t cnt = 0;
- SmallVector<OpFoldResult> dynamicIndices;
- for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) {
- assert(!groups.empty() && "association indices groups cannot be empty");
- dynamicIndices.push_back(indices[cnt++]);
- int64_t groupSize = groups.size();
-
- // Calculate suffix product for all collapse op source dimension sizes
- // except the most major one of each group.
- // We allow the most major source dimension to be dynamic but enforce all
- // others to be known statically.
- SmallVector<int64_t> sizes(groupSize, 1);
- for (int64_t i = 1; i < groupSize; ++i) {
- sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
- if (sizes[i] == ShapedType::kDynamic)
- return failure();
- }
- SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
-
- // Derive the index values along all dimensions of the source corresponding
- // to the index wrt to collapsed shape op output.
- auto d0 = rewriter.getAffineDimExpr(0);
- SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct);
-
- // Construct the AffineApplyOp for each delinearizingExpr.
- for (int64_t i = 0; i < groupSize; i++) {
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
- rewriter, loc,
- AffineMap::get(/*numDims=*/1, /*numSymbols=*/0,
- delinearizingExprs[i]),
- dynamicIndices);
- sourceIndices.push_back(
- getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+ MemRefType sourceType = collapseShapeOp.getSrcType();
+ // Note: collapse_shape requires a strided memref, we can do this.
+ auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+ loc, collapseShapeOp.getSrc());
+ SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
+ for (auto [index, group] :
+ llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
+ assert(!group.empty() && "association indices groups cannot be empty");
+ int64_t groupSize = group.size();
+
+ if (groupSize == 1) {
+ sourceIndices.push_back(index);
+ continue;
}
- dynamicIndices.clear();
+
+ SmallVector<OpFoldResult> basis =
+ llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
+ auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
+ loc, index, basis, /*hasOuterBound=*/true);
+ llvm::append_range(sourceIndices, delinearize.getResults());
}
if (collapseShapeOp.getReassociationIndices().empty()) {
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
int64_t srcRank =
cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
+ OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
for (int64_t i = 0; i < srcRank; i++) {
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
- rewriter, loc, zeroAffineMap, dynamicIndices);
sourceIndices.push_back(
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
}
@@ -513,8 +434,12 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
SmallVector<Value> sourceIndices;
+ // memref.load and affine.load guarantee that indexes start inbounds
+ // while the vector operations don't. This impacts if our linearization
+ // is `disjoint`
if (failed(resolveSourceIndicesExpandShape(
- loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+ loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+ isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
return failure();
llvm::TypeSwitch<Operation *, void>(loadOp)
.Case([&](affine::AffineLoadOp op) {
@@ -676,8 +601,12 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
SmallVector<Value> sourceIndices;
+ // memref.store and affine.store guarantee that indexes start inbounds
+ // while the vector operations don't. This impacts if our linearization
+ // is `disjoint`
if (failed(resolveSourceIndicesExpandShape(
- storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+ storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+ isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
return failure();
llvm::TypeSwitch<Operation *, void>(storeOp)
.Case([&](affine::AffineStoreOp op) {
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index a27fbf26e13d8..106652623933f 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -408,7 +408,6 @@ func.func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x3
// -----
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 6 + s1)>
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32 {
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> f32 {
@@ -416,14 +415,12 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0
%1 = affine.load %0[%arg1, %arg2, %arg3] : memref<2x6x32xf32>
return %1 : f32
}
-// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]]] by (2, 6)
// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG3]]] : memref<12x32xf32>
// CHECK-NEXT: return %[[RESULT]] : f32
// -----
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
// CHECK-LABEL: @fold_static_stride_subview_with_affine_load_store_collapse_shape
// CHECK-SAME: (%[[ARG0:.*]]: memref<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg0 : memref<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
@@ -431,15 +428,12 @@ func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg
%1 = affine.load %0[%arg1, %arg2] : memref<12x32xf32>
return %1 : f32
}
-// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<2x6x32xf32>
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<2x6x32xf32>
// CHECK-NEXT: return %[[RESULT]] : f32
// -----
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
// CHECK-LABEL: @fold_dynamic_size_collapse_shape_with_affine_load
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
@@ -447,14 +441,28 @@ func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x
%1 = affine.load %0[%arg1, %arg2] : memref<?x32xf32>
return %1 : f32
}
-// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<?x6x32xf32>
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x6x32xf32>
// CHECK-NEXT: return %[[RESULT]] : f32
// -----
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1, s2] -> (s0 * 6 + s1 * 3 + s2)>
+// CHECK-LABEL: @fold_fully_dynamic_size_collapse_shape_with_affine_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @fold_fully_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x?x?xf32>, %arg1 : index, %arg2 : index) -> f32 {
+ %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<?x?x?xf32> into memref<?x?xf32>
+ %1 = affine.load %0[%arg1, %arg2] : memref<?x?xf32>
+ return %1 : f32
+}
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, %[[SIZES]]#1)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x?x?xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+
+// -----
+
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_3d
// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 {
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: index) -> f32 {
@@ -462,7 +470,7 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
%1 = affine.load %0[%arg1, %arg2, %arg3, %arg4] : memref<2x2x3x32xf32>
return %1 : f32
}
-// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]]]
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]], %[[ARG3]]] by (2, 2, 3)
// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG4]]] : memref<12x32xf32>
// CHECK-NEXT: return %[[RESULT]] : f32
@@ -476,7 +484,10 @@ func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x
%0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
return %0 : f32
}
-// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0
+// CHECK-NEXT: %[[INDEX1:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG1]]] by (1, 16)
+// CHECK-NEXT: %[[INDEX2:.*]] = affine.linearize_index disjoint [%[[ARG2]], %[[C0]]] by (%[[ARG3]], 1)
+// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[INDEX1]], %[[INDEX2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
// CHECK-NEXT: return %[[VAL1]] : f32
// -----
@@ -490,14 +501,16 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
return
}
-// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT: %[[INDEX1:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG1]]] by (1, 16)
+// CHECK-NEXT: %[[INDEX2:.*]] = affine.linearize_index disjoint [%[[ARG2]], %[[C0]]] by (%[[ARG3]], 1)
+// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[INDEX1]], %[[INDEX2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
// CHECK-NEXT: return
// -----
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
@@ -513,21 +526,20 @@ func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc:
}
return
}
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Krzysztof Drewniak (krzysz00) ChangesThis PR updates the FoldMemRefAliasOps to use This also loosens some limitations of the pass:
Patch is 31.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138930.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index d6d8161d3117b..f34b5b46cab50 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1342,14 +1342,14 @@ def MemRef_ReinterpretCastOp
according to specified offsets, sizes, and strides.
```mlir
- %result1 = memref.reinterpret_cast %arg0 to
+ %result1 = memref.reinterpret_cast %arg0 to
offset: [9],
sizes: [4, 4],
strides: [16, 2]
: memref<8x8xf32, strided<[8, 1], offset: 0>> to
memref<4x4xf32, strided<[16, 2], offset: 9>>
- %result2 = memref.reinterpret_cast %result1 to
+ %result2 = memref.reinterpret_cast %result1 to
offset: [0],
sizes: [2, 2],
strides: [4, 2]
@@ -1755,6 +1755,12 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
OpBuilder &b, Location loc, MemRefType expandedType,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<OpFoldResult> inputShape);
+
+ // Return a vector with all the static and dynamic values in the output shape.
+ SmallVector<OpFoldResult> getMixedOutputShape() {
+ OpBuilder builder(getContext());
+ return ::mlir::getMixedValues(getStaticOutputShape(), getOutputShape(), builder);
+ }
}];
let hasVerifier = 1;
@@ -1873,7 +1879,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
let summary = "store operation";
let description = [{
The `store` op stores an element into a memref at the specified indices.
-
+
The number of indices must match the rank of the memref. The indices must
be in-bounds: `0 <= idx < dim_size`
@@ -2025,7 +2031,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
Unlike the `reinterpret_cast`, the values are relative to the strided
memref of the input (`%result1` in this case) and not its
underlying memory.
-
+
Example 2:
```mlir
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index e4fb3f9bb87ed..2acb90613e5d1 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -59,92 +59,28 @@ using namespace mlir;
///
/// %2 = load %0[6 * i1 + i2, %i3] :
/// memref<12x42xf32>
-static LogicalResult
-resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
- memref::ExpandShapeOp expandShapeOp,
- ValueRange indices,
- SmallVectorImpl<Value> &sourceIndices) {
- // Record the rewriter context for constructing ops later.
- MLIRContext *ctx = rewriter.getContext();
-
- // Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
- // This is done for the purpose of inferring the output shape via
- // `inferExpandOutputShape` which will in turn be used for suffix product
- // calculation later.
- SmallVector<OpFoldResult> srcShape;
- MemRefType srcType = expandShapeOp.getSrcType();
-
- for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) {
- if (srcType.isDynamicDim(i)) {
- srcShape.push_back(
- rewriter.create<memref::DimOp>(loc, expandShapeOp.getSrc(), i)
- .getResult());
- } else {
- srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i]));
- }
- }
-
- auto outputShape = inferExpandShapeOutputShape(
- rewriter, loc, expandShapeOp.getResultType(),
- expandShapeOp.getReassociationIndices(), srcShape);
- if (!outputShape.has_value())
- return failure();
+static LogicalResult resolveSourceIndicesExpandShape(
+ Location loc, PatternRewriter &rewriter,
+ memref::ExpandShapeOp expandShapeOp, ValueRange indices,
+ SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
+ SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
// Traverse all reassociation groups to determine the appropriate indices
// corresponding to each one of them post op folding.
- for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
- assert(!groups.empty() && "association indices groups cannot be empty");
- // Flag to indicate the presence of dynamic dimensions in current
- // reassociation group.
- int64_t groupSize = groups.size();
-
- // Group output dimensions utilized in this reassociation group for suffix
- // product calculation.
- SmallVector<OpFoldResult> sizesVal(groupSize);
- for (int64_t i = 0; i < groupSize; ++i) {
- sizesVal[i] = (*outputShape)[groups[i]];
+ for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
+ assert(!group.empty() && "association indices groups cannot be empty");
+ int64_t groupSize = group.size();
+ if (groupSize == 1) {
+ sourceIndices.push_back(indices[group[0]]);
+ continue;
}
-
- // Calculate suffix product of relevant output dimension sizes.
- SmallVector<OpFoldResult> suffixProduct =
- memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal);
-
- // Create affine expression variables for dimensions and symbols in the
- // newly constructed affine map.
- SmallVector<AffineExpr> dims(groupSize), symbols(groupSize);
- bindDimsList<AffineExpr>(ctx, dims);
- bindSymbolsList<AffineExpr>(ctx, symbols);
-
- // Linearize binded dimensions and symbols to construct the resultant
- // affine expression for this indice.
- AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
-
- // Record the load index corresponding to each dimension in the
- // reassociation group. These are later supplied as operands to the affine
- // map used for calulating relevant index post op folding.
- SmallVector<OpFoldResult> dynamicIndices(groupSize);
- for (int64_t i = 0; i < groupSize; i++)
- dynamicIndices[i] = indices[groups[i]];
-
- // Supply suffix product results followed by load op indices as operands
- // to the map.
- SmallVector<OpFoldResult> mapOperands;
- llvm::append_range(mapOperands, suffixProduct);
- llvm::append_range(mapOperands, dynamicIndices);
-
- // Creating maximally folded and composed affine.apply composes better
- // with other transformations without interleaving canonicalization
- // passes.
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
- rewriter, loc,
- AffineMap::get(/*numDims=*/groupSize,
- /*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
- mapOperands);
-
- // Push index value in the op post folding corresponding to this
- // reassociation group.
- sourceIndices.push_back(
- getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+ SmallVector<OpFoldResult> groupBasis =
+ llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
+ SmallVector<Value> groupIndices =
+ llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
+ Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
+ loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
+ sourceIndices.push_back(collapsedIndex);
}
return success();
}
@@ -167,49 +103,34 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
memref::CollapseShapeOp collapseShapeOp,
ValueRange indices,
SmallVectorImpl<Value> &sourceIndices) {
- int64_t cnt = 0;
- SmallVector<OpFoldResult> dynamicIndices;
- for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) {
- assert(!groups.empty() && "association indices groups cannot be empty");
- dynamicIndices.push_back(indices[cnt++]);
- int64_t groupSize = groups.size();
-
- // Calculate suffix product for all collapse op source dimension sizes
- // except the most major one of each group.
- // We allow the most major source dimension to be dynamic but enforce all
- // others to be known statically.
- SmallVector<int64_t> sizes(groupSize, 1);
- for (int64_t i = 1; i < groupSize; ++i) {
- sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
- if (sizes[i] == ShapedType::kDynamic)
- return failure();
- }
- SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
-
- // Derive the index values along all dimensions of the source corresponding
- // to the index wrt to collapsed shape op output.
- auto d0 = rewriter.getAffineDimExpr(0);
- SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct);
-
- // Construct the AffineApplyOp for each delinearizingExpr.
- for (int64_t i = 0; i < groupSize; i++) {
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
- rewriter, loc,
- AffineMap::get(/*numDims=*/1, /*numSymbols=*/0,
- delinearizingExprs[i]),
- dynamicIndices);
- sourceIndices.push_back(
- getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+ MemRefType sourceType = collapseShapeOp.getSrcType();
+ // Note: collapse_shape requires a strided memref, we can do this.
+ auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+ loc, collapseShapeOp.getSrc());
+ SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
+ for (auto [index, group] :
+ llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
+ assert(!group.empty() && "association indices groups cannot be empty");
+ int64_t groupSize = group.size();
+
+ if (groupSize == 1) {
+ sourceIndices.push_back(index);
+ continue;
}
- dynamicIndices.clear();
+
+ SmallVector<OpFoldResult> basis =
+ llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
+ auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
+ loc, index, basis, /*hasOuterBound=*/true);
+ llvm::append_range(sourceIndices, delinearize.getResults());
}
if (collapseShapeOp.getReassociationIndices().empty()) {
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
int64_t srcRank =
cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
+ OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
for (int64_t i = 0; i < srcRank; i++) {
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
- rewriter, loc, zeroAffineMap, dynamicIndices);
sourceIndices.push_back(
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
}
@@ -513,8 +434,12 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
SmallVector<Value> sourceIndices;
+ // memref.load and affine.load guarantee that indexes start inbounds
+ // while the vector operations don't. This impacts if our linearization
+ // is `disjoint`
if (failed(resolveSourceIndicesExpandShape(
- loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+ loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+ isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
return failure();
llvm::TypeSwitch<Operation *, void>(loadOp)
.Case([&](affine::AffineLoadOp op) {
@@ -676,8 +601,12 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
SmallVector<Value> sourceIndices;
+ // memref.store and affine.store guarantee that indexes start inbounds
+ // while the vector operations don't. This impacts if our linearization
+ // is `disjoint`
if (failed(resolveSourceIndicesExpandShape(
- storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+ storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+ isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
return failure();
llvm::TypeSwitch<Operation *, void>(storeOp)
.Case([&](affine::AffineStoreOp op) {
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index a27fbf26e13d8..106652623933f 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -408,7 +408,6 @@ func.func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x3
// -----
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 6 + s1)>
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32 {
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> f32 {
@@ -416,14 +415,12 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0
%1 = affine.load %0[%arg1, %arg2, %arg3] : memref<2x6x32xf32>
return %1 : f32
}
-// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]]] by (2, 6)
// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG3]]] : memref<12x32xf32>
// CHECK-NEXT: return %[[RESULT]] : f32
// -----
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
// CHECK-LABEL: @fold_static_stride_subview_with_affine_load_store_collapse_shape
// CHECK-SAME: (%[[ARG0:.*]]: memref<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg0 : memref<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
@@ -431,15 +428,12 @@ func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg
%1 = affine.load %0[%arg1, %arg2] : memref<12x32xf32>
return %1 : f32
}
-// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<2x6x32xf32>
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<2x6x32xf32>
// CHECK-NEXT: return %[[RESULT]] : f32
// -----
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
// CHECK-LABEL: @fold_dynamic_size_collapse_shape_with_affine_load
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
@@ -447,14 +441,28 @@ func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x
%1 = affine.load %0[%arg1, %arg2] : memref<?x32xf32>
return %1 : f32
}
-// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<?x6x32xf32>
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x6x32xf32>
// CHECK-NEXT: return %[[RESULT]] : f32
// -----
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1, s2] -> (s0 * 6 + s1 * 3 + s2)>
+// CHECK-LABEL: @fold_fully_dynamic_size_collapse_shape_with_affine_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @fold_fully_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x?x?xf32>, %arg1 : index, %arg2 : index) -> f32 {
+ %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<?x?x?xf32> into memref<?x?xf32>
+ %1 = affine.load %0[%arg1, %arg2] : memref<?x?xf32>
+ return %1 : f32
+}
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, %[[SIZES]]#1)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x?x?xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+
+// -----
+
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_3d
// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 {
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: index) -> f32 {
@@ -462,7 +470,7 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
%1 = affine.load %0[%arg1, %arg2, %arg3, %arg4] : memref<2x2x3x32xf32>
return %1 : f32
}
-// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]]]
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]], %[[ARG3]]] by (2, 2, 3)
// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG4]]] : memref<12x32xf32>
// CHECK-NEXT: return %[[RESULT]] : f32
@@ -476,7 +484,10 @@ func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x
%0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
return %0 : f32
}
-// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0
+// CHECK-NEXT: %[[INDEX1:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG1]]] by (1, 16)
+// CHECK-NEXT: %[[INDEX2:.*]] = affine.linearize_index disjoint [%[[ARG2]], %[[C0]]] by (%[[ARG3]], 1)
+// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[INDEX1]], %[[INDEX2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
// CHECK-NEXT: return %[[VAL1]] : f32
// -----
@@ -490,14 +501,16 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
return
}
-// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT: %[[INDEX1:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG1]]] by (1, 16)
+// CHECK-NEXT: %[[INDEX2:.*]] = affine.linearize_index disjoint [%[[ARG2]], %[[C0]]] by (%[[ARG3]], 1)
+// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[INDEX1]], %[[INDEX2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
// CHECK-NEXT: return
// -----
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
@@ -513,21 +526,20 @@ func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc:
}
return
}
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0...
[truncated]
|
Ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes look good to me. It isn't strictly required, by given that book h of us work on the same downstream project, does this pass with the said downstream project. But this looks good to me
This PR updates the FoldMemRefAliasOps to use `affine.linearize_index` and `affine.delinearize_index` to perform the index computations needed to fold a `memref.expand_shape` or `memref.collapse_shape` into its consumers, respectively. This also loosens some limitations of the pass: 1. The existing `output_shape` argument to `memref.expand_shape` is now used, eliminating the need to re-infer this shape or call `memref.dim`. 2. Because we're using `affine.delinearize_index`, the restriction that each group in a `memref.collapse_shape` can only have one dynamic dimension is removed.
ae27c6f
to
249e426
Compare
Thanks! |
Integrate torch-mlir@389541fb9ddd Integrate stablehlo@5837b2a6ce # Dropped reverts - Drop revert of the 1:N dialect conversion removal since all dependencies have migrated - Drop revert of the APIntParameter error since all dependencies have migrated - Drop revert of allowing function type conversion to fail since all dependencies have migrated - Drop local modifications to torch-mlir and stablehlo, those are now clean submodules (they're no longer needed now that we've dropped reverts and they've migrated) # Continued reverts - We still have a revert of upstream #137930 since it's not clear the roccertness issue is resolved (that's the SDWA cndmask thing, see llvm/llvm-project#138766 ) - We still have a revert of upstream #133231 since that could still be breaking tests # Changes - Rename bufferization.to_memref to bufferization.to_buffer everewhere - Swap all getSource() on Transfer*Op to getBase() - Rename the memref argument of TransferGatherOp to `base` to match the transfer interface - Remove argument materialization calls - In the one case where this wasn't trivial, migrate to a target materilazition since that's what upstream advice suggested - Handle (in any way that seemed appropriate) the failures that eraseArguments() and eraseResults() can now have - Slightly reshuffle the LLVMCPU pipeline so that there's an `affine-expand-index-ops` after the last `FoldMemRefAliasOps` call, because `FoldMemRefAliasOps` now creates `affine.linearize_index` and `affine.delinearize_index` which don't seem to lower to LLVM right on their own. See llvm/llvm-project#138930 - Update narrow type emulation tests to account for correctness fixes in the lineraized shape determination widget - Update vectorization tests to account for changes in ee47454bb8be and update the attention tiling test to account for an unknown change - Update StableHLO rewrites to account for accuracy arguments
…pse_shape (llvm#138930)" This reverts commit a891163.
This PR updates the FoldMemRefAliasOps to use
affine.linearize_index
andaffine.delinearize_index
to perform the index computations needed to fold amemref.expand_shape
ormemref.collapse_shape
into its consumers, respectively.This also loosens some limitations of the pass:
output_shape
argument tomemref.expand_shape
is now used, eliminating the need to re-infer this shape or callmemref.dim
.affine.delinearize_index
, the restriction that each group in amemref.collapse_shape
can only have one dynamic dimension is removed.