Skip to content

Commit ea4c477

Browse files
committed
Argmin and argmax now handle identities correctly
Adds a test for this behavior Fixed a typo in argmin and argmax causing shared local memory variant to be used for more types than expected
1 parent 74d4eb3 commit ea4c477

File tree

2 files changed

+127
-86
lines changed

2 files changed

+127
-86
lines changed

dpctl/tensor/libtensor/include/kernels/reductions.hpp

Lines changed: 117 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1670,25 +1670,37 @@ struct SearchReduction
16701670
auto inp_offset = inp_iter_offset + inp_reduction_offset;
16711671

16721672
argT val = inp_[inp_offset];
1673-
if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
1674-
if (val < local_red_val) {
1675-
local_red_val = val;
1676-
if constexpr (!First) {
1677-
local_idx = inds_[inp_offset];
1678-
}
1679-
else {
1680-
local_idx = static_cast<outT>(arg_reduce_gid);
1681-
}
1673+
if (val == local_red_val) {
1674+
if constexpr (!First) {
1675+
local_idx = std::min(local_idx, inds_[inp_offset]);
1676+
}
1677+
else {
1678+
local_idx = std::min(local_idx,
1679+
static_cast<outT>(arg_reduce_gid));
16821680
}
16831681
}
1684-
else if constexpr (su_ns::IsMaximum<argT, ReductionOp>::value) {
1685-
if (val > local_red_val) {
1686-
local_red_val = val;
1687-
if constexpr (!First) {
1688-
local_idx = inds_[inp_offset];
1682+
else {
1683+
if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
1684+
if (val < local_red_val) {
1685+
local_red_val = val;
1686+
if constexpr (!First) {
1687+
local_idx = inds_[inp_offset];
1688+
}
1689+
else {
1690+
local_idx = static_cast<outT>(arg_reduce_gid);
1691+
}
16891692
}
1690-
else {
1691-
local_idx = static_cast<outT>(arg_reduce_gid);
1693+
}
1694+
else if constexpr (su_ns::IsMaximum<argT,
1695+
ReductionOp>::value) {
1696+
if (val > local_red_val) {
1697+
local_red_val = val;
1698+
if constexpr (!First) {
1699+
local_idx = inds_[inp_offset];
1700+
}
1701+
else {
1702+
local_idx = static_cast<outT>(arg_reduce_gid);
1703+
}
16921704
}
16931705
}
16941706
}
@@ -1813,83 +1825,102 @@ struct CustomSearchReduction
18131825
auto inp_offset = inp_iter_offset + inp_reduction_offset;
18141826

18151827
argT val = inp_[inp_offset];
1816-
if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
1817-
using dpctl::tensor::type_utils::is_complex;
1818-
if constexpr (is_complex<argT>::value) {
1819-
using dpctl::tensor::math_utils::less_complex;
1820-
// less_complex always returns false for NaNs, so check
1821-
if (less_complex<argT>(val, local_red_val) ||
1822-
std::isnan(std::real(val)) ||
1823-
std::isnan(std::imag(val)))
1824-
{
1825-
local_red_val = val;
1826-
if constexpr (!First) {
1827-
local_idx = inds_[inp_offset];
1828-
}
1829-
else {
1830-
local_idx = static_cast<outT>(arg_reduce_gid);
1831-
}
1832-
}
1833-
}
1834-
else if constexpr (std::is_floating_point_v<argT>) {
1835-
if (val < local_red_val || std::isnan(val)) {
1836-
local_red_val = val;
1837-
if constexpr (!First) {
1838-
local_idx = inds_[inp_offset];
1839-
}
1840-
else {
1841-
local_idx = static_cast<outT>(arg_reduce_gid);
1842-
}
1843-
}
1828+
if (val == local_red_val) {
1829+
if constexpr (!First) {
1830+
local_idx = std::min(local_idx, inds_[inp_offset]);
18441831
}
18451832
else {
1846-
if (val < local_red_val) {
1847-
local_red_val = val;
1848-
if constexpr (!First) {
1849-
local_idx = inds_[inp_offset];
1850-
}
1851-
else {
1852-
local_idx = static_cast<outT>(arg_reduce_gid);
1853-
}
1854-
}
1833+
local_idx = std::min(local_idx,
1834+
static_cast<outT>(arg_reduce_gid));
18551835
}
18561836
}
1857-
else if constexpr (su_ns::IsMaximum<argT, ReductionOp>::value) {
1858-
using dpctl::tensor::type_utils::is_complex;
1859-
if constexpr (is_complex<argT>::value) {
1860-
using dpctl::tensor::math_utils::greater_complex;
1861-
if (greater_complex<argT>(val, local_red_val) ||
1862-
std::isnan(std::real(val)) ||
1863-
std::isnan(std::imag(val)))
1864-
{
1865-
local_red_val = val;
1866-
if constexpr (!First) {
1867-
local_idx = inds_[inp_offset];
1868-
}
1869-
else {
1870-
local_idx = static_cast<outT>(arg_reduce_gid);
1837+
else {
1838+
if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
1839+
using dpctl::tensor::type_utils::is_complex;
1840+
if constexpr (is_complex<argT>::value) {
1841+
using dpctl::tensor::math_utils::less_complex;
1842+
// less_complex always returns false for NaNs, so
1843+
// check
1844+
if (less_complex<argT>(val, local_red_val) ||
1845+
std::isnan(std::real(val)) ||
1846+
std::isnan(std::imag(val)))
1847+
{
1848+
local_red_val = val;
1849+
if constexpr (!First) {
1850+
local_idx = inds_[inp_offset];
1851+
}
1852+
else {
1853+
local_idx =
1854+
static_cast<outT>(arg_reduce_gid);
1855+
}
18711856
}
18721857
}
1873-
}
1874-
else if constexpr (std::is_floating_point_v<argT>) {
1875-
if (val > local_red_val || std::isnan(val)) {
1876-
local_red_val = val;
1877-
if constexpr (!First) {
1878-
local_idx = inds_[inp_offset];
1858+
else if constexpr (std::is_floating_point_v<argT>) {
1859+
if (val < local_red_val || std::isnan(val)) {
1860+
local_red_val = val;
1861+
if constexpr (!First) {
1862+
local_idx = inds_[inp_offset];
1863+
}
1864+
else {
1865+
local_idx =
1866+
static_cast<outT>(arg_reduce_gid);
1867+
}
18791868
}
1880-
else {
1881-
local_idx = static_cast<outT>(arg_reduce_gid);
1869+
}
1870+
else {
1871+
if (val < local_red_val) {
1872+
local_red_val = val;
1873+
if constexpr (!First) {
1874+
local_idx = inds_[inp_offset];
1875+
}
1876+
else {
1877+
local_idx =
1878+
static_cast<outT>(arg_reduce_gid);
1879+
}
18821880
}
18831881
}
18841882
}
1885-
else {
1886-
if (val > local_red_val) {
1887-
local_red_val = val;
1888-
if constexpr (!First) {
1889-
local_idx = inds_[inp_offset];
1883+
else if constexpr (su_ns::IsMaximum<argT,
1884+
ReductionOp>::value) {
1885+
using dpctl::tensor::type_utils::is_complex;
1886+
if constexpr (is_complex<argT>::value) {
1887+
using dpctl::tensor::math_utils::greater_complex;
1888+
if (greater_complex<argT>(val, local_red_val) ||
1889+
std::isnan(std::real(val)) ||
1890+
std::isnan(std::imag(val)))
1891+
{
1892+
local_red_val = val;
1893+
if constexpr (!First) {
1894+
local_idx = inds_[inp_offset];
1895+
}
1896+
else {
1897+
local_idx =
1898+
static_cast<outT>(arg_reduce_gid);
1899+
}
18901900
}
1891-
else {
1892-
local_idx = static_cast<outT>(arg_reduce_gid);
1901+
}
1902+
else if constexpr (std::is_floating_point_v<argT>) {
1903+
if (val > local_red_val || std::isnan(val)) {
1904+
local_red_val = val;
1905+
if constexpr (!First) {
1906+
local_idx = inds_[inp_offset];
1907+
}
1908+
else {
1909+
local_idx =
1910+
static_cast<outT>(arg_reduce_gid);
1911+
}
1912+
}
1913+
}
1914+
else {
1915+
if (val > local_red_val) {
1916+
local_red_val = val;
1917+
if constexpr (!First) {
1918+
local_idx = inds_[inp_offset];
1919+
}
1920+
else {
1921+
local_idx =
1922+
static_cast<outT>(arg_reduce_gid);
1923+
}
18931924
}
18941925
}
18951926
}
@@ -2042,7 +2073,7 @@ sycl::event search_reduction_over_group_temps_strided_impl(
20422073
sycl::range<1>{iter_nelems * reduction_groups * wg};
20432074
auto localRange = sycl::range<1>{wg};
20442075

2045-
if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
2076+
if constexpr (su_ns::IsSyclOp<argTy, ReductionOpT>::value) {
20462077
using KernelName = class search_reduction_over_group_temps_krn<
20472078
argTy, resTy, ReductionOpT, IndexOpT,
20482079
InputOutputIterIndexerT, ReductionIndexerT, true, true>;
@@ -2141,7 +2172,7 @@ sycl::event search_reduction_over_group_temps_strided_impl(
21412172
sycl::range<1>{iter_nelems * reduction_groups * wg};
21422173
auto localRange = sycl::range<1>{wg};
21432174

2144-
if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
2175+
if constexpr (su_ns::IsSyclOp<argTy, ReductionOpT>::value) {
21452176
using KernelName = class search_reduction_over_group_temps_krn<
21462177
argTy, resTy, ReductionOpT, IndexOpT,
21472178
InputOutputIterIndexerT, ReductionIndexerT, true, false>;
@@ -2221,7 +2252,7 @@ sycl::event search_reduction_over_group_temps_strided_impl(
22212252
auto globalRange =
22222253
sycl::range<1>{iter_nelems * reduction_groups_ * wg};
22232254
auto localRange = sycl::range<1>{wg};
2224-
if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
2255+
if constexpr (su_ns::IsSyclOp<argTy, ReductionOpT>::value) {
22252256
using KernelName =
22262257
class search_reduction_over_group_temps_krn<
22272258
argTy, resTy, ReductionOpT, IndexOpT,
@@ -2304,7 +2335,7 @@ sycl::event search_reduction_over_group_temps_strided_impl(
23042335
sycl::range<1>{iter_nelems * reduction_groups * wg};
23052336
auto localRange = sycl::range<1>{wg};
23062337

2307-
if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
2338+
if constexpr (su_ns::IsSyclOp<argTy, ReductionOpT>::value) {
23082339
using KernelName = class search_reduction_over_group_temps_krn<
23092340
argTy, resTy, ReductionOpT, IndexOpT,
23102341
InputOutputIterIndexerT, ReductionIndexerT, false, true>;

dpctl/tests/test_usm_ndarray_reductions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,13 @@ def test_argmax_argmin_nan_propagation():
201201
x[idx] = complex(0, dpt.nan)
202202
assert dpt.argmax(x) == idx
203203
assert dpt.argmin(x) == idx
204+
205+
206+
def test_argmax_argmin_identities():
207+
# make sure that identity arrays work as expected
208+
get_queue_or_skip()
209+
210+
x = dpt.full(3, dpt.iinfo(dpt.int32).min, dtype="i4")
211+
assert dpt.argmax(x) == 0
212+
x = dpt.full(3, dpt.iinfo(dpt.int32).max, dtype="i4")
213+
assert dpt.argmin(x) == 0

0 commit comments

Comments
 (0)