Skip to content

Commit 6171a71

Browse files
authored
Add support of combining arrays with different USM types (#1237)
1 parent 8d54d7c commit 6171a71

File tree

3 files changed

+48
-18
lines changed

3 files changed

+48
-18
lines changed

.github/workflows/conda-package.yml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,15 @@ env:
1010
PACKAGE_NAME: dpnp
1111
MODULE_NAME: dpnp
1212
CHANNELS: '-c dppy/label/dev -c intel -c defaults --override-channels'
13+
TEST_SCOPE: >-
14+
test_arraycreation.py
15+
test_dparray.py
16+
test_fft.py
17+
test_linalg.py
18+
test_mathematical.py
19+
test_random_state.py
20+
test_special.py
21+
test_usm_type.py
1322
VER_JSON_NAME: 'version.json'
1423
VER_SCRIPT1: "import json; f = open('version.json', 'r'); j = json.load(f); f.close(); "
1524
VER_SCRIPT2: "d = j['dpnp'][0]; print('='.join((d[s] for s in ('version', 'build'))))"
@@ -235,7 +244,7 @@ jobs:
235244
# TODO: run the whole scope once the issues on CPU are resolved
236245
- name: Run tests
237246
run: |
238-
python -m pytest -q -ra --disable-warnings -vv test_arraycreation.py test_dparray.py test_fft.py test_linalg.py test_mathematical.py test_random_state.py test_special.py
247+
python -m pytest -q -ra --disable-warnings -vv ${{ env.TEST_SCOPE }}
239248
env:
240249
OCL_ICD_FILENAMES: 'libintelocl.so'
241250
working-directory: ${{ env.tests-path }}
@@ -410,7 +419,7 @@ jobs:
410419
# TODO: run the whole scope once the issues on CPU are resolved
411420
- name: Run tests
412421
run: |
413-
python -m pytest -q -ra --disable-warnings -vv test_arraycreation.py test_dparray.py test_fft.py test_linalg.py test_mathematical.py test_random_state.py test_special.py
422+
python -m pytest -q -ra --disable-warnings -vv ${{ env.TEST_SCOPE }}
414423
working-directory: ${{ env.tests-path }}
415424

416425
upload_linux:

dpnp/dpnp_utils/dpnp_algo_utils.pyx

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -540,22 +540,7 @@ cdef tuple get_common_usm_allocation(dpnp_descriptor x1, dpnp_descriptor x2):
540540
array1_obj = x1.get_array()
541541
array2_obj = x2.get_array()
542542

543-
def get_usm_type(usm_types):
544-
if not isinstance(usm_types, (list, tuple)):
545-
raise TypeError(
546-
"Expected a list or a tuple, got {}".format(type(usm_types))
547-
)
548-
if len(usm_types) == 0:
549-
return None
550-
elif len(usm_types) == 1:
551-
return usm_types[0]
552-
for usm_type1, usm_type2 in zip(usm_types, usm_types[1:]):
553-
if usm_type1 != usm_type2:
554-
return None
555-
return usm_types[0]
556-
557-
# TODO: use similar function from dpctl.utils instead of get_usm_type
558-
common_usm_type = get_usm_type((array1_obj.usm_type, array2_obj.usm_type))
543+
common_usm_type = dpctl.utils.get_coerced_usm_type((array1_obj.usm_type, array2_obj.usm_type))
559544
if common_usm_type is None:
560545
raise ValueError(
561546
"could not recognize common USM type for inputs of USM types {} and {}"

tests/test_usm_type.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import pytest
2+
3+
import dpnp as dp
4+
5+
import dpctl.utils as du
6+
7+
list_of_usm_types = [
8+
"device",
9+
"shared",
10+
"host"
11+
]
12+
13+
14+
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
15+
def test_coerced_usm_types_sum(usm_type):
16+
x = dp.arange(10, usm_type = "device")
17+
y = dp.arange(10, usm_type = usm_type)
18+
19+
z = x + y
20+
21+
assert z.usm_type == x.usm_type
22+
assert z.usm_type == "device"
23+
assert y.usm_type == usm_type
24+
25+
26+
@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types)
27+
@pytest.mark.parametrize("usm_type_y", list_of_usm_types, ids=list_of_usm_types)
28+
def test_coerced_usm_types_mul(usm_type_x, usm_type_y):
29+
x = dp.arange(10, usm_type = usm_type_x)
30+
y = dp.arange(10, usm_type = usm_type_y)
31+
32+
z = x * y
33+
34+
assert x.usm_type == usm_type_x
35+
assert y.usm_type == usm_type_y
36+
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])

0 commit comments

Comments
 (0)