Skip to content

Commit 0598416

Browse files
Adding functor factories for product over axis
1 parent 478b30c commit 0598416

File tree

1 file changed

+244
-0
lines changed

1 file changed

+244
-0
lines changed

dpctl/tensor/libtensor/include/kernels/reductions.hpp

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1831,6 +1831,250 @@ struct SumOverAxis0AtomicContigFactory
18311831
}
18321832
};
18331833

1834+
// Product
1835+
1836+
/* @brief Types supported by plus-reduction code based on atomic_ref */
1837+
template <typename argTy, typename outTy>
1838+
struct TypePairSupportDataForProductReductionAtomic
1839+
{
1840+
1841+
/* value if true a kernel for <argTy, outTy> must be instantiated, false
1842+
* otherwise */
1843+
static constexpr bool is_defined = std::disjunction< // disjunction is C++17
1844+
// feature, supported
1845+
// by DPC++ input bool
1846+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int32_t>,
1847+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint32_t>,
1848+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int64_t>,
1849+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint64_t>,
1850+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, float>,
1851+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, double>,
1852+
// input int8
1853+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int32_t>,
1854+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int64_t>,
1855+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, float>,
1856+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, double>,
1857+
// input uint8
1858+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::int32_t>,
1859+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint32_t>,
1860+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::int64_t>,
1861+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint64_t>,
1862+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, float>,
1863+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, double>,
1864+
// input int16
1865+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int32_t>,
1866+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int64_t>,
1867+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, float>,
1868+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, double>,
1869+
// input uint16
1870+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::int32_t>,
1871+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint32_t>,
1872+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::int64_t>,
1873+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint64_t>,
1874+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, float>,
1875+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, double>,
1876+
// input int32
1877+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, std::int32_t>,
1878+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, std::int64_t>,
1879+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, float>,
1880+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, double>,
1881+
// input uint32
1882+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, std::uint32_t>,
1883+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, std::int64_t>,
1884+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, std::uint64_t>,
1885+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, float>,
1886+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, double>,
1887+
// input int64
1888+
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, std::int64_t>,
1889+
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, double>,
1890+
// input uint64
1891+
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, std::uint64_t>,
1892+
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, double>,
1893+
// input half
1894+
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, float>,
1895+
td_ns::TypePairDefinedEntry<argTy, float, outTy, double>,
1896+
// input float
1897+
td_ns::TypePairDefinedEntry<argTy, float, outTy, float>,
1898+
td_ns::TypePairDefinedEntry<argTy, float, outTy, double>,
1899+
// input double
1900+
td_ns::TypePairDefinedEntry<argTy, double, outTy, double>,
1901+
// fall-through
1902+
td_ns::NotDefinedEntry>::is_defined;
1903+
};
1904+
1905+
template <typename argTy, typename outTy>
1906+
struct TypePairSupportDataForProductReductionTemps
1907+
{
1908+
1909+
static constexpr bool is_defined = std::disjunction< // disjunction is C++17
1910+
// feature, supported
1911+
// by DPC++ input bool
1912+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int8_t>,
1913+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint8_t>,
1914+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int16_t>,
1915+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint16_t>,
1916+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int32_t>,
1917+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint32_t>,
1918+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int64_t>,
1919+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint64_t>,
1920+
1921+
// input int8_t
1922+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int8_t>,
1923+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int16_t>,
1924+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int32_t>,
1925+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int64_t>,
1926+
1927+
// input uint8_t
1928+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint8_t>,
1929+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::int16_t>,
1930+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint16_t>,
1931+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::int32_t>,
1932+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint32_t>,
1933+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::int64_t>,
1934+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint64_t>,
1935+
1936+
// input int16_t
1937+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int16_t>,
1938+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int32_t>,
1939+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int64_t>,
1940+
1941+
// input uint16_t
1942+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint16_t>,
1943+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::int32_t>,
1944+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint32_t>,
1945+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::int64_t>,
1946+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint64_t>,
1947+
1948+
// input int32_t
1949+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, std::int32_t>,
1950+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, std::int64_t>,
1951+
1952+
// input uint32_t
1953+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, std::uint32_t>,
1954+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, std::uint64_t>,
1955+
1956+
// input int64_t
1957+
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, std::int64_t>,
1958+
1959+
// input uint32_t
1960+
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, std::uint64_t>,
1961+
1962+
// input half
1963+
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, sycl::half>,
1964+
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, float>,
1965+
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, double>,
1966+
td_ns::
1967+
TypePairDefinedEntry<argTy, sycl::half, outTy, std::complex<float>>,
1968+
td_ns::TypePairDefinedEntry<argTy,
1969+
sycl::half,
1970+
outTy,
1971+
std::complex<double>>,
1972+
1973+
// input float
1974+
td_ns::TypePairDefinedEntry<argTy, float, outTy, float>,
1975+
td_ns::TypePairDefinedEntry<argTy, float, outTy, double>,
1976+
td_ns::TypePairDefinedEntry<argTy, float, outTy, std::complex<float>>,
1977+
td_ns::TypePairDefinedEntry<argTy, float, outTy, std::complex<double>>,
1978+
1979+
// input double
1980+
td_ns::TypePairDefinedEntry<argTy, double, outTy, double>,
1981+
td_ns::TypePairDefinedEntry<argTy, double, outTy, std::complex<double>>,
1982+
1983+
// input std::complex
1984+
td_ns::TypePairDefinedEntry<argTy,
1985+
std::complex<float>,
1986+
outTy,
1987+
std::complex<float>>,
1988+
td_ns::TypePairDefinedEntry<argTy,
1989+
std::complex<float>,
1990+
outTy,
1991+
std::complex<double>>,
1992+
1993+
td_ns::TypePairDefinedEntry<argTy,
1994+
std::complex<double>,
1995+
outTy,
1996+
std::complex<double>>,
1997+
1998+
// fall-throug
1999+
td_ns::NotDefinedEntry>::is_defined;
2000+
};
2001+
2002+
template <typename fnT, typename srcTy, typename dstTy>
2003+
struct ProductOverAxisAtomicStridedFactory
2004+
{
2005+
fnT get() const
2006+
{
2007+
if constexpr (TypePairSupportDataForProductReductionAtomic<
2008+
srcTy, dstTy>::is_defined)
2009+
{
2010+
using ReductionOpT = sycl::multiplies<dstTy>;
2011+
return dpctl::tensor::kernels::
2012+
reduction_over_group_with_atomics_strided_impl<srcTy, dstTy,
2013+
ReductionOpT>;
2014+
}
2015+
else {
2016+
return nullptr;
2017+
}
2018+
}
2019+
};
2020+
2021+
template <typename fnT, typename srcTy, typename dstTy>
2022+
struct ProductOverAxisTempsStridedFactory
2023+
{
2024+
fnT get() const
2025+
{
2026+
if constexpr (TypePairSupportDataForProductReductionTemps<
2027+
srcTy, dstTy>::is_defined)
2028+
{
2029+
using ReductionOpT = sycl::multiplies<dstTy>;
2030+
return dpctl::tensor::kernels::
2031+
reduction_over_group_temps_strided_impl<srcTy, dstTy,
2032+
ReductionOpT>;
2033+
}
2034+
else {
2035+
return nullptr;
2036+
}
2037+
}
2038+
};
2039+
2040+
template <typename fnT, typename srcTy, typename dstTy>
2041+
struct ProductOverAxis1AtomicContigFactory
2042+
{
2043+
fnT get() const
2044+
{
2045+
if constexpr (TypePairSupportDataForProductReductionAtomic<
2046+
srcTy, dstTy>::is_defined)
2047+
{
2048+
using ReductionOpT = sycl::multiplies<dstTy>;
2049+
return dpctl::tensor::kernels::
2050+
reduction_axis1_over_group_with_atomics_contig_impl<
2051+
srcTy, dstTy, ReductionOpT>;
2052+
}
2053+
else {
2054+
return nullptr;
2055+
}
2056+
}
2057+
};
2058+
2059+
template <typename fnT, typename srcTy, typename dstTy>
2060+
struct ProductOverAxis0AtomicContigFactory
2061+
{
2062+
fnT get() const
2063+
{
2064+
if constexpr (TypePairSupportDataForProductReductionAtomic<
2065+
srcTy, dstTy>::is_defined)
2066+
{
2067+
using ReductionOpT = sycl::multiplies<dstTy>;
2068+
return dpctl::tensor::kernels::
2069+
reduction_axis0_over_group_with_atomics_contig_impl<
2070+
srcTy, dstTy, ReductionOpT>;
2071+
}
2072+
else {
2073+
return nullptr;
2074+
}
2075+
}
2076+
};
2077+
18342078
// Argmax and Argmin
18352079

18362080
/* = Search reduction using reduce_over_group*/

0 commit comments

Comments
 (0)