Skip to content

[SYCL][joint matrix] add missing licence to test and add combination-based query #12489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 29, 2024
15 changes: 5 additions & 10 deletions sycl/test-e2e/Matrix/SG32/element_wise_all_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,16 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// REQUIRES: matrix
// REQUIRES: cpu, gpu
// REQUIRES-INTEL-DRIVER: lin: 27501, win: 101.4943
// SG size = 32 is not currently supported for SYCL Joint Matrix by IGC on DG2
// UNSUPPORTED: gpu-intel-dg2

// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

#include <iostream>
#include <random>
#include <sycl/sycl.hpp>
#include "../common.hpp"

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;
using bfloat16 = sycl::ext::oneapi::bfloat16;

constexpr size_t SG_SZ = 32;
constexpr size_t TN = 16;
#define SG_SZ 32

#include "../element_wise_all_ops_impl.hpp"
25 changes: 0 additions & 25 deletions sycl/test-e2e/Matrix/XMX8/element_wise_all_ops.cpp

This file was deleted.

31 changes: 31 additions & 0 deletions sycl/test-e2e/Matrix/common.hpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
//==------------------ common.hpp - DPC++ joint_matrix---------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include <cmath>
#include <iostream>
#include <random>
Expand Down Expand Up @@ -173,3 +180,27 @@ bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
}
return true;
}

bool is_type_supported_by_device(queue q, matrix_type type) {
std::vector<combination> combinations =
q.get_device()
.get_info<sycl::ext::oneapi::experimental::info::device::
matrix_combinations>();
for (int i = 0; i < combinations.size(); i++) {
if (combinations[i].atype == type) {
return true;
}
}
return false;
}

template <typename KernelName> size_t get_sg_size(queue q) {
auto KernelID = get_kernel_id<KernelName>();
auto KB =
get_kernel_bundle<bundle_state::executable>(q.get_context(), {KernelID});
auto kernel = KB.get_kernel(KernelID);

return kernel
.template get_info<info::kernel_device_specific::max_sub_group_size>(
q.get_device());
}
16 changes: 5 additions & 11 deletions sycl/test-e2e/Matrix/element_wise_all_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,14 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// REQUIRES: matrix
// REQUIRES: cpu, gpu
// Test is flaky/timeouts on some variants of DG2 and temporary disabled. Needs
// to be investigated.
// UNSUPPORTED: gpu-intel-dg2

// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

#include <iostream>
#include <random>
#include <sycl/sycl.hpp>

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;
using bfloat16 = sycl::ext::oneapi::bfloat16;

#define SG_SZ 16
constexpr size_t TN = 16;
#include "common.hpp"

#include "element_wise_all_ops_impl.hpp"
131 changes: 70 additions & 61 deletions sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,6 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

static float make_fp32(bfloat16 x) {
unsigned int y = *((int *)&x);
y = y << 16;
float *res = reinterpret_cast<float *>(&y);
return *res;
}

template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
public:
T *mat;

public:
T *get_data() { return mat; }
void set_data(T *data) { mat = data; }
big_matrix(T *data) : mat(data) {}
};

template <typename T, size_t NUM_ROWS, size_t NUM_COLS>
void assert_ops_ref(host_accessor<T, 2, access::mode::read> mat,
const float ref) {
Expand All @@ -39,20 +21,25 @@ void assert_ops_ref(host_accessor<T, 2, access::mode::read> mat,
}

template <typename T, size_t NUM_ROWS, size_t NUM_COLS, size_t SUB_ROWS,
size_t SUB_COLS, typename OP>
size_t SUB_COLS, class kernel_name, typename OP>
void verify_op_a(const T l, const T r, const float ref, OP op) {
T mat[NUM_ROWS][NUM_COLS];
big_matrix<T, NUM_ROWS, NUM_COLS> big_mat((T *)&mat);

buffer<T, 2> bufMat(big_mat.get_data(), range<2>(NUM_ROWS, NUM_COLS));

queue q;
size_t sg_size = get_sg_size<kernel_name>(q);
q.submit([&](handler &cgh) {
sycl::accessor accessMat{bufMat, cgh, sycl::read_write};
cgh.parallel_for(
nd_range<2>({NUM_ROWS / SUB_ROWS, NUM_COLS / SUB_COLS * SG_SZ},
{1, 1 * SG_SZ}),
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
cgh.parallel_for<kernel_name>(
nd_range<2>({NUM_ROWS / SUB_ROWS, NUM_COLS / SUB_COLS * sg_size},
{1, 1 * sg_size}),
[=](nd_item<2> spmd_item)
#ifdef SG_SZ
[[intel::reqd_sub_group_size(SG_SZ)]]
#endif
{
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
Expand All @@ -68,28 +55,32 @@ void verify_op_a(const T l, const T r, const float ref, OP op) {
sg, sub_mat,
accessMat.template get_multi_ptr<access::decorated::no>() +
(sg_startx * SUB_ROWS) * NUM_COLS +
sg_starty / SG_SZ * SUB_COLS,
sg_starty / sg_size * SUB_COLS,
NUM_COLS);
}); // parallel for
}).wait();
assert_ops_ref<T, NUM_ROWS, NUM_COLS>(bufMat.get_host_access(read_only), ref);
}

template <typename T, size_t NUM_ROWS, size_t NUM_COLS, size_t SUB_ROWS,
size_t SUB_COLS, typename OP>
size_t SUB_COLS, class kernel_name, typename OP>
void verify_op_c(const T l, const T r, const float ref, OP op) {
T mat[NUM_ROWS][NUM_COLS];
big_matrix<T, NUM_ROWS, NUM_COLS> big_mat((T *)&mat);

buffer<T, 2> bufMat(big_mat.get_data(), range<2>(NUM_ROWS, NUM_COLS));

queue q;
size_t sg_size = get_sg_size<kernel_name>(q);
q.submit([&](handler &cgh) {
sycl::accessor accessMat{bufMat, cgh, sycl::read_write};
cgh.parallel_for(
nd_range<2>({NUM_ROWS / SUB_ROWS, NUM_COLS / SUB_COLS * SG_SZ},
{1, 1 * SG_SZ}),
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
cgh.parallel_for<kernel_name>(
nd_range<2>({NUM_ROWS / SUB_ROWS, NUM_COLS / SUB_COLS * sg_size},
{1, 1 * sg_size}),
[=](nd_item<2> spmd_item)
#ifdef SG_SZ
[[intel::reqd_sub_group_size(SG_SZ)]]
#endif
{
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
Expand All @@ -105,85 +96,103 @@ void verify_op_c(const T l, const T r, const float ref, OP op) {
sg, sub_mat,
accessMat.template get_multi_ptr<access::decorated::no>() +
(sg_startx * SUB_ROWS) * NUM_COLS +
sg_starty / SG_SZ * SUB_COLS,
sg_starty / sg_size * SUB_COLS,
NUM_COLS, layout::row_major);
}); // parallel for
}).wait();
assert_ops_ref<T, NUM_ROWS, NUM_COLS>(bufMat.get_host_access(read_only), ref);
}

// Avoid same kernel name for different types
template <typename T, class name> class ewops_a {};
template <typename T, size_t NROWS, size_t NCOLS, size_t SROWS, size_t SCOLS>
void test_ewops_a() {

verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_add>>(
T(5.0), T(2.0), 7.0, [](auto l, auto r) { return l + r; });
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_sub>>(
T(5.0), T(2.0), 3.0, [](auto l, auto r) { return l - r; });
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_mul>>(
T(5.0), T(2.0), 10.0, [](auto l, auto r) { return l * r; });
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_div>>(
T(5.0), T(2.0), 2.5, [](auto l, auto r) { return l / r; });
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_logical>>(
T(5.0), T(5.0), 5.0, [](auto l, auto r) { return l == r ? l : T(1.0); });
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_eq>>(
T(5.0), T(4.0), 4.0, [](auto l, auto r) { return l == r ? l : r; });
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_ne>>(
T(5.0), T(5.0), 1.0, [](auto l, auto r) { return l != r ? l : T(1.0); });
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_gt>>(
T(5.0), T(2.0), 3.0,
[](auto l, auto r) { return l > r ? T(3.0) : T(2.0); });
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_lt>>(
T(5.0), T(2.0), 2.0,
[](auto l, auto r) { return l < r ? T(3.0) : T(2.0); });
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_ge>>(
T(5.0), T(2.0), 3.0,
[](auto l, auto r) { return l >= r ? T(3.0) : T(2.0); });
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_le>>(
T(5.0), T(2.0), 2.0,
[](auto l, auto r) { return l <= r ? T(3.0) : T(2.0); });
}

// Avoid same kernel name for different types and numbers of columns
template <typename T, size_t COLS, class name> class ewops_c {};
template <typename T, size_t NROWS, size_t NCOLS, size_t SROWS, size_t SCOLS>
void test_ewops_c() {

verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_add>>(
T(5.0), T(2.0), 7.0, [](auto l, auto r) { return l + r; });
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_sub>>(
T(5.0), T(2.0), 3.0, [](auto l, auto r) { return l - r; });
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_mul>>(
T(5.0), T(2.0), 10.0, [](auto l, auto r) { return l * r; });
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_div>>(
T(5.0), T(2.0), 2.5, [](auto l, auto r) { return l / r; });
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
ewops_c<T, SCOLS, class c_logical>>(
T(5.0), T(5.0), 5.0, [](auto l, auto r) { return l == r ? l : T(1.0); });
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_eq>>(
T(5.0), T(4.0), 4.0, [](auto l, auto r) { return l == r ? l : r; });
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_ne>>(
T(5.0), T(5.0), 1.0, [](auto l, auto r) { return l != r ? l : T(1.0); });
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_gt>>(
T(5.0), T(2.0), 3.0,
[](auto l, auto r) { return l > r ? T(3.0) : T(2.0); });
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_lt>>(
T(5.0), T(2.0), 2.0,
[](auto l, auto r) { return l < r ? T(3.0) : T(2.0); });
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_ge>>(
T(5.0), T(2.0), 3.0,
[](auto l, auto r) { return l >= r ? T(3.0) : T(2.0); });
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_le>>(
T(5.0), T(2.0), 2.0,
[](auto l, auto r) { return l <= r ? T(3.0) : T(2.0); });
}

int main() {
static constexpr size_t TM = 8;
static constexpr size_t TK = 16;

static constexpr size_t MATRIX_M = TM * 2;
static constexpr size_t MATRIX_N = TN * 2;
static constexpr size_t MATRIX_K = TK * 2;

test_ewops_a<bfloat16, MATRIX_M, MATRIX_K, TM, TK>();
test_ewops_c<float, MATRIX_M, MATRIX_N, TM, TN>();

static constexpr size_t MATRIX_N = 32;
static constexpr size_t MATRIX_K = 32;
queue q;
std::vector<combination> combinations =
q.get_device()
.get_info<sycl::ext::oneapi::experimental::info::device::
matrix_combinations>();
for (unsigned int i = 0; i < combinations.size(); i++) {
if (combinations[i].atype == matrix_type::bf16) {
if (combinations[i].nsize == 0 || combinations[i].nsize == 16) {
test_ewops_a<bfloat16, MATRIX_M, MATRIX_K, TM, 16>();
test_ewops_c<float, MATRIX_M, MATRIX_N, TM, 16>();
break;
}
if (combinations[i].nsize == 8) {
test_ewops_a<bfloat16, MATRIX_M, MATRIX_K, TM, 16>();
test_ewops_c<float, MATRIX_M, MATRIX_N, TM, 8>();
break;
}
}
}
return 0;
}