@@ -90,47 +90,43 @@ struct joint_matrix {
90
90
};
91
91
92
92
template <typename Group, typename T, size_t NumRows, size_t NumCols, use Use,
93
- layout Layout, access::address_space Space>
94
- inline __SYCL_ALWAYS_INLINE void
95
- joint_matrix_load ( Group sg,
96
- joint_matrix<T, NumRows, NumCols, Use, Layout , Group> &res,
97
- multi_ptr<T, Space> src, size_t stride, layout MemL) {
93
+ access::address_space Space>
94
+ inline __SYCL_ALWAYS_INLINE void joint_matrix_load (
95
+ Group sg,
96
+ joint_matrix<T, NumRows, NumCols, Use, layout::unused , Group> &res,
97
+ multi_ptr<T, Space> src, size_t stride, layout MemL) {
98
98
#ifdef __SYCL_DEVICE_ONLY__
99
99
T *Ptr = src.get ();
100
100
switch (MemL) {
101
101
default :
102
102
assert (false && " Invalid Memory Layout!" );
103
103
case layout::row_major:
104
- res.spvm =
105
- __spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
106
- spv_matrix_use_traits<Use>::value,
107
- spv_matrix_layout_traits<Layout>::value>(
108
- Ptr, stride, __spv::MatrixLayout::RowMajor,
109
- spv_scope_traits<Group>::value);
104
+ res.spvm = __spirv_JointMatrixLoadINTEL<
105
+ T, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
106
+ spv_matrix_layout_traits<layout::unused>::value>(
107
+ Ptr, stride, __spv::MatrixLayout::RowMajor,
108
+ spv_scope_traits<Group>::value);
110
109
break ;
111
110
case layout::col_major:
112
- res.spvm =
113
- __spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
114
- spv_matrix_use_traits<Use>::value,
115
- spv_matrix_layout_traits<Layout>::value>(
116
- Ptr, stride, __spv::MatrixLayout::ColumnMajor,
117
- spv_scope_traits<Group>::value);
111
+ res.spvm = __spirv_JointMatrixLoadINTEL<
112
+ T, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
113
+ spv_matrix_layout_traits<layout::unused>::value>(
114
+ Ptr, stride, __spv::MatrixLayout::ColumnMajor,
115
+ spv_scope_traits<Group>::value);
118
116
break ;
119
117
case layout::packed_a:
120
- res.spvm =
121
- __spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
122
- spv_matrix_use_traits<Use>::value,
123
- spv_matrix_layout_traits<Layout>::value>(
124
- Ptr, stride, __spv::MatrixLayout::PackedA,
125
- spv_scope_traits<Group>::value);
118
+ res.spvm = __spirv_JointMatrixLoadINTEL<
119
+ T, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
120
+ spv_matrix_layout_traits<layout::unused>::value>(
121
+ Ptr, stride, __spv::MatrixLayout::PackedA,
122
+ spv_scope_traits<Group>::value);
126
123
break ;
127
124
case layout::packed_b:
128
- res.spvm =
129
- __spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
130
- spv_matrix_use_traits<Use>::value,
131
- spv_matrix_layout_traits<Layout>::value>(
132
- Ptr, stride, __spv::MatrixLayout::PackedB,
133
- spv_scope_traits<Group>::value);
125
+ res.spvm = __spirv_JointMatrixLoadINTEL<
126
+ T, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
127
+ spv_matrix_layout_traits<layout::unused>::value>(
128
+ Ptr, stride, __spv::MatrixLayout::PackedB,
129
+ spv_scope_traits<Group>::value);
134
130
break ;
135
131
}
136
132
#else
@@ -145,41 +141,41 @@ joint_matrix_load(Group sg,
145
141
}
146
142
147
143
template <typename Group, typename T, size_t NumRows, size_t NumCols, use Use,
148
- layout MatL, access::address_space Space>
149
- inline __SYCL_ALWAYS_INLINE void
150
- joint_matrix_store ( Group sg,
151
- joint_matrix<T, NumRows, NumCols, Use, MatL , Group> &src,
152
- multi_ptr<T, Space> res, size_t stride, layout MemL) {
144
+ access::address_space Space>
145
+ inline __SYCL_ALWAYS_INLINE void joint_matrix_store (
146
+ Group sg,
147
+ joint_matrix<T, NumRows, NumCols, Use, layout::unused , Group> &src,
148
+ multi_ptr<T, Space> res, size_t stride, layout MemL) {
153
149
#ifdef __SYCL_DEVICE_ONLY__
154
150
T *Ptr = res.get ();
155
151
switch (MemL) {
156
152
default :
157
153
assert (false && " Invalid Memory Layout!" );
158
154
case layout::row_major:
159
- __spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
160
- spv_matrix_use_traits<Use>::value,
161
- spv_matrix_layout_traits<MatL >::value>(
155
+ __spirv_JointMatrixStoreINTEL<
156
+ T, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
157
+ spv_matrix_layout_traits<layout::unused >::value>(
162
158
Ptr, src.spvm , stride, __spv::MatrixLayout::RowMajor,
163
159
spv_scope_traits<Group>::value);
164
160
break ;
165
161
case layout::col_major:
166
- __spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
167
- spv_matrix_use_traits<Use>::value,
168
- spv_matrix_layout_traits<MatL >::value>(
162
+ __spirv_JointMatrixStoreINTEL<
163
+ T, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
164
+ spv_matrix_layout_traits<layout::unused >::value>(
169
165
Ptr, src.spvm , stride, __spv::MatrixLayout::ColumnMajor,
170
166
spv_scope_traits<Group>::value);
171
167
break ;
172
168
case layout::packed_a:
173
- __spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
174
- spv_matrix_use_traits<Use>::value,
175
- spv_matrix_layout_traits<MatL >::value>(
169
+ __spirv_JointMatrixStoreINTEL<
170
+ T, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
171
+ spv_matrix_layout_traits<layout::unused >::value>(
176
172
Ptr, src.spvm , stride, __spv::MatrixLayout::PackedA,
177
173
spv_scope_traits<Group>::value);
178
174
break ;
179
175
case layout::packed_b:
180
- __spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
181
- spv_matrix_use_traits<Use>::value,
182
- spv_matrix_layout_traits<MatL >::value>(
176
+ __spirv_JointMatrixStoreINTEL<
177
+ T, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
178
+ spv_matrix_layout_traits<layout::unused >::value>(
183
179
Ptr, src.spvm , stride, __spv::MatrixLayout::PackedB,
184
180
spv_scope_traits<Group>::value);
185
181
break ;
0 commit comments