Skip to content

[ET-VK] Support serializing scalar tensors as SymInt values #6070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,12 @@ class GraphBuilder {
ref_mapping_[fb_id] = ref;
}

void add_symint_to_graph(const uint32_t fb_id, VkValuePtr value) {
const int32_t fb_symint = value->value_as_SymInt()->value();
ValueRef ref = compute_graph_->add_symint(fb_symint);
ref_mapping_[fb_id] = ref;
}

void add_value_to_graph(const uint32_t fb_id, VkValuePtr value) {
ET_CHECK_MSG(
!fb_id_exists(fb_id),
Expand Down Expand Up @@ -300,6 +306,9 @@ class GraphBuilder {
case vkgraph::GraphTypes::String:
add_string_to_graph(fb_id, value);
break;
case vkgraph::GraphTypes::SymInt:
add_symint_to_graph(fb_id, value);
break;
default:
ET_CHECK_MSG(false, "Unsupported value type.");
}
Expand Down
5 changes: 5 additions & 0 deletions backends/vulkan/serialization/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ table ValueList {
items:[int];
}

table SymInt {
value:int;
}

union GraphTypes {
Null,
Int,
Expand All @@ -100,6 +104,7 @@ union GraphTypes {
BoolList,
ValueList,
String,
SymInt,
}

table VkValue {
Expand Down
11 changes: 11 additions & 0 deletions backends/vulkan/serialization/vulkan_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@ def maybe_add_constant_tensor(self, node: Node) -> int:
return constant_id

def create_node_value(self, node: Node) -> int:
# If the node has been marked as a scalar tensor, create a SymInt instead of a tensor
if node.meta.get("vkdg_is_scalar_tensor", False):
new_id = self.create_symint_value()
self.node_to_value_ids[node] = new_id
return new_id

spec = node.meta.get("spec")
if isinstance(spec, TensorSpec):
constant_id = self.maybe_add_constant_tensor(node)
Expand Down Expand Up @@ -169,6 +175,11 @@ def create_scalar_value(self, scalar: _ScalarType) -> int:
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Double(scalar)))
return new_id

def create_symint_value(self) -> int:
new_id = len(self.values)
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.SymInt(0)))
return new_id

def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:
# Negative id indicates that this tensor will have its own dedicated memory.
mem_obj_id = -1
Expand Down
6 changes: 6 additions & 0 deletions backends/vulkan/serialization/vulkan_graph_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ class String:
string_val: str


@dataclass
class SymInt:
value: int


GraphTypes = Union[
Null,
Int,
Expand All @@ -111,6 +116,7 @@ class String:
DoubleList,
ValueList,
String,
SymInt,
]


Expand Down
Loading