Skip to content

Implement product #1426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@
tanh,
trunc,
)
from ._reduction import argmax, argmin, max, min, sum
from ._reduction import argmax, argmin, max, min, prod, sum
from ._testing import allclose

__all__ = [
Expand Down Expand Up @@ -313,4 +313,5 @@
"min",
"argmax",
"argmin",
"prod",
]
69 changes: 65 additions & 4 deletions dpctl/tensor/_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,12 @@ def _reduction_over_axis(
def sum(x, axis=None, dtype=None, keepdims=False):
"""sum(x, axis=None, dtype=None, keepdims=False)

Calculates the sum of the input array `x`.
Calculates the sum of elements in the input array `x`.

Args:
x (usm_ndarray):
input array.
axis (Optional[int, Tuple[int,...]]):
axis (Optional[int, Tuple[int, ...]]):
axis or axes along which sums must be computed. If a tuple
of unique integers, sums are computed over multiple axes.
If `None`, the sum is computed over the entire array.
Expand Down Expand Up @@ -202,6 +202,67 @@ def sum(x, axis=None, dtype=None, keepdims=False):
)


def prod(x, axis=None, dtype=None, keepdims=False):
"""prod(x, axis=None, dtype=None, keepdims=False)

Calculates the product of elements in the input array `x`.

Args:
x (usm_ndarray):
input array.
axis (Optional[int, Tuple[int, ...]]):
axis or axes along which products must be computed. If a tuple
of unique integers, products are computed over multiple axes.
If `None`, the product is computed over the entire array.
Default: `None`.
dtype (Optional[dtype]):
data type of the returned array. If `None`, the default data
type is inferred from the "kind" of the input array data type.
* If `x` has a real-valued floating-point data type,
the returned array will have the default real-valued
floating-point data type for the device where input
array `x` is allocated.
* If x` has signed integral data type, the returned array
will have the default signed integral type for the device
where input array `x` is allocated.
* If `x` has unsigned integral data type, the returned array
will have the default unsigned integral type for the device
where input array `x` is allocated.
* If `x` has a complex-valued floating-point data typee,
the returned array will have the default complex-valued
floating-pointer data type for the device where input
array `x` is allocated.
* If `x` has a boolean data type, the returned array will
have the default signed integral type for the device
where input array `x` is allocated.
If the data type (either specified or resolved) differs from the
data type of `x`, the input array elements are cast to the
specified data type before computing the product. Default: `None`.
keepdims (Optional[bool]):
if `True`, the reduced axes (dimensions) are included in the result
as singleton dimensions, so that the returned array remains
compatible with the input arrays according to Array Broadcasting
rules. Otherwise, if `False`, the reduced axes are not included in
the returned array. Default: `False`.
Returns:
usm_ndarray:
an array containing the products. If the product was computed over
the entire array, a zero-dimensional array is returned. The returned
array has the data type as described in the `dtype` parameter
description above.
"""
return _reduction_over_axis(
x,
axis,
dtype,
keepdims,
ti._prod_over_axis,
ti._prod_over_axis_dtype_supported,
_default_reduction_dtype,
_identity=1,
)


def _comparison_over_axis(x, axis, keepdims, _reduction_fn):
if not isinstance(x, dpt.usm_ndarray):
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
Expand Down Expand Up @@ -253,7 +314,7 @@ def max(x, axis=None, keepdims=False):
Args:
x (usm_ndarray):
input array.
axis (Optional[int, Tuple[int,...]]):
axis (Optional[int, Tuple[int, ...]]):
axis or axes along which maxima must be computed. If a tuple
of unique integers, the maxima are computed over multiple axes.
If `None`, the max is computed over the entire array.
Expand Down Expand Up @@ -281,7 +342,7 @@ def min(x, axis=None, keepdims=False):
Args:
x (usm_ndarray):
input array.
axis (Optional[int, Tuple[int,...]]):
axis (Optional[int, Tuple[int, ...]]):
axis or axes along which minima must be computed. If a tuple
of unique integers, the minima are computed over multiple axes.
If `None`, the min is computed over the entire array.
Expand Down
Loading