Skip to content

Commit 9f41b06

Browse files
committed
Changed np namespace to dpt for dtype, iinfo, and can_cast functuons
1 parent 5050b47 commit 9f41b06

File tree

6 files changed

+65
-65
lines changed

6 files changed

+65
-65
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _copy_from_numpy(np_ary, usm_type="device", sycl_queue=None):
6464
dt = Xnp.dtype
6565
if dt.char in "dD" and alloc_q.sycl_device.has_aspect_fp64 is False:
6666
Xusm_dtype = (
67-
np.dtype("float32") if dt.char == "d" else np.dtype("complex64")
67+
dpt.dtype("float32") if dt.char == "d" else dpt.dtype("complex64")
6868
)
6969
else:
7070
Xusm_dtype = dt
@@ -318,8 +318,8 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
318318
"Recognized values are 'A', 'C', 'F', or 'K'"
319319
)
320320
ary_dtype = usm_ary.dtype
321-
target_dtype = np.dtype(newdtype)
322-
if not np.can_cast(ary_dtype, target_dtype, casting=casting):
321+
target_dtype = dpt.dtype(newdtype)
322+
if not dpt.can_cast(ary_dtype, target_dtype, casting=casting):
323323
raise TypeError(
324324
"Can not cast from {} to {} according to rule {}".format(
325325
ary_dtype, newdtype, casting

dpctl/tensor/_ctors.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,20 @@ def _get_dtype(dtype, sycl_obj, ref_type=None):
3333
if dtype is None:
3434
if ref_type in [None, float] or np.issubdtype(ref_type, np.floating):
3535
dtype = ti.default_device_fp_type(sycl_obj)
36-
return np.dtype(dtype)
36+
return dpt.dtype(dtype)
3737
elif ref_type in [bool, np.bool_]:
3838
dtype = ti.default_device_bool_type(sycl_obj)
39-
return np.dtype(dtype)
39+
return dpt.dtype(dtype)
4040
elif ref_type is int or np.issubdtype(ref_type, np.integer):
4141
dtype = ti.default_device_int_type(sycl_obj)
42-
return np.dtype(dtype)
42+
return dpt.dtype(dtype)
4343
elif ref_type is complex or np.issubdtype(ref_type, np.complexfloating):
4444
dtype = ti.default_device_complex_type(sycl_obj)
45-
return np.dtype(dtype)
45+
return dpt.dtype(dtype)
4646
else:
4747
raise TypeError(f"Reference type {ref_type} not recognized.")
4848
else:
49-
return np.dtype(dtype)
49+
return dpt.dtype(dtype)
5050

5151

5252
def _array_info_dispatch(obj):
@@ -313,7 +313,7 @@ def asarray(
313313
)
314314
# 2. Check that dtype is None, or a valid dtype
315315
if dtype is not None:
316-
dtype = np.dtype(dtype)
316+
dtype = dpt.dtype(dtype)
317317
# 3. Validate order
318318
if not isinstance(order, str):
319319
raise TypeError(
@@ -768,7 +768,7 @@ def empty_like(
768768
device = x.device
769769
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
770770
sh = x.shape
771-
dtype = np.dtype(dtype)
771+
dtype = dpt.dtype(dtype)
772772
res = dpt.usm_ndarray(
773773
sh,
774774
dtype=dtype,
@@ -825,7 +825,7 @@ def zeros_like(
825825
device = x.device
826826
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
827827
sh = x.shape
828-
dtype = np.dtype(dtype)
828+
dtype = dpt.dtype(dtype)
829829
return zeros(
830830
sh,
831831
dtype=dtype,
@@ -882,7 +882,7 @@ def ones_like(
882882
device = x.device
883883
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
884884
sh = x.shape
885-
dtype = np.dtype(dtype)
885+
dtype = dpt.dtype(dtype)
886886
return ones(
887887
sh,
888888
dtype=dtype,
@@ -946,7 +946,7 @@ def full_like(
946946
device = x.device
947947
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
948948
sh = x.shape
949-
dtype = np.dtype(dtype)
949+
dtype = dpt.dtype(dtype)
950950
return full(
951951
sh,
952952
fill_value,
@@ -1026,7 +1026,7 @@ def linspace(
10261026
)
10271027
if dtype is None and np.issubdtype(dt, np.integer):
10281028
dt = ti.default_device_fp_type(sycl_queue)
1029-
dt = np.dtype(dt)
1029+
dt = dpt.dtype(dt)
10301030
start = float(start)
10311031
stop = float(stop)
10321032
res = dpt.empty(num, dtype=dt, sycl_queue=sycl_queue)

dpctl/tensor/_manipulation_functions.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def _arrays_validation(arrays):
309309
raise ValueError("All the input arrays must have usm_type")
310310

311311
X0 = arrays[0]
312-
_support_dtype(Xi.dtype for Xi in arrays)
312+
_supported_dtype(Xi.dtype for Xi in arrays)
313313

314314
res_dtype = X0.dtype
315315
for i in range(1, n):
@@ -422,7 +422,7 @@ def stack(arrays, axis=0):
422422
return res
423423

424424

425-
def can_cast(array_and_dtype_from, dtype_to):
425+
def can_cast(array_and_dtype_from, dtype_to, casting="safe"):
426426
"""
427427
can_cast(from: usm_ndarray or dtype, to: dtype) -> bool
428428
@@ -434,9 +434,9 @@ def can_cast(array_and_dtype_from, dtype_to):
434434

435435
dtype_from = dpt.dtype(array_and_dtype_from)
436436

437-
_support_dtype([dtype_to, dtype_from])
437+
_supported_dtype([dtype_to, dtype_from])
438438

439-
return np.can_cast(dtype_from, dtype_to)
439+
return np.can_cast(dtype_from, dtype_to, casting)
440440

441441

442442
def result_type(*arrays_and_dtypes):
@@ -449,7 +449,7 @@ def result_type(*arrays_and_dtypes):
449449
"""
450450
dtypes = [dpt.dtype(X) for X in arrays_and_dtypes]
451451

452-
_support_dtype(dtypes)
452+
_supported_dtype(dtypes)
453453

454454
return np.result_type(*dtypes)
455455

@@ -460,7 +460,7 @@ def iinfo(type):
460460
461461
Returns machine limits for integer data types.
462462
"""
463-
_support_dtype(type)
463+
_supported_dtype([dpt.dtype(type)])
464464
return np.iinfo(type)
465465

466466

@@ -470,11 +470,11 @@ def finfo(type):
470470
471471
Returns machine limits for float data types.
472472
"""
473-
_support_dtype(type)
473+
_supported_dtype([dpt.dtype(type)])
474474
return np.finfo(type)
475475

476476

477-
def _support_dtype(dtypes):
477+
def _supported_dtype(dtypes):
478478
if not all(dtype.char in "?bBhHiIlLqQefdFD" for dtype in dtypes):
479479
raise ValueError("Unsupported dtype encountered.")
480480
return True

dpctl/tests/test_sycl_kernel_submit.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,21 @@
3131
@pytest.mark.parametrize(
3232
"ctype_str,dtype,ctypes_ctor",
3333
[
34-
("short", np.dtype("i2"), ctypes.c_short),
35-
("int", np.dtype("i4"), ctypes.c_int),
36-
("unsigned int", np.dtype("u4"), ctypes.c_uint),
37-
("long", np.dtype(np.longlong), ctypes.c_longlong),
38-
("unsigned long", np.dtype(np.ulonglong), ctypes.c_ulonglong),
39-
("float", np.dtype("f4"), ctypes.c_float),
40-
("double", np.dtype("f8"), ctypes.c_double),
34+
("short", dpt.dtype("i2"), ctypes.c_short),
35+
("int", dpt.dtype("i4"), ctypes.c_int),
36+
("unsigned int", dpt.dtype("u4"), ctypes.c_uint),
37+
("long", dpt.dtype(np.longlong), ctypes.c_longlong),
38+
("unsigned long", dpt.dtype(np.ulonglong), ctypes.c_ulonglong),
39+
("float", dpt.dtype("f4"), ctypes.c_float),
40+
("double", dpt.dtype("f8"), ctypes.c_double),
4141
],
4242
)
4343
def test_create_program_from_source(ctype_str, dtype, ctypes_ctor):
4444
try:
4545
q = dpctl.SyclQueue("opencl", property="enable_profiling")
4646
except dpctl.SyclQueueCreationError:
4747
pytest.skip("OpenCL queue could not be created")
48-
if dtype == np.dtype("f8") and q.sycl_device.has_aspect_fp64 is False:
48+
if dtype == dpt.dtype("f8") and q.sycl_device.has_aspect_fp64 is False:
4949
pytest.skip(
5050
"Device does not support double precision floating point type"
5151
)

dpctl/tests/test_tensor_asarray.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -175,21 +175,21 @@ def test_asarray_scalars():
175175
import ctypes
176176

177177
Y = dpt.asarray(5)
178-
assert Y.dtype == np.dtype(int)
178+
assert Y.dtype == dpt.dtype(int)
179179
Y = dpt.asarray(5.2)
180180
if Y.sycl_device.has_aspect_fp64:
181-
assert Y.dtype == np.dtype(float)
181+
assert Y.dtype == dpt.dtype(float)
182182
else:
183-
assert Y.dtype == np.dtype(np.float32)
183+
assert Y.dtype == dpt.dtype(dpt.float32)
184184
Y = dpt.asarray(np.float32(2.3))
185-
assert Y.dtype == np.dtype(np.float32)
185+
assert Y.dtype == dpt.dtype(dpt.float32)
186186
Y = dpt.asarray(1.0j)
187187
if Y.sycl_device.has_aspect_fp64:
188-
assert Y.dtype == np.dtype(complex)
188+
assert Y.dtype == dpt.dtype(complex)
189189
else:
190-
assert Y.dtype == np.dtype(np.complex64)
190+
assert Y.dtype == dpt.dtype(dpt.complex64)
191191
Y = dpt.asarray(ctypes.c_int(8))
192-
assert Y.dtype == np.dtype(ctypes.c_int)
192+
assert Y.dtype == dpt.dtype(ctypes.c_int)
193193

194194

195195
def test_asarray_copy_false():

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,14 @@ def test_usm_ndarray_flags():
8585
"c8",
8686
"c16",
8787
b"float32",
88-
np.dtype("d"),
88+
dpt.dtype("d"),
8989
np.half,
9090
],
9191
)
9292
def test_dtypes(dtype):
9393
Xusm = dpt.usm_ndarray((1,), dtype=dtype)
94-
assert Xusm.itemsize == np.dtype(dtype).itemsize
95-
expected_fmt = (np.dtype(dtype).str)[1:]
94+
assert Xusm.itemsize == dpt.dtype(dtype).itemsize
95+
expected_fmt = (dpt.dtype(dtype).str)[1:]
9696
actual_fmt = Xusm.__sycl_usm_array_interface__["typestr"][1:]
9797
assert expected_fmt == actual_fmt
9898

@@ -112,7 +112,7 @@ def test_properties(dt):
112112
assert isinstance(X.sycl_queue, dpctl.SyclQueue)
113113
assert isinstance(X.sycl_device, dpctl.SyclDevice)
114114
assert isinstance(X.sycl_context, dpctl.SyclContext)
115-
assert isinstance(X.dtype, np.dtype)
115+
assert isinstance(X.dtype, dpt.dtype)
116116
assert isinstance(X.__sycl_usm_array_interface__, dict)
117117
assert isinstance(X.mT, dpt.usm_ndarray)
118118
assert isinstance(X.imag, dpt.usm_ndarray)
@@ -521,44 +521,44 @@ def test_pyx_capi_check_constants():
521521
assert w_flag > 0 and 0 == (w_flag & (w_flag - 1))
522522

523523
bool_typenum = _pyx_capi_int(X, "UAR_BOOL")
524-
assert bool_typenum == np.dtype("bool_").num
524+
assert bool_typenum == dpt.dtype("bool_").num
525525

526526
byte_typenum = _pyx_capi_int(X, "UAR_BYTE")
527-
assert byte_typenum == np.dtype(np.byte).num
527+
assert byte_typenum == dpt.dtype(np.byte).num
528528
ubyte_typenum = _pyx_capi_int(X, "UAR_UBYTE")
529-
assert ubyte_typenum == np.dtype(np.ubyte).num
529+
assert ubyte_typenum == dpt.dtype(np.ubyte).num
530530

531531
short_typenum = _pyx_capi_int(X, "UAR_SHORT")
532-
assert short_typenum == np.dtype(np.short).num
532+
assert short_typenum == dpt.dtype(np.short).num
533533
ushort_typenum = _pyx_capi_int(X, "UAR_USHORT")
534-
assert ushort_typenum == np.dtype(np.ushort).num
534+
assert ushort_typenum == dpt.dtype(np.ushort).num
535535

536536
int_typenum = _pyx_capi_int(X, "UAR_INT")
537-
assert int_typenum == np.dtype(np.intc).num
537+
assert int_typenum == dpt.dtype(np.intc).num
538538
uint_typenum = _pyx_capi_int(X, "UAR_UINT")
539-
assert uint_typenum == np.dtype(np.uintc).num
539+
assert uint_typenum == dpt.dtype(np.uintc).num
540540

541541
long_typenum = _pyx_capi_int(X, "UAR_LONG")
542-
assert long_typenum == np.dtype(np.int_).num
542+
assert long_typenum == dpt.dtype(np.int_).num
543543
ulong_typenum = _pyx_capi_int(X, "UAR_ULONG")
544-
assert ulong_typenum == np.dtype(np.uint).num
544+
assert ulong_typenum == dpt.dtype(np.uint).num
545545

546546
longlong_typenum = _pyx_capi_int(X, "UAR_LONGLONG")
547-
assert longlong_typenum == np.dtype(np.longlong).num
547+
assert longlong_typenum == dpt.dtype(np.longlong).num
548548
ulonglong_typenum = _pyx_capi_int(X, "UAR_ULONGLONG")
549-
assert ulonglong_typenum == np.dtype(np.ulonglong).num
549+
assert ulonglong_typenum == dpt.dtype(np.ulonglong).num
550550

551551
half_typenum = _pyx_capi_int(X, "UAR_HALF")
552-
assert half_typenum == np.dtype(np.half).num
552+
assert half_typenum == dpt.dtype(np.half).num
553553
float_typenum = _pyx_capi_int(X, "UAR_FLOAT")
554-
assert float_typenum == np.dtype(np.single).num
554+
assert float_typenum == dpt.dtype(np.single).num
555555
double_typenum = _pyx_capi_int(X, "UAR_DOUBLE")
556-
assert double_typenum == np.dtype(np.double).num
556+
assert double_typenum == dpt.dtype(np.double).num
557557

558558
cfloat_typenum = _pyx_capi_int(X, "UAR_CFLOAT")
559-
assert cfloat_typenum == np.dtype(np.csingle).num
559+
assert cfloat_typenum == dpt.dtype(np.csingle).num
560560
cdouble_typenum = _pyx_capi_int(X, "UAR_CDOUBLE")
561-
assert cdouble_typenum == np.dtype(np.cdouble).num
561+
assert cdouble_typenum == dpt.dtype(np.cdouble).num
562562

563563

564564
_all_dtypes = [
@@ -720,12 +720,12 @@ def test_setitem_wingaps():
720720
q = dpctl.SyclQueue()
721721
except dpctl.SyclQueueCreationError:
722722
pytest.skip("Default queue could not be created")
723-
if np.dtype("intc").itemsize == np.dtype("int32").itemsize:
723+
if dpt.dtype("intc").itemsize == dpt.dtype("int32").itemsize:
724724
dpt_dst = dpt.empty(4, dtype="int32", sycl_queue=q)
725725
np_src = np.arange(4, dtype="intc")
726726
dpt_dst[:] = np_src # should not raise exceptions
727727
assert np.array_equal(dpt.asnumpy(dpt_dst), np_src)
728-
if np.dtype("long").itemsize == np.dtype("longlong").itemsize:
728+
if dpt.dtype("long").itemsize == dpt.dtype("longlong").itemsize:
729729
dpt_dst = dpt.empty(4, dtype="longlong", sycl_queue=q)
730730
np_src = np.arange(4, dtype="long")
731731
dpt_dst[:] = np_src # should not raise exceptions
@@ -1027,7 +1027,7 @@ def test_full(dtype):
10271027

10281028
def test_full_dtype_inference():
10291029
assert np.issubdtype(dpt.full(10, 4).dtype, np.integer)
1030-
assert dpt.full(10, True).dtype is np.dtype(np.bool_)
1030+
assert dpt.full(10, True).dtype is dpt.dtype(np.bool_)
10311031
assert np.issubdtype(dpt.full(10, 12.3).dtype, np.floating)
10321032
assert np.issubdtype(dpt.full(10, 0.3 - 2j).dtype, np.complexfloating)
10331033

@@ -1047,7 +1047,7 @@ def test_arange(dt):
10471047
"Device does not support double precision floating point type"
10481048
)
10491049
X = dpt.arange(0, 123, dtype=dt, sycl_queue=q)
1050-
dt = np.dtype(dt)
1050+
dt = dpt.dtype(dt)
10511051
if np.issubdtype(dt, np.integer):
10521052
assert int(X[47]) == 47
10531053
elif np.issubdtype(dt, np.floating):
@@ -1056,7 +1056,7 @@ def test_arange(dt):
10561056
assert complex(X[47]) == 47.0 + 0.0j
10571057

10581058
# choose size larger than maximal value that u1/u2 can accomodate
1059-
sz = int(np.iinfo(np.int16).max) + 1
1059+
sz = int(dpt.iinfo(dpt.int16).max) + 1
10601060
X1 = dpt.arange(sz, dtype=dt, sycl_queue=q)
10611061
assert X1.shape == (sz,)
10621062

@@ -1101,9 +1101,9 @@ def test_linspace_fp():
11011101
n = 16
11021102
X = dpt.linspace(0, n - 1, num=n, sycl_queue=q)
11031103
if q.sycl_device.has_aspect_fp64:
1104-
assert X.dtype == np.dtype("float64")
1104+
assert X.dtype == dpt.dtype("float64")
11051105
else:
1106-
assert X.dtype == np.dtype("float32")
1106+
assert X.dtype == dpt.dtype("float32")
11071107
assert X.shape == (n,)
11081108
assert X.strides == (1,)
11091109

@@ -1238,7 +1238,7 @@ def test_full_like(dt, usm_kind):
12381238
"Device does not support double precision floating point type"
12391239
)
12401240

1241-
fill_v = np.dtype(dt).type(1)
1241+
fill_v = dpt.dtype(dt).type(1)
12421242
X = dpt.empty((4, 5), dtype=dt, usm_type=usm_kind, sycl_queue=q)
12431243
Y = dpt.full_like(X, fill_v)
12441244
assert X.shape == Y.shape

0 commit comments

Comments
 (0)