@@ -441,7 +441,9 @@ enum {
441
441
MatrixCSignedComponentsKHR = 0x4 ,
442
442
MatrixResultSignedComponentsKHR = 0x8 ,
443
443
// Unused right now
444
- SaturatingAccumulationKHR = 0x10
444
+ SaturatingAccumulationKHR = 0x10 ,
445
+ MatrixAAndBTF32ComponentsINTEL = 0x20 ,
446
+ MatrixAAndBBFloat16ComponentsINTEL = 0x40
445
447
};
446
448
447
449
namespace IGC {
@@ -1402,7 +1404,8 @@ Instruction *JointMatrixFuncsResolutionPass::ResolveStore(CallInst *CI)
1402
1404
return newCall;
1403
1405
}
1404
1406
1405
- static PrecisionType getElementPrecison (const JointMatrixTypeDescription *desc, bool floatOp, bool isUnsigned) {
1407
+ static PrecisionType getJointMatrixElementPrecison (
1408
+ const JointMatrixTypeDescription *desc, bool floatOp, bool isUnsigned) {
1406
1409
const unsigned width = desc->bitWidth ;
1407
1410
if (floatOp && width == 16 ) {
1408
1411
/* bf is passed as uint16_t, hf is using halfs */
@@ -1417,6 +1420,39 @@ static PrecisionType getElementPrecison(const JointMatrixTypeDescription *desc,
1417
1420
return PrecisionType::PRECISION_UNUSED;
1418
1421
}
1419
1422
1423
+ static PrecisionType getCoopMatrixElementPrecison (
1424
+ const JointMatrixTypeDescription *desc, unsigned OperandsMask, unsigned Use,
1425
+ bool floatOp) {
1426
+ const unsigned width = desc->bitWidth ;
1427
+ if (OperandsMask & MatrixAAndBBFloat16ComponentsINTEL) {
1428
+ IGC_ASSERT_MESSAGE (floatOp && width == 16 ,
1429
+ " Wrong OpCooperativeMatrixMulAddKHR ops for BFloat16" );
1430
+ return PrecisionType::BF16;
1431
+ }
1432
+ if (floatOp && width == 16 ) {
1433
+ IGC_ASSERT_MESSAGE (!OperandsMask,
1434
+ " Wrong OpCooperativeMatrixMulAddKHR ops for FP16" );
1435
+ /* bf is passed as uint16_t, hf is using halfs */
1436
+ return desc->isFloating ? PrecisionType::FP16 : PrecisionType::BF16;
1437
+ }
1438
+ if (OperandsMask & MatrixAAndBTF32ComponentsINTEL ||
1439
+ (floatOp && width == 32 )) {
1440
+ return PrecisionType::TF32;
1441
+ }
1442
+ if (!floatOp && width == 8 ) {
1443
+ if (OperandsMask & MatrixASignedComponentsKHR &&
1444
+ OperandsMask & MatrixBSignedComponentsKHR) {
1445
+ return PrecisionType::S8;
1446
+ } else if (OperandsMask & MatrixASignedComponentsKHR) {
1447
+ return Use == UseMatrixA ? PrecisionType::S8 : PrecisionType::U8;
1448
+ } else if (OperandsMask & MatrixBSignedComponentsKHR) {
1449
+ return Use == UseMatrixB ? PrecisionType::S8 : PrecisionType::U8;
1450
+ }
1451
+ return PrecisionType::U8;
1452
+ }
1453
+ return PrecisionType::PRECISION_UNUSED;
1454
+ }
1455
+
1420
1456
static const char *getElementName (PrecisionType P) {
1421
1457
switch (P) {
1422
1458
case PrecisionType::FP16: return " fp16_" ;
@@ -1499,28 +1535,20 @@ Instruction *JointMatrixFuncsResolutionPass::ResolveMad(CallInst *CI, unsigned O
1499
1535
1500
1536
const bool floatMad = cDesc.isFloating ;
1501
1537
1502
- // TODO: with Cooperative matrix extension and with further extend
1503
- // of a new version of Joint matrix extension we carry information of the
1504
- // type interpretation in MulAdd last masked parameter, so need to adjust
1505
- // getElementPrecison logic for the new versions
1538
+ PrecisionType PA = PrecisionType::PRECISION_UNUSED;
1539
+ PrecisionType PB = PrecisionType::PRECISION_UNUSED;
1506
1540
if (OperationType == CooperativeOp) {
1507
- OperationType = floatMad ? MadOpSS : MadOpUU;
1508
1541
const unsigned MulAddArgSize = CI->arg_size ();
1509
- if (MulAddArgSize > 3 ) {
1510
- const auto OperandsMask =
1511
- cast<ConstantInt>(CI->getArgOperand (3 ))->getZExtValue ();
1512
- if (OperandsMask & MatrixASignedComponentsKHR &&
1513
- OperandsMask & MatrixBSignedComponentsKHR) {
1514
- OperationType = MadOpSS;
1515
- } else if (OperandsMask & MatrixASignedComponentsKHR) {
1516
- OperationType = MadOpSU;
1517
- } else if (OperandsMask & MatrixBSignedComponentsKHR) {
1518
- OperationType = MadOpUS;
1519
- }
1520
- }
1542
+ const auto OperandsMask = MulAddArgSize > 3
1543
+ ? cast<ConstantInt>(CI->getArgOperand (3 ))->getZExtValue () : 0 ;
1544
+ PA = getCoopMatrixElementPrecison (&aDesc, OperandsMask, UseMatrixA, floatMad);
1545
+ PB = getCoopMatrixElementPrecison (&bDesc, OperandsMask, UseMatrixB, floatMad);
1546
+ } else {
1547
+ PA = getJointMatrixElementPrecison (&aDesc, floatMad,
1548
+ isOperandUnsigned (OperationType, 0 ));
1549
+ PB = getJointMatrixElementPrecison (&bDesc, floatMad,
1550
+ isOperandUnsigned (OperationType, 1 ));
1521
1551
}
1522
- PrecisionType PA = getElementPrecison (&aDesc, floatMad, isOperandUnsigned (OperationType, 0 ));
1523
- PrecisionType PB = getElementPrecison (&bDesc, floatMad, isOperandUnsigned (OperationType, 1 ));
1524
1552
1525
1553
IGC_ASSERT_MESSAGE (PA != PrecisionType::PRECISION_UNUSED, " Invalid matrix A element type." );
1526
1554
IGC_ASSERT_MESSAGE (PB != PrecisionType::PRECISION_UNUSED, " Invalid matrix B element type." );
0 commit comments