Skip to content

Commit 72f3621

Browse files
authored
[mlir][Math] Fix 0-rank support for PolynomialApproximation (#114826)
This patch disambiguates 0-rank vectors and scalars in PolynomialApproximation. This fixes a bug in PolynomialApproximation where 0-rank vectors would be treated as scalars and arguments would not be broadcasted properly.
1 parent bf01bb8 commit 72f3621

File tree

2 files changed

+72
-34
lines changed

2 files changed

+72
-34
lines changed

mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -43,20 +43,17 @@ using namespace mlir::vector;
4343
struct VectorShape {
4444
ArrayRef<int64_t> sizes;
4545
ArrayRef<bool> scalableFlags;
46-
47-
bool empty() const { return sizes.empty(); }
4846
};
4947

50-
// Returns vector shape if the type is a vector. Returns an empty shape if it is
51-
// not a vector.
52-
static VectorShape vectorShape(Type type) {
53-
auto vectorType = dyn_cast<VectorType>(type);
54-
return vectorType
55-
? VectorShape{vectorType.getShape(), vectorType.getScalableDims()}
56-
: VectorShape{};
48+
// Returns vector shape if the type is a vector, otherwise return nullopt.
49+
static std::optional<VectorShape> vectorShape(Type type) {
50+
if (auto vectorType = dyn_cast<VectorType>(type)) {
51+
return VectorShape{vectorType.getShape(), vectorType.getScalableDims()};
52+
}
53+
return std::nullopt;
5754
}
5855

59-
static VectorShape vectorShape(Value value) {
56+
static std::optional<VectorShape> vectorShape(Value value) {
6057
return vectorShape(value.getType());
6158
}
6259

@@ -65,19 +62,18 @@ static VectorShape vectorShape(Value value) {
6562
//----------------------------------------------------------------------------//
6663

6764
// Broadcasts scalar type into vector type (iff shape is non-scalar).
68-
static Type broadcast(Type type, VectorShape shape) {
65+
static Type broadcast(Type type, std::optional<VectorShape> shape) {
6966
assert(!isa<VectorType>(type) && "must be scalar type");
70-
return !shape.empty()
71-
? VectorType::get(shape.sizes, type, shape.scalableFlags)
72-
: type;
67+
return shape ? VectorType::get(shape->sizes, type, shape->scalableFlags)
68+
: type;
7369
}
7470

7571
// Broadcasts scalar value into vector (iff shape is non-scalar).
7672
static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
77-
VectorShape shape) {
73+
std::optional<VectorShape> shape) {
7874
assert(!isa<VectorType>(value.getType()) && "must be scalar value");
7975
auto type = broadcast(value.getType(), shape);
80-
return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
76+
return shape ? builder.create<BroadcastOp>(type, value) : value;
8177
}
8278

8379
//----------------------------------------------------------------------------//
@@ -227,7 +223,7 @@ static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
227223
static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
228224
bool isPositive = false) {
229225
assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
230-
VectorShape shape = vectorShape(arg);
226+
std::optional<VectorShape> shape = vectorShape(arg);
231227

232228
auto bcast = [&](Value value) -> Value {
233229
return broadcast(builder, value, shape);
@@ -267,7 +263,7 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
267263
// Computes exp2 for an i32 argument.
268264
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
269265
assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
270-
VectorShape shape = vectorShape(arg);
266+
std::optional<VectorShape> shape = vectorShape(arg);
271267

272268
auto bcast = [&](Value value) -> Value {
273269
return broadcast(builder, value, shape);
@@ -293,7 +289,7 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
293289
Type elementType = getElementTypeOrSelf(x);
294290
assert((elementType.isF32() || elementType.isF16()) &&
295291
"x must be f32 or f16 type");
296-
VectorShape shape = vectorShape(x);
292+
std::optional<VectorShape> shape = vectorShape(x);
297293

298294
if (coeffs.empty())
299295
return broadcast(builder, floatCst(builder, 0.0f, elementType), shape);
@@ -391,7 +387,7 @@ AtanApproximation::matchAndRewrite(math::AtanOp op,
391387
if (!getElementTypeOrSelf(operand).isF32())
392388
return rewriter.notifyMatchFailure(op, "unsupported operand type");
393389

394-
VectorShape shape = vectorShape(op.getOperand());
390+
std::optional<VectorShape> shape = vectorShape(op.getOperand());
395391

396392
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
397393
Value abs = builder.create<math::AbsFOp>(operand);
@@ -490,7 +486,7 @@ Atan2Approximation::matchAndRewrite(math::Atan2Op op,
490486
return rewriter.notifyMatchFailure(op, "unsupported operand type");
491487

492488
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
493-
VectorShape shape = vectorShape(op.getResult());
489+
std::optional<VectorShape> shape = vectorShape(op.getResult());
494490

495491
// Compute atan in the valid range.
496492
auto div = builder.create<arith::DivFOp>(y, x);
@@ -556,7 +552,7 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
556552
if (!getElementTypeOrSelf(op.getOperand()).isF32())
557553
return rewriter.notifyMatchFailure(op, "unsupported operand type");
558554

559-
VectorShape shape = vectorShape(op.getOperand());
555+
std::optional<VectorShape> shape = vectorShape(op.getOperand());
560556

561557
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
562558
auto bcast = [&](Value value) -> Value {
@@ -644,7 +640,7 @@ LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
644640
if (!getElementTypeOrSelf(op.getOperand()).isF32())
645641
return rewriter.notifyMatchFailure(op, "unsupported operand type");
646642

647-
VectorShape shape = vectorShape(op.getOperand());
643+
std::optional<VectorShape> shape = vectorShape(op.getOperand());
648644

649645
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
650646
auto bcast = [&](Value value) -> Value {
@@ -791,7 +787,7 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
791787
if (!getElementTypeOrSelf(op.getOperand()).isF32())
792788
return rewriter.notifyMatchFailure(op, "unsupported operand type");
793789

794-
VectorShape shape = vectorShape(op.getOperand());
790+
std::optional<VectorShape> shape = vectorShape(op.getOperand());
795791

796792
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
797793
auto bcast = [&](Value value) -> Value {
@@ -846,7 +842,7 @@ AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
846842
if (!(elementType.isF32() || elementType.isF16()))
847843
return rewriter.notifyMatchFailure(op,
848844
"only f32 and f16 type is supported.");
849-
VectorShape shape = vectorShape(operand);
845+
std::optional<VectorShape> shape = vectorShape(operand);
850846

851847
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
852848
auto bcast = [&](Value value) -> Value {
@@ -941,7 +937,7 @@ AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
941937
if (!(elementType.isF32() || elementType.isF16()))
942938
return rewriter.notifyMatchFailure(op,
943939
"only f32 and f16 type is supported.");
944-
VectorShape shape = vectorShape(operand);
940+
std::optional<VectorShape> shape = vectorShape(operand);
945941

946942
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
947943
auto bcast = [&](Value value) -> Value {
@@ -1019,7 +1015,7 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
10191015
if (!(elementType.isF32() || elementType.isF16()))
10201016
return rewriter.notifyMatchFailure(op,
10211017
"only f32 and f16 type is supported.");
1022-
VectorShape shape = vectorShape(operand);
1018+
std::optional<VectorShape> shape = vectorShape(operand);
10231019

10241020
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
10251021
auto bcast = [&](Value value) -> Value {
@@ -1128,8 +1124,9 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
11281124

11291125
namespace {
11301126

1131-
Value clampWithNormals(ImplicitLocOpBuilder &builder, const VectorShape shape,
1132-
Value value, float lowerBound, float upperBound) {
1127+
Value clampWithNormals(ImplicitLocOpBuilder &builder,
1128+
const std::optional<VectorShape> shape, Value value,
1129+
float lowerBound, float upperBound) {
11331130
assert(!std::isnan(lowerBound));
11341131
assert(!std::isnan(upperBound));
11351132

@@ -1320,7 +1317,7 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
13201317
if (!getElementTypeOrSelf(op.getOperand()).isF32())
13211318
return rewriter.notifyMatchFailure(op, "unsupported operand type");
13221319

1323-
VectorShape shape = vectorShape(op.getOperand());
1320+
std::optional<VectorShape> shape = vectorShape(op.getOperand());
13241321

13251322
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
13261323
auto bcast = [&](Value value) -> Value {
@@ -1390,7 +1387,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
13901387
if (!getElementTypeOrSelf(op.getOperand()).isF32())
13911388
return rewriter.notifyMatchFailure(op, "unsupported operand type");
13921389

1393-
VectorShape shape = vectorShape(op.getOperand());
1390+
std::optional<VectorShape> shape = vectorShape(op.getOperand());
13941391

13951392
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
13961393
auto bcast = [&](Value value) -> Value {
@@ -1517,7 +1514,7 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op,
15171514
return rewriter.notifyMatchFailure(op, "unsupported operand type");
15181515

15191516
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1520-
VectorShape shape = vectorShape(operand);
1517+
std::optional<VectorShape> shape = vectorShape(operand);
15211518

15221519
Type floatTy = getElementTypeOrSelf(operand.getType());
15231520
Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth());
@@ -1606,10 +1603,10 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
16061603
if (!getElementTypeOrSelf(op.getOperand()).isF32())
16071604
return rewriter.notifyMatchFailure(op, "unsupported operand type");
16081605

1609-
VectorShape shape = vectorShape(op.getOperand());
1606+
std::optional<VectorShape> shape = vectorShape(op.getOperand());
16101607

16111608
// Only support already-vectorized rsqrt's.
1612-
if (shape.empty() || shape.sizes.back() % 8 != 0)
1609+
if (!shape || shape->sizes.empty() || shape->sizes.back() % 8 != 0)
16131610
return rewriter.notifyMatchFailure(op, "unsupported operand type");
16141611

16151612
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);

mlir/test/Dialect/Math/polynomial-approximation.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,47 @@ func.func @math_f16(%arg0 : vector<4xf16>) -> vector<4xf16> {
894894
return %11 : vector<4xf16>
895895
}
896896

897+
// CHECK-LABEL: @math_zero_rank
898+
func.func @math_zero_rank(%arg0 : vector<f16>) -> vector<f16> {
899+
900+
// CHECK-NOT: math.atan
901+
%0 = "math.atan"(%arg0) : (vector<f16>) -> vector<f16>
902+
903+
// CHECK-NOT: math.atan2
904+
%1 = "math.atan2"(%0, %arg0) : (vector<f16>, vector<f16>) -> vector<f16>
905+
906+
// CHECK-NOT: math.tanh
907+
%2 = "math.tanh"(%1) : (vector<f16>) -> vector<f16>
908+
909+
// CHECK-NOT: math.log
910+
%3 = "math.log"(%2) : (vector<f16>) -> vector<f16>
911+
912+
// CHECK-NOT: math.log2
913+
%4 = "math.log2"(%3) : (vector<f16>) -> vector<f16>
914+
915+
// CHECK-NOT: math.log1p
916+
%5 = "math.log1p"(%4) : (vector<f16>) -> vector<f16>
917+
918+
// CHECK-NOT: math.erf
919+
%6 = "math.erf"(%5) : (vector<f16>) -> vector<f16>
920+
921+
// CHECK-NOT: math.exp
922+
%7 = "math.exp"(%6) : (vector<f16>) -> vector<f16>
923+
924+
// CHECK-NOT: math.expm1
925+
%8 = "math.expm1"(%7) : (vector<f16>) -> vector<f16>
926+
927+
// CHECK-NOT: math.cbrt
928+
%9 = "math.cbrt"(%8) : (vector<f16>) -> vector<f16>
929+
930+
// CHECK-NOT: math.sin
931+
%10 = "math.sin"(%9) : (vector<f16>) -> vector<f16>
932+
933+
// CHECK-NOT: math.cos
934+
%11 = "math.cos"(%10) : (vector<f16>) -> vector<f16>
935+
936+
return %11 : vector<f16>
937+
}
897938

898939
// AVX2-LABEL: @rsqrt_f16
899940
func.func @rsqrt_f16(%arg0 : vector<2x8xf16>) -> vector<2x8xf16> {

0 commit comments

Comments
 (0)