Skip to content

Commit 572aa5c

Browse files
OuadiElfaroukiAlcpzjoeatodd
authored
[SYCL][COMPAT] Add math extend_v*2 to SYCLCompat (#13953)
This PR adds math `extend_v*2` operators _(18 in total)_ along with unit-tests for signed and unsigned `int32` cases. --------- Co-authored-by: Alberto Cabrera Pérez <[email protected]> Co-authored-by: Joe Todd <[email protected]>
1 parent cad941f commit 572aa5c

File tree

3 files changed

+977
-12
lines changed

3 files changed

+977
-12
lines changed

sycl/doc/syclcompat/README.md

Lines changed: 239 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1771,7 +1771,7 @@ struct sub_sat {
17711771
} // namespace syclcompat
17721772
```
17731773

1774-
Finally, the math header provides a set of functions to extend 32-bit operations
1774+
The math header provides a set of functions to extend 32-bit operations
17751775
to 33 bit, and handle sign extension internally. There is support for `add`,
17761776
`sub`, `absdiff`, `min` and `max` operations. Each operation provides overloads
17771777
to include a second, separate, `BinaryOperation` after the first, and include
@@ -1855,6 +1855,244 @@ inline constexpr RetT extend_max_sat(AT a, BT b, CT c,
18551855
BinaryOperation second_op);
18561856
```
18571857
1858+
Another set of vectorized extend 32-bit operations is provided in the math
1859+
header.These APIs treat each of the 32-bit operands as 2-elements vector
1860+
(16-bits each) while handling sign extension to 17-bits internally. There is
1861+
support for `add`, `sub`, `absdiff`, `min`, `max` and `avg` binary operations.
1862+
Each operation provides has a `_sat` variat which determines if the returning
1863+
value is saturated or not, and a `_add` variant that computes the binary sum
1864+
of the the initial operation outputs and a third operand.
1865+
1866+
```cpp
1867+
/// Compute vectorized addition of \p a and \p b, with each value treated as a
1868+
/// 2 elements vector type and extend each element to 17 bit.
1869+
/// \tparam [in] RetT The type of the return value, can only be 32 bit integer
1870+
/// \tparam [in] AT The type of the first value, can only be 32 bit integer
1871+
/// \tparam [in] BT The type of the second value, can only be 32 bit integer
1872+
/// \param [in] a The first value
1873+
/// \param [in] b The second value
1874+
/// \param [in] c The third value
1875+
/// \returns The extend vectorized addition of the two values
1876+
template <typename RetT, typename AT, typename BT>
1877+
inline constexpr RetT extend_vadd2(AT a, BT b, RetT c);
1878+
1879+
/// Compute vectorized addition of \p a and \p b, with each value treated as a 2
1880+
/// elements vector type and extend each element to 17 bit. Then add each half
1881+
/// of the result and add with \p c.
1882+
/// \tparam [in] RetT The type of the return value, can only be 32 bit integer
1883+
/// \tparam [in] AT The type of the first value, can only be 32 bit integer
1884+
/// \tparam [in] BT The type of the second value, can only be 32 bit integer
1885+
/// \param [in] a The first value
1886+
/// \param [in] b The second value
1887+
/// \param [in] c The third value
1888+
/// \returns The addition of each half of extend vectorized addition of the two
1889+
/// values and the third value
1890+
template <typename RetT, typename AT, typename BT>
1891+
inline constexpr RetT extend_vadd2_add(AT a, BT b, RetT c);
1892+
1893+
/// Compute vectorized addition of \p a and \p b with saturation, with each
1894+
/// value treated as a 2 elements vector type and extend each element to 17 bit.
1895+
/// \tparam [in] RetT The type of the return value, can only be 32 bit integer
1896+
/// \tparam [in] AT The type of the first value, can only be 32 bit integer
1897+
/// \tparam [in] BT The type of the second value, can only be 32 bit integer
1898+
/// \param [in] a The first value
1899+
/// \param [in] b The second value
1900+
/// \param [in] c The third value
1901+
/// \returns The extend vectorized addition of the two values with saturation
1902+
template <typename RetT, typename AT, typename BT>
1903+
inline constexpr RetT extend_vadd2_sat(AT a, BT b, RetT c);
1904+
1905+
/// Compute vectorized subtraction of \p a and \p b, with each value treated as
1906+
/// a 2 elements vector type and extend each element to 17 bit.
1907+
/// \tparam [in] RetT The type of the return value, can only be 32 bit integer
1908+
/// \tparam [in] AT The type of the first value, can only be 32 bit integer
1909+
/// \tparam [in] BT The type of the second value, can only be 32 bit integer
1910+
/// \param [in] a The first value
1911+
/// \param [in] b The second value
1912+
/// \param [in] c The third value
1913+
/// \returns The extend vectorized subtraction of the two values
1914+
template <typename RetT, typename AT, typename BT>
1915+
inline constexpr RetT extend_vsub2(AT a, BT b, RetT c);
1916+
1917+
/// Compute vectorized subtraction of \p a and \p b, with each value treated as
1918+
/// a 2 elements vector type and extend each element to 17 bit. Then add each
1919+
/// half of the result and add with \p c.
1920+
/// \tparam [in] RetT The type of the return value, can only be 32 bit integer
1921+
/// \tparam [in] AT The type of the first value, can only be 32 bit integer
1922+
/// \tparam [in] BT The type of the second value, can only be 32 bit integer
1923+
/// \param [in] a The first value
1924+
/// \param [in] b The second value
1925+
/// \param [in] c The third value
1926+
/// \returns The addition of each half of extend vectorized subtraction of the
1927+
/// two values and the third value
1928+
template <typename RetT, typename AT, typename BT>
1929+
inline constexpr RetT extend_vsub2_add(AT a, BT b, RetT c);
1930+
1931+
/// Compute vectorized subtraction of \p a and \p b with saturation, with each
1932+
/// value treated as a 2 elements vector type and extend each element to 17 bit.
1933+
/// \tparam [in] RetT The type of the return value, can only be 32 bit integer
1934+
/// \tparam [in] AT The type of the first value, can only be 32 bit integer
1935+
/// \tparam [in] BT The type of the second value, can only be 32 bit integer
1936+
/// \param [in] a The first value
1937+
/// \param [in] b The second value
1938+
/// \param [in] c The third value
1939+
/// \returns The extend vectorized subtraction of the two values with saturation
1940+
template <typename RetT, typename AT, typename BT>
1941+
inline constexpr RetT extend_vsub2_sat(AT a, BT b, RetT c);
1942+
1943+
/// Compute vectorized abs_diff of \p a and \p b, with each value treated as a 2
1944+
/// elements vector type and extend each element to 17 bit.
1945+
/// \tparam [in] RetT The type of the return value, can only be 32 bit integer
1946+
/// \tparam [in] AT The type of the first value, can only be 32 bit integer
1947+
/// \tparam [in] BT The type of the second value, can only be 32 bit integer
1948+
/// \param [in] a The first value
1949+
/// \param [in] b The second value
1950+
/// \param [in] c The third value
1951+
/// \returns The extend vectorized abs_diff of the two values
1952+
template <typename RetT, typename AT, typename BT>
1953+
inline constexpr RetT extend_vabsdiff2(AT a, BT b, RetT c);
1954+
1955+
/// Compute vectorized abs_diff of \p a and \p b, with each value treated as a 2
1956+
/// elements vector type and extend each element to 17 bit. Then add each half
1957+
/// of the result and add with \p c.
1958+
/// \tparam [in] RetT The type of the return value, can only be 32 bit integer
1959+
/// \tparam [in] AT The type of the first value, can only be 32 bit integer
1960+
/// \tparam [in] BT The type of the second value, can only be 32 bit integer
1961+
/// \param [in] a The first value
1962+
/// \param [in] b The second value
1963+
/// \param [in] c The third value
1964+
/// \returns The addition of each half of extend vectorized abs_diff of the
1965+
/// two values and the third value
1966+
template <typename RetT, typename AT, typename BT>
1967+
inline constexpr RetT extend_vabsdiff2_add(AT a, BT b, RetT c);
1968+
1969+
/// Compute vectorized abs_diff of \p a and \p b with saturation, with each
1970+
/// value treated as a 2 elements vector type and extend each element to 17 bit.
1971+
/// \tparam [in] RetT The type of the return value, can only be 32 bit integer
1972+
/// \tparam [in] AT The type of the first value, can only be 32 bit integer
1973+
/// \tparam [in] BT The type of the second value, can only be 32 bit integer
1974+
/// \param [in] a The first value
1975+
/// \param [in] b The second value
1976+
/// \param [in] c The third value
1977+
/// \returns The extend vectorized abs_diff of the two values with saturation
1978+
template <typename RetT, typename AT, typename BT>
1979+
inline constexpr RetT extend_vabsdiff2_sat(AT a, BT b, RetT c);
1980+
1981+
/// Compute vectorized minimum of \p a and \p b, with each value treated as a 2
1982+
/// elements vector type and extend each element to 17 bit.
1983+
/// \tparam [in] RetT The type of the return value, can only be 32 bit integer
1984+
/// \tparam [in] AT The type of the first value, can only be 32 bit integer
1985+
/// \tparam [in] BT The type of the second value, can only be 32 bit integer
1986+
/// \param [in] a The first value
1987+
/// \param [in] b The second value
1988+
/// \param [in] c The third value
1989+
/// \returns The extend vectorized minimum of the two values
1990+
template <typename RetT, typename AT, typename BT>
1991+
inline constexpr RetT extend_vmin2(AT a, BT b, RetT c);
1992+
1993+
/// Compute vectorized minimum of \p a and \p b, with each value treated as a 2
1994+
/// elements vector type and extend each element to 17 bit. Then add each half
1995+
/// of the result and add with \p c.
1996+
/// \tparam [in] RetT The type of the return value, can only be 32 bit integer
1997+
/// \tparam [in] AT The type of the first value, can only be 32 bit integer
1998+
/// \tparam [in] BT The type of the second value, can only be 32 bit integer
1999+
/// \param [in] a The first value
2000+
/// \param [in] b The second value
2001+
/// \param [in] c The third value
2002+
/// \returns The addition of each half of extend vectorized minimum of the
2003+
/// two values and the third value
2004+
template <typename RetT, typename AT, typename BT>
2005+
inline constexpr RetT extend_vmin2_add(AT a, BT b, RetT c);
2006+
2007+
/// Compute vectorized minimum of \p a and \p b with saturation, with each value
2008+
/// treated as a 2 elements vector type and extend each element to 17 bit.
2009+
/// \tparam [in] RetT The type of the return value, can only be 32 bit integer
2010+
/// \tparam [in] AT The type of the first value, can only be 32 bit integer
2011+
/// \tparam [in] BT The type of the second value, can only be 32 bit integer
2012+
/// \param [in] a The first value
2013+
/// \param [in] b The second value
2014+
/// \param [in] c The third value
2015+
/// \returns The extend vectorized minimum of the two values with saturation
2016+
template <typename RetT, typename AT, typename BT>
2017+
inline constexpr RetT extend_vmin2_sat(AT a, BT b, RetT c);
2018+
2019+
/// Compute vectorized maximum of \p a and \p b, with each value treated as a 2
2020+
/// elements vector type and extend each element to 17 bit.
2021+
/// \tparam [in] RetT The type of the return value, can only be 32 bit integer
2022+
/// \tparam [in] AT The type of the first value, can only be 32 bit integer
2023+
/// \tparam [in] BT The type of the second value, can only be 32 bit integer
2024+
/// \param [in] a The first value
2025+
/// \param [in] b The second value
2026+
/// \param [in] c The third value
2027+
/// \returns The extend vectorized maximum of the two values
2028+
template <typename RetT, typename AT, typename BT>
2029+
inline constexpr RetT extend_vmax2(AT a, BT b, RetT c);
2030+
2031+
/// Compute vectorized maximum of \p a and \p b, with each value treated as a 2
2032+
/// elements vector type and extend each element to 17 bit. Then add each half
2033+
/// of the result and add with \p c.
2034+
/// \tparam [in] RetT The type of the return value, can only be 32 bit integer
2035+
/// \tparam [in] AT The type of the first value, can only be 32 bit integer
2036+
/// \tparam [in] BT The type of the second value, can only be 32 bit integer
2037+
/// \param [in] a The first value
2038+
/// \param [in] b The second value
2039+
/// \param [in] c The third value
2040+
/// \returns The addition of each half of extend vectorized maximum of the
2041+
/// two values and the third value
2042+
template <typename RetT, typename AT, typename BT>
2043+
inline constexpr RetT extend_vmax2_add(AT a, BT b, RetT c);
2044+
2045+
/// Compute vectorized maximum of \p a and \p b with saturation, with each value
2046+
/// treated as a 2 elements vector type and extend each element to 17 bit.
2047+
/// \tparam [in] RetT The type of the return value, can only be 32 bit integer
2048+
/// \tparam [in] AT The type of the first value, can only be 32 bit integer
2049+
/// \tparam [in] BT The type of the second value, can only be 32 bit integer
2050+
/// \param [in] a The first value
2051+
/// \param [in] b The second value
2052+
/// \param [in] c The third value
2053+
/// \returns The extend vectorized maximum of the two values with saturation
2054+
template <typename RetT, typename AT, typename BT>
2055+
inline constexpr RetT extend_vmax2_sat(AT a, BT b, RetT c);
2056+
2057+
/// Compute vectorized average of \p a and \p b, with each value treated as a 2
2058+
/// elements vector type and extend each element to 17 bit.
2059+
/// \tparam [in] RetT The type of the return value, can only be 32 bit integer
2060+
/// \tparam [in] AT The type of the first value, can only be 32 bit integer
2061+
/// \tparam [in] BT The type of the second value, can only be 32 bit integer
2062+
/// \param [in] a The first value
2063+
/// \param [in] b The second value
2064+
/// \param [in] c The third value
2065+
/// \returns The extend vectorized average of the two values
2066+
template <typename RetT, typename AT, typename BT>
2067+
inline constexpr RetT extend_vavrg2(AT a, BT b, RetT c);
2068+
2069+
/// Compute vectorized average of \p a and \p b, with each value treated as a 2
2070+
/// elements vector type and extend each element to 17 bit. Then add each half
2071+
/// of the result and add with \p c.
2072+
/// \tparam [in] RetT The type of the return value, can only be 32 bit integer
2073+
/// \tparam [in] AT The type of the first value, can only be 32 bit integer
2074+
/// \tparam [in] BT The type of the second value, can only be 32 bit integer
2075+
/// \param [in] a The first value
2076+
/// \param [in] b The second value
2077+
/// \param [in] c The third value
2078+
/// \returns The addition of each half of extend average maximum of the
2079+
/// two values and the third value
2080+
template <typename RetT, typename AT, typename BT>
2081+
inline constexpr RetT extend_vavrg2_add(AT a, BT b, RetT c);
2082+
2083+
/// Compute vectorized average of \p a and \p b with saturation, with each value
2084+
/// treated as a 2 elements vector type and extend each element to 17 bit.
2085+
/// \tparam [in] RetT The type of the return value, can only be 32 bit integer
2086+
/// \tparam [in] AT The type of the first value, can only be 32 bit integer
2087+
/// \tparam [in] BT The type of the second value, can only be 32 bit integer
2088+
/// \param [in] a The first value
2089+
/// \param [in] b The second value
2090+
/// \param [in] c The third value
2091+
/// \returns The extend vectorized average of the two values with saturation
2092+
template <typename RetT, typename AT, typename BT>
2093+
inline constexpr RetT extend_vavrg2_sat(AT a, BT b, RetT c);
2094+
```
2095+
18582096
## Sample Code
18592097

18602098
Below is a simple linear algebra sample, which computes `y = mx + b` implemented

0 commit comments

Comments
 (0)