Skip to content

add handling of out-of-range indices to Slice #3689

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/Slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@

namespace vkcompute {

inline int64_t normalize_idx(
const int64_t index,
const int64_t max,
const int64_t default_value) {
// INT64_MAX is passed when value is unspecified
if (index == INT64_MAX) {
return default_value;
}
if (index == default_value) {
return index;
}
return normalize(index, max);
}

void add_slice_tensor_out_node(
ComputeGraph& graph,
ValueRef in,
Expand Down Expand Up @@ -57,8 +71,8 @@ void add_slice_tensor_out_node(
int64_t start = opt_start.value_or(0);
int64_t end = opt_end.value_or(in_sizes[dim]);

VK_CHECK_COND((0 <= start) && (start < in_sizes[dim]));
VK_CHECK_COND((0 <= end) && (end <= in_sizes[dim]));
start = normalize_idx(start, in_sizes[dim], 0);
end = normalize_idx(end, in_sizes[dim], in_sizes[dim]);

if (dim_index == kChannel4D) {
// slice by channel
Expand Down
17 changes: 17 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,23 @@ def get_slice_inputs():
Test(self=[13, 1, 10], dim=0, start=1, step=20),
]

# Slice by negative/unspecified indices
INT64_MAX = 9223372036854775807 # represents arr[:]
test_cases += [
Test(self=[8, 9], dim=0, start=-2, step=1),
Test(self=[8, 9], dim=0, start=-2, step=2),
Test(self=[8, 9], dim=0, end=-2, step=1),
Test(self=[8, 9], dim=0, end=-2, step=2),
Test(self=[8, 9], dim=0, end=INT64_MAX, step=1),
Test(self=[8, 9], dim=0, end=INT64_MAX, step=2),
Test(self=[8, 9], dim=1, start=-2, step=1),
Test(self=[8, 9], dim=1, start=-2, step=2),
Test(self=[8, 9], dim=1, end=-2, step=1),
Test(self=[8, 9], dim=1, end=-2, step=2),
Test(self=[8, 9], dim=1, end=INT64_MAX, step=1),
Test(self=[8, 9], dim=1, end=INT64_MAX, step=2),
]

test_suite = VkTestSuite([tuple(tc) for tc in test_cases])

test_suite.dtypes = ["at::kFloat"]
Expand Down
Loading