Skip to content

Commit 2c3f748

Browse files
Call operator of all indexers must return py::ssize_t
1 parent 67cab69 commit 2c3f748

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

dpctl/tensor/libtensor/include/utils/offset_utils.hpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,12 @@ struct StridedIndexer
144144
{
145145
}
146146

147-
size_t operator()(py::ssize_t gid) const
147+
py::ssize_t operator()(py::ssize_t gid) const
148148
{
149149
return compute_offset(gid);
150150
}
151151

152-
size_t operator()(size_t gid) const
152+
py::ssize_t operator()(size_t gid) const
153153
{
154154
return compute_offset(static_cast<py::ssize_t>(gid));
155155
}
@@ -159,7 +159,7 @@ struct StridedIndexer
159159
py::ssize_t starting_offset;
160160
py::ssize_t const *shape_strides;
161161

162-
size_t compute_offset(py::ssize_t gid) const
162+
py::ssize_t compute_offset(py::ssize_t gid) const
163163
{
164164
using dpctl::tensor::strides::CIndexer_vector;
165165

@@ -185,12 +185,12 @@ struct UnpackedStridedIndexer
185185
{
186186
}
187187

188-
size_t operator()(py::ssize_t gid) const
188+
py::ssize_t operator()(py::ssize_t gid) const
189189
{
190190
return compute_offset(gid);
191191
}
192192

193-
size_t operator()(size_t gid) const
193+
py::ssize_t operator()(size_t gid) const
194194
{
195195
return compute_offset(static_cast<py::ssize_t>(gid));
196196
}
@@ -201,7 +201,7 @@ struct UnpackedStridedIndexer
201201
py::ssize_t const *shape;
202202
py::ssize_t const *strides;
203203

204-
size_t compute_offset(py::ssize_t gid) const
204+
py::ssize_t compute_offset(py::ssize_t gid) const
205205
{
206206
using dpctl::tensor::strides::CIndexer_vector;
207207

@@ -223,11 +223,10 @@ struct Strided1DIndexer
223223
{
224224
}
225225

226-
size_t operator()(size_t gid) const
226+
py::ssize_t operator()(size_t gid) const
227227
{
228228
// ensure 0 <= gid < size
229-
return static_cast<size_t>(offset +
230-
std::min<size_t>(gid, size - 1) * step);
229+
return offset + std::min<size_t>(gid, size - 1) * step;
231230
}
232231

233232
private:
@@ -245,9 +244,9 @@ struct Strided1DCyclicIndexer
245244
{
246245
}
247246

248-
size_t operator()(size_t gid) const
247+
py::ssize_t operator()(size_t gid) const
249248
{
250-
return static_cast<size_t>(offset + (gid % size) * step);
249+
return offset + (gid % size) * step;
251250
}
252251

253252
private:

0 commit comments

Comments
 (0)