29
29
#include " mlir/IR/OpDefinition.h"
30
30
#include " mlir/IR/PatternMatch.h"
31
31
#include " mlir/IR/TypeUtilities.h"
32
+ #include " mlir/Support/ScalableVectorType.h"
32
33
#include " mlir/Transforms/DialectConversion.h"
33
34
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
34
35
#include " llvm/ADT/ArrayRef.h"
@@ -39,24 +40,14 @@ using namespace mlir;
39
40
using namespace mlir ::math;
40
41
using namespace mlir ::vector;
41
42
42
- // Helper to encapsulate a vector's shape (including scalable dims).
43
- struct VectorShape {
44
- ArrayRef<int64_t > sizes;
45
- ArrayRef<bool > scalableFlags;
46
-
47
- bool empty () const { return sizes.empty (); }
48
- };
49
-
50
43
// Returns vector shape if the type is a vector. Returns an empty shape if it is
51
44
// not a vector.
52
- static VectorShape vectorShape (Type type) {
45
+ static VectorDimList vectorShape (Type type) {
53
46
auto vectorType = dyn_cast<VectorType>(type);
54
- return vectorType
55
- ? VectorShape{vectorType.getShape (), vectorType.getScalableDims ()}
56
- : VectorShape{};
47
+ return VectorDimList::from (vectorType);
57
48
}
58
49
59
- static VectorShape vectorShape (Value value) {
50
+ static VectorDimList vectorShape (Value value) {
60
51
return vectorShape (value.getType ());
61
52
}
62
53
@@ -65,16 +56,14 @@ static VectorShape vectorShape(Value value) {
65
56
// ----------------------------------------------------------------------------//
66
57
67
58
// Broadcasts scalar type into vector type (iff shape is non-scalar).
68
- static Type broadcast (Type type, VectorShape shape) {
59
+ static Type broadcast (Type type, VectorDimList shape) {
69
60
assert (!isa<VectorType>(type) && " must be scalar type" );
70
- return !shape.empty ()
71
- ? VectorType::get (shape.sizes , type, shape.scalableFlags )
72
- : type;
61
+ return !shape.empty () ? ScalableVectorType::get (shape, type) : type;
73
62
}
74
63
75
64
// Broadcasts scalar value into vector (iff shape is non-scalar).
76
65
static Value broadcast (ImplicitLocOpBuilder &builder, Value value,
77
- VectorShape shape) {
66
+ VectorDimList shape) {
78
67
assert (!isa<VectorType>(value.getType ()) && " must be scalar value" );
79
68
auto type = broadcast (value.getType (), shape);
80
69
return !shape.empty () ? builder.create <BroadcastOp>(type, value) : value;
@@ -227,7 +216,7 @@ static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
227
216
static std::pair<Value, Value> frexp (ImplicitLocOpBuilder &builder, Value arg,
228
217
bool isPositive = false ) {
229
218
assert (getElementTypeOrSelf (arg).isF32 () && " arg must be f32 type" );
230
- VectorShape shape = vectorShape (arg);
219
+ VectorDimList shape = vectorShape (arg);
231
220
232
221
auto bcast = [&](Value value) -> Value {
233
222
return broadcast (builder, value, shape);
@@ -267,7 +256,7 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
267
256
// Computes exp2 for an i32 argument.
268
257
static Value exp2I32 (ImplicitLocOpBuilder &builder, Value arg) {
269
258
assert (getElementTypeOrSelf (arg).isInteger (32 ) && " arg must be i32 type" );
270
- VectorShape shape = vectorShape (arg);
259
+ VectorDimList shape = vectorShape (arg);
271
260
272
261
auto bcast = [&](Value value) -> Value {
273
262
return broadcast (builder, value, shape);
@@ -293,7 +282,7 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
293
282
Type elementType = getElementTypeOrSelf (x);
294
283
assert ((elementType.isF32 () || elementType.isF16 ()) &&
295
284
" x must be f32 or f16 type" );
296
- VectorShape shape = vectorShape (x);
285
+ VectorDimList shape = vectorShape (x);
297
286
298
287
if (coeffs.empty ())
299
288
return broadcast (builder, floatCst (builder, 0 .0f , elementType), shape);
@@ -391,7 +380,7 @@ AtanApproximation::matchAndRewrite(math::AtanOp op,
391
380
if (!getElementTypeOrSelf (operand).isF32 ())
392
381
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
393
382
394
- VectorShape shape = vectorShape (op.getOperand ());
383
+ VectorDimList shape = vectorShape (op.getOperand ());
395
384
396
385
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
397
386
Value abs = builder.create <math::AbsFOp>(operand);
@@ -490,7 +479,7 @@ Atan2Approximation::matchAndRewrite(math::Atan2Op op,
490
479
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
491
480
492
481
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
493
- VectorShape shape = vectorShape (op.getResult ());
482
+ VectorDimList shape = vectorShape (op.getResult ());
494
483
495
484
// Compute atan in the valid range.
496
485
auto div = builder.create <arith::DivFOp>(y, x);
@@ -556,7 +545,7 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
556
545
if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
557
546
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
558
547
559
- VectorShape shape = vectorShape (op.getOperand ());
548
+ VectorDimList shape = vectorShape (op.getOperand ());
560
549
561
550
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
562
551
auto bcast = [&](Value value) -> Value {
@@ -644,7 +633,7 @@ LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
644
633
if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
645
634
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
646
635
647
- VectorShape shape = vectorShape (op.getOperand ());
636
+ VectorDimList shape = vectorShape (op.getOperand ());
648
637
649
638
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
650
639
auto bcast = [&](Value value) -> Value {
@@ -791,7 +780,7 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
791
780
if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
792
781
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
793
782
794
- VectorShape shape = vectorShape (op.getOperand ());
783
+ VectorDimList shape = vectorShape (op.getOperand ());
795
784
796
785
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
797
786
auto bcast = [&](Value value) -> Value {
@@ -846,7 +835,7 @@ AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
846
835
if (!(elementType.isF32 () || elementType.isF16 ()))
847
836
return rewriter.notifyMatchFailure (op,
848
837
" only f32 and f16 type is supported." );
849
- VectorShape shape = vectorShape (operand);
838
+ VectorDimList shape = vectorShape (operand);
850
839
851
840
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
852
841
auto bcast = [&](Value value) -> Value {
@@ -910,7 +899,7 @@ AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
910
899
if (!(elementType.isF32 () || elementType.isF16 ()))
911
900
return rewriter.notifyMatchFailure (op,
912
901
" only f32 and f16 type is supported." );
913
- VectorShape shape = vectorShape (operand);
902
+ VectorDimList shape = vectorShape (operand);
914
903
915
904
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
916
905
auto bcast = [&](Value value) -> Value {
@@ -988,7 +977,7 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
988
977
if (!(elementType.isF32 () || elementType.isF16 ()))
989
978
return rewriter.notifyMatchFailure (op,
990
979
" only f32 and f16 type is supported." );
991
- VectorShape shape = vectorShape (operand);
980
+ VectorDimList shape = vectorShape (operand);
992
981
993
982
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
994
983
auto bcast = [&](Value value) -> Value {
@@ -1097,7 +1086,7 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
1097
1086
1098
1087
namespace {
1099
1088
1100
- Value clampWithNormals (ImplicitLocOpBuilder &builder, const VectorShape shape,
1089
+ Value clampWithNormals (ImplicitLocOpBuilder &builder, const VectorDimList shape,
1101
1090
Value value, float lowerBound, float upperBound) {
1102
1091
assert (!std::isnan (lowerBound));
1103
1092
assert (!std::isnan (upperBound));
@@ -1289,7 +1278,7 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
1289
1278
if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
1290
1279
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
1291
1280
1292
- VectorShape shape = vectorShape (op.getOperand ());
1281
+ VectorDimList shape = vectorShape (op.getOperand ());
1293
1282
1294
1283
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
1295
1284
auto bcast = [&](Value value) -> Value {
@@ -1359,7 +1348,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1359
1348
if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
1360
1349
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
1361
1350
1362
- VectorShape shape = vectorShape (op.getOperand ());
1351
+ VectorDimList shape = vectorShape (op.getOperand ());
1363
1352
1364
1353
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
1365
1354
auto bcast = [&](Value value) -> Value {
@@ -1486,7 +1475,7 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op,
1486
1475
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
1487
1476
1488
1477
ImplicitLocOpBuilder b (op->getLoc (), rewriter);
1489
- VectorShape shape = vectorShape (operand);
1478
+ VectorDimList shape = vectorShape (operand);
1490
1479
1491
1480
Type floatTy = getElementTypeOrSelf (operand.getType ());
1492
1481
Type intTy = b.getIntegerType (floatTy.getIntOrFloatBitWidth ());
@@ -1575,7 +1564,7 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1575
1564
if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
1576
1565
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
1577
1566
1578
- VectorShape shape = vectorShape (op.getOperand ());
1567
+ VectorDimList shape = vectorShape (op.getOperand ());
1579
1568
1580
1569
// Only support already-vectorized rsqrt's.
1581
1570
if (shape.empty () || shape.sizes .back () % 8 != 0 )
0 commit comments