Skip to content

Add integer dtypes (int8, int16, uint8-uint64) to dpnp interface #2230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
a098910
Add integer dtypes (int8,int16, uint8-uint64) to dpnp interface
AlexanderKalistratov Oct 9, 2024
9c28750
update tests for newly added integr dtypes (int8, int16, uint8-uint64…
vtavana Jan 23, 2025
eb68f63
merge with master branch and resolve conflicts
vtavana Jan 23, 2025
d0870b1
update newly added tests for elemnt-wise in-place operators
vtavana Jan 24, 2025
56db205
remove debugging leftover
vtavana Jan 24, 2025
a9af727
Merge branch 'master' into extended_types_support
vtavana Jan 24, 2025
33a6608
Merge branch 'master' into extended_types_support
vtavana Jan 28, 2025
191688b
update TODOs in test_array_api_info.py
vtavana Jan 28, 2025
43fb4cd
update yaml file
vtavana Jan 28, 2025
1c50270
update yaml file
vtavana Jan 28, 2025
bd98e7a
Merge branch 'master' into extended_types_support
vtavana Jan 28, 2025
7c59295
Merge branch 'master' into extended_types_support
vtavana Feb 3, 2025
e90eba9
unmute array-api-test
vtavana Feb 3, 2025
66ef596
Merge branch 'master' into extended_types_support
vtavana Feb 4, 2025
98983bd
Merge branch 'master' into extended_types_support
vtavana Feb 5, 2025
6592abc
Merge branch 'master' into extended_types_support
vtavana Feb 7, 2025
e01c969
update puclic CI workflow for all integer dtypes (#2298)
vtavana Feb 10, 2025
6549d2b
Merge branch 'master' into extended_types_support
vtavana Feb 10, 2025
995c385
Merge branch 'master' into extended_types_support
vtavana Feb 13, 2025
14ece2e
Merge branch 'master' into extended_types_support
vtavana Feb 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/workflows/array-api-skips.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
# array API tests to be skipped

# no 'uint8' dtype
array_api_tests/test_array_object.py::test_getitem_masking

# missing unique-like functions
array_api_tests/test_has_names.py::test_has_names[set-unique_all]
array_api_tests/test_has_names.py::test_has_names[set-unique_counts]
Expand Down
37 changes: 29 additions & 8 deletions .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ env:
CONDA_BUILD_INDEX_ENV_PY_VER: '3.12' # conda does not support python 3.13
CONDA_BUILD_VERSION: '25.1.1'
CONDA_INDEX_VERSION: '0.5.0'
LATEST_PYTHON: '3.13'
RERUN_TESTS_ON_FAILURE: 'true'
RUN_TESTS_MAX_ATTEMPTS: 2
TEST_ENV_NAME: 'test'
Expand Down Expand Up @@ -189,7 +190,7 @@ jobs:
id: install_dpnp
continue-on-error: true
run: |
mamba install ${{ env.PACKAGE_NAME }}=${{ env.PACKAGE_VERSION }} pytest python=${{ matrix.python }} ${{ env.TEST_CHANNELS }}
mamba install ${{ env.PACKAGE_NAME }}=${{ env.PACKAGE_VERSION }} pytest pytest-xdist python=${{ matrix.python }} ${{ env.TEST_CHANNELS }}
env:
TEST_CHANNELS: '-c ${{ env.channel-path }} ${{ env.CHANNELS }}'

Expand All @@ -211,22 +212,32 @@ jobs:
- name: Run tests
if: env.RERUN_TESTS_ON_FAILURE != 'true'
run: |
python -m pytest -ra --pyargs ${{ env.PACKAGE_NAME }}.tests
if [[ ${{ matrix.python }} == ${{ env.LATEST_PYTHON }} ]]; then
export DPNP_TEST_ALL_INT_TYPES=1
python -m pytest -ra --pyargs ${{ env.PACKAGE_NAME }}.tests
else
python -m pytest -n auto -ra --pyargs ${{ env.PACKAGE_NAME }}.tests
fi

- name: Run tests
if: env.RERUN_TESTS_ON_FAILURE == 'true'
id: run_tests_linux
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
with:
timeout_minutes: 15
timeout_minutes: 25
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
retry_on: any
command: |
. $CONDA/etc/profile.d/conda.sh
. $CONDA/etc/profile.d/mamba.sh
mamba activate ${{ env.TEST_ENV_NAME }}

python -m pytest -ra --pyargs ${{ env.PACKAGE_NAME }}.tests
if [[ ${{ matrix.python }} == ${{ env.LATEST_PYTHON }} ]]; then
export DPNP_TEST_ALL_INT_TYPES=1
python -m pytest -ra --pyargs ${{ env.PACKAGE_NAME }}.tests
else
python -m pytest -n auto -ra --pyargs ${{ env.PACKAGE_NAME }}.tests
fi

test_windows:
name: Test
Expand Down Expand Up @@ -319,7 +330,7 @@ jobs:
- name: Install dpnp
run: |
@echo on
mamba install ${{ env.PACKAGE_NAME }}=${{ env.PACKAGE_VERSION }} pytest python=${{ matrix.python }} ${{ env.TEST_CHANNELS }}
mamba install ${{ env.PACKAGE_NAME }}=${{ env.PACKAGE_VERSION }} pytest pytest-xdist python=${{ matrix.python }} ${{ env.TEST_CHANNELS }}
env:
TEST_CHANNELS: '-c ${{ env.channel-path }} ${{ env.CHANNELS }}'
MAMBA_NO_LOW_SPEED_LIMIT: 1
Expand Down Expand Up @@ -348,18 +359,28 @@ jobs:
- name: Run tests
if: env.RERUN_TESTS_ON_FAILURE != 'true'
run: |
pytest -ra --pyargs ${{ env.PACKAGE_NAME }}.tests
if (${{ matrix.python }} -eq ${{ env.LATEST_PYTHON }}) {
set DPNP_TEST_ALL_INT_TYPES=1
python -m pytest -ra --pyargs ${{ env.PACKAGE_NAME }}.tests
} else {
python -m pytest -n auto -ra --pyargs ${{ env.PACKAGE_NAME }}.tests
}

- name: Run tests
if: env.RERUN_TESTS_ON_FAILURE == 'true'
id: run_tests_win
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
with:
timeout_minutes: 17
timeout_minutes: 35
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
retry_on: any
command: |
python -m pytest -ra --pyargs ${{ env.PACKAGE_NAME }}.tests
if ( ${{ matrix.python }} -eq ${{ env.LATEST_PYTHON }} ) {
set DPNP_TEST_ALL_INT_TYPES=1
python -m pytest -ra --pyargs ${{ env.PACKAGE_NAME }}.tests
} else {
python -m pytest -n auto -ra --pyargs ${{ env.PACKAGE_NAME }}.tests
}

upload:
name: Upload
Expand Down
15 changes: 14 additions & 1 deletion doc/reference/dtypes_table.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,35 @@ Table below shows a list of all supported data types (dtypes) and constants of t
- Constants
* -
- :obj:`bool <numpy.bool_>`
- :obj:`int8 <numpy.int8>`
- :obj:`int16 <numpy.int16>`
- :obj:`int32 <numpy.int32>`
- :obj:`int64 <numpy.int64>`
- :obj:`uint8 <numpy.uint8>`
- :obj:`uint16 <numpy.uint16>`
- :obj:`uint32 <numpy.uint32>`
- :obj:`uint64 <numpy.uint64>`
- :obj:`float32 <numpy.float32>`
- :obj:`float64 <numpy.float64>`
- :obj:`complex64 <numpy.complex64>`
- :obj:`complex128 <numpy.complex128>`
-
- :obj:`bool_ <numpy.bool_>`
- :obj:`byte <numpy.byte>`
- :obj:`cdouble <numpy.cdouble>`
- :obj:`csingle <numpy.csingle>`
- :obj:`double <numpy.double>`
- :obj:`float16 <numpy.float16>`
- :obj:`int <numpy.int>`
- :obj:`int_ <numpy.int_>`
- :obj:`intc <numpy.intc>`
- :obj:`intp <numpy.intp>`
- :obj:`longlong <numpy.longlong>`
- :obj:`single <numpy.single>`
- :obj:`ubyte <numpy.ubyte>`
- :obj:`uintc <numpy.uintc>`
- :obj:`uintp <numpy.uintp>`
- :obj:`ushort <numpy.ushort>`
- :obj:`ulonglong <numpy.ulonglong>`
-
- :obj:`e <numpy.e>`
- :obj:`euler_gamma <numpy.euler_gamma>`
Expand Down
16 changes: 11 additions & 5 deletions dpnp/dpnp_algo/dpnp_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,12 +600,18 @@ def __init__(
def __call__(self, x, decimals=0, out=None, dtype=None):
if decimals != 0:
x_usm = dpnp.get_usm_ndarray(x)
if dpnp.issubdtype(x_usm.dtype, dpnp.integer) and dtype is None:
dtype = x_usm.dtype

out_usm = None if out is None else dpnp.get_usm_ndarray(out)
x_usm = dpt.round(x_usm * 10**decimals, out=out_usm)
res_usm = dpt.divide(x_usm, 10**decimals, out=out_usm)

if dpnp.issubdtype(x_usm.dtype, dpnp.integer):
if decimals < 0:
dtype = x_usm.dtype
x_usm = dpt.round(x_usm * 10**decimals, out=out_usm)
res_usm = dpt.divide(x_usm, 10**decimals, out=out_usm)
else:
res_usm = dpt.round(x_usm, out=out_usm)
else:
x_usm = dpt.round(x_usm * 10**decimals, out=out_usm)
res_usm = dpt.divide(x_usm, 10**decimals, out=out_usm)

if dtype is not None:
res_usm = dpt.astype(res_usm, dtype, copy=False)
Expand Down
32 changes: 22 additions & 10 deletions dpnp/dpnp_iface_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@

"""

# pylint: disable=no-name-in-module
import numpy

import dpnp

from .dpnp_utils import map_dtype_to_device
from .dpnp_utils.dpnp_utils_einsum import dpnp_einsum
from .dpnp_utils.dpnp_utils_linearalgebra import (
dpnp_dot,
Expand All @@ -66,6 +68,20 @@
]


# TODO: implement a specific scalar-array kernel
def _call_multiply(a, b, out=None):
"""Call multiply function for special cases of scalar-array dots."""

sc, arr = (a, b) if dpnp.isscalar(a) else (b, a)
sc_dtype = map_dtype_to_device(type(sc), arr.sycl_device)
res_dtype = dpnp.result_type(sc_dtype, arr)
if out is not None and out.dtype == arr.dtype:
res = dpnp.multiply(a, b, out=out)
else:
res = dpnp.multiply(a, b, dtype=res_dtype)
return dpnp.get_result_array(res, out, casting="no")


def dot(a, b, out=None):
"""
Dot product of `a` and `b`.
Expand Down Expand Up @@ -139,8 +155,7 @@ def dot(a, b, out=None):
raise ValueError("Only C-contiguous array is acceptable.")

if dpnp.isscalar(a) or dpnp.isscalar(b):
# TODO: use specific scalar-vector kernel
return dpnp.multiply(a, b, out=out)
return _call_multiply(a, b, out=out)

a_ndim = a.ndim
b_ndim = b.ndim
Expand Down Expand Up @@ -635,8 +650,7 @@ def inner(a, b):
dpnp.check_supported_arrays_type(a, b, scalar_type=True)

if dpnp.isscalar(a) or dpnp.isscalar(b):
# TODO: use specific scalar-vector kernel
return dpnp.multiply(a, b)
return _call_multiply(a, b)

if a.ndim == 0 or b.ndim == 0:
# TODO: use specific scalar-vector kernel
Expand Down Expand Up @@ -714,8 +728,7 @@ def kron(a, b):
dpnp.check_supported_arrays_type(a, b, scalar_type=True)

if dpnp.isscalar(a) or dpnp.isscalar(b):
# TODO: use specific scalar-vector kernel
return dpnp.multiply(a, b)
return _call_multiply(a, b)

a_ndim = a.ndim
b_ndim = b.ndim
Expand Down Expand Up @@ -1199,8 +1212,7 @@ def tensordot(a, b, axes=2):
raise ValueError(
"One of the inputs is scalar, axes should be zero."
)
# TODO: use specific scalar-vector kernel
return dpnp.multiply(a, b)
return _call_multiply(a, b)

return dpnp_tensordot(a, b, axes=axes)

Expand Down Expand Up @@ -1263,13 +1275,13 @@ def vdot(a, b):
if b.size != 1:
raise ValueError("The second array should be of size one.")
a_conj = numpy.conj(a)
return dpnp.multiply(a_conj, b)
return _call_multiply(a_conj, b)

if dpnp.isscalar(b):
if a.size != 1:
raise ValueError("The first array should be of size one.")
a_conj = dpnp.conj(a)
return dpnp.multiply(a_conj, b)
return _call_multiply(a_conj, b)

if a.ndim == 1 and b.ndim == 1:
return dpnp_dot(a, b, out=None, conjugate=True)
Expand Down
13 changes: 13 additions & 0 deletions dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,7 +1494,20 @@ def copyto(dst, src, casting="same_kind", where=True):
f"but got {type(dst)}"
)
if not dpnp.is_supported_array_type(src):
no_dtype_attr = not hasattr(src, "dtype")
src = dpnp.array(src, sycl_queue=dst.sycl_queue)
if no_dtype_attr:
# This case (scalar, list, etc) needs special handling to
# behave similar to NumPy
if dpnp.issubdtype(src, dpnp.integer) and dpnp.issubdtype(
dst, dpnp.unsignedinteger
):
if dpnp.any(src < 0):
raise OverflowError(
"Cannot copy negative values to an unsigned int array"
)

src = src.astype(dst.dtype)

if not dpnp.can_cast(src.dtype, dst.dtype, casting=casting):
raise TypeError(
Expand Down
32 changes: 30 additions & 2 deletions dpnp/dpnp_iface_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
__all__ = [
"bool",
"bool_",
"byte",
"cdouble",
"complex128",
"complex64",
Expand All @@ -57,8 +58,9 @@
"iinfo",
"inexact",
"inf",
"int",
"int_",
"int8",
"int16",
"int32",
"int64",
"integer",
Expand All @@ -67,12 +69,24 @@
"isdtype",
"issubdtype",
"is_type_supported",
"longlong",
"nan",
"newaxis",
"number",
"pi",
"short",
"signedinteger",
"single",
"ubyte",
"uint8",
"uint16",
"uint32",
"uint64",
"uintc",
"uintp",
"unsignedinteger",
"ushort",
"ulonglong",
]


Expand All @@ -82,6 +96,7 @@
# =============================================================================
bool = numpy.bool_
bool_ = numpy.bool_
byte = numpy.byte
cdouble = numpy.cdouble
complex128 = numpy.complex128
complex64 = numpy.complex64
Expand All @@ -94,16 +109,29 @@
float64 = numpy.float64
floating = numpy.floating
inexact = numpy.inexact
int = numpy.int_
int_ = numpy.int_
int8 = numpy.int8
int16 = numpy.int16
int32 = numpy.int32
int64 = numpy.int64
integer = numpy.integer
intc = numpy.intc
intp = numpy.intp
longlong = numpy.longlong
number = numpy.number
short = numpy.short
signedinteger = numpy.signedinteger
single = numpy.single
ubyte = numpy.ubyte
uint8 = numpy.uint8
uint16 = numpy.uint16
uint32 = numpy.uint32
uint64 = numpy.uint64
uintc = numpy.uintc
uintp = numpy.uintp
unsignedinteger = numpy.unsignedinteger
ushort = numpy.ushort
ulonglong = numpy.ulonglong


# =============================================================================
Expand Down
Loading
Loading