Skip to content

Commit e0f6391

Browse files
authored
[SYCL][NFC] Cleanup wi_data/joint_matrix code (#7929)
There doesn't appear to be any reason to have an #ifdef in get_wi_data, so I used the template inference in both cases. I've updated the runtime errors to be accurate, and removed some of my old out of date comments. Signed-off-by: JackAKirk <[email protected]>
1 parent f5126d2 commit e0f6391

File tree

1 file changed

+11
-28
lines changed

1 file changed

+11
-28
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ struct joint_matrix {
2828
__spv::__spirv_JointMatrixINTEL<
2929
T, Rows, Cols, spv_matrix_layout_traits<Layout>::value,
3030
spv_scope_traits<Group>::value, spv_matrix_use_traits<Use>::value> *spvm;
31+
#endif // defined(__NVPTX__)
3132
#endif // defined(__SYCL_DEVICE_ONLY__)
32-
#endif
3333

3434
joint_matrix() {
3535
#ifndef __SYCL_DEVICE_ONLY__
@@ -93,12 +93,8 @@ template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
9393
inline __SYCL_ALWAYS_INLINE decltype(auto)
9494
get_wi_data(Group sg, joint_matrix<Group, T, Use, Rows, Cols, Layout> &jm) {
9595
#if defined(__SYCL_DEVICE_ONLY__)
96-
#if defined(__NVPTX__)
9796
std::ignore = sg;
9897
return wi_data(jm);
99-
#else
100-
return wi_data<Group, T, Use, Rows, Cols, Layout>(jm);
101-
#endif // defined(__NVPTX__)
10298
#else
10399
if constexpr (std::is_same_v<T, precision::tf32>) {
104100
marray<float, 1> unused{};
@@ -131,10 +127,8 @@ joint_matrix_fill(Group sg,
131127
std::ignore = sg;
132128
std::ignore = res;
133129
std::ignore = v;
134-
throw runtime_error(
135-
"This version of the matrix extension is only currently supported on "
136-
"Nvidia devices",
137-
PI_ERROR_INVALID_DEVICE);
130+
throw runtime_error("joint matrix is not supported on host device.",
131+
PI_ERROR_INVALID_DEVICE);
138132
#endif // defined(__SYCL_DEVICE_ONLY__)
139133
}
140134

@@ -155,8 +149,6 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load(
155149
sycl::ext::oneapi::detail::load_accumulator_cuda(res.cuda_impl, src, stride,
156150
Layout);
157151
#else
158-
// intel's impl
159-
// matL is determined by matrix.use?
160152
T *Ptr = src.get();
161153
switch (Layout) {
162154
default:
@@ -189,10 +181,8 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load(
189181
std::ignore = res;
190182
std::ignore = src;
191183
std::ignore = stride;
192-
throw runtime_error(
193-
"This version of the matrix extension is only currently supported on "
194-
"Nvidia devices",
195-
PI_ERROR_INVALID_DEVICE);
184+
throw runtime_error("joint matrix is not supported on host device.",
185+
PI_ERROR_INVALID_DEVICE);
196186
#endif // defined(__SYCL_DEVICE_ONLY__)
197187
}
198188

@@ -228,10 +218,8 @@ joint_matrix_load(Group sg,
228218
std::ignore = res;
229219
std::ignore = src;
230220
std::ignore = stride;
231-
throw runtime_error(
232-
"This version of the matrix extension is only currently supported on "
233-
"Nvidia devices",
234-
PI_ERROR_INVALID_DEVICE);
221+
throw runtime_error("joint matrix is not supported on host device.",
222+
PI_ERROR_INVALID_DEVICE);
235223
#endif // defined(__SYCL_DEVICE_ONLY__)
236224
}
237225

@@ -250,7 +238,6 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
250238
Space>(src.cuda_impl, dst,
251239
stride, Layout);
252240
#else
253-
// intel's impl
254241
T *Ptr = dst.get();
255242
switch (Layout) {
256243
default:
@@ -283,10 +270,8 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
283270
std::ignore = src;
284271
std::ignore = dst;
285272
std::ignore = stride;
286-
throw runtime_error(
287-
"This version of the matrix extension is only currently supported on "
288-
"Nvidia devices",
289-
PI_ERROR_INVALID_DEVICE);
273+
throw runtime_error("joint matrix is not supported on host device.",
274+
PI_ERROR_INVALID_DEVICE);
290275
#endif // defined(__SYCL_DEVICE_ONLY__)
291276
}
292277

@@ -337,10 +322,8 @@ inline __SYCL_ALWAYS_INLINE
337322
std::ignore = A;
338323
std::ignore = B;
339324
std::ignore = C;
340-
throw runtime_error(
341-
"This version of the matrix extension is only currently supported on "
342-
"Nvidia devices",
343-
PI_ERROR_INVALID_DEVICE);
325+
throw runtime_error("joint matrix is not supported on host device.",
326+
PI_ERROR_INVALID_DEVICE);
344327
#endif // defined(__SYCL_DEVICE_ONLY__)
345328
}
346329

0 commit comments

Comments
 (0)