Skip to content

Commit 76d8513

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Introduce vTensorPtr to prevent reference invalidation and remove get_val() API (#2978)
Summary: Pull Request resolved: #2978 ## Context Currently when writing operators developers will save a reference to a `vTensor` retrieved from a `ComputeGraph`'s list of `values_` like so: ``` vTensor& vten = graph.get_val(vref).toTensor(); ``` However, this is dangerous since if any values are added once the reference has been stored, `values_` which is a `std::vector` may have been resized and therefore have its contents moved, meaning the reference is now invalid. To protect against this, this changeset introduces the `vTensorPtr` class which is a wrapper around a `vTensor*`. When constructed, it will increment a counter in the `ComputeGraph` instance, and when destroyed it will decrement the counter. `ComputeGraph` cannot add any values while the counter is not zero. Since `Value` can be converted to other non-trivial types, this changeset also removes the `get_val` function entirely to guard against unsafe behaviour. ghstack-source-id: 222224052 exported-using-ghexport Reviewed By: jorgep31415 Differential Revision: D55984187 fbshipit-source-id: 22c619f651b5b3783c7626263694ca46b9f84723
1 parent 5ef8427 commit 76d8513

20 files changed

+411
-263
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -334,18 +334,18 @@ bool maybe_resize_input(
334334
const size_t input_i,
335335
exec_aten::Tensor& et_tensor) {
336336
ValueRef in_tensor_ref = graph->inputs()[input_i].value;
337-
vTensor& in_tensor = graph->get_val(in_tensor_ref).toTensor();
337+
vTensorPtr in_tensor = graph->get_tensor(in_tensor_ref);
338338

339339
ET_CHECK_MSG(
340-
et_tensor.dim() == in_tensor.sizes().size(),
340+
et_tensor.dim() == in_tensor->sizes().size(),
341341
"Cannot resize input tensor: old ndim %zu does not match new ndim %zu",
342-
static_cast<size_t>(in_tensor.sizes().size()),
342+
static_cast<size_t>(in_tensor->sizes().size()),
343343
static_cast<size_t>(et_tensor.dim()));
344344

345345
bool should_resize = false;
346346
std::vector<int64_t> new_sizes(et_tensor.dim());
347347
for (size_t i = 0; i < et_tensor.dim(); i++) {
348-
if (in_tensor.sizes()[i] != et_tensor.sizes()[i]) {
348+
if (in_tensor->sizes()[i] != et_tensor.sizes()[i]) {
349349
should_resize = true;
350350
}
351351
new_sizes.at(i) = et_tensor.sizes()[i];
@@ -356,9 +356,9 @@ bool maybe_resize_input(
356356
}
357357

358358
ET_CHECK_MSG(
359-
in_tensor.numel() == et_tensor.numel(),
359+
in_tensor->numel() == et_tensor.numel(),
360360
"Vulkan tensor numel %zu does not match ET tensor numel %zu",
361-
static_cast<size_t>(in_tensor.numel()),
361+
static_cast<size_t>(in_tensor->numel()),
362362
static_cast<size_t>(et_tensor.numel()));
363363

364364
return should_resize;
@@ -369,12 +369,12 @@ void maybe_resize_output(
369369
const size_t output_i,
370370
exec_aten::Tensor& et_tensor) {
371371
ValueRef out_tensor_ref = graph->outputs()[output_i].value;
372-
vTensor& out_tensor = graph->get_val(out_tensor_ref).toTensor();
372+
vTensorPtr out_tensor = graph->get_tensor(out_tensor_ref);
373373

374374
exec_aten::SizesType new_output_size[kTensorDimensionLimit];
375-
size_t ndim = out_tensor.sizes().size();
375+
size_t ndim = out_tensor->sizes().size();
376376
for (int i = 0; i < ndim; ++i) {
377-
new_output_size[i] = out_tensor.sizes()[i];
377+
new_output_size[i] = out_tensor->sizes()[i];
378378
}
379379

380380
exec_aten::ArrayRef<exec_aten::SizesType> output_size{new_output_size, ndim};

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 86 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,39 @@
1717

1818
namespace vkcompute {
1919

20+
//
21+
// VTensorPtr
22+
//
23+
24+
#define VALUE_PTR_CLASS_IMPL(classname, ctype, type_name) \
25+
classname::classname(ComputeGraph* const graph, const ValueRef idx) \
26+
: graph_(graph), ptr_(&(graph_->values_.at(idx).to##type_name())) { \
27+
graph_->values_in_use_++; \
28+
} \
29+
ctype* classname::operator->() const { \
30+
return ptr_; \
31+
} \
32+
ctype& classname::operator*() const { \
33+
return *ptr_; \
34+
} \
35+
classname::~classname() { \
36+
graph_->values_in_use_--; \
37+
}
38+
39+
VALUE_PTR_CLASS_IMPL(vTensorPtr, vTensor, Tensor)
40+
VALUE_PTR_CLASS_IMPL(TensorRefPtr, TensorRef, TensorRef)
41+
VALUE_PTR_CLASS_IMPL(StagingPtr, api::StorageBuffer, Staging)
42+
VALUE_PTR_CLASS_IMPL(IntListPtr, std::vector<int64_t>, IntList)
43+
VALUE_PTR_CLASS_IMPL(DoubleListPtr, std::vector<double>, DoubleList)
44+
VALUE_PTR_CLASS_IMPL(BoolListPtr, std::vector<bool>, BoolList)
45+
VALUE_PTR_CLASS_IMPL(ValueListPtr, std::vector<ValueRef>, ValueList)
46+
47+
#undef VALUE_PTR_CLASS_IMPL
48+
49+
//
50+
// ComputeGraph
51+
//
52+
2053
ComputeGraph::ComputeGraph(GraphConfig config)
2154
: config_{config},
2255
prepack_descriptor_counts_{},
@@ -105,6 +138,35 @@ api::GPUMemoryLayout ComputeGraph::suggested_memory_layout(
105138
return api::kChannelsPacked;
106139
}
107140

141+
void ComputeGraph::check_no_active_value_ptrs() {
142+
VK_CHECK_COND(
143+
values_in_use_ == 0,
144+
"Make sure that there are no pointers stored from the return values of "
145+
"`ComputeGraph::get_*()` functions in scope before adding Values to the "
146+
"graph. Modifying the graph's values may cause existing pointers to be "
147+
"invalidated.");
148+
}
149+
150+
std::vector<int64_t> ComputeGraph::get_sizes_of(ValueRef idx) {
151+
Value& val = values_.at(idx);
152+
if (val.isTensor()) {
153+
return val.toTensor().sizes();
154+
} else if (val.isTensorRef()) {
155+
return val.toTensorRef().sizes;
156+
}
157+
VK_THROW("Could not get sizes of value with type ", val.type());
158+
}
159+
160+
api::ScalarType ComputeGraph::get_dtype_of(ValueRef idx) {
161+
Value& val = values_.at(idx);
162+
if (val.isTensor()) {
163+
return val.toTensor().dtype();
164+
} else if (val.isTensorRef()) {
165+
return val.toTensorRef().dtype;
166+
}
167+
VK_THROW("Could not get dtype of value with type ", val.type());
168+
}
169+
108170
ValueRef ComputeGraph::add_tensor(
109171
const std::vector<int64_t>& sizes,
110172
const api::ScalarType dtype,
@@ -114,6 +176,7 @@ ValueRef ComputeGraph::add_tensor(
114176
bool allocate_memory = shared_object_idx < 0;
115177

116178
ValueRef idx(static_cast<int>(values_.size()));
179+
check_no_active_value_ptrs();
117180
values_.emplace_back(vTensor(
118181
context(), sizes, dtype, storage_type, memory_layout, allocate_memory));
119182

@@ -133,18 +196,17 @@ ValueRef ComputeGraph::add_tensor(
133196
}
134197

135198
ValueRef ComputeGraph::add_tensor_like(
136-
const ValueRef vref,
199+
const ValueRef idx,
137200
const api::StorageType storage_type,
138201
const api::GPUMemoryLayout memory_layout) {
139-
TensorRef& tref = get_val(vref).toTensorRef();
140-
return add_tensor(tref.sizes, tref.dtype, storage_type, memory_layout);
202+
return add_tensor(
203+
get_sizes_of(idx), get_dtype_of(idx), storage_type, memory_layout);
141204
}
142205

143206
ValueRef ComputeGraph::add_tensor_like(
144-
const ValueRef vref,
207+
const ValueRef idx,
145208
const api::GPUMemoryLayout memory_layout) {
146-
TensorRef& tref = get_val(vref).toTensorRef();
147-
return add_tensor(tref.sizes, tref.dtype, memory_layout);
209+
return add_tensor(get_sizes_of(idx), get_dtype_of(idx), memory_layout);
148210
}
149211

150212
ValueRef ComputeGraph::add_tensor(
@@ -160,6 +222,7 @@ ValueRef ComputeGraph::add_tensorref(
160222
const api::ScalarType dtype,
161223
const void* const data) {
162224
ValueRef idx(static_cast<int>(values_.size()));
225+
check_no_active_value_ptrs();
163226
values_.emplace_back(TensorRef(sizes, dtype, data));
164227
return idx;
165228
}
@@ -168,24 +231,28 @@ ValueRef ComputeGraph::add_staging(
168231
const api::ScalarType dtype,
169232
const size_t numel) {
170233
ValueRef idx(static_cast<int>(values_.size()));
234+
check_no_active_value_ptrs();
171235
values_.emplace_back(api::StorageBuffer(context(), dtype, numel));
172236
return idx;
173237
}
174238

175239
ValueRef ComputeGraph::add_none() {
176240
ValueRef idx(static_cast<int>(values_.size()));
241+
check_no_active_value_ptrs();
177242
values_.emplace_back();
178243
return idx;
179244
}
180245

181246
ValueRef ComputeGraph::add_value_list(std::vector<ValueRef>&& value) {
182247
ValueRef idx(static_cast<int>(values_.size()));
248+
check_no_active_value_ptrs();
183249
values_.emplace_back(std::move(value));
184250
return idx;
185251
}
186252

187253
ValueRef ComputeGraph::add_string(std::string&& str) {
188254
ValueRef idx(static_cast<int>(values_.size()));
255+
check_no_active_value_ptrs();
189256
values_.emplace_back(std::move(str));
190257
return idx;
191258
}
@@ -194,8 +261,9 @@ ValueRef ComputeGraph::set_input_tensor(
194261
const ValueRef idx,
195262
const bool use_staging) {
196263
if (use_staging) {
197-
vTensor& tensor = get_val(idx).toTensor();
198-
ValueRef staging_idx = add_staging(tensor.dtype(), tensor.gpu_numel());
264+
api::ScalarType dtype = get_tensor(idx)->dtype();
265+
size_t gpu_numel = get_tensor(idx)->gpu_numel();
266+
ValueRef staging_idx = add_staging(dtype, gpu_numel);
199267
add_staging_to_tensor_node(*this, staging_idx, idx);
200268
inputs_.push_back({idx, staging_idx});
201269
return staging_idx;
@@ -208,8 +276,9 @@ ValueRef ComputeGraph::set_output_tensor(
208276
const ValueRef idx,
209277
const bool use_staging) {
210278
if (use_staging) {
211-
vTensor& tensor = get_val(idx).toTensor();
212-
ValueRef staging_idx = add_staging(tensor.dtype(), tensor.gpu_numel());
279+
api::ScalarType dtype = get_tensor(idx)->dtype();
280+
size_t gpu_numel = get_tensor(idx)->gpu_numel();
281+
ValueRef staging_idx = add_staging(dtype, gpu_numel);
213282
add_tensor_to_staging_node(*this, idx, staging_idx);
214283
outputs_.push_back({idx, staging_idx});
215284
return staging_idx;
@@ -229,20 +298,18 @@ void ComputeGraph::copy_into_staging(
229298
const ValueRef idx,
230299
const void* data,
231300
const size_t numel) {
232-
Value& in_val = get_val(idx);
233-
api::StorageBuffer& staging = in_val.toStaging();
234-
size_t nbytes = numel * api::element_size(staging.dtype());
235-
copy_ptr_to_staging(data, staging, nbytes);
301+
StagingPtr staging = get_staging(idx);
302+
size_t nbytes = numel * api::element_size(staging->dtype());
303+
copy_ptr_to_staging(data, *staging, nbytes);
236304
}
237305

238306
void ComputeGraph::copy_from_staging(
239307
const ValueRef idx,
240308
void* data,
241309
const size_t numel) {
242-
Value& out_val = get_val(idx);
243-
api::StorageBuffer& staging = out_val.toStaging();
244-
size_t nbytes = numel * api::element_size(staging.dtype());
245-
copy_staging_to_ptr(staging, data, nbytes);
310+
StagingPtr staging = get_staging(idx);
311+
size_t nbytes = numel * api::element_size(staging->dtype());
312+
copy_staging_to_ptr(*staging, data, nbytes);
246313
}
247314

248315
void ComputeGraph::prepare() {
@@ -308,7 +375,7 @@ void ComputeGraph::resize_input(
308375
const int64_t idx,
309376
const std::vector<int64_t>& new_sizes) {
310377
IOValueRef io_val = inputs_.at(idx);
311-
get_val(io_val.value).toTensor().virtual_resize(new_sizes);
378+
get_tensor(io_val.value)->virtual_resize(new_sizes);
312379
}
313380

314381
void ComputeGraph::propagate_resize() {

0 commit comments

Comments
 (0)