Skip to content

Commit 9a2b917

Browse files
authored
implement dpnp.apply_along_axis (#2169)
* implement dpnp.apply_along_axis * add has_support_aspect64 to tests * update doc * address comments
1 parent 078d9a3 commit 9a2b917

16 files changed

+288
-16
lines changed

doc/reference/fft.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
FFT Functions
2-
=============
1+
Discrete Fourier Transform
2+
==========================
33

44
.. https://numpy.org/doc/stable/reference/routines.fft.html
55

doc/reference/functional.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
Functional programming
2+
======================
3+
4+
.. https://numpy.org/doc/stable/reference/routines.functional.html
5+
6+
.. autosummary::
7+
:toctree: generated/
8+
:nosignatures:
9+
10+
dpnp.apply_along_axis
11+
dpnp.apply_over_axes
12+
dpnp.vectorize
13+
dpnp.frompyfunc
14+
dpnp.piecewise

doc/reference/linalg.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Linear Algebra
1+
Linear algebra
22
==============
33

44
.. https://numpy.org/doc/stable/reference/routines.linalg.html

doc/reference/logic.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Logic Functions
1+
Logic functions
22
===============
33

44
.. https://numpy.org/doc/stable/reference/routines.logic.html

doc/reference/manipulation.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Array Manipulation Routines
1+
Array manipulation routines
22
===========================
33

44
.. https://numpy.org/doc/stable/reference/routines.array-manipulation.html

doc/reference/random.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Random Sampling (``dpnp.random``)
1+
Random sampling (``dpnp.random``)
22
=================================
33

44
.. https://numpy.org/doc/stable/reference/random/legacy.html

doc/reference/routines.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Routines
44

55
The following pages describe NumPy-compatible routines.
66
These functions cover a subset of
7-
`NumPy routines <https://docs.scipy.org/doc/numpy/reference/routines.html>`_.
7+
`NumPy routines <https://numpy.org/doc/stable/reference/routines.html>`_.
88

99
.. currentmodule:: dpnp
1010

@@ -13,10 +13,11 @@ These functions cover a subset of
1313

1414
creation
1515
manipulation
16-
indexing
1716
binary
1817
dtype
1918
fft
19+
functional
20+
indexing
2021
linalg
2122
logic
2223
math

doc/reference/sorting.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Sorting, Searching, and Counting
1+
Sorting, searching, and counting
22
================================
33

44
.. https://numpy.org/doc/stable/reference/routines.sort.html

doc/reference/statistics.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
Statistical Functions
2-
=====================
1+
Statistics
2+
==========
33

44
.. https://numpy.org/doc/stable/reference/routines.statistics.html
55

dpnp/dpnp_iface.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@
8181
from dpnp.dpnp_iface_bitwise import __all__ as __all__bitwise
8282
from dpnp.dpnp_iface_counting import *
8383
from dpnp.dpnp_iface_counting import __all__ as __all__counting
84+
from dpnp.dpnp_iface_functional import *
85+
from dpnp.dpnp_iface_functional import __all__ as __all__functional
8486
from dpnp.dpnp_iface_histograms import *
8587
from dpnp.dpnp_iface_histograms import __all__ as __all__histograms
8688
from dpnp.dpnp_iface_indexing import *
@@ -116,6 +118,7 @@
116118
__all__ += __all__arraycreation
117119
__all__ += __all__bitwise
118120
__all__ += __all__counting
121+
__all__ += __all__functional
119122
__all__ += __all__histograms
120123
__all__ += __all__indexing
121124
__all__ += __all__libmath

dpnp/dpnp_iface_functional.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2024, Intel Corporation
3+
# All rights reserved.
4+
#
5+
# Redistribution and use in source and binary forms, with or without
6+
# modification, are permitted provided that the following conditions are met:
7+
# - Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# - Redistributions in binary form must reproduce the above copyright notice,
10+
# this list of conditions and the following disclaimer in the documentation
11+
# and/or other materials provided with the distribution.
12+
#
13+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
# THE POSSIBILITY OF SUCH DAMAGE.
24+
# *****************************************************************************
25+
26+
"""
27+
Interface of the functional programming routines part of the DPNP
28+
29+
Notes
30+
-----
31+
This module is a face or public interface file for the library
32+
it contains:
33+
- Interface functions
34+
- documentation for the functions
35+
- The functions parameters check
36+
37+
"""
38+
39+
40+
import numpy
41+
from dpctl.tensor._numpy_helper import normalize_axis_index
42+
43+
import dpnp
44+
45+
__all__ = ["apply_along_axis"]
46+
47+
48+
def apply_along_axis(func1d, axis, arr, *args, **kwargs):
49+
"""
50+
Apply a function to 1-D slices along the given axis.
51+
52+
Execute ``func1d(a, *args, **kwargs)`` where `func1d` operates on
53+
1-D arrays and `a` is a 1-D slice of `arr` along `axis`.
54+
55+
This is equivalent to (but faster than) the following use of
56+
:obj:`dpnp.ndindex` and :obj:`dpnp.s_`, which sets each of
57+
``ii``, ``jj``, and ``kk`` to a tuple of indices::
58+
59+
Ni, Nk = a.shape[:axis], a.shape[axis+1:]
60+
for ii in ndindex(Ni):
61+
for kk in ndindex(Nk):
62+
f = func1d(arr[ii + s_[:,] + kk])
63+
Nj = f.shape
64+
for jj in ndindex(Nj):
65+
out[ii + jj + kk] = f[jj]
66+
67+
Equivalently, eliminating the inner loop, this can be expressed as::
68+
69+
Ni, Nk = a.shape[:axis], a.shape[axis+1:]
70+
for ii in ndindex(Ni):
71+
for kk in ndindex(Nk):
72+
out[ii + s_[...,] + kk] = func1d(arr[ii + s_[:,] + kk])
73+
74+
For full documentation refer to :obj:`numpy.apply_along_axis`.
75+
76+
Parameters
77+
----------
78+
func1d : function (M,) -> (Nj...)
79+
This function should accept 1-D arrays. It is applied to 1-D
80+
slices of `arr` along the specified axis.
81+
axis : int
82+
Axis along which `arr` is sliced.
83+
arr : {dpnp.ndarray, usm_ndarray} (Ni..., M, Nk...)
84+
Input array.
85+
args : any
86+
Additional arguments to `func1d`.
87+
kwargs : any
88+
Additional named arguments to `func1d`.
89+
90+
Returns
91+
-------
92+
out : dpnp.ndarray (Ni..., Nj..., Nk...)
93+
The output array. The shape of `out` is identical to the shape of
94+
`arr`, except along the `axis` dimension. This axis is removed, and
95+
replaced with new dimensions equal to the shape of the return value
96+
of `func1d`.
97+
98+
See Also
99+
--------
100+
:obj:`dpnp.apply_over_axes` : Apply a function repeatedly over
101+
multiple axes.
102+
103+
Examples
104+
--------
105+
>>> import dpnp as np
106+
>>> def my_func(a): # Average first and last element of a 1-D array
107+
... return (a[0] + a[-1]) * 0.5
108+
>>> b = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
109+
>>> np.apply_along_axis(my_func, 0, b)
110+
array([4., 5., 6.])
111+
>>> np.apply_along_axis(my_func, 1, b)
112+
array([2., 5., 8.])
113+
114+
For a function that returns a 1D array, the number of dimensions in
115+
`out` is the same as `arr`.
116+
117+
>>> b = np.array([[8, 1, 7], [4, 3, 9], [5, 2, 6]])
118+
>>> np.apply_along_axis(sorted, 1, b)
119+
array([[1, 7, 8],
120+
[3, 4, 9],
121+
[2, 5, 6]])
122+
123+
For a function that returns a higher dimensional array, those dimensions
124+
are inserted in place of the `axis` dimension.
125+
126+
>>> b = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
127+
>>> np.apply_along_axis(np.diag, -1, b)
128+
array([[[1, 0, 0],
129+
[0, 2, 0],
130+
[0, 0, 3]],
131+
[[4, 0, 0],
132+
[0, 5, 0],
133+
[0, 0, 6]],
134+
[[7, 0, 0],
135+
[0, 8, 0],
136+
[0, 0, 9]]])
137+
138+
"""
139+
140+
dpnp.check_supported_arrays_type(arr)
141+
nd = arr.ndim
142+
exec_q = arr.sycl_queue
143+
usm_type = arr.usm_type
144+
axis = normalize_axis_index(axis, nd)
145+
146+
# arr, with the iteration axis at the end
147+
inarr_view = dpnp.moveaxis(arr, axis, -1)
148+
149+
# compute indices for the iteration axes, and append a trailing ellipsis to
150+
# prevent 0d arrays decaying to scalars
151+
# TODO: replace with dpnp.ndindex
152+
inds = numpy.ndindex(inarr_view.shape[:-1])
153+
inds = (ind + (Ellipsis,) for ind in inds)
154+
155+
# invoke the function on the first item
156+
try:
157+
ind0 = next(inds)
158+
except StopIteration:
159+
raise ValueError(
160+
"Cannot apply_along_axis when any iteration dimensions are 0"
161+
) from None
162+
res = dpnp.asanyarray(
163+
func1d(inarr_view[ind0], *args, **kwargs),
164+
sycl_queue=exec_q,
165+
usm_type=usm_type,
166+
)
167+
168+
# build a buffer for storing evaluations of func1d.
169+
# remove the requested axis, and add the new ones on the end.
170+
# laid out so that each write is contiguous.
171+
# for a tuple index inds, buff[inds] = func1d(inarr_view[inds])
172+
buff = dpnp.empty_like(res, shape=inarr_view.shape[:-1] + res.shape)
173+
174+
# save the first result, then compute and save all remaining results
175+
buff[ind0] = res
176+
for ind in inds:
177+
buff[ind] = dpnp.asanyarray(
178+
func1d(inarr_view[ind], *args, **kwargs),
179+
sycl_queue=exec_q,
180+
usm_type=usm_type,
181+
)
182+
183+
# restore the inserted axes back to where they belong
184+
for _ in range(res.ndim):
185+
buff = dpnp.moveaxis(buff, -1, axis)
186+
187+
return buff

dpnp/dpnp_iface_manipulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# *****************************************************************************
2626

2727
"""
28-
Interface of the Array manipulation routines part of the DPNP
28+
Interface of the array manipulation routines part of the DPNP
2929
3030
Notes
3131
-----

tests/test_functional.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import numpy
2+
import pytest
3+
from numpy.testing import assert_array_equal, assert_equal, assert_raises
4+
5+
import dpnp
6+
7+
from .helper import get_all_dtypes
8+
9+
10+
class TestApplyAlongAxis:
11+
def test_tuple_func1d(self):
12+
def sample_1d(x):
13+
return x[1], x[0]
14+
15+
a = numpy.array([[1, 2], [3, 4]])
16+
ia = dpnp.array(a)
17+
18+
# 2d insertion along first axis
19+
expected = numpy.apply_along_axis(sample_1d, 1, a)
20+
result = dpnp.apply_along_axis(sample_1d, 1, ia)
21+
assert_array_equal(result, expected)
22+
23+
@pytest.mark.parametrize("stride", [-1, 2, -3])
24+
def test_stride(self, stride):
25+
a = numpy.ones((20, 10), dtype="f")
26+
ia = dpnp.array(a)
27+
28+
expected = numpy.apply_along_axis(len, 0, a[::stride, ::stride])
29+
result = dpnp.apply_along_axis(len, 0, ia[::stride, ::stride])
30+
assert_array_equal(result, expected)
31+
32+
@pytest.mark.parametrize("dtype", get_all_dtypes())
33+
def test_args(self, dtype):
34+
a = numpy.ones((20, 10))
35+
ia = dpnp.array(a)
36+
37+
# kwargs
38+
expected = numpy.apply_along_axis(
39+
numpy.mean, 0, a, dtype=dtype, keepdims=True
40+
)
41+
result = dpnp.apply_along_axis(
42+
dpnp.mean, 0, ia, dtype=dtype, keepdims=True
43+
)
44+
assert_array_equal(result, expected)
45+
46+
# positional args: axis, dtype, out, keepdims
47+
result = dpnp.apply_along_axis(dpnp.mean, 0, ia, 0, dtype, None, True)
48+
assert_array_equal(result, expected)

tests/test_sycl_queue.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2178,6 +2178,18 @@ def test_split(func, data1, device):
21782178
assert_sycl_queue_equal(result[1].sycl_queue, x1.sycl_queue)
21792179

21802180

2181+
@pytest.mark.parametrize(
2182+
"device",
2183+
valid_devices,
2184+
ids=[device.filter_string for device in valid_devices],
2185+
)
2186+
def test_apply_along_axis(device):
2187+
x = dpnp.arange(9, device=device).reshape(3, 3)
2188+
result = dpnp.apply_along_axis(dpnp.sum, 0, x)
2189+
2190+
assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)
2191+
2192+
21812193
@pytest.mark.parametrize(
21822194
"device_x",
21832195
valid_devices,

tests/test_usm_type.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,14 @@ def test_2in_with_scalar_1out(func, data, scalar, usm_type):
773773
assert z.usm_type == usm_type
774774

775775

776+
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
777+
def test_apply_along_axis(usm_type):
778+
x = dp.arange(9, usm_type=usm_type).reshape(3, 3)
779+
y = dp.apply_along_axis(dp.sum, 0, x)
780+
781+
assert x.usm_type == y.usm_type
782+
783+
776784
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
777785
def test_broadcast_to(usm_type):
778786
x = dp.ones(7, usm_type=usm_type)

0 commit comments

Comments
 (0)