Skip to content

Commit b551ec4

Browse files
copyrightlyfacebook-github-bot
authored andcommitted
add aten.sum.default (#2807)
Summary: The operator `aten.sum.dim_IntList` could take an empty list as the parameter for `dims`. We modify `vulkan_graph_builder.py` to accommodate the empty list. Moreover, the op `aten.sum.default` is implemented as a [decomposition](https://www.internalfb.com/code/fbsource/[96e496f9db8f92967b4394bd4f60e39ab916740b]/xplat/caffe2/torch/_decomp/decompositions.py?lines=4676) into `aten.sum.dim_IntList` with empty `dims`. So we will support `aten.sum.default` with the changes. Context: `torch.sum(x, ())` and `torch.sum(x)` are two ways to compute the sum of all elements in tensor `x`. Differential Revision: D55630993
1 parent d612c23 commit b551ec4

File tree

4 files changed

+40
-6
lines changed

4 files changed

+40
-6
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4848
exir_ops.edge.aten.max_pool2d_with_indices.default,
4949
# Sum
5050
exir_ops.edge.aten.sum.dim_IntList,
51+
exir_ops.edge.aten.sum.default,
5152
# Other
5253
operator.getitem,
5354
]

backends/vulkan/runtime/graph/ops/impl/Sum.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,17 @@ void add_sum_dim_IntList(
123123
auto dims_to_sum = graph.get_val(opt_dim).toIntList();
124124
int64_t in_dim = in_tensor.sizes().size();
125125

126-
for (const auto& dim : dims_to_sum) {
127-
// Normalize (negative) dim into range [0, self.dim() - 1]
128-
int64_t dim_normalized = normalize(dim, in_dim);
129-
dims_set.insert(dim_normalized);
126+
if (dims_to_sum.empty()) {
127+
// If dim is not specified, reduce over all dims
128+
for (int64_t i = 0; i < in_dim; ++i) {
129+
dims_set.insert(i);
130+
}
131+
} else {
132+
for (const auto& dim : dims_to_sum) {
133+
// Normalize (negative) dim into range [0, self.dim() - 1]
134+
int64_t dim_normalized = normalize(dim, in_dim);
135+
dims_set.insert(dim_normalized);
136+
}
130137
}
131138

132139
// Reduce the higher dimensionalities first, otherwise when keepdim is

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,11 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:
178178

179179
def create_scalar_list_value(self, arg: List[_ScalarType]) -> int:
180180
new_id = len(self.values)
181-
if isinstance(arg[0], bool):
181+
if len(arg) == 0:
182+
self.values.append(
183+
vk_graph_schema.VkValue(vk_graph_schema.IntList(items=[]))
184+
)
185+
elif isinstance(arg[0], bool):
182186
self.values.append(
183187
vk_graph_schema.VkValue(
184188
vk_graph_schema.BoolList(items=[cast(bool, e) for e in arg])
@@ -227,7 +231,9 @@ def get_or_create_value_for(self, arg: _Argument):
227231
return self.create_scalar_value(arg)
228232
elif isinstance(arg, TensorSpec):
229233
return self.create_tensor_value(arg)
230-
elif isinstance(arg, list) and isinstance(arg[0], _ScalarType):
234+
elif isinstance(arg, list) and (
235+
len(arg) == 0 or isinstance(arg[0], _ScalarType)
236+
):
231237
# pyre-ignore[6]
232238
return self.create_scalar_list_value(arg)
233239
elif isinstance(arg, list) and isinstance(arg[0], Node):

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,3 +496,23 @@ def forward(self, x):
496496
sample_inputs,
497497
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
498498
)
499+
500+
def test_vulkan_backend_sum(self):
501+
class SumModule(torch.nn.Module):
502+
def __init__(self):
503+
super().__init__()
504+
505+
# test both torch.sum(x, ()) and torch.sum(x)
506+
def forward(self, x):
507+
x = torch.sum(x, (), keepdim=True)
508+
x = torch.sum(x)
509+
return x
510+
511+
module = SumModule()
512+
sample_inputs = (torch.rand(size=(3, 2, 7, 5), dtype=torch.float32),)
513+
514+
self.lower_module_and_test_output(
515+
module,
516+
sample_inputs,
517+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
518+
)

0 commit comments

Comments
 (0)