Skip to content

[mlir][math] Propagate scalability in polynomial approximation #84949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 34 additions & 23 deletions mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,24 @@ using namespace mlir;
using namespace mlir::math;
using namespace mlir::vector;

// Helper to encapsulate a vector's shape (including scalable dims).
struct VectorShape {
ArrayRef<int64_t> sizes;
ArrayRef<bool> scalableFlags;

bool empty() const { return sizes.empty(); }
};

// Returns vector shape if the type is a vector. Returns an empty shape if it is
// not a vector.
static ArrayRef<int64_t> vectorShape(Type type) {
static VectorShape vectorShape(Type type) {
auto vectorType = dyn_cast<VectorType>(type);
return vectorType ? vectorType.getShape() : ArrayRef<int64_t>();
return vectorType
? VectorShape{vectorType.getShape(), vectorType.getScalableDims()}
: VectorShape{};
}

static ArrayRef<int64_t> vectorShape(Value value) {
static VectorShape vectorShape(Value value) {
return vectorShape(value.getType());
}

Expand All @@ -55,14 +65,16 @@ static ArrayRef<int64_t> vectorShape(Value value) {
//----------------------------------------------------------------------------//

// Broadcasts scalar type into vector type (iff shape is non-scalar).
static Type broadcast(Type type, ArrayRef<int64_t> shape) {
static Type broadcast(Type type, VectorShape shape) {
assert(!isa<VectorType>(type) && "must be scalar type");
return !shape.empty() ? VectorType::get(shape, type) : type;
return !shape.empty()
? VectorType::get(shape.sizes, type, shape.scalableFlags)
: type;
}

// Broadcasts scalar value into vector (iff shape is non-scalar).
static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
ArrayRef<int64_t> shape) {
VectorShape shape) {
assert(!isa<VectorType>(value.getType()) && "must be scalar value");
auto type = broadcast(value.getType(), shape);
return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
Expand Down Expand Up @@ -215,7 +227,7 @@ static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
bool isPositive = false) {
assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
ArrayRef<int64_t> shape = vectorShape(arg);
VectorShape shape = vectorShape(arg);

auto bcast = [&](Value value) -> Value {
return broadcast(builder, value, shape);
Expand Down Expand Up @@ -255,7 +267,7 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
// Computes exp2 for an i32 argument.
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
ArrayRef<int64_t> shape = vectorShape(arg);
VectorShape shape = vectorShape(arg);

auto bcast = [&](Value value) -> Value {
return broadcast(builder, value, shape);
Expand All @@ -281,7 +293,7 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
Type elementType = getElementTypeOrSelf(x);
assert((elementType.isF32() || elementType.isF16()) &&
"x must be f32 or f16 type");
ArrayRef<int64_t> shape = vectorShape(x);
VectorShape shape = vectorShape(x);

if (coeffs.empty())
return broadcast(builder, floatCst(builder, 0.0f, elementType), shape);
Expand Down Expand Up @@ -379,7 +391,7 @@ AtanApproximation::matchAndRewrite(math::AtanOp op,
if (!getElementTypeOrSelf(operand).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");

ArrayRef<int64_t> shape = vectorShape(op.getOperand());
VectorShape shape = vectorShape(op.getOperand());

ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
Value abs = builder.create<math::AbsFOp>(operand);
Expand Down Expand Up @@ -478,7 +490,7 @@ Atan2Approximation::matchAndRewrite(math::Atan2Op op,
return rewriter.notifyMatchFailure(op, "unsupported operand type");

ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
ArrayRef<int64_t> shape = vectorShape(op.getResult());
VectorShape shape = vectorShape(op.getResult());

// Compute atan in the valid range.
auto div = builder.create<arith::DivFOp>(y, x);
Expand Down Expand Up @@ -544,7 +556,7 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");

ArrayRef<int64_t> shape = vectorShape(op.getOperand());
VectorShape shape = vectorShape(op.getOperand());

ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
Expand Down Expand Up @@ -632,7 +644,7 @@ LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");

ArrayRef<int64_t> shape = vectorShape(op.getOperand());
VectorShape shape = vectorShape(op.getOperand());

ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
Expand Down Expand Up @@ -779,7 +791,7 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");

ArrayRef<int64_t> shape = vectorShape(op.getOperand());
VectorShape shape = vectorShape(op.getOperand());

ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
Expand Down Expand Up @@ -829,7 +841,7 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
if (!(elementType.isF32() || elementType.isF16()))
return rewriter.notifyMatchFailure(op,
"only f32 and f16 type is supported.");
ArrayRef<int64_t> shape = vectorShape(operand);
VectorShape shape = vectorShape(operand);

ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
Expand Down Expand Up @@ -938,9 +950,8 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,

namespace {

Value clampWithNormals(ImplicitLocOpBuilder &builder,
const llvm::ArrayRef<int64_t> shape, Value value,
float lowerBound, float upperBound) {
Value clampWithNormals(ImplicitLocOpBuilder &builder, const VectorShape shape,
Value value, float lowerBound, float upperBound) {
assert(!std::isnan(lowerBound));
assert(!std::isnan(upperBound));

Expand Down Expand Up @@ -1131,7 +1142,7 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");

ArrayRef<int64_t> shape = vectorShape(op.getOperand());
VectorShape shape = vectorShape(op.getOperand());

ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
Expand Down Expand Up @@ -1201,7 +1212,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");

ArrayRef<int64_t> shape = vectorShape(op.getOperand());
VectorShape shape = vectorShape(op.getOperand());

ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
Expand Down Expand Up @@ -1328,7 +1339,7 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op,
return rewriter.notifyMatchFailure(op, "unsupported operand type");

ImplicitLocOpBuilder b(op->getLoc(), rewriter);
ArrayRef<int64_t> shape = vectorShape(operand);
VectorShape shape = vectorShape(operand);

Type floatTy = getElementTypeOrSelf(operand.getType());
Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth());
Expand Down Expand Up @@ -1417,10 +1428,10 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");

ArrayRef<int64_t> shape = vectorShape(op.getOperand());
VectorShape shape = vectorShape(op.getOperand());

// Only support already-vectorized rsqrt's.
if (shape.empty() || shape.back() % 8 != 0)
if (shape.empty() || shape.sizes.back() % 8 != 0)
return rewriter.notifyMatchFailure(op, "unsupported operand type");

ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
Expand Down
89 changes: 89 additions & 0 deletions mlir/test/Dialect/Math/polynomial-approximation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,20 @@ func.func @erf_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @erf_scalable_vector(
// CHECK-SAME: %[[arg0:.*]]: vector<[8]xf32>) -> vector<[8]xf32> {
// CHECK: %[[zero:.*]] = arith.constant dense<0.000000e+00> : vector<[8]xf32>
// CHECK-NOT: erf
// CHECK-NOT: vector<8xf32>
// CHECK-COUNT-20: select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
// CHECK: %[[res:.*]] = arith.select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
// CHECK: return %[[res]] : vector<[8]xf32>
// CHECK: }
func.func @erf_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
%0 = math.erf %arg0 : vector<[8]xf32>
return %0 : vector<[8]xf32>
}

// CHECK-LABEL: func @exp_scalar(
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 5.000000e-01 : f32
Expand Down Expand Up @@ -151,6 +165,17 @@ func.func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @exp_scalable_vector
// CHECK-NOT: math.exp
// CHECK-NOT: vector<8xf32>
// CHECK-COUNT-46: vector<[8]x{{(i32)|(f32)}}>
// CHECK-NOT: vector<8xf32>
// CHECK-NOT: math.exp
func.func @exp_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
%0 = math.exp %arg0 : vector<[8]xf32>
return %0 : vector<[8]xf32>
}

// CHECK-LABEL: func @expm1_scalar(
// CHECK-SAME: %[[X:.*]]: f32) -> f32 {
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 1.000000e+00 : f32
Expand Down Expand Up @@ -277,6 +302,22 @@ func.func @expm1_vector(%arg0: vector<8x8xf32>) -> vector<8x8xf32> {
return %0 : vector<8x8xf32>
}

// CHECK-LABEL: func @expm1_scalable_vector(
// CHECK-SAME: %{{.*}}: vector<8x[8]xf32>) -> vector<8x[8]xf32> {
// CHECK-NOT: vector<8x8xf32>
// CHECK-NOT: exp
// CHECK-NOT: log
// CHECK-NOT: expm1
// CHECK-COUNT-127: vector<8x[8]x{{(i32)|(f32)|(i1)}}>
// CHECK-NOT: vector<8x8xf32>
// CHECK-NOT: exp
// CHECK-NOT: log
// CHECK-NOT: expm1
func.func @expm1_scalable_vector(%arg0: vector<8x[8]xf32>) -> vector<8x[8]xf32> {
%0 = math.expm1 %arg0 : vector<8x[8]xf32>
return %0 : vector<8x[8]xf32>
}

// CHECK-LABEL: func @log_scalar(
// CHECK-SAME: %[[X:.*]]: f32) -> f32 {
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
Expand Down Expand Up @@ -357,6 +398,18 @@ func.func @log_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @log_scalable_vector(
// CHECK-SAME: %{{.*}}: vector<[8]xf32>) -> vector<[8]xf32> {
// CHECK: %[[CST_LN2:.*]] = arith.constant dense<0.693147182> : vector<[8]xf32>
// CHECK-COUNT-5: select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
// CHECK: %[[VAL_71:.*]] = arith.select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
// CHECK: return %[[VAL_71]] : vector<[8]xf32>
// CHECK: }
func.func @log_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
%0 = math.log %arg0 : vector<[8]xf32>
return %0 : vector<[8]xf32>
}

// CHECK-LABEL: func @log2_scalar(
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
// CHECK: %[[CST_LOG2E:.*]] = arith.constant 1.44269502 : f32
Expand All @@ -381,6 +434,18 @@ func.func @log2_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @log2_scalable_vector(
// CHECK-SAME: %{{.*}}: vector<[8]xf32>) -> vector<[8]xf32> {
// CHECK: %[[CST_LOG2E:.*]] = arith.constant dense<1.44269502> : vector<[8]xf32>
// CHECK-COUNT-5: select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
// CHECK: %[[VAL_71:.*]] = arith.select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
// CHECK: return %[[VAL_71]] : vector<[8]xf32>
// CHECK: }
func.func @log2_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
%0 = math.log2 %arg0 : vector<[8]xf32>
return %0 : vector<[8]xf32>
}

// CHECK-LABEL: func @log1p_scalar(
// CHECK-SAME: %[[X:.*]]: f32) -> f32 {
// CHECK: %[[CST_ONE:.*]] = arith.constant 1.000000e+00 : f32
Expand Down Expand Up @@ -414,6 +479,17 @@ func.func @log1p_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @log1p_scalable_vector(
// CHECK-SAME: %[[VAL_0:.*]]: vector<[8]xf32>) -> vector<[8]xf32> {
// CHECK: %[[CST_ONE:.*]] = arith.constant dense<1.000000e+00> : vector<[8]xf32>
// CHECK-COUNT-6: select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
// CHECK: %[[VAL_79:.*]] = arith.select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
// CHECK: return %[[VAL_79]] : vector<[8]xf32>
// CHECK: }
func.func @log1p_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
%0 = math.log1p %arg0 : vector<[8]xf32>
return %0 : vector<[8]xf32>
}

// CHECK-LABEL: func @tanh_scalar(
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
Expand Down Expand Up @@ -470,6 +546,19 @@ func.func @tanh_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @tanh_scalable_vector(
// CHECK-SAME: %[[VAL_0:.*]]: vector<[8]xf32>) -> vector<[8]xf32> {
// CHECK: %[[VAL_1:.*]] = arith.constant dense<-7.99881172> : vector<[8]xf32>
// CHECK-NOT: tanh
// CHECK-COUNT-2: select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
// CHECK: %[[VAL_33:.*]] = arith.select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
// CHECK: return %[[VAL_33]] : vector<[8]xf32>
// CHECK: }
func.func @tanh_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
%0 = math.tanh %arg0 : vector<[8]xf32>
return %0 : vector<[8]xf32>
}

// We only approximate rsqrt for vectors and when the AVX2 option is enabled.
// CHECK-LABEL: func @rsqrt_scalar
// AVX2-LABEL: func @rsqrt_scalar
Expand Down