Skip to content

Commit 8ddc993

Browse files
Change dtype check in dpnp.asfarray and fix remarks
1 parent 0b3184f commit 8ddc993

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949
import dpnp
5050
import numpy
51+
import numpy.core.numeric as _nx
5152

5253

5354
__all__ = [
@@ -82,15 +83,15 @@ def asfarray(x1, dtype=None):
8283
Notes
8384
-----
8485
This function works exactly the same as :obj:`dpnp.array`.
86+
If dtype is `None`, `bool` or one of the `int` dtypes, it is replaced with
87+
the default floating type in DPNP depending on device capabilities.
8588
8689
"""
8790

8891
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
8992
if x1_desc:
90-
# int types replaced with a floating type by default in DPNP
91-
# depending on device capabilities.
92-
if dtype is None or numpy.issubdtype(dtype, numpy.integer):
93-
dtype = dpnp.default_float_type()
93+
if dtype is None or not numpy.issubdtype(dtype, _nx.inexact):
94+
dtype = dpnp.default_float_type(sycl_queue=x1.sycl_queue)
9495

9596
# if type is the same then same object should be returned
9697
if x1_desc.dtype == dtype:

tests/test_arraymanipulation.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
9-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
9+
@pytest.mark.parametrize("dtype", get_all_dtypes())
1010
@pytest.mark.parametrize(
1111
"data", [[1, 2, 3], [1.0, 2.0, 3.0]], ids=["[1, 2, 3]", "[1., 2., 3.]"]
1212
)
@@ -17,13 +17,12 @@ def test_asfarray(dtype, data):
1717
numpy.testing.assert_array_equal(result, expected)
1818

1919

20-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
21-
@pytest.mark.parametrize(
22-
"data", [[1, 2, 3], [1.0, 2.0, 3.0]], ids=["[1, 2, 3]", "[1., 2., 3.]"]
23-
)
24-
def test_asfarray2(dtype, data):
25-
expected = numpy.asfarray(numpy.array(data), dtype)
26-
result = dpnp.asfarray(dpnp.array(data), dtype)
20+
@pytest.mark.parametrize("dtype", get_all_dtypes())
21+
@pytest.mark.parametrize("data", [[1.0, 2.0, 3.0]], ids=["[1., 2., 3.]"])
22+
@pytest.mark.parametrize("data_dtype", get_all_dtypes(no_none=True))
23+
def test_asfarray2(dtype, data, data_dtype):
24+
expected = numpy.asfarray(numpy.array(data, dtype=data_dtype), dtype)
25+
result = dpnp.asfarray(dpnp.array(data, dtype=data_dtype), dtype)
2726

2827
numpy.testing.assert_array_equal(result, expected)
2928

0 commit comments

Comments
 (0)