@@ -37,42 +37,41 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
37
37
cgh.parallel_for <class imatrix >(
38
38
nd_range<2 >({NDRangeM, NDRangeN * SG_SZ}, {1 , 1 * SG_SZ}),
39
39
[accA, accB, accC, M, N, K](nd_item<2 > spmd_item)
40
-
41
- {
42
- // The submatrix API has to be accessed by all the workitems in a
43
- // subgroup these functions will be called once by the subgroup no
44
- // code divergence between the workitems
45
- const auto global_idx = spmd_item.get_global_id (0 );
46
- const auto global_idy = spmd_item.get_global_id (1 );
47
- const auto sg_startx = global_idx - spmd_item.get_local_id (0 );
48
- const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
49
-
50
- sub_group sg = spmd_item.get_sub_group ();
51
- joint_matrix<int8_t , TM, TK> sub_a (sg);
52
- joint_matrix<int8_t , TK, TN, matrix_layout::packed_b> sub_b (sg);
53
- joint_matrix<int32_t , TM, TN> sub_c (sg);
54
-
55
- joint_matrix_fill (sg, sub_c, 0 );
56
- for (int k = 0 ; k < K / TK; k += 1 ) {
57
- joint_matrix_load (
58
- sg, sub_a,
59
- accA.template get_multi_ptr <access::decorated::no>() +
60
- (sg_startx * TM) * K + k * TK,
61
- K, matrix_layout::row_major);
62
- // VNNI transform is done automatically at this level
63
- joint_matrix_load (
64
- sg, sub_b,
65
- accB.template get_multi_ptr <access::decorated::no>() +
66
- (k * TK) * N + sg_starty / SG_SZ * TN,
67
- N, matrix_layout::row_major);
68
- sub_c = joint_matrix_mad (sg, sub_a, sub_b, sub_c);
69
- }
70
- joint_matrix_store (
71
- sg, sub_c,
72
- accC.template get_multi_ptr <access::decorated::no>() +
73
- (sg_startx * TM) * N + sg_starty / SG_SZ * TN,
74
- N, matrix_layout::row_major);
75
- }); // parallel for
40
+ [[intel::reqd_sub_group_size (SG_SZ)]] {
41
+ // The submatrix API has to be accessed by all the workitems in a
42
+ // subgroup these functions will be called once by the subgroup
43
+ // no code divergence between the workitems
44
+ const auto global_idx = spmd_item.get_global_id (0 );
45
+ const auto global_idy = spmd_item.get_global_id (1 );
46
+ const auto sg_startx = global_idx - spmd_item.get_local_id (0 );
47
+ const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
48
+
49
+ sub_group sg = spmd_item.get_sub_group ();
50
+ joint_matrix<int8_t , TM, TK> sub_a (sg);
51
+ joint_matrix<int8_t , TK, TN, matrix_layout::packed_b> sub_b (sg);
52
+ joint_matrix<int32_t , TM, TN> sub_c (sg);
53
+
54
+ joint_matrix_fill (sg, sub_c, 0 );
55
+ for (int k = 0 ; k < K / TK; k += 1 ) {
56
+ joint_matrix_load (
57
+ sg, sub_a,
58
+ accA.template get_multi_ptr <access::decorated::no>() +
59
+ (sg_startx * TM) * K + k * TK,
60
+ K, matrix_layout::row_major);
61
+ // VNNI transform is done automatically at this level
62
+ joint_matrix_load (
63
+ sg, sub_b,
64
+ accB.template get_multi_ptr <access::decorated::no>() +
65
+ (k * TK) * N + sg_starty / SG_SZ * TN,
66
+ N, matrix_layout::row_major);
67
+ sub_c = joint_matrix_mad (sg, sub_a, sub_b, sub_c);
68
+ }
69
+ joint_matrix_store (
70
+ sg, sub_c,
71
+ accC.template get_multi_ptr <access::decorated::no>() +
72
+ (sg_startx * TM) * N + sg_starty / SG_SZ * TN,
73
+ N, matrix_layout::row_major);
74
+ }); // parallel for
76
75
}).wait ();
77
76
}
78
77
0 commit comments