Skip to content

Commit 433c4d7

Browse files
Fix tests to not use fp64 dtypes without checking for device aspects
Where dtype is irrelavant use data types mandated by SYCL standard, otherwise check if dtype can be used before making the call.
1 parent 6e59293 commit 433c4d7

File tree

1 file changed

+47
-36
lines changed

1 file changed

+47
-36
lines changed

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@
4545
def test_allocate_usm_ndarray(shape, usm_type):
4646
q = get_queue_or_skip()
4747
X = dpt.usm_ndarray(
48-
shape, dtype="d", buffer=usm_type, buffer_ctor_kwargs={"queue": q}
48+
shape, dtype="i8", buffer=usm_type, buffer_ctor_kwargs={"queue": q}
4949
)
50-
Xnp = np.ndarray(shape, dtype="d")
50+
Xnp = np.ndarray(shape, dtype="i8")
5151
assert X.usm_type == usm_type
5252
assert X.sycl_context == q.sycl_context
5353
assert X.sycl_device == q.sycl_device
@@ -57,13 +57,17 @@ def test_allocate_usm_ndarray(shape, usm_type):
5757

5858

5959
def test_usm_ndarray_flags():
60-
assert dpt.usm_ndarray((5,)).flags.fc
61-
assert dpt.usm_ndarray((5, 2)).flags.c_contiguous
62-
assert dpt.usm_ndarray((5, 2), order="F").flags.f_contiguous
63-
assert dpt.usm_ndarray((5, 1, 2), order="F").flags.f_contiguous
64-
assert dpt.usm_ndarray((5, 1, 2), strides=(2, 0, 1)).flags.c_contiguous
65-
assert dpt.usm_ndarray((5, 1, 2), strides=(1, 0, 5)).flags.f_contiguous
66-
assert dpt.usm_ndarray((5, 1, 1), strides=(1, 0, 1)).flags.fc
60+
assert dpt.usm_ndarray((5,), dtype="i4").flags.fc
61+
assert dpt.usm_ndarray((5, 2), dtype="i4").flags.c_contiguous
62+
assert dpt.usm_ndarray((5, 2), dtype="i4", order="F").flags.f_contiguous
63+
assert dpt.usm_ndarray((5, 1, 2), dtype="i4", order="F").flags.f_contiguous
64+
assert dpt.usm_ndarray(
65+
(5, 1, 2), dtype="i4", strides=(2, 0, 1)
66+
).flags.c_contiguous
67+
assert dpt.usm_ndarray(
68+
(5, 1, 2), dtype="i4", strides=(1, 0, 5)
69+
).flags.f_contiguous
70+
assert dpt.usm_ndarray((5, 1, 1), dtype="i4", strides=(1, 0, 1)).flags.fc
6771

6872

6973
@pytest.mark.parametrize(
@@ -88,6 +92,8 @@ def test_usm_ndarray_flags():
8892
],
8993
)
9094
def test_dtypes(dtype):
95+
q = get_queue_or_skip()
96+
skip_if_dtype_not_supported(dtype, q)
9197
Xusm = dpt.usm_ndarray((1,), dtype=dtype)
9298
assert Xusm.itemsize == dpt.dtype(dtype).itemsize
9399
expected_fmt = (dpt.dtype(dtype).str)[1:]
@@ -169,15 +175,15 @@ def test_copy_scalar_with_method(method, shape, dtype):
169175
@pytest.mark.parametrize("func", [bool, float, int, complex])
170176
@pytest.mark.parametrize("shape", [(2,), (1, 2), (3, 4, 5), (0,)])
171177
def test_copy_scalar_invalid_shape(func, shape):
172-
X = dpt.usm_ndarray(shape)
178+
X = dpt.usm_ndarray(shape, dtype="i8")
173179
with pytest.raises(ValueError):
174180
func(X)
175181

176182

177183
def test_index_noninteger():
178184
import operator
179185

180-
X = dpt.usm_ndarray(1, "d")
186+
X = dpt.usm_ndarray(1, "f4")
181187
with pytest.raises(IndexError):
182188
operator.index(X)
183189

@@ -283,7 +289,7 @@ def test_slice_suai(usm_type):
283289

284290

285291
def test_slicing_basic():
286-
Xusm = dpt.usm_ndarray((10, 5), dtype="c16")
292+
Xusm = dpt.usm_ndarray((10, 5), dtype="c8")
287293
Xusm[None]
288294
Xusm[...]
289295
Xusm[8]
@@ -318,20 +324,20 @@ def test_ctor_invalid_order():
318324

319325

320326
def test_ctor_buffer_kwarg():
321-
dpt.usm_ndarray(10, buffer=b"device")
327+
dpt.usm_ndarray(10, dtype="i8", buffer=b"device")
322328
with pytest.raises(ValueError):
323329
dpt.usm_ndarray(10, buffer="invalid_param")
324-
Xusm = dpt.usm_ndarray((10, 5), dtype="c16")
330+
Xusm = dpt.usm_ndarray((10, 5), dtype="c8")
325331
X2 = dpt.usm_ndarray(Xusm.shape, buffer=Xusm, dtype=Xusm.dtype)
326332
assert np.array_equal(
327333
Xusm.usm_data.copy_to_host(), X2.usm_data.copy_to_host()
328334
)
329335
with pytest.raises(ValueError):
330-
dpt.usm_ndarray(10, buffer=dict())
336+
dpt.usm_ndarray(10, dtype="i4", buffer=dict())
331337

332338

333339
def test_usm_ndarray_props():
334-
Xusm = dpt.usm_ndarray((10, 5), dtype="c16", order="F")
340+
Xusm = dpt.usm_ndarray((10, 5), dtype="c8", order="F")
335341
Xusm.ndim
336342
repr(Xusm)
337343
Xusm.flags
@@ -348,7 +354,7 @@ def test_usm_ndarray_props():
348354

349355

350356
def test_datapi_device():
351-
X = dpt.usm_ndarray(1)
357+
X = dpt.usm_ndarray(1, dtype="i4")
352358
dev_t = type(X.device)
353359
with pytest.raises(TypeError):
354360
dev_t()
@@ -387,7 +393,7 @@ def _pyx_capi_fnptr_to_callable(
387393

388394

389395
def test_pyx_capi_get_data():
390-
X = dpt.usm_ndarray(17)[1::2]
396+
X = dpt.usm_ndarray(17, dtype="i8")[1::2]
391397
get_data_fn = _pyx_capi_fnptr_to_callable(
392398
X,
393399
"UsmNDArray_GetData",
@@ -400,7 +406,7 @@ def test_pyx_capi_get_data():
400406

401407

402408
def test_pyx_capi_get_shape():
403-
X = dpt.usm_ndarray(17)[1::2]
409+
X = dpt.usm_ndarray(17, dtype="u4")[1::2]
404410
get_shape_fn = _pyx_capi_fnptr_to_callable(
405411
X,
406412
"UsmNDArray_GetShape",
@@ -413,7 +419,7 @@ def test_pyx_capi_get_shape():
413419

414420

415421
def test_pyx_capi_get_strides():
416-
X = dpt.usm_ndarray(17)[1::2]
422+
X = dpt.usm_ndarray(17, dtype="f4")[1::2]
417423
get_strides_fn = _pyx_capi_fnptr_to_callable(
418424
X,
419425
"UsmNDArray_GetStrides",
@@ -429,7 +435,7 @@ def test_pyx_capi_get_strides():
429435

430436

431437
def test_pyx_capi_get_ndim():
432-
X = dpt.usm_ndarray(17)[1::2]
438+
X = dpt.usm_ndarray(17, dtype="?")[1::2]
433439
get_ndim_fn = _pyx_capi_fnptr_to_callable(
434440
X,
435441
"UsmNDArray_GetNDim",
@@ -440,7 +446,7 @@ def test_pyx_capi_get_ndim():
440446

441447

442448
def test_pyx_capi_get_typenum():
443-
X = dpt.usm_ndarray(17)[1::2]
449+
X = dpt.usm_ndarray(17, dtype="c8")[1::2]
444450
get_typenum_fn = _pyx_capi_fnptr_to_callable(
445451
X,
446452
"UsmNDArray_GetTypenum",
@@ -453,7 +459,7 @@ def test_pyx_capi_get_typenum():
453459

454460

455461
def test_pyx_capi_get_elemsize():
456-
X = dpt.usm_ndarray(17)[1::2]
462+
X = dpt.usm_ndarray(17, dtype="u8")[1::2]
457463
get_elemsize_fn = _pyx_capi_fnptr_to_callable(
458464
X,
459465
"UsmNDArray_GetElementSize",
@@ -466,7 +472,7 @@ def test_pyx_capi_get_elemsize():
466472

467473

468474
def test_pyx_capi_get_flags():
469-
X = dpt.usm_ndarray(17)[1::2]
475+
X = dpt.usm_ndarray(17, dtype="i8")[1::2]
470476
get_flags_fn = _pyx_capi_fnptr_to_callable(
471477
X,
472478
"UsmNDArray_GetFlags",
@@ -478,7 +484,7 @@ def test_pyx_capi_get_flags():
478484

479485

480486
def test_pyx_capi_get_offset():
481-
X = dpt.usm_ndarray(17)[1::2]
487+
X = dpt.usm_ndarray(17, dtype="u2")[1::2]
482488
get_offset_fn = _pyx_capi_fnptr_to_callable(
483489
X,
484490
"UsmNDArray_GetOffset",
@@ -491,7 +497,7 @@ def test_pyx_capi_get_offset():
491497

492498

493499
def test_pyx_capi_get_queue_ref():
494-
X = dpt.usm_ndarray(17)[1::2]
500+
X = dpt.usm_ndarray(17, dtype="i2")[1::2]
495501
get_queue_ref_fn = _pyx_capi_fnptr_to_callable(
496502
X,
497503
"UsmNDArray_GetQueueRef",
@@ -521,7 +527,7 @@ def _pyx_capi_int(X, pyx_capi_name, caps_name=b"int", val_restype=ctypes.c_int):
521527

522528

523529
def test_pyx_capi_check_constants():
524-
X = dpt.usm_ndarray(17)[1::2]
530+
X = dpt.usm_ndarray(17, dtype="i1")[1::2]
525531
cc_flag = _pyx_capi_int(X, "USM_ARRAY_C_CONTIGUOUS")
526532
assert cc_flag > 0 and 0 == (cc_flag & (cc_flag - 1))
527533
fc_flag = _pyx_capi_int(X, "USM_ARRAY_F_CONTIGUOUS")
@@ -598,6 +604,7 @@ def test_pyx_capi_check_constants():
598604
@pytest.mark.parametrize("usm_type", ["device", "shared", "host"])
599605
def test_tofrom_numpy(shape, dtype, usm_type):
600606
q = get_queue_or_skip()
607+
skip_if_dtype_not_supported(dtype, q)
601608
Xnp = np.zeros(shape, dtype=dtype)
602609
Xusm = dpt.from_numpy(Xnp, usm_type=usm_type, sycl_queue=q)
603610
Ynp = np.ones(shape, dtype=dtype)
@@ -733,7 +740,7 @@ def relaxed_strides_equal(st1, st2, sh):
733740
4,
734741
5,
735742
)
736-
X = dpt.usm_ndarray(sh_s, dtype="d")
743+
X = dpt.usm_ndarray(sh_s, dtype="i8")
737744
X.shape = sh_f
738745
assert X.shape == sh_f
739746
assert relaxed_strides_equal(X.strides, cc_strides(sh_f), sh_f)
@@ -750,27 +757,27 @@ def relaxed_strides_equal(st1, st2, sh):
750757
4,
751758
5,
752759
)
753-
X = dpt.usm_ndarray(sh_s, dtype="d", order="C")
760+
X = dpt.usm_ndarray(sh_s, dtype="u4", order="C")
754761
X.shape = sh_f
755762
assert X.shape == sh_f
756763
assert relaxed_strides_equal(X.strides, cc_strides(sh_f), sh_f)
757764

758765
sh_s = (2, 3, 4, 5)
759766
sh_f = (4, 3, 2, 5)
760-
X = dpt.usm_ndarray(sh_s, dtype="d")
767+
X = dpt.usm_ndarray(sh_s, dtype="f4")
761768
X.shape = sh_f
762769
assert relaxed_strides_equal(X.strides, cc_strides(sh_f), sh_f)
763770

764771
sh_s = (2, 3, 4, 5)
765772
sh_f = (4, 3, 1, 2, 5)
766-
X = dpt.usm_ndarray(sh_s, dtype="d")
773+
X = dpt.usm_ndarray(sh_s, dtype="?")
767774
X.shape = sh_f
768775
assert relaxed_strides_equal(X.strides, cc_strides(sh_f), sh_f)
769776

770-
X = dpt.usm_ndarray(sh_s, dtype="d")
777+
X = dpt.usm_ndarray(sh_s, dtype="u4")
771778
with pytest.raises(TypeError):
772779
X.shape = "abcbe"
773-
X = dpt.usm_ndarray((4, 4), dtype="d")[::2, ::2]
780+
X = dpt.usm_ndarray((4, 4), dtype="u1")[::2, ::2]
774781
with pytest.raises(AttributeError):
775782
X.shape = (4,)
776783
X = dpt.usm_ndarray((0,), dtype="i4")
@@ -814,7 +821,7 @@ def test_dlpack():
814821

815822

816823
def test_to_device():
817-
X = dpt.usm_ndarray(1, "d")
824+
X = dpt.usm_ndarray(1, "f4")
818825
for dev in dpctl.get_devices():
819826
if dev.default_selector_score > 0:
820827
Y = X.to_device(dev)
@@ -900,7 +907,7 @@ def test_reshape():
900907
W = dpt.reshape(Z, (-1,), order="C")
901908
assert W.shape == (Z.size,)
902909

903-
X = dpt.usm_ndarray((1,))
910+
X = dpt.usm_ndarray((1,), dtype="i8")
904911
Y = dpt.reshape(X, X.shape)
905912
assert Y.flags == X.flags
906913

@@ -970,7 +977,9 @@ def test_real_imag_views():
970977
_all_dtypes,
971978
)
972979
def test_zeros(dtype):
973-
X = dpt.zeros(10, dtype=dtype)
980+
q = get_queue_or_skip()
981+
skip_if_dtype_not_supported(dtype, q)
982+
X = dpt.zeros(10, dtype=dtype, sycl_queue=q)
974983
assert np.array_equal(dpt.asnumpy(X), np.zeros(10, dtype=dtype))
975984

976985

@@ -1197,6 +1206,7 @@ def test_linspace_fp_max(dtype):
11971206
)
11981207
def test_empty_like(dt, usm_kind):
11991208
q = get_queue_or_skip()
1209+
skip_if_dtype_not_supported(dt, q)
12001210

12011211
X = dpt.empty((4, 5), dtype=dt, usm_type=usm_kind, sycl_queue=q)
12021212
Y = dpt.empty_like(X)
@@ -1232,6 +1242,7 @@ def test_empty_unexpected_data_type():
12321242
)
12331243
def test_zeros_like(dt, usm_kind):
12341244
q = get_queue_or_skip()
1245+
skip_if_dtype_not_supported(dt, q)
12351246

12361247
X = dpt.empty((4, 5), dtype=dt, usm_type=usm_kind, sycl_queue=q)
12371248
Y = dpt.zeros_like(X)

0 commit comments

Comments
 (0)