@@ -284,12 +284,17 @@ def SIfptrunc_round_downward : SDNode<"AMDGPUISD::FPTRUNC_ROUND_DOWNWARD",
284
284
// Returns 1 if the source arguments have modifiers, 0 if they do not.
285
285
class isFloatType<ValueType SrcVT> {
286
286
bit ret = !or(!eq(SrcVT.Value, f16.Value),
287
+ !eq(SrcVT.Value, bf16.Value),
287
288
!eq(SrcVT.Value, f32.Value),
288
289
!eq(SrcVT.Value, f64.Value),
289
290
!eq(SrcVT.Value, v2f16.Value),
291
+ !eq(SrcVT.Value, v2bf16.Value),
290
292
!eq(SrcVT.Value, v4f16.Value),
293
+ !eq(SrcVT.Value, v4bf16.Value),
291
294
!eq(SrcVT.Value, v8f16.Value),
295
+ !eq(SrcVT.Value, v8bf16.Value),
292
296
!eq(SrcVT.Value, v16f16.Value),
297
+ !eq(SrcVT.Value, v16bf16.Value),
293
298
!eq(SrcVT.Value, v2f32.Value),
294
299
!eq(SrcVT.Value, v4f32.Value),
295
300
!eq(SrcVT.Value, v8f32.Value),
@@ -314,7 +319,9 @@ class isIntType<ValueType SrcVT> {
314
319
class isPackedType<ValueType SrcVT> {
315
320
bit ret = !or(!eq(SrcVT.Value, v2i16.Value),
316
321
!eq(SrcVT.Value, v2f16.Value),
322
+ !eq(SrcVT.Value, v2bf16.Value),
317
323
!eq(SrcVT.Value, v4f16.Value),
324
+ !eq(SrcVT.Value, v4bf16.Value),
318
325
!eq(SrcVT.Value, v2i32.Value),
319
326
!eq(SrcVT.Value, v2f32.Value),
320
327
!eq(SrcVT.Value, v4i32.Value),
@@ -1495,14 +1502,14 @@ class getVOPSrc0ForVT<ValueType VT, bit IsTrue16, bit IsFake16 = 1> {
1495
1502
!if(isFP,
1496
1503
!if(!eq(VT.Size, 64),
1497
1504
VSrc_f64,
1498
- !if(!eq(VT.Value, f16.Value),
1505
+ !if(!or(! eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value) ),
1499
1506
!if(IsTrue16,
1500
1507
!if(IsFake16, VSrcFake16_f16_Lo128, VSrcT_f16_Lo128),
1501
1508
VSrc_f16
1502
1509
),
1503
- !if(!eq(VT.Value, v2f16.Value),
1510
+ !if(!or(! eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value) ),
1504
1511
VSrc_v2f16,
1505
- !if(!eq(VT.Value, v4f16.Value),
1512
+ !if(!or(! eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value) ),
1506
1513
AVSrc_64,
1507
1514
VSrc_f32
1508
1515
)
@@ -1576,11 +1583,11 @@ class getVOP3SrcForVT<ValueType VT, bit IsTrue16 = 0> {
1576
1583
!if(!eq(VT.Value, i1.Value),
1577
1584
SSrc_i1,
1578
1585
!if(isFP,
1579
- !if(!eq(VT.Value, f16.Value),
1586
+ !if(!or(! eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value) ),
1580
1587
!if(IsTrue16, VSrcT_f16, VSrc_f16),
1581
- !if(!eq(VT.Value, v2f16.Value),
1588
+ !if(!or(! eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value) ),
1582
1589
VSrc_v2f16,
1583
- !if(!eq(VT.Value, v4f16.Value),
1590
+ !if(!or(! eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value) ),
1584
1591
AVSrc_64,
1585
1592
VSrc_f32
1586
1593
)
@@ -1605,8 +1612,8 @@ class getVOP3DPPSrcForVT<ValueType VT> {
1605
1612
RegisterOperand ret =
1606
1613
!if (!eq(VT.Value, i1.Value), SSrc_i1,
1607
1614
!if (isFP,
1608
- !if (!eq(VT.Value, f16.Value), VCSrc_f16,
1609
- !if (!eq(VT.Value, v2f16.Value), VCSrc_v2f16, VCSrc_f32)),
1615
+ !if (!or(! eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value) ), VCSrc_f16,
1616
+ !if (!or(! eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value) ), VCSrc_v2f16, VCSrc_f32)),
1610
1617
!if (!eq(VT.Value, i16.Value), VCSrc_b16,
1611
1618
!if (!eq(VT.Value, v2i16.Value), VCSrc_v2b16,
1612
1619
VCSrc_b32))));
@@ -1615,22 +1622,27 @@ class getVOP3DPPSrcForVT<ValueType VT> {
1615
1622
// Float or packed int
1616
1623
class isModifierType<ValueType SrcVT> {
1617
1624
bit ret = !or(!eq(SrcVT.Value, f16.Value),
1625
+ !eq(SrcVT.Value, bf16.Value),
1618
1626
!eq(SrcVT.Value, f32.Value),
1619
1627
!eq(SrcVT.Value, f64.Value),
1620
1628
!eq(SrcVT.Value, v2f16.Value),
1621
1629
!eq(SrcVT.Value, v2i16.Value),
1630
+ !eq(SrcVT.Value, v2bf16.Value),
1622
1631
!eq(SrcVT.Value, v2f32.Value),
1623
1632
!eq(SrcVT.Value, v2i32.Value),
1624
1633
!eq(SrcVT.Value, v4f16.Value),
1625
1634
!eq(SrcVT.Value, v4i16.Value),
1635
+ !eq(SrcVT.Value, v4bf16.Value),
1626
1636
!eq(SrcVT.Value, v4f32.Value),
1627
1637
!eq(SrcVT.Value, v4i32.Value),
1628
1638
!eq(SrcVT.Value, v8f16.Value),
1629
1639
!eq(SrcVT.Value, v8i16.Value),
1640
+ !eq(SrcVT.Value, v8bf16.Value),
1630
1641
!eq(SrcVT.Value, v8f32.Value),
1631
1642
!eq(SrcVT.Value, v8i32.Value),
1632
1643
!eq(SrcVT.Value, v16f16.Value),
1633
- !eq(SrcVT.Value, v16i16.Value));
1644
+ !eq(SrcVT.Value, v16i16.Value),
1645
+ !eq(SrcVT.Value, v16bf16.Value));
1634
1646
}
1635
1647
1636
1648
// Return type of input modifiers operand for specified input operand
@@ -1646,7 +1658,8 @@ class getSrcMod <ValueType VT, bit IsTrue16 = 0> {
1646
1658
}
1647
1659
1648
1660
class getOpSelMod <ValueType VT> {
1649
- Operand ret = !if(!eq(VT.Value, f16.Value), FP16InputMods, IntOpSelMods);
1661
+ Operand ret = !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
1662
+ FP16InputMods, IntOpSelMods);
1650
1663
}
1651
1664
1652
1665
// Return type of input modifiers operand specified input operand for DPP
@@ -1659,8 +1672,8 @@ class getSrcModDPP_t16 <ValueType VT> {
1659
1672
bit isFP = isFloatType<VT>.ret;
1660
1673
Operand ret =
1661
1674
!if (isFP,
1662
- !if (!eq(VT.Value, f16.Value), FPT16VRegInputMods ,
1663
- FPVRegInputMods),
1675
+ !if (!or(! eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)) ,
1676
+ FPT16VRegInputMods, FPVRegInputMods),
1664
1677
!if (!eq(VT.Value, i16.Value), IntT16VRegInputMods,
1665
1678
IntVRegInputMods));
1666
1679
}
@@ -1671,8 +1684,8 @@ class getSrcModVOP3DPP <ValueType VT> {
1671
1684
bit isPacked = isPackedType<VT>.ret;
1672
1685
Operand ret =
1673
1686
!if (isFP,
1674
- !if (!eq(VT.Value, f16.Value), FP16VCSrcInputMods ,
1675
- FP32VCSrcInputMods),
1687
+ !if (!or(! eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)) ,
1688
+ FP16VCSrcInputMods, FP32VCSrcInputMods),
1676
1689
Int32VCSrcInputMods);
1677
1690
}
1678
1691
@@ -1681,7 +1694,8 @@ class getSrcModSDWA <ValueType VT> {
1681
1694
Operand ret = !if(!eq(VT.Value, f16.Value), FP16SDWAInputMods,
1682
1695
!if(!eq(VT.Value, f32.Value), FP32SDWAInputMods,
1683
1696
!if(!eq(VT.Value, i16.Value), Int16SDWAInputMods,
1684
- Int32SDWAInputMods)));
1697
+ !if(!eq(VT.Value, bf16.Value), FP16SDWAInputMods,
1698
+ Int32SDWAInputMods))));
1685
1699
}
1686
1700
1687
1701
// Returns the input arguments for VOP[12C] instructions for the given SrcVT.
0 commit comments