Skip to content

[mlir][vector] Add support for scalable vectors to VectorLinearize #86786

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 28, 2024

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Mar 27, 2024

Adds support for scalable vectors to patterns defined in
VectorLineralize.cpp.

Linearization is disable in 2 notable cases:

  • vectors with more than 1 scalable dimension (we cannot represent
    vscale^2),
  • vectors initialised with arith.constant that's not a vector splat
    (such arith.constant Ops cannot be flattened).

@llvmbot
Copy link
Member

llvmbot commented Mar 27, 2024

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes
  • WIP
  • [mlir][vector] Add support for scalable vectors to VectorLinearize

Full diff: https://github.com/llvm/llvm-project/pull/86786.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h (+10)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+8-3)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+5)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+44)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+3-1)
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 2c548fb6740251..f88fbdf9e62765 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -170,6 +170,16 @@ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
                             PatternRewriter &rewriter) const = 0;
 };
 
+/// Returns true if the input Vector type can be linearized.
+///
+/// Linearization is meant in the sense of flattening vectors, e.g.:
+///   * vector<NxMxKxi32> -> vector<N*M*Kxi32>
+/// In this sense, Vectors that are either:
+///   * already linearized, or
+///   * contain more than 1 scalable dimensions,
+/// are not linearizable.
+bool isLinearizableVector(VectorType type);
+
 } // namespace vector
 
 /// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 38536de43f13f2..c8043fbb7c3061 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -49,6 +49,11 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
     Location loc = constOp.getLoc();
     auto resType =
         getTypeConverter()->convertType<VectorType>(constOp.getType());
+
+    if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
+      return rewriter.notifyMatchFailure(
+          loc, "Cannot linearize a constant scalable vector that's not a splt");
+
     if (!resType)
       return rewriter.notifyMatchFailure(loc, "can't convert return type");
     if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
@@ -104,11 +109,11 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
     ConversionTarget &target, unsigned targetBitWidth) {
 
   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
-    // Ignore scalable vectors for now.
-    if (type.getRank() <= 1 || type.isScalable())
+    if (!isLinearizableVector(type))
       return type;
 
-    return VectorType::get(type.getNumElements(), type.getElementType());
+    return VectorType::get(type.getNumElements(), type.getElementType(),
+                           type.isScalable());
   });
 
   auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 63ed0947cf6ce2..a4415a80139af1 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -317,3 +317,8 @@ SmallVector<OpFoldResult> vector::getMixedSizesXfer(bool hasTensorSemantics,
                          : memref::getMixedSizes(rewriter, loc, base);
   return mixedSourceDims;
 }
+
+bool vector::isLinearizableVector(VectorType type) {
+  auto numScalableDims = llvm::count(type.getScalableDims(), true);
+  return ((type.getRank() > 1) && (numScalableDims <= 1));
+}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 1b225c7a97d233..3ab68f19aa0c60 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -97,3 +97,47 @@ func.func @test_tensor_no_linearize(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf3
 
     return %0, %arg0 : tensor<2x2xf32>, tensor<2x2xf32>
 }
+
+// -----
+
+// ALL-LABEL:   func.func @test_1_scalable_dim(
+// ALL-SAME:    %[[ARG_0:.*]]: vector<2x[4]xf32>) -> vector<2x[4]xf32> {
+func.func @test_1_scalable_dim(%arg0: vector<2x[4]xf32>) -> vector<2x[4]xf32> {
+  // DEFAULT:  %[[SC:.*]] = vector.shape_cast %[[ARG_0]] : vector<2x[4]xf32> to vector<[8]xf32>
+  // DEFAULT:  %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<[8]xf32>
+  // BW-128:  %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<2x[4]xf32>
+  // BW-0:  %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<2x[4]xf32>
+  %0 = arith.constant dense<[[3., 3., 3., 3.], [3., 3., 3., 3.]]> : vector<2x[4]xf32>
+
+  // DEFAULT: %[[SIN:.*]] = math.sin %[[SC]] : vector<[8]xf32>
+  // BW-128: %[[SIN:.*]] = math.sin %[[ARG_0]] : vector<2x[4]xf32>
+  // BW-0: %[[SIN:.*]] = math.sin %[[ARG_0]] : vector<2x[4]xf32>
+  %1 = math.sin %arg0 : vector<2x[4]xf32>
+
+  // DEFAULT: %[[ADDF:.*]] = arith.addf %[[SIN]], %[[CST]] : vector<[8]xf32>
+  // BW-128: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<2x[4]xf32>
+  // BW-0: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<2x[4]xf32>
+  %2 = arith.addf %0, %1 : vector<2x[4]xf32>
+
+  // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[8]xf32> to vector<2x[4]xf32>
+  // ALL: return %[[RES]] : vector<2x[4]xf32>
+  return %2 : vector<2x[4]xf32>
+}
+
+// -----
+
+// ALL-LABEL:   func.func @test_2_scalable_dims(
+// ALL-SAME:     %[[VAL_0:.*]]: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
+func.func @test_2_scalable_dims(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
+  // ALL: %[[CST:.*]] = arith.constant dense<2.000000e+00> : vector<[2]x[2]xf32>
+  %0 = arith.constant dense<[[2., 2.], [2., 2.]]> : vector<[2]x[2]xf32>
+
+  // ALL: %[[SIN:.*]] = math.sin %[[VAL_0]] : vector<[2]x[2]xf32>
+  %1 = math.sin %arg0 : vector<[2]x[2]xf32>
+
+  // ALL: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<[2]x[2]xf32>
+  %2 = arith.addf %0, %1 : vector<[2]x[2]xf32>
+
+  // ALL: return %[[RES]] : vector<[2]x[2]xf32>
+  return %2 : vector<[2]x[2]xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f14fb18706d1b7..766ddae47c53b9 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -489,7 +489,9 @@ struct TestFlattenVectorTransferPatterns
   Option<unsigned> targetVectorBitwidth{
       *this, "target-vector-bitwidth",
       llvm::cl::desc(
-          "Minimum vector bitwidth to enable the flattening transformation"),
+          "Minimum vector bitwidth to enable the flattening transformation. "
+          "For scalable vectors this is the base size that's known at compile "
+          "time."),
       llvm::cl::init(std::numeric_limits<unsigned>::max())};
 
   void runOnOperation() override {

@llvmbot
Copy link
Member

llvmbot commented Mar 27, 2024

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes
  • WIP
  • [mlir][vector] Add support for scalable vectors to VectorLinearize

Full diff: https://github.com/llvm/llvm-project/pull/86786.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h (+10)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+8-3)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+5)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+44)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+3-1)
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 2c548fb6740251..f88fbdf9e62765 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -170,6 +170,16 @@ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
                             PatternRewriter &rewriter) const = 0;
 };
 
+/// Returns true if the input Vector type can be linearized.
+///
+/// Linearization is meant in the sense of flattening vectors, e.g.:
+///   * vector<NxMxKxi32> -> vector<N*M*Kxi32>
+/// In this sense, Vectors that are either:
+///   * already linearized, or
+///   * contain more than 1 scalable dimensions,
+/// are not linearizable.
+bool isLinearizableVector(VectorType type);
+
 } // namespace vector
 
 /// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 38536de43f13f2..c8043fbb7c3061 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -49,6 +49,11 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
     Location loc = constOp.getLoc();
     auto resType =
         getTypeConverter()->convertType<VectorType>(constOp.getType());
+
+    if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
+      return rewriter.notifyMatchFailure(
+          loc, "Cannot linearize a constant scalable vector that's not a splt");
+
     if (!resType)
       return rewriter.notifyMatchFailure(loc, "can't convert return type");
     if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
@@ -104,11 +109,11 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
     ConversionTarget &target, unsigned targetBitWidth) {
 
   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
-    // Ignore scalable vectors for now.
-    if (type.getRank() <= 1 || type.isScalable())
+    if (!isLinearizableVector(type))
       return type;
 
-    return VectorType::get(type.getNumElements(), type.getElementType());
+    return VectorType::get(type.getNumElements(), type.getElementType(),
+                           type.isScalable());
   });
 
   auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 63ed0947cf6ce2..a4415a80139af1 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -317,3 +317,8 @@ SmallVector<OpFoldResult> vector::getMixedSizesXfer(bool hasTensorSemantics,
                          : memref::getMixedSizes(rewriter, loc, base);
   return mixedSourceDims;
 }
+
+bool vector::isLinearizableVector(VectorType type) {
+  auto numScalableDims = llvm::count(type.getScalableDims(), true);
+  return ((type.getRank() > 1) && (numScalableDims <= 1));
+}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 1b225c7a97d233..3ab68f19aa0c60 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -97,3 +97,47 @@ func.func @test_tensor_no_linearize(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf3
 
     return %0, %arg0 : tensor<2x2xf32>, tensor<2x2xf32>
 }
+
+// -----
+
+// ALL-LABEL:   func.func @test_1_scalable_dim(
+// ALL-SAME:    %[[ARG_0:.*]]: vector<2x[4]xf32>) -> vector<2x[4]xf32> {
+func.func @test_1_scalable_dim(%arg0: vector<2x[4]xf32>) -> vector<2x[4]xf32> {
+  // DEFAULT:  %[[SC:.*]] = vector.shape_cast %[[ARG_0]] : vector<2x[4]xf32> to vector<[8]xf32>
+  // DEFAULT:  %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<[8]xf32>
+  // BW-128:  %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<2x[4]xf32>
+  // BW-0:  %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<2x[4]xf32>
+  %0 = arith.constant dense<[[3., 3., 3., 3.], [3., 3., 3., 3.]]> : vector<2x[4]xf32>
+
+  // DEFAULT: %[[SIN:.*]] = math.sin %[[SC]] : vector<[8]xf32>
+  // BW-128: %[[SIN:.*]] = math.sin %[[ARG_0]] : vector<2x[4]xf32>
+  // BW-0: %[[SIN:.*]] = math.sin %[[ARG_0]] : vector<2x[4]xf32>
+  %1 = math.sin %arg0 : vector<2x[4]xf32>
+
+  // DEFAULT: %[[ADDF:.*]] = arith.addf %[[SIN]], %[[CST]] : vector<[8]xf32>
+  // BW-128: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<2x[4]xf32>
+  // BW-0: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<2x[4]xf32>
+  %2 = arith.addf %0, %1 : vector<2x[4]xf32>
+
+  // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[8]xf32> to vector<2x[4]xf32>
+  // ALL: return %[[RES]] : vector<2x[4]xf32>
+  return %2 : vector<2x[4]xf32>
+}
+
+// -----
+
+// ALL-LABEL:   func.func @test_2_scalable_dims(
+// ALL-SAME:     %[[VAL_0:.*]]: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
+func.func @test_2_scalable_dims(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
+  // ALL: %[[CST:.*]] = arith.constant dense<2.000000e+00> : vector<[2]x[2]xf32>
+  %0 = arith.constant dense<[[2., 2.], [2., 2.]]> : vector<[2]x[2]xf32>
+
+  // ALL: %[[SIN:.*]] = math.sin %[[VAL_0]] : vector<[2]x[2]xf32>
+  %1 = math.sin %arg0 : vector<[2]x[2]xf32>
+
+  // ALL: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<[2]x[2]xf32>
+  %2 = arith.addf %0, %1 : vector<[2]x[2]xf32>
+
+  // ALL: return %[[RES]] : vector<[2]x[2]xf32>
+  return %2 : vector<[2]x[2]xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f14fb18706d1b7..766ddae47c53b9 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -489,7 +489,9 @@ struct TestFlattenVectorTransferPatterns
   Option<unsigned> targetVectorBitwidth{
       *this, "target-vector-bitwidth",
       llvm::cl::desc(
-          "Minimum vector bitwidth to enable the flattening transformation"),
+          "Minimum vector bitwidth to enable the flattening transformation. "
+          "For scalable vectors this is the base size that's known at compile "
+          "time."),
       llvm::cl::init(std::numeric_limits<unsigned>::max())};
 
   void runOnOperation() override {

Adds support for scalable vectors to patterns defined in
VectorLineralize.cpp.

Linearization is disable in 2 notable cases:
  * vectors with more than 1 scalable dimension (we cannot represent
    vscale^2),
  * vectors initialised with arith.constant that's not a vector splat
    (such arith.constant Ops cannot be flattened).
@banach-space banach-space force-pushed the andrzej/linearize_add_scalable branch from 3b3ab4e to 457722c Compare March 27, 2024 09:49
@banach-space banach-space changed the title andrzej/linearize add scalable [mlir][vector] Add support for scalable vectors to VectorLinearize Mar 27, 2024
Comment on lines 493 to 494
"For scalable vectors this is the base size that's known at compile "
"time."),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"For scalable vectors this is the base size that's known at compile "
"time."),
"For scalable vectors this is the base size (vscale=1) that's known at compile "
"time."),

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this wording would be a bit confusing. How about:

      "Minimum vector bitwidth to enable the flattening transformation. "
      "For scalable vectors this is the base size, i.e. the size "
      "corresponding to vscale=1."),

?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, or simply "For scalable vectors this is the size corresponding to vscale=1.". To me the important thing is mentioning vscale=1 as I think base size is a bit ambiguous.


bool vector::isLinearizableVector(VectorType type) {
auto numScalableDims = llvm::count(type.getScalableDims(), true);
return ((type.getRank() > 1) && (numScalableDims <= 1));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can drop the outermost () pair.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both the outer and inner parentheses are redundant here: type.getRank() > 1 && numScalableDims <= 1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh good catch! I did not notice that there are inner parentheses!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to keep the inner parentheses - I find that easier to parse. Probably a subjective thing 😅

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers

@banach-space banach-space merged commit d3aa92e into llvm:main Mar 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants