|
1 |
| -//==-------- joint_matrix_query.cpp - DPC++ joint_matrix------------ ----==// |
| 1 | +//==-------- joint_matrix_query_default.cpp - DPC++ joint_matrix-----------==// |
2 | 2 | //
|
3 | 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
4 | 4 | // See https://llvm.org/LICENSE.txt for license information.
|
@@ -38,9 +38,9 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
|
38 | 38 | assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4);
|
39 | 39 |
|
40 | 40 | using myparams2 = tpu_params<tpu::amx, int8_t, int8_t, int>;
|
41 |
| - constexpr int TM = myparams2::defaultM; |
42 |
| - constexpr int TN = myparams2::defaultN; |
43 |
| - constexpr int TK = myparams2::defaultK; |
| 41 | + constexpr int TM = myparams2::M; |
| 42 | + constexpr int TN = myparams2::N; |
| 43 | + constexpr int TK = myparams2::K; |
44 | 44 |
|
45 | 45 | std::cout << "AMX query sizes are: M " << TM << " N " << TN << " K " << TK
|
46 | 46 | << std::endl;
|
@@ -74,9 +74,11 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
|
74 | 74 |
|
75 | 75 | ext::oneapi::sub_group sg = spmd_item.get_sub_group();
|
76 | 76 |
|
77 |
| - myparams2::joint_matrix_a<sub_group> sub_a; |
78 |
| - myparams2::joint_matrix_b<sub_group> sub_b; |
79 |
| - myparams2::joint_matrix_c<sub_group> sub_c; |
| 77 | + myparams2::joint_matrix_a<sub_group, layout::row_major> sub_a; |
| 78 | + myparams2::joint_matrix_b< |
| 79 | + sub_group, ext::intel::experimental::matrix::layout::packed> |
| 80 | + sub_b; |
| 81 | + myparams2::joint_matrix_accumulator<sub_group> sub_c; |
80 | 82 |
|
81 | 83 | joint_matrix_load(sg, sub_c,
|
82 | 84 | accC.get_pointer() + (sg_startx * TM) * N +
|
|
0 commit comments