Skip to content

Commit ac71c30

Browse files
committed
[ET-VK] Introduce vTensorPtr to prevent reference invalidation and remove get_val() API
## Context Currently when writing operators developers will save a reference to a `vTensor` retrieved from a `ComputeGraph`'s list of `values_` like so: ``` vTensor& vten = graph.get_val(vref).toTensor(); ``` However, this is dangerous since if any values are added once the reference has been stored, `values_` which is a `std::vector` may have been resized and therefore have its contents moved, meaning the reference is now invalid. To protect against this, this changeset introduces the `vTensorPtr` class which is a wrapper around a `vTensor*`. When constructed, it will increment a counter in the `ComputeGraph` instance, and when destroyed it will decrement the counter. `ComputeGraph` cannot add any values while the counter is not zero. Since `Value` can be converted to other non-trivial types, this changeset also removes the `get_val` function entirely to guard against unsafe behaviour. Differential Revision: [D55984187](https://our.internmc.facebook.com/intern/diff/D55984187/) ghstack-source-id: 222072108 Pull Request resolved: #2978
1 parent a983ebc commit ac71c30

18 files changed

+374
-245
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -334,18 +334,18 @@ bool maybe_resize_input(
334334
const size_t input_i,
335335
exec_aten::Tensor& et_tensor) {
336336
ValueRef in_tensor_ref = graph->inputs()[input_i].value;
337-
vTensor& in_tensor = graph->get_val(in_tensor_ref).toTensor();
337+
vTensorPtr in_tensor = graph->get_tensor(in_tensor_ref);
338338

339339
ET_CHECK_MSG(
340-
et_tensor.dim() == in_tensor.sizes().size(),
340+
et_tensor.dim() == in_tensor->sizes().size(),
341341
"Cannot resize input tensor: old ndim %zu does not match new ndim %zu",
342-
static_cast<size_t>(in_tensor.sizes().size()),
342+
static_cast<size_t>(in_tensor->sizes().size()),
343343
static_cast<size_t>(et_tensor.dim()));
344344

345345
bool should_resize = false;
346346
std::vector<int64_t> new_sizes(et_tensor.dim());
347347
for (size_t i = 0; i < et_tensor.dim(); i++) {
348-
if (in_tensor.sizes()[i] != et_tensor.sizes()[i]) {
348+
if (in_tensor->sizes()[i] != et_tensor.sizes()[i]) {
349349
should_resize = true;
350350
}
351351
new_sizes.at(i) = et_tensor.sizes()[i];
@@ -356,9 +356,9 @@ bool maybe_resize_input(
356356
}
357357

358358
ET_CHECK_MSG(
359-
in_tensor.numel() == et_tensor.numel(),
359+
in_tensor->numel() == et_tensor.numel(),
360360
"Vulkan tensor numel %zu does not match ET tensor numel %zu",
361-
static_cast<size_t>(in_tensor.numel()),
361+
static_cast<size_t>(in_tensor->numel()),
362362
static_cast<size_t>(et_tensor.numel()));
363363

364364
return should_resize;
@@ -369,12 +369,12 @@ void maybe_resize_output(
369369
const size_t output_i,
370370
exec_aten::Tensor& et_tensor) {
371371
ValueRef out_tensor_ref = graph->outputs()[output_i].value;
372-
vTensor& out_tensor = graph->get_val(out_tensor_ref).toTensor();
372+
vTensorPtr out_tensor = graph->get_tensor(out_tensor_ref);
373373

374374
exec_aten::SizesType new_output_size[kTensorDimensionLimit];
375-
size_t ndim = out_tensor.sizes().size();
375+
size_t ndim = out_tensor->sizes().size();
376376
for (int i = 0; i < ndim; ++i) {
377-
new_output_size[i] = out_tensor.sizes()[i];
377+
new_output_size[i] = out_tensor->sizes()[i];
378378
}
379379

380380
exec_aten::ArrayRef<exec_aten::SizesType> output_size{new_output_size, ndim};

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,36 @@
1717

1818
namespace vkcompute {
1919

20+
//
21+
// VTensorPtr
22+
//
23+
24+
vTensorPtr::vTensorPtr(ComputeGraph* graph, const ValueRef idx)
25+
: graph_(graph), tensor_(&(graph_->values_.at(idx).toTensor())) {
26+
graph_->values_in_use_++;
27+
}
28+
29+
vTensorPtr::~vTensorPtr() {
30+
graph_->values_in_use_--;
31+
}
32+
33+
//
34+
// StagingPtr
35+
//
36+
37+
StagingPtr::StagingPtr(ComputeGraph* graph, const ValueRef idx)
38+
: graph_(graph), storage_(&(graph_->values_.at(idx).toStaging())) {
39+
graph_->values_in_use_++;
40+
}
41+
42+
StagingPtr::~StagingPtr() {
43+
graph_->values_in_use_--;
44+
}
45+
46+
//
47+
// ComputeGraph
48+
//
49+
2050
ComputeGraph::ComputeGraph(GraphConfig config)
2151
: config_{config},
2252
prepack_descriptor_counts_{},
@@ -105,6 +135,15 @@ api::GPUMemoryLayout ComputeGraph::suggested_memory_layout(
105135
return api::kChannelsPacked;
106136
}
107137

138+
void ComputeGraph::check_no_active_value_ptrs() {
139+
VK_CHECK_COND(
140+
values_in_use_ == 0,
141+
"Make sure that there are no pointers stored from the return values of "
142+
"`ComputeGraph::get_*()` functions in scope before adding Values to the "
143+
"graph. Modifying the graph's values may cause existing pointers to be "
144+
"invalidated.");
145+
}
146+
108147
ValueRef ComputeGraph::add_tensor(
109148
const std::vector<int64_t>& sizes,
110149
const api::ScalarType dtype,
@@ -114,6 +153,7 @@ ValueRef ComputeGraph::add_tensor(
114153
bool allocate_memory = shared_object_idx < 0;
115154

116155
ValueRef idx(static_cast<int>(values_.size()));
156+
check_no_active_value_ptrs();
117157
values_.emplace_back(vTensor(
118158
context(), sizes, dtype, storage_type, memory_layout, allocate_memory));
119159

@@ -136,14 +176,14 @@ ValueRef ComputeGraph::add_tensor_like(
136176
const ValueRef vref,
137177
const api::StorageType storage_type,
138178
const api::GPUMemoryLayout memory_layout) {
139-
TensorRef& tref = get_val(vref).toTensorRef();
179+
TensorRef tref = get_tref(vref);
140180
return add_tensor(tref.sizes, tref.dtype, storage_type, memory_layout);
141181
}
142182

143183
ValueRef ComputeGraph::add_tensor_like(
144184
const ValueRef vref,
145185
const api::GPUMemoryLayout memory_layout) {
146-
TensorRef& tref = get_val(vref).toTensorRef();
186+
TensorRef tref = get_tref(vref);
147187
return add_tensor(tref.sizes, tref.dtype, memory_layout);
148188
}
149189

@@ -160,6 +200,7 @@ ValueRef ComputeGraph::add_tensorref(
160200
const api::ScalarType dtype,
161201
const void* const data) {
162202
ValueRef idx(static_cast<int>(values_.size()));
203+
check_no_active_value_ptrs();
163204
values_.emplace_back(TensorRef(sizes, dtype, data));
164205
return idx;
165206
}
@@ -168,24 +209,28 @@ ValueRef ComputeGraph::add_staging(
168209
const api::ScalarType dtype,
169210
const size_t numel) {
170211
ValueRef idx(static_cast<int>(values_.size()));
212+
check_no_active_value_ptrs();
171213
values_.emplace_back(api::StorageBuffer(context(), dtype, numel));
172214
return idx;
173215
}
174216

175217
ValueRef ComputeGraph::add_none() {
176218
ValueRef idx(static_cast<int>(values_.size()));
219+
check_no_active_value_ptrs();
177220
values_.emplace_back();
178221
return idx;
179222
}
180223

181224
ValueRef ComputeGraph::add_value_list(std::vector<ValueRef>&& value) {
182225
ValueRef idx(static_cast<int>(values_.size()));
226+
check_no_active_value_ptrs();
183227
values_.emplace_back(std::move(value));
184228
return idx;
185229
}
186230

187231
ValueRef ComputeGraph::add_string(std::string&& str) {
188232
ValueRef idx(static_cast<int>(values_.size()));
233+
check_no_active_value_ptrs();
189234
values_.emplace_back(std::move(str));
190235
return idx;
191236
}
@@ -194,8 +239,9 @@ ValueRef ComputeGraph::set_input_tensor(
194239
const ValueRef idx,
195240
const bool use_staging) {
196241
if (use_staging) {
197-
vTensor& tensor = get_val(idx).toTensor();
198-
ValueRef staging_idx = add_staging(tensor.dtype(), tensor.gpu_numel());
242+
api::ScalarType dtype = get_tensor(idx)->dtype();
243+
size_t gpu_numel = get_tensor(idx)->gpu_numel();
244+
ValueRef staging_idx = add_staging(dtype, gpu_numel);
199245
add_staging_to_tensor_node(*this, staging_idx, idx);
200246
inputs_.push_back({idx, staging_idx});
201247
return staging_idx;
@@ -208,8 +254,9 @@ ValueRef ComputeGraph::set_output_tensor(
208254
const ValueRef idx,
209255
const bool use_staging) {
210256
if (use_staging) {
211-
vTensor& tensor = get_val(idx).toTensor();
212-
ValueRef staging_idx = add_staging(tensor.dtype(), tensor.gpu_numel());
257+
api::ScalarType dtype = get_tensor(idx)->dtype();
258+
size_t gpu_numel = get_tensor(idx)->gpu_numel();
259+
ValueRef staging_idx = add_staging(dtype, gpu_numel);
213260
add_tensor_to_staging_node(*this, idx, staging_idx);
214261
outputs_.push_back({idx, staging_idx});
215262
return staging_idx;
@@ -229,20 +276,18 @@ void ComputeGraph::copy_into_staging(
229276
const ValueRef idx,
230277
const void* data,
231278
const size_t numel) {
232-
Value& in_val = get_val(idx);
233-
api::StorageBuffer& staging = in_val.toStaging();
234-
size_t nbytes = numel * api::element_size(staging.dtype());
235-
copy_ptr_to_staging(data, staging, nbytes);
279+
StagingPtr staging = get_staging(idx);
280+
size_t nbytes = numel * api::element_size(staging->dtype());
281+
copy_ptr_to_staging(data, *staging, nbytes);
236282
}
237283

238284
void ComputeGraph::copy_from_staging(
239285
const ValueRef idx,
240286
void* data,
241287
const size_t numel) {
242-
Value& out_val = get_val(idx);
243-
api::StorageBuffer& staging = out_val.toStaging();
244-
size_t nbytes = numel * api::element_size(staging.dtype());
245-
copy_staging_to_ptr(staging, data, nbytes);
288+
StagingPtr staging = get_staging(idx);
289+
size_t nbytes = numel * api::element_size(staging->dtype());
290+
copy_staging_to_ptr(*staging, data, nbytes);
246291
}
247292

248293
void ComputeGraph::prepare() {
@@ -308,7 +353,7 @@ void ComputeGraph::resize_input(
308353
const int64_t idx,
309354
const std::vector<int64_t>& new_sizes) {
310355
IOValueRef io_val = inputs_.at(idx);
311-
get_val(io_val.value).toTensor().virtual_resize(new_sizes);
356+
get_tensor(io_val.value)->virtual_resize(new_sizes);
312357
}
313358

314359
void ComputeGraph::propagate_resize() {

0 commit comments

Comments
 (0)