@@ -1831,6 +1831,250 @@ struct SumOverAxis0AtomicContigFactory
1831
1831
}
1832
1832
};
1833
1833
1834
+ // Product
1835
+
1836
+ /* @brief Types supported by plus-reduction code based on atomic_ref */
1837
+ template <typename argTy, typename outTy>
1838
+ struct TypePairSupportDataForProductReductionAtomic
1839
+ {
1840
+
1841
+ /* value if true a kernel for <argTy, outTy> must be instantiated, false
1842
+ * otherwise */
1843
+ static constexpr bool is_defined = std::disjunction< // disjunction is C++17
1844
+ // feature, supported
1845
+ // by DPC++ input bool
1846
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::int32_t >,
1847
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::uint32_t >,
1848
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::int64_t >,
1849
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::uint64_t >,
1850
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, float >,
1851
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, double >,
1852
+ // input int8
1853
+ td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int32_t >,
1854
+ td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int64_t >,
1855
+ td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, float >,
1856
+ td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, double >,
1857
+ // input uint8
1858
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::int32_t >,
1859
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::uint32_t >,
1860
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::int64_t >,
1861
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::uint64_t >,
1862
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, float >,
1863
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, double >,
1864
+ // input int16
1865
+ td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, std::int32_t >,
1866
+ td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, std::int64_t >,
1867
+ td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, float >,
1868
+ td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, double >,
1869
+ // input uint16
1870
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::int32_t >,
1871
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::uint32_t >,
1872
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::int64_t >,
1873
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::uint64_t >,
1874
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, float >,
1875
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, double >,
1876
+ // input int32
1877
+ td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, std::int32_t >,
1878
+ td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, std::int64_t >,
1879
+ td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, float >,
1880
+ td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, double >,
1881
+ // input uint32
1882
+ td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, std::uint32_t >,
1883
+ td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, std::int64_t >,
1884
+ td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, std::uint64_t >,
1885
+ td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, float >,
1886
+ td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, double >,
1887
+ // input int64
1888
+ td_ns::TypePairDefinedEntry<argTy, std::int64_t , outTy, std::int64_t >,
1889
+ td_ns::TypePairDefinedEntry<argTy, std::int64_t , outTy, double >,
1890
+ // input uint64
1891
+ td_ns::TypePairDefinedEntry<argTy, std::uint64_t , outTy, std::uint64_t >,
1892
+ td_ns::TypePairDefinedEntry<argTy, std::uint64_t , outTy, double >,
1893
+ // input half
1894
+ td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, float >,
1895
+ td_ns::TypePairDefinedEntry<argTy, float , outTy, double >,
1896
+ // input float
1897
+ td_ns::TypePairDefinedEntry<argTy, float , outTy, float >,
1898
+ td_ns::TypePairDefinedEntry<argTy, float , outTy, double >,
1899
+ // input double
1900
+ td_ns::TypePairDefinedEntry<argTy, double , outTy, double >,
1901
+ // fall-through
1902
+ td_ns::NotDefinedEntry>::is_defined;
1903
+ };
1904
+
1905
+ template <typename argTy, typename outTy>
1906
+ struct TypePairSupportDataForProductReductionTemps
1907
+ {
1908
+
1909
+ static constexpr bool is_defined = std::disjunction< // disjunction is C++17
1910
+ // feature, supported
1911
+ // by DPC++ input bool
1912
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::int8_t >,
1913
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::uint8_t >,
1914
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::int16_t >,
1915
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::uint16_t >,
1916
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::int32_t >,
1917
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::uint32_t >,
1918
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::int64_t >,
1919
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::uint64_t >,
1920
+
1921
+ // input int8_t
1922
+ td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int8_t >,
1923
+ td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int16_t >,
1924
+ td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int32_t >,
1925
+ td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int64_t >,
1926
+
1927
+ // input uint8_t
1928
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::uint8_t >,
1929
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::int16_t >,
1930
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::uint16_t >,
1931
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::int32_t >,
1932
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::uint32_t >,
1933
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::int64_t >,
1934
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::uint64_t >,
1935
+
1936
+ // input int16_t
1937
+ td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, std::int16_t >,
1938
+ td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, std::int32_t >,
1939
+ td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, std::int64_t >,
1940
+
1941
+ // input uint16_t
1942
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::uint16_t >,
1943
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::int32_t >,
1944
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::uint32_t >,
1945
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::int64_t >,
1946
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::uint64_t >,
1947
+
1948
+ // input int32_t
1949
+ td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, std::int32_t >,
1950
+ td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, std::int64_t >,
1951
+
1952
+ // input uint32_t
1953
+ td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, std::uint32_t >,
1954
+ td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, std::uint64_t >,
1955
+
1956
+ // input int64_t
1957
+ td_ns::TypePairDefinedEntry<argTy, std::int64_t , outTy, std::int64_t >,
1958
+
1959
+ // input uint32_t
1960
+ td_ns::TypePairDefinedEntry<argTy, std::uint64_t , outTy, std::uint64_t >,
1961
+
1962
+ // input half
1963
+ td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, sycl::half>,
1964
+ td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, float >,
1965
+ td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, double >,
1966
+ td_ns::
1967
+ TypePairDefinedEntry<argTy, sycl::half, outTy, std::complex<float >>,
1968
+ td_ns::TypePairDefinedEntry<argTy,
1969
+ sycl::half,
1970
+ outTy,
1971
+ std::complex<double >>,
1972
+
1973
+ // input float
1974
+ td_ns::TypePairDefinedEntry<argTy, float , outTy, float >,
1975
+ td_ns::TypePairDefinedEntry<argTy, float , outTy, double >,
1976
+ td_ns::TypePairDefinedEntry<argTy, float , outTy, std::complex<float >>,
1977
+ td_ns::TypePairDefinedEntry<argTy, float , outTy, std::complex<double >>,
1978
+
1979
+ // input double
1980
+ td_ns::TypePairDefinedEntry<argTy, double , outTy, double >,
1981
+ td_ns::TypePairDefinedEntry<argTy, double , outTy, std::complex<double >>,
1982
+
1983
+ // input std::complex
1984
+ td_ns::TypePairDefinedEntry<argTy,
1985
+ std::complex<float >,
1986
+ outTy,
1987
+ std::complex<float >>,
1988
+ td_ns::TypePairDefinedEntry<argTy,
1989
+ std::complex<float >,
1990
+ outTy,
1991
+ std::complex<double >>,
1992
+
1993
+ td_ns::TypePairDefinedEntry<argTy,
1994
+ std::complex<double >,
1995
+ outTy,
1996
+ std::complex<double >>,
1997
+
1998
+ // fall-throug
1999
+ td_ns::NotDefinedEntry>::is_defined;
2000
+ };
2001
+
2002
+ template <typename fnT, typename srcTy, typename dstTy>
2003
+ struct ProductOverAxisAtomicStridedFactory
2004
+ {
2005
+ fnT get () const
2006
+ {
2007
+ if constexpr (TypePairSupportDataForProductReductionAtomic<
2008
+ srcTy, dstTy>::is_defined)
2009
+ {
2010
+ using ReductionOpT = sycl::multiplies<dstTy>;
2011
+ return dpctl::tensor::kernels::
2012
+ reduction_over_group_with_atomics_strided_impl<srcTy, dstTy,
2013
+ ReductionOpT>;
2014
+ }
2015
+ else {
2016
+ return nullptr ;
2017
+ }
2018
+ }
2019
+ };
2020
+
2021
+ template <typename fnT, typename srcTy, typename dstTy>
2022
+ struct ProductOverAxisTempsStridedFactory
2023
+ {
2024
+ fnT get () const
2025
+ {
2026
+ if constexpr (TypePairSupportDataForProductReductionTemps<
2027
+ srcTy, dstTy>::is_defined)
2028
+ {
2029
+ using ReductionOpT = sycl::multiplies<dstTy>;
2030
+ return dpctl::tensor::kernels::
2031
+ reduction_over_group_temps_strided_impl<srcTy, dstTy,
2032
+ ReductionOpT>;
2033
+ }
2034
+ else {
2035
+ return nullptr ;
2036
+ }
2037
+ }
2038
+ };
2039
+
2040
+ template <typename fnT, typename srcTy, typename dstTy>
2041
+ struct ProductOverAxis1AtomicContigFactory
2042
+ {
2043
+ fnT get () const
2044
+ {
2045
+ if constexpr (TypePairSupportDataForProductReductionAtomic<
2046
+ srcTy, dstTy>::is_defined)
2047
+ {
2048
+ using ReductionOpT = sycl::multiplies<dstTy>;
2049
+ return dpctl::tensor::kernels::
2050
+ reduction_axis1_over_group_with_atomics_contig_impl<
2051
+ srcTy, dstTy, ReductionOpT>;
2052
+ }
2053
+ else {
2054
+ return nullptr ;
2055
+ }
2056
+ }
2057
+ };
2058
+
2059
+ template <typename fnT, typename srcTy, typename dstTy>
2060
+ struct ProductOverAxis0AtomicContigFactory
2061
+ {
2062
+ fnT get () const
2063
+ {
2064
+ if constexpr (TypePairSupportDataForProductReductionAtomic<
2065
+ srcTy, dstTy>::is_defined)
2066
+ {
2067
+ using ReductionOpT = sycl::multiplies<dstTy>;
2068
+ return dpctl::tensor::kernels::
2069
+ reduction_axis0_over_group_with_atomics_contig_impl<
2070
+ srcTy, dstTy, ReductionOpT>;
2071
+ }
2072
+ else {
2073
+ return nullptr ;
2074
+ }
2075
+ }
2076
+ };
2077
+
1834
2078
// Argmax and Argmin
1835
2079
1836
2080
/* = Search reduction using reduce_over_group*/
0 commit comments