Skip to content

Commit 2155906

Browse files
authored
[SYCL][Joint Matrix] Add more cases in common JM tests functions - Part 2 (#16021)
1 parent 4bca246 commit 2155906

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ void verify_op_c(const T l, const T r, const float ref, OP op) {
109109
template <typename T, size_t SROWS, size_t SCOLS, use Use, class name>
110110
class ewops_ab {};
111111
template <typename T, size_t SROWS, size_t SCOLS, use Use, layout Layout,
112-
size_t VF>
112+
size_t VF, typename Tv = T>
113113
void test_ewops_ab() {
114114
if constexpr (Use == use::a)
115115
std::cout << "Test A ";
@@ -122,41 +122,43 @@ void test_ewops_ab() {
122122

123123
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
124124
ewops_ab<T, SROWS, SCOLS, Use, class ab_add>>(
125-
T(5.0), T(2.0), 7.0, [](auto l, auto r) { return l + r; });
125+
Tv(5.0), Tv(2.0), 7.0, [](auto l, auto r) { return l + r; });
126126
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
127127
ewops_ab<T, SROWS, SCOLS, Use, class ab_sub>>(
128-
T(5.0), T(2.0), 3.0, [](auto l, auto r) { return l - r; });
128+
Tv(5.0), Tv(2.0), 3.0, [](auto l, auto r) { return l - r; });
129129
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
130130
ewops_ab<T, SROWS, SCOLS, Use, class ab_mul>>(
131-
T(5.0), T(2.0), 10.0, [](auto l, auto r) { return l * r; });
131+
Tv(5.0), Tv(2.0), 10.0, [](auto l, auto r) { return l * r; });
132132
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
133133
ewops_ab<T, SROWS, SCOLS, Use, class ab_div>>(
134-
T(5.0), T(2.0), 2.5, [](auto l, auto r) { return l / r; });
134+
Tv(5.0), Tv(2.0), 2.5, [](auto l, auto r) { return l / r; });
135135
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
136136
ewops_ab<T, SROWS, SCOLS, Use, class ab_logical>>(
137-
T(5.0), T(5.0), 5.0, [](auto l, auto r) { return l == r ? l : T(1.0); });
137+
Tv(5.0), Tv(5.0), 5.0,
138+
[](auto l, auto r) { return l == r ? l : Tv(1.0); });
138139
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
139140
ewops_ab<T, SROWS, SCOLS, Use, class ab_eq>>(
140-
T(5.0), T(4.0), 4.0, [](auto l, auto r) { return l == r ? l : r; });
141+
Tv(5.0), Tv(4.0), 4.0, [](auto l, auto r) { return l == r ? l : r; });
141142
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
142143
ewops_ab<T, SROWS, SCOLS, Use, class ab_ne>>(
143-
T(5.0), T(5.0), 1.0, [](auto l, auto r) { return l != r ? l : T(1.0); });
144+
Tv(5.0), Tv(5.0), 1.0,
145+
[](auto l, auto r) { return l != r ? l : Tv(1.0); });
144146
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
145147
ewops_ab<T, SROWS, SCOLS, Use, class ab_gt>>(
146-
T(5.0), T(2.0), 3.0,
147-
[](auto l, auto r) { return l > r ? T(3.0) : T(2.0); });
148+
Tv(5.0), Tv(2.0), 3.0,
149+
[](auto l, auto r) { return l > r ? Tv(3.0) : Tv(2.0); });
148150
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
149151
ewops_ab<T, SROWS, SCOLS, Use, class ab_lt>>(
150-
T(5.0), T(2.0), 2.0,
151-
[](auto l, auto r) { return l < r ? T(3.0) : T(2.0); });
152+
Tv(5.0), Tv(2.0), 2.0,
153+
[](auto l, auto r) { return l < r ? Tv(3.0) : Tv(2.0); });
152154
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
153155
ewops_ab<T, SROWS, SCOLS, Use, class ab_ge>>(
154-
T(5.0), T(2.0), 3.0,
155-
[](auto l, auto r) { return l >= r ? T(3.0) : T(2.0); });
156+
Tv(5.0), Tv(2.0), 3.0,
157+
[](auto l, auto r) { return l >= r ? Tv(3.0) : Tv(2.0); });
156158
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
157159
ewops_ab<T, SROWS, SCOLS, Use, class ab_le>>(
158-
T(5.0), T(2.0), 2.0,
159-
[](auto l, auto r) { return l <= r ? T(3.0) : T(2.0); });
160+
Tv(5.0), Tv(2.0), 2.0,
161+
[](auto l, auto r) { return l <= r ? Tv(3.0) : Tv(2.0); });
160162
}
161163

162164
// Avoid same kernel name for different types and numbers of columns

0 commit comments

Comments
 (0)