Skip to content

Commit 3699811

Browse files
authored
[AMDGPU] Handle bf16 operands the same way as f16. NFC. (#77826)
This is infrastructure change which shall allow use of bf16 operands with instruction definitions.
1 parent c39926e commit 3699811

File tree

1 file changed

+29
-15
lines changed

1 file changed

+29
-15
lines changed

llvm/lib/Target/AMDGPU/SIInstrInfo.td

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -284,12 +284,17 @@ def SIfptrunc_round_downward : SDNode<"AMDGPUISD::FPTRUNC_ROUND_DOWNWARD",
284284
// Returns 1 if the source arguments have modifiers, 0 if they do not.
285285
class isFloatType<ValueType SrcVT> {
286286
bit ret = !or(!eq(SrcVT.Value, f16.Value),
287+
!eq(SrcVT.Value, bf16.Value),
287288
!eq(SrcVT.Value, f32.Value),
288289
!eq(SrcVT.Value, f64.Value),
289290
!eq(SrcVT.Value, v2f16.Value),
291+
!eq(SrcVT.Value, v2bf16.Value),
290292
!eq(SrcVT.Value, v4f16.Value),
293+
!eq(SrcVT.Value, v4bf16.Value),
291294
!eq(SrcVT.Value, v8f16.Value),
295+
!eq(SrcVT.Value, v8bf16.Value),
292296
!eq(SrcVT.Value, v16f16.Value),
297+
!eq(SrcVT.Value, v16bf16.Value),
293298
!eq(SrcVT.Value, v2f32.Value),
294299
!eq(SrcVT.Value, v4f32.Value),
295300
!eq(SrcVT.Value, v8f32.Value),
@@ -314,7 +319,9 @@ class isIntType<ValueType SrcVT> {
314319
class isPackedType<ValueType SrcVT> {
315320
bit ret = !or(!eq(SrcVT.Value, v2i16.Value),
316321
!eq(SrcVT.Value, v2f16.Value),
322+
!eq(SrcVT.Value, v2bf16.Value),
317323
!eq(SrcVT.Value, v4f16.Value),
324+
!eq(SrcVT.Value, v4bf16.Value),
318325
!eq(SrcVT.Value, v2i32.Value),
319326
!eq(SrcVT.Value, v2f32.Value),
320327
!eq(SrcVT.Value, v4i32.Value),
@@ -1495,14 +1502,14 @@ class getVOPSrc0ForVT<ValueType VT, bit IsTrue16, bit IsFake16 = 1> {
14951502
!if(isFP,
14961503
!if(!eq(VT.Size, 64),
14971504
VSrc_f64,
1498-
!if(!eq(VT.Value, f16.Value),
1505+
!if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
14991506
!if(IsTrue16,
15001507
!if(IsFake16, VSrcFake16_f16_Lo128, VSrcT_f16_Lo128),
15011508
VSrc_f16
15021509
),
1503-
!if(!eq(VT.Value, v2f16.Value),
1510+
!if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
15041511
VSrc_v2f16,
1505-
!if(!eq(VT.Value, v4f16.Value),
1512+
!if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)),
15061513
AVSrc_64,
15071514
VSrc_f32
15081515
)
@@ -1576,11 +1583,11 @@ class getVOP3SrcForVT<ValueType VT, bit IsTrue16 = 0> {
15761583
!if(!eq(VT.Value, i1.Value),
15771584
SSrc_i1,
15781585
!if(isFP,
1579-
!if(!eq(VT.Value, f16.Value),
1586+
!if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
15801587
!if(IsTrue16, VSrcT_f16, VSrc_f16),
1581-
!if(!eq(VT.Value, v2f16.Value),
1588+
!if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
15821589
VSrc_v2f16,
1583-
!if(!eq(VT.Value, v4f16.Value),
1590+
!if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)),
15841591
AVSrc_64,
15851592
VSrc_f32
15861593
)
@@ -1605,8 +1612,8 @@ class getVOP3DPPSrcForVT<ValueType VT> {
16051612
RegisterOperand ret =
16061613
!if (!eq(VT.Value, i1.Value), SSrc_i1,
16071614
!if (isFP,
1608-
!if (!eq(VT.Value, f16.Value), VCSrc_f16,
1609-
!if (!eq(VT.Value, v2f16.Value), VCSrc_v2f16, VCSrc_f32)),
1615+
!if (!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)), VCSrc_f16,
1616+
!if (!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)), VCSrc_v2f16, VCSrc_f32)),
16101617
!if (!eq(VT.Value, i16.Value), VCSrc_b16,
16111618
!if (!eq(VT.Value, v2i16.Value), VCSrc_v2b16,
16121619
VCSrc_b32))));
@@ -1615,22 +1622,27 @@ class getVOP3DPPSrcForVT<ValueType VT> {
16151622
// Float or packed int
16161623
class isModifierType<ValueType SrcVT> {
16171624
bit ret = !or(!eq(SrcVT.Value, f16.Value),
1625+
!eq(SrcVT.Value, bf16.Value),
16181626
!eq(SrcVT.Value, f32.Value),
16191627
!eq(SrcVT.Value, f64.Value),
16201628
!eq(SrcVT.Value, v2f16.Value),
16211629
!eq(SrcVT.Value, v2i16.Value),
1630+
!eq(SrcVT.Value, v2bf16.Value),
16221631
!eq(SrcVT.Value, v2f32.Value),
16231632
!eq(SrcVT.Value, v2i32.Value),
16241633
!eq(SrcVT.Value, v4f16.Value),
16251634
!eq(SrcVT.Value, v4i16.Value),
1635+
!eq(SrcVT.Value, v4bf16.Value),
16261636
!eq(SrcVT.Value, v4f32.Value),
16271637
!eq(SrcVT.Value, v4i32.Value),
16281638
!eq(SrcVT.Value, v8f16.Value),
16291639
!eq(SrcVT.Value, v8i16.Value),
1640+
!eq(SrcVT.Value, v8bf16.Value),
16301641
!eq(SrcVT.Value, v8f32.Value),
16311642
!eq(SrcVT.Value, v8i32.Value),
16321643
!eq(SrcVT.Value, v16f16.Value),
1633-
!eq(SrcVT.Value, v16i16.Value));
1644+
!eq(SrcVT.Value, v16i16.Value),
1645+
!eq(SrcVT.Value, v16bf16.Value));
16341646
}
16351647

16361648
// Return type of input modifiers operand for specified input operand
@@ -1646,7 +1658,8 @@ class getSrcMod <ValueType VT, bit IsTrue16 = 0> {
16461658
}
16471659

16481660
class getOpSelMod <ValueType VT> {
1649-
Operand ret = !if(!eq(VT.Value, f16.Value), FP16InputMods, IntOpSelMods);
1661+
Operand ret = !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
1662+
FP16InputMods, IntOpSelMods);
16501663
}
16511664

16521665
// Return type of input modifiers operand specified input operand for DPP
@@ -1659,8 +1672,8 @@ class getSrcModDPP_t16 <ValueType VT> {
16591672
bit isFP = isFloatType<VT>.ret;
16601673
Operand ret =
16611674
!if (isFP,
1662-
!if (!eq(VT.Value, f16.Value), FPT16VRegInputMods,
1663-
FPVRegInputMods),
1675+
!if (!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
1676+
FPT16VRegInputMods, FPVRegInputMods),
16641677
!if (!eq(VT.Value, i16.Value), IntT16VRegInputMods,
16651678
IntVRegInputMods));
16661679
}
@@ -1671,8 +1684,8 @@ class getSrcModVOP3DPP <ValueType VT> {
16711684
bit isPacked = isPackedType<VT>.ret;
16721685
Operand ret =
16731686
!if (isFP,
1674-
!if (!eq(VT.Value, f16.Value), FP16VCSrcInputMods,
1675-
FP32VCSrcInputMods),
1687+
!if (!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
1688+
FP16VCSrcInputMods, FP32VCSrcInputMods),
16761689
Int32VCSrcInputMods);
16771690
}
16781691

@@ -1681,7 +1694,8 @@ class getSrcModSDWA <ValueType VT> {
16811694
Operand ret = !if(!eq(VT.Value, f16.Value), FP16SDWAInputMods,
16821695
!if(!eq(VT.Value, f32.Value), FP32SDWAInputMods,
16831696
!if(!eq(VT.Value, i16.Value), Int16SDWAInputMods,
1684-
Int32SDWAInputMods)));
1697+
!if(!eq(VT.Value, bf16.Value), FP16SDWAInputMods,
1698+
Int32SDWAInputMods))));
16851699
}
16861700

16871701
// Returns the input arguments for VOP[12C] instructions for the given SrcVT.

0 commit comments

Comments
 (0)