Skip to content

Commit 854ab7e

Browse files
[SYCL][Joint Matrix] Pass on address space to Load/Store (#9244)
Pass on address space information to SPIR-V Joint Matrix Load/Store intrinsics. --------- Signed-off-by: Sidorov, Dmitry <[email protected]> Co-authored-by: Sidorov, Dmitry <[email protected]>
1 parent ef0d151 commit 854ab7e

File tree

7 files changed

+206
-26
lines changed

7 files changed

+206
-26
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,19 +140,19 @@ extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
140140
__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *,
141141
Ts val, size_t i);
142142
#else
143-
template <typename T, std::size_t R, std::size_t C,
143+
template <typename T, typename Tp, std::size_t R, std::size_t C,
144144
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
145145
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
146-
extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S> *
146+
extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S> *
147147
__spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride,
148148
__spv::MatrixLayout Layout = L,
149149
__spv::Scope::Flag Sc = S, int MemOperand = 0);
150150

151-
template <typename T, std::size_t R, std::size_t C,
151+
template <typename T, typename Tp, std::size_t R, std::size_t C,
152152
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
153153
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
154154
extern __DPCPP_SYCL_EXTERNAL void __spirv_JointMatrixStoreINTEL(
155-
T *Ptr, __spv::__spirv_JointMatrixINTEL<T, R, C, L, S> *Object,
155+
T *Ptr, __spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S> *Object,
156156
std::size_t Stride, __spv::MatrixLayout Layout = L,
157157
__spv::Scope::Flag Sc = S, int MemOperand = 0);
158158

sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include "matrix-unified-utils.hpp"
12+
#include "utils.hpp"
1213
#include <CL/__spirv/spirv_ops.hpp>
1314
#include <sycl/detail/defines_elementary.hpp>
1415
#include <sycl/feature_test.hpp>
@@ -481,6 +482,8 @@ joint_matrix_store(Group sg,
481482
Group, Tp, Use, NumRows, NumCols, Layout> &src,
482483
multi_ptr<T, Space, IsDecorated> dst, size_t stride) {
483484
#if defined(__SYCL_DEVICE_ONLY__)
485+
static_assert(Space != access::address_space::private_space,
486+
"Joint Matrix doesn't support store to private memory!");
484487
#if defined(__NVPTX__)
485488
std::ignore = sg;
486489
std::ignore = src;
@@ -492,8 +495,9 @@ joint_matrix_store(Group sg,
492495
PI_ERROR_INVALID_DEVICE);
493496
#else
494497
// intel's impl
495-
T *Ptr = dst.get();
496-
__spirv_JointMatrixStoreINTEL<T, Tp, NumRows, NumCols,
498+
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
499+
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(dst);
500+
__spirv_JointMatrixStoreINTEL<DecorT, Tp, NumRows, NumCols,
497501
sycl::ext::oneapi::experimental::matrix::
498502
spv_matrix_use_traits<Use>::value,
499503
sycl::ext::oneapi::experimental::matrix::

sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#pragma once
1010

11+
#include "utils.hpp"
1112
#include <CL/__spirv/spirv_ops.hpp>
1213
#include <sycl/detail/defines_elementary.hpp>
1314
#include <sycl/ext/oneapi/bfloat16.hpp>
@@ -77,34 +78,37 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load(
7778
Group sg, joint_matrix<T, NumRows, NumCols, Layout, Group> &res,
7879
multi_ptr<T, Space, IsDecorated> src, size_t stride, matrix_layout MemL) {
7980
#ifdef __SYCL_DEVICE_ONLY__
80-
T *Ptr = src.get();
81+
static_assert(Space != access::address_space::private_space,
82+
"Joint Matrix doesn't support load from private memory!");
83+
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
84+
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(src);
8185
switch (MemL) {
8286
default:
8387
assert(false && "Invalid Memory Layout!");
8488
case matrix_layout::row_major:
8589
res.spvm =
86-
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
90+
__spirv_JointMatrixLoadINTEL<DecorT, T, NumRows, NumCols,
8791
spv_matrix_layout_traits<Layout>::value>(
8892
Ptr, stride, __spv::MatrixLayout::RowMajor,
8993
spv_scope_traits<Group>::value);
9094
break;
9195
case matrix_layout::col_major:
9296
res.spvm =
93-
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
97+
__spirv_JointMatrixLoadINTEL<DecorT, T, NumRows, NumCols,
9498
spv_matrix_layout_traits<Layout>::value>(
9599
Ptr, stride, __spv::MatrixLayout::ColumnMajor,
96100
spv_scope_traits<Group>::value);
97101
break;
98102
case matrix_layout::packed_a:
99103
res.spvm =
100-
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
104+
__spirv_JointMatrixLoadINTEL<DecorT, T, NumRows, NumCols,
101105
spv_matrix_layout_traits<Layout>::value>(
102106
Ptr, stride, __spv::MatrixLayout::PackedA,
103107
spv_scope_traits<Group>::value);
104108
break;
105109
case matrix_layout::packed_b:
106110
res.spvm =
107-
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
111+
__spirv_JointMatrixLoadINTEL<DecorT, T, NumRows, NumCols,
108112
spv_matrix_layout_traits<Layout>::value>(
109113
Ptr, stride, __spv::MatrixLayout::PackedB,
110114
spv_scope_traits<Group>::value);
@@ -128,30 +132,33 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
128132
Group sg, joint_matrix<T, NumRows, NumCols, MatL, Group> &src,
129133
multi_ptr<T, Space, IsDecorated> res, size_t stride, matrix_layout MemL) {
130134
#ifdef __SYCL_DEVICE_ONLY__
131-
T *Ptr = res.get();
135+
static_assert(Space != access::address_space::private_space,
136+
"Joint Matrix doesn't support store to private memory!");
137+
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
138+
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(res);
132139
switch (MemL) {
133140
default:
134141
assert(false && "Invalid Memory Layout!");
135142
case matrix_layout::row_major:
136-
__spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
143+
__spirv_JointMatrixStoreINTEL<DecorT, T, NumRows, NumCols,
137144
spv_matrix_layout_traits<MatL>::value>(
138145
Ptr, src.spvm, stride, __spv::MatrixLayout::RowMajor,
139146
spv_scope_traits<Group>::value);
140147
break;
141148
case matrix_layout::col_major:
142-
__spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
149+
__spirv_JointMatrixStoreINTEL<DecorT, T, NumRows, NumCols,
143150
spv_matrix_layout_traits<MatL>::value>(
144151
Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor,
145152
spv_scope_traits<Group>::value);
146153
break;
147154
case matrix_layout::packed_a:
148-
__spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
155+
__spirv_JointMatrixStoreINTEL<DecorT, T, NumRows, NumCols,
149156
spv_matrix_layout_traits<MatL>::value>(
150157
Ptr, src.spvm, stride, __spv::MatrixLayout::PackedA,
151158
spv_scope_traits<Group>::value);
152159
break;
153160
case matrix_layout::packed_b:
154-
__spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
161+
__spirv_JointMatrixStoreINTEL<DecorT, T, NumRows, NumCols,
155162
spv_matrix_layout_traits<MatL>::value>(
156163
Ptr, src.spvm, stride, __spv::MatrixLayout::PackedB,
157164
spv_scope_traits<Group>::value);

sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#pragma once
1010
#include "matrix-intel.hpp"
11+
#include "utils.hpp"
1112
#include <sycl/ext/oneapi/matrix/matrix-tensorcores.hpp>
1213
namespace sycl {
1314
__SYCL_INLINE_VER_NAMESPACE(_V1) {
@@ -199,32 +200,38 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load(
199200
multi_ptr<T, Space, IsDecorated> src, size_t stride,
200201
sycl::ext::oneapi::experimental::matrix::layout Layout) {
201202
#if defined(__SYCL_DEVICE_ONLY__)
203+
static_assert(Space != access::address_space::private_space,
204+
"Joint Matrix doesn't support load from private memory!");
202205
#if defined(__NVPTX__)
203206
std::ignore = sg;
204207
sycl::ext::oneapi::detail::load_accumulator_cuda(res.cuda_impl, src, stride,
205208
Layout);
206209
#else
207-
T *Ptr = src.get();
210+
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
211+
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(src);
208212
switch (Layout) {
209213
default:
210214
assert(false && "Invalid Memory Layout!");
211215
case layout::row_major:
212216
res.spvm = __spirv_JointMatrixLoadINTEL<
213-
T, S, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
217+
DecorT, S, NumRows, NumCols,
218+
spv_matrix_use_traits<use::accumulator>::value,
214219
spv_matrix_layout_traits<layout::dynamic>::value>(
215220
Ptr, stride, __spv::MatrixLayout::RowMajor,
216221
spv_scope_traits<Group>::value);
217222
break;
218223
case layout::col_major:
219224
res.spvm = __spirv_JointMatrixLoadINTEL<
220-
T, S, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
225+
DecorT, S, NumRows, NumCols,
226+
spv_matrix_use_traits<use::accumulator>::value,
221227
spv_matrix_layout_traits<layout::dynamic>::value>(
222228
Ptr, stride, __spv::MatrixLayout::ColumnMajor,
223229
spv_scope_traits<Group>::value);
224230
break;
225231
case sycl::ext::intel::experimental::matrix::layout::packed:
226232
res.spvm = __spirv_JointMatrixLoadINTEL<
227-
T, S, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
233+
DecorT, S, NumRows, NumCols,
234+
spv_matrix_use_traits<use::accumulator>::value,
228235
spv_matrix_layout_traits<layout::dynamic>::value>(
229236
Ptr, stride, __spv::MatrixLayout::Packed,
230237
spv_scope_traits<Group>::value);
@@ -255,15 +262,18 @@ joint_matrix_load(Group sg,
255262
joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &res,
256263
multi_ptr<T, Space, IsDecorated> src, size_t stride) {
257264
#if defined(__SYCL_DEVICE_ONLY__)
265+
static_assert(Space != access::address_space::private_space,
266+
"Joint Matrix doesn't support load from private memory!");
258267
#if defined(__NVPTX__)
259268
std::ignore = sg;
260269
sycl::ext::oneapi::detail::load_multiplicand_cuda<S, T, NumRows, NumCols, Use,
261270
Layout, Space>(
262271
res.cuda_impl, src, stride);
263272
#else
264-
T *Ptr = src.get();
273+
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
274+
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(src);
265275
res.spvm =
266-
__spirv_JointMatrixLoadINTEL<T, S, NumRows, NumCols,
276+
__spirv_JointMatrixLoadINTEL<DecorT, S, NumRows, NumCols,
267277
spv_matrix_use_traits<Use>::value,
268278
spv_matrix_layout_traits<Layout>::value>(
269279
Ptr, stride, spv_matrix_layout_traits<Layout>::value,
@@ -288,33 +298,39 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
288298
multi_ptr<T, Space, IsDecorated> dst, size_t stride,
289299
sycl::ext::oneapi::experimental::matrix::layout Layout) {
290300
#if defined(__SYCL_DEVICE_ONLY__)
301+
static_assert(Space != access::address_space::private_space,
302+
"Joint Matrix doesn't support store to private memory!");
291303
#if defined(__NVPTX__)
292304
std::ignore = sg;
293305
sycl::ext::oneapi::detail::joint_matrix_store_cuda<T, NumRows, NumCols,
294306
Space>(src.cuda_impl, dst,
295307
stride, Layout);
296308
#else
297-
T *Ptr = dst.get();
309+
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
310+
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(dst);
298311
switch (Layout) {
299312
default:
300313
assert(false && "Invalid Memory Layout!");
301314
case layout::row_major:
302315
__spirv_JointMatrixStoreINTEL<
303-
T, T, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
316+
DecorT, T, NumRows, NumCols,
317+
spv_matrix_use_traits<use::accumulator>::value,
304318
spv_matrix_layout_traits<layout::dynamic>::value>(
305319
Ptr, src.spvm, stride, __spv::MatrixLayout::RowMajor,
306320
spv_scope_traits<Group>::value);
307321
break;
308322
case layout::col_major:
309323
__spirv_JointMatrixStoreINTEL<
310-
T, T, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
324+
DecorT, T, NumRows, NumCols,
325+
spv_matrix_use_traits<use::accumulator>::value,
311326
spv_matrix_layout_traits<layout::dynamic>::value>(
312327
Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor,
313328
spv_scope_traits<Group>::value);
314329
break;
315330
case sycl::ext::intel::experimental::matrix::layout::packed:
316331
__spirv_JointMatrixStoreINTEL<
317-
T, T, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
332+
DecorT, T, NumRows, NumCols,
333+
spv_matrix_use_traits<use::accumulator>::value,
318334
spv_matrix_layout_traits<layout::dynamic>::value>(
319335
Ptr, src.spvm, stride, __spv::MatrixLayout::Packed,
320336
spv_scope_traits<Group>::value);
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
//===------- utils.hpp - SYCL matrix extension ----*- C++ -*------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
// ===--------------------------------------------------------------------=== //
8+
9+
#pragma once
10+
11+
#include <sycl/access/access.hpp>
12+
#include <sycl/multi_ptr.hpp>
13+
14+
namespace sycl {
15+
__SYCL_INLINE_VER_NAMESPACE(_V1) {
16+
namespace detail {
17+
18+
// Helper to return decorated pointer for different values
19+
// of access::decorated parameter.
20+
// If access::decorated::legacy is removed in the future
21+
// this helper usage can be replaced with ptr.get_decorated().
22+
template <typename DecorT, typename T, access::address_space Space,
23+
access::decorated IsDecorated>
24+
DecorT *getDecorated(multi_ptr<T, Space, IsDecorated> ptr) {
25+
if constexpr (IsDecorated == access::decorated::legacy)
26+
return ptr.get();
27+
else
28+
return ptr.get_decorated();
29+
}
30+
31+
} // namespace detail
32+
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
33+
} // namespace sycl
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// RUN: %clangxx -fsycl-device-only -S -emit-llvm -o - %s | FileCheck %s
2+
3+
// check that correct address spaces are used to load from and store to
4+
#define SYCL_EXT_ONEAPI_MATRIX_VERSION 4
5+
#include <sycl/sycl.hpp>
6+
7+
using namespace sycl;
8+
using namespace sycl::ext::oneapi::experimental::matrix;
9+
10+
int main(void) {
11+
queue q;
12+
unsigned short *A = malloc_shared<unsigned short>(8 * 16, q);
13+
unsigned short *B = malloc_shared<unsigned short>(16 * 16, q);
14+
float *C = malloc_shared<float>(8 * 16, q);
15+
16+
auto pA = multi_ptr<unsigned short, access::address_space::global_space>(A);
17+
auto pB = multi_ptr<unsigned short, access::address_space::global_space>(B);
18+
auto pC = multi_ptr<float, access::address_space::global_space>(C);
19+
20+
q.submit([&](handler &h) {
21+
local_accessor<unsigned short, 2> tileA{{8, 16}, h};
22+
23+
h.parallel_for(
24+
nd_range<2>({1, 16}, {1, 16}),
25+
[=](nd_item<2> it) [[intel::reqd_sub_group_size(16)]] {
26+
joint_matrix<sub_group, unsigned short, use::a, 8, 16,
27+
layout::row_major>
28+
tA;
29+
joint_matrix<sub_group, unsigned short, use::b, 16, 16,
30+
ext::intel::experimental::matrix::layout::packed>
31+
tB;
32+
joint_matrix<sub_group, float, use::accumulator, 8, 16> tC;
33+
34+
sub_group sg = it.get_sub_group();
35+
vec<unsigned short, 8> slmvec = sg.load<8>(pA);
36+
sg.store<8>(
37+
tileA.template get_multi_ptr<sycl::access::decorated::yes>(),
38+
slmvec);
39+
it.barrier(access::fence_space::local_space);
40+
41+
// A should load from local address space
42+
// CHECK: %{{.*}} = tail call spir_func noundef %spirv.JointMatrixINTEL._short_8_16_0_3_0 addrspace(4)* @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(i16 addrspace(3)* noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}}
43+
joint_matrix_load(
44+
sg, tA,
45+
tileA.template get_multi_ptr<sycl::access::decorated::yes>(), 16);
46+
// B should load from global address space
47+
// CHECK: %{{.*}} = tail call spir_func noundef %spirv.JointMatrixINTEL._short_16_16_2_3_1 addrspace(4)* @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(i16 addrspace(1)* noundef %{{.*}}, i64 noundef 32, i32 noundef 2, i32 noundef 3, i32 noundef 0) #{{.*}}
48+
joint_matrix_load(sg, tB, pB, 32);
49+
tC = joint_matrix_mad(sg, tA, tB, tC);
50+
// C should store to global address space
51+
// CHECK: tail call spir_func void @_Z[[#]]__spirv_JointMatrixStoreINTEL{{.*}}(float addrspace(1)* noundef %{{.*}}, %spirv.JointMatrixINTEL._float_8_16_3_3_2 addrspace(4)* noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}}
52+
joint_matrix_store(sg, tC, pC, 16, layout::row_major);
53+
});
54+
});
55+
56+
free(A, q);
57+
free(B, q);
58+
free(C, q);
59+
60+
return 0;
61+
}

0 commit comments

Comments
 (0)