Skip to content

Commit 07c5f28

Browse files
address JarkAKirk's comments
1 parent 2a178bc commit 07c5f28

File tree

1 file changed

+42
-46
lines changed

1 file changed

+42
-46
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp

Lines changed: 42 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -90,47 +90,43 @@ struct joint_matrix {
9090
};
9191

9292
template <typename Group, typename T, size_t NumRows, size_t NumCols, use Use,
93-
layout Layout, access::address_space Space>
94-
inline __SYCL_ALWAYS_INLINE void
95-
joint_matrix_load(Group sg,
96-
joint_matrix<T, NumRows, NumCols, Use, Layout, Group> &res,
97-
multi_ptr<T, Space> src, size_t stride, layout MemL) {
93+
access::address_space Space>
94+
inline __SYCL_ALWAYS_INLINE void joint_matrix_load(
95+
Group sg,
96+
joint_matrix<T, NumRows, NumCols, Use, layout::unused, Group> &res,
97+
multi_ptr<T, Space> src, size_t stride, layout MemL) {
9898
#ifdef __SYCL_DEVICE_ONLY__
9999
T *Ptr = src.get();
100100
switch (MemL) {
101101
default:
102102
assert(false && "Invalid Memory Layout!");
103103
case layout::row_major:
104-
res.spvm =
105-
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
106-
spv_matrix_use_traits<Use>::value,
107-
spv_matrix_layout_traits<Layout>::value>(
108-
Ptr, stride, __spv::MatrixLayout::RowMajor,
109-
spv_scope_traits<Group>::value);
104+
res.spvm = __spirv_JointMatrixLoadINTEL<
105+
T, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
106+
spv_matrix_layout_traits<layout::unused>::value>(
107+
Ptr, stride, __spv::MatrixLayout::RowMajor,
108+
spv_scope_traits<Group>::value);
110109
break;
111110
case layout::col_major:
112-
res.spvm =
113-
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
114-
spv_matrix_use_traits<Use>::value,
115-
spv_matrix_layout_traits<Layout>::value>(
116-
Ptr, stride, __spv::MatrixLayout::ColumnMajor,
117-
spv_scope_traits<Group>::value);
111+
res.spvm = __spirv_JointMatrixLoadINTEL<
112+
T, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
113+
spv_matrix_layout_traits<layout::unused>::value>(
114+
Ptr, stride, __spv::MatrixLayout::ColumnMajor,
115+
spv_scope_traits<Group>::value);
118116
break;
119117
case layout::packed_a:
120-
res.spvm =
121-
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
122-
spv_matrix_use_traits<Use>::value,
123-
spv_matrix_layout_traits<Layout>::value>(
124-
Ptr, stride, __spv::MatrixLayout::PackedA,
125-
spv_scope_traits<Group>::value);
118+
res.spvm = __spirv_JointMatrixLoadINTEL<
119+
T, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
120+
spv_matrix_layout_traits<layout::unused>::value>(
121+
Ptr, stride, __spv::MatrixLayout::PackedA,
122+
spv_scope_traits<Group>::value);
126123
break;
127124
case layout::packed_b:
128-
res.spvm =
129-
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
130-
spv_matrix_use_traits<Use>::value,
131-
spv_matrix_layout_traits<Layout>::value>(
132-
Ptr, stride, __spv::MatrixLayout::PackedB,
133-
spv_scope_traits<Group>::value);
125+
res.spvm = __spirv_JointMatrixLoadINTEL<
126+
T, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
127+
spv_matrix_layout_traits<layout::unused>::value>(
128+
Ptr, stride, __spv::MatrixLayout::PackedB,
129+
spv_scope_traits<Group>::value);
134130
break;
135131
}
136132
#else
@@ -145,41 +141,41 @@ joint_matrix_load(Group sg,
145141
}
146142

147143
template <typename Group, typename T, size_t NumRows, size_t NumCols, use Use,
148-
layout MatL, access::address_space Space>
149-
inline __SYCL_ALWAYS_INLINE void
150-
joint_matrix_store(Group sg,
151-
joint_matrix<T, NumRows, NumCols, Use, MatL, Group> &src,
152-
multi_ptr<T, Space> res, size_t stride, layout MemL) {
144+
access::address_space Space>
145+
inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
146+
Group sg,
147+
joint_matrix<T, NumRows, NumCols, Use, layout::unused, Group> &src,
148+
multi_ptr<T, Space> res, size_t stride, layout MemL) {
153149
#ifdef __SYCL_DEVICE_ONLY__
154150
T *Ptr = res.get();
155151
switch (MemL) {
156152
default:
157153
assert(false && "Invalid Memory Layout!");
158154
case layout::row_major:
159-
__spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
160-
spv_matrix_use_traits<Use>::value,
161-
spv_matrix_layout_traits<MatL>::value>(
155+
__spirv_JointMatrixStoreINTEL<
156+
T, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
157+
spv_matrix_layout_traits<layout::unused>::value>(
162158
Ptr, src.spvm, stride, __spv::MatrixLayout::RowMajor,
163159
spv_scope_traits<Group>::value);
164160
break;
165161
case layout::col_major:
166-
__spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
167-
spv_matrix_use_traits<Use>::value,
168-
spv_matrix_layout_traits<MatL>::value>(
162+
__spirv_JointMatrixStoreINTEL<
163+
T, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
164+
spv_matrix_layout_traits<layout::unused>::value>(
169165
Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor,
170166
spv_scope_traits<Group>::value);
171167
break;
172168
case layout::packed_a:
173-
__spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
174-
spv_matrix_use_traits<Use>::value,
175-
spv_matrix_layout_traits<MatL>::value>(
169+
__spirv_JointMatrixStoreINTEL<
170+
T, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
171+
spv_matrix_layout_traits<layout::unused>::value>(
176172
Ptr, src.spvm, stride, __spv::MatrixLayout::PackedA,
177173
spv_scope_traits<Group>::value);
178174
break;
179175
case layout::packed_b:
180-
__spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
181-
spv_matrix_use_traits<Use>::value,
182-
spv_matrix_layout_traits<MatL>::value>(
176+
__spirv_JointMatrixStoreINTEL<
177+
T, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
178+
spv_matrix_layout_traits<layout::unused>::value>(
183179
Ptr, src.spvm, stride, __spv::MatrixLayout::PackedB,
184180
spv_scope_traits<Group>::value);
185181
break;

0 commit comments

Comments
 (0)