Skip to content

Commit bd5a1f0

Browse files
authored
Merge pull request intel#1534 from dkhaldi/ats-m-double-bug-fix
[SYCL][Matrix] fix ATS-M double bug
2 parents b82d0ca + e9d9119 commit bd5a1f0

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

SYCL/Matrix/Legacy/element_wise_all_ops_half_impl.hpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
4141
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
4242
joint_matrix<T, TM, TK> sub_a(sg);
4343

44-
joint_matrix_fill(sg, sub_a, 5.0);
44+
joint_matrix_fill(sg, sub_a, 5);
4545

4646
auto wi_slice_a = sub_a.get_wi_data();
4747
for (int i = 0; i < wi_slice_a.length(); i++) {
@@ -74,7 +74,7 @@ void matrix_verify_sub(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
7474
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
7575
joint_matrix<T, TM, TK> sub_a(sg);
7676

77-
joint_matrix_fill(sg, sub_a, 5.0);
77+
joint_matrix_fill(sg, sub_a, 5);
7878

7979
auto wi_slice_a = sub_a.get_wi_data();
8080
for (int i = 0; i < wi_slice_a.length(); i++) {
@@ -107,7 +107,7 @@ void matrix_verify_mul(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
107107
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
108108
joint_matrix<T, TM, TK> sub_a(sg);
109109

110-
joint_matrix_fill(sg, sub_a, 5.0);
110+
joint_matrix_fill(sg, sub_a, 5);
111111

112112
auto wi_slice_a = sub_a.get_wi_data();
113113
for (int i = 0; i < wi_slice_a.length(); i++) {
@@ -140,7 +140,7 @@ void matrix_verify_div(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
140140
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
141141
joint_matrix<T, TM, TK> sub_a(sg);
142142

143-
joint_matrix_fill(sg, sub_a, 4.0);
143+
joint_matrix_fill(sg, sub_a, 4);
144144

145145
auto wi_slice_a = sub_a.get_wi_data();
146146
for (int i = 0; i < wi_slice_a.length(); i++) {
@@ -173,7 +173,7 @@ void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
173173
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
174174
joint_matrix<T, TM, TK> sub_a(sg);
175175

176-
joint_matrix_fill(sg, sub_a, 5.0);
176+
joint_matrix_fill(sg, sub_a, 5);
177177

178178
auto wi_slice_a = sub_a.get_wi_data();
179179
for (int i = 0; i < wi_slice_a.length(); i++) {
@@ -189,8 +189,8 @@ void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
189189
val++;
190190
if (wi_slice_a[i] == static_cast<half>(2.0)) {
191191
val -= 2;
192-
val *= 3.0;
193-
val /= 2.0;
192+
val *= 3;
193+
val /= 2;
194194
} else {
195195
val += 2;
196196
}

SYCL/Matrix/element_wise_all_ops_half_impl.hpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
4141
sub_group sg = spmd_item.get_sub_group();
4242
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
4343

44-
joint_matrix_fill(sg, sub_a, 5.0);
44+
joint_matrix_fill(sg, sub_a, 5);
4545

4646
auto wi_slice_a = get_wi_data(sg, sub_a);
4747
for (int i = 0; i < wi_slice_a.length(); i++) {
@@ -75,7 +75,7 @@ void matrix_verify_sub(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
7575
sub_group sg = spmd_item.get_sub_group();
7676
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
7777

78-
joint_matrix_fill(sg, sub_a, 5.0);
78+
joint_matrix_fill(sg, sub_a, 5);
7979

8080
auto wi_slice_a = get_wi_data(sg, sub_a);
8181
for (int i = 0; i < wi_slice_a.length(); i++) {
@@ -109,7 +109,7 @@ void matrix_verify_mul(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
109109
sub_group sg = spmd_item.get_sub_group();
110110
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
111111

112-
joint_matrix_fill(sg, sub_a, 5.0);
112+
joint_matrix_fill(sg, sub_a, 5);
113113

114114
auto wi_slice_a = get_wi_data(sg, sub_a);
115115
for (int i = 0; i < wi_slice_a.length(); i++) {
@@ -143,7 +143,7 @@ void matrix_verify_div(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
143143
sub_group sg = spmd_item.get_sub_group();
144144
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
145145

146-
joint_matrix_fill(sg, sub_a, 4.0);
146+
joint_matrix_fill(sg, sub_a, 4);
147147

148148
auto wi_slice_a = get_wi_data(sg, sub_a);
149149
for (int i = 0; i < wi_slice_a.length(); i++) {
@@ -177,7 +177,7 @@ void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
177177
sub_group sg = spmd_item.get_sub_group();
178178
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
179179

180-
joint_matrix_fill(sg, sub_a, 5.0);
180+
joint_matrix_fill(sg, sub_a, 5);
181181

182182
auto wi_slice_a = get_wi_data(sg, sub_a);
183183
for (int i = 0; i < wi_slice_a.length(); i++) {
@@ -193,8 +193,8 @@ void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
193193
val++;
194194
if (wi_slice_a[i] == static_cast<half>(2.0)) {
195195
val -= 2;
196-
val *= 3.0;
197-
val /= 2.0;
196+
val *= 3;
197+
val /= 2;
198198
} else {
199199
val += 2;
200200
}

0 commit comments

Comments
 (0)