Skip to content

Commit f3137e9

Browse files
authored
[SYCL][joint matrix] add implementation for prefetch and overloads of load/store on annotated pointers (#12066)
Spec added in #11473
1 parent c7f6229 commit f3137e9

11 files changed

+544
-58
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#ifdef __SYCL_DEVICE_ONLY__
2626

2727
extern __DPCPP_SYCL_EXTERNAL float __spirv_RoundFToTF32INTEL(float a);
28+
2829
template <typename T, typename Tp, std::size_t R, std::size_t C,
2930
__spv::MatrixUse U,
3031
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
@@ -139,6 +140,11 @@ extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
139140
__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *,
140141
Ts val, size_t i);
141142

143+
template <typename T, std::size_t NumRows, std::size_t NumCols>
144+
extern __DPCPP_SYCL_EXTERNAL void __spirv_JointMatrixPrefetchINTEL(
145+
T *Ptr, std::size_t coordX, std::size_t coordY, unsigned int CacheLevel,
146+
__spv::MatrixLayout Layout, std::size_t Stride);
147+
142148
#ifndef __SPIRV_BUILTIN_DECLARATIONS__
143149
#error \
144150
"SPIR-V built-ins are not available. Please set -fdeclare-spirv-builtins flag."

sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,51 @@ joint_matrix_store(Group,
521521
#endif // defined(__SYCL_DEVICE_ONLY__)
522522
}
523523

524+
template <
525+
typename Group, typename T, typename Tp,
526+
sycl::ext::oneapi::experimental::matrix::use Use, size_t NumRows,
527+
size_t NumCols, sycl::ext::oneapi::experimental::matrix::layout Layout,
528+
typename PropertyListT,
529+
std::enable_if_t<Use == sycl::ext::oneapi::experimental::matrix::use::a ||
530+
Use == sycl::ext::oneapi::experimental::matrix::use::b,
531+
bool> = true>
532+
inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
533+
Group,
534+
const sycl::ext::oneapi::experimental::matrix::joint_matrix<
535+
Group, Tp, Use, NumRows, NumCols, Layout> &src,
536+
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> dst,
537+
size_t stride) {
538+
#if defined(__SYCL_DEVICE_ONLY__)
539+
#if defined(__NVPTX__)
540+
std::ignore = src;
541+
std::ignore = dst;
542+
std::ignore = stride;
543+
throw runtime_error(
544+
"This version of the matrix extension is only currently supported on "
545+
"intel devices",
546+
PI_ERROR_INVALID_DEVICE);
547+
#else
548+
// intel's impl
549+
T *Ptr = dst.get();
550+
__spirv_JointMatrixStoreINTEL<T, Tp, NumRows, NumCols,
551+
sycl::ext::oneapi::experimental::matrix::
552+
spv_matrix_use_traits<Use>::value,
553+
sycl::ext::oneapi::experimental::matrix::
554+
spv_matrix_layout_traits<Layout>::value>(
555+
Ptr, src.spvm, stride,
556+
sycl::ext::oneapi::experimental::matrix::spv_matrix_layout_traits<
557+
Layout>::value,
558+
sycl::ext::oneapi::experimental::matrix::spv_scope_traits<Group>::value);
559+
#endif // defined(__NVPTX__)
560+
#else
561+
std::ignore = src;
562+
std::ignore = dst;
563+
std::ignore = stride;
564+
throw runtime_error("joint matrix is not supported on host device.",
565+
PI_ERROR_INVALID_DEVICE);
566+
#endif // defined(__SYCL_DEVICE_ONLY__)
567+
}
568+
524569
template <typename Group, typename T,
525570
sycl::ext::oneapi::experimental::matrix::use Use, size_t Rows,
526571
size_t Cols, sycl::ext::oneapi::experimental::matrix::layout Layout,

sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,21 @@ convertMatrixUseStringToEnum(const char *UseString) {
6161
}
6262
return std::nullopt;
6363
}
64+
65+
inline __SYCL_ALWAYS_INLINE __spv::MatrixLayout joint_matrix_layout_to_spv(
66+
sycl::ext::oneapi::experimental::matrix::layout Layout) {
67+
switch (Layout) {
68+
case sycl::ext::oneapi::experimental::matrix::layout::row_major:
69+
return __spv::MatrixLayout::RowMajor;
70+
case sycl::ext::oneapi::experimental::matrix::layout::col_major:
71+
return __spv::MatrixLayout::ColumnMajor;
72+
case sycl::ext::oneapi::experimental::matrix::layout::ext_intel_packed:
73+
return __spv::MatrixLayout::Packed;
74+
case sycl::ext::oneapi::experimental::matrix::layout::dynamic:
75+
return __spv::MatrixLayout::Dynamic;
76+
}
77+
}
78+
6479
} // namespace detail
6580
} // namespace _V1
6681
} // namespace sycl

sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp

Lines changed: 168 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
#include <sycl/exception.hpp> // for runtime_error
2525
#include <sycl/ext/oneapi/matrix/matrix-unified-utils.hpp> // for layout, use, tf32, convertMatrixUseEnumToString
2626
#include <sycl/ext/oneapi/matrix/query-types.hpp> // for convertTypeToMatrixTypeString
27-
#include <sycl/marray.hpp> // for marray
28-
#include <sycl/multi_ptr.hpp> // for multi_ptr
27+
#include <sycl/marray.hpp> // for marray
28+
#include <sycl/multi_ptr.hpp> // for multi_ptr
2929

3030
#include <cstring> // for size_t, memcpy
3131
#include <stdint.h> // for uint32_t
@@ -165,34 +165,12 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load(
165165
std::ignore = sg;
166166
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
167167
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(src);
168-
switch (Layout) {
169-
default:
170-
assert(false && "Invalid Memory Layout!");
171-
case layout::row_major:
172-
res.spvm = __spirv_JointMatrixLoadINTEL<
173-
DecorT, S, NumRows, NumCols,
174-
spv_matrix_use_traits<use::accumulator>::value,
175-
spv_matrix_layout_traits<layout::dynamic>::value>(
176-
Ptr, stride, __spv::MatrixLayout::RowMajor,
177-
spv_scope_traits<Group>::value);
178-
break;
179-
case layout::col_major:
180-
res.spvm = __spirv_JointMatrixLoadINTEL<
181-
DecorT, S, NumRows, NumCols,
182-
spv_matrix_use_traits<use::accumulator>::value,
183-
spv_matrix_layout_traits<layout::dynamic>::value>(
184-
Ptr, stride, __spv::MatrixLayout::ColumnMajor,
185-
spv_scope_traits<Group>::value);
186-
break;
187-
case layout::ext_intel_packed:
188-
res.spvm = __spirv_JointMatrixLoadINTEL<
189-
DecorT, S, NumRows, NumCols,
190-
spv_matrix_use_traits<use::accumulator>::value,
191-
spv_matrix_layout_traits<layout::dynamic>::value>(
192-
Ptr, stride, __spv::MatrixLayout::Packed,
193-
spv_scope_traits<Group>::value);
194-
break;
195-
}
168+
res.spvm = __spirv_JointMatrixLoadINTEL<
169+
DecorT, S, NumRows, NumCols,
170+
spv_matrix_use_traits<use::accumulator>::value,
171+
spv_matrix_layout_traits<layout::dynamic>::value>(
172+
Ptr, stride, sycl::detail::joint_matrix_layout_to_spv(Layout),
173+
spv_scope_traits<Group>::value);
196174
#endif // defined(__NVPTX__)
197175
#else
198176
std::ignore = sg;
@@ -250,6 +228,83 @@ joint_matrix_load(Group sg,
250228
#endif // defined(__SYCL_DEVICE_ONLY__)
251229
}
252230

231+
template <typename Group, typename S, typename T, size_t NumRows,
232+
size_t NumCols, typename PropertyListT,
233+
std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value,
234+
bool> = true>
235+
inline __SYCL_ALWAYS_INLINE void joint_matrix_load(
236+
Group sg,
237+
joint_matrix<Group, S, use::accumulator, NumRows, NumCols,
238+
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res,
239+
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> src,
240+
size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout) {
241+
#if defined(__SYCL_DEVICE_ONLY__)
242+
#if defined(__NVPTX__)
243+
std::ignore = sg;
244+
throw runtime_error("Use joint_matrix_load on multi_ptr on Nvidia device.",
245+
PI_ERROR_INVALID_DEVICE);
246+
#elif defined(__HIP_PLATFORM_AMD_MFMA__)
247+
throw runtime_error("Use joint_matrix_load on multi_ptr on AMD device.",
248+
PI_ERROR_INVALID_DEVICE);
249+
#else
250+
std::ignore = sg;
251+
T *Ptr = src.get();
252+
res.spvm = __spirv_JointMatrixLoadINTEL<
253+
T, S, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
254+
spv_matrix_layout_traits<layout::dynamic>::value>(
255+
Ptr, stride, sycl::detail::joint_matrix_layout_to_spv(Layout),
256+
spv_scope_traits<Group>::value);
257+
#endif // defined(__NVPTX__)
258+
#else
259+
std::ignore = sg;
260+
std::ignore = res;
261+
std::ignore = src;
262+
std::ignore = stride;
263+
std::ignore = Layout;
264+
throw runtime_error("joint matrix is not supported on host device.",
265+
PI_ERROR_INVALID_DEVICE);
266+
#endif // defined(__SYCL_DEVICE_ONLY__)
267+
}
268+
269+
template <
270+
typename Group, typename S, typename T, use Use, size_t NumRows,
271+
size_t NumCols, matrix::layout Layout, typename PropertyListT,
272+
std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
273+
(std::is_same<S, precision::tf32>::value &&
274+
std::is_same<std::remove_const_t<T>, float>::value),
275+
bool> = true>
276+
inline __SYCL_ALWAYS_INLINE void joint_matrix_load(
277+
Group sg, joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &res,
278+
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> src,
279+
size_t stride) {
280+
#if defined(__SYCL_DEVICE_ONLY__)
281+
#if defined(__NVPTX__)
282+
std::ignore = sg;
283+
throw runtime_error("Use joint_matrix_load on multi_ptr on Nvidia device.",
284+
PI_ERROR_INVALID_DEVICE);
285+
#elif defined(__HIP_PLATFORM_AMD_MFMA__)
286+
throw runtime_error("Use joint_matrix_load on multi_ptr on AMD device.",
287+
PI_ERROR_INVALID_DEVICE);
288+
#else
289+
std::ignore = sg;
290+
T *Ptr = src.get();
291+
res.spvm =
292+
__spirv_JointMatrixLoadINTEL<T, S, NumRows, NumCols,
293+
spv_matrix_use_traits<Use>::value,
294+
spv_matrix_layout_traits<Layout>::value>(
295+
Ptr, stride, spv_matrix_layout_traits<Layout>::value,
296+
spv_scope_traits<Group>::value);
297+
#endif // defined(__NVPTX__)
298+
#else
299+
std::ignore = sg;
300+
std::ignore = res;
301+
std::ignore = src;
302+
std::ignore = stride;
303+
throw runtime_error("joint matrix is not supported on host device.",
304+
PI_ERROR_INVALID_DEVICE);
305+
#endif // defined(__SYCL_DEVICE_ONLY__)
306+
}
307+
253308
template <typename Group, typename T, size_t NumRows, size_t NumCols,
254309
access::address_space Space, access::decorated IsDecorated>
255310
inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
@@ -275,34 +330,49 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
275330
std::ignore = sg;
276331
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
277332
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(dst);
278-
switch (Layout) {
279-
default:
280-
assert(false && "Invalid Memory Layout!");
281-
case layout::row_major:
282-
__spirv_JointMatrixStoreINTEL<
283-
DecorT, T, NumRows, NumCols,
284-
spv_matrix_use_traits<use::accumulator>::value,
285-
spv_matrix_layout_traits<layout::dynamic>::value>(
286-
Ptr, src.spvm, stride, __spv::MatrixLayout::RowMajor,
287-
spv_scope_traits<Group>::value);
288-
break;
289-
case layout::col_major:
290-
__spirv_JointMatrixStoreINTEL<
291-
DecorT, T, NumRows, NumCols,
292-
spv_matrix_use_traits<use::accumulator>::value,
293-
spv_matrix_layout_traits<layout::dynamic>::value>(
294-
Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor,
295-
spv_scope_traits<Group>::value);
296-
break;
297-
case layout::ext_intel_packed:
298-
__spirv_JointMatrixStoreINTEL<
299-
DecorT, T, NumRows, NumCols,
300-
spv_matrix_use_traits<use::accumulator>::value,
301-
spv_matrix_layout_traits<layout::dynamic>::value>(
302-
Ptr, src.spvm, stride, __spv::MatrixLayout::Packed,
303-
spv_scope_traits<Group>::value);
304-
break;
305-
}
333+
__spirv_JointMatrixStoreINTEL<
334+
DecorT, T, NumRows, NumCols,
335+
spv_matrix_use_traits<use::accumulator>::value,
336+
spv_matrix_layout_traits<layout::dynamic>::value>(
337+
Ptr, src.spvm, stride, sycl::detail::joint_matrix_layout_to_spv(Layout),
338+
spv_scope_traits<Group>::value);
339+
#endif // defined(__NVPTX__)
340+
#else
341+
std::ignore = sg;
342+
std::ignore = src;
343+
std::ignore = dst;
344+
std::ignore = stride;
345+
std::ignore = Layout;
346+
throw runtime_error("joint matrix is not supported on host device.",
347+
PI_ERROR_INVALID_DEVICE);
348+
#endif // defined(__SYCL_DEVICE_ONLY__)
349+
}
350+
351+
template <typename Group, typename T, size_t NumRows, size_t NumCols,
352+
typename PropertyListT>
353+
inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
354+
Group sg,
355+
const joint_matrix<Group, T, use::accumulator, NumRows, NumCols,
356+
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
357+
&src,
358+
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> dst,
359+
size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout) {
360+
#if defined(__SYCL_DEVICE_ONLY__)
361+
#if defined(__NVPTX__)
362+
std::ignore = sg;
363+
throw runtime_error("Use joint_matrix_store on multi_ptr on Nvidia device.",
364+
PI_ERROR_INVALID_DEVICE);
365+
#elif defined(__HIP_PLATFORM_AMD_MFMA__)
366+
throw runtime_error("Use joint_matrix_store on multi_ptr on AMD device.",
367+
PI_ERROR_INVALID_DEVICE);
368+
#else
369+
std::ignore = sg;
370+
T *Ptr = dst.get();
371+
__spirv_JointMatrixStoreINTEL<
372+
T, T, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
373+
spv_matrix_layout_traits<layout::dynamic>::value>(
374+
Ptr, src.spvm, stride, sycl::detail::joint_matrix_layout_to_spv(Layout),
375+
spv_scope_traits<Group>::value);
306376
#endif // defined(__NVPTX__)
307377
#else
308378
std::ignore = sg;
@@ -429,6 +499,46 @@ inline __SYCL_ALWAYS_INLINE float round_to_tf32(const float &a) {
429499
return ret;
430500
#endif // defined(__SYCL_DEVICE_ONLY__)
431501
}
502+
503+
template <size_t NumRows, size_t NumCols, typename Group, typename T,
504+
typename Properties = ext::oneapi::experimental::empty_properties_t>
505+
inline __SYCL_ALWAYS_INLINE void
506+
joint_matrix_prefetch(Group sg, T *Ptr, size_t stride,
507+
sycl::ext::oneapi::experimental::matrix::layout Layout,
508+
Properties properties = {}) {
509+
#if defined(__SYCL_DEVICE_ONLY__)
510+
#if defined(__NVPTX__)
511+
std::ignore = sg;
512+
std::ignore = properties;
513+
throw runtime_error(
514+
"joint_matrix_prefetch is not supported on Nvidia device.",
515+
PI_ERROR_INVALID_DEVICE);
516+
#elif defined(__HIP_PLATFORM_AMD_MFMA__)
517+
std::ignore = sg;
518+
std::ignore = properties;
519+
throw runtime_error("joint_matrix_prefetch is not supported on AMD device.",
520+
PI_ERROR_INVALID_DEVICE);
521+
#else
522+
std::ignore = sg;
523+
auto prop = properties.template get_property<prefetch_hint_key>();
524+
// Will be removed once SPIRV implementation also uses offsetpointer
525+
size_t coordX = 0;
526+
size_t coordY = 0;
527+
__spirv_JointMatrixPrefetchINTEL<T, NumRows, NumCols>(
528+
Ptr, coordX, coordY, detail::PropertyMetaInfo<decltype(prop)>::value,
529+
sycl::detail::joint_matrix_layout_to_spv(Layout), stride);
530+
#endif // defined(__NVPTX__)
531+
#else
532+
std::ignore = sg;
533+
std::ignore = Ptr;
534+
std::ignore = stride;
535+
std::ignore = Layout;
536+
std::ignore = properties;
537+
throw runtime_error("joint matrix is not supported on host device.",
538+
PI_ERROR_INVALID_DEVICE);
539+
#endif // defined(__SYCL_DEVICE_ONLY__)
540+
}
541+
432542
} // namespace matrix
433543
} // namespace experimental
434544
} // namespace oneapi
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//==-------- joint_matrix_annotated_ptr.cpp - DPC++ joint_matrix-----------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %{build} -o %t.out
11+
// RUN: %{run} %t.out
12+
13+
// Currently row major B fails when annotated_ptr is used
14+
// XFAIL: gpu
15+
16+
#include "../common.hpp"
17+
18+
#define SG_SZ 32
19+
constexpr size_t TN = 16;
20+
21+
#include "../joint_matrix_annotated_ptr_impl.hpp"
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
//==-------- joint_matrix_prefetch.cpp - DPC++ joint_matrix----------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// RUN: %{build} -o %t.out
9+
// RUN: %{run} %t.out
10+
11+
// XFAIL:*
12+
13+
#include "../common.hpp"
14+
15+
#define SG_SZ 32
16+
constexpr size_t TN = 16;
17+
#include "../joint_matrix_prefetch_impl.hpp"

0 commit comments

Comments
 (0)