Skip to content

[mlir][vector] Fix FlattenGather for scalable vectors #96074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 24, 2024

Conversation

c-rhodes
Copy link
Collaborator

This pattern flattens vector.gather ops by unrolling the outermost
dimension for rank > 2 vectors. There's two issues with this pattern for
scalable vectors:

  1. The unrolling doesn't take vscale into account. A constraint is
    added to disable this pattern for vectors with leading scalable
    dims.
  2. The scalable dims are dropped when creating the new gather. Fixed
    by propagating the flags.

Depends on #96049.

@c-rhodes c-rhodes requested a review from MacDue June 19, 2024 13:56
c-rhodes added 2 commits June 20, 2024 07:09
This pattern flattens vector.gather ops by unrolling the outermost
dimension for rank > 2 vectors. There's two issues with this pattern for
scalable vectors:

  1. The unrolling doesn't take vscale into account. A constraint is
     added to disable this pattern for vectors with leading scalable
     dims.
  2. The scalable dims are dropped when creating the new gather. Fixed
     by propagating the flags.

Depends on llvm#96049.
@c-rhodes c-rhodes force-pushed the mlir-vector-scalable-flatten-gather branch from 5e9f5e5 to ae6d11f Compare June 20, 2024 07:19
@c-rhodes c-rhodes marked this pull request as ready for review June 20, 2024 07:20
@llvmbot
Copy link
Member

llvmbot commented Jun 20, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Cullen Rhodes (c-rhodes)

Changes

This pattern flattens vector.gather ops by unrolling the outermost
dimension for rank > 2 vectors. There's two issues with this pattern for
scalable vectors:

  1. The unrolling doesn't take vscale into account. A constraint is
    added to disable this pattern for vectors with leading scalable
    dims.
  2. The scalable dims are dropped when creating the new gather. Fixed
    by propagating the flags.

Depends on #96049.


Full diff: https://github.com/llvm/llvm-project/pull/96074.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (+9-1)
  • (modified) mlir/test/Dialect/Vector/vector-gather-lowering.mlir (+26)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index dd027d107d16a..1abde32450f1e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -55,6 +55,8 @@ namespace {
 /// ```
 ///
 /// When applied exhaustively, this will produce a sequence of 1-d gather ops.
+///
+/// Supports vector types with trailing scalable dim.
 struct FlattenGather : OpRewritePattern<vector::GatherOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -64,6 +66,11 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
     if (resultTy.getRank() < 2)
       return rewriter.notifyMatchFailure(op, "already flat");
 
+    // Unrolling doesn't take vscale into account. Pattern is disabled for
+    // vectors with leading scalable dim(s).
+    if (resultTy.getScalableDims().front())
+      return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
+
     Location loc = op.getLoc();
     Value indexVec = op.getIndexVec();
     Value maskVec = op.getMask();
@@ -73,7 +80,8 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
         loc, resultTy, rewriter.getZeroAttr(resultTy));
 
     Type subTy = VectorType::get(resultTy.getShape().drop_front(),
-                                 resultTy.getElementType());
+                                 resultTy.getElementType(),
+                                 resultTy.getScalableDims().drop_front());
 
     for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
       int64_t thisIdx[1] = {i};
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index c2eb88afa4dbf..ff1a92a65c42d 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -74,6 +74,32 @@ func.func @gather_memref_1d_i32_index(%base: memref<?xf32>, %v: vector<2xi32>, %
   return %0 : vector<2x3xf32>
  }
 
+// CHECK-LABEL: @scalable_gather_memref_2d
+// CHECK-SAME:      %[[BASE:.*]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[IDXVEC:.*]]: vector<2x[3]xindex>,
+// CHECK-SAME:      %[[MASK:.*]]: vector<2x[3]xi1>,
+// CHECK-SAME:      %[[PASS:.*]]: vector<2x[3]xf32>
+// CHECK:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         %[[C1:.*]] = arith.constant 1 : index
+// CHECK:         %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x[3]xf32>
+// CHECK:         %[[IDXVEC0:.*]] = vector.extract %[[IDXVEC]][0] : vector<[3]xindex> from vector<2x[3]xindex>
+// CHECK:         %[[MASK0:.*]] = vector.extract %[[MASK]][0] : vector<[3]xi1> from vector<2x[3]xi1>
+// CHECK:         %[[PASS0:.*]] = vector.extract %[[PASS]][0] : vector<[3]xf32> from vector<2x[3]xf32>
+// CHECK:         %[[GATHER0:.*]] = vector.gather %[[BASE]]{{\[}}%[[C0]], %[[C1]]] {{\[}}%[[IDXVEC0]]], %[[MASK0]], %[[PASS0]] : memref<?x?xf32>, vector<[3]xindex>, vector<[3]xi1>, vector<[3]xf32> into vector<[3]xf32>
+// CHECK:         %[[INS0:.*]] = vector.insert %[[GATHER0]], %[[INIT]] [0] : vector<[3]xf32> into vector<2x[3]xf32>
+// CHECK:         %[[IDXVEC1:.*]] = vector.extract %[[IDXVEC]][1] : vector<[3]xindex> from vector<2x[3]xindex>
+// CHECK:         %[[MASK1:.*]] = vector.extract %[[MASK]][1] : vector<[3]xi1> from vector<2x[3]xi1>
+// CHECK:         %[[PASS1:.*]] = vector.extract %[[PASS]][1] : vector<[3]xf32> from vector<2x[3]xf32>
+// CHECK:         %[[GATHER1:.*]] = vector.gather %[[BASE]]{{\[}}%[[C0]], %[[C1]]] {{\[}}%[[IDXVEC1]]], %[[MASK1]], %[[PASS1]] : memref<?x?xf32>, vector<[3]xindex>, vector<[3]xi1>, vector<[3]xf32> into vector<[3]xf32>
+// CHECK:         %[[INS1:.*]] = vector.insert %[[GATHER1]], %[[INS0]] [1] : vector<[3]xf32> into vector<2x[3]xf32>
+// CHECK-NEXT:    return %[[INS1]] : vector<2x[3]xf32>
+func.func @scalable_gather_memref_2d(%base: memref<?x?xf32>, %v: vector<2x[3]xindex>, %mask: vector<2x[3]xi1>, %pass_thru: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : memref<?x?xf32>, vector<2x[3]xindex>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
+ return %0 : vector<2x[3]xf32>
+}
+
 // CHECK-LABEL: @gather_tensor_1d
 // CHECK-SAME:    ([[BASE:%.+]]: tensor<?xf32>, [[IDXVEC:%.+]]: vector<2xindex>, [[MASK:%.+]]: vector<2xi1>, [[PASS:%.+]]: vector<2xf32>)
 // CHECK-DAG:     [[M0:%.+]]    = vector.extract [[MASK]][0] : i1 from vector<2xi1>

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just a few little nits:

@c-rhodes c-rhodes merged commit 9931ee6 into llvm:main Jun 24, 2024
7 checks passed
@c-rhodes c-rhodes deleted the mlir-vector-scalable-flatten-gather branch June 24, 2024 07:36
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
This pattern flattens vector.gather ops by unrolling the outermost
dimension for rank > 2 vectors. There's two issues with this pattern for
scalable vectors:

  1. The unrolling doesn't take vscale into account. A constraint is
     added to disable this pattern for vectors with leading scalable
     dims.
  2. The scalable dims are dropped when creating the new gather. Fixed
     by propagating the flags.

Depends on llvm#96049.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants