Skip to content

Commit 8ffe881

Browse files
committed
Update on "[ET-VK] Using shared memory to save position in conv2d dw output op."
This diff introduces a change to conv2d dw op to save output positions in shared memory, which reduces register usage and improves performance. Differential Revision: [D68400890](https://our.internmc.facebook.com/intern/diff/D68400890/) [ghstack-poisoned]
2 parents 011734f + 0d34737 commit 8ffe881

File tree

25 files changed

+386
-93
lines changed

25 files changed

+386
-93
lines changed

backends/arm/operator_support/to_copy_support.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool
125125
# Check dim_order (to_dim_order_copy)
126126
if "dim_order" in node.kwargs:
127127
dim_order = node.kwargs["dim_order"]
128+
# pyre-ignore[6]
128129
if dim_order != list(range(len(dim_order))):
129130
logger.info(
130131
f"Argument {dim_order=} is not supported for "

backends/cadence/aot/compiler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ExecutorchProgramManager,
3434
to_edge,
3535
)
36+
from executorch.exir.dialects._ops import ops as exir_ops
3637
from executorch.exir.pass_base import PassResult
3738
from executorch.exir.passes import ToOutVarPass
3839
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
@@ -186,14 +187,17 @@ def export_to_edge(
186187
edge_prog_manager = to_edge(
187188
expo_program,
188189
compile_config=EdgeCompileConfig(
189-
_skip_dim_order=True,
190190
# Allow specific non-core aten ops in the IR.
191191
_core_aten_ops_exception_list=[
192192
torch.ops.aten._native_batch_norm_legit_functional.default,
193193
torch.ops.aten.linear.default,
194194
torch.ops.aten.linalg_vector_norm.default,
195195
torch.ops.aten.unfold.default,
196196
torch.ops.aten.angle.default,
197+
# cadence replaced to_dim_order_copy with _to_copy for performance
198+
# skip _to_copy op to get around of dim order check
199+
# We should remove this op once cadence can support dim order
200+
exir_ops.edge.aten._to_copy.default,
197201
],
198202
),
199203
constant_methods=constant_methods,

backends/cadence/aot/replace_ops.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
# pyre-unsafe
1313

14+
import copy
1415
import math
1516
from operator import neg
1617
from typing import cast, Dict, Iterable, Sequence, Set, Tuple
@@ -35,7 +36,12 @@
3536
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
3637
from executorch.exir.dialects._ops import ops as exir_ops
3738
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
39+
from executorch.exir.dim_order_utils import get_memory_format
3840
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
41+
from executorch.exir.passes.dim_order_ops_registry import (
42+
DimOrderOpsMap,
43+
MemoryFormatOpsMap,
44+
)
3945
from torch._subclasses import FakeTensor
4046
from torch.fx.node import Argument
4147

@@ -1799,6 +1805,72 @@ def call_operator(
17991805
)
18001806

18011807

1808+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
1809+
class ReplaceToDimOrderCopyWithToCopyPass(ExportPass):
1810+
"""
1811+
dim_order_ops::to_dim_order_copy is not supported, so this is an opt_level=0 pass.
1812+
If the dim order is sequential, we don't need the extra work with strides and
1813+
can just use to_copy.
1814+
"""
1815+
1816+
def call_operator(
1817+
self,
1818+
op,
1819+
args: Tuple[Argument, ...],
1820+
kwargs: Dict[str, Argument],
1821+
meta: NodeMetadata,
1822+
) -> ProxyValue:
1823+
if op not in DimOrderOpsMap:
1824+
return super().call_operator(op, args, kwargs, meta)
1825+
1826+
# new kwargs with dim_order, and no memory_format for the new op
1827+
nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable
1828+
1829+
ndim = None
1830+
1831+
# can always get the shape, assuming rank is specialized
1832+
1833+
# pyre-ignore[16]: `None` has no attribute `to_tensor`
1834+
if isinstance(args[0], ProxyValue) and args[0].is_tensor():
1835+
# pyre-ignore[16]: `None` has no attribute `to_tensor`
1836+
ndim = args[0].to_tensor().dim()
1837+
elif isinstance(args[0], torch.Tensor):
1838+
# pyre-ignore[16]: `None` has no attribute `dim`
1839+
ndim = args[0].dim()
1840+
elif isinstance(args[0], torch.fx.immutable_collections.immutable_list):
1841+
# pyre-ignore[6]: Incompatible parameter type
1842+
ndim = len(args[0])
1843+
else:
1844+
assert 0, f"Expecting a Tensor or a ProxyValue but got {type(args[0])}"
1845+
1846+
# get the "to" memory format for the EdgeOp
1847+
contiguous_dim_order = list(range(ndim))
1848+
dim_order = nkwargs.pop("dim_order", None)
1849+
1850+
# Cadence only supports contiguous memory format
1851+
assert (
1852+
dim_order is None
1853+
# pyre-ignore[6]: Incompatible parameter type
1854+
or len(dim_order) == 0
1855+
or dim_order == contiguous_dim_order
1856+
), "Expected dim order in congituous or prevserve memory format, but got {}".format(
1857+
dim_order
1858+
)
1859+
1860+
# bring back memory format
1861+
# pyre-ignore[6]: Incompatible parameter type
1862+
nkwargs["memory_format"] = get_memory_format(dim_order)
1863+
1864+
memory_format_op = MemoryFormatOpsMap[op]
1865+
1866+
return super().call_operator(
1867+
memory_format_op,
1868+
args,
1869+
nkwargs,
1870+
meta,
1871+
)
1872+
1873+
18021874
@register_cadence_pass(CadencePassAttribute(opt_level=0))
18031875
class ReplaceFullLikeWithFullPass(ExportPass):
18041876
"""
@@ -2108,4 +2180,5 @@ class CadenceReplaceOpsInGraph:
21082180
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
21092181
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
21102182
ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass,
2183+
ReplaceToDimOrderCopyWithToCopyPass,
21112184
]

backends/cadence/fusion_g3/operators/op_exp.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ Tensor& exp_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
4949
out);
5050
#endif
5151

52-
if (out.scalar_type() == ScalarType::Float) {
53-
float* const out_data = out.mutable_data_ptr<float>();
54-
const float* const in_data = in.const_data_ptr<float>();
52+
if (in.scalar_type() == ScalarType::Float) {
53+
float* __restrict__ out_data = out.mutable_data_ptr<float>();
54+
const float* __restrict__ in_data = in.const_data_ptr<float>();
5555

5656
XT_KERNEL_CHECK(
5757
ctx, out, xa_nn_elm_exp_f32_f32, out_data, in_data, out.numel());
@@ -66,4 +66,4 @@ Tensor& exp_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
6666
} // namespace native
6767
} // namespace G3
6868
} // namespace impl
69-
} // namespace cadence
69+
} // namespace cadence

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void main() {
4141
div_by_x % out_limits.y,
4242
div_by_x / out_limits.y);
4343

44-
if (any(greaterThanEqual(pos, out_limits))) {
44+
if (pos.z >= out_limits.z) {
4545
return;
4646
}
4747

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ void main() {
6666
pos.y *= BATCH_SIZE_Y;
6767

6868
// do not process if top pixel does not fit within the output range
69-
if (any(greaterThanEqual(pos, out_limits))) {
69+
if (pos.z >= out_limits.z) {
7070
return;
7171
}
7272

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_sned_output_tile.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ void main() {
4444
div_by_x % out_limits.y,
4545
div_by_x / out_limits.y);
4646

47-
if (any(greaterThanEqual(pos, out_limits))) {
47+
if (pos.z >= out_limits.z) {
4848
return;
4949
}
5050

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212

1313
#define VEC4_T ${texel_type(DTYPE)}
1414

15-
#define TILE_SIZE ${TILE_SIZE}
15+
#define TILE_SIZE_X ${TILE_SIZE_X}
16+
#define TILE_SIZE_Y ${TILE_SIZE_Y}
17+
#define LOCAL_WG_SIZE 64
1618

1719
#define op(X, A, B) ${OPERATOR}
1820

@@ -41,19 +43,19 @@ layout(push_constant) uniform restrict Block {
4143

4244
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4345

44-
// shared memory to hold calculated positions, this would reduce register usage thus improving performance.
45-
// 64 is the number of threads in the local wg
46-
$num_shared = 64 * TILE_SIZE * TILE_SIZE
47-
shared ivec2 pos_shared[${num_shared}];
46+
// For performance improvement, reduce register usage by caching positions in shared memory.
47+
// Offset index by 1 every 16 points to avoid bank access conflict.
48+
#define offset_pos_index(index) (index + ((index) >> 4))
49+
shared ivec3 pos_shared[offset_pos_index(LOCAL_WG_SIZE * TILE_SIZE_X * TILE_SIZE_Y)];
4850

4951
/*
5052
* Computes a 2D pointwise convolution of an NxN output tile. Calculating an
5153
* output tile for pointwise convolution is more efficient because the kernel
5254
* size is only 1x1, making it easier to re-use loaded texels from t_kernel.
5355
*/
5456
void main() {
55-
const ivec2 out_limits_scaled = (out_limits.xy + TILE_SIZE - 1) / TILE_SIZE;
56-
const uint shared_mem_stride = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z;
57+
const ivec2 out_limits_scaled = (out_limits.xy + ivec2(TILE_SIZE_X - 1, TILE_SIZE_Y - 1)) / ivec2(TILE_SIZE_X, TILE_SIZE_Y);
58+
const uint shared_mem_stride = LOCAL_WG_SIZE;
5759

5860
const uint div_by_x = gl_GlobalInvocationID.x / out_limits_scaled.x;
5961
const ivec3 gpos = ivec3(
@@ -67,33 +69,32 @@ void main() {
6769
// +--------+--------+
6870
// | pos[2] | pos[3] |
6971
// +--------+--------+
70-
ivec2 pos[TILE_SIZE * TILE_SIZE];
71-
for (int y = 0, i = 0; y < TILE_SIZE; ++y) {
72-
for (int x = 0; x < TILE_SIZE; ++x) {
73-
pos[i] = ivec2(
74-
gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y);
75-
pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex] = pos[i];
72+
ivec2 pos[TILE_SIZE_X * TILE_SIZE_Y];
73+
for (int y = 0, i = 0; y < TILE_SIZE_Y; ++y) {
74+
for (int x = 0; x < TILE_SIZE_X; ++x) {
75+
pos[i] = ivec2(gpos.x * TILE_SIZE_X + x, gpos.y * TILE_SIZE_Y + y);
76+
pos_shared[offset_pos_index((shared_mem_stride * i) + gl_LocalInvocationIndex)] = ivec3(pos[i], gpos.z);
7677
i++;
7778
}
7879
}
7980

8081
// If the top left position is out of bounds, then this invocation will have
8182
// no work to do.
82-
if (any(greaterThanEqual(ivec3(pos[0], gpos.z), out_limits.xyz))) {
83+
if (gpos.z >= out_limits.z) {
8384
return;
8485
}
8586

8687
// Compute the index of the input texture that needs to be loaded for each
8788
// output position. Note that negative indices can be produced indicating that
8889
// the top-left element is in a region added by padding.
89-
ivec2 ipos[TILE_SIZE * TILE_SIZE];
90-
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
90+
ivec2 ipos[TILE_SIZE_X * TILE_SIZE_Y];
91+
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
9192
ipos[i] = pos[i] * stride - padding;
9293
}
9394

94-
vec4 sum[TILE_SIZE * TILE_SIZE];
95+
vec4 sum[TILE_SIZE_X * TILE_SIZE_Y];
9596
sum[0] = texelFetch(t_bias, ivec2(gpos.z, 0), 0);
96-
for (int i = 1; i < TILE_SIZE * TILE_SIZE; ++i) {
97+
for (int i = 1; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
9798
sum[i] = sum[0];
9899
}
99100

@@ -109,7 +110,7 @@ void main() {
109110
const vec4 ktex_3 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(3, 0));
110111

111112
#pragma unroll
112-
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
113+
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
113114
const vec4 in_tex = texelFetch(t_in, ivec3(ipos[i], z4), 0);
114115
// For 2x2 tile size algorithm works as follows.
115116
// To explain the calculations below, the contents of one in_tex and the
@@ -151,10 +152,11 @@ void main() {
151152
}
152153
}
153154

154-
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
155-
const ivec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex];
156-
if (all(lessThan(ivec3(pos, gpos.z), out_limits.xyz))) {
157-
imageStore(t_out, ivec3(pos, gpos.z), op(sum[i], out_min, out_max));
155+
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
156+
const uint index = (shared_mem_stride * i) + gl_LocalInvocationIndex;
157+
const ivec3 pos = pos_shared[offset_pos_index(index)];
158+
if (all(lessThan(pos, out_limits.xyz))) {
159+
imageStore(t_out, pos, op(sum[i], out_min, out_max));
158160
}
159161
}
160162
}

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ conv2d_pw:
99
OPERATOR: X
1010
NDIM: 3
1111
DTYPE: float
12-
TILE_SIZE: 2
12+
TILE_SIZE_X: 2
13+
TILE_SIZE_Y: 2
1314
generate_variant_forall:
1415
DTYPE:
1516
- VALUE: half

backends/xnnpack/test/ops/test_cat.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,15 @@ def test_qs8_cat_gt_5(self):
187187
inputs.append(torch.randn(1, 2, 3))
188188
self._test_cat(self.Cat(), tuple(inputs), cat_num=num_inputs, quant=True)
189189

190+
def test_qs8_cat_with_empty_tensor(self):
191+
inputs = (
192+
torch.randn(0, 2, 3),
193+
torch.randn(1, 2, 3),
194+
torch.randn(3, 2, 3),
195+
torch.randn(0, 2, 3),
196+
)
197+
self._test_cat(self.Cat(), inputs, cat_num=4, quant=True)
198+
190199
class CatNegativeDim(torch.nn.Module):
191200
def __init__(self):
192201
super().__init__()

examples/cadence/operators/facto_util.py

Lines changed: 12 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,58 +18,36 @@
1818

1919
def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> None:
2020
match op_name:
21-
case (
22-
"sigmoid.default"
23-
| "_softmax.default"
24-
| "rsqrt.default"
25-
| "exp.default"
26-
| "mul.Tensor"
27-
| "div.Tensor"
28-
):
21+
case "sigmoid.default" | "rsqrt.default":
2922
tensor_constraints.extend(
3023
[
3124
cp.Dtype.In(lambda deps: [torch.float]),
32-
cp.Size.Le(lambda deps, r, d: 2),
33-
cp.Rank.Le(lambda deps: 2),
25+
cp.Rank.Le(lambda deps: 2**3),
3426
]
3527
)
36-
case (
37-
"add.Tensor"
38-
| "sub.Tensor"
39-
| "add.Scalar"
40-
| "sub.Scalar"
41-
| "mul.Scalar"
42-
| "div.Scalar"
43-
):
28+
case "exp.default":
4429
tensor_constraints.extend(
4530
[
46-
cp.Dtype.In(lambda deps: [torch.float, torch.int32]),
47-
cp.Size.Le(lambda deps, r, d: 2),
48-
cp.Rank.Le(lambda deps: 2),
49-
]
50-
)
51-
case "native_layer_norm.default":
52-
tensor_constraints.extend(
53-
[
54-
cp.Dtype.In(lambda deps: [torch.float, torch.int32]),
55-
cp.Size.Le(lambda deps, r, d: 2**4),
56-
cp.Rank.Le(lambda deps: 2**4),
31+
cp.Rank.Le(lambda deps: 2**3),
32+
cp.Value.Ge(lambda deps, dtype, struct: -(2**2)),
33+
cp.Value.Le(lambda deps, dtype, struct: 2**2),
5734
]
5835
)
5936
case _:
6037
tensor_constraints.extend(
6138
[
62-
cp.Dtype.In(lambda deps: [torch.float, torch.int32]),
63-
cp.Size.Le(lambda deps, r, d: 2),
64-
cp.Rank.Le(lambda deps: 2),
39+
cp.Rank.Le(lambda deps: 2**2),
6540
]
6641
)
6742
tensor_constraints.extend(
6843
[
69-
cp.Value.Ge(lambda deps, dtype, struct: -(2**8)),
70-
cp.Value.Le(lambda deps, dtype, struct: 2**8),
44+
cp.Dtype.In(lambda deps: [torch.int, torch.float]),
45+
cp.Dtype.NotIn(lambda deps: [torch.int64, torch.float64]),
46+
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
47+
cp.Value.Le(lambda deps, dtype, struct: 2**4),
7148
cp.Rank.Ge(lambda deps: 1),
7249
cp.Size.Ge(lambda deps, r, d: 1),
50+
cp.Size.Le(lambda deps, r, d: 2**9),
7351
]
7452
)
7553

0 commit comments

Comments
 (0)