Skip to content

Commit 5f1b083

Browse files
committed
Apply fallback to numpy for all unsupported functions on cuda device.
1 parent b6bd08a commit 5f1b083

10 files changed

+125
-40
lines changed

dpnp/dpnp_iface.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,10 @@
6969
"get_normalized_queue_device",
7070
"get_result_array",
7171
"get_usm_ndarray",
72+
"is_cuda_backend",
7273
"get_usm_ndarray_or_scalar",
7374
"is_supported_array_or_scalar",
7475
"is_supported_array_type",
75-
"not_implemented_for_cuda_backend",
7676
"synchronize_array_data",
7777
]
7878

@@ -758,6 +758,37 @@ def get_usm_ndarray_or_scalar(a):
758758
return a if dpnp.isscalar(a) else get_usm_ndarray(a)
759759

760760

761+
def is_cuda_backend(obj=None):
762+
"""
763+
Checks that object has a cuda backend.
764+
765+
Parameters
766+
----------
767+
obj : {Device, SyclDevice, SyclQueue, dpnp.ndarray, usm_ndarray, None},
768+
optional
769+
An input object with sycl_device property to check device backend.
770+
If obj is ``None``, device backend will be checked for the default
771+
queue.
772+
Default: ``None``.
773+
774+
Returns
775+
-------
776+
out : bool
777+
Return ``True`` if object has a cuda backend, otherwise``False``.
778+
779+
"""
780+
781+
if obj is None:
782+
sycl_device = dpctl.SyclQueue().sycl_device
783+
elif isinstance(obj, dpctl.SyclDevice):
784+
sycl_device = obj
785+
else:
786+
sycl_device = getattr(obj, "sycl_device", None)
787+
if sycl_device is not None and "cuda" in sycl_device.backend.name:
788+
return True
789+
return False
790+
791+
761792
def is_supported_array_or_scalar(a):
762793
"""
763794
Return ``True`` if `a` is a scalar or an array of either
@@ -801,21 +832,6 @@ def is_supported_array_type(a):
801832
return isinstance(a, (dpnp_array, dpt.usm_ndarray))
802833

803834

804-
def not_implemented_for_cuda_backend(obj):
805-
"""
806-
Raise NotImplementedError for cuda devices.
807-
808-
Parameters
809-
----------
810-
obj : {SyclDevice, SyclQueue, dpnp.ndarray, usm_ndarray}
811-
An input object with sycl_device property to check device backend.
812-
813-
"""
814-
sycl_device = getattr(obj, "sycl_device", None)
815-
if sycl_device is not None and "cuda" in sycl_device.backend.name:
816-
raise NotImplementedError("function not implemented for cuda backend")
817-
818-
819835
def synchronize_array_data(a):
820836
"""
821837
The dpctl interface was reworked to make asynchronous execution.

dpnp/dpnp_iface_indexing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,6 @@ def choose(x1, choices, out=None, mode="raise"):
174174
175175
"""
176176

177-
dpnp.not_implemented_for_cuda_backend(x1)
178-
179177
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
180178

181179
choices_list = []
@@ -195,6 +193,8 @@ def choose(x1, choices, out=None, mode="raise"):
195193
pass
196194
elif not choices_list:
197195
pass
196+
elif dpnp.is_cuda_backend(x1):
197+
pass
198198
else:
199199
size = x1_desc.size
200200
choices_size = choices_list[0].size

dpnp/dpnp_iface_libmath.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,10 @@ def erf(in_array1):
7878
7979
"""
8080

81-
dpnp.not_implemented_for_cuda_backend(in_array1)
82-
8381
x1_desc = dpnp.get_dpnp_descriptor(
8482
in_array1, copy_when_strides=False, copy_when_nondefault_queue=False
8583
)
86-
if x1_desc:
84+
if x1_desc and dpnp.is_cuda_backend(in_array1):
8785
return dpnp_erf(x1_desc).get_pyobj()
8886

8987
result = create_output_descriptor_py(

dpnp/dpnp_iface_mathematical.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2477,11 +2477,14 @@ def modf(x1, **kwargs):
24772477
24782478
"""
24792479

2480-
dpnp.not_implemented_for_cuda_backend(x1)
2481-
24822480
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
2483-
if x1_desc and not kwargs:
2484-
return dpnp_modf(x1_desc)
2481+
if x1_desc:
2482+
if not kwargs:
2483+
pass
2484+
elif dpnp.is_cuda_backend(x1):
2485+
pass
2486+
else:
2487+
return dpnp_modf(x1_desc)
24852488

24862489
return call_origin(numpy.modf, x1, **kwargs)
24872490

dpnp/dpnp_iface_sorting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,6 @@ def partition(x1, kth, axis=-1, kind="introselect", order=None):
174174
175175
"""
176176

177-
dpnp.not_implemented_for_cuda_backend(x1)
178-
179177
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
180178
if x1_desc:
181179
if not isinstance(kth, int):
@@ -190,6 +188,8 @@ def partition(x1, kth, axis=-1, kind="introselect", order=None):
190188
pass
191189
elif order is not None:
192190
pass
191+
elif dpnp.is_cuda_backend(x1):
192+
pass
193193
else:
194194
return dpnp_partition(x1_desc, kth, axis, kind, order).get_pyobj()
195195

dpnp/dpnp_iface_statistics.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -371,8 +371,6 @@ def correlate(x1, x2, mode="valid"):
371371
372372
"""
373373

374-
dpnp.not_implemented_for_cuda_backend(x1)
375-
376374
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
377375
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_nondefault_queue=False)
378376
if x1_desc and x2_desc:
@@ -382,6 +380,8 @@ def correlate(x1, x2, mode="valid"):
382380
pass
383381
elif mode != "valid":
384382
pass
383+
elif dpnp.is_cuda_backend(x1) or dpnp.is_cuda_backend(x2):
384+
pass
385385
else:
386386
return dpnp_correlate(x1_desc, x2_desc).get_pyobj()
387387

@@ -657,8 +657,6 @@ def median(x1, axis=None, out=None, overwrite_input=False, keepdims=False):
657657
658658
"""
659659

660-
dpnp.not_implemented_for_cuda_backend(x1)
661-
662660
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
663661
if x1_desc:
664662
if axis is not None:
@@ -669,6 +667,8 @@ def median(x1, axis=None, out=None, overwrite_input=False, keepdims=False):
669667
pass
670668
elif keepdims:
671669
pass
670+
elif dpnp.is_cuda_backend(x1):
671+
pass
672672
else:
673673
result_obj = dpnp_median(x1_desc).get_pyobj()
674674
result = dpnp.convert_single_elem_array_to_scalar(result_obj)

0 commit comments

Comments
 (0)