Skip to content

Commit 6afeb2b

Browse files
authored
[SYCL][Joint Matrix][Tests] Add tests for 16x16x16 and 32x64x32 joint_matrix shape combinations (#11649)
Note that I did not touch the file in XMX8 as optimal combinations on DG2 are different. So it is better to address DG2 case as a separate task
1 parent f7bf29b commit 6afeb2b

7 files changed

+111
-9
lines changed

sycl/test-e2e/Matrix/SG32/joint_matrix_bfloat16_32x64.cpp renamed to sycl/test-e2e/Matrix/SG32/joint_matrix_bfloat16_16x16x16.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//==----- joint_matrix_bfloat16_32x64.cpp - DPC++ joint_matrix-------------==//
1+
//==----- joint_matrix_bfloat16_16x16x16.cpp - DPC++ joint_matrix----------==//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -17,6 +17,9 @@
1717
using namespace sycl;
1818
using namespace sycl::ext::oneapi::experimental::matrix;
1919

20-
constexpr size_t SG_SZ = 32;
20+
#define SG_SZ 32
21+
constexpr size_t TM = 16;
22+
constexpr size_t TN = 16;
23+
constexpr size_t TK = 16;
2124

22-
#include "../joint_matrix_bfloat16_32x64_impl.hpp"
25+
#include "../joint_matrix_bfloat16_packedB_impl.hpp"
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//==----- joint_matrix_bfloat16_32x64x16.cpp - DPC++ joint_matrix----------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
11+
// RUN: %{run} %t.out
12+
13+
// XFAIL: *
14+
15+
#include "../common.hpp"
16+
17+
using namespace sycl;
18+
using namespace sycl::ext::oneapi::experimental::matrix;
19+
20+
#define SG_SZ 32
21+
constexpr size_t TM = 32;
22+
constexpr size_t TN = 64;
23+
constexpr size_t TK = 16;
24+
25+
#include "../joint_matrix_bfloat16_packedB_impl.hpp"
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//==----- joint_matrix_bfloat16_32x64x32.cpp - DPC++ joint_matrix----------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
11+
// RUN: %{run} %t.out
12+
13+
// XFAIL: *
14+
15+
#include "../common.hpp"
16+
17+
using namespace sycl;
18+
using namespace sycl::ext::oneapi::experimental::matrix;
19+
20+
#define SG_SZ 32
21+
constexpr size_t TM = 32;
22+
constexpr size_t TN = 64;
23+
constexpr size_t TK = 32;
24+
25+
#include "../joint_matrix_bfloat16_packedB_impl.hpp"

sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64.cpp renamed to sycl/test-e2e/Matrix/joint_matrix_bfloat16_16x16x16.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//==----- joint_matrix_bfloat16_32x64.cpp - DPC++ joint_matrix-------------==//
1+
//==----- joint_matrix_bfloat16_16x16x16.cpp - DPC++ joint_matrix----------==//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -18,5 +18,8 @@ using namespace sycl;
1818
using namespace sycl::ext::oneapi::experimental::matrix;
1919

2020
#define SG_SZ 16
21+
constexpr size_t TM = 16;
22+
constexpr size_t TN = 16;
23+
constexpr size_t TK = 16;
2124

22-
#include "joint_matrix_bfloat16_32x64_impl.hpp"
25+
#include "joint_matrix_bfloat16_packedB_impl.hpp"
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//==----- joint_matrix_bfloat16_32x64x16.cpp - DPC++ joint_matrix----------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
11+
// RUN: %{run} %t.out
12+
13+
// XFAIL: *
14+
15+
#include "common.hpp"
16+
17+
using namespace sycl;
18+
using namespace sycl::ext::oneapi::experimental::matrix;
19+
20+
#define SG_SZ 16
21+
constexpr size_t TM = 32;
22+
constexpr size_t TN = 64;
23+
constexpr size_t TK = 16;
24+
25+
#include "joint_matrix_bfloat16_packedB_impl.hpp"
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//==----- joint_matrix_bfloat16_32x64x32.cpp - DPC++ joint_matrix----------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
11+
// RUN: %{run} %t.out
12+
13+
// XFAIL: *
14+
15+
#include "common.hpp"
16+
17+
using namespace sycl;
18+
using namespace sycl::ext::oneapi::experimental::matrix;
19+
20+
#define SG_SZ 16
21+
constexpr size_t TM = 32;
22+
constexpr size_t TN = 64;
23+
constexpr size_t TK = 32;
24+
25+
#include "joint_matrix_bfloat16_packedB_impl.hpp"

sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp renamed to sycl/test-e2e/Matrix/joint_matrix_bfloat16_packedB_impl.hpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
#define TM 32
2-
#define TN 64
3-
#define TK 16
4-
51
template <typename T1, typename T2, size_t M, size_t N, size_t K>
62
void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
73
big_matrix<T2, K / 2, N * 2> &B) {

0 commit comments

Comments
 (0)