Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit ac564a6

Browse files
authored
[SYCL][CUDA] Adds cuda test for joint_matrix_apply. (#1655)
intel/llvm#8417 is merged so this is ready for review. --------- Signed-off-by: JackAKirk <[email protected]>
1 parent 5eb663d commit ac564a6

File tree

1 file changed

+124
-0
lines changed

1 file changed

+124
-0
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
//==------------ joint_matrix_apply_cuda.cpp - DPC++ joint_matrix----------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: cuda
9+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -Xsycl-target-backend --cuda-gpu-arch=sm_80 %s -o %t.out
10+
// RUN: %t.out
11+
12+
#include <sycl/sycl.hpp>
13+
14+
using namespace sycl;
15+
using namespace sycl::ext::oneapi::experimental::matrix;
16+
using sycl::ext::oneapi::bfloat16;
17+
18+
#define SG_SZ 32
19+
constexpr size_t nWGperDim = 2;
20+
21+
template <typename T1, typename T2, size_t M, size_t K, size_t N>
22+
class KernelName;
23+
24+
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
25+
public:
26+
T *mat;
27+
28+
T *get_data() { return mat; }
29+
void set_data(T *data) { mat = data; }
30+
big_matrix(T *data) : mat(data) {}
31+
};
32+
33+
template <typename T, size_t M, size_t N>
34+
void assert_ref(T *C, const float ref) {
35+
for (size_t i = 0; i < M; i++)
36+
for (size_t j = 0; j < N; j++) {
37+
auto diff = C[i + j * M] - ref;
38+
assert(std::fabs(static_cast<float>(diff)) <
39+
std::numeric_limits<float>::epsilon());
40+
}
41+
}
42+
43+
template <typename T, typename T2, size_t M, size_t K, size_t N, typename F>
44+
void matrix_verify_lambda(queue q,
45+
big_matrix<T2, M * nWGperDim, N * nWGperDim> &C,
46+
nd_range<2> &r, const float ref, F &&lambda) {
47+
{
48+
buffer<T2, 2> bufC(C.get_data(), range<2>(N * nWGperDim, M * nWGperDim));
49+
50+
q.submit([&](handler &cgh) {
51+
accessor<T2, 2, access::mode::read_write, target::device> accC(bufC, cgh);
52+
53+
cgh.parallel_for<KernelName<T, T2, M, K, N>>(
54+
r, [accC, lambda](
55+
nd_item<2> spmd_item) [[sycl::reqd_sub_group_size(SG_SZ)]] {
56+
const auto global_idx = spmd_item.get_global_id(0);
57+
const auto global_idy = spmd_item.get_global_id(1);
58+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
59+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
60+
61+
auto sg = spmd_item.get_sub_group();
62+
63+
joint_matrix<sub_group, T, use::a, M, K, layout::row_major> sub_a;
64+
joint_matrix<sub_group, T, use::b, K, N, layout::row_major> sub_b;
65+
joint_matrix<sub_group, T2, use::accumulator, M, N> sub_c;
66+
67+
joint_matrix_fill(sg, sub_a, 3);
68+
joint_matrix_fill(sg, sub_b, 1);
69+
joint_matrix_fill(sg, sub_c, -80);
70+
71+
joint_matrix_apply(sg, sub_a, lambda);
72+
73+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
74+
75+
joint_matrix_store(sg, sub_c,
76+
accC.get_pointer() +
77+
(sg_startx * M) * (N * nWGperDim) +
78+
sg_starty / SG_SZ * N,
79+
(N * nWGperDim), layout::row_major);
80+
}); // parallel for
81+
});
82+
}
83+
assert_ref<T2, M * nWGperDim, N * nWGperDim>(C.get_data(), ref);
84+
}
85+
86+
static constexpr size_t MATRIX_M = 16 * nWGperDim;
87+
static constexpr size_t MATRIX_N = 16 * nWGperDim;
88+
89+
int main() {
90+
91+
float D[MATRIX_M][MATRIX_N];
92+
big_matrix<float, MATRIX_M, MATRIX_N> MD_f((float *)&D);
93+
94+
queue q;
95+
auto computeCapability =
96+
std::stof(q.get_device().get_info<sycl::info::device::backend_version>());
97+
nd_range<2> r({nWGperDim, nWGperDim * SG_SZ}, {1, 1 * SG_SZ});
98+
99+
auto apply_add = [](auto &x) { x = x + 2; };
100+
101+
if (computeCapability >= 7.0) {
102+
matrix_verify_lambda<half, float, 16, 16, 16>(q, MD_f, r, 0.0, apply_add);
103+
}
104+
105+
if (computeCapability >= 7.2) {
106+
int32_t D_i[MATRIX_M][MATRIX_N];
107+
big_matrix<int32_t, MATRIX_M, MATRIX_N> MD_i((int32_t *)&D_i);
108+
matrix_verify_lambda<uint8_t, int32_t, 16, 16, 16>(q, MD_i, r, 0,
109+
apply_add);
110+
matrix_verify_lambda<int8_t, int32_t, 16, 16, 16>(q, MD_i, r, 0, apply_add);
111+
}
112+
113+
if (computeCapability >= 8.0) {
114+
matrix_verify_lambda<bfloat16, float, 16, 16, 16>(q, MD_f, r, 0.0,
115+
apply_add);
116+
117+
double D_d[MATRIX_M / 2][MATRIX_N / 2];
118+
big_matrix<double, 8 * nWGperDim, 8 * nWGperDim> MD_d((double *)&D_d);
119+
120+
matrix_verify_lambda<double, double, 8, 4, 8>(q, MD_d, r, -60.0, apply_add);
121+
}
122+
123+
return 0;
124+
}

0 commit comments

Comments
 (0)