17
17
18
18
namespace vkcompute {
19
19
20
+ //
21
+ // VTensorPtr
22
+ //
23
+
24
+ #define VALUE_PTR_CLASS_IMPL (classname, ctype, type_name ) \
25
+ classname::classname (ComputeGraph* const graph, const ValueRef idx) \
26
+ : graph_(graph), ptr_(&(graph_->values_.at(idx).to##type_name())) { \
27
+ graph_->values_in_use_ ++; \
28
+ } \
29
+ ctype* classname::operator ->() const { \
30
+ return ptr_; \
31
+ } \
32
+ ctype& classname::operator *() const { \
33
+ return *ptr_; \
34
+ } \
35
+ classname::~classname () { \
36
+ graph_->values_in_use_ --; \
37
+ }
38
+
39
+ VALUE_PTR_CLASS_IMPL (vTensorPtr, vTensor, Tensor)
40
+ VALUE_PTR_CLASS_IMPL (TensorRefPtr, TensorRef, TensorRef)
41
+ VALUE_PTR_CLASS_IMPL (StagingPtr, api::StorageBuffer, Staging)
42
+ VALUE_PTR_CLASS_IMPL (IntListPtr, std::vector<int64_t >, IntList)
43
+ VALUE_PTR_CLASS_IMPL (DoubleListPtr, std::vector<double >, DoubleList)
44
+ VALUE_PTR_CLASS_IMPL (BoolListPtr, std::vector<bool >, BoolList)
45
+ VALUE_PTR_CLASS_IMPL (ValueListPtr, std::vector<ValueRef>, ValueList)
46
+
47
+ #undef VALUE_PTR_CLASS_IMPL
48
+
49
+ //
50
+ // ComputeGraph
51
+ //
52
+
20
53
ComputeGraph::ComputeGraph (GraphConfig config)
21
54
: config_{config},
22
55
prepack_descriptor_counts_{},
@@ -105,6 +138,35 @@ api::GPUMemoryLayout ComputeGraph::suggested_memory_layout(
105
138
return api::kChannelsPacked ;
106
139
}
107
140
141
+ void ComputeGraph::check_no_active_value_ptrs () {
142
+ VK_CHECK_COND (
143
+ values_in_use_ == 0 ,
144
+ " Make sure that there are no pointers stored from the return values of "
145
+ " `ComputeGraph::get_*()` functions in scope before adding Values to the "
146
+ " graph. Modifying the graph's values may cause existing pointers to be "
147
+ " invalidated." );
148
+ }
149
+
150
+ std::vector<int64_t > ComputeGraph::get_sizes_of (ValueRef idx) {
151
+ Value& val = values_.at (idx);
152
+ if (val.isTensor ()) {
153
+ return val.toTensor ().sizes ();
154
+ } else if (val.isTensorRef ()) {
155
+ return val.toTensorRef ().sizes ;
156
+ }
157
+ VK_THROW (" Could not get sizes of value with type " , val.type ());
158
+ }
159
+
160
+ api::ScalarType ComputeGraph::get_dtype_of (ValueRef idx) {
161
+ Value& val = values_.at (idx);
162
+ if (val.isTensor ()) {
163
+ return val.toTensor ().dtype ();
164
+ } else if (val.isTensorRef ()) {
165
+ return val.toTensorRef ().dtype ;
166
+ }
167
+ VK_THROW (" Could not get dtype of value with type " , val.type ());
168
+ }
169
+
108
170
ValueRef ComputeGraph::add_tensor (
109
171
const std::vector<int64_t >& sizes,
110
172
const api::ScalarType dtype,
@@ -114,6 +176,7 @@ ValueRef ComputeGraph::add_tensor(
114
176
bool allocate_memory = shared_object_idx < 0 ;
115
177
116
178
ValueRef idx (static_cast <int >(values_.size ()));
179
+ check_no_active_value_ptrs ();
117
180
values_.emplace_back (vTensor (
118
181
context (), sizes, dtype, storage_type, memory_layout, allocate_memory));
119
182
@@ -133,18 +196,17 @@ ValueRef ComputeGraph::add_tensor(
133
196
}
134
197
135
198
ValueRef ComputeGraph::add_tensor_like (
136
- const ValueRef vref ,
199
+ const ValueRef idx ,
137
200
const api::StorageType storage_type,
138
201
const api::GPUMemoryLayout memory_layout) {
139
- TensorRef& tref = get_val (vref). toTensorRef ();
140
- return add_tensor (tref. sizes , tref. dtype , storage_type, memory_layout);
202
+ return add_tensor (
203
+ get_sizes_of (idx), get_dtype_of (idx) , storage_type, memory_layout);
141
204
}
142
205
143
206
ValueRef ComputeGraph::add_tensor_like (
144
- const ValueRef vref ,
207
+ const ValueRef idx ,
145
208
const api::GPUMemoryLayout memory_layout) {
146
- TensorRef& tref = get_val (vref).toTensorRef ();
147
- return add_tensor (tref.sizes , tref.dtype , memory_layout);
209
+ return add_tensor (get_sizes_of (idx), get_dtype_of (idx), memory_layout);
148
210
}
149
211
150
212
ValueRef ComputeGraph::add_tensor (
@@ -160,6 +222,7 @@ ValueRef ComputeGraph::add_tensorref(
160
222
const api::ScalarType dtype,
161
223
const void * const data) {
162
224
ValueRef idx (static_cast <int >(values_.size ()));
225
+ check_no_active_value_ptrs ();
163
226
values_.emplace_back (TensorRef (sizes, dtype, data));
164
227
return idx;
165
228
}
@@ -168,24 +231,28 @@ ValueRef ComputeGraph::add_staging(
168
231
const api::ScalarType dtype,
169
232
const size_t numel) {
170
233
ValueRef idx (static_cast <int >(values_.size ()));
234
+ check_no_active_value_ptrs ();
171
235
values_.emplace_back (api::StorageBuffer (context (), dtype, numel));
172
236
return idx;
173
237
}
174
238
175
239
ValueRef ComputeGraph::add_none () {
176
240
ValueRef idx (static_cast <int >(values_.size ()));
241
+ check_no_active_value_ptrs ();
177
242
values_.emplace_back ();
178
243
return idx;
179
244
}
180
245
181
246
ValueRef ComputeGraph::add_value_list (std::vector<ValueRef>&& value) {
182
247
ValueRef idx (static_cast <int >(values_.size ()));
248
+ check_no_active_value_ptrs ();
183
249
values_.emplace_back (std::move (value));
184
250
return idx;
185
251
}
186
252
187
253
ValueRef ComputeGraph::add_string (std::string&& str) {
188
254
ValueRef idx (static_cast <int >(values_.size ()));
255
+ check_no_active_value_ptrs ();
189
256
values_.emplace_back (std::move (str));
190
257
return idx;
191
258
}
@@ -194,8 +261,9 @@ ValueRef ComputeGraph::set_input_tensor(
194
261
const ValueRef idx,
195
262
const bool use_staging) {
196
263
if (use_staging) {
197
- vTensor& tensor = get_val (idx).toTensor ();
198
- ValueRef staging_idx = add_staging (tensor.dtype (), tensor.gpu_numel ());
264
+ api::ScalarType dtype = get_tensor (idx)->dtype ();
265
+ size_t gpu_numel = get_tensor (idx)->gpu_numel ();
266
+ ValueRef staging_idx = add_staging (dtype, gpu_numel);
199
267
add_staging_to_tensor_node (*this , staging_idx, idx);
200
268
inputs_.push_back ({idx, staging_idx});
201
269
return staging_idx;
@@ -208,8 +276,9 @@ ValueRef ComputeGraph::set_output_tensor(
208
276
const ValueRef idx,
209
277
const bool use_staging) {
210
278
if (use_staging) {
211
- vTensor& tensor = get_val (idx).toTensor ();
212
- ValueRef staging_idx = add_staging (tensor.dtype (), tensor.gpu_numel ());
279
+ api::ScalarType dtype = get_tensor (idx)->dtype ();
280
+ size_t gpu_numel = get_tensor (idx)->gpu_numel ();
281
+ ValueRef staging_idx = add_staging (dtype, gpu_numel);
213
282
add_tensor_to_staging_node (*this , idx, staging_idx);
214
283
outputs_.push_back ({idx, staging_idx});
215
284
return staging_idx;
@@ -229,20 +298,18 @@ void ComputeGraph::copy_into_staging(
229
298
const ValueRef idx,
230
299
const void * data,
231
300
const size_t numel) {
232
- Value& in_val = get_val (idx);
233
- api::StorageBuffer& staging = in_val.toStaging ();
234
- size_t nbytes = numel * api::element_size (staging.dtype ());
235
- copy_ptr_to_staging (data, staging, nbytes);
301
+ StagingPtr staging = get_staging (idx);
302
+ size_t nbytes = numel * api::element_size (staging->dtype ());
303
+ copy_ptr_to_staging (data, *staging, nbytes);
236
304
}
237
305
238
306
void ComputeGraph::copy_from_staging (
239
307
const ValueRef idx,
240
308
void * data,
241
309
const size_t numel) {
242
- Value& out_val = get_val (idx);
243
- api::StorageBuffer& staging = out_val.toStaging ();
244
- size_t nbytes = numel * api::element_size (staging.dtype ());
245
- copy_staging_to_ptr (staging, data, nbytes);
310
+ StagingPtr staging = get_staging (idx);
311
+ size_t nbytes = numel * api::element_size (staging->dtype ());
312
+ copy_staging_to_ptr (*staging, data, nbytes);
246
313
}
247
314
248
315
void ComputeGraph::prepare () {
@@ -308,7 +375,7 @@ void ComputeGraph::resize_input(
308
375
const int64_t idx,
309
376
const std::vector<int64_t >& new_sizes) {
310
377
IOValueRef io_val = inputs_.at (idx);
311
- get_val (io_val.value ). toTensor (). virtual_resize (new_sizes);
378
+ get_tensor (io_val.value )-> virtual_resize (new_sizes);
312
379
}
313
380
314
381
void ComputeGraph::propagate_resize () {
0 commit comments