Skip to content

[mlir][memref] Add memref alias folders for expand/collapse_shape for vector load/store #95223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 12, 2024

Conversation

Groverkss
Copy link
Member

This patch adds adds patterns to fold memref alias for expand_shape/collapse_shape feeding into vector.load/vector.store and vector.maskedload/vector.maskedstore

@llvmbot
Copy link
Member

llvmbot commented Jun 12, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Author: Kunwar Grover (Groverkss)

Changes

This patch adds adds patterns to fold memref alias for expand_shape/collapse_shape feeding into vector.load/vector.store and vector.maskedload/vector.maskedstore


Full diff: https://github.com/llvm/llvm-project/pull/95223.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (+79-11)
  • (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+156-8)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index db085b386483c..96daf4c5972a4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -518,10 +518,25 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
           loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
     return failure();
   llvm::TypeSwitch<Operation *, void>(loadOp)
-      .Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) {
-        rewriter.replaceOpWithNewOp<decltype(op)>(
+      .Case([&](affine::AffineLoadOp op) {
+        rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
             loadOp, expandShapeOp.getViewSource(), sourceIndices);
       })
+      .Case([&](memref::LoadOp op) {
+        rewriter.replaceOpWithNewOp<memref::LoadOp>(
+            loadOp, expandShapeOp.getViewSource(), sourceIndices,
+            op.getNontemporal());
+      })
+      .Case([&](vector::LoadOp op) {
+        rewriter.replaceOpWithNewOp<vector::LoadOp>(
+            op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
+            op.getNontemporal());
+      })
+      .Case([&](vector::MaskedLoadOp op) {
+        rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
+            op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
+            op.getMask(), op.getPassThru());
+      })
       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
   return success();
 }
@@ -551,10 +566,25 @@ LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
           loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
     return failure();
   llvm::TypeSwitch<Operation *, void>(loadOp)
-      .Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) {
-        rewriter.replaceOpWithNewOp<decltype(op)>(
+      .Case([&](affine::AffineLoadOp op) {
+        rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
             loadOp, collapseShapeOp.getViewSource(), sourceIndices);
       })
+      .Case([&](memref::LoadOp op) {
+        rewriter.replaceOpWithNewOp<memref::LoadOp>(
+            loadOp, collapseShapeOp.getViewSource(), sourceIndices,
+            op.getNontemporal());
+      })
+      .Case([&](vector::LoadOp op) {
+        rewriter.replaceOpWithNewOp<vector::LoadOp>(
+            op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
+            op.getNontemporal());
+      })
+      .Case([&](vector::MaskedLoadOp op) {
+        rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
+            op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
+            op.getMask(), op.getPassThru());
+      })
       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
   return success();
 }
@@ -651,10 +681,25 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
           storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
     return failure();
   llvm::TypeSwitch<Operation *, void>(storeOp)
-      .Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) {
-        rewriter.replaceOpWithNewOp<decltype(op)>(storeOp, storeOp.getValue(),
-                                                  expandShapeOp.getViewSource(),
-                                                  sourceIndices);
+      .Case([&](affine::AffineStoreOp op) {
+        rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
+            storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
+            sourceIndices);
+      })
+      .Case([&](memref::StoreOp op) {
+        rewriter.replaceOpWithNewOp<memref::StoreOp>(
+            storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
+            sourceIndices, op.getNontemporal());
+      })
+      .Case([&](vector::StoreOp op) {
+        rewriter.replaceOpWithNewOp<vector::StoreOp>(
+            op, op.getValueToStore(), expandShapeOp.getViewSource(),
+            sourceIndices, op.getNontemporal());
+      })
+      .Case([&](vector::MaskedStoreOp op) {
+        rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
+            op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
+            op.getValueToStore());
       })
       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
   return success();
@@ -685,11 +730,26 @@ LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
           storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
     return failure();
   llvm::TypeSwitch<Operation *, void>(storeOp)
-      .Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) {
-        rewriter.replaceOpWithNewOp<decltype(op)>(
-            storeOp, storeOp.getValue(), collapseShapeOp.getViewSource(),
+      .Case([&](affine::AffineStoreOp op) {
+        rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
+            storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
             sourceIndices);
       })
+      .Case([&](memref::StoreOp op) {
+        rewriter.replaceOpWithNewOp<memref::StoreOp>(
+            storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
+            sourceIndices, op.getNontemporal());
+      })
+      .Case([&](vector::StoreOp op) {
+        rewriter.replaceOpWithNewOp<vector::StoreOp>(
+            op, op.getValueToStore(), collapseShapeOp.getViewSource(),
+            sourceIndices, op.getNontemporal());
+      })
+      .Case([&](vector::MaskedStoreOp op) {
+        rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
+            op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
+            op.getValueToStore());
+      })
       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
   return success();
 }
@@ -763,12 +823,20 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
                StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
                LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
                LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
+               LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
+               LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
                StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
                StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
+               StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
+               StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
                LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
                LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
+               LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
+               LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
                StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
                StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
+               StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
+               StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
                SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
       patterns.getContext());
 }
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index e49dff44ae0d6..d67d6df23f90b 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -819,14 +819,14 @@ func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: ind
 
 // -----
 
-func.func @fold_vector_load(
+func.func @fold_vector_load_subview(
   %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {
   %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
   %1 = vector.load %0[] : memref<f32, strided<[], offset: ?>>, vector<12x32xf32>
   return %1 : vector<12x32xf32>
 }
 
-//      CHECK: func @fold_vector_load
+//      CHECK: func @fold_vector_load_subview
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -834,14 +834,14 @@ func.func @fold_vector_load(
 
 // -----
 
-func.func @fold_vector_maskedload(
+func.func @fold_vector_maskedload_subview(
   %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> vector<32xf32> {
   %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
   %1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32> into vector<32xf32>
   return %1 : vector<32xf32>
 }
 
-//      CHECK: func @fold_vector_maskedload
+//      CHECK: func @fold_vector_maskedload_subview
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -851,14 +851,14 @@ func.func @fold_vector_maskedload(
 
 // -----
 
-func.func @fold_vector_store(
+func.func @fold_vector_store_subview(
   %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<2x32xf32>) -> () {
   %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
   vector.store %arg3, %0[] : memref<f32, strided<[], offset: ?>>, vector<2x32xf32>
   return
 }
 
-//      CHECK: func @fold_vector_store
+//      CHECK: func @fold_vector_store_subview
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -868,14 +868,14 @@ func.func @fold_vector_store(
 
 // -----
 
-func.func @fold_vector_maskedstore(
+func.func @fold_vector_maskedstore_subview(
   %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> () {
   %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
   vector.maskedstore %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32>
   return
 }
 
-//      CHECK: func @fold_vector_maskedstore
+//      CHECK: func @fold_vector_maskedstore_subview
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -883,3 +883,151 @@ func.func @fold_vector_maskedstore(
 // CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
 //      CHECK:   vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32>
 //      CHECK:   return
+
+// -----
+
+func.func @fold_vector_load_expand_shape(
+  %arg0 : memref<32xf32>, %arg1 : index) -> vector<8xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+  %1 = vector.load %0[%arg1, %c0] : memref<4x8xf32>, vector<8xf32>
+  return %1 : vector<8xf32>
+}
+
+//   CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-LABEL: func @fold_vector_load_expand_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//       CHECK:   %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   vector.load %[[ARG0]][%[[IDX]]]
+
+// -----
+
+func.func @fold_vector_maskedload_expand_shape(
+  %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+  %1 = vector.maskedload %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+  return %1 : vector<8xf32>
+}
+
+//   CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-LABEL: func @fold_vector_maskedload_expand_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+//       CHECK:   %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   vector.maskedload %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]
+
+// -----
+
+func.func @fold_vector_store_expand_shape(
+  %arg0 : memref<32xf32>, %arg1 : index, %val : vector<8xf32>) {
+  %c0 = arith.constant 0 : index
+  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+  vector.store %val, %0[%arg1, %c0] : memref<4x8xf32>, vector<8xf32>
+  return
+}
+
+//   CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-LABEL: func @fold_vector_store_expand_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//       CHECK:   %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   vector.store %{{.*}}, %[[ARG0]][%[[IDX]]]
+
+// -----
+
+func.func @fold_vector_maskedstore_expand_shape(
+  %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
+  %c0 = arith.constant 0 : index
+  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+  vector.maskedstore %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32>
+  return
+}
+
+//   CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-LABEL: func @fold_vector_maskedstore_expand_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+//       CHECK:   %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   vector.maskedstore %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]
+
+// -----
+
+func.func @fold_vector_load_collapse_shape(
+  %arg0 : memref<4x8xf32>, %arg1 : index) -> vector<8xf32> {
+  %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
+  %1 = vector.load %0[%arg1] : memref<32xf32>, vector<8xf32>
+  return %1 : vector<8xf32>
+}
+
+//   CHECK-DAG: #[[$MAP:.*]]  = affine_map<()[s0] -> (s0 floordiv 8)>
+//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-LABEL: func @fold_vector_load_collapse_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//       CHECK:   %[[IDX:.*]] = affine.apply  #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
+//       CHECK:   vector.load %[[ARG0]][%[[IDX]], %[[IDX1]]]
+
+// -----
+
+func.func @fold_vector_maskedload_collapse_shape(
+  %arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
+  %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
+  %1 = vector.maskedload %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+  return %1 : vector<8xf32>
+}
+
+//   CHECK-DAG: #[[$MAP:.*]]  = affine_map<()[s0] -> (s0 floordiv 8)>
+//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-LABEL: func @fold_vector_maskedload_collapse_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+//       CHECK:   %[[IDX:.*]] = affine.apply  #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
+//       CHECK:   vector.maskedload %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]
+
+// -----
+
+func.func @fold_vector_store_collapse_shape(
+  %arg0 : memref<4x8xf32>, %arg1 : index, %val : vector<8xf32>) {
+  %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
+  vector.store %val, %0[%arg1] : memref<32xf32>, vector<8xf32>
+  return
+}
+
+//   CHECK-DAG: #[[$MAP:.*]]  = affine_map<()[s0] -> (s0 floordiv 8)>
+//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-LABEL: func @fold_vector_store_collapse_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//       CHECK:   %[[IDX:.*]] = affine.apply  #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
+//       CHECK:   vector.store %{{.*}}, %[[ARG0]][%[[IDX]], %[[IDX1]]]
+
+// -----
+
+func.func @fold_vector_maskedstore_collapse_shape(
+  %arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
+  %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
+  vector.maskedstore %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32>
+  return
+}
+
+//   CHECK-DAG: #[[$MAP:.*]]  = affine_map<()[s0] -> (s0 floordiv 8)>
+//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-LABEL: func @fold_vector_maskedstore_collapse_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+//       CHECK:   %[[IDX:.*]] = affine.apply  #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
+//       CHECK:   vector.maskedstore %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally LGTM with comments addressed.

Is the ´[arith]` in the title on purpose or just a copy-paste error or typo? Given that these patterns are in MemRef.

Side node: Its funny to see type switches when its already templated on the operation type and could be if constexprs. Don't view this as an actionable item tho :)

@Groverkss Groverkss changed the title [mlir][arith] Add memref alias folders for expand/collapse_shape for vector load/store [mlir][memref] Add memref alias folders for expand/collapse_shape for vector load/store Jun 12, 2024
@Groverkss
Copy link
Member Author

Is the ´[arith]` in the title on purpose or just a copy-paste error or typo? Given that these patterns are in MemRef.

Typo, sorry. I don't know where the arith came from haha.

@Groverkss Groverkss requested a review from zero9178 June 12, 2024 14:11
@Groverkss
Copy link
Member Author

Looks like the linux CI is stuck (the previous one was stuck for 3h). This issue should be platform independent and the windows CI passes, so I'm landing this.

@Groverkss Groverkss merged commit 57e4360 into llvm:main Jun 12, 2024
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants