-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Add support for scalable vectors to VectorLinearize #86786
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Add support for scalable vectors to VectorLinearize #86786
Conversation
@llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/86786.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 2c548fb6740251..f88fbdf9e62765 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -170,6 +170,16 @@ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
PatternRewriter &rewriter) const = 0;
};
+/// Returns true if the input Vector type can be linearized.
+///
+/// Linearization is meant in the sense of flattening vectors, e.g.:
+/// * vector<NxMxKxi32> -> vector<N*M*Kxi32>
+/// In this sense, Vectors that are either:
+/// * already linearized, or
+/// * contain more than 1 scalable dimensions,
+/// are not linearizable.
+bool isLinearizableVector(VectorType type);
+
} // namespace vector
/// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 38536de43f13f2..c8043fbb7c3061 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -49,6 +49,11 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
Location loc = constOp.getLoc();
auto resType =
getTypeConverter()->convertType<VectorType>(constOp.getType());
+
+ if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
+ return rewriter.notifyMatchFailure(
+ loc, "Cannot linearize a constant scalable vector that's not a splt");
+
if (!resType)
return rewriter.notifyMatchFailure(loc, "can't convert return type");
if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
@@ -104,11 +109,11 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
ConversionTarget &target, unsigned targetBitWidth) {
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
- // Ignore scalable vectors for now.
- if (type.getRank() <= 1 || type.isScalable())
+ if (!isLinearizableVector(type))
return type;
- return VectorType::get(type.getNumElements(), type.getElementType());
+ return VectorType::get(type.getNumElements(), type.getElementType(),
+ type.isScalable());
});
auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 63ed0947cf6ce2..a4415a80139af1 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -317,3 +317,8 @@ SmallVector<OpFoldResult> vector::getMixedSizesXfer(bool hasTensorSemantics,
: memref::getMixedSizes(rewriter, loc, base);
return mixedSourceDims;
}
+
+bool vector::isLinearizableVector(VectorType type) {
+ auto numScalableDims = llvm::count(type.getScalableDims(), true);
+ return ((type.getRank() > 1) && (numScalableDims <= 1));
+}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 1b225c7a97d233..3ab68f19aa0c60 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -97,3 +97,47 @@ func.func @test_tensor_no_linearize(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf3
return %0, %arg0 : tensor<2x2xf32>, tensor<2x2xf32>
}
+
+// -----
+
+// ALL-LABEL: func.func @test_1_scalable_dim(
+// ALL-SAME: %[[ARG_0:.*]]: vector<2x[4]xf32>) -> vector<2x[4]xf32> {
+func.func @test_1_scalable_dim(%arg0: vector<2x[4]xf32>) -> vector<2x[4]xf32> {
+ // DEFAULT: %[[SC:.*]] = vector.shape_cast %[[ARG_0]] : vector<2x[4]xf32> to vector<[8]xf32>
+ // DEFAULT: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<[8]xf32>
+ // BW-128: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<2x[4]xf32>
+ // BW-0: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<2x[4]xf32>
+ %0 = arith.constant dense<[[3., 3., 3., 3.], [3., 3., 3., 3.]]> : vector<2x[4]xf32>
+
+ // DEFAULT: %[[SIN:.*]] = math.sin %[[SC]] : vector<[8]xf32>
+ // BW-128: %[[SIN:.*]] = math.sin %[[ARG_0]] : vector<2x[4]xf32>
+ // BW-0: %[[SIN:.*]] = math.sin %[[ARG_0]] : vector<2x[4]xf32>
+ %1 = math.sin %arg0 : vector<2x[4]xf32>
+
+ // DEFAULT: %[[ADDF:.*]] = arith.addf %[[SIN]], %[[CST]] : vector<[8]xf32>
+ // BW-128: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<2x[4]xf32>
+ // BW-0: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<2x[4]xf32>
+ %2 = arith.addf %0, %1 : vector<2x[4]xf32>
+
+ // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[8]xf32> to vector<2x[4]xf32>
+ // ALL: return %[[RES]] : vector<2x[4]xf32>
+ return %2 : vector<2x[4]xf32>
+}
+
+// -----
+
+// ALL-LABEL: func.func @test_2_scalable_dims(
+// ALL-SAME: %[[VAL_0:.*]]: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
+func.func @test_2_scalable_dims(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
+ // ALL: %[[CST:.*]] = arith.constant dense<2.000000e+00> : vector<[2]x[2]xf32>
+ %0 = arith.constant dense<[[2., 2.], [2., 2.]]> : vector<[2]x[2]xf32>
+
+ // ALL: %[[SIN:.*]] = math.sin %[[VAL_0]] : vector<[2]x[2]xf32>
+ %1 = math.sin %arg0 : vector<[2]x[2]xf32>
+
+ // ALL: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<[2]x[2]xf32>
+ %2 = arith.addf %0, %1 : vector<[2]x[2]xf32>
+
+ // ALL: return %[[RES]] : vector<[2]x[2]xf32>
+ return %2 : vector<[2]x[2]xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f14fb18706d1b7..766ddae47c53b9 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -489,7 +489,9 @@ struct TestFlattenVectorTransferPatterns
Option<unsigned> targetVectorBitwidth{
*this, "target-vector-bitwidth",
llvm::cl::desc(
- "Minimum vector bitwidth to enable the flattening transformation"),
+ "Minimum vector bitwidth to enable the flattening transformation. "
+ "For scalable vectors this is the base size that's known at compile "
+ "time."),
llvm::cl::init(std::numeric_limits<unsigned>::max())};
void runOnOperation() override {
|
@llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/86786.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 2c548fb6740251..f88fbdf9e62765 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -170,6 +170,16 @@ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
PatternRewriter &rewriter) const = 0;
};
+/// Returns true if the input Vector type can be linearized.
+///
+/// Linearization is meant in the sense of flattening vectors, e.g.:
+/// * vector<NxMxKxi32> -> vector<N*M*Kxi32>
+/// In this sense, Vectors that are either:
+/// * already linearized, or
+/// * contain more than 1 scalable dimensions,
+/// are not linearizable.
+bool isLinearizableVector(VectorType type);
+
} // namespace vector
/// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 38536de43f13f2..c8043fbb7c3061 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -49,6 +49,11 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
Location loc = constOp.getLoc();
auto resType =
getTypeConverter()->convertType<VectorType>(constOp.getType());
+
+ if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
+ return rewriter.notifyMatchFailure(
+ loc, "Cannot linearize a constant scalable vector that's not a splt");
+
if (!resType)
return rewriter.notifyMatchFailure(loc, "can't convert return type");
if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
@@ -104,11 +109,11 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
ConversionTarget &target, unsigned targetBitWidth) {
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
- // Ignore scalable vectors for now.
- if (type.getRank() <= 1 || type.isScalable())
+ if (!isLinearizableVector(type))
return type;
- return VectorType::get(type.getNumElements(), type.getElementType());
+ return VectorType::get(type.getNumElements(), type.getElementType(),
+ type.isScalable());
});
auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 63ed0947cf6ce2..a4415a80139af1 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -317,3 +317,8 @@ SmallVector<OpFoldResult> vector::getMixedSizesXfer(bool hasTensorSemantics,
: memref::getMixedSizes(rewriter, loc, base);
return mixedSourceDims;
}
+
+bool vector::isLinearizableVector(VectorType type) {
+ auto numScalableDims = llvm::count(type.getScalableDims(), true);
+ return ((type.getRank() > 1) && (numScalableDims <= 1));
+}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 1b225c7a97d233..3ab68f19aa0c60 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -97,3 +97,47 @@ func.func @test_tensor_no_linearize(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf3
return %0, %arg0 : tensor<2x2xf32>, tensor<2x2xf32>
}
+
+// -----
+
+// ALL-LABEL: func.func @test_1_scalable_dim(
+// ALL-SAME: %[[ARG_0:.*]]: vector<2x[4]xf32>) -> vector<2x[4]xf32> {
+func.func @test_1_scalable_dim(%arg0: vector<2x[4]xf32>) -> vector<2x[4]xf32> {
+ // DEFAULT: %[[SC:.*]] = vector.shape_cast %[[ARG_0]] : vector<2x[4]xf32> to vector<[8]xf32>
+ // DEFAULT: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<[8]xf32>
+ // BW-128: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<2x[4]xf32>
+ // BW-0: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<2x[4]xf32>
+ %0 = arith.constant dense<[[3., 3., 3., 3.], [3., 3., 3., 3.]]> : vector<2x[4]xf32>
+
+ // DEFAULT: %[[SIN:.*]] = math.sin %[[SC]] : vector<[8]xf32>
+ // BW-128: %[[SIN:.*]] = math.sin %[[ARG_0]] : vector<2x[4]xf32>
+ // BW-0: %[[SIN:.*]] = math.sin %[[ARG_0]] : vector<2x[4]xf32>
+ %1 = math.sin %arg0 : vector<2x[4]xf32>
+
+ // DEFAULT: %[[ADDF:.*]] = arith.addf %[[SIN]], %[[CST]] : vector<[8]xf32>
+ // BW-128: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<2x[4]xf32>
+ // BW-0: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<2x[4]xf32>
+ %2 = arith.addf %0, %1 : vector<2x[4]xf32>
+
+ // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[8]xf32> to vector<2x[4]xf32>
+ // ALL: return %[[RES]] : vector<2x[4]xf32>
+ return %2 : vector<2x[4]xf32>
+}
+
+// -----
+
+// ALL-LABEL: func.func @test_2_scalable_dims(
+// ALL-SAME: %[[VAL_0:.*]]: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
+func.func @test_2_scalable_dims(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
+ // ALL: %[[CST:.*]] = arith.constant dense<2.000000e+00> : vector<[2]x[2]xf32>
+ %0 = arith.constant dense<[[2., 2.], [2., 2.]]> : vector<[2]x[2]xf32>
+
+ // ALL: %[[SIN:.*]] = math.sin %[[VAL_0]] : vector<[2]x[2]xf32>
+ %1 = math.sin %arg0 : vector<[2]x[2]xf32>
+
+ // ALL: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<[2]x[2]xf32>
+ %2 = arith.addf %0, %1 : vector<[2]x[2]xf32>
+
+ // ALL: return %[[RES]] : vector<[2]x[2]xf32>
+ return %2 : vector<[2]x[2]xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f14fb18706d1b7..766ddae47c53b9 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -489,7 +489,9 @@ struct TestFlattenVectorTransferPatterns
Option<unsigned> targetVectorBitwidth{
*this, "target-vector-bitwidth",
llvm::cl::desc(
- "Minimum vector bitwidth to enable the flattening transformation"),
+ "Minimum vector bitwidth to enable the flattening transformation. "
+ "For scalable vectors this is the base size that's known at compile "
+ "time."),
llvm::cl::init(std::numeric_limits<unsigned>::max())};
void runOnOperation() override {
|
Adds support for scalable vectors to patterns defined in VectorLineralize.cpp. Linearization is disable in 2 notable cases: * vectors with more than 1 scalable dimension (we cannot represent vscale^2), * vectors initialised with arith.constant that's not a vector splat (such arith.constant Ops cannot be flattened).
3b3ab4e
to
457722c
Compare
"For scalable vectors this is the base size that's known at compile " | ||
"time."), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"For scalable vectors this is the base size that's known at compile " | |
"time."), | |
"For scalable vectors this is the base size (vscale=1) that's known at compile " | |
"time."), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that this wording would be a bit confusing. How about:
"Minimum vector bitwidth to enable the flattening transformation. " "For scalable vectors this is the base size, i.e. the size " "corresponding to vscale=1."),
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, or simply "For scalable vectors this is the size corresponding to vscale=1."
. To me the important thing is mentioning vscale=1 as I think base size is a bit ambiguous.
…rize Address PR comments
|
||
bool vector::isLinearizableVector(VectorType type) { | ||
auto numScalableDims = llvm::count(type.getScalableDims(), true); | ||
return ((type.getRank() > 1) && (numScalableDims <= 1)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can drop the outermost ()
pair.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both the outer and inner parentheses are redundant here: type.getRank() > 1 && numScalableDims <= 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh good catch! I did not notice that there are inner parentheses!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to keep the inner parentheses - I find that easier to parse. Probably a subjective thing 😅
…orLinearize Addressing PR comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM cheers
Adds support for scalable vectors to patterns defined in
VectorLineralize.cpp.
Linearization is disable in 2 notable cases:
vscale^2),
(such arith.constant Ops cannot be flattened).