Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit 61464a1

Browse files
authored
[SYCL][Matrix] Minor corrections to the matrix tests (#1475)
1 parent 87f7445 commit 61464a1

File tree

3 files changed

+10
-176
lines changed

3 files changed

+10
-176
lines changed

SYCL/Matrix/joint_matrix_int8_vnni_impl.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
5050
sub_group sg = spmd_item.get_sub_group();
5151
joint_matrix<sub_group, int8_t, use::a, TM, TK, layout::row_major>
5252
sub_a;
53-
joint_matrix<sub_group, int8_t, use::b, TK, TN,
54-
ext::intel::experimental::matrix::layout::packed>
53+
joint_matrix<sub_group, int8_t, use::b, TK, TN, layout::row_major>
5554
sub_b;
5655
joint_matrix<sub_group, int32_t, use::accumulator, TM, TN> sub_c;
5756

SYCL/Matrix/joint_matrix_query_default.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//==-------- joint_matrix_query.cpp - DPC++ joint_matrix------------ ----==//
1+
//==-------- joint_matrix_query_default.cpp - DPC++ joint_matrix-----------==//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -38,9 +38,9 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
3838
assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4);
3939

4040
using myparams2 = tpu_params<tpu::amx, int8_t, int8_t, int>;
41-
constexpr int TM = myparams2::defaultM;
42-
constexpr int TN = myparams2::defaultN;
43-
constexpr int TK = myparams2::defaultK;
41+
constexpr int TM = myparams2::M;
42+
constexpr int TN = myparams2::N;
43+
constexpr int TK = myparams2::K;
4444

4545
std::cout << "AMX query sizes are: M " << TM << " N " << TN << " K " << TK
4646
<< std::endl;
@@ -74,9 +74,11 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
7474

7575
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
7676

77-
myparams2::joint_matrix_a<sub_group> sub_a;
78-
myparams2::joint_matrix_b<sub_group> sub_b;
79-
myparams2::joint_matrix_c<sub_group> sub_c;
77+
myparams2::joint_matrix_a<sub_group, layout::row_major> sub_a;
78+
myparams2::joint_matrix_b<
79+
sub_group, ext::intel::experimental::matrix::layout::packed>
80+
sub_b;
81+
myparams2::joint_matrix_accumulator<sub_group> sub_c;
8082

8183
joint_matrix_load(sg, sub_c,
8284
accC.get_pointer() + (sg_startx * TM) * N +

SYCL/Matrix/joint_matrix_query_use_default.cpp

Lines changed: 0 additions & 167 deletions
This file was deleted.

0 commit comments

Comments
 (0)