Skip to content

Commit 32d674c

Browse files
[ET-VK] Shortening code for slice op when packed dim is not the same as slice dim. (#9169)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #9136 by @trivedivivek ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/59/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/59/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/59/orig @diff-train-skip-merge Co-authored-by: Vivek Trivedi <[email protected]>
1 parent c5e457a commit 32d674c

File tree

1 file changed

+9
-21
lines changed
  • backends/vulkan/runtime/graph/ops/impl

1 file changed

+9
-21
lines changed

backends/vulkan/runtime/graph/ops/impl/Slice.cpp

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -108,27 +108,15 @@ void add_slice_tensor_copy_node(
108108
spec_vars));
109109

110110
} else {
111-
// GPU's coordinate is in x, y, z
112-
int64_t gpu_dim = -1;
113-
int64_t in_channel_stride = 1;
114-
if (dim_index == kWidth4D) {
115-
gpu_dim = 0; // width: x dimension in gpu
116-
VK_CHECK_COND(out_sizes[dim] == (1 + (end - start - 1) / step));
117-
} else if (dim_index == kHeight4D) {
118-
gpu_dim = 1; // height: y dimension
119-
VK_CHECK_COND(out_sizes[dim] == (1 + (end - start - 1) / step));
120-
} else if (dim_index == kChannel4D) {
121-
gpu_dim = 2; // channel: z dimension
122-
VK_CHECK_COND(out_sizes[dim] == (1 + (end - start - 1) / step));
123-
in_channel_stride = dim_at(in_sizes, kChannel4D);
124-
} else {
125-
gpu_dim = 3; // batch: w dimension
126-
127-
in_channel_stride = dim_at(in_sizes, kChannel4D);
128-
if (packed_dim_idx == kChannel4D) {
129-
// Due to channel packing, each batch value is span over stride planes
130-
in_channel_stride = utils::div_up_4(in_channel_stride);
131-
}
111+
// GPU's coordinate is in x = 0, y = 1, z = 2, w = 3
112+
const int64_t gpu_dim = -(dim_index + 1);
113+
// stride of input tensor's channel dimension
114+
int64_t in_channel_stride = dim_at(in_sizes, kChannel4D);
115+
VK_CHECK_COND(out_sizes[dim] == (1 + (end - start - 1) / step));
116+
117+
// Due to channel packing, each batch value is span over stride planes
118+
if (dim_index == kBatch4D && packed_dim_idx == kChannel4D) {
119+
in_channel_stride = utils::div_up_4(in_channel_stride);
132120
}
133121

134122
std::string kernel_name = "slice_batch_height_width";

0 commit comments

Comments
 (0)