Skip to content

[SYCL][CUDA][MATRIX] Remove using namespace experimental from headers #5217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 146 additions & 75 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,34 +51,45 @@ struct joint_matrix<
} // namespace experimental::matrix

namespace detail {
using namespace experimental;

template <typename T, matrix::matrix_use MT, size_t NumRows, size_t NumCols,
matrix::matrix_layout Layout, access::address_space Space,
typename Cond = void>
template <typename T, sycl::ext::oneapi::experimental::matrix::matrix_use MT,
size_t NumRows, size_t NumCols,
sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
access::address_space Space, typename Cond = void>
struct joint_matrix_load_impl {
void load(matrix::joint_matrix<T, MT, NumRows, NumCols, Layout> &res,
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix<
T, MT, NumRows, NumCols, Layout> &res,
multi_ptr<T, Space> src, size_t stride);
};

template <matrix::matrix_layout Layout> constexpr int get_layout_id();
template <sycl::ext::oneapi::experimental::matrix::matrix_layout Layout>
constexpr int get_layout_id();

template <> constexpr int get_layout_id<matrix::matrix_layout::row_major>() {
template <>
constexpr int get_layout_id<
sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major>() {
return 0;
}

template <> constexpr int get_layout_id<matrix::matrix_layout::col_major>() {
template <>
constexpr int get_layout_id<
sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major>() {
return 1;
}

template <matrix::matrix_layout Layout, access::address_space Space>
template <sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
access::address_space Space>
struct joint_matrix_load_impl<
double, matrix::matrix_use::a, 8, 4, Layout, Space,
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major ||
Layout == matrix::matrix_layout::col_major>> {
void
load(matrix::joint_matrix<double, matrix::matrix_use::a, 8, 4, Layout> &res,
multi_ptr<double, Space> src, size_t stride) {
double, sycl::ext::oneapi::experimental::matrix::matrix_use::a, 8, 4,
Layout, Space,
typename std::enable_if_t<Layout == sycl::ext::oneapi::experimental::
matrix::matrix_layout::row_major ||
Layout == sycl::ext::oneapi::experimental::
matrix::matrix_layout::col_major>> {
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix<
double, sycl::ext::oneapi::experimental::matrix::matrix_use::a,
8, 4, Layout> &res,
multi_ptr<double, Space> src, size_t stride) {

#ifdef __NVPTX__
#ifdef __SYCL_DEVICE_ONLY__
Expand All @@ -88,14 +99,19 @@ struct joint_matrix_load_impl<
}
};

template <matrix::matrix_layout Layout, access::address_space Space>
template <sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
access::address_space Space>
struct joint_matrix_load_impl<
double, matrix::matrix_use::b, 4, 8, Layout, Space,
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major ||
Layout == matrix::matrix_layout::col_major>> {
void
load(matrix::joint_matrix<double, matrix::matrix_use::b, 4, 8, Layout> &res,
multi_ptr<double, Space> src, size_t stride) {
double, sycl::ext::oneapi::experimental::matrix::matrix_use::b, 4, 8,
Layout, Space,
typename std::enable_if_t<Layout == sycl::ext::oneapi::experimental::
matrix::matrix_layout::row_major ||
Layout == sycl::ext::oneapi::experimental::
matrix::matrix_layout::col_major>> {
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix<
double, sycl::ext::oneapi::experimental::matrix::matrix_use::b,
4, 8, Layout> &res,
multi_ptr<double, Space> src, size_t stride) {
#ifdef __NVPTX__
#ifdef __SYCL_DEVICE_ONLY__
__dmma_m8n8k4_ld_b(res.data, src.get(), stride, get_layout_id<Layout>());
Expand All @@ -104,14 +120,21 @@ struct joint_matrix_load_impl<
}
};

template <matrix::matrix_layout Layout, access::address_space Space>
template <sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
access::address_space Space>
struct joint_matrix_load_impl<
double, matrix::matrix_use::accumulator, 8, 8, Layout, Space,
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major ||
Layout == matrix::matrix_layout::col_major>> {
void load(matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8,
Layout> &res,
multi_ptr<double, Space> src, size_t stride) {
double, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8,
8, Layout, Space,
typename std::enable_if_t<Layout == sycl::ext::oneapi::experimental::
matrix::matrix_layout::row_major ||
Layout == sycl::ext::oneapi::experimental::
matrix::matrix_layout::col_major>> {
void
load(sycl::ext::oneapi::experimental::matrix::joint_matrix<
double,
sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8,
8, Layout> &res,
multi_ptr<double, Space> src, size_t stride) {

#ifdef __NVPTX__
#ifdef __SYCL_DEVICE_ONLY__
Expand All @@ -122,22 +145,30 @@ struct joint_matrix_load_impl<
};

template <typename T, size_t NumRows, size_t NumCols,
matrix::matrix_layout Layout, access::address_space Space,
typename Cond = void>
sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
access::address_space Space, typename Cond = void>
struct joint_matrix_store_impl {
void store(matrix::joint_matrix<T, matrix::matrix_use::accumulator, NumRows,
NumCols, Layout> &src,
multi_ptr<T, Space> dst, size_t stride);
void
store(sycl::ext::oneapi::experimental::matrix::joint_matrix<
T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
NumRows, NumCols, Layout> &src,
multi_ptr<T, Space> dst, size_t stride);
};

template <matrix::matrix_layout Layout, access::address_space Space>
template <sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
access::address_space Space>
struct joint_matrix_store_impl<
double, 8, 8, Layout, Space,
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major ||
Layout == matrix::matrix_layout::col_major>> {
void store(matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8,
Layout> &src,
multi_ptr<double, Space> dst, size_t stride) {
typename std::enable_if_t<Layout == sycl::ext::oneapi::experimental::
matrix::matrix_layout::row_major ||
Layout == sycl::ext::oneapi::experimental::
matrix::matrix_layout::col_major>> {
void
store(sycl::ext::oneapi::experimental::matrix::joint_matrix<
double,
sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8,
8, Layout> &src,
multi_ptr<double, Space> dst, size_t stride) {

#ifdef __NVPTX__
#ifdef __SYCL_DEVICE_ONLY__
Expand All @@ -149,60 +180,98 @@ struct joint_matrix_store_impl<
};

template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB,
matrix::matrix_layout LayoutC, typename Cond = void>
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA,
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB,
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC,
typename Cond = void>
struct joint_matrix_mad_impl {
matrix::joint_matrix<T2, matrix::matrix_use::accumulator, M, N, LayoutC>
mad(matrix::joint_matrix<T1, matrix::matrix_use::a, M, K, LayoutA> A,
matrix::joint_matrix<T1, matrix::matrix_use::b, K, N, LayoutB> B,
matrix::joint_matrix<T2, matrix::matrix_use::accumulator, M, N, LayoutC>
sycl::ext::oneapi::experimental::matrix::joint_matrix<
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M,
N, LayoutC>
mad(sycl::ext::oneapi::experimental::matrix::joint_matrix<
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K,
LayoutA>
A,
sycl::ext::oneapi::experimental::matrix::joint_matrix<
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N,
LayoutB>
B,
sycl::ext::oneapi::experimental::matrix::joint_matrix<
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
M, N, LayoutC>
C);
};

template <matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB>
template <sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA,
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB>
constexpr int get_layout_pair_id();

template <>
constexpr int get_layout_pair_id<matrix::matrix_layout::row_major,
matrix::matrix_layout::row_major>() {
constexpr int get_layout_pair_id<
sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major,
sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major>() {
return 0;
}

template <>
constexpr int get_layout_pair_id<matrix::matrix_layout::row_major,
matrix::matrix_layout::col_major>() {
constexpr int get_layout_pair_id<
sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major,
sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major>() {
return 1;
}

template <>
constexpr int get_layout_pair_id<matrix::matrix_layout::col_major,
matrix::matrix_layout::row_major>() {
constexpr int get_layout_pair_id<
sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major,
sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major>() {
return 2;
}

template <>
constexpr int get_layout_pair_id<matrix::matrix_layout::col_major,
matrix::matrix_layout::col_major>() {
constexpr int get_layout_pair_id<
sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major,
sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major>() {
return 3;
}

template <matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB,
matrix::matrix_layout LayoutC>
template <sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA,
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB,
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC>
struct joint_matrix_mad_impl<
double, double, 8, 4, 8, LayoutA, LayoutB, LayoutC,
typename std::enable_if_t<(LayoutA == matrix::matrix_layout::row_major ||
LayoutA == matrix::matrix_layout::col_major) &&
(LayoutB == matrix::matrix_layout::row_major ||
LayoutB == matrix::matrix_layout::col_major) &&
(LayoutC == matrix::matrix_layout::row_major ||
LayoutC == matrix::matrix_layout::col_major)>> {
matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8, LayoutC>
mad(matrix::joint_matrix<double, matrix::matrix_use::a, 8, 4, LayoutA> A,
matrix::joint_matrix<double, matrix::matrix_use::b, 4, 8, LayoutB> B,
matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8,
LayoutC>
typename std::enable_if_t<
(LayoutA == sycl::ext::oneapi::experimental::matrix::matrix_layout::
row_major ||
LayoutA == sycl::ext::oneapi::experimental::matrix::matrix_layout::
col_major) &&
(LayoutB == sycl::ext::oneapi::experimental::matrix::matrix_layout::
row_major ||
LayoutB == sycl::ext::oneapi::experimental::matrix::matrix_layout::
col_major) &&
(LayoutC == sycl::ext::oneapi::experimental::matrix::matrix_layout::
row_major ||
LayoutC == sycl::ext::oneapi::experimental::matrix::matrix_layout::
col_major)>> {
sycl::ext::oneapi::experimental::matrix::joint_matrix<
double, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
8, 8, LayoutC>
mad(sycl::ext::oneapi::experimental::matrix::joint_matrix<
double, sycl::ext::oneapi::experimental::matrix::matrix_use::a, 8, 4,
LayoutA>
A,
sycl::ext::oneapi::experimental::matrix::joint_matrix<
double, sycl::ext::oneapi::experimental::matrix::matrix_use::b, 4, 8,
LayoutB>
B,
sycl::ext::oneapi::experimental::matrix::joint_matrix<
double,
sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8,
8, LayoutC>
C) {
matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8, LayoutC>
sycl::ext::oneapi::experimental::matrix::joint_matrix<
double,
sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8, 8,
LayoutC>
D;

#ifdef __NVPTX__
Expand All @@ -225,8 +294,9 @@ template <typename Group, typename T, matrix_use MT, size_t NumRows,
void joint_matrix_load(
Group sg, joint_matrix<T, MT, NumRows, NumCols, Layout, Group> &res,
multi_ptr<T, Space> src, size_t stride) {
detail::joint_matrix_load_impl<T, MT, NumRows, NumCols, Layout, Space>{}.load(
res, src, stride);
sycl::ext::oneapi::detail::joint_matrix_load_impl<T, MT, NumRows, NumCols,
Layout, Space>{}
.load(res, src, stride);
}

template <typename Group, typename T, size_t NumRows, size_t NumCols,
Expand All @@ -235,8 +305,9 @@ void joint_matrix_store(Group sg,
joint_matrix<T, matrix_use::accumulator, NumRows,
NumCols, Layout, Group> &src,
multi_ptr<T, Space> dst, size_t stride) {
detail::joint_matrix_store_impl<T, NumRows, NumCols, Layout, Space>{}.store(
src, dst, stride);
sycl::ext::oneapi::detail::joint_matrix_store_impl<T, NumRows, NumCols,
Layout, Space>{}
.store(src, dst, stride);
}

template <typename Group, typename T1, typename T2, std::size_t M,
Expand All @@ -247,8 +318,8 @@ joint_matrix_mad(
Group sg, joint_matrix<T1, matrix_use::a, M, K, LayoutA, Group> A,
joint_matrix<T1, matrix_use::b, K, N, LayoutB, Group> B,
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group> C) {
return detail::joint_matrix_mad_impl<T1, T2, M, K, N, LayoutA, LayoutB,
LayoutC>{}
return sycl::ext::oneapi::detail::joint_matrix_mad_impl<
T1, T2, M, K, N, LayoutA, LayoutB, LayoutC>{}
.mad(A, B, C);
}

Expand Down
20 changes: 10 additions & 10 deletions sycl/test/check_device_code/matrix/matrix-nvptx-double-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ int main() {
joint_matrix<double, matrix_use::b, K, N, matrix_layout::row_major>
sub_b;

//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i, i32 8) #{{.*}}
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64.p1f64(double addrspace(1)* %_arg_, i32 8) #{{.*}}
joint_matrix_load(sg, sub_c, accC.get_pointer(), N);
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i54, i32 4) #{{.*}}
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64.p1f64(double addrspace(1)* %_arg_4, i32 4) #{{.*}}
joint_matrix_load(sg, sub_a, accA.get_pointer(), K);
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i65, i32 8) #{{.*}}
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.row.stride.f64.p1f64(double addrspace(1)* %_arg_9, i32 8) #{{.*}}
joint_matrix_load(sg, sub_b, accB.get_pointer(), N);
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double %11, double %12, double %9, double %10) #{{.*}}
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double %3, double %4, double %1, double %2) #{{.*}}
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
//CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i76, double %14, double %15, i32 8) #{{.*}}
//CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1f64(double addrspace(1)* %_arg_14, double %6, double %7, i32 8) #{{.*}}
joint_matrix_store(sg, sub_c, accD.get_pointer(), N);
});
});
Expand All @@ -84,15 +84,15 @@ int main() {
joint_matrix<double, matrix_use::b, K, N, matrix_layout::col_major>
sub_b;

//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i, i32 8) #{{.*}}
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.col.stride.f64.p1f64(double addrspace(1)* %_arg_, i32 8) #{{.*}}
joint_matrix_load(sg, sub_c, accC.get_pointer(), M);
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i54, i32 8) #{{.*}}
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.col.stride.f64.p1f64(double addrspace(1)* %_arg_4, i32 8) #{{.*}}
joint_matrix_load(sg, sub_a, accA.get_pointer(), M);
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i65, i32 4) #{{.*}}
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.col.stride.f64.p1f64(double addrspace(1)* %_arg_9, i32 4) #{{.*}}
joint_matrix_load(sg, sub_b, accB.get_pointer(), K);
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double %11, double %12, double %9, double %10) #{{.*}}
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double %3, double %4, double %1, double %2) #{{.*}}
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
//CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i76, double %14, double %15, i32 8) #{{.*}}
//CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1f64(double addrspace(1)* %_arg_14, double %6, double %7, i32 8) #{{.*}}
joint_matrix_store(sg, sub_c, accD.get_pointer(), M);
});
});
Expand Down