@@ -126,23 +126,43 @@ constexpr int verify_parameters_and_deduce_exec_size() {
126
126
127
127
if constexpr (APrecision == dpas_argument_type::FP16 ||
128
128
BPrecision == dpas_argument_type::FP16) {
129
- static_assert (APrecision == BPrecision &&
130
- __ESIMD_DNS::is_type<T, float , sycl::half>() &&
131
- __ESIMD_DNS::is_type<CT, float , sycl::half>(),
132
- " Unsupported DPAS types! The supported types are:\n "
133
- " Result | C | B | A \n "
134
- " f, hf | f, hf | hf | hf \n " );
129
+ if constexpr (ExecutionSize == 8 ) {
130
+ static_assert (APrecision == BPrecision &&
131
+ __ESIMD_DNS::is_type<T, float >() &&
132
+ __ESIMD_DNS::is_type<CT, float >(),
133
+ " Unsupported DPAS types! The supported types are:\n "
134
+ " Result | C | B | A \n "
135
+ " f | f | hf | hf \n " );
136
+ } else {
137
+ static_assert (APrecision == BPrecision &&
138
+ __ESIMD_DNS::is_type<T, float , sycl::half>() &&
139
+ __ESIMD_DNS::is_type<CT, float , sycl::half>(),
140
+ " Unsupported DPAS types! The supported types are:\n "
141
+ " Result | C | B | A \n "
142
+ " f, hf | f, hf | hf | hf \n " );
143
+ }
135
144
} else if constexpr (APrecision == dpas_argument_type::BF16 ||
136
145
BPrecision == dpas_argument_type::BF16) {
137
146
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
138
- static_assert (APrecision == BPrecision &&
139
- __ESIMD_DNS::is_type<T, float , bfloat16>() &&
140
- __ESIMD_DNS::is_type<CT, float , bfloat16>(),
141
- " Unsupported DPAS types! The supported types are:\n "
142
- " Result | C | B | A \n "
143
- " f, bf | f, bf | bf | bf \n " );
147
+ if constexpr (ExecutionSize == 8 ) {
148
+ static_assert (APrecision == BPrecision &&
149
+ __ESIMD_DNS::is_type<T, float , bfloat16>() &&
150
+ __ESIMD_DNS::is_type<CT, float , bfloat16>(),
151
+ " Unsupported DPAS types! The supported types are:\n "
152
+ " Result | C | B | A \n "
153
+ " f | f | bf | bf \n " );
154
+ } else {
155
+ static_assert (APrecision == BPrecision &&
156
+ __ESIMD_DNS::is_type<T, float , bfloat16>() &&
157
+ __ESIMD_DNS::is_type<CT, float , bfloat16>(),
158
+ " Unsupported DPAS types! The supported types are:\n "
159
+ " Result | C | B | A \n "
160
+ " f, bf | f, bf | bf | bf \n " );
161
+ }
144
162
} else if constexpr (APrecision == dpas_argument_type::TF32 ||
145
163
BPrecision == dpas_argument_type::TF32) {
164
+ static_assert (ExecutionSize == 16 ,
165
+ " tf32 type can be used only with ExecutionSize=16" );
146
166
static_assert (APrecision == BPrecision && std::is_same_v<T, float > &&
147
167
std::is_same_v<CT, float >,
148
168
" Unsupported DPAS types! The supported types are:\n "
@@ -223,7 +243,7 @@ auto dpas(__ESIMD_NS::simd<BT, BN> B, __ESIMD_NS::simd<AT, AN> A) {
223
243
detail::verify_parameters_and_deduce_exec_size<SystolicDepth, RepeatCount,
224
244
T, T, BT, AT, BPrecision,
225
245
APrecision, BN, AN>();
226
- // Result(_Mx_N) = A(_Mx_K) * B(_Kx_N) + C(_Mx_N)
246
+ // Result(_Mx_N) = A(_Mx_K) * B(_Kx_N)
227
247
// where:
228
248
// _M = RepeatCount;
229
249
// _K = SystolicDepth * OpsPerChannel;
@@ -237,8 +257,10 @@ auto dpas(__ESIMD_NS::simd<BT, BN> B, __ESIMD_NS::simd<AT, AN> A) {
237
257
238
258
constexpr int Info = (RepeatCount << 24 ) + (SystolicDepth << 16 ) +
239
259
((int )APrecision << 8 ) + (int )BPrecision;
240
- return __esimd_dpas_nosrc0<Info, T, int , int , ResultN, BNCasted, ANCasted>(
241
- BCasted.data (), ACasted.data ());
260
+ __ESIMD_NS::simd<T, ResultN> Result =
261
+ __esimd_dpas_nosrc0<Info, T, int , int , ResultN, BNCasted, ANCasted>(
262
+ BCasted.data (), ACasted.data ());
263
+ return Result;
242
264
}
243
265
244
266
// / DPAS (Dot Product Accumulate Systolic)
@@ -283,24 +305,32 @@ template <
283
305
int SystolicDepth, int RepeatCount, typename T, typename BT, typename AT,
284
306
dpas_argument_type BPrecision = detail::dpas_precision_from_type<BT>(),
285
307
dpas_argument_type APrecision = detail::dpas_precision_from_type<AT>(),
286
- int N, int BN, int AN>
287
- __ESIMD_NS::simd<T, N> dpasw (__ESIMD_NS::simd<BT, BN> B,
288
- __ESIMD_NS::simd<AT, AN> A) {
308
+ int BN, int AN>
309
+ auto dpasw (__ESIMD_NS::simd<BT, BN> B, __ESIMD_NS::simd<AT, AN> A) {
289
310
290
311
constexpr bool IsDPASW = true ;
291
- ( void ) detail::verify_parameters_and_deduce_exec_size<
312
+ constexpr int ExecutionSize = detail::verify_parameters_and_deduce_exec_size<
292
313
SystolicDepth, RepeatCount, T, T, BT, AT, BPrecision, APrecision, BN, AN,
293
314
IsDPASW>();
294
315
316
+ // Result(_Mx_N) = A(_Mx_K) * B(_Kx_N)
317
+ // where:
318
+ // _M = RepeatCount;
319
+ // _K = SystolicDepth * OpsPerChannel;
320
+ // _N = ExecutionSize (unknown, but deducible), must be 8 or 16.
321
+ constexpr int ResultN = RepeatCount * ExecutionSize;
322
+
295
323
constexpr int ANCasted = AN / (sizeof (int ) / sizeof (AT));
296
324
constexpr int BNCasted = BN / (sizeof (int ) / sizeof (BT));
297
325
__ESIMD_NS::simd<int , ANCasted> ACasted = A.template bit_cast_view <int >();
298
326
__ESIMD_NS::simd<int , BNCasted> BCasted = B.template bit_cast_view <int >();
299
327
300
328
constexpr int Info = (RepeatCount << 24 ) + (SystolicDepth << 16 ) +
301
329
((int )APrecision << 8 ) + (int )BPrecision;
302
- return __esimd_dpasw_nosrc0<Info, T, int , int , N, BNCasted, ANCasted>(
303
- BCasted.data (), ACasted.data ());
330
+ __ESIMD_NS::simd<T, ResultN> Result =
331
+ __esimd_dpasw_nosrc0<Info, T, int , int , ResultN, BNCasted, ANCasted>(
332
+ BCasted.data (), ACasted.data ());
333
+ return Result;
304
334
}
305
335
306
336
// / @} sycl_esimd_xmx_systolic_array_api
0 commit comments