Skip to content

Commit 4b6a033

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Support serializing scalar tensors as SymInt values (#6070)
Summary: Pull Request resolved: #6070 ## Context * Add `SymInt` to serialization schema * Make the serializer serialize scalar tensors as `SymInt` instead of `VkTensor` * Add support for `SymInt` in `VulkanBackend.cpp` ghstack-source-id: 247163958 exported-using-ghexport Reviewed By: jorgep31415 Differential Revision: D64139868 fbshipit-source-id: 44d225ca6c63b311e4839783787713a38b8b6017
1 parent 1a0c2c7 commit 4b6a033

File tree

4 files changed

+31
-0
lines changed

4 files changed

+31
-0
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,12 @@ class GraphBuilder {
248248
ref_mapping_[fb_id] = ref;
249249
}
250250

251+
void add_symint_to_graph(const uint32_t fb_id, VkValuePtr value) {
252+
const int32_t fb_symint = value->value_as_SymInt()->value();
253+
ValueRef ref = compute_graph_->add_symint(fb_symint);
254+
ref_mapping_[fb_id] = ref;
255+
}
256+
251257
void add_value_to_graph(const uint32_t fb_id, VkValuePtr value) {
252258
ET_CHECK_MSG(
253259
!fb_id_exists(fb_id),
@@ -300,6 +306,9 @@ class GraphBuilder {
300306
case vkgraph::GraphTypes::String:
301307
add_string_to_graph(fb_id, value);
302308
break;
309+
case vkgraph::GraphTypes::SymInt:
310+
add_symint_to_graph(fb_id, value);
311+
break;
303312
default:
304313
ET_CHECK_MSG(false, "Unsupported value type.");
305314
}

backends/vulkan/serialization/schema.fbs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ table ValueList {
8989
items:[int];
9090
}
9191

92+
table SymInt {
93+
value:int;
94+
}
95+
9296
union GraphTypes {
9397
Null,
9498
Int,
@@ -100,6 +104,7 @@ union GraphTypes {
100104
BoolList,
101105
ValueList,
102106
String,
107+
SymInt,
103108
}
104109

105110
table VkValue {

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,12 @@ def maybe_add_constant_tensor(self, node: Node) -> int:
139139
return constant_id
140140

141141
def create_node_value(self, node: Node) -> int:
142+
# If the node has been marked as a scalar tensor, create a SymInt instead of a tensor
143+
if node.meta.get("vkdg_is_scalar_tensor", False):
144+
new_id = self.create_symint_value()
145+
self.node_to_value_ids[node] = new_id
146+
return new_id
147+
142148
spec = node.meta.get("spec")
143149
if isinstance(spec, TensorSpec):
144150
constant_id = self.maybe_add_constant_tensor(node)
@@ -169,6 +175,11 @@ def create_scalar_value(self, scalar: _ScalarType) -> int:
169175
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Double(scalar)))
170176
return new_id
171177

178+
def create_symint_value(self) -> int:
179+
new_id = len(self.values)
180+
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.SymInt(0)))
181+
return new_id
182+
172183
def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:
173184
# Negative id indicates that this tensor will have its own dedicated memory.
174185
mem_obj_id = -1

backends/vulkan/serialization/vulkan_graph_schema.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ class String:
100100
string_val: str
101101

102102

103+
@dataclass
104+
class SymInt:
105+
value: int
106+
107+
103108
GraphTypes = Union[
104109
Null,
105110
Int,
@@ -111,6 +116,7 @@ class String:
111116
DoubleList,
112117
ValueList,
113118
String,
119+
SymInt,
114120
]
115121

116122

0 commit comments

Comments
 (0)