Skip to content

Enable matrix_load, matrix_store, and matrix_mad #4076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 17, 2021
Merged
27 changes: 27 additions & 0 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,33 @@
#endif

#ifdef __SYCL_DEVICE_ONLY__
template <typename T, std::size_t R, std::size_t C,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S> *
__spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride,
__spv::MatrixLayout Layout = L,
__spv::Scope::Flag Sc = S, int MemOperand = 0);

template <typename T, std::size_t R, std::size_t C,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL void __spirv_JointMatrixStoreINTEL(
T *Ptr, __spv::__spirv_JointMatrixINTEL<T, R, C, L, S> *Object,
std::size_t Stride, __spv::MatrixLayout Layout = L,
__spv::Scope::Flag Sc = S, int MemOperand = 0);

template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T2, M, N, LC, S> *
__spirv_JointMatrixMadINTEL(
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S> *A,
__spv::__spirv_JointMatrixINTEL<T1, K, N, LB, S> *B,
__spv::__spirv_JointMatrixINTEL<T2, M, N, LC, S> *C,
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);

#ifndef __SPIRV_BUILTIN_DECLARATIONS__
#error \
Expand Down
7 changes: 7 additions & 0 deletions sycl/include/CL/__spirv/spirv_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#pragma once

#include <cstddef>
#include <cstdint>

// TODO: include the header file with SPIR-V declarations from SPIRV-Headers
Expand Down Expand Up @@ -105,6 +106,12 @@ enum class GroupOperation : uint32_t {
ExclusiveScan = 2
};

enum class MatrixLayout { RowMajor, ColumnMajor, PackedA, PackedB };

template <typename T, std::size_t R, std::size_t C, MatrixLayout U,
Scope::Flag S = Scope::Flag::Subgroup>
struct __spirv_JointMatrixINTEL;

} // namespace __spv

#ifdef __SYCL_DEVICE_ONLY__
Expand Down
17 changes: 0 additions & 17 deletions sycl/include/CL/sycl/ONEAPI/matrix/matrix-amx.hpp

This file was deleted.

17 changes: 0 additions & 17 deletions sycl/include/CL/sycl/ONEAPI/matrix/matrix.hpp

This file was deleted.

9 changes: 8 additions & 1 deletion sycl/include/CL/sycl/feature_test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@ namespace sycl {

// TODO: Move these feature-test macros to compiler driver.
#define SYCL_EXT_INTEL_DEVICE_INFO 2
#define SYCL_EXT_ONEAPI_MATRIX 1
// As for SYCL_EXT_ONEAPI_MATRIX:
// 1- provides AOT initial implementation for AMX for the experimental matrix
// extension
// 2- provides JIT implementation (target agnostic) for the
// experimental matrix extension
#ifndef SYCL_EXT_ONEAPI_MATRIX
#define SYCL_EXT_ONEAPI_MATRIX 2
#endif

} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)
187 changes: 187 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
//==------------------ matrix.hpp - SYCL matrix ----------------*- C++ -*---==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// ===--------------------------------------------------------------------=== //

#pragma once

#include <CL/__spirv/spirv_ops.hpp>
#include <CL/sycl/detail/defines_elementary.hpp>
#include <CL/sycl/feature_test.hpp>

__SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
namespace ext {
namespace oneapi {
namespace experimental::matrix {

enum class matrix_layout { row_major, col_major, packed_a, packed_b };

template <matrix_layout Layout> struct spv_matrix_layout_traits {
static constexpr __spv::MatrixLayout value = __spv::MatrixLayout::RowMajor;
};

#define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT) \
template <> struct spv_matrix_layout_traits<LAYOUT> { \
static constexpr __spv::MatrixLayout value = SPV_LAYOUT; \
};

SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::row_major,
__spv::MatrixLayout::RowMajor)
SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::col_major,
__spv::MatrixLayout::ColumnMajor)
SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::packed_a, __spv::MatrixLayout::PackedA)
SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::packed_b, __spv::MatrixLayout::PackedB)

template <typename G> struct spv_scope_traits {};
template <> struct spv_scope_traits<sycl::sub_group> {
constexpr static auto value = __spv::Scope::Subgroup;
};
template <int D> struct spv_scope_traits<sycl::group<D>> {
constexpr static auto value = __spv::Scope::Workgroup;
};

template <typename T, size_t NumRows, size_t NumCols,
matrix_layout Layout = matrix_layout::row_major,
typename Group = sycl::sub_group>
struct joint_matrix {
public:
__spv::__spirv_JointMatrixINTEL<
T, NumRows, NumCols, spv_matrix_layout_traits<Layout>::value> *spvm;
joint_matrix(Group sg) {
#ifndef __SYCL_DEVICE_ONLY__
(void)sg;
throw runtime_error("joint matrix is not supported on host device.",
PI_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}
};

template <typename Group, typename T, size_t NumRows, size_t NumCols,
matrix_layout Layout = matrix_layout::row_major,
Copy link
Contributor Author

@yubingex007-a11y yubingex007-a11y Jul 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

matL

access::address_space Space>
inline __SYCL_ALWAYS_INLINE void
joint_matrix_load(Group sg,
joint_matrix<T, NumRows, NumCols, Layout, Group> &res,
multi_ptr<T, Space> src, size_t stride, matrix_layout MemL) {
#ifdef __SYCL_DEVICE_ONLY__
T *Ptr = src.get();
switch (MemL) {
default:
assert(false && "Invalid Memory Layout!");
case matrix_layout::row_major:
res.spvm =
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
spv_matrix_layout_traits<Layout>::value>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you omitted the group argument here right?

Ptr, stride, __spv::MatrixLayout::RowMajor,
spv_scope_traits<Group>::value);
break;
case matrix_layout::col_major:
res.spvm =
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
spv_matrix_layout_traits<Layout>::value>(
Ptr, stride, __spv::MatrixLayout::ColumnMajor,
spv_scope_traits<Group>::value);
break;
case matrix_layout::packed_a:
res.spvm =
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
spv_matrix_layout_traits<Layout>::value>(
Ptr, stride, __spv::MatrixLayout::PackedA,
spv_scope_traits<Group>::value);
break;
case matrix_layout::packed_b:
res.spvm =
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
spv_matrix_layout_traits<Layout>::value>(
Ptr, stride, __spv::MatrixLayout::PackedB,
spv_scope_traits<Group>::value);
break;
}
#else
(void)sg;
(void)res;
(void)src;
(void)stride;
(void)MemL;
throw runtime_error("joint matrix is not supported on host device.",
PI_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}

template <typename Group, typename T, size_t NumRows, size_t NumCols,
matrix_layout MatL = matrix_layout::row_major,
access::address_space Space>
inline __SYCL_ALWAYS_INLINE void
joint_matrix_store(Group sg,
joint_matrix<T, NumRows, NumCols, MatL, Group> &src,
multi_ptr<T, Space> res, size_t stride, matrix_layout MemL) {
#ifdef __SYCL_DEVICE_ONLY__
T *Ptr = res.get();
switch (MemL) {
default:
assert(false && "Invalid Memory Layout!");
case matrix_layout::row_major:
__spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
spv_matrix_layout_traits<MatL>::value>(
Ptr, src.spvm, stride, __spv::MatrixLayout::RowMajor,
spv_scope_traits<Group>::value);
break;
case matrix_layout::col_major:
__spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
spv_matrix_layout_traits<MatL>::value>(
Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor,
spv_scope_traits<Group>::value);
break;
case matrix_layout::packed_a:
__spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
spv_matrix_layout_traits<MatL>::value>(
Ptr, src.spvm, stride, __spv::MatrixLayout::PackedA,
spv_scope_traits<Group>::value);
break;
case matrix_layout::packed_b:
__spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
spv_matrix_layout_traits<MatL>::value>(
Ptr, src.spvm, stride, __spv::MatrixLayout::PackedB,
spv_scope_traits<Group>::value);
break;
}
#else
(void)sg;
(void)src;
(void)res;
(void)stride;
(void)MemL;
throw runtime_error("joint matrix is not supported on host device.",
PI_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}

template <typename Group, typename T1, typename T2, size_t M, size_t K,
size_t N, matrix_layout LayoutA, matrix_layout LayoutB,
matrix_layout LayoutC>
inline __SYCL_ALWAYS_INLINE joint_matrix<T2, M, N, LayoutC, Group>
joint_matrix_mad(Group sg, joint_matrix<T1, M, K, LayoutA, Group> &mA,
joint_matrix<T1, K, N, LayoutB, Group> &mB,
joint_matrix<T2, M, N, LayoutC, Group> &mC) {
#ifdef __SYCL_DEVICE_ONLY__
joint_matrix<T2, M, N, LayoutC, Group> res(sg);
res.spvm = __spirv_JointMatrixMadINTEL(mA.spvm, mB.spvm, mC.spvm);
return res;
#else
(void)sg;
(void)mA;
(void)mB;
(void)mC;
throw runtime_error("joint matrix is not supported on host device.",
PI_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}
} // namespace experimental::matrix
} // namespace oneapi
} // namespace ext
} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)
5 changes: 4 additions & 1 deletion sycl/include/sycl/ext/oneapi/matrix/matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

#if (SYCL_EXT_ONEAPI_MATRIX == 1)
#if defined(__AMXTILE__) && defined(__AMXINT8__) && defined(__AMXBF16__)
#include <sycl/ext/oneapi/matrix/matrix-amx.hpp>
#include <sycl/ext/oneapi/matrix/matrix-aot-amx.hpp>
#endif
#endif
#if (SYCL_EXT_ONEAPI_MATRIX == 2)
#include <sycl/ext/oneapi/matrix/matrix-jit.hpp>
#endif
2 changes: 1 addition & 1 deletion sycl/test/matrix/matrix-amx-bf16-test.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %clangxx -march=sapphirerapids -fsycl -O2 %s -o %t.out
// RUN: %clangxx -DSYCL_EXT_ONEAPI_MATRIX=1 -march=sapphirerapids -fsycl -O2 %s -o %t.out
#include <CL/sycl.hpp>
#if (SYCL_EXT_ONEAPI_MATRIX == 1)
Copy link
Contributor

@dkhaldi dkhaldi Jul 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Only the leader perform AMX computation.
Why do you have "the leader" test in the code. this is wrong for two reasons:
1- SG size is one so we don't need to test that.
2- all work items in the subgroup must enter the joint matrix code. So we should NOT have diverged code in there

#include <iostream>
Expand Down
2 changes: 1 addition & 1 deletion sycl/test/matrix/matrix-amx-int8-test.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %clangxx -march=sapphirerapids -fsycl -O2 %s -o %t.out
// RUN: %clangxx -DSYCL_EXT_ONEAPI_MATRIX=1 -march=sapphirerapids -fsycl -O2 %s -o %t.out
#include <CL/sycl.hpp>
#if (SYCL_EXT_ONEAPI_MATRIX == 1)
#include <iostream>
Expand Down
Loading