Skip to content

Commit 299667d

Browse files
authored
Native astype (#642)
* start change
1 parent 92192ad commit 299667d

File tree

9 files changed

+143
-4
lines changed

9 files changed

+143
-4
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,17 @@ INP_DLLEXPORT void dpnp_any_c(const void* array, void* result, const size_t size
141141
template <typename _DataType>
142142
INP_DLLEXPORT void dpnp_arange_c(size_t start, size_t step, void* result1, size_t size);
143143

144+
/**
145+
* @ingroup BACKEND_API
146+
* @brief Copy of the array, cast to a specified type.
147+
*
148+
* @param [in] array Input array.
149+
* @param [out] result Output array.
150+
* @param [in] size Number of input elements in `array`.
151+
*/
152+
template <typename _DataType, typename _ResultType>
153+
INP_DLLEXPORT void dpnp_astype_c(const void* array, void* result, const size_t size);
154+
144155
/**
145156
* @ingroup BACKEND_API
146157
* @brief Implementation of full function

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ enum class DPNPFuncName : size_t
7373
DPNP_FN_ARGMAX, /**< Used in numpy.argmax() implementation */
7474
DPNP_FN_ARGMIN, /**< Used in numpy.argmin() implementation */
7575
DPNP_FN_ARGSORT, /**< Used in numpy.argsort() implementation */
76+
DPNP_FN_ASTYPE, /**< Used in numpy.astype() implementation */
7677
DPNP_FN_BITWISE_AND, /**< Used in numpy.bitwise_and() implementation */
7778
DPNP_FN_BITWISE_OR, /**< Used in numpy.bitwise_or() implementation */
7879
DPNP_FN_BITWISE_XOR, /**< Used in numpy.bitwise_xor() implementation */

dpnp/backend/kernels/dpnp_krnl_common.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,42 @@
3535
namespace mkl_blas = oneapi::mkl::blas;
3636
namespace mkl_lapack = oneapi::mkl::lapack;
3737

38+
template <typename _DataType, typename _ResultType>
39+
class dpnp_astype_c_kernel;
40+
41+
template <typename _DataType, typename _ResultType>
42+
void dpnp_astype_c(const void* array1_in, void* result1, const size_t size)
43+
{
44+
cl::sycl::event event;
45+
46+
const _DataType* array_in = reinterpret_cast<const _DataType*>(array1_in);
47+
_ResultType* result = reinterpret_cast<_ResultType*>(result1);
48+
49+
if ((array_in == nullptr) || (result == nullptr))
50+
{
51+
return;
52+
}
53+
54+
if (size == 0)
55+
{
56+
return;
57+
}
58+
59+
cl::sycl::range<1> gws(size);
60+
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
61+
size_t i = global_id[0];
62+
result[i] = array_in[i];
63+
};
64+
65+
auto kernel_func = [&](cl::sycl::handler& cgh) {
66+
cgh.parallel_for<class dpnp_astype_c_kernel<_DataType, _ResultType>>(gws, kernel_parallel_for_func);
67+
};
68+
69+
event = DPNP_QUEUE.submit(kernel_func);
70+
71+
event.wait();
72+
}
73+
3874
template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2, typename _KernelNameSpecialization3>
3975
class dpnp_dot_c_kernel;
4076

@@ -324,6 +360,33 @@ void dpnp_matmul_c(void* array1_in, void* array2_in, void* result1, size_t size_
324360

325361
void func_map_init_linalg(func_map_t& fmap)
326362
{
363+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_BLN][eft_BLN] = {eft_BLN, (void*)dpnp_astype_c<bool, bool>};
364+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_BLN][eft_INT] = {eft_INT, (void*)dpnp_astype_c<bool, int>};
365+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_BLN][eft_LNG] = {eft_LNG, (void*)dpnp_astype_c<bool, long>};
366+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_BLN][eft_FLT] = {eft_FLT, (void*)dpnp_astype_c<bool, float>};
367+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_BLN][eft_DBL] = {eft_DBL, (void*)dpnp_astype_c<bool, double>};
368+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_INT][eft_BLN] = {eft_BLN, (void*)dpnp_astype_c<int, bool>};
369+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_astype_c<int, int>};
370+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_INT][eft_LNG] = {eft_LNG, (void*)dpnp_astype_c<int, long>};
371+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_INT][eft_FLT] = {eft_FLT, (void*)dpnp_astype_c<int, float>};
372+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_INT][eft_DBL] = {eft_DBL, (void*)dpnp_astype_c<int, double>};
373+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_LNG][eft_BLN] = {eft_BLN, (void*)dpnp_astype_c<long, bool>};
374+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_LNG][eft_INT] = {eft_INT, (void*)dpnp_astype_c<long, int>};
375+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_astype_c<long, long>};
376+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_LNG][eft_FLT] = {eft_FLT, (void*)dpnp_astype_c<long, float>};
377+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_LNG][eft_DBL] = {eft_DBL, (void*)dpnp_astype_c<long, double>};
378+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_FLT][eft_BLN] = {eft_BLN, (void*)dpnp_astype_c<float, bool>};
379+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_FLT][eft_INT] = {eft_INT, (void*)dpnp_astype_c<float, int>};
380+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_FLT][eft_LNG] = {eft_LNG, (void*)dpnp_astype_c<float, long>};
381+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_astype_c<float, float>};
382+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_FLT][eft_DBL] = {eft_DBL, (void*)dpnp_astype_c<float, double>};
383+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_DBL][eft_BLN] = {eft_BLN, (void*)dpnp_astype_c<double, bool>};
384+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_DBL][eft_INT] = {eft_INT, (void*)dpnp_astype_c<double, int>};
385+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_DBL][eft_LNG] = {eft_LNG, (void*)dpnp_astype_c<double, long>};
386+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_DBL][eft_FLT] = {eft_FLT, (void*)dpnp_astype_c<double, float>};
387+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_astype_c<double, double>};
388+
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_C128][eft_C128] = {eft_C128, (void*)dpnp_astype_c<std::complex<double>, std::complex<double>>};
389+
327390
fmap[DPNPFuncName::DPNP_FN_DOT][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_dot_c<int, int, int>};
328391
fmap[DPNPFuncName::DPNP_FN_DOT][eft_INT][eft_LNG] = {eft_LNG, (void*)dpnp_dot_c<int, long, long>};
329392
fmap[DPNPFuncName::DPNP_FN_DOT][eft_INT][eft_FLT] = {eft_DBL, (void*)dpnp_dot_c<int, float, double>};

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
4646
DPNP_FN_ARGMAX
4747
DPNP_FN_ARGMIN
4848
DPNP_FN_ARGSORT
49+
DPNP_FN_ASTYPE
4950
DPNP_FN_BITWISE_AND
5051
DPNP_FN_BITWISE_OR
5152
DPNP_FN_BITWISE_XOR

dpnp/dpnp_algo/dpnp_algo.pyx

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ include "dpnp_algo_trigonometric.pyx"
6969

7070

7171
ctypedef void(*fptr_dpnp_arange_t)(size_t, size_t, void *, size_t)
72+
ctypedef void(*fptr_dpnp_astype_t)(const void *, void * , const size_t)
7273
ctypedef void(*fptr_dpnp_initval_t)(void *, void * , size_t)
7374

7475

@@ -125,10 +126,16 @@ cpdef dparray dpnp_array(obj, dtype=None):
125126

126127

127128
cpdef dparray dpnp_astype(dparray array1, dtype_target):
128-
cdef dparray result = dparray(array1.shape, dtype=dtype_target)
129+
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(array1.dtype)
130+
cdef DPNPFuncType param2_type = dpnp_dtype_to_DPNPFuncType(dtype_target)
129131

130-
for i in range(result.size):
131-
result[i] = array1[i]
132+
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_ASTYPE, param1_type, param2_type)
133+
134+
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
135+
cdef dparray result = dparray(array1.shape, dtype=result_type)
136+
137+
cdef fptr_dpnp_astype_t func = <fptr_dpnp_astype_t > kernel_data.ptr
138+
func(array1.get_data(), result.get_data(), array1.size)
132139

133140
return result
134141

tests/skipped_tests.tbl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
1+
tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-float64]
2+
tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-float32]
3+
tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-int64]
4+
tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-int32]
5+
tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-bool]
6+
tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-bool_]
7+
tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-float64]
8+
tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-float32]
9+
tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-int64]
10+
tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-int32]
11+
tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-bool]
12+
tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-bool_]
13+
tests/test_dparray.py::test_astype[[]-complex-float64]
14+
tests/test_dparray.py::test_astype[[]-complex-float32]
15+
tests/test_dparray.py::test_astype[[]-complex-int64]
16+
tests/test_dparray.py::test_astype[[]-complex-int32]
17+
tests/test_dparray.py::test_astype[[]-complex-bool]
18+
tests/test_dparray.py::test_astype[[]-complex-bool_]
119
tests/test_linalg.py::test_cond[-1-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
220
tests/test_linalg.py::test_cond[1-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
321
tests/test_linalg.py::test_cond[-2-[[1, 0, -1], [0, 1, 0], [1, 0, 1]]]
@@ -154,7 +172,6 @@ tests/test_linalg.py::test_svd[(5,3)-complex128]
154172
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[ lambda x: x]
155173
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[ lambda x: dpnp.asarray(x).astype(dpnp.int8)]
156174
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[ lambda x: dpnp.asarray(x).astype(dpnp.complex64)]
157-
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[ lambda x: dpnp.asarray(x).astype(object)]
158175
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[ lambda x: [(i, i) for i in x]]
159176
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[ lambda x: dpnp.vstack([x, x]).T]
160177
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[ lambda x: (dpnp.asarray([(i, i) for i in x], [("a", int), ("b", int)]).view(dpnp.recarray))]

tests/skipped_tests_gpu.tbl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
1+
tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-float64]
2+
tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-float32]
3+
tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-int64]
4+
tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-int32]
5+
tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-bool]
6+
tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-bool_]
7+
tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-float64]
8+
tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-float32]
9+
tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-int64]
10+
tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-int32]
11+
tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-bool]
12+
tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-bool_]
13+
tests/test_dparray.py::test_astype[[]-complex-float64]
14+
tests/test_dparray.py::test_astype[[]-complex-float32]
15+
tests/test_dparray.py::test_astype[[]-complex-int64]
16+
tests/test_dparray.py::test_astype[[]-complex-int32]
17+
tests/test_dparray.py::test_astype[[]-complex-bool]
18+
tests/test_dparray.py::test_astype[[]-complex-bool_]
119
tests/test_dot.py::test_dot_arange[float32]
220
tests/test_dot.py::test_dot_arange[float64]
321
tests/test_dot.py::test_dot_ones[float32]

tests/test_dparray.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import dpnp
2+
import numpy
3+
import pytest
4+
5+
6+
@pytest.mark.parametrize("res_dtype",
7+
[numpy.float64, numpy.float32, numpy.int64, numpy.int32, numpy.bool, numpy.bool_, numpy.complex],
8+
ids=['float64', 'float32', 'int64', 'int32', 'bool', 'bool_', 'complex'])
9+
@pytest.mark.parametrize("arr_dtype",
10+
[numpy.float64, numpy.float32, numpy.int64, numpy.int32, numpy.bool, numpy.bool_, numpy.complex],
11+
ids=['float64', 'float32', 'int64', 'int32', 'bool', 'bool_', 'complex'])
12+
@pytest.mark.parametrize("arr",
13+
[[-2, -1, 0, 1, 2], [[-2, -1], [1, 2]], []],
14+
ids=['[-2, -1, 0, 1, 2]', '[[-2, -1], [1, 2]]', '[]'])
15+
def test_astype(arr, arr_dtype, res_dtype):
16+
numpy_array = numpy.array(arr, dtype=arr_dtype)
17+
dpnp_array = dpnp.array(numpy_array)
18+
expected = numpy_array.astype(res_dtype)
19+
result = dpnp_array.astype(res_dtype)
20+
numpy.testing.assert_array_equal(expected, result)

tests_external/skipped_tests_numpy.tbl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2881,6 +2881,7 @@ tests/test_umath.py::test_rint_big_int
28812881
tests/test_umath.py::TestRoundingFunctions::test_object_direct
28822882
tests/test_umath.py::TestRoundingFunctions::test_object_indirect
28832883
tests/test_umath.py::test_signaling_nan_exceptions
2884+
tests/test_umath.py::TestSign::test_sign_dtype_nan_object
28842885
tests/test_umath.py::TestSign::test_sign
28852886
tests/test_umath.py::TestSign::test_sign_dtype_object
28862887
tests/test_umath.py::test_spacing

0 commit comments

Comments
 (0)