25
25
26
26
namespace mlir {
27
27
28
+ class ShapedTypeComponents ;
28
29
using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<Value>>;
29
30
30
- // / ShapedTypeComponents that represents the components of a ShapedType.
31
- // / The components consist of
32
- // / - A ranked or unranked shape with the dimension specification match those
33
- // / of ShapeType's getShape() (e.g., dynamic dimension represented using
34
- // / ShapedType::kDynamicSize)
35
- // / - A element type, may be unset (nullptr)
36
- // / - A attribute, may be unset (nullptr)
37
- // / Used by ShapedType type inferences.
38
- class ShapedTypeComponents {
39
- // / Internal storage type for shape.
40
- using ShapeStorageT = SmallVector<int64_t , 3 >;
41
-
42
- public:
43
- // / Default construction is an unranked shape.
44
- ShapedTypeComponents () : elementType(nullptr ), attr(nullptr ){};
45
- ShapedTypeComponents (Type elementType)
46
- : elementType(elementType), attr(nullptr ), ranked(false ) {}
47
- ShapedTypeComponents (ShapedType shapedType) : attr(nullptr ) {
48
- ranked = shapedType.hasRank ();
49
- elementType = shapedType.getElementType ();
50
- if (ranked)
51
- dims = llvm::to_vector<4 >(shapedType.getShape ());
52
- }
53
- template <typename Arg, typename = typename std::enable_if_t <
54
- std::is_constructible<ShapeStorageT, Arg>::value>>
55
- ShapedTypeComponents (Arg &&arg, Type elementType = nullptr ,
56
- Attribute attr = nullptr )
57
- : dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr),
58
- ranked (true ) {}
59
- ShapedTypeComponents (ArrayRef<int64_t > vec, Type elementType = nullptr ,
60
- Attribute attr = nullptr )
61
- : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr),
62
- ranked(true ) {}
63
-
64
- // / Return the dimensions of the shape.
65
- // / Requires: shape is ranked.
66
- ArrayRef<int64_t > getDims () const {
67
- assert (ranked && " requires ranked shape" );
68
- return dims;
69
- }
70
-
71
- // / Return whether the shape has a rank.
72
- bool hasRank () const { return ranked; };
73
-
74
- // / Return the element type component.
75
- Type getElementType () const { return elementType; };
76
-
77
- // / Return the raw attribute component.
78
- Attribute getAttribute () const { return attr; };
79
-
80
- private:
81
- friend class ShapeAdaptor ;
82
-
83
- ShapeStorageT dims;
84
- Type elementType;
85
- Attribute attr;
86
- bool ranked{false };
87
- };
88
-
89
31
// / Adaptor class to abstract the differences between whether value is from
90
32
// / a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
91
33
class ShapeAdaptor {
@@ -137,7 +79,7 @@ class ShapeAdaptor {
137
79
int64_t getNumElements () const ;
138
80
139
81
// / Returns whether valid (non-null) shape.
140
- operator bool () const { return !val.isNull (); }
82
+ explicit operator bool () const { return !val.isNull (); }
141
83
142
84
// / Dumps textual repesentation to stderr.
143
85
void dump () const ;
@@ -148,6 +90,71 @@ class ShapeAdaptor {
148
90
PointerUnion<ShapedTypeComponents *, Type, Attribute> val = nullptr ;
149
91
};
150
92
93
+ // / ShapedTypeComponents that represents the components of a ShapedType.
94
+ // / The components consist of
95
+ // / - A ranked or unranked shape with the dimension specification match those
96
+ // / of ShapeType's getShape() (e.g., dynamic dimension represented using
97
+ // / ShapedType::kDynamicSize)
98
+ // / - A element type, may be unset (nullptr)
99
+ // / - A attribute, may be unset (nullptr)
100
+ // / Used by ShapedType type inferences.
101
+ class ShapedTypeComponents {
102
+ // / Internal storage type for shape.
103
+ using ShapeStorageT = SmallVector<int64_t , 3 >;
104
+
105
+ public:
106
+ // / Default construction is an unranked shape.
107
+ ShapedTypeComponents () : elementType(nullptr ), attr(nullptr ){};
108
+ ShapedTypeComponents (Type elementType)
109
+ : elementType(elementType), attr(nullptr ), ranked(false ) {}
110
+ ShapedTypeComponents (ShapedType shapedType) : attr(nullptr ) {
111
+ ranked = shapedType.hasRank ();
112
+ elementType = shapedType.getElementType ();
113
+ if (ranked)
114
+ dims = llvm::to_vector<4 >(shapedType.getShape ());
115
+ }
116
+ ShapedTypeComponents (ShapeAdaptor adaptor) : attr(nullptr ) {
117
+ ranked = adaptor.hasRank ();
118
+ elementType = adaptor.getElementType ();
119
+ if (ranked)
120
+ adaptor.getDims (*this );
121
+ }
122
+ template <typename Arg, typename = typename std::enable_if_t <
123
+ std::is_constructible<ShapeStorageT, Arg>::value>>
124
+ ShapedTypeComponents (Arg &&arg, Type elementType = nullptr ,
125
+ Attribute attr = nullptr )
126
+ : dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr),
127
+ ranked (true ) {}
128
+ ShapedTypeComponents (ArrayRef<int64_t > vec, Type elementType = nullptr ,
129
+ Attribute attr = nullptr )
130
+ : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr),
131
+ ranked(true ) {}
132
+
133
+ // / Return the dimensions of the shape.
134
+ // / Requires: shape is ranked.
135
+ ArrayRef<int64_t > getDims () const {
136
+ assert (ranked && " requires ranked shape" );
137
+ return dims;
138
+ }
139
+
140
+ // / Return whether the shape has a rank.
141
+ bool hasRank () const { return ranked; };
142
+
143
+ // / Return the element type component.
144
+ Type getElementType () const { return elementType; };
145
+
146
+ // / Return the raw attribute component.
147
+ Attribute getAttribute () const { return attr; };
148
+
149
+ private:
150
+ friend class ShapeAdaptor ;
151
+
152
+ ShapeStorageT dims;
153
+ Type elementType;
154
+ Attribute attr;
155
+ bool ranked{false };
156
+ };
157
+
151
158
// / Range of values and shapes (corresponding effectively to Shapes dialect's
152
159
// / ValueShape type concept).
153
160
// Currently this exposes the Value (of operands) and Type of the Value. This is
0 commit comments