45
45
def test_allocate_usm_ndarray (shape , usm_type ):
46
46
q = get_queue_or_skip ()
47
47
X = dpt .usm_ndarray (
48
- shape , dtype = "d " , buffer = usm_type , buffer_ctor_kwargs = {"queue" : q }
48
+ shape , dtype = "i8 " , buffer = usm_type , buffer_ctor_kwargs = {"queue" : q }
49
49
)
50
- Xnp = np .ndarray (shape , dtype = "d " )
50
+ Xnp = np .ndarray (shape , dtype = "i8 " )
51
51
assert X .usm_type == usm_type
52
52
assert X .sycl_context == q .sycl_context
53
53
assert X .sycl_device == q .sycl_device
@@ -57,13 +57,17 @@ def test_allocate_usm_ndarray(shape, usm_type):
57
57
58
58
59
59
def test_usm_ndarray_flags ():
60
- assert dpt .usm_ndarray ((5 ,)).flags .fc
61
- assert dpt .usm_ndarray ((5 , 2 )).flags .c_contiguous
62
- assert dpt .usm_ndarray ((5 , 2 ), order = "F" ).flags .f_contiguous
63
- assert dpt .usm_ndarray ((5 , 1 , 2 ), order = "F" ).flags .f_contiguous
64
- assert dpt .usm_ndarray ((5 , 1 , 2 ), strides = (2 , 0 , 1 )).flags .c_contiguous
65
- assert dpt .usm_ndarray ((5 , 1 , 2 ), strides = (1 , 0 , 5 )).flags .f_contiguous
66
- assert dpt .usm_ndarray ((5 , 1 , 1 ), strides = (1 , 0 , 1 )).flags .fc
60
+ assert dpt .usm_ndarray ((5 ,), dtype = "i4" ).flags .fc
61
+ assert dpt .usm_ndarray ((5 , 2 ), dtype = "i4" ).flags .c_contiguous
62
+ assert dpt .usm_ndarray ((5 , 2 ), dtype = "i4" , order = "F" ).flags .f_contiguous
63
+ assert dpt .usm_ndarray ((5 , 1 , 2 ), dtype = "i4" , order = "F" ).flags .f_contiguous
64
+ assert dpt .usm_ndarray (
65
+ (5 , 1 , 2 ), dtype = "i4" , strides = (2 , 0 , 1 )
66
+ ).flags .c_contiguous
67
+ assert dpt .usm_ndarray (
68
+ (5 , 1 , 2 ), dtype = "i4" , strides = (1 , 0 , 5 )
69
+ ).flags .f_contiguous
70
+ assert dpt .usm_ndarray ((5 , 1 , 1 ), dtype = "i4" , strides = (1 , 0 , 1 )).flags .fc
67
71
68
72
69
73
@pytest .mark .parametrize (
@@ -88,6 +92,8 @@ def test_usm_ndarray_flags():
88
92
],
89
93
)
90
94
def test_dtypes (dtype ):
95
+ q = get_queue_or_skip ()
96
+ skip_if_dtype_not_supported (dtype , q )
91
97
Xusm = dpt .usm_ndarray ((1 ,), dtype = dtype )
92
98
assert Xusm .itemsize == dpt .dtype (dtype ).itemsize
93
99
expected_fmt = (dpt .dtype (dtype ).str )[1 :]
@@ -169,15 +175,15 @@ def test_copy_scalar_with_method(method, shape, dtype):
169
175
@pytest .mark .parametrize ("func" , [bool , float , int , complex ])
170
176
@pytest .mark .parametrize ("shape" , [(2 ,), (1 , 2 ), (3 , 4 , 5 ), (0 ,)])
171
177
def test_copy_scalar_invalid_shape (func , shape ):
172
- X = dpt .usm_ndarray (shape )
178
+ X = dpt .usm_ndarray (shape , dtype = "i8" )
173
179
with pytest .raises (ValueError ):
174
180
func (X )
175
181
176
182
177
183
def test_index_noninteger ():
178
184
import operator
179
185
180
- X = dpt .usm_ndarray (1 , "d " )
186
+ X = dpt .usm_ndarray (1 , "f4 " )
181
187
with pytest .raises (IndexError ):
182
188
operator .index (X )
183
189
@@ -283,7 +289,7 @@ def test_slice_suai(usm_type):
283
289
284
290
285
291
def test_slicing_basic ():
286
- Xusm = dpt .usm_ndarray ((10 , 5 ), dtype = "c16 " )
292
+ Xusm = dpt .usm_ndarray ((10 , 5 ), dtype = "c8 " )
287
293
Xusm [None ]
288
294
Xusm [...]
289
295
Xusm [8 ]
@@ -318,20 +324,20 @@ def test_ctor_invalid_order():
318
324
319
325
320
326
def test_ctor_buffer_kwarg ():
321
- dpt .usm_ndarray (10 , buffer = b"device" )
327
+ dpt .usm_ndarray (10 , dtype = "i8" , buffer = b"device" )
322
328
with pytest .raises (ValueError ):
323
329
dpt .usm_ndarray (10 , buffer = "invalid_param" )
324
- Xusm = dpt .usm_ndarray ((10 , 5 ), dtype = "c16 " )
330
+ Xusm = dpt .usm_ndarray ((10 , 5 ), dtype = "c8 " )
325
331
X2 = dpt .usm_ndarray (Xusm .shape , buffer = Xusm , dtype = Xusm .dtype )
326
332
assert np .array_equal (
327
333
Xusm .usm_data .copy_to_host (), X2 .usm_data .copy_to_host ()
328
334
)
329
335
with pytest .raises (ValueError ):
330
- dpt .usm_ndarray (10 , buffer = dict ())
336
+ dpt .usm_ndarray (10 , dtype = "i4" , buffer = dict ())
331
337
332
338
333
339
def test_usm_ndarray_props ():
334
- Xusm = dpt .usm_ndarray ((10 , 5 ), dtype = "c16 " , order = "F" )
340
+ Xusm = dpt .usm_ndarray ((10 , 5 ), dtype = "c8 " , order = "F" )
335
341
Xusm .ndim
336
342
repr (Xusm )
337
343
Xusm .flags
@@ -348,7 +354,7 @@ def test_usm_ndarray_props():
348
354
349
355
350
356
def test_datapi_device ():
351
- X = dpt .usm_ndarray (1 )
357
+ X = dpt .usm_ndarray (1 , dtype = "i4" )
352
358
dev_t = type (X .device )
353
359
with pytest .raises (TypeError ):
354
360
dev_t ()
@@ -387,7 +393,7 @@ def _pyx_capi_fnptr_to_callable(
387
393
388
394
389
395
def test_pyx_capi_get_data ():
390
- X = dpt .usm_ndarray (17 )[1 ::2 ]
396
+ X = dpt .usm_ndarray (17 , dtype = "i8" )[1 ::2 ]
391
397
get_data_fn = _pyx_capi_fnptr_to_callable (
392
398
X ,
393
399
"UsmNDArray_GetData" ,
@@ -400,7 +406,7 @@ def test_pyx_capi_get_data():
400
406
401
407
402
408
def test_pyx_capi_get_shape ():
403
- X = dpt .usm_ndarray (17 )[1 ::2 ]
409
+ X = dpt .usm_ndarray (17 , dtype = "u4" )[1 ::2 ]
404
410
get_shape_fn = _pyx_capi_fnptr_to_callable (
405
411
X ,
406
412
"UsmNDArray_GetShape" ,
@@ -413,7 +419,7 @@ def test_pyx_capi_get_shape():
413
419
414
420
415
421
def test_pyx_capi_get_strides ():
416
- X = dpt .usm_ndarray (17 )[1 ::2 ]
422
+ X = dpt .usm_ndarray (17 , dtype = "f4" )[1 ::2 ]
417
423
get_strides_fn = _pyx_capi_fnptr_to_callable (
418
424
X ,
419
425
"UsmNDArray_GetStrides" ,
@@ -429,7 +435,7 @@ def test_pyx_capi_get_strides():
429
435
430
436
431
437
def test_pyx_capi_get_ndim ():
432
- X = dpt .usm_ndarray (17 )[1 ::2 ]
438
+ X = dpt .usm_ndarray (17 , dtype = "?" )[1 ::2 ]
433
439
get_ndim_fn = _pyx_capi_fnptr_to_callable (
434
440
X ,
435
441
"UsmNDArray_GetNDim" ,
@@ -440,7 +446,7 @@ def test_pyx_capi_get_ndim():
440
446
441
447
442
448
def test_pyx_capi_get_typenum ():
443
- X = dpt .usm_ndarray (17 )[1 ::2 ]
449
+ X = dpt .usm_ndarray (17 , dtype = "c8" )[1 ::2 ]
444
450
get_typenum_fn = _pyx_capi_fnptr_to_callable (
445
451
X ,
446
452
"UsmNDArray_GetTypenum" ,
@@ -453,7 +459,7 @@ def test_pyx_capi_get_typenum():
453
459
454
460
455
461
def test_pyx_capi_get_elemsize ():
456
- X = dpt .usm_ndarray (17 )[1 ::2 ]
462
+ X = dpt .usm_ndarray (17 , dtype = "u8" )[1 ::2 ]
457
463
get_elemsize_fn = _pyx_capi_fnptr_to_callable (
458
464
X ,
459
465
"UsmNDArray_GetElementSize" ,
@@ -466,7 +472,7 @@ def test_pyx_capi_get_elemsize():
466
472
467
473
468
474
def test_pyx_capi_get_flags ():
469
- X = dpt .usm_ndarray (17 )[1 ::2 ]
475
+ X = dpt .usm_ndarray (17 , dtype = "i8" )[1 ::2 ]
470
476
get_flags_fn = _pyx_capi_fnptr_to_callable (
471
477
X ,
472
478
"UsmNDArray_GetFlags" ,
@@ -478,7 +484,7 @@ def test_pyx_capi_get_flags():
478
484
479
485
480
486
def test_pyx_capi_get_offset ():
481
- X = dpt .usm_ndarray (17 )[1 ::2 ]
487
+ X = dpt .usm_ndarray (17 , dtype = "u2" )[1 ::2 ]
482
488
get_offset_fn = _pyx_capi_fnptr_to_callable (
483
489
X ,
484
490
"UsmNDArray_GetOffset" ,
@@ -491,7 +497,7 @@ def test_pyx_capi_get_offset():
491
497
492
498
493
499
def test_pyx_capi_get_queue_ref ():
494
- X = dpt .usm_ndarray (17 )[1 ::2 ]
500
+ X = dpt .usm_ndarray (17 , dtype = "i2" )[1 ::2 ]
495
501
get_queue_ref_fn = _pyx_capi_fnptr_to_callable (
496
502
X ,
497
503
"UsmNDArray_GetQueueRef" ,
@@ -521,7 +527,7 @@ def _pyx_capi_int(X, pyx_capi_name, caps_name=b"int", val_restype=ctypes.c_int):
521
527
522
528
523
529
def test_pyx_capi_check_constants ():
524
- X = dpt .usm_ndarray (17 )[1 ::2 ]
530
+ X = dpt .usm_ndarray (17 , dtype = "i1" )[1 ::2 ]
525
531
cc_flag = _pyx_capi_int (X , "USM_ARRAY_C_CONTIGUOUS" )
526
532
assert cc_flag > 0 and 0 == (cc_flag & (cc_flag - 1 ))
527
533
fc_flag = _pyx_capi_int (X , "USM_ARRAY_F_CONTIGUOUS" )
@@ -598,6 +604,7 @@ def test_pyx_capi_check_constants():
598
604
@pytest .mark .parametrize ("usm_type" , ["device" , "shared" , "host" ])
599
605
def test_tofrom_numpy (shape , dtype , usm_type ):
600
606
q = get_queue_or_skip ()
607
+ skip_if_dtype_not_supported (dtype , q )
601
608
Xnp = np .zeros (shape , dtype = dtype )
602
609
Xusm = dpt .from_numpy (Xnp , usm_type = usm_type , sycl_queue = q )
603
610
Ynp = np .ones (shape , dtype = dtype )
@@ -733,7 +740,7 @@ def relaxed_strides_equal(st1, st2, sh):
733
740
4 ,
734
741
5 ,
735
742
)
736
- X = dpt .usm_ndarray (sh_s , dtype = "d " )
743
+ X = dpt .usm_ndarray (sh_s , dtype = "i8 " )
737
744
X .shape = sh_f
738
745
assert X .shape == sh_f
739
746
assert relaxed_strides_equal (X .strides , cc_strides (sh_f ), sh_f )
@@ -750,27 +757,27 @@ def relaxed_strides_equal(st1, st2, sh):
750
757
4 ,
751
758
5 ,
752
759
)
753
- X = dpt .usm_ndarray (sh_s , dtype = "d " , order = "C" )
760
+ X = dpt .usm_ndarray (sh_s , dtype = "u4 " , order = "C" )
754
761
X .shape = sh_f
755
762
assert X .shape == sh_f
756
763
assert relaxed_strides_equal (X .strides , cc_strides (sh_f ), sh_f )
757
764
758
765
sh_s = (2 , 3 , 4 , 5 )
759
766
sh_f = (4 , 3 , 2 , 5 )
760
- X = dpt .usm_ndarray (sh_s , dtype = "d " )
767
+ X = dpt .usm_ndarray (sh_s , dtype = "f4 " )
761
768
X .shape = sh_f
762
769
assert relaxed_strides_equal (X .strides , cc_strides (sh_f ), sh_f )
763
770
764
771
sh_s = (2 , 3 , 4 , 5 )
765
772
sh_f = (4 , 3 , 1 , 2 , 5 )
766
- X = dpt .usm_ndarray (sh_s , dtype = "d " )
773
+ X = dpt .usm_ndarray (sh_s , dtype = "? " )
767
774
X .shape = sh_f
768
775
assert relaxed_strides_equal (X .strides , cc_strides (sh_f ), sh_f )
769
776
770
- X = dpt .usm_ndarray (sh_s , dtype = "d " )
777
+ X = dpt .usm_ndarray (sh_s , dtype = "u4 " )
771
778
with pytest .raises (TypeError ):
772
779
X .shape = "abcbe"
773
- X = dpt .usm_ndarray ((4 , 4 ), dtype = "d " )[::2 , ::2 ]
780
+ X = dpt .usm_ndarray ((4 , 4 ), dtype = "u1 " )[::2 , ::2 ]
774
781
with pytest .raises (AttributeError ):
775
782
X .shape = (4 ,)
776
783
X = dpt .usm_ndarray ((0 ,), dtype = "i4" )
@@ -814,7 +821,7 @@ def test_dlpack():
814
821
815
822
816
823
def test_to_device ():
817
- X = dpt .usm_ndarray (1 , "d " )
824
+ X = dpt .usm_ndarray (1 , "f4 " )
818
825
for dev in dpctl .get_devices ():
819
826
if dev .default_selector_score > 0 :
820
827
Y = X .to_device (dev )
@@ -900,7 +907,7 @@ def test_reshape():
900
907
W = dpt .reshape (Z , (- 1 ,), order = "C" )
901
908
assert W .shape == (Z .size ,)
902
909
903
- X = dpt .usm_ndarray ((1 ,))
910
+ X = dpt .usm_ndarray ((1 ,), dtype = "i8" )
904
911
Y = dpt .reshape (X , X .shape )
905
912
assert Y .flags == X .flags
906
913
@@ -970,7 +977,9 @@ def test_real_imag_views():
970
977
_all_dtypes ,
971
978
)
972
979
def test_zeros (dtype ):
973
- X = dpt .zeros (10 , dtype = dtype )
980
+ q = get_queue_or_skip ()
981
+ skip_if_dtype_not_supported (dtype , q )
982
+ X = dpt .zeros (10 , dtype = dtype , sycl_queue = q )
974
983
assert np .array_equal (dpt .asnumpy (X ), np .zeros (10 , dtype = dtype ))
975
984
976
985
@@ -1197,6 +1206,7 @@ def test_linspace_fp_max(dtype):
1197
1206
)
1198
1207
def test_empty_like (dt , usm_kind ):
1199
1208
q = get_queue_or_skip ()
1209
+ skip_if_dtype_not_supported (dt , q )
1200
1210
1201
1211
X = dpt .empty ((4 , 5 ), dtype = dt , usm_type = usm_kind , sycl_queue = q )
1202
1212
Y = dpt .empty_like (X )
@@ -1232,6 +1242,7 @@ def test_empty_unexpected_data_type():
1232
1242
)
1233
1243
def test_zeros_like (dt , usm_kind ):
1234
1244
q = get_queue_or_skip ()
1245
+ skip_if_dtype_not_supported (dt , q )
1235
1246
1236
1247
X = dpt .empty ((4 , 5 ), dtype = dt , usm_type = usm_kind , sycl_queue = q )
1237
1248
Y = dpt .zeros_like (X )
0 commit comments