Skip to content

Commit 446ce05

Browse files
Address PR feedback
sort.cpp -> merge_sort.cpp, argsort.cpp -> merge_argsort.cpp Refined exception texts thrown when implementation function pointer is found missing.
1 parent d63dd70 commit 446ce05

File tree

8 files changed

+22
-17
lines changed

8 files changed

+22
-17
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ set(_reduction_sources
112112
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/sum.cpp
113113
)
114114
set(_sorting_sources
115-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/sort.cpp
116-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/argsort.cpp
115+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_sort.cpp
116+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_argsort.cpp
117117
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp
118118
)
119119
set(_sorting_radix_sources

dpctl/tensor/libtensor/source/sorting/argsort.cpp renamed to dpctl/tensor/libtensor/source/sorting/merge_argsort.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
#include "kernels/sorting/sort_impl_fn_ptr_t.hpp"
3737
#include "rich_comparisons.hpp"
3838

39-
#include "argsort.hpp"
39+
#include "merge_argsort.hpp"
4040
#include "py_argsort_common.hpp"
4141

4242
namespace td_ns = dpctl::tensor::type_dispatch;
@@ -93,7 +93,7 @@ struct DescendingArgSortContigFactory
9393
}
9494
};
9595

96-
void init_argsort_dispatch_tables(void)
96+
void init_merge_argsort_dispatch_tables(void)
9797
{
9898
using dpctl::tensor::kernels::sort_contig_fn_ptr_t;
9999

@@ -108,9 +108,9 @@ void init_argsort_dispatch_tables(void)
108108
dtb2.populate_dispatch_table(descending_argsort_contig_dispatch_table);
109109
}
110110

111-
void init_argsort_functions(py::module_ m)
111+
void init_merge_argsort_functions(py::module_ m)
112112
{
113-
dpctl::tensor::py_internal::init_argsort_dispatch_tables();
113+
dpctl::tensor::py_internal::init_merge_argsort_dispatch_tables();
114114

115115
auto py_argsort_ascending = [](const dpctl::tensor::usm_ndarray &src,
116116
const int trailing_dims_to_sort,

dpctl/tensor/libtensor/source/sorting/sort.hpp renamed to dpctl/tensor/libtensor/source/sorting/merge_argsort.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ namespace tensor
3535
namespace py_internal
3636
{
3737

38-
extern void init_sort_functions(py::module_);
38+
extern void init_merge_argsort_functions(py::module_);
3939

4040
} // namespace py_internal
4141
} // namespace tensor

dpctl/tensor/libtensor/source/sorting/sort.cpp renamed to dpctl/tensor/libtensor/source/sorting/merge_sort.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@
3636
#include "kernels/sorting/merge_sort.hpp"
3737
#include "kernels/sorting/sort_impl_fn_ptr_t.hpp"
3838

39+
#include "merge_sort.hpp"
3940
#include "py_sort_common.hpp"
4041
#include "rich_comparisons.hpp"
41-
#include "sort.hpp"
4242

4343
namespace td_ns = dpctl::tensor::type_dispatch;
4444

@@ -76,7 +76,7 @@ template <typename fnT, typename argTy> struct DescendingSortContigFactory
7676
}
7777
};
7878

79-
void init_sort_dispatch_vectors(void)
79+
void init_merge_sort_dispatch_vectors(void)
8080
{
8181
using dpctl::tensor::kernels::sort_contig_fn_ptr_t;
8282

@@ -91,9 +91,9 @@ void init_sort_dispatch_vectors(void)
9191
dtv2.populate_dispatch_vector(descending_sort_contig_dispatch_vector);
9292
}
9393

94-
void init_sort_functions(py::module_ m)
94+
void init_merge_sort_functions(py::module_ m)
9595
{
96-
dpctl::tensor::py_internal::init_sort_dispatch_vectors();
96+
dpctl::tensor::py_internal::init_merge_sort_dispatch_vectors();
9797

9898
auto py_sort_ascending = [](const dpctl::tensor::usm_ndarray &src,
9999
const int trailing_dims_to_sort,

dpctl/tensor/libtensor/source/sorting/argsort.hpp renamed to dpctl/tensor/libtensor/source/sorting/merge_sort.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ namespace tensor
3535
namespace py_internal
3636
{
3737

38-
extern void init_argsort_functions(py::module_);
38+
extern void init_merge_sort_functions(py::module_);
3939

4040
} // namespace py_internal
4141
} // namespace tensor

dpctl/tensor/libtensor/source/sorting/py_argsort_common.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ py_argsort(const dpctl::tensor::usm_ndarray &src,
130130
auto fn = sort_contig_fns[src_typeid][dst_typeid];
131131

132132
if (fn == nullptr) {
133-
throw py::value_error("Not implemented for given index type");
133+
throw py::value_error("Not implemented for dtypes of input arrays");
134134
}
135135

136136
sycl::event comp_ev =

dpctl/tensor/libtensor/source/sorting/py_sort_common.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,11 @@ py_sort(const dpctl::tensor::usm_ndarray &src,
130130

131131
auto fn = sort_contig_fns[src_typeid];
132132

133+
if (nullptr == fn) {
134+
throw py::value_error(
135+
"Not implemented for the dtype of input arrays");
136+
}
137+
133138
sycl::event comp_ev =
134139
fn(exec_q, iter_nelems, sort_nelems, src.get_data(), dst.get_data(),
135140
zero_offset, zero_offset, zero_offset, zero_offset, depends);

dpctl/tensor/libtensor/source/tensor_sorting.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@
2525

2626
#include <pybind11/pybind11.h>
2727

28-
#include "sorting/argsort.hpp"
28+
#include "sorting/merge_argsort.hpp"
29+
#include "sorting/merge_sort.hpp"
2930
#include "sorting/searchsorted.hpp"
30-
#include "sorting/sort.hpp"
3131

3232
namespace py = pybind11;
3333

3434
PYBIND11_MODULE(_tensor_sorting_impl, m)
3535
{
36-
dpctl::tensor::py_internal::init_sort_functions(m);
37-
dpctl::tensor::py_internal::init_argsort_functions(m);
36+
dpctl::tensor::py_internal::init_merge_sort_functions(m);
37+
dpctl::tensor::py_internal::init_merge_argsort_functions(m);
3838
dpctl::tensor::py_internal::init_searchsorted_functions(m);
3939
}

0 commit comments

Comments
 (0)