@@ -50,8 +50,8 @@ std::string toString(dpas_argument_type T) {
50
50
return " bf16" ;
51
51
case dpas_argument_type::tf32:
52
52
return " tf32" ;
53
- case dpas_argument_type::S1 :
54
- case dpas_argument_type::U1 :
53
+ case dpas_argument_type::s1 :
54
+ case dpas_argument_type::u1 :
55
55
case dpas_argument_type::Invalid:
56
56
return " UNSUPPORTED" ;
57
57
}
@@ -65,7 +65,7 @@ template <dpas_argument_type T> struct DpasPrintType {
65
65
static constexpr bool is_uint = T == dpas_argument_type::u2 ||
66
66
T == dpas_argument_type::u4 ||
67
67
T == dpas_argument_type::u8 ;
68
- static constexpr bool is_fp = T == dpas_argument_type::FP16 ||
68
+ static constexpr bool is_fp = T == dpas_argument_type::fp16 ||
69
69
T == dpas_argument_type::bf16 ||
70
70
T == dpas_argument_type::tf32;
71
71
@@ -100,7 +100,7 @@ template <dpas_argument_type T> struct DpasNaturalOperandType {
100
100
is_uint, unsigned char ,
101
101
std::conditional_t <
102
102
is_fp16, sycl::half,
103
- std::conditional <
103
+ std::conditional_t <
104
104
is_bf16, sycl::ext::oneapi::experimental::bfloat16, void >>>>;
105
105
};
106
106
@@ -123,6 +123,11 @@ template <dpas_argument_type T> constexpr int getBitSize() {
123
123
124
124
case dpas_argument_type::tf32:
125
125
return 32 ;
126
+
127
+ case dpas_argument_type::Invalid:
128
+ case dpas_argument_type::s1:
129
+ case dpas_argument_type::u1:
130
+ break ;
126
131
}
127
132
return 0 ;
128
133
}
@@ -282,7 +287,8 @@ void printMatrix(void *Vec, std::string Msg) {
282
287
}
283
288
284
289
template <int SystolicDepth, int RepeatCount, dpas_argument_type BPrec,
285
- dpas_argument_type APrec, bool UseSrc0>
290
+ dpas_argument_type APrec, bool UseSrc0, int ExecSize,
291
+ bool LetDeduceArgs>
286
292
bool test (queue &Q, bool Print) {
287
293
constexpr unsigned Size = 128 ;
288
294
constexpr unsigned VL = 16 ;
@@ -300,12 +306,13 @@ bool test(queue &Q, bool Print) {
300
306
// where:
301
307
constexpr int M = RepeatCount;
302
308
constexpr int K = SystolicDepth * OpsPerChannel;
303
- constexpr int N = 16 ; // Execution size: 16 for PVC.
309
+ constexpr int N = ExecSize ; // 16 for PVC, 8 for DG2 .
304
310
305
311
auto Dev = Q.get_device ();
306
- std::cout << " Running test case " << toString (BPrec, APrec)
307
- << " with UseSrc0 = " << UseSrc0 << " on "
308
- << Dev.get_info <info::device::name>() << " \n " ;
312
+ std::cout << " Running on " << Dev.get_info <info::device::name>()
313
+ << " (ExecSize = " << ExecSize << " ): " << toString (BPrec, APrec)
314
+ << " , UseSrc0 = " << UseSrc0
315
+ << " , LetDeduceArgs = " << LetDeduceArgs << std::endl;
309
316
310
317
using ANaturalType = typename DpasNaturalOperandType<APrec>::type;
311
318
using BNaturalType = typename DpasNaturalOperandType<BPrec>::type;
@@ -317,6 +324,7 @@ bool test(queue &Q, bool Print) {
317
324
auto BPacked = aligned_alloc_shared<BNaturalType>(128 , BPackedSize, Q);
318
325
auto Res = aligned_alloc_shared<ResNaturalType>(128 , M * N, Q);
319
326
// Init APacked;
327
+ <<<<<<< HEAD
320
328
<<<<<<< HEAD
321
329
int Value = 1 ;
322
330
for (int II = 0 ; II < M; II++) {
@@ -328,6 +336,12 @@ bool test(queue &Q, bool Print) {
328
336
for (int JJ = 0 ; JJ < K; JJ++) {
329
337
Value++;
330
338
>>>>>>> 78be3ae16 ([ESIMD] Add tests for new esimd::xmx:dpas API (#1281 ))
339
+ =======
340
+ float Value = 1.2 ;
341
+ for (int II = 0 ; II < M; II++) {
342
+ for (int JJ = 0 ; JJ < K; JJ++) {
343
+ Value += 1.1 ;
344
+ >>>>>>> 7fc11d5ad ([ESIMD] Add more tests for new xmx::dpas () (#1291 ))
331
345
writeToHorizontallyPackedMatrix<M, K, APrec>(
332
346
APacked, II, JJ, static_cast <ANaturalType>(Value));
333
347
}
@@ -357,15 +371,27 @@ bool test(queue &Q, bool Print) {
357
371
simd<BNaturalType, BPackedSize> B (BPacked, overaligned_tag<16 >{});
358
372
simd<ResNaturalType, M * N> C;
359
373
360
- if constexpr (UseSrc0) {
361
- // Compute C = C + AxB;
362
- C = 1 ;
363
- C = dpas<8 , RepeatCount, ResNaturalType, ResNaturalType, BNaturalType,
364
- ANaturalType, BPrec, APrec>(C, B, A);
374
+ if constexpr (LetDeduceArgs) {
375
+ if constexpr (UseSrc0) {
376
+ // Compute C = C + AxB;
377
+ C = 1 ;
378
+ C = dpas<8 , RepeatCount, ResNaturalType>(C, B, A);
379
+ } else {
380
+ // Compute C = AxB;
381
+ C = dpas<8 , RepeatCount, ResNaturalType>(B, A);
382
+ }
383
+
365
384
} else {
366
- // Compute C = AxB;
367
- C = dpas<8 , RepeatCount, ResNaturalType, BNaturalType, ANaturalType,
368
- BPrec, APrec>(B, A);
385
+ if constexpr (UseSrc0) {
386
+ // Compute C = C + AxB;
387
+ C = 1 ;
388
+ C = dpas<8 , RepeatCount, ResNaturalType, ResNaturalType, BNaturalType,
389
+ ANaturalType, BPrec, APrec>(C, B, A);
390
+ } else {
391
+ // Compute C = AxB;
392
+ C = dpas<8 , RepeatCount, ResNaturalType, BNaturalType, ANaturalType,
393
+ BPrec, APrec>(B, A);
394
+ }
369
395
}
370
396
371
397
C.copy_to (Res);
@@ -408,11 +434,40 @@ bool test(queue &Q, bool Print) {
408
434
}
409
435
410
436
template <int SystolicDepth, int RepeatCount, dpas_argument_type T1,
411
- dpas_argument_type T2>
437
+ dpas_argument_type T2, bool LetDeduceArgs = false >
412
438
bool tests (queue Q, bool Print) {
413
439
bool Passed = true ;
414
440
constexpr bool UseSrc0 = true ;
415
- Passed &= test<SystolicDepth, RepeatCount, T1, T2, UseSrc0>(Q, Print);
416
- Passed &= test<SystolicDepth, RepeatCount, T1, T2, !UseSrc0>(Q, Print);
441
+ auto Dev = Q.get_device ();
442
+
443
+ // Detect the execution size.
444
+ // The device trait is not implemented for esimd_emulator. Use both 8 and 16.
445
+ int ExecSize;
446
+ bool IsEmulator = false ;
447
+ try {
448
+ ExecSize = Dev.get_info <ext::intel::info::device::gpu_eu_simd_width>();
449
+ } catch (sycl::exception e) {
450
+ IsEmulator = true ;
451
+ }
452
+ assert ((IsEmulator || (ExecSize == 8 || ExecSize == 16 )) &&
453
+ " Execution size must be 8 or 16" );
454
+
455
+ if (ExecSize == 16 || IsEmulator) {
456
+ Passed &=
457
+ test<SystolicDepth, RepeatCount, T1, T2, UseSrc0, 16 , LetDeduceArgs>(
458
+ Q, Print);
459
+ Passed &=
460
+ test<SystolicDepth, RepeatCount, T1, T2, !UseSrc0, 16 , LetDeduceArgs>(
461
+ Q, Print);
462
+ }
463
+ if (ExecSize == 8 || IsEmulator) {
464
+ Passed &=
465
+ test<SystolicDepth, RepeatCount, T1, T2, UseSrc0, 8 , LetDeduceArgs>(
466
+ Q, Print);
467
+ Passed &=
468
+ test<SystolicDepth, RepeatCount, T1, T2, !UseSrc0, 8 , LetDeduceArgs>(
469
+ Q, Print);
470
+ }
471
+
417
472
return Passed;
418
473
}
0 commit comments