Skip to content

Commit 7f21853

Browse files
[SYCL] Enable matrix_load, matrix_store, and matrix_mad (#4076)
1 parent 91a79d5 commit 7f21853

14 files changed

+937
-38
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,33 @@
2222
#endif
2323

2424
#ifdef __SYCL_DEVICE_ONLY__
25+
template <typename T, std::size_t R, std::size_t C,
26+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
27+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
28+
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S> *
29+
__spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride,
30+
__spv::MatrixLayout Layout = L,
31+
__spv::Scope::Flag Sc = S, int MemOperand = 0);
32+
33+
template <typename T, std::size_t R, std::size_t C,
34+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
35+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
36+
extern SYCL_EXTERNAL void __spirv_JointMatrixStoreINTEL(
37+
T *Ptr, __spv::__spirv_JointMatrixINTEL<T, R, C, L, S> *Object,
38+
std::size_t Stride, __spv::MatrixLayout Layout = L,
39+
__spv::Scope::Flag Sc = S, int MemOperand = 0);
40+
41+
template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
42+
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
43+
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
44+
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
45+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
46+
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T2, M, N, LC, S> *
47+
__spirv_JointMatrixMadINTEL(
48+
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S> *A,
49+
__spv::__spirv_JointMatrixINTEL<T1, K, N, LB, S> *B,
50+
__spv::__spirv_JointMatrixINTEL<T2, M, N, LC, S> *C,
51+
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
2552

2653
#ifndef __SPIRV_BUILTIN_DECLARATIONS__
2754
#error \

sycl/include/CL/__spirv/spirv_types.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#pragma once
1010

11+
#include <cstddef>
1112
#include <cstdint>
1213

1314
// TODO: include the header file with SPIR-V declarations from SPIRV-Headers
@@ -105,6 +106,12 @@ enum class GroupOperation : uint32_t {
105106
ExclusiveScan = 2
106107
};
107108

109+
enum class MatrixLayout { RowMajor, ColumnMajor, PackedA, PackedB };
110+
111+
template <typename T, std::size_t R, std::size_t C, MatrixLayout U,
112+
Scope::Flag S = Scope::Flag::Subgroup>
113+
struct __spirv_JointMatrixINTEL;
114+
108115
} // namespace __spv
109116

110117
#ifdef __SYCL_DEVICE_ONLY__

sycl/include/CL/sycl/ONEAPI/matrix/matrix-amx.hpp

Lines changed: 0 additions & 17 deletions
This file was deleted.

sycl/include/CL/sycl/ONEAPI/matrix/matrix.hpp

Lines changed: 0 additions & 17 deletions
This file was deleted.

sycl/include/CL/sycl/feature_test.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,14 @@ namespace sycl {
1414

1515
// TODO: Move these feature-test macros to compiler driver.
1616
#define SYCL_EXT_INTEL_DEVICE_INFO 2
17-
#define SYCL_EXT_ONEAPI_MATRIX 1
17+
// As for SYCL_EXT_ONEAPI_MATRIX:
18+
// 1- provides AOT initial implementation for AMX for the experimental matrix
19+
// extension
20+
// 2- provides JIT implementation (target agnostic) for the
21+
// experimental matrix extension
22+
#ifndef SYCL_EXT_ONEAPI_MATRIX
23+
#define SYCL_EXT_ONEAPI_MATRIX 2
24+
#endif
1825

1926
} // namespace sycl
2027
} // __SYCL_INLINE_NAMESPACE(cl)
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
//==------------------ matrix.hpp - SYCL matrix ----------------*- C++ -*---==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
// ===--------------------------------------------------------------------=== //
8+
9+
#pragma once
10+
11+
#include <CL/__spirv/spirv_ops.hpp>
12+
#include <CL/sycl/detail/defines_elementary.hpp>
13+
#include <CL/sycl/feature_test.hpp>
14+
15+
__SYCL_INLINE_NAMESPACE(cl) {
16+
namespace sycl {
17+
namespace ext {
18+
namespace oneapi {
19+
namespace experimental::matrix {
20+
21+
enum class matrix_layout { row_major, col_major, packed_a, packed_b };
22+
23+
template <matrix_layout Layout> struct spv_matrix_layout_traits {
24+
static constexpr __spv::MatrixLayout value = __spv::MatrixLayout::RowMajor;
25+
};
26+
27+
#define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT) \
28+
template <> struct spv_matrix_layout_traits<LAYOUT> { \
29+
static constexpr __spv::MatrixLayout value = SPV_LAYOUT; \
30+
};
31+
32+
SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::row_major,
33+
__spv::MatrixLayout::RowMajor)
34+
SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::col_major,
35+
__spv::MatrixLayout::ColumnMajor)
36+
SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::packed_a, __spv::MatrixLayout::PackedA)
37+
SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::packed_b, __spv::MatrixLayout::PackedB)
38+
39+
template <typename G> struct spv_scope_traits {};
40+
template <> struct spv_scope_traits<sycl::sub_group> {
41+
constexpr static auto value = __spv::Scope::Subgroup;
42+
};
43+
template <int D> struct spv_scope_traits<sycl::group<D>> {
44+
constexpr static auto value = __spv::Scope::Workgroup;
45+
};
46+
47+
template <typename T, size_t NumRows, size_t NumCols,
48+
matrix_layout Layout = matrix_layout::row_major,
49+
typename Group = sycl::sub_group>
50+
struct joint_matrix {
51+
public:
52+
__spv::__spirv_JointMatrixINTEL<
53+
T, NumRows, NumCols, spv_matrix_layout_traits<Layout>::value> *spvm;
54+
joint_matrix(Group sg) {
55+
#ifndef __SYCL_DEVICE_ONLY__
56+
(void)sg;
57+
throw runtime_error("joint matrix is not supported on host device.",
58+
PI_INVALID_DEVICE);
59+
#endif // __SYCL_DEVICE_ONLY__
60+
}
61+
};
62+
63+
template <typename Group, typename T, size_t NumRows, size_t NumCols,
64+
matrix_layout Layout = matrix_layout::row_major,
65+
access::address_space Space>
66+
inline __SYCL_ALWAYS_INLINE void
67+
joint_matrix_load(Group sg,
68+
joint_matrix<T, NumRows, NumCols, Layout, Group> &res,
69+
multi_ptr<T, Space> src, size_t stride, matrix_layout MemL) {
70+
#ifdef __SYCL_DEVICE_ONLY__
71+
T *Ptr = src.get();
72+
switch (MemL) {
73+
default:
74+
assert(false && "Invalid Memory Layout!");
75+
case matrix_layout::row_major:
76+
res.spvm =
77+
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
78+
spv_matrix_layout_traits<Layout>::value>(
79+
Ptr, stride, __spv::MatrixLayout::RowMajor,
80+
spv_scope_traits<Group>::value);
81+
break;
82+
case matrix_layout::col_major:
83+
res.spvm =
84+
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
85+
spv_matrix_layout_traits<Layout>::value>(
86+
Ptr, stride, __spv::MatrixLayout::ColumnMajor,
87+
spv_scope_traits<Group>::value);
88+
break;
89+
case matrix_layout::packed_a:
90+
res.spvm =
91+
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
92+
spv_matrix_layout_traits<Layout>::value>(
93+
Ptr, stride, __spv::MatrixLayout::PackedA,
94+
spv_scope_traits<Group>::value);
95+
break;
96+
case matrix_layout::packed_b:
97+
res.spvm =
98+
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
99+
spv_matrix_layout_traits<Layout>::value>(
100+
Ptr, stride, __spv::MatrixLayout::PackedB,
101+
spv_scope_traits<Group>::value);
102+
break;
103+
}
104+
#else
105+
(void)sg;
106+
(void)res;
107+
(void)src;
108+
(void)stride;
109+
(void)MemL;
110+
throw runtime_error("joint matrix is not supported on host device.",
111+
PI_INVALID_DEVICE);
112+
#endif // __SYCL_DEVICE_ONLY__
113+
}
114+
115+
template <typename Group, typename T, size_t NumRows, size_t NumCols,
116+
matrix_layout MatL = matrix_layout::row_major,
117+
access::address_space Space>
118+
inline __SYCL_ALWAYS_INLINE void
119+
joint_matrix_store(Group sg,
120+
joint_matrix<T, NumRows, NumCols, MatL, Group> &src,
121+
multi_ptr<T, Space> res, size_t stride, matrix_layout MemL) {
122+
#ifdef __SYCL_DEVICE_ONLY__
123+
T *Ptr = res.get();
124+
switch (MemL) {
125+
default:
126+
assert(false && "Invalid Memory Layout!");
127+
case matrix_layout::row_major:
128+
__spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
129+
spv_matrix_layout_traits<MatL>::value>(
130+
Ptr, src.spvm, stride, __spv::MatrixLayout::RowMajor,
131+
spv_scope_traits<Group>::value);
132+
break;
133+
case matrix_layout::col_major:
134+
__spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
135+
spv_matrix_layout_traits<MatL>::value>(
136+
Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor,
137+
spv_scope_traits<Group>::value);
138+
break;
139+
case matrix_layout::packed_a:
140+
__spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
141+
spv_matrix_layout_traits<MatL>::value>(
142+
Ptr, src.spvm, stride, __spv::MatrixLayout::PackedA,
143+
spv_scope_traits<Group>::value);
144+
break;
145+
case matrix_layout::packed_b:
146+
__spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
147+
spv_matrix_layout_traits<MatL>::value>(
148+
Ptr, src.spvm, stride, __spv::MatrixLayout::PackedB,
149+
spv_scope_traits<Group>::value);
150+
break;
151+
}
152+
#else
153+
(void)sg;
154+
(void)src;
155+
(void)res;
156+
(void)stride;
157+
(void)MemL;
158+
throw runtime_error("joint matrix is not supported on host device.",
159+
PI_INVALID_DEVICE);
160+
#endif // __SYCL_DEVICE_ONLY__
161+
}
162+
163+
template <typename Group, typename T1, typename T2, size_t M, size_t K,
164+
size_t N, matrix_layout LayoutA, matrix_layout LayoutB,
165+
matrix_layout LayoutC>
166+
inline __SYCL_ALWAYS_INLINE joint_matrix<T2, M, N, LayoutC, Group>
167+
joint_matrix_mad(Group sg, joint_matrix<T1, M, K, LayoutA, Group> &mA,
168+
joint_matrix<T1, K, N, LayoutB, Group> &mB,
169+
joint_matrix<T2, M, N, LayoutC, Group> &mC) {
170+
#ifdef __SYCL_DEVICE_ONLY__
171+
joint_matrix<T2, M, N, LayoutC, Group> res(sg);
172+
res.spvm = __spirv_JointMatrixMadINTEL(mA.spvm, mB.spvm, mC.spvm);
173+
return res;
174+
#else
175+
(void)sg;
176+
(void)mA;
177+
(void)mB;
178+
(void)mC;
179+
throw runtime_error("joint matrix is not supported on host device.",
180+
PI_INVALID_DEVICE);
181+
#endif // __SYCL_DEVICE_ONLY__
182+
}
183+
} // namespace experimental::matrix
184+
} // namespace oneapi
185+
} // namespace ext
186+
} // namespace sycl
187+
} // __SYCL_INLINE_NAMESPACE(cl)

sycl/include/sycl/ext/oneapi/matrix/matrix.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818

1919
#if (SYCL_EXT_ONEAPI_MATRIX == 1)
2020
#if defined(__AMXTILE__) && defined(__AMXINT8__) && defined(__AMXBF16__)
21-
#include <sycl/ext/oneapi/matrix/matrix-amx.hpp>
21+
#include <sycl/ext/oneapi/matrix/matrix-aot-amx.hpp>
2222
#endif
2323
#endif
24+
#if (SYCL_EXT_ONEAPI_MATRIX == 2)
25+
#include <sycl/ext/oneapi/matrix/matrix-jit.hpp>
26+
#endif

sycl/test/matrix/matrix-amx-bf16-test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %clangxx -march=sapphirerapids -fsycl -O2 %s -o %t.out
1+
// RUN: %clangxx -DSYCL_EXT_ONEAPI_MATRIX=1 -march=sapphirerapids -fsycl -O2 %s -o %t.out
22
#include <CL/sycl.hpp>
33
#if (SYCL_EXT_ONEAPI_MATRIX == 1)
44
#include <iostream>

sycl/test/matrix/matrix-amx-int8-test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %clangxx -march=sapphirerapids -fsycl -O2 %s -o %t.out
1+
// RUN: %clangxx -DSYCL_EXT_ONEAPI_MATRIX=1 -march=sapphirerapids -fsycl -O2 %s -o %t.out
22
#include <CL/sycl.hpp>
33
#if (SYCL_EXT_ONEAPI_MATRIX == 1)
44
#include <iostream>

0 commit comments

Comments
 (0)