Skip to content

Commit 9b43159

Browse files
[ET-VK] Adding convenience functions in Compute graph to get PushConstantDataInfo for various attributes of a tensor.
Pull Request resolved: #7224 This diff adds convenience functions in the Compute graph to get PushConstantDataInfo for various attributes of a tensor. ghstack-source-id: 257899588 @exported-using-ghexport Differential Revision: [D66853502](https://our.internmc.facebook.com/intern/diff/D66853502/) --------- Co-authored-by: Vivek Trivedi <[email protected]>
1 parent 10bd491 commit 9b43159

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <executorch/backends/vulkan/runtime/graph/containers/SharedObject.h>
2121
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
2222

23+
#include <executorch/backends/vulkan/runtime/graph/ops/DispatchNode.h>
2324
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
2425
#include <executorch/backends/vulkan/runtime/graph/ops/PrepackNode.h>
2526

@@ -350,6 +351,28 @@ class ComputeGraph final {
350351
return values_.at(idx).toTensor().logical_limits_ubo();
351352
}
352353

354+
inline PushConstantDataInfo sizes_pc_of(const ValueRef idx) const {
355+
return PushConstantDataInfo(
356+
values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorSizes);
357+
}
358+
359+
inline PushConstantDataInfo strides_pc_of(const ValueRef idx) const {
360+
return PushConstantDataInfo(
361+
values_.at(idx).toConstTensor().get_uniform_data(),
362+
api::kTensorStrides);
363+
}
364+
365+
inline PushConstantDataInfo logical_limits_pc_of(const ValueRef idx) const {
366+
return PushConstantDataInfo(
367+
values_.at(idx).toConstTensor().get_uniform_data(),
368+
api::kTensorLogicalLimits);
369+
}
370+
371+
inline PushConstantDataInfo numel_pc_of(const ValueRef idx) const {
372+
return PushConstantDataInfo(
373+
values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorNumel);
374+
}
375+
353376
//
354377
// Scalar Value Extraction
355378
//

0 commit comments

Comments
 (0)