Skip to content

Commit e74bcec

Browse files
authored
[mlir][math] Propagate scalability in polynomial approximation (#84949)
This simply updates the rewrites to propagate the scalable flags (which as they do not alter the vector shape, is pretty simple). The added tests are simply scalable versions of the existing vector tests.
1 parent f75d164 commit e74bcec

File tree

2 files changed

+123
-23
lines changed

2 files changed

+123
-23
lines changed

mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,24 @@ using namespace mlir;
3939
using namespace mlir::math;
4040
using namespace mlir::vector;
4141

42+
// Helper to encapsulate a vector's shape (including scalable dims).
43+
struct VectorShape {
44+
ArrayRef<int64_t> sizes;
45+
ArrayRef<bool> scalableFlags;
46+
47+
bool empty() const { return sizes.empty(); }
48+
};
49+
4250
// Returns vector shape if the type is a vector. Returns an empty shape if it is
4351
// not a vector.
44-
static ArrayRef<int64_t> vectorShape(Type type) {
52+
static VectorShape vectorShape(Type type) {
4553
auto vectorType = dyn_cast<VectorType>(type);
46-
return vectorType ? vectorType.getShape() : ArrayRef<int64_t>();
54+
return vectorType
55+
? VectorShape{vectorType.getShape(), vectorType.getScalableDims()}
56+
: VectorShape{};
4757
}
4858

49-
static ArrayRef<int64_t> vectorShape(Value value) {
59+
static VectorShape vectorShape(Value value) {
5060
return vectorShape(value.getType());
5161
}
5262

@@ -55,14 +65,16 @@ static ArrayRef<int64_t> vectorShape(Value value) {
5565
//----------------------------------------------------------------------------//
5666

5767
// Broadcasts scalar type into vector type (iff shape is non-scalar).
58-
static Type broadcast(Type type, ArrayRef<int64_t> shape) {
68+
static Type broadcast(Type type, VectorShape shape) {
5969
assert(!isa<VectorType>(type) && "must be scalar type");
60-
return !shape.empty() ? VectorType::get(shape, type) : type;
70+
return !shape.empty()
71+
? VectorType::get(shape.sizes, type, shape.scalableFlags)
72+
: type;
6173
}
6274

6375
// Broadcasts scalar value into vector (iff shape is non-scalar).
6476
static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
65-
ArrayRef<int64_t> shape) {
77+
VectorShape shape) {
6678
assert(!isa<VectorType>(value.getType()) && "must be scalar value");
6779
auto type = broadcast(value.getType(), shape);
6880
return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
@@ -215,7 +227,7 @@ static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
215227
static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
216228
bool isPositive = false) {
217229
assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
218-
ArrayRef<int64_t> shape = vectorShape(arg);
230+
VectorShape shape = vectorShape(arg);
219231

220232
auto bcast = [&](Value value) -> Value {
221233
return broadcast(builder, value, shape);
@@ -255,7 +267,7 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
255267
// Computes exp2 for an i32 argument.
256268
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
257269
assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
258-
ArrayRef<int64_t> shape = vectorShape(arg);
270+
VectorShape shape = vectorShape(arg);
259271

260272
auto bcast = [&](Value value) -> Value {
261273
return broadcast(builder, value, shape);
@@ -281,7 +293,7 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
281293
Type elementType = getElementTypeOrSelf(x);
282294
assert((elementType.isF32() || elementType.isF16()) &&
283295
"x must be f32 or f16 type");
284-
ArrayRef<int64_t> shape = vectorShape(x);
296+
VectorShape shape = vectorShape(x);
285297

286298
if (coeffs.empty())
287299
return broadcast(builder, floatCst(builder, 0.0f, elementType), shape);
@@ -379,7 +391,7 @@ AtanApproximation::matchAndRewrite(math::AtanOp op,
379391
if (!getElementTypeOrSelf(operand).isF32())
380392
return rewriter.notifyMatchFailure(op, "unsupported operand type");
381393

382-
ArrayRef<int64_t> shape = vectorShape(op.getOperand());
394+
VectorShape shape = vectorShape(op.getOperand());
383395

384396
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
385397
Value abs = builder.create<math::AbsFOp>(operand);
@@ -478,7 +490,7 @@ Atan2Approximation::matchAndRewrite(math::Atan2Op op,
478490
return rewriter.notifyMatchFailure(op, "unsupported operand type");
479491

480492
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
481-
ArrayRef<int64_t> shape = vectorShape(op.getResult());
493+
VectorShape shape = vectorShape(op.getResult());
482494

483495
// Compute atan in the valid range.
484496
auto div = builder.create<arith::DivFOp>(y, x);
@@ -544,7 +556,7 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
544556
if (!getElementTypeOrSelf(op.getOperand()).isF32())
545557
return rewriter.notifyMatchFailure(op, "unsupported operand type");
546558

547-
ArrayRef<int64_t> shape = vectorShape(op.getOperand());
559+
VectorShape shape = vectorShape(op.getOperand());
548560

549561
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
550562
auto bcast = [&](Value value) -> Value {
@@ -632,7 +644,7 @@ LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
632644
if (!getElementTypeOrSelf(op.getOperand()).isF32())
633645
return rewriter.notifyMatchFailure(op, "unsupported operand type");
634646

635-
ArrayRef<int64_t> shape = vectorShape(op.getOperand());
647+
VectorShape shape = vectorShape(op.getOperand());
636648

637649
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
638650
auto bcast = [&](Value value) -> Value {
@@ -779,7 +791,7 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
779791
if (!getElementTypeOrSelf(op.getOperand()).isF32())
780792
return rewriter.notifyMatchFailure(op, "unsupported operand type");
781793

782-
ArrayRef<int64_t> shape = vectorShape(op.getOperand());
794+
VectorShape shape = vectorShape(op.getOperand());
783795

784796
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
785797
auto bcast = [&](Value value) -> Value {
@@ -829,7 +841,7 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
829841
if (!(elementType.isF32() || elementType.isF16()))
830842
return rewriter.notifyMatchFailure(op,
831843
"only f32 and f16 type is supported.");
832-
ArrayRef<int64_t> shape = vectorShape(operand);
844+
VectorShape shape = vectorShape(operand);
833845

834846
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
835847
auto bcast = [&](Value value) -> Value {
@@ -938,9 +950,8 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
938950

939951
namespace {
940952

941-
Value clampWithNormals(ImplicitLocOpBuilder &builder,
942-
const llvm::ArrayRef<int64_t> shape, Value value,
943-
float lowerBound, float upperBound) {
953+
Value clampWithNormals(ImplicitLocOpBuilder &builder, const VectorShape shape,
954+
Value value, float lowerBound, float upperBound) {
944955
assert(!std::isnan(lowerBound));
945956
assert(!std::isnan(upperBound));
946957

@@ -1131,7 +1142,7 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
11311142
if (!getElementTypeOrSelf(op.getOperand()).isF32())
11321143
return rewriter.notifyMatchFailure(op, "unsupported operand type");
11331144

1134-
ArrayRef<int64_t> shape = vectorShape(op.getOperand());
1145+
VectorShape shape = vectorShape(op.getOperand());
11351146

11361147
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
11371148
auto bcast = [&](Value value) -> Value {
@@ -1201,7 +1212,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
12011212
if (!getElementTypeOrSelf(op.getOperand()).isF32())
12021213
return rewriter.notifyMatchFailure(op, "unsupported operand type");
12031214

1204-
ArrayRef<int64_t> shape = vectorShape(op.getOperand());
1215+
VectorShape shape = vectorShape(op.getOperand());
12051216

12061217
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
12071218
auto bcast = [&](Value value) -> Value {
@@ -1328,7 +1339,7 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op,
13281339
return rewriter.notifyMatchFailure(op, "unsupported operand type");
13291340

13301341
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1331-
ArrayRef<int64_t> shape = vectorShape(operand);
1342+
VectorShape shape = vectorShape(operand);
13321343

13331344
Type floatTy = getElementTypeOrSelf(operand.getType());
13341345
Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth());
@@ -1417,10 +1428,10 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
14171428
if (!getElementTypeOrSelf(op.getOperand()).isF32())
14181429
return rewriter.notifyMatchFailure(op, "unsupported operand type");
14191430

1420-
ArrayRef<int64_t> shape = vectorShape(op.getOperand());
1431+
VectorShape shape = vectorShape(op.getOperand());
14211432

14221433
// Only support already-vectorized rsqrt's.
1423-
if (shape.empty() || shape.back() % 8 != 0)
1434+
if (shape.empty() || shape.sizes.back() % 8 != 0)
14241435
return rewriter.notifyMatchFailure(op, "unsupported operand type");
14251436

14261437
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);

mlir/test/Dialect/Math/polynomial-approximation.mlir

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,20 @@ func.func @erf_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
9494
return %0 : vector<8xf32>
9595
}
9696

97+
// CHECK-LABEL: func @erf_scalable_vector(
98+
// CHECK-SAME: %[[arg0:.*]]: vector<[8]xf32>) -> vector<[8]xf32> {
99+
// CHECK: %[[zero:.*]] = arith.constant dense<0.000000e+00> : vector<[8]xf32>
100+
// CHECK-NOT: erf
101+
// CHECK-NOT: vector<8xf32>
102+
// CHECK-COUNT-20: select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
103+
// CHECK: %[[res:.*]] = arith.select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
104+
// CHECK: return %[[res]] : vector<[8]xf32>
105+
// CHECK: }
106+
func.func @erf_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
107+
%0 = math.erf %arg0 : vector<[8]xf32>
108+
return %0 : vector<[8]xf32>
109+
}
110+
97111
// CHECK-LABEL: func @exp_scalar(
98112
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
99113
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 5.000000e-01 : f32
@@ -151,6 +165,17 @@ func.func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
151165
return %0 : vector<8xf32>
152166
}
153167

168+
// CHECK-LABEL: func @exp_scalable_vector
169+
// CHECK-NOT: math.exp
170+
// CHECK-NOT: vector<8xf32>
171+
// CHECK-COUNT-46: vector<[8]x{{(i32)|(f32)}}>
172+
// CHECK-NOT: vector<8xf32>
173+
// CHECK-NOT: math.exp
174+
func.func @exp_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
175+
%0 = math.exp %arg0 : vector<[8]xf32>
176+
return %0 : vector<[8]xf32>
177+
}
178+
154179
// CHECK-LABEL: func @expm1_scalar(
155180
// CHECK-SAME: %[[X:.*]]: f32) -> f32 {
156181
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 1.000000e+00 : f32
@@ -277,6 +302,22 @@ func.func @expm1_vector(%arg0: vector<8x8xf32>) -> vector<8x8xf32> {
277302
return %0 : vector<8x8xf32>
278303
}
279304

305+
// CHECK-LABEL: func @expm1_scalable_vector(
306+
// CHECK-SAME: %{{.*}}: vector<8x[8]xf32>) -> vector<8x[8]xf32> {
307+
// CHECK-NOT: vector<8x8xf32>
308+
// CHECK-NOT: exp
309+
// CHECK-NOT: log
310+
// CHECK-NOT: expm1
311+
// CHECK-COUNT-127: vector<8x[8]x{{(i32)|(f32)|(i1)}}>
312+
// CHECK-NOT: vector<8x8xf32>
313+
// CHECK-NOT: exp
314+
// CHECK-NOT: log
315+
// CHECK-NOT: expm1
316+
func.func @expm1_scalable_vector(%arg0: vector<8x[8]xf32>) -> vector<8x[8]xf32> {
317+
%0 = math.expm1 %arg0 : vector<8x[8]xf32>
318+
return %0 : vector<8x[8]xf32>
319+
}
320+
280321
// CHECK-LABEL: func @log_scalar(
281322
// CHECK-SAME: %[[X:.*]]: f32) -> f32 {
282323
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
@@ -357,6 +398,18 @@ func.func @log_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
357398
return %0 : vector<8xf32>
358399
}
359400

401+
// CHECK-LABEL: func @log_scalable_vector(
402+
// CHECK-SAME: %{{.*}}: vector<[8]xf32>) -> vector<[8]xf32> {
403+
// CHECK: %[[CST_LN2:.*]] = arith.constant dense<0.693147182> : vector<[8]xf32>
404+
// CHECK-COUNT-5: select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
405+
// CHECK: %[[VAL_71:.*]] = arith.select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
406+
// CHECK: return %[[VAL_71]] : vector<[8]xf32>
407+
// CHECK: }
408+
func.func @log_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
409+
%0 = math.log %arg0 : vector<[8]xf32>
410+
return %0 : vector<[8]xf32>
411+
}
412+
360413
// CHECK-LABEL: func @log2_scalar(
361414
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
362415
// CHECK: %[[CST_LOG2E:.*]] = arith.constant 1.44269502 : f32
@@ -381,6 +434,18 @@ func.func @log2_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
381434
return %0 : vector<8xf32>
382435
}
383436

437+
// CHECK-LABEL: func @log2_scalable_vector(
438+
// CHECK-SAME: %{{.*}}: vector<[8]xf32>) -> vector<[8]xf32> {
439+
// CHECK: %[[CST_LOG2E:.*]] = arith.constant dense<1.44269502> : vector<[8]xf32>
440+
// CHECK-COUNT-5: select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
441+
// CHECK: %[[VAL_71:.*]] = arith.select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
442+
// CHECK: return %[[VAL_71]] : vector<[8]xf32>
443+
// CHECK: }
444+
func.func @log2_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
445+
%0 = math.log2 %arg0 : vector<[8]xf32>
446+
return %0 : vector<[8]xf32>
447+
}
448+
384449
// CHECK-LABEL: func @log1p_scalar(
385450
// CHECK-SAME: %[[X:.*]]: f32) -> f32 {
386451
// CHECK: %[[CST_ONE:.*]] = arith.constant 1.000000e+00 : f32
@@ -414,6 +479,17 @@ func.func @log1p_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
414479
return %0 : vector<8xf32>
415480
}
416481

482+
// CHECK-LABEL: func @log1p_scalable_vector(
483+
// CHECK-SAME: %[[VAL_0:.*]]: vector<[8]xf32>) -> vector<[8]xf32> {
484+
// CHECK: %[[CST_ONE:.*]] = arith.constant dense<1.000000e+00> : vector<[8]xf32>
485+
// CHECK-COUNT-6: select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
486+
// CHECK: %[[VAL_79:.*]] = arith.select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
487+
// CHECK: return %[[VAL_79]] : vector<[8]xf32>
488+
// CHECK: }
489+
func.func @log1p_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
490+
%0 = math.log1p %arg0 : vector<[8]xf32>
491+
return %0 : vector<[8]xf32>
492+
}
417493

418494
// CHECK-LABEL: func @tanh_scalar(
419495
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
@@ -470,6 +546,19 @@ func.func @tanh_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
470546
return %0 : vector<8xf32>
471547
}
472548

549+
// CHECK-LABEL: func @tanh_scalable_vector(
550+
// CHECK-SAME: %[[VAL_0:.*]]: vector<[8]xf32>) -> vector<[8]xf32> {
551+
// CHECK: %[[VAL_1:.*]] = arith.constant dense<-7.99881172> : vector<[8]xf32>
552+
// CHECK-NOT: tanh
553+
// CHECK-COUNT-2: select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
554+
// CHECK: %[[VAL_33:.*]] = arith.select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
555+
// CHECK: return %[[VAL_33]] : vector<[8]xf32>
556+
// CHECK: }
557+
func.func @tanh_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
558+
%0 = math.tanh %arg0 : vector<[8]xf32>
559+
return %0 : vector<[8]xf32>
560+
}
561+
473562
// We only approximate rsqrt for vectors and when the AVX2 option is enabled.
474563
// CHECK-LABEL: func @rsqrt_scalar
475564
// AVX2-LABEL: func @rsqrt_scalar

0 commit comments

Comments
 (0)