Skip to content

Commit 7cecd73

Browse files
nathanaelseefacebook-github-bot
authored andcommitted
add handling of out-of-range indices to Slice (#3689)
Summary: Seeing some ops in graph like: ```aten edge dialect new_k[-self.left_context :, :, :], slice_66: "f32[10, 1, 256]" = torch.ops.aten.slice.Tensor(cat_8, 0, -10, 9223372036854775807); cat_8 = None ``` Negative indices and 9223372036854775807 are valid inputs to `start` and `end` params on slice op, but runtime checks in Slice.cpp don't accept them. (9223372036854775807 is the max value of signed int_64; it maps to the index not being provided.) Adding code to compute the real values to the range [0, size(dim)) Differential Revision: D57597106
1 parent 4b7c6db commit 7cecd73

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

backends/vulkan/runtime/graph/ops/impl/Slice.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,27 @@ void add_slice_tensor_out_node(
5757
int64_t start = opt_start.value_or(0);
5858
int64_t end = opt_end.value_or(in_sizes[dim]);
5959

60-
VK_CHECK_COND((0 <= start) && (start < in_sizes[dim]));
61-
VK_CHECK_COND((0 <= end) && (end <= in_sizes[dim]));
60+
// if start or end is < 0, calculate real index
61+
start = start < 0 ? start + in_sizes[dim] : start;
62+
end = end < 0 ? end + in_sizes[dim] : end;
63+
64+
// if end is LONG_MAX, we are slicing to the end of the dim
65+
if (end == LONG_MAX) {
66+
end = in_sizes[dim];
67+
}
68+
69+
VK_CHECK_COND(
70+
(0 <= start) && (start < in_sizes[dim]),
71+
"start must be in range of [0, self.size(dim)), but current start's value is ",
72+
start,
73+
" and self.size(dim) = ",
74+
in_sizes[dim]);
75+
VK_CHECK_COND(
76+
(0 <= end) && (end <= in_sizes[dim]),
77+
"end must be in range of [0, self.size(dim)), but current end's value is ",
78+
end,
79+
" and self.size(dim) = ",
80+
in_sizes[dim]);
6281

6382
if (dim_index == kChannel4D) {
6483
// slice by channel

backends/vulkan/test/op_tests/cases.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,23 @@ def get_slice_inputs():
382382
Test(self=[13, 1, 10], dim=0, start=1, step=20),
383383
]
384384

385+
# Slice by negative/unspecified indices
386+
LONG_MAX = 9223372036854775807 # represents arr[:]
387+
test_cases += [
388+
Test(self=[8, 9], dim=0, start=-2, step=1),
389+
Test(self=[8, 9], dim=0, start=-2, step=2),
390+
Test(self=[8, 9], dim=0, end=-2, step=1),
391+
Test(self=[8, 9], dim=0, end=-2, step=2),
392+
Test(self=[8, 9], dim=0, end=LONG_MAX, step=1),
393+
Test(self=[8, 9], dim=0, end=LONG_MAX, step=2),
394+
Test(self=[8, 9], dim=1, start=-2, step=1),
395+
Test(self=[8, 9], dim=1, start=-2, step=2),
396+
Test(self=[8, 9], dim=1, end=-2, step=1),
397+
Test(self=[8, 9], dim=1, end=-2, step=2),
398+
Test(self=[8, 9], dim=1, end=LONG_MAX, step=1),
399+
Test(self=[8, 9], dim=1, end=LONG_MAX, step=2),
400+
]
401+
385402
test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
386403

387404
test_suite.dtypes = ["at::kFloat"]

0 commit comments

Comments
 (0)