@@ -1670,25 +1670,37 @@ struct SearchReduction
1670
1670
auto inp_offset = inp_iter_offset + inp_reduction_offset;
1671
1671
1672
1672
argT val = inp_[inp_offset];
1673
- if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
1674
- if (val < local_red_val) {
1675
- local_red_val = val;
1676
- if constexpr (!First) {
1677
- local_idx = inds_[inp_offset];
1678
- }
1679
- else {
1680
- local_idx = static_cast <outT>(arg_reduce_gid);
1681
- }
1673
+ if (val == local_red_val) {
1674
+ if constexpr (!First) {
1675
+ local_idx = std::min (local_idx, inds_[inp_offset]);
1676
+ }
1677
+ else {
1678
+ local_idx = std::min (local_idx,
1679
+ static_cast <outT>(arg_reduce_gid));
1682
1680
}
1683
1681
}
1684
- else if constexpr (su_ns::IsMaximum<argT, ReductionOp>::value) {
1685
- if (val > local_red_val) {
1686
- local_red_val = val;
1687
- if constexpr (!First) {
1688
- local_idx = inds_[inp_offset];
1682
+ else {
1683
+ if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
1684
+ if (val < local_red_val) {
1685
+ local_red_val = val;
1686
+ if constexpr (!First) {
1687
+ local_idx = inds_[inp_offset];
1688
+ }
1689
+ else {
1690
+ local_idx = static_cast <outT>(arg_reduce_gid);
1691
+ }
1689
1692
}
1690
- else {
1691
- local_idx = static_cast <outT>(arg_reduce_gid);
1693
+ }
1694
+ else if constexpr (su_ns::IsMaximum<argT,
1695
+ ReductionOp>::value) {
1696
+ if (val > local_red_val) {
1697
+ local_red_val = val;
1698
+ if constexpr (!First) {
1699
+ local_idx = inds_[inp_offset];
1700
+ }
1701
+ else {
1702
+ local_idx = static_cast <outT>(arg_reduce_gid);
1703
+ }
1692
1704
}
1693
1705
}
1694
1706
}
@@ -1813,83 +1825,102 @@ struct CustomSearchReduction
1813
1825
auto inp_offset = inp_iter_offset + inp_reduction_offset;
1814
1826
1815
1827
argT val = inp_[inp_offset];
1816
- if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
1817
- using dpctl::tensor::type_utils::is_complex;
1818
- if constexpr (is_complex<argT>::value) {
1819
- using dpctl::tensor::math_utils::less_complex;
1820
- // less_complex always returns false for NaNs, so check
1821
- if (less_complex<argT>(val, local_red_val) ||
1822
- std::isnan (std::real (val)) ||
1823
- std::isnan (std::imag (val)))
1824
- {
1825
- local_red_val = val;
1826
- if constexpr (!First) {
1827
- local_idx = inds_[inp_offset];
1828
- }
1829
- else {
1830
- local_idx = static_cast <outT>(arg_reduce_gid);
1831
- }
1832
- }
1833
- }
1834
- else if constexpr (std::is_floating_point_v<argT>) {
1835
- if (val < local_red_val || std::isnan (val)) {
1836
- local_red_val = val;
1837
- if constexpr (!First) {
1838
- local_idx = inds_[inp_offset];
1839
- }
1840
- else {
1841
- local_idx = static_cast <outT>(arg_reduce_gid);
1842
- }
1843
- }
1828
+ if (val == local_red_val) {
1829
+ if constexpr (!First) {
1830
+ local_idx = std::min (local_idx, inds_[inp_offset]);
1844
1831
}
1845
1832
else {
1846
- if (val < local_red_val) {
1847
- local_red_val = val;
1848
- if constexpr (!First) {
1849
- local_idx = inds_[inp_offset];
1850
- }
1851
- else {
1852
- local_idx = static_cast <outT>(arg_reduce_gid);
1853
- }
1854
- }
1833
+ local_idx = std::min (local_idx,
1834
+ static_cast <outT>(arg_reduce_gid));
1855
1835
}
1856
1836
}
1857
- else if constexpr (su_ns::IsMaximum<argT, ReductionOp>::value) {
1858
- using dpctl::tensor::type_utils::is_complex;
1859
- if constexpr (is_complex<argT>::value) {
1860
- using dpctl::tensor::math_utils::greater_complex;
1861
- if (greater_complex<argT>(val, local_red_val) ||
1862
- std::isnan (std::real (val)) ||
1863
- std::isnan (std::imag (val)))
1864
- {
1865
- local_red_val = val;
1866
- if constexpr (!First) {
1867
- local_idx = inds_[inp_offset];
1868
- }
1869
- else {
1870
- local_idx = static_cast <outT>(arg_reduce_gid);
1837
+ else {
1838
+ if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
1839
+ using dpctl::tensor::type_utils::is_complex;
1840
+ if constexpr (is_complex<argT>::value) {
1841
+ using dpctl::tensor::math_utils::less_complex;
1842
+ // less_complex always returns false for NaNs, so
1843
+ // check
1844
+ if (less_complex<argT>(val, local_red_val) ||
1845
+ std::isnan (std::real (val)) ||
1846
+ std::isnan (std::imag (val)))
1847
+ {
1848
+ local_red_val = val;
1849
+ if constexpr (!First) {
1850
+ local_idx = inds_[inp_offset];
1851
+ }
1852
+ else {
1853
+ local_idx =
1854
+ static_cast <outT>(arg_reduce_gid);
1855
+ }
1871
1856
}
1872
1857
}
1873
- }
1874
- else if constexpr (std::is_floating_point_v<argT>) {
1875
- if (val > local_red_val || std::isnan (val)) {
1876
- local_red_val = val;
1877
- if constexpr (!First) {
1878
- local_idx = inds_[inp_offset];
1858
+ else if constexpr (std::is_floating_point_v<argT>) {
1859
+ if (val < local_red_val || std::isnan (val)) {
1860
+ local_red_val = val;
1861
+ if constexpr (!First) {
1862
+ local_idx = inds_[inp_offset];
1863
+ }
1864
+ else {
1865
+ local_idx =
1866
+ static_cast <outT>(arg_reduce_gid);
1867
+ }
1879
1868
}
1880
- else {
1881
- local_idx = static_cast <outT>(arg_reduce_gid);
1869
+ }
1870
+ else {
1871
+ if (val < local_red_val) {
1872
+ local_red_val = val;
1873
+ if constexpr (!First) {
1874
+ local_idx = inds_[inp_offset];
1875
+ }
1876
+ else {
1877
+ local_idx =
1878
+ static_cast <outT>(arg_reduce_gid);
1879
+ }
1882
1880
}
1883
1881
}
1884
1882
}
1885
- else {
1886
- if (val > local_red_val) {
1887
- local_red_val = val;
1888
- if constexpr (!First) {
1889
- local_idx = inds_[inp_offset];
1883
+ else if constexpr (su_ns::IsMaximum<argT,
1884
+ ReductionOp>::value) {
1885
+ using dpctl::tensor::type_utils::is_complex;
1886
+ if constexpr (is_complex<argT>::value) {
1887
+ using dpctl::tensor::math_utils::greater_complex;
1888
+ if (greater_complex<argT>(val, local_red_val) ||
1889
+ std::isnan (std::real (val)) ||
1890
+ std::isnan (std::imag (val)))
1891
+ {
1892
+ local_red_val = val;
1893
+ if constexpr (!First) {
1894
+ local_idx = inds_[inp_offset];
1895
+ }
1896
+ else {
1897
+ local_idx =
1898
+ static_cast <outT>(arg_reduce_gid);
1899
+ }
1890
1900
}
1891
- else {
1892
- local_idx = static_cast <outT>(arg_reduce_gid);
1901
+ }
1902
+ else if constexpr (std::is_floating_point_v<argT>) {
1903
+ if (val > local_red_val || std::isnan (val)) {
1904
+ local_red_val = val;
1905
+ if constexpr (!First) {
1906
+ local_idx = inds_[inp_offset];
1907
+ }
1908
+ else {
1909
+ local_idx =
1910
+ static_cast <outT>(arg_reduce_gid);
1911
+ }
1912
+ }
1913
+ }
1914
+ else {
1915
+ if (val > local_red_val) {
1916
+ local_red_val = val;
1917
+ if constexpr (!First) {
1918
+ local_idx = inds_[inp_offset];
1919
+ }
1920
+ else {
1921
+ local_idx =
1922
+ static_cast <outT>(arg_reduce_gid);
1923
+ }
1893
1924
}
1894
1925
}
1895
1926
}
@@ -2042,7 +2073,7 @@ sycl::event search_reduction_over_group_temps_strided_impl(
2042
2073
sycl::range<1 >{iter_nelems * reduction_groups * wg};
2043
2074
auto localRange = sycl::range<1 >{wg};
2044
2075
2045
- if constexpr (su_ns::IsSyclOp<resTy , ReductionOpT>::value) {
2076
+ if constexpr (su_ns::IsSyclOp<argTy , ReductionOpT>::value) {
2046
2077
using KernelName = class search_reduction_over_group_temps_krn <
2047
2078
argTy, resTy, ReductionOpT, IndexOpT,
2048
2079
InputOutputIterIndexerT, ReductionIndexerT, true , true >;
@@ -2141,7 +2172,7 @@ sycl::event search_reduction_over_group_temps_strided_impl(
2141
2172
sycl::range<1 >{iter_nelems * reduction_groups * wg};
2142
2173
auto localRange = sycl::range<1 >{wg};
2143
2174
2144
- if constexpr (su_ns::IsSyclOp<resTy , ReductionOpT>::value) {
2175
+ if constexpr (su_ns::IsSyclOp<argTy , ReductionOpT>::value) {
2145
2176
using KernelName = class search_reduction_over_group_temps_krn <
2146
2177
argTy, resTy, ReductionOpT, IndexOpT,
2147
2178
InputOutputIterIndexerT, ReductionIndexerT, true , false >;
@@ -2221,7 +2252,7 @@ sycl::event search_reduction_over_group_temps_strided_impl(
2221
2252
auto globalRange =
2222
2253
sycl::range<1 >{iter_nelems * reduction_groups_ * wg};
2223
2254
auto localRange = sycl::range<1 >{wg};
2224
- if constexpr (su_ns::IsSyclOp<resTy , ReductionOpT>::value) {
2255
+ if constexpr (su_ns::IsSyclOp<argTy , ReductionOpT>::value) {
2225
2256
using KernelName =
2226
2257
class search_reduction_over_group_temps_krn <
2227
2258
argTy, resTy, ReductionOpT, IndexOpT,
@@ -2304,7 +2335,7 @@ sycl::event search_reduction_over_group_temps_strided_impl(
2304
2335
sycl::range<1 >{iter_nelems * reduction_groups * wg};
2305
2336
auto localRange = sycl::range<1 >{wg};
2306
2337
2307
- if constexpr (su_ns::IsSyclOp<resTy , ReductionOpT>::value) {
2338
+ if constexpr (su_ns::IsSyclOp<argTy , ReductionOpT>::value) {
2308
2339
using KernelName = class search_reduction_over_group_temps_krn <
2309
2340
argTy, resTy, ReductionOpT, IndexOpT,
2310
2341
InputOutputIterIndexerT, ReductionIndexerT, false , true >;
0 commit comments