Skip to content

Commit 16a7d97

Browse files
committed
Fixed gh-1272
1 parent 744a3f2 commit 16a7d97

File tree

3 files changed

+36
-5
lines changed

3 files changed

+36
-5
lines changed

dpnp/dpnp_iface.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@
6666
"dpnp_queue_is_cpu",
6767
"get_dpnp_descriptor",
6868
"get_include",
69-
"get_normalized_queue_device"
69+
"get_normalized_queue_device",
70+
"isarray"
7071
]
7172

7273
from dpnp import (
@@ -338,3 +339,13 @@ def get_normalized_queue_device(obj=None,
338339
if hasattr(dpt._device, 'normalize_queue_device'):
339340
return dpt._device.normalize_queue_device(sycl_queue=sycl_queue, device=device)
340341
return sycl_queue
342+
343+
344+
def isarray(obj):
345+
"""
346+
Return True if:
347+
`obj` has a supported array type
348+
Return False if:
349+
`obj` has an unsupported array type or other data type
350+
"""
351+
return isinstance(obj, (dpnp_array, dpt.usm_ndarray))

dpnp/dpnp_iface_arraycreation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ def empty_like(x1,
552552
553553
"""
554554

555-
if not isinstance(x1, dpnp.ndarray):
555+
if not dpnp.isarray(x1):
556556
pass
557557
elif order not in ('C', 'c', 'F', 'f', None):
558558
pass
@@ -783,7 +783,7 @@ def full_like(x1,
783783
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
784784
785785
"""
786-
if not isinstance(x1, dpnp.ndarray):
786+
if not dpnp.isarray(x1):
787787
pass
788788
elif order not in ('C', 'c', 'F', 'f', None):
789789
pass
@@ -1211,7 +1211,7 @@ def ones_like(x1,
12111211
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
12121212
12131213
"""
1214-
if not isinstance(x1, dpnp.ndarray):
1214+
if not dpnp.isarray(x1):
12151215
pass
12161216
elif order not in ('C', 'c', 'F', 'f', None):
12171217
pass
@@ -1524,7 +1524,7 @@ def zeros_like(x1,
15241524
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
15251525
15261526
"""
1527-
if not isinstance(x1, dpnp.ndarray):
1527+
if not dpnp.isarray(x1):
15281528
pass
15291529
elif order not in ('C', 'c', 'F', 'f', None):
15301530
pass

tests/test_arraycreation.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,3 +485,23 @@ def test_ones_like(array, dtype, order):
485485
a = numpy.array(array)
486486
ia = dpnp.array(array)
487487
assert_array_equal(func(numpy, a), func(dpnp, ia))
488+
489+
490+
@pytest.mark.parametrize(
491+
"func, args",
492+
[
493+
pytest.param("full_like",
494+
['x0', '4']),
495+
pytest.param("zeros_like",
496+
['x0']),
497+
pytest.param("ones_like",
498+
['x0']),
499+
pytest.param("empty_like",
500+
['x0']),
501+
])
502+
def test_dpctl_tensor_input(func, args):
503+
x0 = dpt.reshape(dpt.arange(9), (3,3))
504+
new_args = [eval(val, {'x0' : x0}) for val in args]
505+
X = getattr(dpt, func)(*new_args)
506+
Y = getattr(dpnp, func)(*new_args)
507+
assert_array_equal(X, Y)

0 commit comments

Comments
 (0)