Skip to content

Commit e7439ea

Browse files
committed
Update on "[ET-VK][10/n] copy node, aten.repeat"
1. Introduce a `CopyNode` for generic copy-with-offset operations. 2. `aten.repeat` on all dimensions. 2.1 Use `CopyNode` where possible. 2.2. Specialized `repeat_channel` shader to handle packings 3. Update codegen to support `Methods` variant only operations. Need a new route to trigger the dispatch. Differential Revision: [D56499329](https://our.internmc.facebook.com/intern/diff/D56499329/) [ghstack-poisoned]
2 parents 8c9f7d8 + 1c9088e commit e7439ea

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

backends/vulkan/runtime/graph/ops/impl/Repeat.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ void check_args(
2828

2929
int64_t in_dim = in.dim();
3030
VK_CHECK_COND(
31-
in_dim == repeats.size(), "Input tensor dim size must match argument");
31+
in_dim <= repeats.size(),
32+
"Input tensor dim size must be not greater than the repeat argument's size");
3233

3334
VK_CHECK_COND(
3435
dim_at<Dim4D::Width>(in.sizes()) * dim_at<Dim4D::Width>(repeats) ==
@@ -98,7 +99,6 @@ void add_repeat_channel_node(
9899
};
99100

100101
auto shader = VK_KERNEL_FROM_STR(kernel_name);
101-
// std::cout << "out tile size: " << shader.out_tile_size << std::endl;
102102

103103
graph.execute_nodes().emplace_back(new ExecuteNode(
104104
graph,
@@ -110,7 +110,7 @@ void add_repeat_channel_node(
110110
// Parameter buffers
111111
{graph.create_params_buffer(repeat_channel_args)},
112112
// Specialization Constants
113-
{}));
113+
{SV(t_out->gpu_memory_layout_int())}));
114114
}
115115

116116
void add_repeat_node(

backends/vulkan/test/op_tests/cases.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,9 @@ def get_repeat_inputs():
415415
((S1, S2, S2, S2), [1, 3, 1, 3]),
416416
((S1, S2, S2, S2), [3, 3, 3, 3]),
417417
((S1, S2, S2, S2), [3, 3, 1, 1]),
418+
# Expanding cases
419+
((2, 3), [3, 1, 4]),
420+
((2, 3), [3, 3, 2, 4]),
418421
]
419422
)
420423
test_suite.layouts = [

0 commit comments

Comments
 (0)