Skip to content

Commit 1ad34f3

Browse files
committed
[ET-VK] Introduce add_tensor overloads consuming TensorRef
From @SSJia: > we should always make sure to store references produced from `graph.get_val()` only after any calls to `graph.add_*()` (i.e. modifications to the values list) are made. This is because `graph.values_`, being a `std::vector`, will reallocate with more space and move its contents if the current allocation is not sufficient. This means that if you store a reference then call `graph.add_*()` then the underlying resource the reference points to may have been moved. I think we can guard against this behavior by passing a `TensorRef` directly, and never having to declare a variable `TensorRef& tref` in the caller's scope. An example is shown in `Staging.cpp`. We could have it consume `ValueRef` for brevity of the passing parameter but IMO it hinders readability. Differential Revision: [D55703483](https://our.internmc.facebook.com/intern/diff/D55703483/) [ghstack-poisoned]
1 parent fa6731a commit 1ad34f3

File tree

3 files changed

+33
-9
lines changed

3 files changed

+33
-9
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,13 @@ ValueRef ComputeGraph::add_tensor(
123123
return idx;
124124
}
125125

126+
ValueRef ComputeGraph::add_tensor(
127+
TensorRef& tref,
128+
const api::StorageType storage_type,
129+
const api::GPUMemoryLayout memory_layout) {
130+
return add_tensor(tref.sizes, tref.dtype, storage_type, memory_layout);
131+
}
132+
126133
ValueRef ComputeGraph::add_tensor(
127134
const std::vector<int64_t>& sizes,
128135
const api::ScalarType dtype,
@@ -132,16 +139,18 @@ ValueRef ComputeGraph::add_tensor(
132139
sizes, dtype, suggested_storage_type(), memory_layout, shared_object_idx);
133140
}
134141

142+
ValueRef ComputeGraph::add_tensor(
143+
TensorRef& tref,
144+
const api::GPUMemoryLayout memory_layout) {
145+
return add_tensor(tref.sizes, tref.dtype, memory_layout);
146+
}
147+
135148
ValueRef ComputeGraph::add_tensor(
136149
const std::vector<int64_t>& sizes,
137150
const api::ScalarType dtype,
138151
const int64_t shared_object_idx) {
139152
return add_tensor(
140-
sizes,
141-
dtype,
142-
suggested_storage_type(),
143-
suggested_memory_layout(sizes),
144-
shared_object_idx);
153+
sizes, dtype, suggested_memory_layout(sizes), shared_object_idx);
145154
}
146155

147156
ValueRef ComputeGraph::add_tensorref(

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,15 @@ class ComputeGraph final {
172172
const api::ScalarType dtype,
173173
const api::StorageType storage_type,
174174
const api::GPUMemoryLayout memory_layout,
175-
const int64_t shared_object_idx);
175+
const int64_t shared_object_idx = -1);
176+
177+
/*
178+
* Add a `vTensor` value to the graph with the properties of `tref`.
179+
*/
180+
ValueRef add_tensor(
181+
TensorRef& tref,
182+
const api::StorageType storage_type,
183+
const api::GPUMemoryLayout memory_layout);
176184

177185
/*
178186
* Add a `vTensor` value to the graph with the specified properties. The
@@ -184,14 +192,22 @@ class ComputeGraph final {
184192
const api::GPUMemoryLayout memory_layout,
185193
const int64_t shared_object_idx = -1);
186194

195+
/*
196+
* Add a `vTensor` value to the graph with the properties of `tref`. The
197+
* suggested storage type will be used to construct the `vTensor`.
198+
*/
199+
ValueRef add_tensor(
200+
TensorRef& tref,
201+
const api::GPUMemoryLayout memory_layout);
202+
187203
/*
188204
* Add a `vTensor` value to the graph with the specified properties. The
189205
* suggested storage type and memory layout will be used to construct the
190206
* `vTensor`.
191207
*/
192208
ValueRef add_tensor(
193209
const std::vector<int64_t>& sizes,
194-
const api::ScalarType dtype = api::ScalarType::Float,
210+
const api::ScalarType dtype,
195211
const int64_t shared_object_idx = -1);
196212

197213
/*

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ ValueRef prepack(
6363
ComputeGraph& graph,
6464
const ValueRef vref,
6565
const api::GPUMemoryLayout layout) {
66-
TensorRef& tref = graph.get_val(vref).toTensorRef();
67-
ValueRef v = graph.add_tensor(tref.sizes, tref.dtype, layout);
66+
ValueRef v = graph.add_tensor(graph.get_val(vref).toTensorRef(), layout);
6867
vTensor& t = graph.get_val(v).toTensor();
6968

7069
api::ShaderInfo shader = get_nchw_to_image_shader(t);

0 commit comments

Comments
 (0)