Skip to content

Commit 49d72c8

Browse files
committed
Update on "[ET-VK] Introduce add_tensor overloads consuming TensorRef"
From ssjia: > we should always make sure to store references produced from `graph.get_val()` only after any calls to `graph.add_*()` (i.e. modifications to the values list) are made. This is because `graph.values_`, being a `std::vector`, will reallocate with more space and move its contents if the current allocation is not sufficient. This means that if you store a reference then call `graph.add_*()` then the underlying resource the reference points to may have been moved. I think we can guard against this behavior by passing a `TensorRef` directly, and never having to declare a variable `TensorRef& tref` in the caller's scope. An example is shown in `Staging.cpp`. We could have it consume `ValueRef` for brevity of the passing parameter but IMO it hinders readability. Differential Revision: [D55703483](https://our.internmc.facebook.com/intern/diff/D55703483/) [ghstack-poisoned]
1 parent 1ad34f3 commit 49d72c8

File tree

3 files changed

+28
-26
lines changed

3 files changed

+28
-26
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,6 @@ ValueRef ComputeGraph::add_tensor(
123123
return idx;
124124
}
125125

126-
ValueRef ComputeGraph::add_tensor(
127-
TensorRef& tref,
128-
const api::StorageType storage_type,
129-
const api::GPUMemoryLayout memory_layout) {
130-
return add_tensor(tref.sizes, tref.dtype, storage_type, memory_layout);
131-
}
132-
133126
ValueRef ComputeGraph::add_tensor(
134127
const std::vector<int64_t>& sizes,
135128
const api::ScalarType dtype,
@@ -139,9 +132,18 @@ ValueRef ComputeGraph::add_tensor(
139132
sizes, dtype, suggested_storage_type(), memory_layout, shared_object_idx);
140133
}
141134

142-
ValueRef ComputeGraph::add_tensor(
143-
TensorRef& tref,
135+
ValueRef ComputeGraph::add_tensor_like(
136+
const ValueRef vref,
137+
const api::StorageType storage_type,
138+
const api::GPUMemoryLayout memory_layout) {
139+
TensorRef& tref = get_val(vref).toTensorRef();
140+
return add_tensor(tref.sizes, tref.dtype, storage_type, memory_layout);
141+
}
142+
143+
ValueRef ComputeGraph::add_tensor_like(
144+
const ValueRef vref,
144145
const api::GPUMemoryLayout memory_layout) {
146+
TensorRef& tref = get_val(vref).toTensorRef();
145147
return add_tensor(tref.sizes, tref.dtype, memory_layout);
146148
}
147149

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -174,14 +174,6 @@ class ComputeGraph final {
174174
const api::GPUMemoryLayout memory_layout,
175175
const int64_t shared_object_idx = -1);
176176

177-
/*
178-
* Add a `vTensor` value to the graph with the properties of `tref`.
179-
*/
180-
ValueRef add_tensor(
181-
TensorRef& tref,
182-
const api::StorageType storage_type,
183-
const api::GPUMemoryLayout memory_layout);
184-
185177
/*
186178
* Add a `vTensor` value to the graph with the specified properties. The
187179
* suggested storage type will be used to construct the `vTensor`.
@@ -192,14 +184,6 @@ class ComputeGraph final {
192184
const api::GPUMemoryLayout memory_layout,
193185
const int64_t shared_object_idx = -1);
194186

195-
/*
196-
* Add a `vTensor` value to the graph with the properties of `tref`. The
197-
* suggested storage type will be used to construct the `vTensor`.
198-
*/
199-
ValueRef add_tensor(
200-
TensorRef& tref,
201-
const api::GPUMemoryLayout memory_layout);
202-
203187
/*
204188
* Add a `vTensor` value to the graph with the specified properties. The
205189
* suggested storage type and memory layout will be used to construct the
@@ -210,6 +194,22 @@ class ComputeGraph final {
210194
const api::ScalarType dtype,
211195
const int64_t shared_object_idx = -1);
212196

197+
/*
198+
* Add a `vTensor` value to the graph with the properties of `vref`.
199+
*/
200+
ValueRef add_tensor_like(
201+
const ValueRef vref,
202+
const api::StorageType storage_type,
203+
const api::GPUMemoryLayout memory_layout);
204+
205+
/*
206+
* Add a `vTensor` value to the graph with the properties of `vref`. The
207+
* suggested storage type will be used to construct the `vTensor`.
208+
*/
209+
ValueRef add_tensor_like(
210+
const ValueRef vref,
211+
const api::GPUMemoryLayout memory_layout);
212+
213213
/*
214214
* Add a `TensorRef` value to the graph with the specific properties. A
215215
* `TensorRef` is a reference to a `vTensor` whose data is stored in an

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ ValueRef prepack(
6363
ComputeGraph& graph,
6464
const ValueRef vref,
6565
const api::GPUMemoryLayout layout) {
66-
ValueRef v = graph.add_tensor(graph.get_val(vref).toTensorRef(), layout);
66+
ValueRef v = graph.add_tensor_like(vref, layout);
6767
vTensor& t = graph.get_val(v).toTensor();
6868

6969
api::ShaderInfo shader = get_nchw_to_image_shader(t);

0 commit comments

Comments
 (0)