Skip to content

Commit 166bbc3

Browse files
authored
[SYCL][CUDA] Implementation of matrix ext using new "unified" interface (#7077)
CUDA backend implementation using the "unified" matrix extension interface. The same interface will be used for a future Intel backend implementation of the matrix extension. - New "unified" interface uses SYCL_EXT_ONEAPI_MATRIX_VERSION=4 - `joint_matrix_load`, `joint_matrix_store`, `joint_matrix_mad` and `joint_matrix` interfaces match the new spec from #6662 - Separated `joint_matrix_*` functions into new header matrix-unified.hpp: Intel backend implementations can be called from the same functions in the future. - C++17 everywhere in line with #6678 - Updated device code tests to use new interfaces - Completely removed uint16 implementations that are replaced by bfloat16 that is being moved out of the experimental namespace - Updated all CUDA runtime matrix tests here: intel/llvm-test-suite#1183 Signed-off-by: JackAKirk <[email protected]>
1 parent 5de8d26 commit 166bbc3

11 files changed

+1209
-418
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp renamed to sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores-legacy.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===---- matrix-tensorcore.hpp - SYCL tensor cores matrix ----*- C++ -*---===//
1+
//===-------------- matrix-tensorcores-legacy.hpp - -----------*- C++ -*---===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.

sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp

Lines changed: 639 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
//===------- matrix-unified.hpp - SYCL matrix extension ----*- C++ -*------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
// ===--------------------------------------------------------------------=== //
8+
9+
#pragma once
10+
#include <sycl/ext/oneapi/matrix/matrix-tensorcores.hpp>
11+
12+
namespace sycl {
13+
__SYCL_INLINE_VER_NAMESPACE(_V1) {
14+
namespace ext {
15+
namespace oneapi {
16+
namespace experimental {
17+
namespace matrix {
18+
19+
template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
20+
layout Layout>
21+
struct joint_matrix {
22+
23+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
24+
// TODO: Intel case here: we use the ext_oneapi_cuda case also for the host,
25+
// because the Intel SPIRV functions will not be host compilable.
26+
#else
27+
sycl::ext::oneapi::detail::joint_matrix_cuda<T, Use, Rows, Cols, Layout>
28+
cuda_impl;
29+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
30+
31+
joint_matrix() {
32+
#ifndef __SYCL_DEVICE_ONLY__
33+
throw runtime_error("joint matrix is not supported on host device.",
34+
PI_ERROR_INVALID_DEVICE);
35+
#endif
36+
}
37+
};
38+
39+
template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
40+
layout Layout>
41+
inline __SYCL_ALWAYS_INLINE wi_data<Group, T, Use, Rows, Cols, Layout>
42+
get_wi_data(Group sg, joint_matrix<Group, T, Use, Rows, Cols, Layout> &jm) {
43+
#if defined(__SYCL_DEVICE_ONLY__)
44+
#if defined(__NVPTX__)
45+
std::ignore = sg;
46+
return wi_data(jm);
47+
#else
48+
// TODO add Intel impl.
49+
#endif // defined(__NVPTX__)
50+
#endif // defined(__SYCL_DEVICE_ONLY__)
51+
}
52+
53+
template <typename Group, typename T, size_t NumRows, size_t NumCols, use Use,
54+
layout Layout, typename T2>
55+
inline __SYCL_ALWAYS_INLINE void
56+
joint_matrix_fill(Group sg,
57+
joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &res,
58+
const T2 &v) {
59+
#if defined(__SYCL_DEVICE_ONLY__)
60+
#if defined(__NVPTX__)
61+
std::ignore = sg;
62+
res.cuda_impl.wi_marray = v;
63+
#endif // defined(__NVPTX__)
64+
#else
65+
std::ignore = sg;
66+
std::ignore = res;
67+
std::ignore = v;
68+
throw runtime_error(
69+
"This version of the matrix extension is only currently supported on "
70+
"Nvidia devices",
71+
PI_ERROR_INVALID_DEVICE);
72+
#endif // defined(__SYCL_DEVICE_ONLY__)
73+
}
74+
75+
template <
76+
typename Group, typename S, typename T, size_t NumRows, size_t NumCols,
77+
access::address_space Space, access::decorated IsDecorated,
78+
std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value, bool> =
79+
true>
80+
inline __SYCL_ALWAYS_INLINE void joint_matrix_load(
81+
Group sg,
82+
joint_matrix<Group, S, use::accumulator, NumRows, NumCols,
83+
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res,
84+
multi_ptr<T, Space, IsDecorated> src, size_t stride,
85+
sycl::ext::oneapi::experimental::matrix::layout Layout) {
86+
#if defined(__SYCL_DEVICE_ONLY__)
87+
#if defined(__NVPTX__)
88+
std::ignore = sg;
89+
sycl::ext::oneapi::detail::load_accumulator_cuda(res.cuda_impl, src, stride,
90+
Layout);
91+
#endif // defined(__NVPTX__)
92+
#else
93+
std::ignore = sg;
94+
std::ignore = res;
95+
std::ignore = src;
96+
std::ignore = stride;
97+
throw runtime_error(
98+
"This version of the matrix extension is only currently supported on "
99+
"Nvidia devices",
100+
PI_ERROR_INVALID_DEVICE);
101+
#endif // defined(__SYCL_DEVICE_ONLY__)
102+
}
103+
104+
template <
105+
typename Group, typename S, typename T, use Use, size_t NumRows,
106+
size_t NumCols, matrix::layout Layout, access::address_space Space,
107+
access::decorated IsDecorated,
108+
std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
109+
(std::is_same<S, precision::tf32>::value &&
110+
std::is_same<std::remove_const_t<T>, float>::value),
111+
bool> = true>
112+
inline __SYCL_ALWAYS_INLINE void
113+
joint_matrix_load(Group sg,
114+
joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &res,
115+
multi_ptr<T, Space, IsDecorated> src, size_t stride) {
116+
#if defined(__SYCL_DEVICE_ONLY__)
117+
#if defined(__NVPTX__)
118+
std::ignore = sg;
119+
sycl::ext::oneapi::detail::load_multiplicand_cuda<S, T, NumRows, NumCols, Use,
120+
Layout, Space>(
121+
res.cuda_impl, src, stride);
122+
#endif // defined(__NVPTX__)
123+
#else
124+
std::ignore = sg;
125+
std::ignore = res;
126+
std::ignore = src;
127+
std::ignore = stride;
128+
throw runtime_error(
129+
"This version of the matrix extension is only currently supported on "
130+
"Nvidia devices",
131+
PI_ERROR_INVALID_DEVICE);
132+
#endif // defined(__SYCL_DEVICE_ONLY__)
133+
}
134+
135+
template <typename Group, typename T, size_t NumRows, size_t NumCols,
136+
access::address_space Space, access::decorated IsDecorated>
137+
inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
138+
Group sg,
139+
joint_matrix<Group, T, use::accumulator, NumRows, NumCols,
140+
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src,
141+
multi_ptr<T, Space, IsDecorated> dst, size_t stride,
142+
sycl::ext::oneapi::experimental::matrix::layout Layout) {
143+
#if defined(__SYCL_DEVICE_ONLY__)
144+
#if defined(__NVPTX__)
145+
std::ignore = sg;
146+
sycl::ext::oneapi::detail::joint_matrix_store_cuda<T, NumRows, NumCols,
147+
Space>(src.cuda_impl, dst,
148+
stride, Layout);
149+
#endif // defined(__NVPTX__)
150+
#else
151+
std::ignore = sg;
152+
std::ignore = src;
153+
std::ignore = dst;
154+
std::ignore = stride;
155+
throw runtime_error(
156+
"This version of the matrix extension is only currently supported on "
157+
"Nvidia devices",
158+
PI_ERROR_INVALID_DEVICE);
159+
#endif // defined(__SYCL_DEVICE_ONLY__)
160+
}
161+
162+
template <typename Group, typename Ta, typename Tb, typename Tc, std::size_t M,
163+
std::size_t K, std::size_t N, layout LayoutA, layout LayoutB>
164+
inline __SYCL_ALWAYS_INLINE
165+
joint_matrix<Group, Tc, use::accumulator, M, N,
166+
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
167+
joint_matrix_mad(
168+
Group sg, joint_matrix<Group, Ta, use::a, M, K, LayoutA> &A,
169+
joint_matrix<Group, Tb, use::b, K, N, LayoutB> &B,
170+
joint_matrix<Group, Tc, use::accumulator, M, N,
171+
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
172+
&C) {
173+
#if defined(__SYCL_DEVICE_ONLY__)
174+
#if defined(__NVPTX__)
175+
std::ignore = sg;
176+
if constexpr (std::is_same<Ta, Tb>::value) {
177+
joint_matrix<Group, Tc, use::accumulator, M, N,
178+
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
179+
D;
180+
sycl::ext::oneapi::detail::joint_matrix_mad_cuda<Ta, Tc, M, K, N, LayoutA,
181+
LayoutB>(
182+
D.cuda_impl, A.cuda_impl, B.cuda_impl, C.cuda_impl);
183+
return D;
184+
} else {
185+
assert(false && "Ta != Tb : In the CUDA backend joint_matrix_mad "
186+
"requires that joint_matrix data types Ta and Tb match");
187+
}
188+
#endif // defined(__NVPTX__)
189+
#else
190+
std::ignore = sg;
191+
std::ignore = A;
192+
std::ignore = B;
193+
std::ignore = C;
194+
throw runtime_error(
195+
"This version of the matrix extension is only currently supported on "
196+
"Nvidia devices",
197+
PI_ERROR_INVALID_DEVICE);
198+
#endif // defined(__SYCL_DEVICE_ONLY__)
199+
}
200+
201+
// This function rounds the bottom 13 bits up or down, and then zeros out the
202+
// bottom bits
203+
inline __SYCL_ALWAYS_INLINE float round_to_tf32(float &a) {
204+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
205+
int32_t tmp_int = __nvvm_f2tf32_rna(a);
206+
return __nvvm_bitcast_i2f(tmp_int);
207+
#else
208+
uint32_t tmp_uint = reinterpret_cast<uint32_t &>(a);
209+
tmp_uint += 0x1000u;
210+
tmp_uint &= 0xFFFFE000u;
211+
float ret = reinterpret_cast<float &>(tmp_uint);
212+
return ret;
213+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
214+
}
215+
216+
} // namespace matrix
217+
} // namespace experimental
218+
} // namespace oneapi
219+
} // namespace ext
220+
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
221+
} // namespace sycl

sycl/include/sycl/ext/oneapi/matrix/matrix.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,8 @@
2727
#include <sycl/ext/oneapi/matrix/static-query-use.hpp>
2828
#endif // SYCL_EXT_ONEAPI_MATRIX_VERSION
2929
#if (SYCL_EXT_ONEAPI_MATRIX_VERSION == 3)
30-
#include <sycl/ext/oneapi/matrix/matrix-tensorcore.hpp>
30+
#include <sycl/ext/oneapi/matrix/matrix-tensorcores-legacy.hpp>
31+
#endif // SYCL_EXT_ONEAPI_MATRIX_VERSION
32+
#if (SYCL_EXT_ONEAPI_MATRIX_VERSION == 4)
33+
#include <sycl/ext/oneapi/matrix/matrix-unified.hpp>
3134
#endif // SYCL_EXT_ONEAPI_MATRIX_VERSION

0 commit comments

Comments
 (0)