Skip to content

Commit d3aa92e

Browse files
authored
[mlir][vector] Add support for scalable vectors to VectorLinearize (#86786)
Adds support for scalable vectors to patterns defined in VectorLineralize.cpp. Linearization is disable in 2 notable cases: * vectors with more than 1 scalable dimension (we cannot represent vscale^2), * vectors initialised with arith.constant that's not a vector splat (such arith.constant Ops cannot be flattened).
1 parent ffed554 commit d3aa92e

File tree

5 files changed

+86
-6
lines changed

5 files changed

+86
-6
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,16 @@ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
170170
PatternRewriter &rewriter) const = 0;
171171
};
172172

173+
/// Returns true if the input Vector type can be linearized.
174+
///
175+
/// Linearization is meant in the sense of flattening vectors, e.g.:
176+
/// * vector<NxMxKxi32> -> vector<N*M*Kxi32>
177+
/// In this sense, Vectors that are either:
178+
/// * already linearized, or
179+
/// * contain more than 1 scalable dimensions,
180+
/// are not linearizable.
181+
bool isLinearizableVector(VectorType type);
182+
173183
} // namespace vector
174184

175185
/// Constructs a permutation map of invariant memref indices to vector

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
4949
Location loc = constOp.getLoc();
5050
auto resType =
5151
getTypeConverter()->convertType<VectorType>(constOp.getType());
52+
53+
if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
54+
return rewriter.notifyMatchFailure(
55+
loc,
56+
"Cannot linearize a constant scalable vector that's not a splat");
57+
5258
if (!resType)
5359
return rewriter.notifyMatchFailure(loc, "can't convert return type");
5460
if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
@@ -104,11 +110,11 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
104110
ConversionTarget &target, unsigned targetBitWidth) {
105111

106112
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
107-
// Ignore scalable vectors for now.
108-
if (type.getRank() <= 1 || type.isScalable())
113+
if (!isLinearizableVector(type))
109114
return type;
110115

111-
return VectorType::get(type.getNumElements(), type.getElementType());
116+
return VectorType::get(type.getNumElements(), type.getElementType(),
117+
type.isScalable());
112118
});
113119

114120
auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,3 +317,8 @@ SmallVector<OpFoldResult> vector::getMixedSizesXfer(bool hasTensorSemantics,
317317
: memref::getMixedSizes(rewriter, loc, base);
318318
return mixedSourceDims;
319319
}
320+
321+
bool vector::isLinearizableVector(VectorType type) {
322+
auto numScalableDims = llvm::count(type.getScalableDims(), true);
323+
return (type.getRank() > 1) && (numScalableDims <= 1);
324+
}

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt %s -split-input-file -test-vector-linearize | FileCheck %s --check-prefixes=ALL,DEFAULT
2-
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 | FileCheck %s --check-prefixes=ALL,BW-128
1+
// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s --check-prefixes=ALL,DEFAULT
2+
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
33
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
44

55
// ALL-LABEL: test_linearize
@@ -97,3 +97,60 @@ func.func @test_tensor_no_linearize(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf3
9797

9898
return %0, %arg0 : tensor<2x2xf32>, tensor<2x2xf32>
9999
}
100+
101+
// -----
102+
103+
// ALL-LABEL: func.func @test_scalable_linearize(
104+
// ALL-SAME: %[[ARG_0:.*]]: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
105+
func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
106+
// DEFAULT: %[[SC:.*]] = vector.shape_cast %[[ARG_0]] : vector<2x[2]xf32> to vector<[4]xf32>
107+
// DEFAULT: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<[4]xf32>
108+
// BW-128: %[[SC:.*]] = vector.shape_cast %[[ARG_0]] : vector<2x[2]xf32> to vector<[4]xf32>
109+
// BW-128: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<[4]xf32>
110+
// BW-0: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<2x[2]xf32>
111+
%0 = arith.constant dense<[[3., 3.], [3., 3.]]> : vector<2x[2]xf32>
112+
113+
// DEFAULT: %[[SIN:.*]] = math.sin %[[SC]] : vector<[4]xf32>
114+
// BW-128: %[[SIN:.*]] = math.sin %[[SC]] : vector<[4]xf32>
115+
// BW-0: %[[SIN:.*]] = math.sin %[[ARG_0]] : vector<2x[2]xf32>
116+
%1 = math.sin %arg0 : vector<2x[2]xf32>
117+
118+
// DEFAULT: %[[ADDF:.*]] = arith.addf %[[SIN]], %[[CST]] : vector<[4]xf32>
119+
// BW-128: %[[ADDF:.*]] = arith.addf %[[SIN]], %[[CST]] : vector<[4]xf32>
120+
// BW-0: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<2x[2]xf32>
121+
%2 = arith.addf %0, %1 : vector<2x[2]xf32>
122+
123+
// DEFAULT: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
124+
// BW-128: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
125+
// ALL: return %[[RES]] : vector<2x[2]xf32>
126+
return %2 : vector<2x[2]xf32>
127+
}
128+
129+
// -----
130+
131+
// ALL-LABEL: func.func @test_scalable_no_linearize(
132+
// ALL-SAME: %[[VAL_0:.*]]: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
133+
func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
134+
// ALL: %[[CST:.*]] = arith.constant dense<2.000000e+00> : vector<[2]x[2]xf32>
135+
%0 = arith.constant dense<[[2., 2.], [2., 2.]]> : vector<[2]x[2]xf32>
136+
137+
// ALL: %[[SIN:.*]] = math.sin %[[VAL_0]] : vector<[2]x[2]xf32>
138+
%1 = math.sin %arg0 : vector<[2]x[2]xf32>
139+
140+
// ALL: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<[2]x[2]xf32>
141+
%2 = arith.addf %0, %1 : vector<[2]x[2]xf32>
142+
143+
// ALL: return %[[RES]] : vector<[2]x[2]xf32>
144+
return %2 : vector<[2]x[2]xf32>
145+
}
146+
147+
// -----
148+
149+
func.func @test_scalable_no_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
150+
// expected-error@+1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}}
151+
%0 = arith.constant dense<[[1., 1.], [3., 3.]]> : vector<2x[2]xf32>
152+
%1 = math.sin %arg0 : vector<2x[2]xf32>
153+
%2 = arith.addf %0, %1 : vector<2x[2]xf32>
154+
155+
return %2 : vector<2x[2]xf32>
156+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,9 @@ struct TestFlattenVectorTransferPatterns
489489
Option<unsigned> targetVectorBitwidth{
490490
*this, "target-vector-bitwidth",
491491
llvm::cl::desc(
492-
"Minimum vector bitwidth to enable the flattening transformation"),
492+
"Minimum vector bitwidth to enable the flattening transformation. "
493+
"For scalable vectors this is the base size, i.e. the size "
494+
"corresponding to vscale=1."),
493495
llvm::cl::init(std::numeric_limits<unsigned>::max())};
494496

495497
void runOnOperation() override {

0 commit comments

Comments
 (0)