Skip to content

Commit 54889b5

Browse files
npolina4oleksandr-pavlyk
authored andcommitted
Added can_cast and result_type
1 parent 6120377 commit 54889b5

File tree

3 files changed

+93
-3
lines changed

3 files changed

+93
-3
lines changed

dpctl/tensor/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
2222
"""
2323

24-
from numpy import dtype, finfo, iinfo
24+
from numpy import dtype
2525

2626
from dpctl.tensor._copy_utils import asnumpy, astype, copy, from_numpy, to_numpy
2727
from dpctl.tensor._ctors import (
@@ -45,10 +45,14 @@
4545
from dpctl.tensor._manipulation_functions import (
4646
broadcast_arrays,
4747
broadcast_to,
48+
can_cast,
4849
concat,
4950
expand_dims,
51+
finfo,
5052
flip,
53+
iinfo,
5154
permute_dims,
55+
result_type,
5256
roll,
5357
squeeze,
5458
stack,
@@ -121,4 +125,6 @@
121125
"complex128",
122126
"iinfo",
123127
"finfo",
128+
"can_cast",
129+
"result_type",
124130
]

dpctl/tensor/_manipulation_functions.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,7 @@ def _arrays_validation(arrays):
309309
raise ValueError("All the input arrays must have usm_type")
310310

311311
X0 = arrays[0]
312-
if not all(Xi.dtype.char in "?bBhHiIlLqQefdFD" for Xi in arrays):
313-
raise ValueError("Unsupported dtype encountered.")
312+
_support_dtype(Xi.dtype for Xi in arrays)
314313

315314
res_dtype = X0.dtype
316315
for i in range(1, n):
@@ -421,3 +420,61 @@ def stack(arrays, axis=0):
421420
dpctl.SyclEvent.wait_for(hev_list)
422421

423422
return res
423+
424+
425+
def can_cast(array_and_dtype_from, dtype_to):
426+
"""
427+
can_cast(from: usm_ndarray or dtype, to: dtype) -> bool
428+
429+
Determines if one data type can be cast to another data type according \
430+
Type Promotion Rules rules.
431+
"""
432+
if not isinstance(dtype_to, dpt.dtype):
433+
raise TypeError("Expected dtype type.")
434+
435+
dtype_from = dpt.dtype(array_and_dtype_from)
436+
437+
_support_dtype([dtype_to, dtype_from])
438+
439+
return np.can_cast(dtype_from, dtype_to)
440+
441+
442+
def result_type(*arrays_and_dtypes):
443+
"""
444+
result_type(arrays_and_dtypes: an arbitrary number usm_ndarrays or dtypes)\
445+
-> dtype
446+
447+
Returns the dtype that results from applying the Type Promotion Rules to \
448+
the arguments.
449+
"""
450+
dtypes = [dpt.dtype(X) for X in arrays_and_dtypes]
451+
452+
_support_dtype(dtypes)
453+
454+
return np.result_type(*dtypes)
455+
456+
457+
def iinfo(type):
458+
"""
459+
iinfo(type: integer data-type) -> iinfo_object
460+
461+
Returns machine limits for integer data types.
462+
"""
463+
_support_dtype(type)
464+
return np.iinfo(type)
465+
466+
467+
def finfo(type):
468+
"""
469+
finfo(type: float data-type) -> finfo_object
470+
471+
Returns machine limits for float data types.
472+
"""
473+
_support_dtype(type)
474+
return np.finfo(type)
475+
476+
477+
def _support_dtype(dtypes):
478+
if not all(dtype.char in "?bBhHiIlLqQefdFD" for dtype in dtypes):
479+
raise ValueError("Unsupported dtype encountered.")
480+
return True

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -999,3 +999,30 @@ def test_stack_3arrays(data):
999999
R = dpt.stack([X, Y, Z], axis=axis)
10001000

10011001
assert_array_equal(Rnp, dpt.asnumpy(R))
1002+
1003+
1004+
def test_can_cast():
1005+
try:
1006+
q = dpctl.SyclQueue()
1007+
except dpctl.SyclQueueCreationError:
1008+
pytest.skip("Queue could not be created")
1009+
1010+
# incorrect input
1011+
X = dpt.ones((2, 2), dtype=dpt.int64, sycl_queue=q)
1012+
pytest.raises(TypeError, dpt.can_cast, X, 1)
1013+
X_np = np.ones((2, 2), dtype=np.int64)
1014+
1015+
assert dpt.can_cast(X, dpt.int32) == np.can_cast(X_np, np.int32)
1016+
assert dpt.can_cast(X, dpt.int64) == np.can_cast(X_np, np.int64)
1017+
1018+
1019+
def test_result_type():
1020+
try:
1021+
q = dpctl.SyclQueue()
1022+
except dpctl.SyclQueueCreationError:
1023+
pytest.skip("Queue could not be created")
1024+
1025+
X = [dpt.ones((2), dtype=dpt.int64, sycl_queue=q), dpt.int32, dpt.float16]
1026+
X_np = [np.ones((2), dtype=np.int64), dpt.int32, dpt.float16]
1027+
1028+
assert dpt.result_type(*X) == np.result_type(*X_np)

0 commit comments

Comments
 (0)