Skip to content

Commit a0cbbd8

Browse files
authored
[SYCL][joint matrix][tests]Add test for down convert float to bfloat16 using joint matrix (#10684)
1 parent 4f68009 commit a0cbbd8

File tree

1 file changed

+93
-0
lines changed

1 file changed

+93
-0
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
//==-------- joint_matrix_down_convert.cpp - DPC++ joint_matrix------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
11+
// RUN: %{run} %t.out
12+
13+
#include "common.hpp"
14+
#include <iostream>
15+
16+
using namespace sycl;
17+
using namespace sycl::ext::oneapi::experimental::matrix;
18+
19+
constexpr size_t SG_SZ = 16;
20+
21+
constexpr size_t TM = 8;
22+
// TN and TK must be the same for this test.
23+
constexpr size_t TN = 16;
24+
constexpr size_t TK = 16;
25+
26+
template <typename T1, typename T2, size_t M, size_t N, size_t K>
27+
void matrix_copy(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A) {
28+
size_t NDRangeM = M / TM;
29+
size_t NDRangeN = N / TN;
30+
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(M, K));
31+
buffer<float, 2> bufC((float *)C.get_data(), range<2>(M, N));
32+
33+
queue q;
34+
q.submit([&](handler &cgh) {
35+
auto accC = bufC.get_access<access::mode::read_write>(cgh);
36+
auto accA = bufA.get_access<access::mode::write>(cgh);
37+
38+
cgh.parallel_for(
39+
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
40+
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
41+
// The submatrix API has to be accessed by all the workitems in a
42+
// subgroup these functions will be called once by the subgroup no
43+
// code divergence between the workitems
44+
const auto global_idx = spmd_item.get_global_id(0);
45+
const auto global_idy = spmd_item.get_global_id(1);
46+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
47+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
48+
49+
sub_group sg = spmd_item.get_sub_group();
50+
joint_matrix<sub_group, bfloat16, use::a, TM, TK, layout::row_major>
51+
sub_a;
52+
joint_matrix<sub_group, float, use::accumulator, TM, TN> sub_c;
53+
54+
joint_matrix_load(
55+
sg, sub_c,
56+
accC.template get_multi_ptr<access::decorated::no>() +
57+
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
58+
N, layout::row_major);
59+
// This will be replaced by joint_matrix_copy API
60+
// joint_matrix_copy(sg, sub_c, sub_ac);
61+
auto wi_slice_c =
62+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c);
63+
auto wi_slice_a =
64+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);
65+
for (int i = 0; i < wi_slice_c.length(); i++) {
66+
wi_slice_a[i] = (bfloat16)wi_slice_c[i];
67+
}
68+
ext::intel::experimental::matrix::joint_matrix_store(
69+
sg, sub_a,
70+
accA.template get_multi_ptr<access::decorated::no>() +
71+
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
72+
N);
73+
}); // parallel for
74+
}).wait();
75+
}
76+
77+
int main() {
78+
static constexpr size_t MATRIX_M = TM * 2;
79+
static constexpr size_t MATRIX_N = TN * 2;
80+
static constexpr size_t MATRIX_K = TK * 2;
81+
bfloat16 A[MATRIX_M][MATRIX_K];
82+
float C[MATRIX_M][MATRIX_N];
83+
84+
matrix_rand(MATRIX_M, MATRIX_N, *C, (float)5);
85+
86+
big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C);
87+
big_matrix<bfloat16, MATRIX_M, MATRIX_K> MA((bfloat16 *)&A);
88+
matrix_copy(MC, MA);
89+
90+
bool res = matrix_compare(MATRIX_M, MATRIX_N, *A, *C);
91+
std::cout << (res ? "passed" : "failed") << std::endl;
92+
return !res;
93+
}

0 commit comments

Comments
 (0)