Skip to content

[mlir][math] Propagate scalability in polynomial approximation #84949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 15, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Mar 12, 2024

This simply updates the rewrites to propagate the scalable flags (which as they do not alter the vector shape, is pretty simple).

The added tests are simply scalable versions of the existing vector tests.

This simply updates the rewrites to propagate the scalable flags (which
as they do not alter the vector shape, is pretty simple).

The added tests are simply scalable versions of the existing vector
tests.
@llvmbot
Copy link
Member

llvmbot commented Mar 12, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-math

Author: Benjamin Maxwell (MacDue)

Changes

This simply updates the rewrites to propagate the scalable flags (which as they do not alter the vector shape, is pretty simple).

The added tests are simply scalable versions of the existing vector tests.


Full diff: https://github.com/llvm/llvm-project/pull/84949.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (+34-23)
  • (modified) mlir/test/Dialect/Math/polynomial-approximation.mlir (+89)
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 962cb28b7c2ab9..428c1c37c4e8b5 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -39,14 +39,24 @@ using namespace mlir;
 using namespace mlir::math;
 using namespace mlir::vector;
 
+// Helper to encapsulate a vector's shape (including scalable dims).
+struct VectorShape {
+  ArrayRef<int64_t> sizes;
+  ArrayRef<bool> scalableFlags;
+
+  bool empty() const { return sizes.empty(); }
+};
+
 // Returns vector shape if the type is a vector. Returns an empty shape if it is
 // not a vector.
-static ArrayRef<int64_t> vectorShape(Type type) {
+static VectorShape vectorShape(Type type) {
   auto vectorType = dyn_cast<VectorType>(type);
-  return vectorType ? vectorType.getShape() : ArrayRef<int64_t>();
+  return vectorType
+             ? VectorShape{vectorType.getShape(), vectorType.getScalableDims()}
+             : VectorShape{};
 }
 
-static ArrayRef<int64_t> vectorShape(Value value) {
+static VectorShape vectorShape(Value value) {
   return vectorShape(value.getType());
 }
 
@@ -55,14 +65,16 @@ static ArrayRef<int64_t> vectorShape(Value value) {
 //----------------------------------------------------------------------------//
 
 // Broadcasts scalar type into vector type (iff shape is non-scalar).
-static Type broadcast(Type type, ArrayRef<int64_t> shape) {
+static Type broadcast(Type type, VectorShape shape) {
   assert(!isa<VectorType>(type) && "must be scalar type");
-  return !shape.empty() ? VectorType::get(shape, type) : type;
+  return !shape.empty()
+             ? VectorType::get(shape.sizes, type, shape.scalableFlags)
+             : type;
 }
 
 // Broadcasts scalar value into vector (iff shape is non-scalar).
 static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
-                       ArrayRef<int64_t> shape) {
+                       VectorShape shape) {
   assert(!isa<VectorType>(value.getType()) && "must be scalar value");
   auto type = broadcast(value.getType(), shape);
   return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
@@ -215,7 +227,7 @@ static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
 static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
                                      bool isPositive = false) {
   assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
-  ArrayRef<int64_t> shape = vectorShape(arg);
+  VectorShape shape = vectorShape(arg);
 
   auto bcast = [&](Value value) -> Value {
     return broadcast(builder, value, shape);
@@ -255,7 +267,7 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
 // Computes exp2 for an i32 argument.
 static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
   assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
-  ArrayRef<int64_t> shape = vectorShape(arg);
+  VectorShape shape = vectorShape(arg);
 
   auto bcast = [&](Value value) -> Value {
     return broadcast(builder, value, shape);
@@ -281,7 +293,7 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
   Type elementType = getElementTypeOrSelf(x);
   assert((elementType.isF32() || elementType.isF16()) &&
          "x must be f32 or f16 type");
-  ArrayRef<int64_t> shape = vectorShape(x);
+  VectorShape shape = vectorShape(x);
 
   if (coeffs.empty())
     return broadcast(builder, floatCst(builder, 0.0f, elementType), shape);
@@ -379,7 +391,7 @@ AtanApproximation::matchAndRewrite(math::AtanOp op,
   if (!getElementTypeOrSelf(operand).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
-  ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+  VectorShape shape = vectorShape(op.getOperand());
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   Value abs = builder.create<math::AbsFOp>(operand);
@@ -478,7 +490,7 @@ Atan2Approximation::matchAndRewrite(math::Atan2Op op,
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
-  ArrayRef<int64_t> shape = vectorShape(op.getResult());
+  VectorShape shape = vectorShape(op.getResult());
 
   // Compute atan in the valid range.
   auto div = builder.create<arith::DivFOp>(y, x);
@@ -544,7 +556,7 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
   if (!getElementTypeOrSelf(op.getOperand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
-  ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+  VectorShape shape = vectorShape(op.getOperand());
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
@@ -632,7 +644,7 @@ LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
   if (!getElementTypeOrSelf(op.getOperand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
-  ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+  VectorShape shape = vectorShape(op.getOperand());
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
@@ -779,7 +791,7 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
   if (!getElementTypeOrSelf(op.getOperand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
-  ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+  VectorShape shape = vectorShape(op.getOperand());
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
@@ -829,7 +841,7 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
   if (!(elementType.isF32() || elementType.isF16()))
     return rewriter.notifyMatchFailure(op,
                                        "only f32 and f16 type is supported.");
-  ArrayRef<int64_t> shape = vectorShape(operand);
+  VectorShape shape = vectorShape(operand);
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
@@ -938,9 +950,8 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
 
 namespace {
 
-Value clampWithNormals(ImplicitLocOpBuilder &builder,
-                       const llvm::ArrayRef<int64_t> shape, Value value,
-                       float lowerBound, float upperBound) {
+Value clampWithNormals(ImplicitLocOpBuilder &builder, const VectorShape shape,
+                       Value value, float lowerBound, float upperBound) {
   assert(!std::isnan(lowerBound));
   assert(!std::isnan(upperBound));
 
@@ -1131,7 +1142,7 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
   if (!getElementTypeOrSelf(op.getOperand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
-  ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+  VectorShape shape = vectorShape(op.getOperand());
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
@@ -1201,7 +1212,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
   if (!getElementTypeOrSelf(op.getOperand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
-  ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+  VectorShape shape = vectorShape(op.getOperand());
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
@@ -1328,7 +1339,7 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op,
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
-  ArrayRef<int64_t> shape = vectorShape(operand);
+  VectorShape shape = vectorShape(operand);
 
   Type floatTy = getElementTypeOrSelf(operand.getType());
   Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth());
@@ -1417,10 +1428,10 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
   if (!getElementTypeOrSelf(op.getOperand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
-  ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+  VectorShape shape = vectorShape(op.getOperand());
 
   // Only support already-vectorized rsqrt's.
-  if (shape.empty() || shape.back() % 8 != 0)
+  if (shape.empty() || shape.sizes.back() % 8 != 0)
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 834a7dc0af66d6..82b2646bea4a86 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -94,6 +94,20 @@ func.func @erf_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
   return %0 : vector<8xf32>
 }
 
+// CHECK-LABEL:   func @erf_scalable_vector(
+// CHECK-SAME:                     %[[arg0:.*]]: vector<[8]xf32>) -> vector<[8]xf32> {
+// CHECK:           %[[zero:.*]] = arith.constant dense<0.000000e+00> : vector<[8]xf32>
+// CHECK-NOT:       erf
+// CHECK-NOT:       vector<8xf32>
+// CHECK-COUNT-20:  select
+// CHECK:           %[[res:.*]] = arith.select
+// CHECK:           return %[[res]] : vector<[8]xf32>
+// CHECK:         }
+func.func @erf_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
+  %0 = math.erf %arg0 : vector<[8]xf32>
+  return %0 : vector<[8]xf32>
+}
+
 // CHECK-LABEL:   func @exp_scalar(
 // CHECK-SAME:                     %[[VAL_0:.*]]: f32) -> f32 {
 // CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 5.000000e-01 : f32
@@ -151,6 +165,17 @@ func.func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
   return %0 : vector<8xf32>
 }
 
+// CHECK-LABEL:   func @exp_scalable_vector
+// CHECK-NOT:  math.exp
+// CHECK-NOT:  vector<8xf32>
+// CHECK:      vector<[8]xf32>
+// CHECK-NOT:  vector<8xf32>
+// CHECK-NOT:  math.exp
+func.func @exp_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
+  %0 = math.exp %arg0 : vector<[8]xf32>
+  return %0 : vector<[8]xf32>
+}
+
 // CHECK-LABEL:   func @expm1_scalar(
 // CHECK-SAME:                       %[[X:.*]]: f32) -> f32 {
 // CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 1.000000e+00 : f32
@@ -277,6 +302,22 @@ func.func @expm1_vector(%arg0: vector<8x8xf32>) -> vector<8x8xf32> {
   return %0 : vector<8x8xf32>
 }
 
+// CHECK-LABEL:   func @expm1_scalable_vector(
+// CHECK-SAME:                       %{{.*}}: vector<8x[8]xf32>) -> vector<8x[8]xf32> {
+// CHECK-NOT:       vector<8x8xf32>
+// CHECK-NOT:       exp
+// CHECK-NOT:       log
+// CHECK-NOT:       expm1
+// CHECK:           vector<8x[8]xf32>
+// CHECK-NOT:       vector<8x8xf32>
+// CHECK-NOT:       exp
+// CHECK-NOT:       log
+// CHECK-NOT:       expm1
+func.func @expm1_scalable_vector(%arg0: vector<8x[8]xf32>) -> vector<8x[8]xf32> {
+  %0 = math.expm1 %arg0 : vector<8x[8]xf32>
+  return %0 : vector<8x[8]xf32>
+}
+
 // CHECK-LABEL:   func @log_scalar(
 // CHECK-SAME:                             %[[X:.*]]: f32) -> f32 {
 // CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
@@ -357,6 +398,18 @@ func.func @log_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
   return %0 : vector<8xf32>
 }
 
+// CHECK-LABEL:   func @log_scalable_vector(
+// CHECK-SAME:                     %{{.*}}: vector<[8]xf32>) -> vector<[8]xf32> {
+// CHECK:           %[[CST_LN2:.*]] = arith.constant dense<0.693147182> : vector<[8]xf32>
+// CHECK-COUNT-5:   select
+// CHECK:           %[[VAL_71:.*]] = arith.select
+// CHECK:           return %[[VAL_71]] : vector<[8]xf32>
+// CHECK:         }
+func.func @log_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
+  %0 = math.log %arg0 : vector<[8]xf32>
+  return %0 : vector<[8]xf32>
+}
+
 // CHECK-LABEL:   func @log2_scalar(
 // CHECK-SAME:                      %[[VAL_0:.*]]: f32) -> f32 {
 // CHECK:           %[[CST_LOG2E:.*]] = arith.constant 1.44269502 : f32
@@ -381,6 +434,18 @@ func.func @log2_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
   return %0 : vector<8xf32>
 }
 
+// CHECK-LABEL:   func @log2_scalable_vector(
+// CHECK-SAME:                      %{{.*}}: vector<[8]xf32>) -> vector<[8]xf32> {
+// CHECK:           %[[CST_LOG2E:.*]] = arith.constant dense<1.44269502> : vector<[8]xf32>
+// CHECK-COUNT-5:   select
+// CHECK:           %[[VAL_71:.*]] = arith.select
+// CHECK:           return %[[VAL_71]] : vector<[8]xf32>
+// CHECK:         }
+func.func @log2_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
+  %0 = math.log2 %arg0 : vector<[8]xf32>
+  return %0 : vector<[8]xf32>
+}
+
 // CHECK-LABEL:   func @log1p_scalar(
 // CHECK-SAME:                       %[[X:.*]]: f32) -> f32 {
 // CHECK:           %[[CST_ONE:.*]] = arith.constant 1.000000e+00 : f32
@@ -414,6 +479,17 @@ func.func @log1p_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
   return %0 : vector<8xf32>
 }
 
+// CHECK-LABEL:   func @log1p_scalable_vector(
+// CHECK-SAME:                       %[[VAL_0:.*]]: vector<[8]xf32>) -> vector<[8]xf32> {
+// CHECK:           %[[CST_ONE:.*]] = arith.constant dense<1.000000e+00> : vector<[8]xf32>
+// CHECK-COUNT-6:   select
+// CHECK:           %[[VAL_79:.*]] = arith.select
+// CHECK:           return %[[VAL_79]] : vector<[8]xf32>
+// CHECK:         }
+func.func @log1p_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
+  %0 = math.log1p %arg0 : vector<[8]xf32>
+  return %0 : vector<[8]xf32>
+}
 
 // CHECK-LABEL:   func @tanh_scalar(
 // CHECK-SAME:                      %[[VAL_0:.*]]: f32) -> f32 {
@@ -470,6 +546,19 @@ func.func @tanh_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
   return %0 : vector<8xf32>
 }
 
+// CHECK-LABEL:   func @tanh_scalable_vector(
+// CHECK-SAME:                      %[[VAL_0:.*]]: vector<[8]xf32>) -> vector<[8]xf32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant dense<-7.99881172> : vector<[8]xf32>
+// CHECK-NOT:       tanh
+// CHECK-COUNT-2:   select
+// CHECK:           %[[VAL_33:.*]] = arith.select
+// CHECK:           return %[[VAL_33]] : vector<[8]xf32>
+// CHECK:         }
+func.func @tanh_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
+  %0 = math.tanh %arg0 : vector<[8]xf32>
+  return %0 : vector<[8]xf32>
+}
+
 // We only approximate rsqrt for vectors and when the AVX2 option is enabled.
 // CHECK-LABEL:   func @rsqrt_scalar
 // AVX2-LABEL:    func @rsqrt_scalar

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just left one comment but otherwise LGTM cheers

@MacDue MacDue merged commit e74bcec into llvm:main Mar 15, 2024
@MacDue MacDue deleted the scalable_poly_approx branch March 15, 2024 20:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants