Skip to content

Commit c3b5c2c

Browse files
author
git apple-llvm automerger
committed
Merge commit 'e74bcecd36a5' from llvm.org/main into next
2 parents e11eb9e + e74bcec commit c3b5c2c

File tree

2 files changed

+123
-23
lines changed

2 files changed

+123
-23
lines changed

mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,24 @@ using namespace mlir;
3939
using namespace mlir::math;
4040
using namespace mlir::vector;
4141

42+
// Helper to encapsulate a vector's shape (including scalable dims).
43+
struct VectorShape {
44+
ArrayRef<int64_t> sizes;
45+
ArrayRef<bool> scalableFlags;
46+
47+
bool empty() const { return sizes.empty(); }
48+
};
49+
4250
// Returns vector shape if the type is a vector. Returns an empty shape if it is
4351
// not a vector.
44-
static ArrayRef<int64_t> vectorShape(Type type) {
52+
static VectorShape vectorShape(Type type) {
4553
auto vectorType = dyn_cast<VectorType>(type);
46-
return vectorType ? vectorType.getShape() : ArrayRef<int64_t>();
54+
return vectorType
55+
? VectorShape{vectorType.getShape(), vectorType.getScalableDims()}
56+
: VectorShape{};
4757
}
4858

49-
static ArrayRef<int64_t> vectorShape(Value value) {
59+
static VectorShape vectorShape(Value value) {
5060
return vectorShape(value.getType());
5161
}
5262

@@ -55,14 +65,16 @@ static ArrayRef<int64_t> vectorShape(Value value) {
5565
//----------------------------------------------------------------------------//
5666

5767
// Broadcasts scalar type into vector type (iff shape is non-scalar).
58-
static Type broadcast(Type type, ArrayRef<int64_t> shape) {
68+
static Type broadcast(Type type, VectorShape shape) {
5969
assert(!isa<VectorType>(type) && "must be scalar type");
60-
return !shape.empty() ? VectorType::get(shape, type) : type;
70+
return !shape.empty()
71+
? VectorType::get(shape.sizes, type, shape.scalableFlags)
72+
: type;
6173
}
6274

6375
// Broadcasts scalar value into vector (iff shape is non-scalar).
6476
static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
65-
ArrayRef<int64_t> shape) {
77+
VectorShape shape) {
6678
assert(!isa<VectorType>(value.getType()) && "must be scalar value");
6779
auto type = broadcast(value.getType(), shape);
6880
return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
@@ -215,7 +227,7 @@ static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
215227
static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
216228
bool isPositive = false) {
217229
assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
218-
ArrayRef<int64_t> shape = vectorShape(arg);
230+
VectorShape shape = vectorShape(arg);
219231

220232
auto bcast = [&](Value value) -> Value {
221233
return broadcast(builder, value, shape);
@@ -255,7 +267,7 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
255267
// Computes exp2 for an i32 argument.
256268
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
257269
assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
258-
ArrayRef<int64_t> shape = vectorShape(arg);
270+
VectorShape shape = vectorShape(arg);
259271

260272
auto bcast = [&](Value value) -> Value {
261273
return broadcast(builder, value, shape);
@@ -281,7 +293,7 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
281293
Type elementType = getElementTypeOrSelf(x);
282294
assert((elementType.isF32() || elementType.isF16()) &&
283295
"x must be f32 or f16 type");
284-
ArrayRef<int64_t> shape = vectorShape(x);
296+
VectorShape shape = vectorShape(x);
285297

286298
if (coeffs.empty())
287299
return broadcast(builder, floatCst(builder, 0.0f, elementType), shape);
@@ -379,7 +391,7 @@ AtanApproximation::matchAndRewrite(math::AtanOp op,
379391
if (!getElementTypeOrSelf(operand).isF32())
380392
return rewriter.notifyMatchFailure(op, "unsupported operand type");
381393

382-
ArrayRef<int64_t> shape = vectorShape(op.getOperand());
394+
VectorShape shape = vectorShape(op.getOperand());
383395

384396
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
385397
Value abs = builder.create<math::AbsFOp>(operand);
@@ -478,7 +490,7 @@ Atan2Approximation::matchAndRewrite(math::Atan2Op op,
478490
return rewriter.notifyMatchFailure(op, "unsupported operand type");
479491

480492
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
481-
ArrayRef<int64_t> shape = vectorShape(op.getResult());
493+
VectorShape shape = vectorShape(op.getResult());
482494

483495
// Compute atan in the valid range.
484496
auto div = builder.create<arith::DivFOp>(y, x);
@@ -544,7 +556,7 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
544556
if (!getElementTypeOrSelf(op.getOperand()).isF32())
545557
return rewriter.notifyMatchFailure(op, "unsupported operand type");
546558

547-
ArrayRef<int64_t> shape = vectorShape(op.getOperand());
559+
VectorShape shape = vectorShape(op.getOperand());
548560

549561
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
550562
auto bcast = [&](Value value) -> Value {
@@ -632,7 +644,7 @@ LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
632644
if (!getElementTypeOrSelf(op.getOperand()).isF32())
633645
return rewriter.notifyMatchFailure(op, "unsupported operand type");
634646

635-
ArrayRef<int64_t> shape = vectorShape(op.getOperand());
647+
VectorShape shape = vectorShape(op.getOperand());
636648

637649
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
638650
auto bcast = [&](Value value) -> Value {
@@ -779,7 +791,7 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
779791
if (!getElementTypeOrSelf(op.getOperand()).isF32())
780792
return rewriter.notifyMatchFailure(op, "unsupported operand type");
781793

782-
ArrayRef<int64_t> shape = vectorShape(op.getOperand());
794+
VectorShape shape = vectorShape(op.getOperand());
783795

784796
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
785797
auto bcast = [&](Value value) -> Value {
@@ -829,7 +841,7 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
829841
if (!(elementType.isF32() || elementType.isF16()))
830842
return rewriter.notifyMatchFailure(op,
831843
"only f32 and f16 type is supported.");
832-
ArrayRef<int64_t> shape = vectorShape(operand);
844+
VectorShape shape = vectorShape(operand);
833845

834846
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
835847
auto bcast = [&](Value value) -> Value {
@@ -938,9 +950,8 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
938950

939951
namespace {
940952

941-
Value clampWithNormals(ImplicitLocOpBuilder &builder,
942-
const llvm::ArrayRef<int64_t> shape, Value value,
943-
float lowerBound, float upperBound) {
953+
Value clampWithNormals(ImplicitLocOpBuilder &builder, const VectorShape shape,
954+
Value value, float lowerBound, float upperBound) {
944955
assert(!std::isnan(lowerBound));
945956
assert(!std::isnan(upperBound));
946957

@@ -1131,7 +1142,7 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
11311142
if (!getElementTypeOrSelf(op.getOperand()).isF32())
11321143
return rewriter.notifyMatchFailure(op, "unsupported operand type");
11331144

1134-
ArrayRef<int64_t> shape = vectorShape(op.getOperand());
1145+
VectorShape shape = vectorShape(op.getOperand());
11351146

11361147
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
11371148
auto bcast = [&](Value value) -> Value {
@@ -1201,7 +1212,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
12011212
if (!getElementTypeOrSelf(op.getOperand()).isF32())
12021213
return rewriter.notifyMatchFailure(op, "unsupported operand type");
12031214

1204-
ArrayRef<int64_t> shape = vectorShape(op.getOperand());
1215+
VectorShape shape = vectorShape(op.getOperand());
12051216

12061217
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
12071218
auto bcast = [&](Value value) -> Value {
@@ -1328,7 +1339,7 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op,
13281339
return rewriter.notifyMatchFailure(op, "unsupported operand type");
13291340

13301341
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1331-
ArrayRef<int64_t> shape = vectorShape(operand);
1342+
VectorShape shape = vectorShape(operand);
13321343

13331344
Type floatTy = getElementTypeOrSelf(operand.getType());
13341345
Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth());
@@ -1417,10 +1428,10 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
14171428
if (!getElementTypeOrSelf(op.getOperand()).isF32())
14181429
return rewriter.notifyMatchFailure(op, "unsupported operand type");
14191430

1420-
ArrayRef<int64_t> shape = vectorShape(op.getOperand());
1431+
VectorShape shape = vectorShape(op.getOperand());
14211432

14221433
// Only support already-vectorized rsqrt's.
1423-
if (shape.empty() || shape.back() % 8 != 0)
1434+
if (shape.empty() || shape.sizes.back() % 8 != 0)
14241435
return rewriter.notifyMatchFailure(op, "unsupported operand type");
14251436

14261437
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);

mlir/test/Dialect/Math/polynomial-approximation.mlir

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,20 @@ func.func @erf_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
9494
return %0 : vector<8xf32>
9595
}
9696

97+
// CHECK-LABEL: func @erf_scalable_vector(
98+
// CHECK-SAME: %[[arg0:.*]]: vector<[8]xf32>) -> vector<[8]xf32> {
99+
// CHECK: %[[zero:.*]] = arith.constant dense<0.000000e+00> : vector<[8]xf32>
100+
// CHECK-NOT: erf
101+
// CHECK-NOT: vector<8xf32>
102+
// CHECK-COUNT-20: select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
103+
// CHECK: %[[res:.*]] = arith.select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
104+
// CHECK: return %[[res]] : vector<[8]xf32>
105+
// CHECK: }
106+
func.func @erf_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
107+
%0 = math.erf %arg0 : vector<[8]xf32>
108+
return %0 : vector<[8]xf32>
109+
}
110+
97111
// CHECK-LABEL: func @exp_scalar(
98112
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
99113
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 5.000000e-01 : f32
@@ -151,6 +165,17 @@ func.func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
151165
return %0 : vector<8xf32>
152166
}
153167

168+
// CHECK-LABEL: func @exp_scalable_vector
169+
// CHECK-NOT: math.exp
170+
// CHECK-NOT: vector<8xf32>
171+
// CHECK-COUNT-46: vector<[8]x{{(i32)|(f32)}}>
172+
// CHECK-NOT: vector<8xf32>
173+
// CHECK-NOT: math.exp
174+
func.func @exp_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
175+
%0 = math.exp %arg0 : vector<[8]xf32>
176+
return %0 : vector<[8]xf32>
177+
}
178+
154179
// CHECK-LABEL: func @expm1_scalar(
155180
// CHECK-SAME: %[[X:.*]]: f32) -> f32 {
156181
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 1.000000e+00 : f32
@@ -277,6 +302,22 @@ func.func @expm1_vector(%arg0: vector<8x8xf32>) -> vector<8x8xf32> {
277302
return %0 : vector<8x8xf32>
278303
}
279304

305+
// CHECK-LABEL: func @expm1_scalable_vector(
306+
// CHECK-SAME: %{{.*}}: vector<8x[8]xf32>) -> vector<8x[8]xf32> {
307+
// CHECK-NOT: vector<8x8xf32>
308+
// CHECK-NOT: exp
309+
// CHECK-NOT: log
310+
// CHECK-NOT: expm1
311+
// CHECK-COUNT-127: vector<8x[8]x{{(i32)|(f32)|(i1)}}>
312+
// CHECK-NOT: vector<8x8xf32>
313+
// CHECK-NOT: exp
314+
// CHECK-NOT: log
315+
// CHECK-NOT: expm1
316+
func.func @expm1_scalable_vector(%arg0: vector<8x[8]xf32>) -> vector<8x[8]xf32> {
317+
%0 = math.expm1 %arg0 : vector<8x[8]xf32>
318+
return %0 : vector<8x[8]xf32>
319+
}
320+
280321
// CHECK-LABEL: func @log_scalar(
281322
// CHECK-SAME: %[[X:.*]]: f32) -> f32 {
282323
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
@@ -357,6 +398,18 @@ func.func @log_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
357398
return %0 : vector<8xf32>
358399
}
359400

401+
// CHECK-LABEL: func @log_scalable_vector(
402+
// CHECK-SAME: %{{.*}}: vector<[8]xf32>) -> vector<[8]xf32> {
403+
// CHECK: %[[CST_LN2:.*]] = arith.constant dense<0.693147182> : vector<[8]xf32>
404+
// CHECK-COUNT-5: select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
405+
// CHECK: %[[VAL_71:.*]] = arith.select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
406+
// CHECK: return %[[VAL_71]] : vector<[8]xf32>
407+
// CHECK: }
408+
func.func @log_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
409+
%0 = math.log %arg0 : vector<[8]xf32>
410+
return %0 : vector<[8]xf32>
411+
}
412+
360413
// CHECK-LABEL: func @log2_scalar(
361414
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
362415
// CHECK: %[[CST_LOG2E:.*]] = arith.constant 1.44269502 : f32
@@ -381,6 +434,18 @@ func.func @log2_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
381434
return %0 : vector<8xf32>
382435
}
383436

437+
// CHECK-LABEL: func @log2_scalable_vector(
438+
// CHECK-SAME: %{{.*}}: vector<[8]xf32>) -> vector<[8]xf32> {
439+
// CHECK: %[[CST_LOG2E:.*]] = arith.constant dense<1.44269502> : vector<[8]xf32>
440+
// CHECK-COUNT-5: select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
441+
// CHECK: %[[VAL_71:.*]] = arith.select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
442+
// CHECK: return %[[VAL_71]] : vector<[8]xf32>
443+
// CHECK: }
444+
func.func @log2_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
445+
%0 = math.log2 %arg0 : vector<[8]xf32>
446+
return %0 : vector<[8]xf32>
447+
}
448+
384449
// CHECK-LABEL: func @log1p_scalar(
385450
// CHECK-SAME: %[[X:.*]]: f32) -> f32 {
386451
// CHECK: %[[CST_ONE:.*]] = arith.constant 1.000000e+00 : f32
@@ -414,6 +479,17 @@ func.func @log1p_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
414479
return %0 : vector<8xf32>
415480
}
416481

482+
// CHECK-LABEL: func @log1p_scalable_vector(
483+
// CHECK-SAME: %[[VAL_0:.*]]: vector<[8]xf32>) -> vector<[8]xf32> {
484+
// CHECK: %[[CST_ONE:.*]] = arith.constant dense<1.000000e+00> : vector<[8]xf32>
485+
// CHECK-COUNT-6: select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
486+
// CHECK: %[[VAL_79:.*]] = arith.select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
487+
// CHECK: return %[[VAL_79]] : vector<[8]xf32>
488+
// CHECK: }
489+
func.func @log1p_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
490+
%0 = math.log1p %arg0 : vector<[8]xf32>
491+
return %0 : vector<[8]xf32>
492+
}
417493

418494
// CHECK-LABEL: func @tanh_scalar(
419495
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
@@ -470,6 +546,19 @@ func.func @tanh_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
470546
return %0 : vector<8xf32>
471547
}
472548

549+
// CHECK-LABEL: func @tanh_scalable_vector(
550+
// CHECK-SAME: %[[VAL_0:.*]]: vector<[8]xf32>) -> vector<[8]xf32> {
551+
// CHECK: %[[VAL_1:.*]] = arith.constant dense<-7.99881172> : vector<[8]xf32>
552+
// CHECK-NOT: tanh
553+
// CHECK-COUNT-2: select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
554+
// CHECK: %[[VAL_33:.*]] = arith.select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
555+
// CHECK: return %[[VAL_33]] : vector<[8]xf32>
556+
// CHECK: }
557+
func.func @tanh_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
558+
%0 = math.tanh %arg0 : vector<[8]xf32>
559+
return %0 : vector<[8]xf32>
560+
}
561+
473562
// We only approximate rsqrt for vectors and when the AVX2 option is enabled.
474563
// CHECK-LABEL: func @rsqrt_scalar
475564
// AVX2-LABEL: func @rsqrt_scalar

0 commit comments

Comments
 (0)