Skip to content

Commit daea2ba

Browse files
authored
[SYCL][Matrix] Fix the bug in element_wise_all_ops_tf32_impl.hpp (#10469)
1 parent 65cc0cf commit daea2ba

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ void matrix_verify_add(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
5252
ext::intel::experimental::matrix::joint_matrix_store(
5353
sg, sub_a,
5454
accA.template get_multi_ptr<access::decorated::no>() +
55-
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
55+
(sg_startx * TM) * N + sg_starty / SG_SZ * TK,
5656
N);
5757
}); // parallel for
5858
}).wait();
@@ -87,7 +87,7 @@ void matrix_verify_sub(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
8787
ext::intel::experimental::matrix::joint_matrix_store(
8888
sg, sub_a,
8989
accA.template get_multi_ptr<access::decorated::no>() +
90-
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
90+
(sg_startx * TM) * N + sg_starty / SG_SZ * TK,
9191
N);
9292
}); // parallel for
9393
}).wait();
@@ -121,7 +121,7 @@ void matrix_verify_mul(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
121121
ext::intel::experimental::matrix::joint_matrix_store(
122122
sg, sub_a,
123123
accA.template get_multi_ptr<access::decorated::no>() +
124-
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
124+
(sg_startx * TM) * N + sg_starty / SG_SZ * TK,
125125
N);
126126
}); // parallel for
127127
}).wait();
@@ -156,7 +156,7 @@ void matrix_verify_div(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
156156
ext::intel::experimental::matrix::joint_matrix_store(
157157
sg, sub_a,
158158
accA.template get_multi_ptr<access::decorated::no>() +
159-
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
159+
(sg_startx * TM) * N + sg_starty / SG_SZ * TK,
160160
N);
161161
}); // parallel for
162162
}).wait();
@@ -206,7 +206,7 @@ void matrix_verify_logic(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
206206
ext::intel::experimental::matrix::joint_matrix_store(
207207
sg, sub_a,
208208
accA.template get_multi_ptr<access::decorated::no>() +
209-
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
209+
(sg_startx * TM) * N + sg_starty / SG_SZ * TK,
210210
N);
211211
}); // parallel for
212212
}).wait();
@@ -224,7 +224,7 @@ int main() {
224224
big_matrix<float, MATRIX_M, MATRIX_N> MA((float *)&A);
225225

226226
size_t NDRangeM = MATRIX_M / TM;
227-
size_t NDRangeN = MATRIX_N / TN;
227+
size_t NDRangeN = MATRIX_N / TK;
228228
queue q;
229229
nd_range<2> r({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ});
230230

0 commit comments

Comments
 (0)