Skip to content

Commit c8a0b8e

Browse files
committed
[mlir][VectorOps] Add unrolling for n-D vector.interleave ops
This unrolls n-D vector.interleave ops like: ```mlir vector.interleave %i, %j : vector<6x3xf32> ``` To a sequence of 1-D operations, which can then be directly lowered to LLVM.
1 parent 79ce2c9 commit c8a0b8e

File tree

5 files changed

+122
-0
lines changed

5 files changed

+122
-0
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,14 @@ void populateVectorMaskLoweringPatternsForSideEffectingOps(
264264
void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns,
265265
PatternBenefit benefit = 1);
266266

267+
/// Populate the pattern set with the following patterns:
268+
///
269+
/// [InterleaveOpLowering]
270+
/// Progressive lowering of InterleaveOp to ExtractOp + InsertOp + lower-D
271+
/// InterleaveOp until dim 1.
272+
void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
273+
PatternBenefit benefit = 1);
274+
267275
} // namespace vector
268276
} // namespace mlir
269277
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
6868
populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions());
6969
populateVectorMaskOpLoweringPatterns(patterns);
7070
populateVectorShapeCastLoweringPatterns(patterns);
71+
populateVectorInterleaveLoweringPatterns(patterns);
7172
populateVectorTransposeLoweringPatterns(patterns,
7273
VectorTransformsOptions());
7374
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
44
LowerVectorBroadcast.cpp
55
LowerVectorContract.cpp
66
LowerVectorGather.cpp
7+
LowerVectorInterleave.cpp
78
LowerVectorMask.cpp
89
LowerVectorMultiReduction.cpp
910
LowerVectorScan.cpp
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
//===- LowerVectorInterleave.cpp - Lower 'vector.interleave' operation ----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements target-independent rewrites and utilities to lower the
10+
// 'vector.interleave' operation.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
15+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
16+
#include "mlir/IR/BuiltinTypes.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
19+
#define DEBUG_TYPE "vector-interleave-lowering"
20+
21+
using namespace mlir;
22+
using namespace mlir::vector;
23+
24+
namespace {
25+
/// Progressive lowering of InterleaveOp.
26+
class InterleaveOpLowering : public OpRewritePattern<vector::InterleaveOp> {
27+
public:
28+
using OpRewritePattern::OpRewritePattern;
29+
30+
LogicalResult matchAndRewrite(vector::InterleaveOp op,
31+
PatternRewriter &rewriter) const override {
32+
VectorType resultType = op.getResultVectorType();
33+
// 1-D vector.interleave ops can be directly lowered to LLVM (later).
34+
if (resultType.getRank() == 1)
35+
return failure();
36+
37+
// Below we unroll the leading (or front) dimension. If that dimension is
38+
// scalable we can't unroll it.
39+
if (resultType.getScalableDims().front())
40+
return failure();
41+
42+
// n-D case: Unroll the leading dimension.
43+
auto loc = op.getLoc();
44+
Value result = rewriter.create<arith::ConstantOp>(
45+
loc, resultType, rewriter.getZeroAttr(resultType));
46+
for (int idx = 0, end = resultType.getDimSize(0); idx < end; ++idx) {
47+
Value extractLhs = rewriter.create<ExtractOp>(loc, op.getLhs(), idx);
48+
Value extractRhs = rewriter.create<ExtractOp>(loc, op.getRhs(), idx);
49+
Value interleave =
50+
rewriter.create<InterleaveOp>(loc, extractLhs, extractRhs);
51+
result = rewriter.create<InsertOp>(loc, interleave, result, idx);
52+
}
53+
54+
rewriter.replaceOp(op, result);
55+
return success();
56+
}
57+
};
58+
59+
} // namespace
60+
61+
void mlir::vector::populateVectorInterleaveLoweringPatterns(
62+
RewritePatternSet &patterns, PatternBenefit benefit) {
63+
patterns.add<InterleaveOpLowering>(patterns.getContext(), benefit);
64+
}

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2497,3 +2497,51 @@ func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32
24972497
%0 = vector.interleave %a, %b : vector<[4]xi32>
24982498
return %0 : vector<[8]xi32>
24992499
}
2500+
2501+
// -----
2502+
2503+
// CHECK-LABEL: @vector_interleave_2d
2504+
// CHECK-SAME: %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>)
2505+
func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8>
2506+
{
2507+
// CHECK: %[[LHS_LLVM:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : vector<2x3xi8> to !llvm.array<2 x vector<3xi8>>
2508+
// CHECK: %[[RHS_LLVM:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : vector<2x3xi8> to !llvm.array<2 x vector<3xi8>>
2509+
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x6xi8>
2510+
// CHECK: %[[CST_LLVM:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<2x6xi8> to !llvm.array<2 x vector<6xi8>>
2511+
// CHECK: %[[LHS_DIM_0:.*]] = llvm.extractvalue %[[LHS_LLVM]][0] : !llvm.array<2 x vector<3xi8>>
2512+
// CHECK: %[[RHS_DIM_0:.*]] = llvm.extractvalue %[[RHS_LLVM]][0] : !llvm.array<2 x vector<3xi8>>
2513+
// CHECK: %[[ZIM_DIM_0:.*]] = llvm.shufflevector %[[LHS_DIM_0]], %[[RHS_DIM_0]] [0, 3, 1, 4, 2, 5] : vector<3xi8>
2514+
// CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[ZIM_DIM_0]], %[[CST_LLVM]][0] : !llvm.array<2 x vector<6xi8>>
2515+
// CHECK: %[[LHS_DIM_1:.*]] = llvm.extractvalue %[[LHS_LLVM]][1] : !llvm.array<2 x vector<3xi8>>
2516+
// CHECK: %[[RHS_DIM_1:.*]] = llvm.extractvalue %[[RHS_LLVM]][1] : !llvm.array<2 x vector<3xi8>>
2517+
// CHECK: %[[ZIM_DIM_1:.*]] = llvm.shufflevector %[[LHS_DIM_1]], %[[RHS_DIM_1]] [0, 3, 1, 4, 2, 5] : vector<3xi8>
2518+
// CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[ZIM_DIM_1]], %[[RES_0]][1] : !llvm.array<2 x vector<6xi8>>
2519+
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x vector<6xi8>> to vector<2x6xi8>
2520+
// CHECK: return %[[RES]]
2521+
%0 = vector.interleave %a, %b : vector<2x3xi8>
2522+
return %0 : vector<2x6xi8>
2523+
}
2524+
2525+
// -----
2526+
2527+
// CHECK-LABEL: @vector_interleave_2d_scalable
2528+
// CHECK-SAME: %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>)
2529+
func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16>
2530+
{
2531+
// CHECK: %[[LHS_LLVM:.*]] = builtin.unrealized_conversion_cast %arg0 : vector<2x[8]xi16> to !llvm.array<2 x vector<[8]xi16>>
2532+
// CHECK: %[[RHS_LLVM:.*]] = builtin.unrealized_conversion_cast %arg1 : vector<2x[8]xi16> to !llvm.array<2 x vector<[8]xi16>>
2533+
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x[16]xi16>
2534+
// CHECK: %[[CST_LLVM:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<2x[16]xi16> to !llvm.array<2 x vector<[16]xi16>>
2535+
// CHECK: %[[LHS_DIM_0:.*]] = llvm.extractvalue %[[LHS_LLVM]][0] : !llvm.array<2 x vector<[8]xi16>>
2536+
// CHECK: %[[RHS_DIM_0:.*]] = llvm.extractvalue %[[RHS_LLVM]][0] : !llvm.array<2 x vector<[8]xi16>>
2537+
// CHECK: %[[ZIM_DIM_0:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[LHS_DIM_0]], %[[RHS_DIM_0]]) : (vector<[8]xi16>, vector<[8]xi16>) -> vector<[16]xi16>
2538+
// CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[ZIM_DIM_0]], %[[CST_LLVM]][0] : !llvm.array<2 x vector<[16]xi16>>
2539+
// CHECK: %[[LHS_DIM_1:.*]] = llvm.extractvalue %0[1] : !llvm.array<2 x vector<[8]xi16>>
2540+
// CHECK: %[[RHS_DIM_1:.*]] = llvm.extractvalue %1[1] : !llvm.array<2 x vector<[8]xi16>>
2541+
// CHECK: %[[ZIP_DIM_1:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[LHS_DIM_1]], %[[RHS_DIM_1]]) : (vector<[8]xi16>, vector<[8]xi16>) -> vector<[16]xi16>
2542+
// CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[ZIP_DIM_1]], %[[RES_0]][1] : !llvm.array<2 x vector<[16]xi16>>
2543+
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x vector<[16]xi16>> to vector<2x[16]xi16>
2544+
// CHECK: return %[[RES]]
2545+
%0 = vector.interleave %a, %b : vector<2x[8]xi16>
2546+
return %0 : vector<2x[16]xi16>
2547+
}

0 commit comments

Comments
 (0)