Skip to content

Commit 755c89c

Browse files
copyrightlyfacebook-github-bot
authored andcommitted
add aten.sum.dim_IntList (#2796)
Summary: Pull Request resolved: #2796 ` aten.sum.dim_IntList` reads in a parameter `dims` which are the dimensions we want to reduce. We process these dimensions one by one by adding intermediate nodes to store the intermediate output. bypass-github-pytorch-ci-checks bypass-github-export-checks Reviewed By: jorgep31415 Differential Revision: D55288560 fbshipit-source-id: e46e7f686aa61f27d38cf7fa9b20af853704d1cd
1 parent daa217f commit 755c89c

File tree

8 files changed

+427
-0
lines changed

8 files changed

+427
-0
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4646
exir_ops.edge.aten.mm.default,
4747
# Pooling operators
4848
exir_ops.edge.aten.max_pool2d_with_indices.default,
49+
# Sum
50+
exir_ops.edge.aten.sum.dim_IntList,
4951
# Other
5052
operator.getitem,
5153
]
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#include "broadcasting_utils.h"
12+
#include "indexing_utils.h"
13+
14+
#define PRECISION ${PRECISION}
15+
16+
layout(std430) buffer;
17+
18+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
19+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
20+
21+
layout(set = 0, binding = 2) uniform PRECISION restrict OutExtents {
22+
uvec4 data;
23+
}
24+
out_extents;
25+
26+
// dim to sum
27+
layout(set = 0, binding = 3) uniform PRECISION restrict DimVal {
28+
int data;
29+
}
30+
dim;
31+
32+
// size of dim (in the input)
33+
layout(set = 0, binding = 4) uniform PRECISION restrict DimSize {
34+
int data;
35+
}
36+
dim_size;
37+
38+
layout(set = 0, binding = 5) uniform PRECISION restrict Channel {
39+
int data;
40+
}
41+
flattened_channels;
42+
43+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
44+
45+
/*
46+
* Returns a new tensor with values summed along dimension dim
47+
* Dimension dim is squeezed
48+
* For each pos:
49+
* - Iterate over the out_texel and the summed dimension
50+
* - For H,W; rearrange pos.x, pos.y
51+
* - For C,H,W;
52+
* When CHW are summed, batch moves into channel
53+
* The src N is determined by pos.z * 4 + out_index
54+
*/
55+
56+
void main() {
57+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
58+
59+
vec4 out_texel = vec4(0);
60+
61+
int src_n;
62+
int src_c;
63+
64+
// Batch
65+
if (dim.data == 0) {
66+
for (int batch = 0; batch < dim_size.data; ++batch) {
67+
src_n = batch;
68+
src_c = pos.z;
69+
int src_z = src_n * flattened_channels.data + src_c;
70+
vec4 v = texelFetch(image_in, ivec3(pos.x, pos.y, src_z), 0);
71+
out_texel += v;
72+
}
73+
imageStore(image_out, pos, out_texel);
74+
}
75+
76+
// Channel
77+
else if (dim.data == 1) {
78+
for (int out_index = 0; out_index < 4; ++out_index) {
79+
for (int channel = 0; channel < dim_size.data; ++channel) {
80+
src_n = pos.z * 4 + out_index;
81+
src_c = channel;
82+
int src_z =
83+
src_n * flattened_channels.data + src_c / 4;
84+
vec4 v = texelFetch(image_in, ivec3(pos.x, pos.y, src_z), 0);
85+
out_texel[out_index] += v[channel % 4];
86+
}
87+
}
88+
imageStore(image_out, pos, out_texel);
89+
}
90+
91+
// Height, Width
92+
else {
93+
for (int out_index = 0; out_index < 4; ++out_index) {
94+
src_n = pos.z * 4 + out_index;
95+
src_c = pos.y;
96+
int src_z = src_n * flattened_channels.data + src_c / 4;
97+
for (int hw = 0; hw < dim_size.data; ++hw) {
98+
vec4 v = (dim.data == 2)
99+
? texelFetch(image_in, ivec3(pos.x, hw, src_z), 0) // Height
100+
: texelFetch(image_in, ivec3(hw, pos.x, src_z), 0); // Width
101+
out_texel[out_index] += v[pos.y % 4];
102+
}
103+
}
104+
imageStore(image_out, pos, out_texel);
105+
}
106+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
sum_dim:
8+
parameter_names_with_default_values:
9+
NDIM: 3
10+
DTYPE: float
11+
generate_variant_forall:
12+
DTYPE:
13+
- VALUE: half
14+
SUFFIX: half
15+
- VALUE: float
16+
SUFFIX: float
17+
shader_variants:
18+
- NAME: sum_dim
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#include "indexing_utils.h"
12+
13+
#define PRECISION ${PRECISION}
14+
15+
layout(std430) buffer;
16+
17+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
18+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
19+
20+
layout(set = 0, binding = 2) uniform PRECISION restrict OutExtents {
21+
uvec4 data;
22+
}
23+
out_extents;
24+
25+
// dim to sum
26+
layout(set = 0, binding = 3) uniform PRECISION restrict DimVal {
27+
int data;
28+
}
29+
dim;
30+
31+
// size of dim (in the input)
32+
layout(set = 0, binding = 4) uniform PRECISION restrict DimSize {
33+
int data;
34+
}
35+
dim_size;
36+
37+
layout(set = 0, binding = 5) uniform PRECISION restrict Channel {
38+
int data;
39+
}
40+
flattened_channels;
41+
42+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
43+
44+
/*
45+
* Returns a new tensor with values summed along dimension dim.
46+
* Output and input have same number of dimensions.
47+
* summed dimension is of size 1.
48+
*/
49+
50+
void main() {
51+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
52+
53+
vec4 out_texel = vec4(0);
54+
55+
int src_n;
56+
int src_c;
57+
58+
// Batch
59+
if (dim.data == 0) {
60+
for (int batch = 0; batch < dim_size.data; ++batch) {
61+
src_n = batch;
62+
src_c = pos.z;
63+
int src_z = src_n * flattened_channels.data + src_c;
64+
out_texel += texelFetch(image_in, ivec3(pos.x, pos.y, src_z), 0);
65+
}
66+
imageStore(image_out, pos, out_texel);
67+
}
68+
69+
// Channel
70+
else if (dim.data == 1) {
71+
for (int out_index = 0; out_index < 4; ++out_index) {
72+
for (int channel = 0; channel < dim_size.data; ++channel) {
73+
src_n = pos.z;
74+
src_c = channel;
75+
int src_z = src_n * flattened_channels.data + src_c / 4;
76+
vec4 v = texelFetch(image_in, ivec3(pos.x, pos.y, src_z), 0);
77+
out_texel[out_index] += v[channel % 4];
78+
}
79+
}
80+
imageStore(image_out, pos, out_texel);
81+
}
82+
83+
// Height, Width
84+
else {
85+
for (int hw = 0; hw < dim_size.data; ++hw) {
86+
vec4 v = (dim.data == 2)
87+
? texelFetch(image_in, ivec3(pos.x, hw, pos.z), 0) // Height
88+
: texelFetch(image_in, ivec3(hw, pos.y, pos.z), 0); // Width
89+
out_texel += v;
90+
}
91+
imageStore(image_out, pos, out_texel);
92+
}
93+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
sum_dim_keepdim:
8+
parameter_names_with_default_values:
9+
NDIM: 3
10+
DTYPE: float
11+
generate_variant_forall:
12+
DTYPE:
13+
- VALUE: half
14+
SUFFIX: half
15+
- VALUE: float
16+
SUFFIX: float
17+
shader_variants:
18+
- NAME: sum_dim_keepdim

0 commit comments

Comments
 (0)