@@ -39,14 +39,24 @@ using namespace mlir;
39
39
using namespace mlir ::math;
40
40
using namespace mlir ::vector;
41
41
42
+ // Helper to encapsulate a vector's shape (including scalable dims).
43
+ struct VectorShape {
44
+ ArrayRef<int64_t > sizes;
45
+ ArrayRef<bool > scalableFlags;
46
+
47
+ bool empty () const { return sizes.empty (); }
48
+ };
49
+
42
50
// Returns vector shape if the type is a vector. Returns an empty shape if it is
43
51
// not a vector.
44
- static ArrayRef< int64_t > vectorShape (Type type) {
52
+ static VectorShape vectorShape (Type type) {
45
53
auto vectorType = dyn_cast<VectorType>(type);
46
- return vectorType ? vectorType.getShape () : ArrayRef<int64_t >();
54
+ return vectorType
55
+ ? VectorShape{vectorType.getShape (), vectorType.getScalableDims ()}
56
+ : VectorShape{};
47
57
}
48
58
49
- static ArrayRef< int64_t > vectorShape (Value value) {
59
+ static VectorShape vectorShape (Value value) {
50
60
return vectorShape (value.getType ());
51
61
}
52
62
@@ -55,14 +65,16 @@ static ArrayRef<int64_t> vectorShape(Value value) {
55
65
// ----------------------------------------------------------------------------//
56
66
57
67
// Broadcasts scalar type into vector type (iff shape is non-scalar).
58
- static Type broadcast (Type type, ArrayRef< int64_t > shape) {
68
+ static Type broadcast (Type type, VectorShape shape) {
59
69
assert (!isa<VectorType>(type) && " must be scalar type" );
60
- return !shape.empty () ? VectorType::get (shape, type) : type;
70
+ return !shape.empty ()
71
+ ? VectorType::get (shape.sizes , type, shape.scalableFlags )
72
+ : type;
61
73
}
62
74
63
75
// Broadcasts scalar value into vector (iff shape is non-scalar).
64
76
static Value broadcast (ImplicitLocOpBuilder &builder, Value value,
65
- ArrayRef< int64_t > shape) {
77
+ VectorShape shape) {
66
78
assert (!isa<VectorType>(value.getType ()) && " must be scalar value" );
67
79
auto type = broadcast (value.getType (), shape);
68
80
return !shape.empty () ? builder.create <BroadcastOp>(type, value) : value;
@@ -215,7 +227,7 @@ static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
215
227
static std::pair<Value, Value> frexp (ImplicitLocOpBuilder &builder, Value arg,
216
228
bool isPositive = false ) {
217
229
assert (getElementTypeOrSelf (arg).isF32 () && " arg must be f32 type" );
218
- ArrayRef< int64_t > shape = vectorShape (arg);
230
+ VectorShape shape = vectorShape (arg);
219
231
220
232
auto bcast = [&](Value value) -> Value {
221
233
return broadcast (builder, value, shape);
@@ -255,7 +267,7 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
255
267
// Computes exp2 for an i32 argument.
256
268
static Value exp2I32 (ImplicitLocOpBuilder &builder, Value arg) {
257
269
assert (getElementTypeOrSelf (arg).isInteger (32 ) && " arg must be i32 type" );
258
- ArrayRef< int64_t > shape = vectorShape (arg);
270
+ VectorShape shape = vectorShape (arg);
259
271
260
272
auto bcast = [&](Value value) -> Value {
261
273
return broadcast (builder, value, shape);
@@ -281,7 +293,7 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
281
293
Type elementType = getElementTypeOrSelf (x);
282
294
assert ((elementType.isF32 () || elementType.isF16 ()) &&
283
295
" x must be f32 or f16 type" );
284
- ArrayRef< int64_t > shape = vectorShape (x);
296
+ VectorShape shape = vectorShape (x);
285
297
286
298
if (coeffs.empty ())
287
299
return broadcast (builder, floatCst (builder, 0 .0f , elementType), shape);
@@ -379,7 +391,7 @@ AtanApproximation::matchAndRewrite(math::AtanOp op,
379
391
if (!getElementTypeOrSelf (operand).isF32 ())
380
392
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
381
393
382
- ArrayRef< int64_t > shape = vectorShape (op.getOperand ());
394
+ VectorShape shape = vectorShape (op.getOperand ());
383
395
384
396
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
385
397
Value abs = builder.create <math::AbsFOp>(operand);
@@ -478,7 +490,7 @@ Atan2Approximation::matchAndRewrite(math::Atan2Op op,
478
490
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
479
491
480
492
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
481
- ArrayRef< int64_t > shape = vectorShape (op.getResult ());
493
+ VectorShape shape = vectorShape (op.getResult ());
482
494
483
495
// Compute atan in the valid range.
484
496
auto div = builder.create <arith::DivFOp>(y, x);
@@ -544,7 +556,7 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
544
556
if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
545
557
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
546
558
547
- ArrayRef< int64_t > shape = vectorShape (op.getOperand ());
559
+ VectorShape shape = vectorShape (op.getOperand ());
548
560
549
561
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
550
562
auto bcast = [&](Value value) -> Value {
@@ -632,7 +644,7 @@ LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
632
644
if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
633
645
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
634
646
635
- ArrayRef< int64_t > shape = vectorShape (op.getOperand ());
647
+ VectorShape shape = vectorShape (op.getOperand ());
636
648
637
649
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
638
650
auto bcast = [&](Value value) -> Value {
@@ -779,7 +791,7 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
779
791
if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
780
792
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
781
793
782
- ArrayRef< int64_t > shape = vectorShape (op.getOperand ());
794
+ VectorShape shape = vectorShape (op.getOperand ());
783
795
784
796
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
785
797
auto bcast = [&](Value value) -> Value {
@@ -829,7 +841,7 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
829
841
if (!(elementType.isF32 () || elementType.isF16 ()))
830
842
return rewriter.notifyMatchFailure (op,
831
843
" only f32 and f16 type is supported." );
832
- ArrayRef< int64_t > shape = vectorShape (operand);
844
+ VectorShape shape = vectorShape (operand);
833
845
834
846
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
835
847
auto bcast = [&](Value value) -> Value {
@@ -938,9 +950,8 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
938
950
939
951
namespace {
940
952
941
- Value clampWithNormals (ImplicitLocOpBuilder &builder,
942
- const llvm::ArrayRef<int64_t > shape, Value value,
943
- float lowerBound, float upperBound) {
953
+ Value clampWithNormals (ImplicitLocOpBuilder &builder, const VectorShape shape,
954
+ Value value, float lowerBound, float upperBound) {
944
955
assert (!std::isnan (lowerBound));
945
956
assert (!std::isnan (upperBound));
946
957
@@ -1131,7 +1142,7 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
1131
1142
if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
1132
1143
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
1133
1144
1134
- ArrayRef< int64_t > shape = vectorShape (op.getOperand ());
1145
+ VectorShape shape = vectorShape (op.getOperand ());
1135
1146
1136
1147
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
1137
1148
auto bcast = [&](Value value) -> Value {
@@ -1201,7 +1212,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1201
1212
if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
1202
1213
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
1203
1214
1204
- ArrayRef< int64_t > shape = vectorShape (op.getOperand ());
1215
+ VectorShape shape = vectorShape (op.getOperand ());
1205
1216
1206
1217
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
1207
1218
auto bcast = [&](Value value) -> Value {
@@ -1328,7 +1339,7 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op,
1328
1339
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
1329
1340
1330
1341
ImplicitLocOpBuilder b (op->getLoc (), rewriter);
1331
- ArrayRef< int64_t > shape = vectorShape (operand);
1342
+ VectorShape shape = vectorShape (operand);
1332
1343
1333
1344
Type floatTy = getElementTypeOrSelf (operand.getType ());
1334
1345
Type intTy = b.getIntegerType (floatTy.getIntOrFloatBitWidth ());
@@ -1417,10 +1428,10 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1417
1428
if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
1418
1429
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
1419
1430
1420
- ArrayRef< int64_t > shape = vectorShape (op.getOperand ());
1431
+ VectorShape shape = vectorShape (op.getOperand ());
1421
1432
1422
1433
// Only support already-vectorized rsqrt's.
1423
- if (shape.empty () || shape.back () % 8 != 0 )
1434
+ if (shape.empty () || shape.sizes . back () % 8 != 0 )
1424
1435
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
1425
1436
1426
1437
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
0 commit comments