Skip to content

Commit 0ea1271

Browse files
authored
[mlir][vector] Add support for unrolling vector.bitcast ops. (#94064)
The revision unrolls vector.bitcast like: ```mlir %0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64> ``` to ```mlir %cst = arith.constant dense<0> : vector<2x2xi64> %0 = vector.extract %arg0[0] : vector<4xi32> from vector<2x4xi32> %1 = vector.bitcast %0 : vector<4xi32> to vector<2xi64> %2 = vector.insert %1, %cst [0] : vector<2xi64> into vector<2x2xi64> %3 = vector.extract %arg0[1] : vector<4xi32> from vector<2x4xi32> %4 = vector.bitcast %3 : vector<4xi32> to vector<2xi64> %5 = vector.insert %4, %2 [1] : vector<2xi64> into vector<2x2xi64> ``` The scalable vector is not supported because of the limitation of `vector::createUnrollIterator`. The targetRank could mismatch the final rank during unrolling; there is no direct way to query what the final rank is from the object.
1 parent 43847c1 commit 0ea1271

File tree

8 files changed

+189
-0
lines changed

8 files changed

+189
-0
lines changed

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,20 @@ def ApplyTransferPermutationPatternsOp : Op<Transform_Dialect,
8989
let assemblyFormat = "attr-dict";
9090
}
9191

92+
def ApplyLowerBitCastPatternsOp : Op<Transform_Dialect,
93+
"apply_patterns.vector.lower_bitcast",
94+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
95+
let description = [{
96+
Indicates that vector bitcast operations should be lowered to
97+
finer-grained vector primitives.
98+
99+
This is usally a late step that is run after bufferization as part of the
100+
process of lowering to e.g. LLVM or NVVM.
101+
}];
102+
103+
let assemblyFormat = "attr-dict";
104+
}
105+
92106
def ApplyLowerBroadcastPatternsOp : Op<Transform_Dialect,
93107
"apply_patterns.vector.lower_broadcast",
94108
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,15 @@ void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
276276
void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns,
277277
PatternBenefit benefit = 1);
278278

279+
/// Populates the pattern set with the following patterns:
280+
///
281+
/// [UnrollBitCastOp]
282+
/// A one-shot unrolling of BitCastOp to (one or more) ExtractOp +
283+
/// BitCastOp (of `targetRank`) + InsertOp.
284+
void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
285+
int64_t targetRank = 1,
286+
PatternBenefit benefit = 1);
287+
279288
} // namespace vector
280289
} // namespace mlir
281290
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
6464
{
6565
RewritePatternSet patterns(&getContext());
6666
populateVectorToVectorCanonicalizationPatterns(patterns);
67+
populateVectorBitCastLoweringPatterns(patterns);
6768
populateVectorBroadcastLoweringPatterns(patterns);
6869
populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions());
6970
populateVectorMaskOpLoweringPatterns(patterns);

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
7979
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
8080
}
8181

82+
void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
83+
RewritePatternSet &patterns) {
84+
vector::populateVectorBitCastLoweringPatterns(patterns);
85+
}
86+
8287
void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
8388
RewritePatternSet &patterns) {
8489
populateVectorBroadcastLoweringPatterns(patterns);

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_dialect_library(MLIRVectorTransforms
22
BufferizableOpInterfaceImpl.cpp
3+
LowerVectorBitCast.cpp
34
LowerVectorBroadcast.cpp
45
LowerVectorContract.cpp
56
LowerVectorGather.cpp
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
//===- LowerVectorBitCast.cpp - Lower 'vector.bitcast' operation ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements target-independent rewrites and utilities to lower the
10+
// 'vector.bitcast' operation.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
15+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
16+
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
17+
#include "mlir/IR/BuiltinTypes.h"
18+
#include "mlir/IR/PatternMatch.h"
19+
#include "mlir/Support/LogicalResult.h"
20+
21+
#define DEBUG_TYPE "vector-bitcast-lowering"
22+
23+
using namespace mlir;
24+
using namespace mlir::vector;
25+
26+
namespace {
27+
28+
/// A one-shot unrolling of vector.bitcast to the `targetRank`.
29+
///
30+
/// Example:
31+
///
32+
/// vector.bitcast %a, %b : vector<1x2x3x4xi64> to vector<1x2x3x8xi32>
33+
///
34+
/// Would be unrolled to:
35+
///
36+
/// %result = arith.constant dense<0> : vector<1x2x3x8xi32>
37+
/// %0 = vector.extract %a[0, 0, 0] ─┐
38+
/// : vector<4xi64> from vector<1x2x3x4xi64> |
39+
/// %1 = vector.bitcast %0 | - Repeated 6x for
40+
/// : vector<4xi64> to vector<8xi32> | all leading positions
41+
/// %2 = vector.insert %1, %result [0, 0, 0] |
42+
/// : vector<8xi64> into vector<1x2x3x8xi32> ─┘
43+
///
44+
/// Note: If any leading dimension before the `targetRank` is scalable the
45+
/// unrolling will stop before the scalable dimension.
46+
class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
47+
public:
48+
UnrollBitCastOp(int64_t targetRank, MLIRContext *context,
49+
PatternBenefit benefit = 1)
50+
: OpRewritePattern(context, benefit), targetRank(targetRank) {};
51+
52+
LogicalResult matchAndRewrite(vector::BitCastOp op,
53+
PatternRewriter &rewriter) const override {
54+
VectorType resultType = op.getResultVectorType();
55+
auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
56+
if (!unrollIterator)
57+
return failure();
58+
59+
// TODO: Support the scalable vector cases. It is not supported because
60+
// the final rank could be values other than `targetRank`. It makes creating
61+
// the result type of new vector.bitcast ops much harder.
62+
if (resultType.isScalable()) {
63+
return rewriter.notifyMatchFailure(op,
64+
"unrolling vector.bitcast on scalable "
65+
"vectors is not yet implemented");
66+
}
67+
68+
ArrayRef<int64_t> shape = resultType.getShape().take_back(targetRank);
69+
auto bitcastResType = VectorType::get(shape, resultType.getElementType());
70+
71+
Location loc = op.getLoc();
72+
Value result = rewriter.create<arith::ConstantOp>(
73+
loc, resultType, rewriter.getZeroAttr(resultType));
74+
for (auto position : *unrollIterator) {
75+
Value extract =
76+
rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
77+
Value bitcast =
78+
rewriter.create<vector::BitCastOp>(loc, bitcastResType, extract);
79+
result =
80+
rewriter.create<vector::InsertOp>(loc, bitcast, result, position);
81+
}
82+
83+
rewriter.replaceOp(op, result);
84+
return success();
85+
}
86+
87+
private:
88+
int64_t targetRank = 1;
89+
};
90+
91+
} // namespace
92+
93+
void mlir::vector::populateVectorBitCastLoweringPatterns(
94+
RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
95+
patterns.add<UnrollBitCastOp>(targetRank, patterns.getContext(), benefit);
96+
}

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2564,3 +2564,13 @@ func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi
25642564
%0, %1 = vector.deinterleave %a : vector<[4]xi32> -> vector<[2]xi32>
25652565
return %0, %1 : vector<[2]xi32>, vector<[2]xi32>
25662566
}
2567+
2568+
// -----
2569+
2570+
// CHECK-LABEL: func.func @vector_bitcast_2d
2571+
// CHECK: llvm.bitcast
2572+
// CHECK-NOT: vector.bitcast
2573+
func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> {
2574+
%0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64>
2575+
return %0 : vector<2x2xi64>
2576+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
2+
3+
func.func @vector_bitcast_0d(%arg0: vector<i32>) -> vector<f32> {
4+
%0 = vector.bitcast %arg0 : vector<i32> to vector<f32>
5+
return %0 : vector<f32>
6+
}
7+
// CHECK-LABEL: func.func @vector_bitcast_0d
8+
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
9+
// CHECK: %[[RES:.+]] = vector.bitcast %[[IN]] : vector<i32> to vector<f32>
10+
// CHECK: return %[[RES]]
11+
12+
func.func @vector_bitcast_1d(%arg0: vector<10xi64>) -> vector<20xi32> {
13+
%0 = vector.bitcast %arg0 : vector<10xi64> to vector<20xi32>
14+
return %0 : vector<20xi32>
15+
}
16+
// CHECK-LABEL: func.func @vector_bitcast_1d
17+
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
18+
// CHECK: %[[RES:.+]] = vector.bitcast %[[IN]] : vector<10xi64> to vector<20xi32>
19+
// CHECK: return %[[RES]]
20+
21+
func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> {
22+
%0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64>
23+
return %0 : vector<2x2xi64>
24+
}
25+
// CHECK-LABEL: func.func @vector_bitcast_2d
26+
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
27+
// CHECK: %[[INIT:.+]] = arith.constant {{.+}} : vector<2x2xi64>
28+
// CHECK: %[[V1:.+]] = vector.extract %[[IN]][0] : vector<4xi32> from vector<2x4xi32>
29+
// CHECK: %[[B1:.+]] = vector.bitcast %[[V1]] : vector<4xi32> to vector<2xi64>
30+
// CHECK: %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0]
31+
// CHECK: %[[V2:.+]] = vector.extract %[[IN]][1] : vector<4xi32> from vector<2x4xi32>
32+
// CHECK: %[[B2:.+]] = vector.bitcast %[[V2]] : vector<4xi32> to vector<2xi64>
33+
// CHECK: %[[R2:.+]] = vector.insert %[[B2]], %[[R1]] [1]
34+
// CHECK: return %[[R2]]
35+
36+
func.func @vector_bitcast_4d_with_scalable_dim(%arg0: vector<1x2x[3]x4xi64>) -> vector<1x2x[3]x8xi32> {
37+
%0 = vector.bitcast %arg0 : vector<1x2x[3]x4xi64> to vector<1x2x[3]x8xi32>
38+
return %0 : vector<1x2x[3]x8xi32>
39+
}
40+
// CHECK-LABEL: func.func @vector_bitcast_4d_with_scalable_dim
41+
// CHECK: vector.bitcast {{.+}} : vector<1x2x[3]x4xi64> to vector<1x2x[3]x8xi32>
42+
43+
module attributes {transform.with_named_sequence} {
44+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
45+
%f = transform.structured.match ops{["func.func"]} in %module_op
46+
: (!transform.any_op) -> !transform.any_op
47+
48+
transform.apply_patterns to %f {
49+
transform.apply_patterns.vector.lower_bitcast
50+
} : !transform.any_op
51+
transform.yield
52+
}
53+
}

0 commit comments

Comments
 (0)