22
22
#endif
23
23
24
24
#ifdef __SYCL_DEVICE_ONLY__
25
+
26
+ #ifdef __SYCL_EXT_ONEAPI_MATRIX_USE__
27
+ #define JOINT_MATRIX_INTEL (T, R, C, L, S, U ) \
28
+ __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U>
29
+ #else
30
+ #define JOINT_MATRIX_INTEL (T, R, C, L, S, U ) \
31
+ __spv::__spirv_JointMatrixINTEL<T, R, C, L, S>
32
+ #endif // __SYCL_EXT_ONEAPI_MATRIX_USE__
33
+
25
34
template <typename T, std::size_t R, std::size_t C,
26
35
__spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
27
36
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
28
37
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
29
- extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL< T, R, C, L, S, U> *
38
+ extern SYCL_EXTERNAL JOINT_MATRIX_INTEL ( T, R, C, L, S, U) *
30
39
__spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride,
31
40
__spv::MatrixLayout Layout = L,
32
41
__spv::Scope::Flag Sc = S, int MemOperand = 0 );
@@ -36,7 +45,7 @@ template <typename T, std::size_t R, std::size_t C,
36
45
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
37
46
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
38
47
extern SYCL_EXTERNAL void __spirv_JointMatrixStoreINTEL (
39
- T *Ptr, __spv::__spirv_JointMatrixINTEL< T, R, C, L, S, U> *Object,
48
+ T *Ptr, JOINT_MATRIX_INTEL( T, R, C, L, S, U) *Object,
40
49
std::size_t Stride, __spv::MatrixLayout Layout = L,
41
50
__spv::Scope::Flag Sc = S, int MemOperand = 0);
42
51
@@ -48,11 +57,11 @@ template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
48
57
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
49
58
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
50
59
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
51
- extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL< T2, M, N, LC, S, UC> *
60
+ extern SYCL_EXTERNAL JOINT_MATRIX_INTEL ( T2, M, N, LC, S, UC) *
52
61
__spirv_JointMatrixMadINTEL(
53
- __spv::__spirv_JointMatrixINTEL< T1, M, K, LA, S, UA> *A,
54
- __spv::__spirv_JointMatrixINTEL< T1, K, N, LB, S, UB> *B,
55
- __spv::__spirv_JointMatrixINTEL< T2, M, N, LC, S, UC> *C,
62
+ JOINT_MATRIX_INTEL ( T1, M, K, LA, S, UA) *A,
63
+ JOINT_MATRIX_INTEL( T1, K, N, LB, S, UB) *B,
64
+ JOINT_MATRIX_INTEL( T2, M, N, LC, S, UC) *C,
56
65
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
57
66
58
67
template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
@@ -63,11 +72,11 @@ template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
63
72
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
64
73
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
65
74
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
66
- extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3 , M, N, LC, S, UC> *
75
+ extern SYCL_EXTERNAL JOINT_MATRIX_INTEL (T2 , M, N, LC, S, UC) *
67
76
__spirv_JointMatrixUUMadINTEL(
68
- __spv::__spirv_JointMatrixINTEL< T1, M, K, LA, S, UA> *A,
69
- __spv::__spirv_JointMatrixINTEL< T2, K, N, LB, S, UB> *B,
70
- __spv::__spirv_JointMatrixINTEL< T3, M, N, LC, S, UC> *C,
77
+ JOINT_MATRIX_INTEL ( T1, M, K, LA, S, UA) *A,
78
+ JOINT_MATRIX_INTEL( T2, K, N, LB, S, UB) *B,
79
+ JOINT_MATRIX_INTEL( T3, M, N, LC, S, UC) *C,
71
80
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
72
81
73
82
template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
@@ -78,11 +87,11 @@ template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
78
87
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
79
88
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
80
89
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
81
- extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL< T3, M, N, LC, S, UC> *
90
+ extern SYCL_EXTERNAL JOINT_MATRIX_INTEL ( T3, M, N, LC, S, UC) *
82
91
__spirv_JointMatrixUSMadINTEL(
83
- __spv::__spirv_JointMatrixINTEL< T1, M, K, LA, S, UA> *A,
84
- __spv::__spirv_JointMatrixINTEL< T2, K, N, LB, S, UB> *B,
85
- __spv::__spirv_JointMatrixINTEL< T3, M, N, LC, S, UC> *C,
92
+ JOINT_MATRIX_INTEL ( T1, M, K, LA, S, UA) *A,
93
+ JOINT_MATRIX_INTEL( T2, K, N, LB, S, UB) *B,
94
+ JOINT_MATRIX_INTEL( T3, M, N, LC, S, UC) *C,
86
95
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
87
96
88
97
template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
@@ -93,38 +102,39 @@ template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
93
102
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
94
103
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
95
104
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
96
- extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL< T3, M, N, LC, S, UC> *
105
+ extern SYCL_EXTERNAL JOINT_MATRIX_INTEL ( T3, M, N, LC, S, UC) *
97
106
__spirv_JointMatrixSUMadINTEL(
98
- __spv::__spirv_JointMatrixINTEL< T1, M, K, LA, S, UA> *A,
99
- __spv::__spirv_JointMatrixINTEL< T2, K, N, LB, S, UB> *B,
100
- __spv::__spirv_JointMatrixINTEL< T3, M, N, LC, S, UC> *C,
107
+ JOINT_MATRIX_INTEL ( T1, M, K, LA, S, UA) *A,
108
+ JOINT_MATRIX_INTEL( T2, K, N, LB, S, UB) *B,
109
+ JOINT_MATRIX_INTEL( T3, M, N, LC, S, UC) *C,
101
110
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
102
111
103
112
template <typename T, std::size_t R, std::size_t C,
104
113
__spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
105
114
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
106
115
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
107
- extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL< T, R, C, L, S, U> *
116
+ extern SYCL_EXTERNAL JOINT_MATRIX_INTEL ( T, R, C, L, S, U) *
108
117
__spirv_CompositeConstruct(const T v);
109
118
110
119
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
111
120
__spv::MatrixLayout L,
112
121
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
113
122
extern SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL (
114
- __spv::__spirv_JointMatrixINTEL< T, R, C, L, S, U> *);
123
+ JOINT_MATRIX_INTEL ( T, R, C, L, S, U) *);
115
124
116
125
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
117
126
__spv::MatrixLayout L,
118
127
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
119
128
extern SYCL_EXTERNAL T __spirv_VectorExtractDynamic (
120
- __spv::__spirv_JointMatrixINTEL< T, R, C, L, S, U> *, size_t i);
129
+ JOINT_MATRIX_INTEL ( T, R, C, L, S, U) *, size_t i);
121
130
122
131
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
123
132
__spv::MatrixLayout L,
124
133
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
125
- extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL< T, R, C, L, S, U> *
126
- __spirv_VectorInsertDynamic (__spv::__spirv_JointMatrixINTEL< T, R, C, L, S, U> *,
134
+ extern SYCL_EXTERNAL JOINT_MATRIX_INTEL ( T, R, C, L, S, U) *
135
+ __spirv_VectorInsertDynamic(JOINT_MATRIX_INTEL( T, R, C, L, S, U) *,
127
136
T val, size_t i);
137
+ #undef JOINT_MATRIX_INTEL
128
138
129
139
#ifndef __SPIRV_BUILTIN_DECLARATIONS__
130
140
#error \
0 commit comments