Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit 98f24ea

Browse files
address dounia's comments
1 parent 8302263 commit 98f24ea

9 files changed

+96
-630
lines changed

SYCL/Matrix/joint_matrix_bfloat16_col_major.cpp

Lines changed: 0 additions & 170 deletions
This file was deleted.
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//==-- joint_matrix_bfloat16_colmajorA_colmajorB.cpp - DPC++ joint_matrix--==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %clangxx -fsycl %s -o %t.out
11+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
12+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
13+
// CHECK: passed
14+
15+
// This tests support of col major layout for matrix B which does transpose and
16+
// then VNNI transform. This is currently only available on AMX
17+
18+
// XFAIL: gpu
19+
20+
#include <iostream>
21+
#include <sycl/sycl.hpp>
22+
23+
using namespace sycl;
24+
using namespace sycl::ext::oneapi::experimental::matrix;
25+
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
26+
27+
#define SG_SZ 16
28+
29+
#include "joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp"

SYCL/Matrix/joint_matrix_bfloat16_col_majorA.cpp renamed to SYCL/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,7 @@
1-
//==----- joint_matrix_bfloat16_col_major.cpp - DPC++ joint_matrix---------==//
2-
//
3-
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4-
// See https://llvm.org/LICENSE.txt for license information.
5-
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6-
//
7-
//===----------------------------------------------------------------------===//
8-
// REQUIRES: matrix
9-
10-
// RUN: %clangxx -fsycl %s -o %t.out
11-
// RUN: %CPU_RUN_PLACEHOLDER %t.out
12-
// RUN: %GPU_RUN_PLACEHOLDER %t.out
13-
14-
// This tests support of col major layout for matrix B which does transpose and
15-
// then VNNI transform. This is currently only available on AMX
16-
17-
// XFAIL: gpu
18-
19-
#include <iostream>
20-
#include <sycl/sycl.hpp>
21-
22-
using namespace sycl;
23-
using namespace sycl::ext::oneapi::experimental::matrix;
24-
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
25-
26-
#define SG_SZ 8
27-
281
#define TM 8
29-
#define TN 8
2+
#define TN SG_SZ
303
#define TK 16
4+
#define BF16_EPSILON 0.00781250
315

326
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
337
private:
@@ -78,7 +52,7 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
7852
N, matrix_layout::row_major);
7953
for (int k = 0; k < K / TK; k += 1) { //
8054
joint_matrix_load(
81-
sg, sub_a, accA.get_pointer() + ( k* TK) * M + sg_startx * TM,
55+
sg, sub_a, accA.get_pointer() + (k * TK) * M + sg_startx * TM,
8256
M, matrix_layout::col_major);
8357
joint_matrix_load(sg, sub_b,
8458
accB.get_pointer() +
@@ -158,9 +132,8 @@ int main() {
158132
bool res = true;
159133
for (int i = 0; i < MATRIX_M; i++) {
160134
for (int j = 0; j < MATRIX_N; j++) {
161-
if (C[i][j] != D[i][j]) {
135+
if ((fabs(C[i][j]) - fabs(D[i][j])) > BF16_EPSILON)
162136
res = false;
163-
}
164137
}
165138
}
166139
if (res)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//==--joint_matrix_bfloat16_rowmajorA_rowmajorB.cpp - DPC++ joint_matrix---==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %clangxx -fsycl %s -o %t.out
11+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
12+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
13+
// CHECK: passed
14+
15+
// This tests support of row major layout for matrix B which does automatic VNNI
16+
// transform. This is currently only available on AMX
17+
18+
// XFAIL: gpu
19+
20+
#include <iostream>
21+
#include <sycl/sycl.hpp>
22+
23+
using namespace sycl;
24+
using namespace sycl::ext::oneapi::experimental::matrix;
25+
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
26+
27+
#define SG_SZ 16
28+
29+
#include "joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp"

SYCL/Matrix/joint_matrix_bfloat16_row_major.cpp renamed to SYCL/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,7 @@
1-
//==-------joint_matrix_bfloat16_row_major.cpp - DPC++ joint_matrix--------==//
2-
//
3-
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4-
// See https://llvm.org/LICENSE.txt for license information.
5-
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6-
//
7-
//===----------------------------------------------------------------------===//
8-
// REQUIRES: matrix
9-
10-
// RUN: %clangxx -fsycl %s -o %t.out
11-
// RUN: %CPU_RUN_PLACEHOLDER %t.out
12-
// RUN: %GPU_RUN_PLACEHOLDER %t.out
13-
14-
// This tests support of row major layout for matrix B which does automatic VNNI
15-
// transform. This is currently only available on AMX
16-
17-
// XFAIL: gpu
18-
19-
#include <iostream>
20-
#include <sycl/sycl.hpp>
21-
22-
using namespace sycl;
23-
using namespace sycl::ext::oneapi::experimental::matrix;
24-
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
25-
26-
#define SG_SZ 8
27-
281
#define TM 8
29-
#define TN 8
2+
#define TN SG_SZ
303
#define TK 16
4+
#define BF16_EPSILON 0.00781250
315

326
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
337
private:
@@ -158,9 +132,8 @@ int main() {
158132
bool res = true;
159133
for (int i = 0; i < MATRIX_M; i++) {
160134
for (int j = 0; j < MATRIX_N; j++) {
161-
if (C[i][j] != D[i][j]) {
135+
if ((fabs(C[i][j]) - fabs(D[i][j])) > BF16_EPSILON)
162136
res = false;
163-
}
164137
}
165138
}
166139
if (res)

0 commit comments

Comments
 (0)