Skip to content

Commit 17b81ff

Browse files
committed
[SYCL][Matrix] Fix __spirv_JointMatrixINTEL signature
Default SYCL_EXT_ONEAPI_MATRIX to 1 Only add 'Use' parameter if testing macro __SYCL_EXT_ONEAPI_MATRIX_USE__ is defined. Signed-off-by: Sidorov, Dmitry <[email protected]>
1 parent 9f89247 commit 17b81ff

12 files changed

+60
-50
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,20 @@
2222
#endif
2323

2424
#ifdef __SYCL_DEVICE_ONLY__
25+
26+
#ifdef __SYCL_EXT_ONEAPI_MATRIX_USE__
27+
#define JOINT_MATRIX_INTEL(T, R, C, L, S, U) \
28+
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U>
29+
#else
30+
#define JOINT_MATRIX_INTEL(T, R, C, L, S, U) \
31+
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S>
32+
#endif // __SYCL_EXT_ONEAPI_MATRIX_USE__
33+
2534
template <typename T, std::size_t R, std::size_t C,
2635
__spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
2736
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
2837
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
29-
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
38+
extern SYCL_EXTERNAL JOINT_MATRIX_INTEL(T, R, C, L, S, U) *
3039
__spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride,
3140
__spv::MatrixLayout Layout = L,
3241
__spv::Scope::Flag Sc = S, int MemOperand = 0);
@@ -36,7 +45,7 @@ template <typename T, std::size_t R, std::size_t C,
3645
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
3746
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
3847
extern SYCL_EXTERNAL void __spirv_JointMatrixStoreINTEL(
39-
T *Ptr, __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *Object,
48+
T *Ptr, JOINT_MATRIX_INTEL(T, R, C, L, S, U) *Object,
4049
std::size_t Stride, __spv::MatrixLayout Layout = L,
4150
__spv::Scope::Flag Sc = S, int MemOperand = 0);
4251

@@ -48,11 +57,11 @@ template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
4857
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
4958
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
5059
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
51-
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T2, M, N, LC, S, UC> *
60+
extern SYCL_EXTERNAL JOINT_MATRIX_INTEL(T2, M, N, LC, S, UC) *
5261
__spirv_JointMatrixMadINTEL(
53-
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S, UA> *A,
54-
__spv::__spirv_JointMatrixINTEL<T1, K, N, LB, S, UB> *B,
55-
__spv::__spirv_JointMatrixINTEL<T2, M, N, LC, S, UC> *C,
62+
JOINT_MATRIX_INTEL(T1, M, K, LA, S, UA) *A,
63+
JOINT_MATRIX_INTEL(T1, K, N, LB, S, UB) *B,
64+
JOINT_MATRIX_INTEL(T2, M, N, LC, S, UC) *C,
5665
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
5766

5867
template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
@@ -63,11 +72,11 @@ template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
6372
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
6473
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
6574
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
66-
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *
75+
extern SYCL_EXTERNAL JOINT_MATRIX_INTEL(T2, M, N, LC, S, UC) *
6776
__spirv_JointMatrixUUMadINTEL(
68-
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S, UA> *A,
69-
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S, UB> *B,
70-
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *C,
77+
JOINT_MATRIX_INTEL(T1, M, K, LA, S, UA) *A,
78+
JOINT_MATRIX_INTEL(T2, K, N, LB, S, UB) *B,
79+
JOINT_MATRIX_INTEL(T3, M, N, LC, S, UC) *C,
7180
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
7281

7382
template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
@@ -78,11 +87,11 @@ template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
7887
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
7988
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
8089
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
81-
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *
90+
extern SYCL_EXTERNAL JOINT_MATRIX_INTEL(T3, M, N, LC, S, UC) *
8291
__spirv_JointMatrixUSMadINTEL(
83-
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S, UA> *A,
84-
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S, UB> *B,
85-
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *C,
92+
JOINT_MATRIX_INTEL(T1, M, K, LA, S, UA) *A,
93+
JOINT_MATRIX_INTEL(T2, K, N, LB, S, UB) *B,
94+
JOINT_MATRIX_INTEL(T3, M, N, LC, S, UC) *C,
8695
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
8796

8897
template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
@@ -93,38 +102,39 @@ template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
93102
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
94103
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
95104
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
96-
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *
105+
extern SYCL_EXTERNAL JOINT_MATRIX_INTEL(T3, M, N, LC, S, UC) *
97106
__spirv_JointMatrixSUMadINTEL(
98-
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S, UA> *A,
99-
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S, UB> *B,
100-
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *C,
107+
JOINT_MATRIX_INTEL(T1, M, K, LA, S, UA) *A,
108+
JOINT_MATRIX_INTEL(T2, K, N, LB, S, UB) *B,
109+
JOINT_MATRIX_INTEL(T3, M, N, LC, S, UC) *C,
101110
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
102111

103112
template <typename T, std::size_t R, std::size_t C,
104113
__spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
105114
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
106115
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
107-
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
116+
extern SYCL_EXTERNAL JOINT_MATRIX_INTEL(T, R, C, L, S, U) *
108117
__spirv_CompositeConstruct(const T v);
109118

110119
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
111120
__spv::MatrixLayout L,
112121
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
113122
extern SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL(
114-
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *);
123+
JOINT_MATRIX_INTEL(T, R, C, L, S, U) *);
115124

116125
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
117126
__spv::MatrixLayout L,
118127
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
119128
extern SYCL_EXTERNAL T __spirv_VectorExtractDynamic(
120-
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *, size_t i);
129+
JOINT_MATRIX_INTEL(T, R, C, L, S, U) *, size_t i);
121130

122131
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
123132
__spv::MatrixLayout L,
124133
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
125-
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
126-
__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *,
134+
extern SYCL_EXTERNAL JOINT_MATRIX_INTEL(T, R, C, L, S, U) *
135+
__spirv_VectorInsertDynamic(JOINT_MATRIX_INTEL(T, R, C, L, S, U) *,
127136
T val, size_t i);
137+
#undef JOINT_MATRIX_INTEL
128138

129139
#ifndef __SPIRV_BUILTIN_DECLARATIONS__
130140
#error \

sycl/include/CL/__spirv/spirv_types.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ enum class MatrixUse : uint32_t {
137137
// information to SPIRV translator.
138138
// The long term solution would be to introduce a matrix type in Clang and use
139139
// it instead of this member.
140+
#ifdef __SYCL_EXT_ONEAPI_MATRIX_USE__
140141
template <typename T, std::size_t R, std::size_t C, MatrixLayout L,
141142
Scope::Flag S = Scope::Flag::Subgroup,
142143
MatrixUse U = MatrixUse::Unnecessary>
@@ -145,6 +146,14 @@ struct __spirv_JointMatrixINTEL {
145146
[R][C][static_cast<size_t>(L) + 1][static_cast<size_t>(S) + 1]
146147
[static_cast<size_t>(U) + 1];
147148
};
149+
#else
150+
template <typename T, std::size_t R, std::size_t C, MatrixLayout L,
151+
Scope::Flag S = Scope::Flag::Subgroup>
152+
struct __spirv_JointMatrixINTEL {
153+
T(*Value)
154+
[R][C][static_cast<size_t>(L) + 1][static_cast<size_t>(S) + 1];
155+
};
156+
#endif // __SYCL_EXT_ONEAPI_MATRIX_USE__
148157

149158
} // namespace __spv
150159

sycl/include/sycl/ext/oneapi/matrix/matrix.hpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,14 @@
1616

1717
#include <sycl/feature_test.hpp>
1818

19-
// the default is matrix-jit-use but existing tests in llvm-test-suite won't
20-
// fail because we have the "unnecessary" use value
21-
#if (SYCL_EXT_ONEAPI_MATRIX == 1)
22-
#include <sycl/ext/oneapi/matrix/matrix-jit.hpp>
23-
#include <sycl/ext/oneapi/matrix/static-query.hpp>
19+
#if (SYCL_EXT_ONEAPI_MATRIX == 3)
20+
#include <sycl/ext/oneapi/matrix/matrix-tensorcore.hpp>
2421
#endif
25-
#if (SYCL_EXT_ONEAPI_MATRIX == 2)
22+
#elif (SYCL_EXT_ONEAPI_MATRIX == 2 || __SYCL_EXT_ONEAPI_MATRIX_USE__)
2623
#include <sycl/ext/oneapi/matrix/matrix-jit-use.hpp>
2724
#include <sycl/ext/oneapi/matrix/static-query-use.hpp>
2825
#endif
29-
#if (SYCL_EXT_ONEAPI_MATRIX == 3)
30-
#include <sycl/ext/oneapi/matrix/matrix-tensorcore.hpp>
26+
#elif (SYCL_EXT_ONEAPI_MATRIX == 1)
27+
#include <sycl/ext/oneapi/matrix/matrix-jit.hpp>
28+
#include <sycl/ext/oneapi/matrix/static-query.hpp>
3129
#endif

sycl/include/sycl/feature_test.hpp.in

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,7 @@ __SYCL_INLINE_VER_NAMESPACE(_V1) {
3232
#define SYCL_EXT_INTEL_DEVICE_INFO 3
3333
#define SYCL_EXT_ONEAPI_SUB_GROUP_MASK 1
3434
#define SYCL_EXT_ONEAPI_LOCAL_MEMORY 1
35-
// As for SYCL_EXT_ONEAPI_MATRIX:
36-
// 1- provides AOT initial implementation for AMX for the experimental matrix
37-
// extension
38-
// 2- provides JIT implementation (target agnostic) for the
39-
// experimental matrix extension
40-
#ifndef SYCL_EXT_ONEAPI_MATRIX
41-
#define SYCL_EXT_ONEAPI_MATRIX 2
42-
#endif
35+
#define SYCL_EXT_ONEAPI_MATRIX 1
4336
#define SYCL_EXT_ONEAPI_ASSERT 1
4437
#define SYCL_EXT_ONEAPI_COMPLEX_ALGORITHMS 1
4538
#define SYCL_EXT_ONEAPI_DISCARD_QUEUE_EVENTS 1

sycl/test/matrix/matrix-bf16-test-SG-16.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1
1+
// RUN: %clangxx -fsycl -O2 %s -o %t.out
22
#include <iostream>
33
#include <sycl/sycl.hpp>
44

sycl/test/matrix/matrix-bf16-test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1
1+
// RUN: %clangxx -fsycl -O2 %s -o %t.out
22
#include <iostream>
33
#include <sycl/sycl.hpp>
44

sycl/test/matrix/matrix-bfloat16-test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1
1+
// RUN: %clangxx -fsycl -O2 %s -o %t.out
22
#include <iostream>
33
#include <sycl/sycl.hpp>
44

sycl/test/matrix/matrix-elemwise-ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1
1+
// RUN: %clangxx -fsycl -O2 %s -o %t.out
22

33
#include <iostream>
44
#include <sycl/sycl.hpp>

sycl/test/matrix/matrix-int8-test-SG-16.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1
1+
// RUN: %clangxx -fsycl -O2 %s -o %t.out
22
#include <iostream>
33
#include <sycl/sycl.hpp>
44

sycl/test/matrix/matrix-int8-test-use.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %clangxx -fsycl -fsycl-device-only -O2 -S -emit-llvm -o - %s | FileCheck %s
1+
// RUN: %clangxx -fsycl -fsycl-device-only -D__SYCL_EXT_ONEAPI_MATRIX_USE__ -O2 -S -emit-llvm -o - %s | FileCheck %s
22

33
// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_4_3_0 = type { [12 x [48 x [5 x [4 x [1 x i8]]]]] addrspace(4)* }
44
// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_4_3_2 = type { [12 x [12 x [5 x [4 x [3 x i32]]]]] addrspace(4)* }

sycl/test/matrix/matrix-int8-test.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
// RUN: %clangxx -DSYCL_EXT_ONEAPI_MATRIX=1 -fsycl -fsycl-device-only -O2 -S -emit-llvm -o - %s | FileCheck %s
1+
// RUN: %clangxx -fsycl -fsycl-device-only -O2 -S -emit-llvm -o - %s | FileCheck %s
22

3-
// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3_3 = type { [12 x [48 x [1 x [4 x [4 x i8]]]]] addrspace(4)* }
4-
// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_0_3_3 = type { [12 x [12 x [1 x [4 x [4 x i32]]]]] addrspace(4)* }
5-
// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_3_3_3 = type { [48 x [12 x [4 x [4 x [4 x i8]]]]] addrspace(4)* }
3+
// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3 = type { [12 x [48 x [1 x [4 x i8]]]] addrspace(4)* }
4+
// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_0_3 = type { [12 x [12 x [1 x [4 x i32]]]] addrspace(4)* }
5+
// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_3_3 = type { [48 x [12 x [4 x [4 x i8]]]] addrspace(4)* }
66

77
#include <iostream>
88
#include <sycl/sycl.hpp>

sycl/test/matrix/query.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %clangxx -DSYCL_EXT_ONEAPI_MATRIX=1 -fsycl -o query %s
1+
// RUN: %clangxx -fsycl -o query %s
22
#include <iostream>
33
#include <sycl/sycl.hpp>
44

0 commit comments

Comments
 (0)