Skip to content

Add support of combining arrays with different USM types #1237

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ env:
PACKAGE_NAME: dpnp
MODULE_NAME: dpnp
CHANNELS: '-c dppy/label/dev -c intel -c defaults --override-channels'
TEST_SCOPE: >-
test_arraycreation.py
test_dparray.py
test_fft.py
test_linalg.py
test_mathematical.py
test_random_state.py
test_special.py
test_usm_type.py
VER_JSON_NAME: 'version.json'
VER_SCRIPT1: "import json; f = open('version.json', 'r'); j = json.load(f); f.close(); "
VER_SCRIPT2: "d = j['dpnp'][0]; print('='.join((d[s] for s in ('version', 'build'))))"
Expand Down Expand Up @@ -235,7 +244,7 @@ jobs:
# TODO: run the whole scope once the issues on CPU are resolved
- name: Run tests
run: |
python -m pytest -q -ra --disable-warnings -vv test_arraycreation.py test_dparray.py test_fft.py test_linalg.py test_mathematical.py test_random_state.py test_special.py
python -m pytest -q -ra --disable-warnings -vv ${{ env.TEST_SCOPE }}
env:
OCL_ICD_FILENAMES: 'libintelocl.so'
working-directory: ${{ env.tests-path }}
Expand Down Expand Up @@ -410,7 +419,7 @@ jobs:
# TODO: run the whole scope once the issues on CPU are resolved
- name: Run tests
run: |
python -m pytest -q -ra --disable-warnings -vv test_arraycreation.py test_dparray.py test_fft.py test_linalg.py test_mathematical.py test_random_state.py test_special.py
python -m pytest -q -ra --disable-warnings -vv ${{ env.TEST_SCOPE }}
working-directory: ${{ env.tests-path }}

upload_linux:
Expand Down
17 changes: 1 addition & 16 deletions dpnp/dpnp_utils/dpnp_algo_utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -540,22 +540,7 @@ cdef tuple get_common_usm_allocation(dpnp_descriptor x1, dpnp_descriptor x2):
array1_obj = x1.get_array()
array2_obj = x2.get_array()

def get_usm_type(usm_types):
if not isinstance(usm_types, (list, tuple)):
raise TypeError(
"Expected a list or a tuple, got {}".format(type(usm_types))
)
if len(usm_types) == 0:
return None
elif len(usm_types) == 1:
return usm_types[0]
for usm_type1, usm_type2 in zip(usm_types, usm_types[1:]):
if usm_type1 != usm_type2:
return None
return usm_types[0]

# TODO: use similar function from dpctl.utils instead of get_usm_type
common_usm_type = get_usm_type((array1_obj.usm_type, array2_obj.usm_type))
common_usm_type = dpctl.utils.get_coerced_usm_type((array1_obj.usm_type, array2_obj.usm_type))
if common_usm_type is None:
raise ValueError(
"could not recognize common USM type for inputs of USM types {} and {}"
Expand Down
36 changes: 36 additions & 0 deletions tests/test_usm_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest

import dpnp as dp

import dpctl.utils as du

list_of_usm_types = [
"device",
"shared",
"host"
]


@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
def test_coerced_usm_types_sum(usm_type):
x = dp.arange(10, usm_type = "device")
y = dp.arange(10, usm_type = usm_type)

z = x + y

assert z.usm_type == x.usm_type
assert z.usm_type == "device"
assert y.usm_type == usm_type


@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types)
@pytest.mark.parametrize("usm_type_y", list_of_usm_types, ids=list_of_usm_types)
def test_coerced_usm_types_mul(usm_type_x, usm_type_y):
x = dp.arange(10, usm_type = usm_type_x)
y = dp.arange(10, usm_type = usm_type_y)

z = x * y

assert x.usm_type == usm_type_x
assert y.usm_type == usm_type_y
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])