17
17
18
18
namespace vkcompute {
19
19
20
+ //
21
+ // VTensorPtr
22
+ //
23
+
24
+ vTensorPtr::vTensorPtr (ComputeGraph* graph, const ValueRef idx)
25
+ : graph_(graph), tensor_(&(graph_->values_.at(idx).toTensor())) {
26
+ graph_->values_in_use_ ++;
27
+ }
28
+
29
+ vTensorPtr::~vTensorPtr () {
30
+ graph_->values_in_use_ --;
31
+ }
32
+
33
+ //
34
+ // StagingPtr
35
+ //
36
+
37
+ StagingPtr::StagingPtr (ComputeGraph* graph, const ValueRef idx)
38
+ : graph_(graph), storage_(&(graph_->values_.at(idx).toStaging())) {
39
+ graph_->values_in_use_ ++;
40
+ }
41
+
42
+ StagingPtr::~StagingPtr () {
43
+ graph_->values_in_use_ --;
44
+ }
45
+
46
+ //
47
+ // ComputeGraph
48
+ //
49
+
20
50
ComputeGraph::ComputeGraph (GraphConfig config)
21
51
: config_{config},
22
52
prepack_descriptor_counts_{},
@@ -105,6 +135,15 @@ api::GPUMemoryLayout ComputeGraph::suggested_memory_layout(
105
135
return api::kChannelsPacked ;
106
136
}
107
137
138
+ void ComputeGraph::check_no_active_value_ptrs () {
139
+ VK_CHECK_COND (
140
+ values_in_use_ == 0 ,
141
+ " Make sure that there are no pointers stored from the return values of "
142
+ " `ComputeGraph::get_*()` functions in scope before adding Values to the "
143
+ " graph. Modifying the graph's values may cause existing pointers to be "
144
+ " invalidated." );
145
+ }
146
+
108
147
ValueRef ComputeGraph::add_tensor (
109
148
const std::vector<int64_t >& sizes,
110
149
const api::ScalarType dtype,
@@ -114,6 +153,7 @@ ValueRef ComputeGraph::add_tensor(
114
153
bool allocate_memory = shared_object_idx < 0 ;
115
154
116
155
ValueRef idx (static_cast <int >(values_.size ()));
156
+ check_no_active_value_ptrs ();
117
157
values_.emplace_back (vTensor (
118
158
context (), sizes, dtype, storage_type, memory_layout, allocate_memory));
119
159
@@ -136,14 +176,14 @@ ValueRef ComputeGraph::add_tensor_like(
136
176
const ValueRef vref,
137
177
const api::StorageType storage_type,
138
178
const api::GPUMemoryLayout memory_layout) {
139
- TensorRef& tref = get_val (vref). toTensorRef ( );
179
+ TensorRef tref = get_tref (vref);
140
180
return add_tensor (tref.sizes , tref.dtype , storage_type, memory_layout);
141
181
}
142
182
143
183
ValueRef ComputeGraph::add_tensor_like (
144
184
const ValueRef vref,
145
185
const api::GPUMemoryLayout memory_layout) {
146
- TensorRef& tref = get_val (vref). toTensorRef ( );
186
+ TensorRef tref = get_tref (vref);
147
187
return add_tensor (tref.sizes , tref.dtype , memory_layout);
148
188
}
149
189
@@ -160,6 +200,7 @@ ValueRef ComputeGraph::add_tensorref(
160
200
const api::ScalarType dtype,
161
201
const void * const data) {
162
202
ValueRef idx (static_cast <int >(values_.size ()));
203
+ check_no_active_value_ptrs ();
163
204
values_.emplace_back (TensorRef (sizes, dtype, data));
164
205
return idx;
165
206
}
@@ -168,24 +209,28 @@ ValueRef ComputeGraph::add_staging(
168
209
const api::ScalarType dtype,
169
210
const size_t numel) {
170
211
ValueRef idx (static_cast <int >(values_.size ()));
212
+ check_no_active_value_ptrs ();
171
213
values_.emplace_back (api::StorageBuffer (context (), dtype, numel));
172
214
return idx;
173
215
}
174
216
175
217
ValueRef ComputeGraph::add_none () {
176
218
ValueRef idx (static_cast <int >(values_.size ()));
219
+ check_no_active_value_ptrs ();
177
220
values_.emplace_back ();
178
221
return idx;
179
222
}
180
223
181
224
ValueRef ComputeGraph::add_value_list (std::vector<ValueRef>&& value) {
182
225
ValueRef idx (static_cast <int >(values_.size ()));
226
+ check_no_active_value_ptrs ();
183
227
values_.emplace_back (std::move (value));
184
228
return idx;
185
229
}
186
230
187
231
ValueRef ComputeGraph::add_string (std::string&& str) {
188
232
ValueRef idx (static_cast <int >(values_.size ()));
233
+ check_no_active_value_ptrs ();
189
234
values_.emplace_back (std::move (str));
190
235
return idx;
191
236
}
@@ -194,8 +239,9 @@ ValueRef ComputeGraph::set_input_tensor(
194
239
const ValueRef idx,
195
240
const bool use_staging) {
196
241
if (use_staging) {
197
- vTensor& tensor = get_val (idx).toTensor ();
198
- ValueRef staging_idx = add_staging (tensor.dtype (), tensor.gpu_numel ());
242
+ api::ScalarType dtype = get_tensor (idx)->dtype ();
243
+ size_t gpu_numel = get_tensor (idx)->gpu_numel ();
244
+ ValueRef staging_idx = add_staging (dtype, gpu_numel);
199
245
add_staging_to_tensor_node (*this , staging_idx, idx);
200
246
inputs_.push_back ({idx, staging_idx});
201
247
return staging_idx;
@@ -208,8 +254,9 @@ ValueRef ComputeGraph::set_output_tensor(
208
254
const ValueRef idx,
209
255
const bool use_staging) {
210
256
if (use_staging) {
211
- vTensor& tensor = get_val (idx).toTensor ();
212
- ValueRef staging_idx = add_staging (tensor.dtype (), tensor.gpu_numel ());
257
+ api::ScalarType dtype = get_tensor (idx)->dtype ();
258
+ size_t gpu_numel = get_tensor (idx)->gpu_numel ();
259
+ ValueRef staging_idx = add_staging (dtype, gpu_numel);
213
260
add_tensor_to_staging_node (*this , idx, staging_idx);
214
261
outputs_.push_back ({idx, staging_idx});
215
262
return staging_idx;
@@ -229,20 +276,18 @@ void ComputeGraph::copy_into_staging(
229
276
const ValueRef idx,
230
277
const void * data,
231
278
const size_t numel) {
232
- Value& in_val = get_val (idx);
233
- api::StorageBuffer& staging = in_val.toStaging ();
234
- size_t nbytes = numel * api::element_size (staging.dtype ());
235
- copy_ptr_to_staging (data, staging, nbytes);
279
+ StagingPtr staging = get_staging (idx);
280
+ size_t nbytes = numel * api::element_size (staging->dtype ());
281
+ copy_ptr_to_staging (data, *staging, nbytes);
236
282
}
237
283
238
284
void ComputeGraph::copy_from_staging (
239
285
const ValueRef idx,
240
286
void * data,
241
287
const size_t numel) {
242
- Value& out_val = get_val (idx);
243
- api::StorageBuffer& staging = out_val.toStaging ();
244
- size_t nbytes = numel * api::element_size (staging.dtype ());
245
- copy_staging_to_ptr (staging, data, nbytes);
288
+ StagingPtr staging = get_staging (idx);
289
+ size_t nbytes = numel * api::element_size (staging->dtype ());
290
+ copy_staging_to_ptr (*staging, data, nbytes);
246
291
}
247
292
248
293
void ComputeGraph::prepare () {
@@ -308,7 +353,7 @@ void ComputeGraph::resize_input(
308
353
const int64_t idx,
309
354
const std::vector<int64_t >& new_sizes) {
310
355
IOValueRef io_val = inputs_.at (idx);
311
- get_val (io_val.value ). toTensor (). virtual_resize (new_sizes);
356
+ get_tensor (io_val.value )-> virtual_resize (new_sizes);
312
357
}
313
358
314
359
void ComputeGraph::propagate_resize () {
0 commit comments