@@ -52,7 +52,7 @@ void matrix_verify_add(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
52
52
ext::intel::experimental::matrix::joint_matrix_store (
53
53
sg, sub_a,
54
54
accA.template get_multi_ptr <access::decorated::no>() +
55
- (sg_startx * TM) * N + sg_starty / SG_SZ * TN ,
55
+ (sg_startx * TM) * N + sg_starty / SG_SZ * TK ,
56
56
N);
57
57
}); // parallel for
58
58
}).wait ();
@@ -87,7 +87,7 @@ void matrix_verify_sub(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
87
87
ext::intel::experimental::matrix::joint_matrix_store (
88
88
sg, sub_a,
89
89
accA.template get_multi_ptr <access::decorated::no>() +
90
- (sg_startx * TM) * N + sg_starty / SG_SZ * TN ,
90
+ (sg_startx * TM) * N + sg_starty / SG_SZ * TK ,
91
91
N);
92
92
}); // parallel for
93
93
}).wait ();
@@ -121,7 +121,7 @@ void matrix_verify_mul(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
121
121
ext::intel::experimental::matrix::joint_matrix_store (
122
122
sg, sub_a,
123
123
accA.template get_multi_ptr <access::decorated::no>() +
124
- (sg_startx * TM) * N + sg_starty / SG_SZ * TN ,
124
+ (sg_startx * TM) * N + sg_starty / SG_SZ * TK ,
125
125
N);
126
126
}); // parallel for
127
127
}).wait ();
@@ -156,7 +156,7 @@ void matrix_verify_div(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
156
156
ext::intel::experimental::matrix::joint_matrix_store (
157
157
sg, sub_a,
158
158
accA.template get_multi_ptr <access::decorated::no>() +
159
- (sg_startx * TM) * N + sg_starty / SG_SZ * TN ,
159
+ (sg_startx * TM) * N + sg_starty / SG_SZ * TK ,
160
160
N);
161
161
}); // parallel for
162
162
}).wait ();
@@ -206,7 +206,7 @@ void matrix_verify_logic(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
206
206
ext::intel::experimental::matrix::joint_matrix_store (
207
207
sg, sub_a,
208
208
accA.template get_multi_ptr <access::decorated::no>() +
209
- (sg_startx * TM) * N + sg_starty / SG_SZ * TN ,
209
+ (sg_startx * TM) * N + sg_starty / SG_SZ * TK ,
210
210
N);
211
211
}); // parallel for
212
212
}).wait ();
@@ -224,7 +224,7 @@ int main() {
224
224
big_matrix<float , MATRIX_M, MATRIX_N> MA ((float *)&A);
225
225
226
226
size_t NDRangeM = MATRIX_M / TM;
227
- size_t NDRangeN = MATRIX_N / TN ;
227
+ size_t NDRangeN = MATRIX_N / TK ;
228
228
queue q;
229
229
nd_range<2 > r ({NDRangeM, NDRangeN * SG_SZ}, {1 , 1 * SG_SZ});
230
230
0 commit comments