Skip to content

Commit a1a6860

Browse files
authored
[mlir][VectorOps] Add unrolling for n-D vector.interleave ops (#80967)
This unrolls n-D vector.interleave ops like: ```mlir vector.interleave %i, %j : vector<6x3xf32> ``` To a sequence of 1-D operations: ```mlir %i_0 = vector.extract %i[0] %j_0 = vector.extract %j[0] %res_0 = vector.interleave %i_0, %j_0 : vector<3xf32> vector.insert %res_0, %result[0] : // ... repeated x6 ``` The 1-D operations can then be directly lowered to LLVM. Depends on: #80966
1 parent 052ee74 commit a1a6860

File tree

10 files changed

+255
-0
lines changed

10 files changed

+255
-0
lines changed

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,20 @@ def ApplyLowerTransposePatternsOp : Op<Transform_Dialect,
292292
}];
293293
}
294294

295+
def ApplyLowerInterleavePatternsOp : Op<Transform_Dialect,
296+
"apply_patterns.vector.lower_interleave",
297+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
298+
let description = [{
299+
Indicates that vector interleave operations should be lowered to
300+
finer-grained vector primitives.
301+
302+
This is usally a late step that is run after bufferization as part of the
303+
process of lowering to e.g. LLVM or NVVM.
304+
}];
305+
306+
let assemblyFormat = "attr-dict";
307+
}
308+
295309
def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
296310
"apply_patterns.vector.rewrite_narrow_types",
297311
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,15 @@ void populateVectorMaskLoweringPatternsForSideEffectingOps(
264264
void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns,
265265
PatternBenefit benefit = 1);
266266

267+
/// Populate the pattern set with the following patterns:
268+
///
269+
/// [UnrollInterleaveOp]
270+
/// A one-shot unrolling of InterleaveOp to (one or more) ExtractOp +
271+
/// InterleaveOp (of `targetRank`) + InsertOp.
272+
void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
273+
int64_t targetRank = 1,
274+
PatternBenefit benefit = 1);
275+
267276
} // namespace vector
268277
} // namespace mlir
269278
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
1010
#define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
1111

12+
#include "mlir/Dialect/Utils/IndexingUtils.h"
1213
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1314
#include "mlir/IR/BuiltinAttributes.h"
1415
#include "mlir/Support/LLVM.h"
@@ -75,6 +76,28 @@ FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
7576
/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
7677
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
7778

79+
/// Returns an iterator for all positions in the leading dimensions of `vType`
80+
/// up to the `targetRank`. If any leading dimension before the `targetRank` is
81+
/// scalable (so cannot be unrolled), it will return an iterator for positions
82+
/// up to the first scalable dimension.
83+
///
84+
/// If no leading dimensions can be unrolled an empty optional will be returned.
85+
///
86+
/// Examples:
87+
///
88+
/// For vType = vector<2x3x4> and targetRank = 1
89+
///
90+
/// The resulting iterator will yield:
91+
/// [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]
92+
///
93+
/// For vType = vector<3x[4]x5> and targetRank = 0
94+
///
95+
/// The scalable dimension blocks unrolling so the iterator yields only:
96+
/// [0], [1], [2]
97+
///
98+
std::optional<StaticTileOffsetRange>
99+
createUnrollIterator(VectorType vType, int64_t targetRank = 1);
100+
78101
} // namespace vector
79102

80103
/// Constructs a permutation map of invariant memref indices to vector

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
6868
populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions());
6969
populateVectorMaskOpLoweringPatterns(patterns);
7070
populateVectorShapeCastLoweringPatterns(patterns);
71+
populateVectorInterleaveLoweringPatterns(patterns);
7172
populateVectorTransposeLoweringPatterns(patterns,
7273
VectorTransformsOptions());
7374
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns(
159159
}
160160
}
161161

162+
void transform::ApplyLowerInterleavePatternsOp::populatePatterns(
163+
RewritePatternSet &patterns) {
164+
vector::populateVectorInterleaveLoweringPatterns(patterns);
165+
}
166+
162167
void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
163168
RewritePatternSet &patterns) {
164169
populateVectorNarrowTypeRewritePatterns(patterns);

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
44
LowerVectorBroadcast.cpp
55
LowerVectorContract.cpp
66
LowerVectorGather.cpp
7+
LowerVectorInterleave.cpp
78
LowerVectorMask.cpp
89
LowerVectorMultiReduction.cpp
910
LowerVectorScan.cpp
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
//===- LowerVectorInterleave.cpp - Lower 'vector.interleave' operation ----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements target-independent rewrites and utilities to lower the
10+
// 'vector.interleave' operation.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
15+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
16+
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
17+
#include "mlir/IR/BuiltinTypes.h"
18+
#include "mlir/IR/PatternMatch.h"
19+
20+
#define DEBUG_TYPE "vector-interleave-lowering"
21+
22+
using namespace mlir;
23+
using namespace mlir::vector;
24+
25+
namespace {
26+
27+
/// A one-shot unrolling of vector.interleave to the `targetRank`.
28+
///
29+
/// Example:
30+
///
31+
/// ```mlir
32+
/// vector.interleave %a, %b : vector<1x2x3x4xi64>
33+
/// ```
34+
/// Would be unrolled to:
35+
/// ```mlir
36+
/// %result = arith.constant dense<0> : vector<1x2x3x8xi64>
37+
/// %0 = vector.extract %a[0, 0, 0] ─┐
38+
/// : vector<4xi64> from vector<1x2x3x4xi64> |
39+
/// %1 = vector.extract %b[0, 0, 0] |
40+
/// : vector<4xi64> from vector<1x2x3x4xi64> | - Repeated 6x for
41+
/// %2 = vector.interleave %0, %1 : vector<4xi64> | all leading positions
42+
/// %3 = vector.insert %2, %result [0, 0, 0] |
43+
/// : vector<8xi64> into vector<1x2x3x8xi64> ┘
44+
/// ```
45+
///
46+
/// Note: If any leading dimension before the `targetRank` is scalable the
47+
/// unrolling will stop before the scalable dimension.
48+
class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
49+
public:
50+
UnrollInterleaveOp(int64_t targetRank, MLIRContext *context,
51+
PatternBenefit benefit = 1)
52+
: OpRewritePattern(context, benefit), targetRank(targetRank){};
53+
54+
LogicalResult matchAndRewrite(vector::InterleaveOp op,
55+
PatternRewriter &rewriter) const override {
56+
VectorType resultType = op.getResultVectorType();
57+
auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
58+
if (!unrollIterator)
59+
return failure();
60+
61+
auto loc = op.getLoc();
62+
Value result = rewriter.create<arith::ConstantOp>(
63+
loc, resultType, rewriter.getZeroAttr(resultType));
64+
for (auto position : *unrollIterator) {
65+
Value extractLhs = rewriter.create<ExtractOp>(loc, op.getLhs(), position);
66+
Value extractRhs = rewriter.create<ExtractOp>(loc, op.getRhs(), position);
67+
Value interleave =
68+
rewriter.create<InterleaveOp>(loc, extractLhs, extractRhs);
69+
result = rewriter.create<InsertOp>(loc, interleave, result, position);
70+
}
71+
72+
rewriter.replaceOp(op, result);
73+
return success();
74+
}
75+
76+
private:
77+
int64_t targetRank = 1;
78+
};
79+
80+
} // namespace
81+
82+
void mlir::vector::populateVectorInterleaveLoweringPatterns(
83+
RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
84+
patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit);
85+
}

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,25 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
278278

279279
return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
280280
}
281+
282+
std::optional<StaticTileOffsetRange>
283+
vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
284+
if (vType.getRank() <= targetRank)
285+
return {};
286+
// Attempt to unroll until targetRank or the first scalable dimension (which
287+
// cannot be unrolled).
288+
auto shapeToUnroll = vType.getShape().drop_back(targetRank);
289+
auto scalableDimsToUnroll = vType.getScalableDims().drop_back(targetRank);
290+
auto it =
291+
std::find(scalableDimsToUnroll.begin(), scalableDimsToUnroll.end(), true);
292+
auto firstScalableDim = it - scalableDimsToUnroll.begin();
293+
if (firstScalableDim == 0)
294+
return {};
295+
// All scalable dimensions should be removed now.
296+
scalableDimsToUnroll = scalableDimsToUnroll.slice(0, firstScalableDim);
297+
assert(!llvm::is_contained(scalableDimsToUnroll, true) &&
298+
"unexpected leading scalable dimension");
299+
// Create an unroll iterator for leading dimensions.
300+
shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim);
301+
return StaticTileOffsetRange(shapeToUnroll, /*unrollStep=*/1);
302+
}

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2497,3 +2497,27 @@ func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32
24972497
%0 = vector.interleave %a, %b : vector<[4]xi32>
24982498
return %0 : vector<[8]xi32>
24992499
}
2500+
2501+
// -----
2502+
2503+
// CHECK-LABEL: @vector_interleave_2d
2504+
// CHECK-SAME: %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>)
2505+
func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8>
2506+
{
2507+
// CHECK: llvm.shufflevector
2508+
// CHECK-NOT: vector.interleave {{.*}} : vector<2x3xi8>
2509+
%0 = vector.interleave %a, %b : vector<2x3xi8>
2510+
return %0 : vector<2x6xi8>
2511+
}
2512+
2513+
// -----
2514+
2515+
// CHECK-LABEL: @vector_interleave_2d_scalable
2516+
// CHECK-SAME: %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>)
2517+
func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16>
2518+
{
2519+
// CHECK: llvm.intr.experimental.vector.interleave2
2520+
// CHECK-NOT: vector.interleave {{.*}} : vector<2x[8]xi16>
2521+
%0 = vector.interleave %a, %b : vector<2x[8]xi16>
2522+
return %0 : vector<2x[16]xi16>
2523+
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
2+
3+
// CHECK-LABEL: @vector_interleave_2d
4+
// CHECK-SAME: %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>)
5+
func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8>
6+
{
7+
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0>
8+
// CHECK-DAG: %[[LHS_0:.*]] = vector.extract %[[LHS]][0]
9+
// CHECK-DAG: %[[RHS_0:.*]] = vector.extract %[[RHS]][0]
10+
// CHECK-DAG: %[[LHS_1:.*]] = vector.extract %[[LHS]][1]
11+
// CHECK-DAG: %[[RHS_1:.*]] = vector.extract %[[RHS]][1]
12+
// CHECK-DAG: %[[ZIP_0:.*]] = vector.interleave %[[LHS_0]], %[[RHS_0]]
13+
// CHECK-DAG: %[[ZIP_1:.*]] = vector.interleave %[[LHS_1]], %[[RHS_1]]
14+
// CHECK-DAG: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0]
15+
// CHECK-DAG: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1]
16+
// CHECK-NEXT: return %[[RES_1]] : vector<2x6xi8>
17+
%0 = vector.interleave %a, %b : vector<2x3xi8>
18+
return %0 : vector<2x6xi8>
19+
}
20+
21+
// CHECK-LABEL: @vector_interleave_2d_scalable
22+
// CHECK-SAME: %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>)
23+
func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16>
24+
{
25+
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0>
26+
// CHECK-DAG: %[[LHS_0:.*]] = vector.extract %[[LHS]][0]
27+
// CHECK-DAG: %[[RHS_0:.*]] = vector.extract %[[RHS]][0]
28+
// CHECK-DAG: %[[LHS_1:.*]] = vector.extract %[[LHS]][1]
29+
// CHECK-DAG: %[[RHS_1:.*]] = vector.extract %[[RHS]][1]
30+
// CHECK-DAG: %[[ZIP_0:.*]] = vector.interleave %[[LHS_0]], %[[RHS_0]]
31+
// CHECK-DAG: %[[ZIP_1:.*]] = vector.interleave %[[LHS_1]], %[[RHS_1]]
32+
// CHECK-DAG: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0]
33+
// CHECK-DAG: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1]
34+
// CHECK-NEXT: return %[[RES_1]] : vector<2x[16]xi16>
35+
%0 = vector.interleave %a, %b : vector<2x[8]xi16>
36+
return %0 : vector<2x[16]xi16>
37+
}
38+
39+
// CHECK-LABEL: @vector_interleave_4d
40+
// CHECK-SAME: %[[LHS:.*]]: vector<1x2x3x4xi64>, %[[RHS:.*]]: vector<1x2x3x4xi64>)
41+
func.func @vector_interleave_4d(%a: vector<1x2x3x4xi64>, %b: vector<1x2x3x4xi64>) -> vector<1x2x3x8xi64>
42+
{
43+
// CHECK: %[[LHS_0:.*]] = vector.extract %[[LHS]][0, 0, 0] : vector<4xi64> from vector<1x2x3x4xi64>
44+
// CHECK: %[[RHS_0:.*]] = vector.extract %[[RHS]][0, 0, 0] : vector<4xi64> from vector<1x2x3x4xi64>
45+
// CHECK: %[[ZIP_0:.*]] = vector.interleave %[[LHS_0]], %[[RHS_0]] : vector<4xi64>
46+
// CHECK: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %{{.*}} [0, 0, 0] : vector<8xi64> into vector<1x2x3x8xi64>
47+
// CHECK-COUNT-5: vector.interleave %{{.*}}, %{{.*}} : vector<4xi64>
48+
%0 = vector.interleave %a, %b : vector<1x2x3x4xi64>
49+
return %0 : vector<1x2x3x8xi64>
50+
}
51+
52+
// CHECK-LABEL: @vector_interleave_nd_with_scalable_dim
53+
func.func @vector_interleave_nd_with_scalable_dim(%a: vector<1x3x[2]x2x3x4xf16>, %b: vector<1x3x[2]x2x3x4xf16>) -> vector<1x3x[2]x2x3x8xf16>
54+
{
55+
// The scalable dim blocks unrolling so only the first two dims are unrolled.
56+
// CHECK-COUNT-3: vector.interleave %{{.*}}, %{{.*}} : vector<[2]x2x3x4xf16>
57+
%0 = vector.interleave %a, %b : vector<1x3x[2]x2x3x4xf16>
58+
return %0 : vector<1x3x[2]x2x3x8xf16>
59+
}
60+
61+
module attributes {transform.with_named_sequence} {
62+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
63+
%f = transform.structured.match ops{["func.func"]} in %module_op
64+
: (!transform.any_op) -> !transform.any_op
65+
66+
transform.apply_patterns to %f {
67+
transform.apply_patterns.vector.lower_interleave
68+
} : !transform.any_op
69+
transform.yield
70+
}
71+
}

0 commit comments

Comments
 (0)