-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][memref] Add memref alias folders for expand/collapse_shape for vector load/store #95223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-memref Author: Kunwar Grover (Groverkss) ChangesThis patch adds adds patterns to fold memref alias for expand_shape/collapse_shape feeding into vector.load/vector.store and vector.maskedload/vector.maskedstore Full diff: https://github.com/llvm/llvm-project/pull/95223.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index db085b386483c..96daf4c5972a4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -518,10 +518,25 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
return failure();
llvm::TypeSwitch<Operation *, void>(loadOp)
- .Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) {
- rewriter.replaceOpWithNewOp<decltype(op)>(
+ .Case([&](affine::AffineLoadOp op) {
+ rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
loadOp, expandShapeOp.getViewSource(), sourceIndices);
})
+ .Case([&](memref::LoadOp op) {
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(
+ loadOp, expandShapeOp.getViewSource(), sourceIndices,
+ op.getNontemporal());
+ })
+ .Case([&](vector::LoadOp op) {
+ rewriter.replaceOpWithNewOp<vector::LoadOp>(
+ op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
+ op.getNontemporal());
+ })
+ .Case([&](vector::MaskedLoadOp op) {
+ rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
+ op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
+ op.getMask(), op.getPassThru());
+ })
.Default([](Operation *) { llvm_unreachable("unexpected operation."); });
return success();
}
@@ -551,10 +566,25 @@ LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
return failure();
llvm::TypeSwitch<Operation *, void>(loadOp)
- .Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) {
- rewriter.replaceOpWithNewOp<decltype(op)>(
+ .Case([&](affine::AffineLoadOp op) {
+ rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
loadOp, collapseShapeOp.getViewSource(), sourceIndices);
})
+ .Case([&](memref::LoadOp op) {
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(
+ loadOp, collapseShapeOp.getViewSource(), sourceIndices,
+ op.getNontemporal());
+ })
+ .Case([&](vector::LoadOp op) {
+ rewriter.replaceOpWithNewOp<vector::LoadOp>(
+ op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
+ op.getNontemporal());
+ })
+ .Case([&](vector::MaskedLoadOp op) {
+ rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
+ op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
+ op.getMask(), op.getPassThru());
+ })
.Default([](Operation *) { llvm_unreachable("unexpected operation."); });
return success();
}
@@ -651,10 +681,25 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
return failure();
llvm::TypeSwitch<Operation *, void>(storeOp)
- .Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) {
- rewriter.replaceOpWithNewOp<decltype(op)>(storeOp, storeOp.getValue(),
- expandShapeOp.getViewSource(),
- sourceIndices);
+ .Case([&](affine::AffineStoreOp op) {
+ rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
+ storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
+ sourceIndices);
+ })
+ .Case([&](memref::StoreOp op) {
+ rewriter.replaceOpWithNewOp<memref::StoreOp>(
+ storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
+ sourceIndices, op.getNontemporal());
+ })
+ .Case([&](vector::StoreOp op) {
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(
+ op, op.getValueToStore(), expandShapeOp.getViewSource(),
+ sourceIndices, op.getNontemporal());
+ })
+ .Case([&](vector::MaskedStoreOp op) {
+ rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
+ op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
+ op.getValueToStore());
})
.Default([](Operation *) { llvm_unreachable("unexpected operation."); });
return success();
@@ -685,11 +730,26 @@ LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
return failure();
llvm::TypeSwitch<Operation *, void>(storeOp)
- .Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) {
- rewriter.replaceOpWithNewOp<decltype(op)>(
- storeOp, storeOp.getValue(), collapseShapeOp.getViewSource(),
+ .Case([&](affine::AffineStoreOp op) {
+ rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
+ storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
sourceIndices);
})
+ .Case([&](memref::StoreOp op) {
+ rewriter.replaceOpWithNewOp<memref::StoreOp>(
+ storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
+ sourceIndices, op.getNontemporal());
+ })
+ .Case([&](vector::StoreOp op) {
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(
+ op, op.getValueToStore(), collapseShapeOp.getViewSource(),
+ sourceIndices, op.getNontemporal());
+ })
+ .Case([&](vector::MaskedStoreOp op) {
+ rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
+ op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
+ op.getValueToStore());
+ })
.Default([](Operation *) { llvm_unreachable("unexpected operation."); });
return success();
}
@@ -763,12 +823,20 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
+ LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
+ LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
+ StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
+ StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
+ LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
+ LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
+ StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
+ StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
patterns.getContext());
}
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index e49dff44ae0d6..d67d6df23f90b 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -819,14 +819,14 @@ func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: ind
// -----
-func.func @fold_vector_load(
+func.func @fold_vector_load_subview(
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
%1 = vector.load %0[] : memref<f32, strided<[], offset: ?>>, vector<12x32xf32>
return %1 : vector<12x32xf32>
}
-// CHECK: func @fold_vector_load
+// CHECK: func @fold_vector_load_subview
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -834,14 +834,14 @@ func.func @fold_vector_load(
// -----
-func.func @fold_vector_maskedload(
+func.func @fold_vector_maskedload_subview(
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> vector<32xf32> {
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
%1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32> into vector<32xf32>
return %1 : vector<32xf32>
}
-// CHECK: func @fold_vector_maskedload
+// CHECK: func @fold_vector_maskedload_subview
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -851,14 +851,14 @@ func.func @fold_vector_maskedload(
// -----
-func.func @fold_vector_store(
+func.func @fold_vector_store_subview(
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<2x32xf32>) -> () {
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
vector.store %arg3, %0[] : memref<f32, strided<[], offset: ?>>, vector<2x32xf32>
return
}
-// CHECK: func @fold_vector_store
+// CHECK: func @fold_vector_store_subview
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -868,14 +868,14 @@ func.func @fold_vector_store(
// -----
-func.func @fold_vector_maskedstore(
+func.func @fold_vector_maskedstore_subview(
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> () {
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
vector.maskedstore %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32>
return
}
-// CHECK: func @fold_vector_maskedstore
+// CHECK: func @fold_vector_maskedstore_subview
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -883,3 +883,151 @@ func.func @fold_vector_maskedstore(
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
// CHECK: vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32>
// CHECK: return
+
+// -----
+
+func.func @fold_vector_load_expand_shape(
+ %arg0 : memref<32xf32>, %arg1 : index) -> vector<8xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+ %1 = vector.load %0[%arg1, %c0] : memref<4x8xf32>, vector<8xf32>
+ return %1 : vector<8xf32>
+}
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-LABEL: func @fold_vector_load_expand_shape
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+// CHECK: vector.load %[[ARG0]][%[[IDX]]]
+
+// -----
+
+func.func @fold_vector_maskedload_expand_shape(
+ %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+ %1 = vector.maskedload %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+ return %1 : vector<8xf32>
+}
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-LABEL: func @fold_vector_maskedload_expand_shape
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+// CHECK: vector.maskedload %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]
+
+// -----
+
+func.func @fold_vector_store_expand_shape(
+ %arg0 : memref<32xf32>, %arg1 : index, %val : vector<8xf32>) {
+ %c0 = arith.constant 0 : index
+ %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+ vector.store %val, %0[%arg1, %c0] : memref<4x8xf32>, vector<8xf32>
+ return
+}
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-LABEL: func @fold_vector_store_expand_shape
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+// CHECK: vector.store %{{.*}}, %[[ARG0]][%[[IDX]]]
+
+// -----
+
+func.func @fold_vector_maskedstore_expand_shape(
+ %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
+ %c0 = arith.constant 0 : index
+ %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+ vector.maskedstore %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32>
+ return
+}
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-LABEL: func @fold_vector_maskedstore_expand_shape
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+// CHECK: vector.maskedstore %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]
+
+// -----
+
+func.func @fold_vector_load_collapse_shape(
+ %arg0 : memref<4x8xf32>, %arg1 : index) -> vector<8xf32> {
+ %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
+ %1 = vector.load %0[%arg1] : memref<32xf32>, vector<8xf32>
+ return %1 : vector<8xf32>
+}
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-LABEL: func @fold_vector_load_collapse_shape
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+// CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
+// CHECK: vector.load %[[ARG0]][%[[IDX]], %[[IDX1]]]
+
+// -----
+
+func.func @fold_vector_maskedload_collapse_shape(
+ %arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
+ %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
+ %1 = vector.maskedload %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+ return %1 : vector<8xf32>
+}
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-LABEL: func @fold_vector_maskedload_collapse_shape
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+// CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
+// CHECK: vector.maskedload %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]
+
+// -----
+
+func.func @fold_vector_store_collapse_shape(
+ %arg0 : memref<4x8xf32>, %arg1 : index, %val : vector<8xf32>) {
+ %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
+ vector.store %val, %0[%arg1] : memref<32xf32>, vector<8xf32>
+ return
+}
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-LABEL: func @fold_vector_store_collapse_shape
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+// CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
+// CHECK: vector.store %{{.*}}, %[[ARG0]][%[[IDX]], %[[IDX1]]]
+
+// -----
+
+func.func @fold_vector_maskedstore_collapse_shape(
+ %arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
+ %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
+ vector.maskedstore %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32>
+ return
+}
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-LABEL: func @fold_vector_maskedstore_collapse_shape
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+// CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
+// CHECK: vector.maskedstore %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally LGTM with comments addressed.
Is the ´[arith]` in the title on purpose or just a copy-paste error or typo? Given that these patterns are in MemRef.
Side node: Its funny to see type switches when its already templated on the operation type and could be if constexpr
s. Don't view this as an actionable item tho :)
Typo, sorry. I don't know where the arith came from haha. |
Looks like the linux CI is stuck (the previous one was stuck for 3h). This issue should be platform independent and the windows CI passes, so I'm landing this. |
This patch adds adds patterns to fold memref alias for expand_shape/collapse_shape feeding into vector.load/vector.store and vector.maskedload/vector.maskedstore