Skip to content

Commit fcd4ee5

Browse files
committed
[mlir] Make ShapedTypeComponents contructible from ShapeAdaptor
ValueShapeRange::getShape() returns ShapeAdaptor rather than ShapedType and ShapeAdaptor allows implicit conversion to bool. It ends up that ShapedTypeComponents can be constructed with ShapeAdaptor incorrectly. The reason is that the type trait std::is_constructible<ShapeStorageT, Arg>::value is fulfilled because ShapeAdaptor can be converted to bool and it can be used to construct ShapeStorageT. In the end, we won't give any warning or error message when doing things like inferredReturnShapes.emplace_back(valueShapeRange.getShape(0)); Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D120845
1 parent 2b0ec7c commit fcd4ee5

File tree

1 file changed

+67
-60
lines changed

1 file changed

+67
-60
lines changed

mlir/include/mlir/Interfaces/InferTypeOpInterface.h

Lines changed: 67 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -25,67 +25,9 @@
2525

2626
namespace mlir {
2727

28+
class ShapedTypeComponents;
2829
using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<Value>>;
2930

30-
/// ShapedTypeComponents that represents the components of a ShapedType.
31-
/// The components consist of
32-
/// - A ranked or unranked shape with the dimension specification match those
33-
/// of ShapeType's getShape() (e.g., dynamic dimension represented using
34-
/// ShapedType::kDynamicSize)
35-
/// - A element type, may be unset (nullptr)
36-
/// - A attribute, may be unset (nullptr)
37-
/// Used by ShapedType type inferences.
38-
class ShapedTypeComponents {
39-
/// Internal storage type for shape.
40-
using ShapeStorageT = SmallVector<int64_t, 3>;
41-
42-
public:
43-
/// Default construction is an unranked shape.
44-
ShapedTypeComponents() : elementType(nullptr), attr(nullptr){};
45-
ShapedTypeComponents(Type elementType)
46-
: elementType(elementType), attr(nullptr), ranked(false) {}
47-
ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) {
48-
ranked = shapedType.hasRank();
49-
elementType = shapedType.getElementType();
50-
if (ranked)
51-
dims = llvm::to_vector<4>(shapedType.getShape());
52-
}
53-
template <typename Arg, typename = typename std::enable_if_t<
54-
std::is_constructible<ShapeStorageT, Arg>::value>>
55-
ShapedTypeComponents(Arg &&arg, Type elementType = nullptr,
56-
Attribute attr = nullptr)
57-
: dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr),
58-
ranked(true) {}
59-
ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr,
60-
Attribute attr = nullptr)
61-
: dims(vec.begin(), vec.end()), elementType(elementType), attr(attr),
62-
ranked(true) {}
63-
64-
/// Return the dimensions of the shape.
65-
/// Requires: shape is ranked.
66-
ArrayRef<int64_t> getDims() const {
67-
assert(ranked && "requires ranked shape");
68-
return dims;
69-
}
70-
71-
/// Return whether the shape has a rank.
72-
bool hasRank() const { return ranked; };
73-
74-
/// Return the element type component.
75-
Type getElementType() const { return elementType; };
76-
77-
/// Return the raw attribute component.
78-
Attribute getAttribute() const { return attr; };
79-
80-
private:
81-
friend class ShapeAdaptor;
82-
83-
ShapeStorageT dims;
84-
Type elementType;
85-
Attribute attr;
86-
bool ranked{false};
87-
};
88-
8931
/// Adaptor class to abstract the differences between whether value is from
9032
/// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
9133
class ShapeAdaptor {
@@ -137,7 +79,7 @@ class ShapeAdaptor {
13779
int64_t getNumElements() const;
13880

13981
/// Returns whether valid (non-null) shape.
140-
operator bool() const { return !val.isNull(); }
82+
explicit operator bool() const { return !val.isNull(); }
14183

14284
/// Dumps textual repesentation to stderr.
14385
void dump() const;
@@ -148,6 +90,71 @@ class ShapeAdaptor {
14890
PointerUnion<ShapedTypeComponents *, Type, Attribute> val = nullptr;
14991
};
15092

93+
/// ShapedTypeComponents that represents the components of a ShapedType.
94+
/// The components consist of
95+
/// - A ranked or unranked shape with the dimension specification match those
96+
/// of ShapeType's getShape() (e.g., dynamic dimension represented using
97+
/// ShapedType::kDynamicSize)
98+
/// - A element type, may be unset (nullptr)
99+
/// - A attribute, may be unset (nullptr)
100+
/// Used by ShapedType type inferences.
101+
class ShapedTypeComponents {
102+
/// Internal storage type for shape.
103+
using ShapeStorageT = SmallVector<int64_t, 3>;
104+
105+
public:
106+
/// Default construction is an unranked shape.
107+
ShapedTypeComponents() : elementType(nullptr), attr(nullptr){};
108+
ShapedTypeComponents(Type elementType)
109+
: elementType(elementType), attr(nullptr), ranked(false) {}
110+
ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) {
111+
ranked = shapedType.hasRank();
112+
elementType = shapedType.getElementType();
113+
if (ranked)
114+
dims = llvm::to_vector<4>(shapedType.getShape());
115+
}
116+
ShapedTypeComponents(ShapeAdaptor adaptor) : attr(nullptr) {
117+
ranked = adaptor.hasRank();
118+
elementType = adaptor.getElementType();
119+
if (ranked)
120+
adaptor.getDims(*this);
121+
}
122+
template <typename Arg, typename = typename std::enable_if_t<
123+
std::is_constructible<ShapeStorageT, Arg>::value>>
124+
ShapedTypeComponents(Arg &&arg, Type elementType = nullptr,
125+
Attribute attr = nullptr)
126+
: dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr),
127+
ranked(true) {}
128+
ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr,
129+
Attribute attr = nullptr)
130+
: dims(vec.begin(), vec.end()), elementType(elementType), attr(attr),
131+
ranked(true) {}
132+
133+
/// Return the dimensions of the shape.
134+
/// Requires: shape is ranked.
135+
ArrayRef<int64_t> getDims() const {
136+
assert(ranked && "requires ranked shape");
137+
return dims;
138+
}
139+
140+
/// Return whether the shape has a rank.
141+
bool hasRank() const { return ranked; };
142+
143+
/// Return the element type component.
144+
Type getElementType() const { return elementType; };
145+
146+
/// Return the raw attribute component.
147+
Attribute getAttribute() const { return attr; };
148+
149+
private:
150+
friend class ShapeAdaptor;
151+
152+
ShapeStorageT dims;
153+
Type elementType;
154+
Attribute attr;
155+
bool ranked{false};
156+
};
157+
151158
/// Range of values and shapes (corresponding effectively to Shapes dialect's
152159
/// ValueShape type concept).
153160
// Currently this exposes the Value (of operands) and Type of the Value. This is

0 commit comments

Comments
 (0)