Skip to content

Commit 6f5c4f2

Browse files
authored
[mlir][vector]Add Vector bitwidth target to Linearize Vectorizable and Constant Ops (llvm#83314)
Added a new flag `targetVectorBitwidth` to capture bit-width input.
1 parent a5c90e4 commit 6f5c4f2

File tree

4 files changed

+136
-14
lines changed

4 files changed

+136
-14
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
387387
/// the ops to get converted properly.
388388
void populateVectorLinearizeTypeConversionsAndLegality(
389389
TypeConverter &typeConverter, RewritePatternSet &patterns,
390-
ConversionTarget &target);
390+
ConversionTarget &target, unsigned targetBitWidth);
391391

392392
} // namespace vector
393393
} // namespace mlir

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,30 @@
1919

2020
using namespace mlir;
2121

22+
static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
23+
auto resultTypes = op->getResultTypes();
24+
for (auto resType : resultTypes) {
25+
VectorType vecType = cast<VectorType>(resType);
26+
// Reject index since getElementTypeBitWidth will abort for Index types.
27+
if (vecType.getElementType().isIndex())
28+
return false;
29+
unsigned trailingVecDimBitWidth =
30+
vecType.getShape().back() * vecType.getElementTypeBitWidth();
31+
if (trailingVecDimBitWidth >= targetBitWidth)
32+
return false;
33+
}
34+
return true;
35+
}
36+
2237
namespace {
2338
struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
2439
using OpConversionPattern::OpConversionPattern;
25-
40+
LinearizeConstant(
41+
const TypeConverter &typeConverter, MLIRContext *context,
42+
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
43+
PatternBenefit benefit = 1)
44+
: OpConversionPattern(typeConverter, context, benefit),
45+
targetVectorBitWidth(targetVectBitWidth) {}
2646
LogicalResult
2747
matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
2848
ConversionPatternRewriter &rewriter) const override {
@@ -31,7 +51,9 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
3151
getTypeConverter()->convertType<VectorType>(constOp.getType());
3252
if (!resType)
3353
return rewriter.notifyMatchFailure(loc, "can't convert return type");
34-
54+
if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
55+
return rewriter.notifyMatchFailure(
56+
loc, "Can't flatten since targetBitWidth <= OpSize");
3557
auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
3658
if (!dstElementsAttr)
3759
return rewriter.notifyMatchFailure(loc, "unsupported attr type");
@@ -41,15 +63,28 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
4163
dstElementsAttr);
4264
return success();
4365
}
66+
67+
private:
68+
unsigned targetVectorBitWidth;
4469
};
4570

4671
struct LinearizeVectorizable final
4772
: OpTraitConversionPattern<OpTrait::Vectorizable> {
4873
using OpTraitConversionPattern::OpTraitConversionPattern;
4974

75+
public:
76+
LinearizeVectorizable(
77+
const TypeConverter &typeConverter, MLIRContext *context,
78+
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
79+
PatternBenefit benefit = 1)
80+
: OpTraitConversionPattern(typeConverter, context, benefit),
81+
targetVectorBitWidth(targetVectBitWidth) {}
5082
LogicalResult
5183
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
5284
ConversionPatternRewriter &rewriter) const override {
85+
if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
86+
return rewriter.notifyMatchFailure(
87+
op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
5388
FailureOr<Operation *> newOp =
5489
convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
5590
if (failed(newOp))
@@ -58,12 +93,16 @@ struct LinearizeVectorizable final
5893
rewriter.replaceOp(op, (*newOp)->getResults());
5994
return success();
6095
}
96+
97+
private:
98+
unsigned targetVectorBitWidth;
6199
};
62100
} // namespace
63101

64102
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
65103
TypeConverter &typeConverter, RewritePatternSet &patterns,
66-
ConversionTarget &target) {
104+
ConversionTarget &target, unsigned targetBitWidth) {
105+
67106
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
68107
// Ignore scalable vectors for now.
69108
if (type.getRank() <= 1 || type.isScalable())
@@ -83,15 +122,17 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
83122
typeConverter.addArgumentMaterialization(materializeCast);
84123
typeConverter.addSourceMaterialization(materializeCast);
85124
typeConverter.addTargetMaterialization(materializeCast);
86-
87125
target.markUnknownOpDynamicallyLegal(
88-
[&](Operation *op) -> std::optional<bool> {
89-
if (isa<arith::ConstantOp>(op) || op->hasTrait<OpTrait::Vectorizable>())
90-
return typeConverter.isLegal(op);
91-
126+
[=](Operation *op) -> std::optional<bool> {
127+
if ((isa<arith::ConstantOp>(op) ||
128+
op->hasTrait<OpTrait::Vectorizable>())) {
129+
return (isLessThanTargetBitWidth(op, targetBitWidth)
130+
? typeConverter.isLegal(op)
131+
: true);
132+
}
92133
return std::nullopt;
93134
});
94135

95-
patterns.add<LinearizeConstant, LinearizeVectorizable>(typeConverter,
96-
patterns.getContext());
136+
patterns.add<LinearizeConstant, LinearizeVectorizable>(
137+
typeConverter, patterns.getContext(), targetBitWidth);
97138
}
Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,92 @@
11
// RUN: mlir-opt %s -split-input-file -test-vector-linearize | FileCheck %s
2+
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 | FileCheck %s --check-prefix=CHECK128
3+
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefix=CHECK0
24

35
// CHECK-LABEL: test_linearize
6+
// CHECK128-LABEL: test_linearize
7+
// CHECK0-LABEL: test_linearize
48
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
9+
// CHECK128-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
510
// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
11+
// CHECK128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
612
func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
713
// CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
14+
// CHECK128: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
15+
// CHECK0: %[[C1:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
16+
817
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
918
// CHECK: %[[RES:.*]] = vector.shape_cast %[[C1]] : vector<4xf32> to vector<2x2xf32>
10-
19+
// CHECK128: %[[RES:.*]] = vector.shape_cast %[[C1]] : vector<4xf32> to vector<2x2xf32>
1120
// Arith and math ops are handled in generic way, check some of them
1221
// CHECK: %{{.*}} = math.sin %[[ARG]] : vector<4xf32>
22+
// CHECK128: %{{.*}} = math.sin %[[ARG]] : vector<4xf32>
23+
// CHECK0: %{{.*}} = math.sin %{{.*}} : vector<2x2xf32>
24+
%1 = math.sin %arg0 : vector<2x2xf32>
25+
// CHECK: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
26+
// CHECK128: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
27+
// CHECK0: %{{.*}} = arith.addf %{{.*}} : vector<2x2xf32>
28+
29+
%2 = arith.addf %arg0, %0 : vector<2x2xf32>
30+
31+
// CHECK: return %[[RES]] : vector<2x2xf32>
32+
// CHECK128: return %[[RES]] : vector<2x2xf32>
33+
return %0 : vector<2x2xf32>
34+
}
35+
36+
// CHECK-LABEL: test_partial_linearize
37+
// CHECK128-LABEL: test_partial_linearize
38+
// CHECK0-LABEL: test_partial_linearize
39+
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>)
40+
// CHECK128-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>)
41+
// CHECK0-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>)
42+
// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
43+
// CHECK128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
44+
// CHECK: %[[ARG2:.*]] = vector.shape_cast %[[ORIG_ARG2]] : vector<4x4xf32> to vector<16xf32>
45+
func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>) -> vector<2x2xf32> {
46+
// CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
47+
// CHECK128: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
48+
// CHECK0: %[[C1:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
49+
50+
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
51+
// CHECK: %[[RES:.*]] = vector.shape_cast %[[C1]] : vector<4xf32> to vector<2x2xf32>
52+
// CHECK128: %[[RES:.*]] = vector.shape_cast %[[C1]] : vector<4xf32> to vector<2x2xf32>
53+
54+
// CHECK: %[[C2:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 1.000000e+00, 2.000000e+00, 5.000000e+00, 6.000000e+00]> : vector<16xf32>
55+
// CHECK128: %[[C2:.*]] = arith.constant dense<{{.*}}> : vector<4x4xf32>
56+
// CHECK0: %[[C2:.*]] = arith.constant dense<{{.*}}> : vector<4x4xf32>
57+
%5 = arith.constant dense<[[1.0, 2.0, 3.0, 4.0], [1.0, 2.0,3.0, 4.0], [1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 5.0, 6.0]]> : vector<4x4xf32>
58+
// Arith and math ops are handled in generic way, check some of them
59+
// CHECK: %[[SIN:.*]] = math.sin %[[ARG]] : vector<4xf32>
60+
// CHECK128: %[[SIN:.*]] = math.sin %[[ARG]] : vector<4xf32>
61+
// CHECK0: %[[SIN:.*]] = math.sin %[[ORIG_ARG]] : vector<2x2xf32>
1362
%1 = math.sin %arg0 : vector<2x2xf32>
63+
64+
// CHECK: %[[SIN1:.*]] = math.sin %[[ARG2]] : vector<16xf32>
65+
// CHECK128: %[[SIN1:.*]] = math.sin %[[ORIG_ARG2]] : vector<4x4xf32>
66+
// CHECK0: %[[SIN1:.*]] = math.sin %[[ORIG_ARG2]] : vector<4x4xf32>
67+
%6 = math.sin %arg1 : vector<4x4xf32>
1468
// CHECK: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
69+
// CHECK128: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
70+
// CHECK0: %{{.*}} = arith.addf %{{.*}} : vector<2x2xf32>
71+
1572
%2 = arith.addf %arg0, %0 : vector<2x2xf32>
1673

74+
// CHECK: %[[ADD2:.*]] = arith.addf %[[ARG2]], %[[C2]] : vector<16xf32>
75+
// CHECK128: %[[ADD2:.*]] = arith.addf %[[ORIG_ARG2]], %[[C2]] : vector<4x4xf32>
76+
// CHECK0: %[[ADD2:.*]] = arith.addf %[[ORIG_ARG2]], %[[C2]] : vector<4x4xf32>
77+
%7 = arith.addf %arg1, %5 : vector<4x4xf32>
1778
// CHECK: return %[[RES]] : vector<2x2xf32>
79+
// CHECK128: return %[[RES]] : vector<2x2xf32>
1880
return %0 : vector<2x2xf32>
1981
}
82+
83+
// CHECK-LABEL: test_index_no_linearize
84+
// CHECK128-LABEL: test_index_no_linearize
85+
// CHECK0-LABEL: test_index_no_linearize
86+
func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
87+
// CHECK: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
88+
// CHECK128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
89+
// CHECK0: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
90+
%0 = arith.addi %arg0, %arg1 : vector<2x2xindex>
91+
return %0 : vector<2x2xindex>
92+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,9 @@ struct TestVectorLinearize final
840840
: public PassWrapper<TestVectorLinearize, OperationPass<>> {
841841
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
842842

843+
TestVectorLinearize() = default;
844+
TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {}
845+
843846
StringRef getArgument() const override { return "test-vector-linearize"; }
844847
StringRef getDescription() const override {
845848
return "Linearizes ND vectors for N >= 2 into 1D vectors";
@@ -848,15 +851,20 @@ struct TestVectorLinearize final
848851
registry.insert<vector::VectorDialect>();
849852
}
850853

854+
Option<unsigned> targetVectorBitwidth{
855+
*this, "target-vector-bitwidth",
856+
llvm::cl::desc(
857+
"Minimum vector bitwidth to enable the flattening transformation"),
858+
llvm::cl::init(std::numeric_limits<unsigned>::max())};
851859
void runOnOperation() override {
852860
auto *context = &getContext();
853861

854862
TypeConverter typeConverter;
855863
RewritePatternSet patterns(context);
856864
ConversionTarget target(*context);
857865

858-
vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter,
859-
patterns, target);
866+
vector::populateVectorLinearizeTypeConversionsAndLegality(
867+
typeConverter, patterns, target, targetVectorBitwidth);
860868
if (failed(applyPartialConversion(getOperation(), target,
861869
std::move(patterns))))
862870
return signalPassFailure();

0 commit comments

Comments
 (0)