Skip to content

[mlir][VectorOps] Add unrolling for n-D vector.interleave ops (3/4) #80967

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 20, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Feb 7, 2024

This unrolls n-D vector.interleave ops like:

vector.interleave %i, %j : vector<6x3xf32>

To a sequence of 1-D operations:

%i_0 = vector.extract %i[0] 
%j_0 = vector.extract %j[0] 
%res_0 = vector.interleave %i_0, %j_0 : vector<3xf32>
vector.insert %res_0, %result[0] :
// ... repeated x6

The 1-D operations can then be directly lowered to LLVM.

Depends on: #80966

@llvmbot
Copy link
Member

llvmbot commented Feb 13, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

This unrolls n-D vector.interleave ops like:

vector.interleave %i, %j : vector&lt;6x3xf32&gt;

To a sequence of 1-D operations:

%i_0 = vector.extract %i[0] 
%j_0 = vector.extract %j[0] 
%res_0 = vector.interleave %i_0, %j_0 : vector&lt;3xf32&gt;
vector.insert %res_0, %result[0] :
// ... repeated x6

The 1-D operations can then be directly lowered to LLVM.

Depends on: #80966


Full diff: https://github.com/llvm/llvm-project/pull/80967.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+8)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp (+64)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+48)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 57b39f5f52c6d3..1cd3bab46396e3 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -264,6 +264,14 @@ void populateVectorMaskLoweringPatternsForSideEffectingOps(
 void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns,
                                                     PatternBenefit benefit = 1);
 
+/// Populate the pattern set with the following patterns:
+///
+/// [InterleaveOpLowering]
+/// Progressive lowering of InterleaveOp to ExtractOp + InsertOp + lower-D
+/// InterleaveOp until dim 1.
+void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
+                                              PatternBenefit benefit = 1);
+
 } // namespace vector
 } // namespace mlir
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index ff8e78a668e0f1..e3a436c4a94009 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -68,6 +68,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
     populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions());
     populateVectorMaskOpLoweringPatterns(patterns);
     populateVectorShapeCastLoweringPatterns(patterns);
+    populateVectorInterleaveLoweringPatterns(patterns);
     populateVectorTransposeLoweringPatterns(patterns,
                                             VectorTransformsOptions());
     // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index daf28882976ef6..f221b7462dfd7a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorBroadcast.cpp
   LowerVectorContract.cpp
   LowerVectorGather.cpp
+  LowerVectorInterleave.cpp
   LowerVectorMask.cpp
   LowerVectorMultiReduction.cpp
   LowerVectorScan.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
new file mode 100644
index 00000000000000..0ca38eba942a5d
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -0,0 +1,64 @@
+//===- LowerVectorInterleave.cpp - Lower 'vector.interleave' operation ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.interleave' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+
+#define DEBUG_TYPE "vector-interleave-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+/// Progressive lowering of InterleaveOp.
+class InterleaveOpLowering : public OpRewritePattern<vector::InterleaveOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::InterleaveOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResultVectorType();
+    // 1-D vector.interleave ops can be directly lowered to LLVM (later).
+    if (resultType.getRank() == 1)
+      return failure();
+
+    // Below we unroll the leading (or front) dimension. If that dimension is
+    // scalable we can't unroll it.
+    if (resultType.getScalableDims().front())
+      return failure();
+
+    // n-D case: Unroll the leading dimension.
+    auto loc = op.getLoc();
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resultType, rewriter.getZeroAttr(resultType));
+    for (int idx = 0, end = resultType.getDimSize(0); idx < end; ++idx) {
+      Value extractLhs = rewriter.create<ExtractOp>(loc, op.getLhs(), idx);
+      Value extractRhs = rewriter.create<ExtractOp>(loc, op.getRhs(), idx);
+      Value interleave =
+          rewriter.create<InterleaveOp>(loc, extractLhs, extractRhs);
+      result = rewriter.create<InsertOp>(loc, interleave, result, idx);
+    }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorInterleaveLoweringPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<InterleaveOpLowering>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index a46f2e101f3c35..3cbca65472fb69 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2497,3 +2497,51 @@ func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32
   %0 = vector.interleave %a, %b : vector<[4]xi32>
   return %0 : vector<[8]xi32>
 }
+
+// -----
+
+// CHECK-LABEL: @vector_interleave_2d
+//  CHECK-SAME:     %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>)
+func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8>
+{
+  // CHECK: %[[LHS_LLVM:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : vector<2x3xi8> to !llvm.array<2 x vector<3xi8>>
+  // CHECK: %[[RHS_LLVM:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : vector<2x3xi8> to !llvm.array<2 x vector<3xi8>>
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x6xi8>
+  // CHECK: %[[CST_LLVM:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<2x6xi8> to !llvm.array<2 x vector<6xi8>>
+  // CHECK: %[[LHS_DIM_0:.*]] = llvm.extractvalue %[[LHS_LLVM]][0] : !llvm.array<2 x vector<3xi8>>
+  // CHECK: %[[RHS_DIM_0:.*]] = llvm.extractvalue %[[RHS_LLVM]][0] : !llvm.array<2 x vector<3xi8>>
+  // CHECK: %[[ZIM_DIM_0:.*]] = llvm.shufflevector %[[LHS_DIM_0]], %[[RHS_DIM_0]] [0, 3, 1, 4, 2, 5] : vector<3xi8>
+  // CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[ZIM_DIM_0]], %[[CST_LLVM]][0] : !llvm.array<2 x vector<6xi8>>
+  // CHECK: %[[LHS_DIM_1:.*]] = llvm.extractvalue %[[LHS_LLVM]][1] : !llvm.array<2 x vector<3xi8>>
+  // CHECK: %[[RHS_DIM_1:.*]] = llvm.extractvalue %[[RHS_LLVM]][1] : !llvm.array<2 x vector<3xi8>>
+  // CHECK: %[[ZIM_DIM_1:.*]] = llvm.shufflevector %[[LHS_DIM_1]], %[[RHS_DIM_1]] [0, 3, 1, 4, 2, 5] : vector<3xi8>
+  // CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[ZIM_DIM_1]], %[[RES_0]][1] : !llvm.array<2 x vector<6xi8>>
+  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x vector<6xi8>> to vector<2x6xi8>
+  // CHECK: return %[[RES]]
+  %0 = vector.interleave %a, %b : vector<2x3xi8>
+  return %0 : vector<2x6xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_interleave_2d_scalable
+//  CHECK-SAME:     %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>)
+func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16>
+{
+  // CHECK: %[[LHS_LLVM:.*]] = builtin.unrealized_conversion_cast %arg0 : vector<2x[8]xi16> to !llvm.array<2 x vector<[8]xi16>>
+  // CHECK: %[[RHS_LLVM:.*]] = builtin.unrealized_conversion_cast %arg1 : vector<2x[8]xi16> to !llvm.array<2 x vector<[8]xi16>>
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x[16]xi16>
+  // CHECK: %[[CST_LLVM:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<2x[16]xi16> to !llvm.array<2 x vector<[16]xi16>>
+  // CHECK: %[[LHS_DIM_0:.*]] = llvm.extractvalue %[[LHS_LLVM]][0] : !llvm.array<2 x vector<[8]xi16>>
+  // CHECK: %[[RHS_DIM_0:.*]] = llvm.extractvalue %[[RHS_LLVM]][0] : !llvm.array<2 x vector<[8]xi16>>
+  // CHECK: %[[ZIM_DIM_0:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[LHS_DIM_0]], %[[RHS_DIM_0]]) : (vector<[8]xi16>, vector<[8]xi16>) -> vector<[16]xi16>
+  // CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[ZIM_DIM_0]], %[[CST_LLVM]][0] : !llvm.array<2 x vector<[16]xi16>>
+  // CHECK: %[[LHS_DIM_1:.*]] = llvm.extractvalue %0[1] : !llvm.array<2 x vector<[8]xi16>>
+  // CHECK: %[[RHS_DIM_1:.*]] = llvm.extractvalue %1[1] : !llvm.array<2 x vector<[8]xi16>>
+  // CHECK: %[[ZIP_DIM_1:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[LHS_DIM_1]], %[[RHS_DIM_1]]) : (vector<[8]xi16>, vector<[8]xi16>) -> vector<[16]xi16>
+  // CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[ZIP_DIM_1]], %[[RES_0]][1] : !llvm.array<2 x vector<[16]xi16>>
+  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x vector<[16]xi16>> to vector<2x[16]xi16>
+  // CHECK: return %[[RES]]
+  %0 = vector.interleave %a, %b : vector<2x[8]xi16>
+  return %0 : vector<2x[16]xi16>
+}

@MacDue MacDue force-pushed the add_vector.interleave_2 branch from c8a0b8e to 661e3b6 Compare February 13, 2024 17:37
Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

This unrolls n-D vector.interleave ops like:

```mlir
vector.interleave %i, %j : vector<6x3xf32>
```

To a sequence of 1-D operations, which can then be directly lowered to
LLVM.
@MacDue MacDue force-pushed the add_vector.interleave_2 branch 2 times, most recently from b25cd93 to 04a5655 Compare February 15, 2024 13:36
Instead of progressively unrolling a leading dimension at a time, this
now uses `vector::createUnrollIterator()` which returns an iterator for
all leading dimensions of a vector type (until a target rank).
@MacDue MacDue force-pushed the add_vector.interleave_2 branch from 04a5655 to 6b7d618 Compare February 15, 2024 14:26
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@MacDue MacDue merged commit a1a6860 into llvm:main Feb 20, 2024
@MacDue MacDue deleted the add_vector.interleave_2 branch February 20, 2024 14:33
dcaballe pushed a commit to iree-org/llvm-project that referenced this pull request Feb 23, 2024
…0967)

This unrolls n-D vector.interleave ops like:

```mlir
vector.interleave %i, %j : vector<6x3xf32>
```

To a sequence of 1-D operations:
```mlir
%i_0 = vector.extract %i[0] 
%j_0 = vector.extract %j[0] 
%res_0 = vector.interleave %i_0, %j_0 : vector<3xf32>
vector.insert %res_0, %result[0] :
// ... repeated x6
```

The 1-D operations can then be directly lowered to LLVM.

Depends on: llvm#80966
MacDue added a commit that referenced this pull request Mar 6, 2024
This folds fixed-size vector.shuffle ops that perform a 1-D interleave
to a vector.interleave operation.

For example:

```mlir
%0 = vector.shuffle %a, %b [0, 2, 1, 4] : vector<2xi32>, vector<2xi32>
```

folds to:

```mlir
%0 = vector.interleave %a, %b : vector<2xi32>
```

Depends on: #80967
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants