@@ -58,7 +58,7 @@ struct Value final {
58
58
bool as_bool;
59
59
} u;
60
60
61
- api::vTensor as_tensor;
61
+ std::unique_ptr< api::vTensor> as_tensor;
62
62
api::StagingBuffer as_staging;
63
63
TensorRef as_tensorref;
64
64
@@ -106,15 +106,18 @@ struct Value final {
106
106
rhs.payload.member_name.~dtor_name (); \
107
107
break ;
108
108
109
+ #define CASE_MOVE_UNIQUE_PTR_TYPE (type_tag, member_name ) \
110
+ case type_tag: \
111
+ payload.member_name = std::move(rhs.payload.member_name); \
112
+ break ;
113
+
109
114
Value (Value&& rhs) noexcept : tag(rhs.tag) {
110
115
switch (tag) {
111
116
// Scalar types
112
117
CASE_MOVE_TRIVIALLY_COPYABLE_TYPE (TypeTag::INT, as_int);
113
118
CASE_MOVE_TRIVIALLY_COPYABLE_TYPE (TypeTag::DOUBLE, as_double);
114
119
CASE_MOVE_TRIVIALLY_COPYABLE_TYPE (TypeTag::BOOL, as_bool);
115
- // Tensor and tensor adjacent types
116
- CASE_MOVE_MOVEABLE_TYPE (
117
- TypeTag::TENSOR, api::vTensor, as_tensor, vTensor);
120
+ // Tensor adjacent types
118
121
CASE_MOVE_MOVEABLE_TYPE (
119
122
TypeTag::STAGING, api::StagingBuffer, as_staging, StagingBuffer);
120
123
CASE_MOVE_MOVEABLE_TYPE (
@@ -132,6 +135,8 @@ struct Value final {
132
135
CASE_MOVE_MOVEABLE_TYPE (
133
136
TypeTag::STRING, std::string, as_string, basic_string);
134
137
CASE_MOVE_MOVEABLE_TYPE (TypeTag::SYMINT, SymInt, as_symint, SymInt);
138
+ // Tensor type
139
+ CASE_MOVE_UNIQUE_PTR_TYPE (TypeTag::TENSOR, as_tensor);
135
140
136
141
case TypeTag::NONE:
137
142
clearToNone ();
@@ -142,6 +147,7 @@ struct Value final {
142
147
143
148
#undef CASE_MOVE_TRIVIALLY_COPYABLE_TYPE
144
149
#undef CASE_MOVE_MOVEABLE_TYPE
150
+ #undef CASE_MOVE_UNIQUE_PTR_TYPE
145
151
146
152
//
147
153
// Accessors
@@ -157,9 +163,6 @@ struct Value final {
157
163
158
164
~Value () {
159
165
switch (tag) {
160
- case TypeTag::TENSOR:
161
- payload.as_tensor .~vTensor ();
162
- break ;
163
166
case TypeTag::STAGING:
164
167
payload.as_staging .~StagingBuffer ();
165
168
break ;
@@ -184,6 +187,9 @@ struct Value final {
184
187
case TypeTag::SYMINT:
185
188
payload.as_symint .~SymInt ();
186
189
break ;
190
+ case TypeTag::TENSOR:
191
+ payload.as_tensor .reset ();
192
+ break ;
187
193
// Manually list out the types so that if a type here is added later and
188
194
// not handled the compiler can catch it.
189
195
case TypeTag::NONE:
@@ -252,12 +258,6 @@ struct Value final {
252
258
return payload.member_name ; \
253
259
}
254
260
255
- SUPPORT_TRIVIALLY_MOVEABLE_TYPE (
256
- api::vTensor,
257
- Tensor,
258
- TypeTag::TENSOR,
259
- as_tensor);
260
-
261
261
SUPPORT_TRIVIALLY_MOVEABLE_TYPE (
262
262
api::StagingBuffer,
263
263
Staging,
@@ -302,9 +302,36 @@ struct Value final {
302
302
303
303
SUPPORT_TRIVIALLY_MOVEABLE_TYPE (SymInt, SymInt, TypeTag::SYMINT, as_symint);
304
304
305
- #undef SUPPORT_TRIVIALLY_COPYABLE_TYPE
306
305
#undef SUPPORT_TRIVIALLY_MOVEABLE_TYPE
307
306
307
+ #define SUPPORT_UNIQUE_PTR_TYPE (type, type_name, type_tag, member_name ) \
308
+ explicit Value (type t) : tag(type_tag) { \
309
+ payload.member_name = std::make_unique<type>(std::move (t)); \
310
+ } \
311
+ inline bool is##type_name() const { \
312
+ return tag == type_tag; \
313
+ } \
314
+ inline type& to##type_name() const { \
315
+ VK_CHECK_COND ( \
316
+ is##type_name (), \
317
+ " Expected value to have type " #type_name " , got " , \
318
+ tag, \
319
+ " instead." ); \
320
+ return *payload.member_name ; \
321
+ } \
322
+ inline const type& toConst##type_name() const { \
323
+ VK_CHECK_COND ( \
324
+ is##type_name (), \
325
+ " Expected value to have type " #type_name " , got " , \
326
+ tag, \
327
+ " instead." ); \
328
+ return *payload.member_name ; \
329
+ }
330
+
331
+ SUPPORT_UNIQUE_PTR_TYPE (api::vTensor, Tensor, TypeTag::TENSOR, as_tensor);
332
+
333
+ #undef SUPPORT_UNIQUE_PTR_TYPE
334
+
308
335
private:
309
336
Payload payload;
310
337
TypeTag tag;
0 commit comments