@@ -17,11 +17,20 @@ enum class matrix_use { a, b, accumulator };
17
17
18
18
enum class matrix_layout { row_major, col_major, packed_a, packed_b };
19
19
20
+ template <typename T, matrix_use Use, size_t Rows = sycl::dynamic_extent,
21
+ size_t Cols = sycl::dynamic_extent,
22
+ matrix_layout Layout = matrix_layout::row_major,
23
+ typename Group = sycl::sub_group, typename Cond = void >
24
+ struct joint_matrix ;
25
+
20
26
template <typename type, size_t size> class wi_data {
21
27
marray<type, size> &data;
28
+ wi_data (marray<type, size> &wi_data) : data(wi_data){};
29
+ template <typename T, matrix_use Use, size_t Rows, size_t Cols,
30
+ matrix_layout Layout, typename Group, typename Cond>
31
+ friend struct joint_matrix ;
22
32
23
33
public:
24
- wi_data (marray<type, size> &wi_data) : data(wi_data){};
25
34
size_t length () {
26
35
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
27
36
return data.size ();
@@ -41,12 +50,6 @@ template <typename type, size_t size> class wi_data {
41
50
};
42
51
};
43
52
44
- template <typename T, matrix_use Use, size_t Rows = sycl::dynamic_extent,
45
- size_t Cols = sycl::dynamic_extent,
46
- matrix_layout Layout = matrix_layout::row_major,
47
- typename Group = sycl::sub_group, typename Cond = void >
48
- struct joint_matrix ;
49
-
50
53
#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR (type, use, M, N, size ) \
51
54
template <matrix_layout Layout> \
52
55
struct joint_matrix < \
0 commit comments