Skip to content

Commit c3a5790

Browse files
committed
[mlir][VectorOps] Add unrolling for n-D vector.interleave ops
This unrolls n-D vector.interleave ops like: ```mlir vector.interleave %i, %j : vector<6x3xf32> ``` To a sequence of 1-D operations, which can then be directly lowered to LLVM.
1 parent f6f8e20 commit c3a5790

File tree

8 files changed

+190
-0
lines changed

8 files changed

+190
-0
lines changed

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,20 @@ def ApplyLowerTransposePatternsOp : Op<Transform_Dialect,
292292
}];
293293
}
294294

295+
def ApplyLowerInterleavePatternsOp : Op<Transform_Dialect,
296+
"apply_patterns.vector.lower_interleave",
297+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
298+
let description = [{
299+
Indicates that vector interleave operations should be lowered to
300+
finer-grained vector primitives.
301+
302+
This is usally a late step that is run after bufferization as part of the
303+
process of lowering to e.g. LLVM or NVVM.
304+
}];
305+
306+
let assemblyFormat = "attr-dict";
307+
}
308+
295309
def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
296310
"apply_patterns.vector.rewrite_narrow_types",
297311
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,14 @@ void populateVectorMaskLoweringPatternsForSideEffectingOps(
264264
void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns,
265265
PatternBenefit benefit = 1);
266266

267+
/// Populate the pattern set with the following patterns:
268+
///
269+
/// [InterleaveOpLowering]
270+
/// Progressive lowering of InterleaveOp to ExtractOp + InsertOp + lower-D
271+
/// InterleaveOp until dim 1.
272+
void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
273+
PatternBenefit benefit = 1);
274+
267275
} // namespace vector
268276
} // namespace mlir
269277
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
6868
populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions());
6969
populateVectorMaskOpLoweringPatterns(patterns);
7070
populateVectorShapeCastLoweringPatterns(patterns);
71+
populateVectorInterleaveLoweringPatterns(patterns);
7172
populateVectorTransposeLoweringPatterns(patterns,
7273
VectorTransformsOptions());
7374
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns(
159159
}
160160
}
161161

162+
void transform::ApplyLowerInterleavePatternsOp::populatePatterns(
163+
RewritePatternSet &patterns) {
164+
vector::populateVectorInterleaveLoweringPatterns(patterns);
165+
}
166+
162167
void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
163168
RewritePatternSet &patterns) {
164169
populateVectorNarrowTypeRewritePatterns(patterns);

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
44
LowerVectorBroadcast.cpp
55
LowerVectorContract.cpp
66
LowerVectorGather.cpp
7+
LowerVectorInterleave.cpp
78
LowerVectorMask.cpp
89
LowerVectorMultiReduction.cpp
910
LowerVectorScan.cpp
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
//===- LowerVectorInterleave.cpp - Lower 'vector.interleave' operation ----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements target-independent rewrites and utilities to lower the
10+
// 'vector.interleave' operation.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
15+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
16+
#include "mlir/IR/BuiltinTypes.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
19+
#define DEBUG_TYPE "vector-interleave-lowering"
20+
21+
using namespace mlir;
22+
using namespace mlir::vector;
23+
24+
namespace {
25+
26+
/// Progressive lowering of InterleaveOp.
27+
///
28+
/// Each leading dimension is unrolled until the result of the interleave is
29+
/// rank 1 (or the dimension is scalable, so can't be unrolled).
30+
///
31+
/// Example:
32+
///
33+
/// ```
34+
/// %0 = vector.interleave %lhs, %rhs : vector<2x...8xty>
35+
/// ```
36+
/// Becomes:
37+
/// ```
38+
/// %lhs_0 = vector.extract %lhs[0]
39+
/// %rhs_0 = vector.extract %rhs[0]
40+
/// %lhs_1 = vector.extract %lhs[1]
41+
/// %rhs_1 = vector.extract %rhs[1]
42+
/// %zip_0 = vector.interleave %lhs_0, %rhs_0
43+
/// %zip_1 = vector.interleave %lhs_1, %rhs_1
44+
/// %res_0 = vector.insert %zip_0, %undef[0]
45+
/// %0 = vector.insert %zip_1, %res_0[1]
46+
/// ```
47+
///
48+
/// If %zip_0 and %zip_1 still have a rank > 1 they will be unrolled again
49+
/// following the same pattern.
50+
class InterleaveOpLowering : public OpRewritePattern<vector::InterleaveOp> {
51+
public:
52+
using OpRewritePattern::OpRewritePattern;
53+
54+
LogicalResult matchAndRewrite(vector::InterleaveOp op,
55+
PatternRewriter &rewriter) const override {
56+
VectorType resultType = op.getResultVectorType();
57+
// 1-D vector.interleave ops can be directly lowered to LLVM (later).
58+
if (resultType.getRank() == 1)
59+
return failure();
60+
61+
// Below we unroll the leading (or front) dimension. If that dimension is
62+
// scalable we can't unroll it.
63+
if (resultType.getScalableDims().front())
64+
return failure();
65+
66+
// n-D case: Unroll the leading dimension.
67+
auto loc = op.getLoc();
68+
Value result = rewriter.create<arith::ConstantOp>(
69+
loc, resultType, rewriter.getZeroAttr(resultType));
70+
for (int idx = 0, end = resultType.getDimSize(0); idx < end; ++idx) {
71+
Value extractLhs = rewriter.create<ExtractOp>(loc, op.getLhs(), idx);
72+
Value extractRhs = rewriter.create<ExtractOp>(loc, op.getRhs(), idx);
73+
Value interleave =
74+
rewriter.create<InterleaveOp>(loc, extractLhs, extractRhs);
75+
result = rewriter.create<InsertOp>(loc, interleave, result, idx);
76+
}
77+
78+
rewriter.replaceOp(op, result);
79+
return success();
80+
}
81+
};
82+
83+
} // namespace
84+
85+
void mlir::vector::populateVectorInterleaveLoweringPatterns(
86+
RewritePatternSet &patterns, PatternBenefit benefit) {
87+
patterns.add<InterleaveOpLowering>(patterns.getContext(), benefit);
88+
}

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2497,3 +2497,27 @@ func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32
24972497
%0 = vector.interleave %a, %b : vector<[4]xi32>
24982498
return %0 : vector<[8]xi32>
24992499
}
2500+
2501+
// -----
2502+
2503+
// CHECK-LABEL: @vector_interleave_2d
2504+
// CHECK-SAME: %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>)
2505+
func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8>
2506+
{
2507+
// CHECK: llvm.shufflevector
2508+
// CHECK-NOT: vector.interleave {{.*}} : vector<2x3xi8>
2509+
%0 = vector.interleave %a, %b : vector<2x3xi8>
2510+
return %0 : vector<2x6xi8>
2511+
}
2512+
2513+
// -----
2514+
2515+
// CHECK-LABEL: @vector_interleave_2d_scalable
2516+
// CHECK-SAME: %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>)
2517+
func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16>
2518+
{
2519+
// CHECK: llvm.intr.experimental.vector.interleave2
2520+
// CHECK-NOT: vector.interleave {{.*}} : vector<2x[8]xi16>
2521+
%0 = vector.interleave %a, %b : vector<2x[8]xi16>
2522+
return %0 : vector<2x[16]xi16>
2523+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
2+
3+
// CHECK-LABEL: @vector_interleave_2d
4+
// CHECK-SAME: %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>)
5+
func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8>
6+
{
7+
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0>
8+
// CHECK-DAG: %[[LHS_0:.*]] = vector.extract %[[LHS]][0]
9+
// CHECK-DAG: %[[RHS_0:.*]] = vector.extract %[[RHS]][0]
10+
// CHECK-DAG: %[[LHS_1:.*]] = vector.extract %[[LHS]][1]
11+
// CHECK-DAG: %[[RHS_1:.*]] = vector.extract %[[RHS]][1]
12+
// CHECK-DAG: %[[ZIP_0:.*]] = vector.interleave %[[LHS_0]], %[[RHS_0]]
13+
// CHECK-DAG: %[[ZIP_1:.*]] = vector.interleave %[[LHS_1]], %[[RHS_1]]
14+
// CHECK-DAG: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0]
15+
// CHECK-DAG: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1]
16+
// CHECK-NEXT: return %[[RES_1]] : vector<2x6xi8>
17+
%0 = vector.interleave %a, %b : vector<2x3xi8>
18+
return %0 : vector<2x6xi8>
19+
}
20+
21+
// CHECK-LABEL: @vector_interleave_2d_scalable
22+
// CHECK-SAME: %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>)
23+
func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16>
24+
{
25+
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0>
26+
// CHECK-DAG: %[[LHS_0:.*]] = vector.extract %[[LHS]][0]
27+
// CHECK-DAG: %[[RHS_0:.*]] = vector.extract %[[RHS]][0]
28+
// CHECK-DAG: %[[LHS_1:.*]] = vector.extract %[[LHS]][1]
29+
// CHECK-DAG: %[[RHS_1:.*]] = vector.extract %[[RHS]][1]
30+
// CHECK-DAG: %[[ZIP_0:.*]] = vector.interleave %[[LHS_0]], %[[RHS_0]]
31+
// CHECK-DAG: %[[ZIP_1:.*]] = vector.interleave %[[LHS_1]], %[[RHS_1]]
32+
// CHECK-DAG: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0]
33+
// CHECK-DAG: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1]
34+
// CHECK-NEXT: return %[[RES_1]] : vector<2x[16]xi16>
35+
%0 = vector.interleave %a, %b : vector<2x[8]xi16>
36+
return %0 : vector<2x[16]xi16>
37+
}
38+
39+
module attributes {transform.with_named_sequence} {
40+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
41+
%f = transform.structured.match ops{["func.func"]} in %module_op
42+
: (!transform.any_op) -> !transform.any_op
43+
44+
transform.apply_patterns to %f {
45+
transform.apply_patterns.vector.lower_interleave
46+
} : !transform.any_op
47+
transform.yield
48+
}
49+
}

0 commit comments

Comments
 (0)