Skip to content

Commit 8733db7

Browse files
nathanaelseefacebook-github-bot
authored andcommitted
add handling of out-of-range indices to Slice (#3689)
Summary: Seeing some ops in graph like: ```aten edge dialect new_k[-self.left_context :, :, :], slice_66: "f32[10, 1, 256]" = torch.ops.aten.slice.Tensor(cat_8, 0, -10, 9223372036854775807); cat_8 = None ``` Negative indices and 9223372036854775807 are valid inputs to `start` and `end` params on slice op, but runtime checks in Slice.cpp don't accept them. (9223372036854775807 is the max value of signed int_64; it maps to the index not being provided.) Adding code to compute the real values to the range [0, size(dim)) Reviewed By: jorgep31415 Differential Revision: D57597106
1 parent 79e9b79 commit 8733db7

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

backends/vulkan/runtime/graph/ops/impl/Slice.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,20 @@
1717

1818
namespace vkcompute {
1919

20+
inline int64_t normalize_idx(
21+
const int64_t index,
22+
const int64_t max,
23+
const int64_t default_value) {
24+
// INT64_MAX is passed when value is unspecified
25+
if (index == INT64_MAX) {
26+
return default_value;
27+
}
28+
if (index == default_value) {
29+
return index;
30+
}
31+
return normalize(index, max);
32+
}
33+
2034
void add_slice_tensor_out_node(
2135
ComputeGraph& graph,
2236
ValueRef in,
@@ -57,8 +71,8 @@ void add_slice_tensor_out_node(
5771
int64_t start = opt_start.value_or(0);
5872
int64_t end = opt_end.value_or(in_sizes[dim]);
5973

60-
VK_CHECK_COND((0 <= start) && (start < in_sizes[dim]));
61-
VK_CHECK_COND((0 <= end) && (end <= in_sizes[dim]));
74+
start = normalize_idx(start, in_sizes[dim], 0);
75+
end = normalize_idx(end, in_sizes[dim], in_sizes[dim]);
6276

6377
if (dim_index == kChannel4D) {
6478
// slice by channel

backends/vulkan/test/op_tests/cases.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,23 @@ def get_slice_inputs():
392392
Test(self=[13, 1, 10], dim=0, start=1, step=20),
393393
]
394394

395+
# Slice by negative/unspecified indices
396+
INT64_MAX = 9223372036854775807 # represents arr[:]
397+
test_cases += [
398+
Test(self=[8, 9], dim=0, start=-2, step=1),
399+
Test(self=[8, 9], dim=0, start=-2, step=2),
400+
Test(self=[8, 9], dim=0, end=-2, step=1),
401+
Test(self=[8, 9], dim=0, end=-2, step=2),
402+
Test(self=[8, 9], dim=0, end=INT64_MAX, step=1),
403+
Test(self=[8, 9], dim=0, end=INT64_MAX, step=2),
404+
Test(self=[8, 9], dim=1, start=-2, step=1),
405+
Test(self=[8, 9], dim=1, start=-2, step=2),
406+
Test(self=[8, 9], dim=1, end=-2, step=1),
407+
Test(self=[8, 9], dim=1, end=-2, step=2),
408+
Test(self=[8, 9], dim=1, end=INT64_MAX, step=1),
409+
Test(self=[8, 9], dim=1, end=INT64_MAX, step=2),
410+
]
411+
395412
test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
396413

397414
test_suite.dtypes = ["at::kFloat"]

0 commit comments

Comments
 (0)