@@ -43,20 +43,17 @@ using namespace mlir::vector;
43
43
struct VectorShape {
44
44
ArrayRef<int64_t > sizes;
45
45
ArrayRef<bool > scalableFlags;
46
-
47
- bool empty () const { return sizes.empty (); }
48
46
};
49
47
50
- // Returns vector shape if the type is a vector. Returns an empty shape if it is
51
- // not a vector.
52
- static VectorShape vectorShape (Type type) {
53
- auto vectorType = dyn_cast<VectorType>(type);
54
- return vectorType
55
- ? VectorShape{vectorType.getShape (), vectorType.getScalableDims ()}
56
- : VectorShape{};
48
+ // Returns vector shape if the type is a vector, otherwise return nullopt.
49
+ static std::optional<VectorShape> vectorShape (Type type) {
50
+ if (auto vectorType = dyn_cast<VectorType>(type)) {
51
+ return VectorShape{vectorType.getShape (), vectorType.getScalableDims ()};
52
+ }
53
+ return std::nullopt;
57
54
}
58
55
59
- static VectorShape vectorShape (Value value) {
56
+ static std::optional< VectorShape> vectorShape (Value value) {
60
57
return vectorShape (value.getType ());
61
58
}
62
59
@@ -65,19 +62,18 @@ static VectorShape vectorShape(Value value) {
65
62
// ----------------------------------------------------------------------------//
66
63
67
64
// Broadcasts scalar type into vector type (iff shape is non-scalar).
68
- static Type broadcast (Type type, VectorShape shape) {
65
+ static Type broadcast (Type type, std::optional< VectorShape> shape) {
69
66
assert (!isa<VectorType>(type) && " must be scalar type" );
70
- return !shape.empty ()
71
- ? VectorType::get (shape.sizes , type, shape.scalableFlags )
72
- : type;
67
+ return shape ? VectorType::get (shape->sizes , type, shape->scalableFlags )
68
+ : type;
73
69
}
74
70
75
71
// Broadcasts scalar value into vector (iff shape is non-scalar).
76
72
static Value broadcast (ImplicitLocOpBuilder &builder, Value value,
77
- VectorShape shape) {
73
+ std::optional< VectorShape> shape) {
78
74
assert (!isa<VectorType>(value.getType ()) && " must be scalar value" );
79
75
auto type = broadcast (value.getType (), shape);
80
- return ! shape. empty () ? builder.create <BroadcastOp>(type, value) : value;
76
+ return shape ? builder.create <BroadcastOp>(type, value) : value;
81
77
}
82
78
83
79
// ----------------------------------------------------------------------------//
@@ -227,7 +223,7 @@ static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
227
223
static std::pair<Value, Value> frexp (ImplicitLocOpBuilder &builder, Value arg,
228
224
bool isPositive = false ) {
229
225
assert (getElementTypeOrSelf (arg).isF32 () && " arg must be f32 type" );
230
- VectorShape shape = vectorShape (arg);
226
+ std::optional< VectorShape> shape = vectorShape (arg);
231
227
232
228
auto bcast = [&](Value value) -> Value {
233
229
return broadcast (builder, value, shape);
@@ -267,7 +263,7 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
267
263
// Computes exp2 for an i32 argument.
268
264
static Value exp2I32 (ImplicitLocOpBuilder &builder, Value arg) {
269
265
assert (getElementTypeOrSelf (arg).isInteger (32 ) && " arg must be i32 type" );
270
- VectorShape shape = vectorShape (arg);
266
+ std::optional< VectorShape> shape = vectorShape (arg);
271
267
272
268
auto bcast = [&](Value value) -> Value {
273
269
return broadcast (builder, value, shape);
@@ -293,7 +289,7 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
293
289
Type elementType = getElementTypeOrSelf (x);
294
290
assert ((elementType.isF32 () || elementType.isF16 ()) &&
295
291
" x must be f32 or f16 type" );
296
- VectorShape shape = vectorShape (x);
292
+ std::optional< VectorShape> shape = vectorShape (x);
297
293
298
294
if (coeffs.empty ())
299
295
return broadcast (builder, floatCst (builder, 0 .0f , elementType), shape);
@@ -391,7 +387,7 @@ AtanApproximation::matchAndRewrite(math::AtanOp op,
391
387
if (!getElementTypeOrSelf (operand).isF32 ())
392
388
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
393
389
394
- VectorShape shape = vectorShape (op.getOperand ());
390
+ std::optional< VectorShape> shape = vectorShape (op.getOperand ());
395
391
396
392
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
397
393
Value abs = builder.create <math::AbsFOp>(operand);
@@ -490,7 +486,7 @@ Atan2Approximation::matchAndRewrite(math::Atan2Op op,
490
486
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
491
487
492
488
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
493
- VectorShape shape = vectorShape (op.getResult ());
489
+ std::optional< VectorShape> shape = vectorShape (op.getResult ());
494
490
495
491
// Compute atan in the valid range.
496
492
auto div = builder.create <arith::DivFOp>(y, x);
@@ -556,7 +552,7 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
556
552
if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
557
553
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
558
554
559
- VectorShape shape = vectorShape (op.getOperand ());
555
+ std::optional< VectorShape> shape = vectorShape (op.getOperand ());
560
556
561
557
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
562
558
auto bcast = [&](Value value) -> Value {
@@ -644,7 +640,7 @@ LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
644
640
if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
645
641
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
646
642
647
- VectorShape shape = vectorShape (op.getOperand ());
643
+ std::optional< VectorShape> shape = vectorShape (op.getOperand ());
648
644
649
645
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
650
646
auto bcast = [&](Value value) -> Value {
@@ -791,7 +787,7 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
791
787
if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
792
788
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
793
789
794
- VectorShape shape = vectorShape (op.getOperand ());
790
+ std::optional< VectorShape> shape = vectorShape (op.getOperand ());
795
791
796
792
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
797
793
auto bcast = [&](Value value) -> Value {
@@ -846,7 +842,7 @@ AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
846
842
if (!(elementType.isF32 () || elementType.isF16 ()))
847
843
return rewriter.notifyMatchFailure (op,
848
844
" only f32 and f16 type is supported." );
849
- VectorShape shape = vectorShape (operand);
845
+ std::optional< VectorShape> shape = vectorShape (operand);
850
846
851
847
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
852
848
auto bcast = [&](Value value) -> Value {
@@ -941,7 +937,7 @@ AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
941
937
if (!(elementType.isF32 () || elementType.isF16 ()))
942
938
return rewriter.notifyMatchFailure (op,
943
939
" only f32 and f16 type is supported." );
944
- VectorShape shape = vectorShape (operand);
940
+ std::optional< VectorShape> shape = vectorShape (operand);
945
941
946
942
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
947
943
auto bcast = [&](Value value) -> Value {
@@ -1019,7 +1015,7 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
1019
1015
if (!(elementType.isF32 () || elementType.isF16 ()))
1020
1016
return rewriter.notifyMatchFailure (op,
1021
1017
" only f32 and f16 type is supported." );
1022
- VectorShape shape = vectorShape (operand);
1018
+ std::optional< VectorShape> shape = vectorShape (operand);
1023
1019
1024
1020
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
1025
1021
auto bcast = [&](Value value) -> Value {
@@ -1128,8 +1124,9 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
1128
1124
1129
1125
namespace {
1130
1126
1131
- Value clampWithNormals (ImplicitLocOpBuilder &builder, const VectorShape shape,
1132
- Value value, float lowerBound, float upperBound) {
1127
+ Value clampWithNormals (ImplicitLocOpBuilder &builder,
1128
+ const std::optional<VectorShape> shape, Value value,
1129
+ float lowerBound, float upperBound) {
1133
1130
assert (!std::isnan (lowerBound));
1134
1131
assert (!std::isnan (upperBound));
1135
1132
@@ -1320,7 +1317,7 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
1320
1317
if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
1321
1318
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
1322
1319
1323
- VectorShape shape = vectorShape (op.getOperand ());
1320
+ std::optional< VectorShape> shape = vectorShape (op.getOperand ());
1324
1321
1325
1322
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
1326
1323
auto bcast = [&](Value value) -> Value {
@@ -1390,7 +1387,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1390
1387
if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
1391
1388
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
1392
1389
1393
- VectorShape shape = vectorShape (op.getOperand ());
1390
+ std::optional< VectorShape> shape = vectorShape (op.getOperand ());
1394
1391
1395
1392
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
1396
1393
auto bcast = [&](Value value) -> Value {
@@ -1517,7 +1514,7 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op,
1517
1514
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
1518
1515
1519
1516
ImplicitLocOpBuilder b (op->getLoc (), rewriter);
1520
- VectorShape shape = vectorShape (operand);
1517
+ std::optional< VectorShape> shape = vectorShape (operand);
1521
1518
1522
1519
Type floatTy = getElementTypeOrSelf (operand.getType ());
1523
1520
Type intTy = b.getIntegerType (floatTy.getIntOrFloatBitWidth ());
@@ -1606,10 +1603,10 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1606
1603
if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
1607
1604
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
1608
1605
1609
- VectorShape shape = vectorShape (op.getOperand ());
1606
+ std::optional< VectorShape> shape = vectorShape (op.getOperand ());
1610
1607
1611
1608
// Only support already-vectorized rsqrt's.
1612
- if (shape.empty () || shape. sizes .back () % 8 != 0 )
1609
+ if (! shape || shape-> sizes .empty () || shape-> sizes .back () % 8 != 0 )
1613
1610
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
1614
1611
1615
1612
ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
0 commit comments