Skip to content

Commit 258fe36

Browse files
committed
address comments
1 parent 6449353 commit 258fe36

File tree

8 files changed

+90
-136
lines changed

8 files changed

+90
-136
lines changed

dpnp/dpnp_array.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -600,8 +600,12 @@ def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True):
600600
"""
601601
Copy the array with data type casting.
602602
603+
For full documentation refer to :obj:`numpy.ndarray.astype`.
604+
603605
Parameters
604606
----------
607+
x1 : {dpnp.ndarray, usm_ndarray}
608+
Array data type casting.
605609
dtype : dtype
606610
Target data type.
607611
order : {'C', 'F', 'A', 'K'}
@@ -611,12 +615,23 @@ def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True):
611615
copy : bool
612616
If it is False and no cast happens, then this method returns the array itself.
613617
Otherwise, a copy is returned.
618+
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
619+
Controls what kind of data casting may occur. Defaults to 'unsafe' for backwards compatibility.
620+
'no' means the data types should not be cast at all.
621+
'equiv' means only byte-order changes are allowed.
622+
'safe' means only casts which can preserve values are allowed.
623+
'same_kind' means only safe casts or casts within a kind, like float64 to float32, are allowed.
624+
'unsafe' means any data conversions may be done.
625+
copy : bool, optional
626+
By default, astype always returns a newly allocated array. If this is set to false, and the dtype,
627+
order, and subok requirements are satisfied, the input array is returned instead of a copy.
614628
615629
Returns
616630
-------
617-
out : dpnp.ndarray
618-
If ``copy`` is False and no cast is required, then the array itself is returned.
619-
Otherwise, it returns a (possibly casted) copy of the array.
631+
arr_t : dpnp.ndarray
632+
Unless `copy` is ``False`` and the other conditions for returning the input array
633+
are satisfied, `arr_t` is a new array of the same shape as the input array,
634+
with dtype, order given by dtype, order.
620635
621636
Limitations
622637
-----------
@@ -634,9 +649,12 @@ def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True):
634649
635650
"""
636651

637-
return dpnp.astype(
638-
self, dtype, order=order, casting=casting, subok=subok, copy=copy
639-
)
652+
if subok is not True:
653+
raise NotImplementedError(
654+
f"subok={subok} is currently not supported"
655+
)
656+
657+
return dpnp.astype(self, dtype, order=order, casting=casting, copy=copy)
640658

641659
# 'base',
642660
# 'byteswap',

dpnp/dpnp_iface.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,14 @@ def asnumpy(input, order="C"):
150150
return numpy.asarray(input, order=order)
151151

152152

153-
def astype(x1, dtype, order="K", casting="unsafe", subok=True, copy=True):
153+
def astype(x1, dtype, order="K", casting="unsafe", copy=True):
154154
"""
155155
Copy the array with data type casting.
156156
157157
Parameters
158158
----------
159+
x1 : {dpnp.ndarray, usm_ndarray}
160+
Array data type casting.
159161
dtype : dtype
160162
Target data type.
161163
order : {'C', 'F', 'A', 'K'}
@@ -165,6 +167,16 @@ def astype(x1, dtype, order="K", casting="unsafe", subok=True, copy=True):
165167
copy : bool
166168
If it is False and no cast happens, then this method returns the array itself.
167169
Otherwise, a copy is returned.
170+
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
171+
Controls what kind of data casting may occur. Defaults to 'unsafe' for backwards compatibility.
172+
'no' means the data types should not be cast at all.
173+
'equiv' means only byte-order changes are allowed.
174+
'safe' means only casts which can preserve values are allowed.
175+
'same_kind' means only safe casts or casts within a kind, like float64 to float32, are allowed.
176+
'unsafe' means any data conversions may be done.
177+
copy : bool, optional
178+
By default, astype always returns a newly allocated array. If this is set to false, and the dtype,
179+
order, and subok requirements are satisfied, the input array is returned instead of a copy.
168180
169181
Returns
170182
-------
@@ -173,18 +185,8 @@ def astype(x1, dtype, order="K", casting="unsafe", subok=True, copy=True):
173185
are satisfied, `arr_t` is a new array of the same shape as the input array,
174186
with dtype, order given by dtype, order.
175187
176-
Limitations
177-
-----------
178-
Parameter `subok` is supported with default value.
179-
Otherwise ``NotImplementedError`` exception will be raised.
180-
181188
"""
182189

183-
if subok is not True:
184-
raise NotImplementedError(
185-
f"Requested function={function.__name__} with subok={subok} isn't currently supported"
186-
)
187-
188190
if order is None:
189191
order = "K"
190192

tests/skipped_tests.tbl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,12 @@ tests/test_umath.py::test_umaths[('spacing', 'd')]
111111
tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestAngle::test_angle
112112
tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_imag_inplace
113113
tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_real_inplace
114-
tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestScalarConversion::test_scalar_conversion
115114
tests/third_party/cupy/core_tests/test_ndarray_conversion.py::TestNdarrayToBytes_param_0_{shape=()}::test_item
116115
tests/third_party/cupy/core_tests/test_ndarray_conversion.py::TestNdarrayToBytes_param_1_{shape=(1,)}::test_item
117116
tests/third_party/cupy/core_tests/test_ndarray_conversion.py::TestNdarrayToBytes_param_2_{shape=(2, 3)}::test_item
118117
tests/third_party/cupy/core_tests/test_ndarray_conversion.py::TestNdarrayToBytes_param_3_{order='C', shape=(2, 3)}::test_item
119118
tests/third_party/cupy/core_tests/test_ndarray_conversion.py::TestNdarrayToBytes_param_4_{order='F', shape=(2, 3)}::test_item
120119

121-
tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayAsType::test_astype_strides_broadcast
122120
tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayDiagonal::test_diagonal1
123121
tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayDiagonal::test_diagonal2
124122
tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayFlatten::test_flatten_order

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,7 @@ tests/test_linalg.py::test_matrix_rank[None-[[1, 2], [3, 4]]-int32]
240240
tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestAngle::test_angle
241241
tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_imag_inplace
242242
tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_real_inplace
243-
tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestScalarConversion::test_scalar_conversion
244243

245-
tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayAsType::test_astype_strides_broadcast
246244
tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayDiagonal::test_diagonal1
247245
tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayDiagonal::test_diagonal2
248246
tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayFlatten::test_flatten_order

tests/test_dparray.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414

1515

16+
@pytest.mark.usefixtures("suppress_complex_warning")
1617
@pytest.mark.parametrize("res_dtype", get_all_dtypes())
1718
@pytest.mark.parametrize("arr_dtype", get_all_dtypes())
1819
@pytest.mark.parametrize(
@@ -28,6 +29,12 @@ def test_astype(arr, arr_dtype, res_dtype):
2829
assert_allclose(expected, result)
2930

3031

32+
def test_astype_subok_error():
33+
x = dpnp.ones((4))
34+
with pytest.raises(NotImplementedError):
35+
x.astype("i4", subok=False)
36+
37+
3138
@pytest.mark.parametrize("arr_dtype", get_all_dtypes())
3239
@pytest.mark.parametrize(
3340
"arr",

tests/test_mathematical.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,6 +1647,7 @@ def test_sum_empty_out(dtype):
16471647
assert_array_equal(out.asnumpy(), numpy.array(0, dtype=dtype))
16481648

16491649

1650+
@pytest.mark.usefixtures("suppress_complex_warning")
16501651
@pytest.mark.parametrize(
16511652
"shape",
16521653
[

0 commit comments

Comments
 (0)