Skip to content

Commit 38fd39d

Browse files
npolina4antonwolfy
andauthored
Added device keyword argument to astype function (#1870)
* Added device keyword argument to astype function * Added test for astype function * address comments --------- Co-authored-by: Anton <[email protected]>
1 parent 6a737c4 commit 38fd39d

File tree

3 files changed

+47
-4
lines changed

3 files changed

+47
-4
lines changed

dpnp/dpnp_array.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,15 @@ def asnumpy(self):
562562

563563
return dpt.asnumpy(self._array_obj)
564564

565-
def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True):
565+
def astype(
566+
self,
567+
dtype,
568+
order="K",
569+
casting="unsafe",
570+
subok=True,
571+
copy=True,
572+
device=None,
573+
):
566574
"""
567575
Copy the array with data type casting.
568576
@@ -597,6 +605,13 @@ def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True):
597605
this is set to ``False``, and the `dtype`, `order`, and `subok`
598606
requirements are satisfied, the input array is returned instead of
599607
a copy.
608+
device : {None, string, SyclDevice, SyclQueue}, optional
609+
An array API concept of device where the output array is created.
610+
The `device` can be ``None`` (the default), an OneAPI filter selector
611+
string, an instance of :class:`dpctl.SyclDevice` corresponding to
612+
a non-partitioned SYCL device, an instance of :class:`dpctl.SyclQueue`,
613+
or a `Device` object returned by
614+
:obj:`dpnp.dpnp_array.dpnp_array.device` property. Default: ``None``.
600615
601616
Returns
602617
-------
@@ -626,7 +641,9 @@ def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True):
626641
f"subok={subok} is currently not supported"
627642
)
628643

629-
return dpnp.astype(self, dtype, order=order, casting=casting, copy=copy)
644+
return dpnp.astype(
645+
self, dtype, order=order, casting=casting, copy=copy, device=device
646+
)
630647

631648
# 'base',
632649
# 'byteswap',

dpnp/dpnp_iface.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def asnumpy(a, order="C"):
180180

181181

182182
# pylint: disable=redefined-outer-name
183-
def astype(x1, dtype, order="K", casting="unsafe", copy=True):
183+
def astype(x1, dtype, order="K", casting="unsafe", copy=True, device=None):
184184
"""
185185
Copy the array with data type casting.
186186
@@ -213,6 +213,13 @@ def astype(x1, dtype, order="K", casting="unsafe", copy=True):
213213
By default, ``astype`` always returns a newly allocated array. If this
214214
is set to ``False``, and the `dtype`, `order`, and `subok` requirements
215215
are satisfied, the input array is returned instead of a copy.
216+
device : {None, string, SyclDevice, SyclQueue}, optional
217+
An array API concept of device where the output array is created.
218+
The `device` can be ``None`` (the default), an OneAPI filter selector
219+
string, an instance of :class:`dpctl.SyclDevice` corresponding to
220+
a non-partitioned SYCL device, an instance of :class:`dpctl.SyclQueue`,
221+
or a `Device` object returned by
222+
:obj:`dpnp.dpnp_array.dpnp_array.device` property. Default: ``None``.
216223
217224
Returns
218225
-------
@@ -228,7 +235,7 @@ def astype(x1, dtype, order="K", casting="unsafe", copy=True):
228235

229236
x1_obj = dpnp.get_usm_ndarray(x1)
230237
array_obj = dpt.astype(
231-
x1_obj, dtype, order=order, casting=casting, copy=copy
238+
x1_obj, dtype, order=order, casting=casting, copy=copy, device=device
232239
)
233240

234241
# return x1 if dpctl returns a zero copy of x1_obj

tests/test_sycl_queue.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2211,3 +2211,22 @@ def test_histogram_bin_edges(weights, device):
22112211

22122212
edges_queue = result_edges.sycl_queue
22132213
assert_sycl_queue_equal(edges_queue, iv.sycl_queue)
2214+
2215+
2216+
@pytest.mark.parametrize(
2217+
"device_x",
2218+
valid_devices,
2219+
ids=[device.filter_string for device in valid_devices],
2220+
)
2221+
@pytest.mark.parametrize(
2222+
"device_y",
2223+
valid_devices,
2224+
ids=[device.filter_string for device in valid_devices],
2225+
)
2226+
def test_astype(device_x, device_y):
2227+
x = dpnp.array([1, 2, 3], dtype="i4", device=device_x)
2228+
y = dpnp.astype(x, dtype="f4")
2229+
assert_sycl_queue_equal(y.sycl_queue, x.sycl_queue)
2230+
sycl_queue = dpctl.SyclQueue(device_y)
2231+
y = dpnp.astype(x, dtype="f4", device=sycl_queue)
2232+
assert_sycl_queue_equal(y.sycl_queue, sycl_queue)

0 commit comments

Comments
 (0)