Skip to content

[mlir][LinAlg] Vectorize reverse-like ops using vector.gather ops. #83205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 28, 2024

Conversation

hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Feb 27, 2024

The reverse op is treated as a VectorMemoryAccessKind::Contiguous load. It is contiguous slice, but we'll need to compute indices differently and apply a reverse at vector level. It takes non-trivial efforts for the approach. The revision flips the case to use vector.gather. Otherwise there are functionality issues. E.g., the below example loaded 2, 3, 4 (which is a bug), but what we want is 2, 1, 0.

Before vectorization:

func.func @vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1: tensor<1x1x3xf32>, %arg2: index) -> tensor<1x1x3xf32> {
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %c2 = arith.constant 2 : index
  %0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%arg1 : tensor<1x1x3xf32>) {
  ^bb0(%out: f32):
    %1 = linalg.index 1 : index
    %2 = linalg.index 0 : index
    %3 = affine.apply #map1(%1, %2, %arg2)
    %4 = linalg.index 2 : index
    %5 = arith.subi %c2, %4 : index
    %extracted = tensor.extract %arg0[%c0, %3, %5] : tensor<1x2x3xf32>
    linalg.yield %extracted : f32
  } -> tensor<1x1x3xf32>
  return %0 : tensor<1x1x3xf32>
}

Partial IR after vectorization:

  %5 = vector.constant_mask [1, 1, 3] : vector<1x1x4xi1>
  %6 = vector.broadcast %arg0 : index to vector<1x1x4xindex>
  %7 = vector.shape_cast %6 : vector<1x1x4xindex> to vector<4xindex>
  %8 = vector.extractelement %7[%c0_i32 : i32] : vector<4xindex>
  %9 = vector.transfer_read %3[%c0, %8, %c2], %cst, %5 {in_bounds = [true, true, true]} : tensor<1x2x3xf32>, vector<1x1x4xf32>

@llvmbot
Copy link
Member

llvmbot commented Feb 27, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

Changes

The reverse op is treated as a VectorMemoryAccessKind::Contiguous load. It is contiguous slice, but we'll need to compute indices differently and apply a reverse at vector level. It takes non-trivial efforts for the approach. The revision flips the case to use vector.gather. Otherwise there are functionality issues. E.g., the below example loaded 2, 3, 4 (which is a bug), but what we want is 2, 1, 0.

Before vectorization:

func.func @<!-- -->vectorize_reverse_like_tensor_extract(%arg0: tensor&lt;1x2x3xf32&gt;, %arg1: tensor&lt;1x1x3xf32&gt;, %arg2: index) -&gt; tensor&lt;1x1x3xf32&gt; {
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %c2 = arith.constant 2 : index
  %0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%arg1 : tensor&lt;1x1x3xf32&gt;) {
  ^bb0(%out: f32):
    %1 = linalg.index 1 : index
    %2 = linalg.index 0 : index
    %3 = affine.apply #map1(%1, %2, %arg2)
    %4 = linalg.index 2 : index
    %5 = arith.subi %c2, %4 : index
    %extracted = tensor.extract %arg0[%c0, %3, %5] : tensor&lt;1x2x3xf32&gt;
    linalg.yield %extracted : f32
  } -&gt; tensor&lt;1x1x3xf32&gt;
  return %0 : tensor&lt;1x1x3xf32&gt;
}

Partial IR after vectorization:

  %5 = vector.constant_mask [1, 1, 3] : vector&lt;1x1x4xi1&gt;
  %6 = vector.broadcast %arg0 : index to vector&lt;1x1x4xindex&gt;
  %7 = vector.shape_cast %6 : vector&lt;1x1x4xindex&gt; to vector&lt;4xindex&gt;
  %8 = vector.extractelement %7[%c0_i32 : i32] : vector&lt;4xindex&gt;
  %9 = vector.transfer_read %3[%c0, %8, %c2], %cst, %5 {in_bounds = [true, true, true]} : tensor&lt;1x2x3xf32&gt;, vector&lt;1x1x4xf32&gt;

Full diff: https://github.com/llvm/llvm-project/pull/83205.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+1-2)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+46)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ac043e87223dfe..1e703dacfd0c75 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -891,8 +891,7 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
 
   // Conservatively reject Ops that could lead to indices with stride other
   // than 1.
-  if (!isa<arith::AddIOp, arith::SubIOp, arith::ConstantOp, linalg::IndexOp>(
-          ancestor))
+  if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
     return false;
 
   bool result = false;
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index 96953c234a0873..9832b312c32439 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -542,6 +542,52 @@ func.func @vectorize_0d_tensor_extract(%arg0: tensor<f32>, %arg2: tensor<1x1x3xf
 // CHECK:           %[[EXTRACT:.*]] = tensor.extract %[[ARG_0]][] : tensor<f32>
 // CHECK:           vector.broadcast %[[EXTRACT]] : f32 to vector<1x1x3xf32>
 
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+     %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
+     transform.yield
+   }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
+func.func @vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1: tensor<1x1x3xf32>, %arg2: index) -> tensor<1x1x3xf32> {
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%arg1 : tensor<1x1x3xf32>) {
+  ^bb0(%out: f32):
+    %1 = linalg.index 1 : index
+    %2 = linalg.index 0 : index
+    %3 = affine.apply #map1(%1, %2, %arg2)
+    %4 = linalg.index 2 : index
+    %5 = arith.subi %c2, %4 : index
+    %extracted = tensor.extract %arg0[%c0, %3, %5] : tensor<1x2x3xf32>
+    linalg.yield %extracted : f32
+  } -> tensor<1x1x3xf32>
+  return %0 : tensor<1x1x3xf32>
+}
+// CHECK-LABEL: func.func @vectorize_reverse_like_tensor_extract
+// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]*]]
+// CHECK-SAME:    %[[ARG1:[0-9a-zA-Z]*]]
+// CHECK-SAME:    %[[ARG2:[0-9a-zA-Z]*]]
+// CHECK-DAG:    %[[CST:.+]] = arith.constant dense<3> : vector<1x1x3xindex>
+// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:    %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1>
+// CHECK-DAG:    %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32>
+// CHECK-DAG:    %[[INIT_IDX:.+]] = arith.constant dense<[2, 1, 0]> : vector<3xindex>
+// CHECK:        %[[T0:.+]] = vector.broadcast %[[ARG2]] : index to vector<1x1x3xindex>
+// CHECK:        %[[T1:.+]] = arith.muli %[[T0]], %[[CST]] : vector<1x1x3xindex>
+// CHECK:        %[[T2:.+]] = vector.broadcast %[[INIT_IDX]]
+// CHECK:        %[[T3:.+]] = arith.addi %[[T2]], %[[T1]]
+// CHECK:        %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[T3]]], %[[MASK]], %[[PASSTHRU]]
+// CHECK:        vector.transfer_write %[[GATHER]]
+
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
      %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op

The reverse op is treated as a VectorMemoryAccessKind::Contiguous load.
It is contiguous slice, but we'll need to compute indices differently
and apply a reverse at vector level. It takes non-trivial efforts for
the approach. The revision flips the case to use vector.gather.
Otherwise there are functionality issues. E.g., the below example loaded
`2, 3, 4` (which is a bug), but what we want is `2, 1, 0`.

Before vectorization:

```mlir
func.func @vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1: tensor<1x1x3xf32>, %arg2: index) -> tensor<1x1x3xf32> {
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %c2 = arith.constant 2 : index
  %0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%arg1 : tensor<1x1x3xf32>) {
  ^bb0(%out: f32):
    %1 = linalg.index 1 : index
    %2 = linalg.index 0 : index
    %3 = affine.apply #map1(%1, %2, %arg2)
    %4 = linalg.index 2 : index
    %5 = arith.subi %c2, %4 : index
    %extracted = tensor.extract %arg0[%c0, %3, %5] : tensor<1x2x3xf32>
    linalg.yield %extracted : f32
  } -> tensor<1x1x3xf32>
  return %0 : tensor<1x1x3xf32>
}
```

Partial IR after vectorization:

```
  %5 = vector.constant_mask [1, 1, 3] : vector<1x1x4xi1>
  %6 = vector.broadcast %arg0 : index to vector<1x1x4xindex>
  %7 = vector.shape_cast %6 : vector<1x1x4xindex> to vector<4xindex>
  %8 = vector.extractelement %7[%c0_i32 : i32] : vector<4xindex>
  %9 = vector.transfer_read %3[%c0, %8, %c2], %cst, %5 {in_bounds = [true, true, true]} : tensor<1x2x3xf32>, vector<1x1x4xf32>
```
@@ -891,8 +891,7 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,

// Conservatively reject Ops that could lead to indices with stride other
// than 1.
if (!isa<arith::AddIOp, arith::SubIOp, arith::ConstantOp, linalg::IndexOp>(
ancestor))
if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of removing the sub operation we may want to check that the sub is of the form index - loop_invariant. This should ensure that the index is monotonically increasing.

Additionally, we may want this function to return not only true or false but also if the index is an increasing or decreasing contiguous index. Then, the case sub loop_invariant - index should return a decreasing index and we can generate a gather for that case for now. That should leave the code ready to easily generate the proper transfer read.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also wondering... shouldn't the condition below be a &= instead of an or?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, totally agree with what you said. Checking that the sub is of the form index - loop_invariant is also the thing in my mind, but I think @banach-space should weigh in, while he's OOO this week. I have few questions too:

  1. Why do we add sub op to the list while there are no other tests about it.
  2. I'm also wondering why the condition below be a &=.

Can we land this as what it is for now to unblock correctness issue, and follow-up on this after Andrzej is back?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to landing this as is

Both issues that you mention look like my oversight, sorry for that.

There’s quite a few cases for tensor.extract and I struggled to come up with good tests to cover all of them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM, if one of you approve the PR, I'll land it. Thanks!

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG! Probably good to file an issue so that we don't miss this case!

@hanhanW
Copy link
Contributor Author

hanhanW commented Feb 28, 2024

SG! Probably good to file an issue so that we don't miss this case!

sure, will do!

@hanhanW hanhanW merged commit 46bd65a into llvm:main Feb 28, 2024
@hanhanW hanhanW deleted the vectorize-reverse-like-using-gather branch February 28, 2024 17:45
mylai-mtk pushed a commit to mylai-mtk/llvm-project that referenced this pull request Jul 12, 2024
…lvm#83205)

The reverse op is treated as a VectorMemoryAccessKind::Contiguous load.
It is contiguous slice, but we'll need to compute indices differently
and apply a reverse at vector level. It takes non-trivial efforts for
the approach. The revision flips the case to use vector.gather.
Otherwise there are functionality issues. E.g., the below example loaded
`2, 3, 4` (which is a bug), but what we want is `2, 1, 0`.

Before vectorization:

```mlir
func.func @vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1: tensor<1x1x3xf32>, %arg2: index) -> tensor<1x1x3xf32> {
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %c2 = arith.constant 2 : index
  %0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%arg1 : tensor<1x1x3xf32>) {
  ^bb0(%out: f32):
    %1 = linalg.index 1 : index
    %2 = linalg.index 0 : index
    %3 = affine.apply #map1(%1, %2, %arg2)
    %4 = linalg.index 2 : index
    %5 = arith.subi %c2, %4 : index
    %extracted = tensor.extract %arg0[%c0, %3, %5] : tensor<1x2x3xf32>
    linalg.yield %extracted : f32
  } -> tensor<1x1x3xf32>
  return %0 : tensor<1x1x3xf32>
}
```

Partial IR after vectorization:

```
  %5 = vector.constant_mask [1, 1, 3] : vector<1x1x4xi1>
  %6 = vector.broadcast %arg0 : index to vector<1x1x4xindex>
  %7 = vector.shape_cast %6 : vector<1x1x4xindex> to vector<4xindex>
  %8 = vector.extractelement %7[%c0_i32 : i32] : vector<4xindex>
  %9 = vector.transfer_read %3[%c0, %8, %c2], %cst, %5 {in_bounds = [true, true, true]} : tensor<1x2x3xf32>, vector<1x1x4xf32>
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants