Skip to content

Commit c043037

Browse files
authored
[SYCL][Joint matrix][tests] Add more test cases for transpose C (#10762)
1 parent a0cbbd8 commit c043037

6 files changed

+259
-0
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//==---------- joint_matrix_colA_rowB_colC.cpp - DPC++ joint_matrix---------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
11+
// RUN: %{run} %t.out
12+
13+
// XFAIL:gpu
14+
15+
#include "../common.hpp"
16+
17+
constexpr size_t SG_SZ = 8;
18+
constexpr size_t TN = 8;
19+
20+
#include "../joint_matrix_colA_rowB_colC_impl.hpp"
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//==----------- joint_matrix_transposeC.cpp - DPC++ joint_matrix-----------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
11+
// RUN: %{run} %t.out
12+
13+
// XFAIL:gpu
14+
15+
#include "../common.hpp"
16+
17+
constexpr size_t SG_SZ = 8;
18+
constexpr size_t TN = 8;
19+
20+
#include "../joint_matrix_transposeC_impl.hpp"
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//==---------- joint_matrix_colA_rowB_colC.cpp - DPC++ joint_matrix---------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
11+
// RUN: %{run} %t.out
12+
13+
// XFAIL:gpu
14+
15+
#include "common.hpp"
16+
17+
constexpr size_t SG_SZ = 16;
18+
constexpr size_t TN = 16;
19+
20+
#include "joint_matrix_colA_rowB_colC_impl.hpp"
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#include <iostream>
2+
#include <random>
3+
4+
using namespace sycl;
5+
using namespace sycl::ext::oneapi::experimental::matrix;
6+
7+
constexpr size_t TM = 8;
8+
constexpr size_t TK = 16;
9+
10+
template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
11+
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C,
12+
size_t NUM_COLS_C>
13+
void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) {
14+
size_t M = NUM_ROWS_C;
15+
size_t N = NUM_COLS_C;
16+
size_t K = NUM_COLS_A;
17+
18+
assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B);
19+
size_t NDRangeM = M / TM;
20+
size_t NDRangeN = N / TN;
21+
22+
auto pA = multi_ptr<T2, sycl::access::address_space::global_space>(A);
23+
auto pB = multi_ptr<T2, sycl::access::address_space::global_space>(B);
24+
auto pC = multi_ptr<T1, sycl::access::address_space::global_space>(C);
25+
26+
q.submit([&](handler &cgh) {
27+
cgh.parallel_for(
28+
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
29+
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]]
30+
31+
{
32+
// The submatrix API has to be accessed by all the workitems in a
33+
// subgroup these functions will be called once by the subgroup no
34+
// code divergence between the workitems
35+
const auto global_idx = spmd_item.get_global_id(0);
36+
const auto global_idy = spmd_item.get_global_id(1);
37+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
38+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
39+
40+
sub_group sg = spmd_item.get_sub_group();
41+
joint_matrix<sub_group, bfloat16, use::a, TM, TK, layout::col_major>
42+
sub_a;
43+
joint_matrix<sub_group, bfloat16, use::b, TK, TN, layout::row_major>
44+
sub_b;
45+
joint_matrix<sub_group, float, use::accumulator, TM, TN> sub_c;
46+
joint_matrix_fill(sg, sub_c, 1);
47+
for (int k = 0; k < K; k += TK) {
48+
joint_matrix_load(sg, sub_a, pA + (sg_startx * TM) * K + k, K);
49+
joint_matrix_load(sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN,
50+
N);
51+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
52+
}
53+
joint_matrix_store(
54+
sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N,
55+
layout::col_major);
56+
}); // parallel for
57+
}).wait();
58+
}
59+
60+
int main() {
61+
static constexpr size_t MATRIX_M = 1024;
62+
static constexpr size_t MATRIX_N = 1024;
63+
static constexpr size_t MATRIX_K = 1024;
64+
queue q;
65+
bfloat16 *A = malloc_shared<bfloat16>(MATRIX_M * MATRIX_K, q);
66+
bfloat16 *B = malloc_shared<bfloat16>(MATRIX_K * MATRIX_N, q);
67+
float *C = malloc_shared<float>(MATRIX_M * MATRIX_N, q);
68+
float *D = malloc_shared<float>(MATRIX_M * MATRIX_N, q);
69+
70+
matrix_rand(MATRIX_M, MATRIX_K, A, (bfloat16)5);
71+
matrix_rand(MATRIX_K, MATRIX_N, B, (bfloat16)5);
72+
matrix_fill(MATRIX_M, MATRIX_N, C, (float)1.0);
73+
matrix_fill(MATRIX_M, MATRIX_N, D, (float)1.0);
74+
75+
matrix_multiply<float, bfloat16, MATRIX_M, MATRIX_K, MATRIX_K, MATRIX_N,
76+
MATRIX_M, MATRIX_N>(C, A, B, q);
77+
matrix_multiply_ref(A, B, D, MATRIX_M, MATRIX_N, MATRIX_K,
78+
true /*transposed c*/);
79+
80+
bool res = matrix_compare(MATRIX_M, MATRIX_N, C, D);
81+
82+
std::cout << (res ? "passed" : "failed") << std::endl;
83+
return !res;
84+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//==----------- joint_matrix_transposeC.cpp - DPC++ joint_matrix-----------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
11+
// RUN: %{run} %t.out
12+
13+
// XFAIL:gpu
14+
15+
#include "common.hpp"
16+
17+
constexpr size_t SG_SZ = 16;
18+
constexpr size_t TN = 16;
19+
20+
#include "joint_matrix_transposeC_impl.hpp"
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#include <iostream>
2+
#include <random>
3+
4+
using namespace sycl;
5+
using namespace sycl::ext::oneapi::experimental::matrix;
6+
7+
constexpr size_t TM = 8;
8+
constexpr size_t TK = 16;
9+
10+
template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
11+
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C,
12+
size_t NUM_COLS_C>
13+
void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) {
14+
size_t M = NUM_ROWS_C;
15+
size_t N = NUM_COLS_C;
16+
size_t K = NUM_COLS_A;
17+
18+
size_t NDRangeM = M / TM;
19+
size_t NDRangeN = N / TN;
20+
21+
auto pA = multi_ptr<T2, sycl::access::address_space::global_space>(A);
22+
auto pB = multi_ptr<T2, sycl::access::address_space::global_space>(B);
23+
auto pC = multi_ptr<T1, sycl::access::address_space::global_space>(C);
24+
25+
q.submit([&](handler &cgh) {
26+
cgh.parallel_for(
27+
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
28+
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]]
29+
30+
{
31+
// The submatrix API has to be accessed by all the workitems in a
32+
// subgroup these functions will be called once by the subgroup no
33+
// code divergence between the workitems
34+
const auto global_idx = spmd_item.get_global_id(0);
35+
const auto global_idy = spmd_item.get_global_id(1);
36+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
37+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
38+
39+
sub_group sg = spmd_item.get_sub_group();
40+
joint_matrix<sub_group, bfloat16, use::a, TM, TK, layout::row_major>
41+
sub_a;
42+
43+
// For B, since current implementation does not support non-packed
44+
// layout, users need to specify the packed_b layout.
45+
joint_matrix<sub_group, bfloat16, use::b, TK, TN,
46+
ext::intel::experimental::matrix::layout::packed>
47+
sub_b;
48+
joint_matrix<sub_group, float, use::accumulator, TM, TN> sub_c;
49+
joint_matrix_load(sg, sub_c,
50+
pC + (sg_startx * TM) * N + sg_starty / SG_SZ * TN,
51+
N, layout::col_major);
52+
for (int k = 0; k < K; k += TK) {
53+
joint_matrix_load(sg, sub_a, pA + (sg_startx * TM) * K + k, K);
54+
// Assume we alreay in vnni format.
55+
joint_matrix_load(sg, sub_b,
56+
pB + k * N + sg_starty / SG_SZ * TN * vnniFactor,
57+
N * vnniFactor);
58+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
59+
}
60+
joint_matrix_store(
61+
sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N,
62+
layout::col_major);
63+
}); // parallel for
64+
}).wait();
65+
}
66+
67+
int main() {
68+
static constexpr size_t MATRIX_M = 1024;
69+
static constexpr size_t MATRIX_N = 1024;
70+
static constexpr size_t MATRIX_K = 1024;
71+
static constexpr unsigned int vnniFactor = 2;
72+
queue q;
73+
bfloat16 *A = malloc_shared<bfloat16>(MATRIX_M * MATRIX_K, q);
74+
bfloat16 *B = malloc_shared<bfloat16>(MATRIX_K * MATRIX_N, q);
75+
bfloat16 *vnniB = malloc_shared<bfloat16>(MATRIX_K * MATRIX_N, q);
76+
float *C = malloc_shared<float>(MATRIX_M * MATRIX_N, q);
77+
float *D = malloc_shared<float>(MATRIX_M * MATRIX_N, q);
78+
79+
matrix_rand(MATRIX_M, MATRIX_K, A, (bfloat16)5);
80+
matrix_rand(MATRIX_K, MATRIX_N, B, (bfloat16)5);
81+
matrix_fill(MATRIX_M, MATRIX_N, C, (float)1.0);
82+
matrix_fill(MATRIX_M, MATRIX_N, D, (float)1.0);
83+
84+
matrix_vnni<bfloat16>(MATRIX_K, MATRIX_N, B, vnniB, vnniFactor);
85+
matrix_multiply<float, bfloat16, MATRIX_M, MATRIX_K, MATRIX_K / vnniFactor,
86+
MATRIX_N * vnniFactor, MATRIX_M, MATRIX_N>(C, A, vnniB, q,
87+
vnniFactor);
88+
matrix_multiply_ref(A, B, D, MATRIX_M, MATRIX_N, MATRIX_K,
89+
true /*transposed c*/);
90+
91+
bool res = matrix_compare(MATRIX_M, MATRIX_N, C, D);
92+
93+
std::cout << (res ? "passed" : "failed") << std::endl;
94+
return !res;
95+
}

0 commit comments

Comments
 (0)