Skip to content

Commit 7795ca2

Browse files
committed
Update on "[ET-VK] Introduce vTensorPtr to prevent reference invalidation and remove get_val() API"
## Context Currently when writing operators developers will save a reference to a `vTensor` retrieved from a `ComputeGraph`'s list of `values_` like so: ``` vTensor& vten = graph.get_val(vref).toTensor(); ``` However, this is dangerous since if any values are added once the reference has been stored, `values_` which is a `std::vector` may have been resized and therefore have its contents moved, meaning the reference is now invalid. To protect against this, this changeset introduces the `vTensorPtr` class which is a wrapper around a `vTensor*`. When constructed, it will increment a counter in the `ComputeGraph` instance, and when destroyed it will decrement the counter. `ComputeGraph` cannot add any values while the counter is not zero. Since `Value` can be converted to other non-trivial types, this changeset also removes the `get_val` function entirely to guard against unsafe behaviour. Differential Revision: [D55984187](https://our.internmc.facebook.com/intern/diff/D55984187/) [ghstack-poisoned]
1 parent 04975bd commit 7795ca2

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

backends/vulkan/runtime/graph/ops/impl/Sum.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ ValueRef add_node(
102102
const int dim,
103103
const bool keepdim,
104104
const api::ScalarType dtype = api::kFloat) {
105-
std::vector<int64_t> output_size = calc_out_sizes(*(graph.get_tensor(input)), dim, keepdim);
105+
std::vector<int64_t> output_size =
106+
calc_out_sizes(*(graph.get_tensor(input)), dim, keepdim);
106107
return graph.add_tensor(output_size, dtype, api::kChannelsPacked);
107108
}
108109

backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ std::vector<int64_t> calc_out_sizes_hw(
150150
return calc_transpose_out_sizes_hw(
151151
in_sizes, kernel_size, stride, padding, dilation, output_padding);
152152
} else {
153-
const bool ceil_mode = graph.extract_scalar<bool>(args[3]);
153+
const bool ceil_mode =
154+
graph.val_is_bool(args[3]) ? graph.get_bool(args[3]) : false;
155+
154156
return calc_out_sizes_hw(
155157
in_sizes, kernel_size, stride, padding, dilation, ceil_mode);
156158
}

0 commit comments

Comments
 (0)