Skip to content

Commit 051e473

Browse files
Merge pull request #913 from IntelPython/add-dtypes-and-infos
Exported data types, finfo and iinfo symbols
2 parents 8cbed99 + 21a6aaa commit 051e473

8 files changed

+197
-59
lines changed

dpctl/tensor/__init__.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
2222
"""
2323

24+
from numpy import dtype
25+
2426
from dpctl.tensor._copy_utils import asnumpy, astype, copy, from_numpy, to_numpy
2527
from dpctl.tensor._ctors import (
2628
arange,
@@ -44,17 +46,36 @@
4446
from dpctl.tensor._manipulation_functions import (
4547
broadcast_arrays,
4648
broadcast_to,
49+
can_cast,
4750
concat,
4851
expand_dims,
52+
finfo,
4953
flip,
54+
iinfo,
5055
permute_dims,
56+
result_type,
5157
roll,
5258
squeeze,
5359
stack,
5460
)
5561
from dpctl.tensor._reshape import reshape
5662
from dpctl.tensor._usmarray import usm_ndarray
5763

64+
bool = dtype("bool")
65+
int8 = dtype("int8")
66+
int16 = dtype("int16")
67+
int32 = dtype("int32")
68+
int64 = dtype("int64")
69+
uint8 = dtype("uint8")
70+
uint16 = dtype("uint16")
71+
uint32 = dtype("uint32")
72+
uint64 = dtype("uint64")
73+
float16 = dtype("float16")
74+
float32 = dtype("float32")
75+
float64 = dtype("float64")
76+
complex64 = dtype("complex64")
77+
complex128 = dtype("complex128")
78+
5879
__all__ = [
5980
"Device",
6081
"usm_ndarray",
@@ -88,5 +109,24 @@
88109
"from_dlpack",
89110
"tril",
90111
"triu",
112+
"dtype",
113+
"bool",
114+
"int8",
115+
"uint8",
116+
"int16",
117+
"uint16",
118+
"int32",
119+
"uint32",
120+
"int64",
121+
"uint64",
122+
"float16",
123+
"float32",
124+
"float64",
125+
"complex64",
126+
"complex128",
127+
"iinfo",
128+
"finfo",
129+
"can_cast",
130+
"result_type",
91131
"meshgrid",
92132
]

dpctl/tensor/_copy_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _copy_from_numpy(np_ary, usm_type="device", sycl_queue=None):
6464
dt = Xnp.dtype
6565
if dt.char in "dD" and alloc_q.sycl_device.has_aspect_fp64 is False:
6666
Xusm_dtype = (
67-
np.dtype("float32") if dt.char == "d" else np.dtype("complex64")
67+
dpt.dtype("float32") if dt.char == "d" else dpt.dtype("complex64")
6868
)
6969
else:
7070
Xusm_dtype = dt
@@ -318,8 +318,8 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
318318
"Recognized values are 'A', 'C', 'F', or 'K'"
319319
)
320320
ary_dtype = usm_ary.dtype
321-
target_dtype = np.dtype(newdtype)
322-
if not np.can_cast(ary_dtype, target_dtype, casting=casting):
321+
target_dtype = dpt.dtype(newdtype)
322+
if not dpt.can_cast(ary_dtype, target_dtype, casting=casting):
323323
raise TypeError(
324324
"Can not cast from {} to {} according to rule {}".format(
325325
ary_dtype, newdtype, casting

dpctl/tensor/_ctors.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,20 @@ def _get_dtype(dtype, sycl_obj, ref_type=None):
3333
if dtype is None:
3434
if ref_type in [None, float] or np.issubdtype(ref_type, np.floating):
3535
dtype = ti.default_device_fp_type(sycl_obj)
36-
return np.dtype(dtype)
36+
return dpt.dtype(dtype)
3737
elif ref_type in [bool, np.bool_]:
3838
dtype = ti.default_device_bool_type(sycl_obj)
39-
return np.dtype(dtype)
39+
return dpt.dtype(dtype)
4040
elif ref_type is int or np.issubdtype(ref_type, np.integer):
4141
dtype = ti.default_device_int_type(sycl_obj)
42-
return np.dtype(dtype)
42+
return dpt.dtype(dtype)
4343
elif ref_type is complex or np.issubdtype(ref_type, np.complexfloating):
4444
dtype = ti.default_device_complex_type(sycl_obj)
45-
return np.dtype(dtype)
45+
return dpt.dtype(dtype)
4646
else:
4747
raise TypeError(f"Reference type {ref_type} not recognized.")
4848
else:
49-
return np.dtype(dtype)
49+
return dpt.dtype(dtype)
5050

5151

5252
def _array_info_dispatch(obj):
@@ -313,7 +313,7 @@ def asarray(
313313
)
314314
# 2. Check that dtype is None, or a valid dtype
315315
if dtype is not None:
316-
dtype = np.dtype(dtype)
316+
dtype = dpt.dtype(dtype)
317317
# 3. Validate order
318318
if not isinstance(order, str):
319319
raise TypeError(
@@ -768,7 +768,7 @@ def empty_like(
768768
device = x.device
769769
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
770770
sh = x.shape
771-
dtype = np.dtype(dtype)
771+
dtype = dpt.dtype(dtype)
772772
res = dpt.usm_ndarray(
773773
sh,
774774
dtype=dtype,
@@ -825,7 +825,7 @@ def zeros_like(
825825
device = x.device
826826
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
827827
sh = x.shape
828-
dtype = np.dtype(dtype)
828+
dtype = dpt.dtype(dtype)
829829
return zeros(
830830
sh,
831831
dtype=dtype,
@@ -882,7 +882,7 @@ def ones_like(
882882
device = x.device
883883
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
884884
sh = x.shape
885-
dtype = np.dtype(dtype)
885+
dtype = dpt.dtype(dtype)
886886
return ones(
887887
sh,
888888
dtype=dtype,
@@ -946,7 +946,7 @@ def full_like(
946946
device = x.device
947947
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
948948
sh = x.shape
949-
dtype = np.dtype(dtype)
949+
dtype = dpt.dtype(dtype)
950950
return full(
951951
sh,
952952
fill_value,
@@ -1026,7 +1026,7 @@ def linspace(
10261026
)
10271027
if dtype is None and np.issubdtype(dt, np.integer):
10281028
dt = ti.default_device_fp_type(sycl_queue)
1029-
dt = np.dtype(dt)
1029+
dt = dpt.dtype(dt)
10301030
start = float(start)
10311031
stop = float(stop)
10321032
res = dpt.empty(num, dtype=dt, sycl_queue=sycl_queue)

dpctl/tensor/_manipulation_functions.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,7 @@ def _arrays_validation(arrays):
309309
raise ValueError("All the input arrays must have usm_type")
310310

311311
X0 = arrays[0]
312-
if not all(Xi.dtype.char in "?bBhHiIlLqQefdFD" for Xi in arrays):
313-
raise ValueError("Unsupported dtype encountered.")
312+
_supported_dtype(Xi.dtype for Xi in arrays)
314313

315314
res_dtype = X0.dtype
316315
for i in range(1, n):
@@ -421,3 +420,73 @@ def stack(arrays, axis=0):
421420
dpctl.SyclEvent.wait_for(hev_list)
422421

423422
return res
423+
424+
425+
def can_cast(from_, to, casting="safe"):
426+
"""
427+
can_cast(from: usm_ndarray or dtype, to: dtype) -> bool
428+
429+
Determines if one data type can be cast to another data type according \
430+
to Type Promotion Rules rules.
431+
"""
432+
if isinstance(to, dpt.usm_ndarray):
433+
raise TypeError("Expected dtype type.")
434+
435+
dtype_to = dpt.dtype(to)
436+
437+
dtype_from = (
438+
from_.dtype if isinstance(from_, dpt.usm_ndarray) else dpt.dtype(from_)
439+
)
440+
441+
_supported_dtype([dtype_from, dtype_to])
442+
443+
return np.can_cast(dtype_from, dtype_to, casting)
444+
445+
446+
def result_type(*arrays_and_dtypes):
447+
"""
448+
result_type(arrays_and_dtypes: an arbitrary number usm_ndarrays or dtypes)\
449+
-> dtype
450+
451+
Returns the dtype that results from applying the Type Promotion Rules to \
452+
the arguments.
453+
"""
454+
dtypes = [
455+
X.dtype if isinstance(X, dpt.usm_ndarray) else dpt.dtype(X)
456+
for X in arrays_and_dtypes
457+
]
458+
459+
_supported_dtype(dtypes)
460+
461+
return np.result_type(*dtypes)
462+
463+
464+
def iinfo(type):
465+
"""
466+
iinfo(type: integer data-type) -> iinfo_object
467+
468+
Returns machine limits for integer data types.
469+
"""
470+
if isinstance(type, dpt.usm_ndarray):
471+
raise TypeError("Expected dtype type, get {to}.")
472+
_supported_dtype([dpt.dtype(type)])
473+
return np.iinfo(type)
474+
475+
476+
def finfo(type):
477+
"""
478+
finfo(type: float data-type) -> finfo_object
479+
480+
Returns machine limits for float data types.
481+
"""
482+
if isinstance(type, dpt.usm_ndarray):
483+
raise TypeError("Expected dtype type, get {to}.")
484+
_supported_dtype([dpt.dtype(type)])
485+
return np.finfo(type)
486+
487+
488+
def _supported_dtype(dtypes):
489+
for dtype in dtypes:
490+
if dtype.char not in "?bBhHiIlLqQefdFD":
491+
raise ValueError(f"Dpctl doesn't support dtype {dtype}.")
492+
return True

dpctl/tests/test_sycl_kernel_submit.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,21 @@
3131
@pytest.mark.parametrize(
3232
"ctype_str,dtype,ctypes_ctor",
3333
[
34-
("short", np.dtype("i2"), ctypes.c_short),
35-
("int", np.dtype("i4"), ctypes.c_int),
36-
("unsigned int", np.dtype("u4"), ctypes.c_uint),
37-
("long", np.dtype(np.longlong), ctypes.c_longlong),
38-
("unsigned long", np.dtype(np.ulonglong), ctypes.c_ulonglong),
39-
("float", np.dtype("f4"), ctypes.c_float),
40-
("double", np.dtype("f8"), ctypes.c_double),
34+
("short", dpt.dtype("i2"), ctypes.c_short),
35+
("int", dpt.dtype("i4"), ctypes.c_int),
36+
("unsigned int", dpt.dtype("u4"), ctypes.c_uint),
37+
("long", dpt.dtype(np.longlong), ctypes.c_longlong),
38+
("unsigned long", dpt.dtype(np.ulonglong), ctypes.c_ulonglong),
39+
("float", dpt.dtype("f4"), ctypes.c_float),
40+
("double", dpt.dtype("f8"), ctypes.c_double),
4141
],
4242
)
4343
def test_create_program_from_source(ctype_str, dtype, ctypes_ctor):
4444
try:
4545
q = dpctl.SyclQueue("opencl", property="enable_profiling")
4646
except dpctl.SyclQueueCreationError:
4747
pytest.skip("OpenCL queue could not be created")
48-
if dtype == np.dtype("f8") and q.sycl_device.has_aspect_fp64 is False:
48+
if dtype == dpt.dtype("f8") and q.sycl_device.has_aspect_fp64 is False:
4949
pytest.skip(
5050
"Device does not support double precision floating point type"
5151
)

dpctl/tests/test_tensor_asarray.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -175,21 +175,21 @@ def test_asarray_scalars():
175175
import ctypes
176176

177177
Y = dpt.asarray(5)
178-
assert Y.dtype == np.dtype(int)
178+
assert Y.dtype == dpt.dtype(int)
179179
Y = dpt.asarray(5.2)
180180
if Y.sycl_device.has_aspect_fp64:
181-
assert Y.dtype == np.dtype(float)
181+
assert Y.dtype == dpt.dtype(float)
182182
else:
183-
assert Y.dtype == np.dtype(np.float32)
183+
assert Y.dtype == dpt.dtype(dpt.float32)
184184
Y = dpt.asarray(np.float32(2.3))
185-
assert Y.dtype == np.dtype(np.float32)
185+
assert Y.dtype == dpt.dtype(dpt.float32)
186186
Y = dpt.asarray(1.0j)
187187
if Y.sycl_device.has_aspect_fp64:
188-
assert Y.dtype == np.dtype(complex)
188+
assert Y.dtype == dpt.dtype(complex)
189189
else:
190-
assert Y.dtype == np.dtype(np.complex64)
190+
assert Y.dtype == dpt.dtype(dpt.complex64)
191191
Y = dpt.asarray(ctypes.c_int(8))
192-
assert Y.dtype == np.dtype(ctypes.c_int)
192+
assert Y.dtype == dpt.dtype(ctypes.c_int)
193193

194194

195195
def test_asarray_copy_false():

0 commit comments

Comments
 (0)