@@ -132,6 +132,33 @@ ValueRef prepack(
132
132
return v;
133
133
}
134
134
135
+ ValueRef prepack_buffer (
136
+ ComputeGraph& graph,
137
+ const ValueRef vref,
138
+ const utils::GPUMemoryLayout layout) {
139
+ ValueRef v = graph.add_tensor_like (vref, layout);
140
+
141
+ vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR (" buffer_to_buffer" );
142
+
143
+ vkapi::ParamsBindList ubos;
144
+ ubos.append ({graph.numel_ubo (v)});
145
+
146
+ graph.prepack_nodes ().emplace_back (new PrepackNode (
147
+ graph,
148
+ shader,
149
+ graph.create_global_wg_size (v),
150
+ graph.create_local_wg_size (v),
151
+ // Input and Outputs
152
+ vref,
153
+ v,
154
+ // Parameter Buffers
155
+ ubos,
156
+ // Specialization Constants
157
+ {}));
158
+
159
+ return v;
160
+ }
161
+
135
162
ValueRef prepack_if_tensor_ref (
136
163
ComputeGraph& graph,
137
164
const ValueRef v,
@@ -143,6 +170,17 @@ ValueRef prepack_if_tensor_ref(
143
170
}
144
171
}
145
172
173
+ ValueRef prepack_buffer_if_tensor_ref (
174
+ ComputeGraph& graph,
175
+ const ValueRef v,
176
+ const utils::GPUMemoryLayout layout) {
177
+ if (graph.val_is_tref (v)) {
178
+ return prepack_buffer (graph, v, layout);
179
+ } else {
180
+ return v;
181
+ }
182
+ }
183
+
146
184
ValueRef prepack_if_tensor_ref (ComputeGraph& graph, const ValueRef v) {
147
185
if (graph.val_is_tref (v)) {
148
186
utils::GPUMemoryLayout layout =
@@ -153,4 +191,14 @@ ValueRef prepack_if_tensor_ref(ComputeGraph& graph, const ValueRef v) {
153
191
}
154
192
}
155
193
194
+ ValueRef prepack_buffer_if_tensor_ref (ComputeGraph& graph, const ValueRef v) {
195
+ if (graph.val_is_tref (v)) {
196
+ utils::GPUMemoryLayout layout =
197
+ graph.suggested_memory_layout (graph.get_tref (v)->sizes );
198
+ return prepack_buffer (graph, v, layout);
199
+ } else {
200
+ return v;
201
+ }
202
+ }
203
+
156
204
} // namespace vkcompute
0 commit comments