Skip to content

Commit bdbaf6e

Browse files
authored
AMDGPU: Make v8bf16/v16bf16 legal types (#76678)
Depends #76217
1 parent 4fdd24b commit bdbaf6e

File tree

4 files changed

+3284
-3507
lines changed

4 files changed

+3284
-3507
lines changed

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -387,18 +387,20 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
387387
MVT::v9i32, MVT::v9f32, MVT::v10i32, MVT::v10f32,
388388
MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32},
389389
Custom);
390+
391+
// FIXME: Why is v8f16/v8bf16 missing?
390392
setOperationAction(
391393
ISD::EXTRACT_SUBVECTOR,
392-
{MVT::v2f16, MVT::v2i16, MVT::v2bf16, MVT::v4f16, MVT::v4i16,
393-
MVT::v4bf16, MVT::v2f32, MVT::v2i32, MVT::v3f32, MVT::v3i32,
394+
{MVT::v2f16, MVT::v2bf16, MVT::v2i16, MVT::v4f16, MVT::v4bf16,
395+
MVT::v4i16, MVT::v2f32, MVT::v2i32, MVT::v3f32, MVT::v3i32,
394396
MVT::v4f32, MVT::v4i32, MVT::v5f32, MVT::v5i32, MVT::v6f32,
395397
MVT::v6i32, MVT::v7f32, MVT::v7i32, MVT::v8f32, MVT::v8i32,
396398
MVT::v9f32, MVT::v9i32, MVT::v10i32, MVT::v10f32, MVT::v11i32,
397-
MVT::v11f32, MVT::v12i32, MVT::v12f32, MVT::v16f16, MVT::v16i16,
398-
MVT::v16f32, MVT::v16i32, MVT::v32f32, MVT::v32i32, MVT::v2f64,
399-
MVT::v2i64, MVT::v3f64, MVT::v3i64, MVT::v4f64, MVT::v4i64,
400-
MVT::v8f64, MVT::v8i64, MVT::v16f64, MVT::v16i64, MVT::v32i16,
401-
MVT::v32f16},
399+
MVT::v11f32, MVT::v12i32, MVT::v12f32, MVT::v16f16, MVT::v16bf16,
400+
MVT::v16i16, MVT::v16f32, MVT::v16i32, MVT::v32f32, MVT::v32i32,
401+
MVT::v2f64, MVT::v2i64, MVT::v3f64, MVT::v3i64, MVT::v4f64,
402+
MVT::v4i64, MVT::v8f64, MVT::v8i64, MVT::v16f64, MVT::v16i64,
403+
MVT::v32i16, MVT::v32f16, MVT::v32bf16},
402404
Custom);
403405

404406
setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,10 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
167167
addRegisterClass(MVT::v4bf16, &AMDGPU::SReg_64RegClass);
168168
addRegisterClass(MVT::v8i16, &AMDGPU::SGPR_128RegClass);
169169
addRegisterClass(MVT::v8f16, &AMDGPU::SGPR_128RegClass);
170+
addRegisterClass(MVT::v8bf16, &AMDGPU::SGPR_128RegClass);
170171
addRegisterClass(MVT::v16i16, &AMDGPU::SGPR_256RegClass);
171172
addRegisterClass(MVT::v16f16, &AMDGPU::SGPR_256RegClass);
173+
addRegisterClass(MVT::v16bf16, &AMDGPU::SGPR_256RegClass);
172174
addRegisterClass(MVT::v32i16, &AMDGPU::SGPR_512RegClass);
173175
addRegisterClass(MVT::v32f16, &AMDGPU::SGPR_512RegClass);
174176
}
@@ -310,13 +312,14 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
310312
// We only support LOAD/STORE and vector manipulation ops for vectors
311313
// with > 4 elements.
312314
for (MVT VT :
313-
{MVT::v8i32, MVT::v8f32, MVT::v9i32, MVT::v9f32, MVT::v10i32,
314-
MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32,
315-
MVT::v16i32, MVT::v16f32, MVT::v2i64, MVT::v2f64, MVT::v4i16,
316-
MVT::v4f16, MVT::v4bf16, MVT::v3i64, MVT::v3f64, MVT::v6i32,
317-
MVT::v6f32, MVT::v4i64, MVT::v4f64, MVT::v8i64, MVT::v8f64,
318-
MVT::v8i16, MVT::v8f16, MVT::v16i16, MVT::v16f16, MVT::v16i64,
319-
MVT::v16f64, MVT::v32i32, MVT::v32f32, MVT::v32i16, MVT::v32f16}) {
315+
{MVT::v8i32, MVT::v8f32, MVT::v9i32, MVT::v9f32, MVT::v10i32,
316+
MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32,
317+
MVT::v16i32, MVT::v16f32, MVT::v2i64, MVT::v2f64, MVT::v4i16,
318+
MVT::v4f16, MVT::v4bf16, MVT::v3i64, MVT::v3f64, MVT::v6i32,
319+
MVT::v6f32, MVT::v4i64, MVT::v4f64, MVT::v8i64, MVT::v8f64,
320+
MVT::v8i16, MVT::v8f16, MVT::v8bf16, MVT::v16i16, MVT::v16f16,
321+
MVT::v16bf16, MVT::v16i64, MVT::v16f64, MVT::v32i32, MVT::v32f32,
322+
MVT::v32i16, MVT::v32f16, MVT::v32bf16}) {
320323
for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op) {
321324
switch (Op) {
322325
case ISD::LOAD:
@@ -683,6 +686,8 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
683686
AddPromotedToType(ISD::LOAD, MVT::v8i16, MVT::v4i32);
684687
setOperationAction(ISD::LOAD, MVT::v8f16, Promote);
685688
AddPromotedToType(ISD::LOAD, MVT::v8f16, MVT::v4i32);
689+
setOperationAction(ISD::LOAD, MVT::v8bf16, Promote);
690+
AddPromotedToType(ISD::LOAD, MVT::v8bf16, MVT::v4i32);
686691

687692
setOperationAction(ISD::STORE, MVT::v4i16, Promote);
688693
AddPromotedToType(ISD::STORE, MVT::v4i16, MVT::v2i32);
@@ -693,16 +698,22 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
693698
AddPromotedToType(ISD::STORE, MVT::v8i16, MVT::v4i32);
694699
setOperationAction(ISD::STORE, MVT::v8f16, Promote);
695700
AddPromotedToType(ISD::STORE, MVT::v8f16, MVT::v4i32);
701+
setOperationAction(ISD::STORE, MVT::v8bf16, Promote);
702+
AddPromotedToType(ISD::STORE, MVT::v8bf16, MVT::v4i32);
696703

697704
setOperationAction(ISD::LOAD, MVT::v16i16, Promote);
698705
AddPromotedToType(ISD::LOAD, MVT::v16i16, MVT::v8i32);
699706
setOperationAction(ISD::LOAD, MVT::v16f16, Promote);
700707
AddPromotedToType(ISD::LOAD, MVT::v16f16, MVT::v8i32);
708+
setOperationAction(ISD::LOAD, MVT::v16bf16, Promote);
709+
AddPromotedToType(ISD::LOAD, MVT::v16bf16, MVT::v8i32);
701710

702711
setOperationAction(ISD::STORE, MVT::v16i16, Promote);
703712
AddPromotedToType(ISD::STORE, MVT::v16i16, MVT::v8i32);
704713
setOperationAction(ISD::STORE, MVT::v16f16, Promote);
705714
AddPromotedToType(ISD::STORE, MVT::v16f16, MVT::v8i32);
715+
setOperationAction(ISD::STORE, MVT::v16bf16, Promote);
716+
AddPromotedToType(ISD::STORE, MVT::v16bf16, MVT::v8i32);
706717

707718
setOperationAction(ISD::LOAD, MVT::v32i16, Promote);
708719
AddPromotedToType(ISD::LOAD, MVT::v32i16, MVT::v16i32);
@@ -725,7 +736,8 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
725736
MVT::v8i32, Expand);
726737

727738
if (!Subtarget->hasVOP3PInsts())
728-
setOperationAction(ISD::BUILD_VECTOR, {MVT::v2i16, MVT::v2f16}, Custom);
739+
setOperationAction(ISD::BUILD_VECTOR,
740+
{MVT::v2i16, MVT::v2f16, MVT::v2bf16}, Custom);
729741

730742
setOperationAction(ISD::FNEG, MVT::v2f16, Legal);
731743
// This isn't really legal, but this avoids the legalizer unrolling it (and
@@ -743,8 +755,9 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
743755
{MVT::v4f16, MVT::v8f16, MVT::v16f16, MVT::v32f16},
744756
Expand);
745757

746-
for (MVT Vec16 : {MVT::v8i16, MVT::v8f16, MVT::v16i16, MVT::v16f16,
747-
MVT::v32i16, MVT::v32f16}) {
758+
for (MVT Vec16 :
759+
{MVT::v8i16, MVT::v8f16, MVT::v8bf16, MVT::v16i16, MVT::v16f16,
760+
MVT::v16bf16, MVT::v32i16, MVT::v32f16, MVT::v32bf16}) {
748761
setOperationAction(
749762
{ISD::BUILD_VECTOR, ISD::EXTRACT_VECTOR_ELT, ISD::SCALAR_TO_VECTOR},
750763
Vec16, Custom);
@@ -814,9 +827,10 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
814827
}
815828

816829
setOperationAction(ISD::SELECT,
817-
{MVT::v4i16, MVT::v4f16, MVT::v2i8, MVT::v4i8, MVT::v8i8,
818-
MVT::v8i16, MVT::v8f16, MVT::v16i16, MVT::v16f16,
819-
MVT::v32i16, MVT::v32f16},
830+
{MVT::v4i16, MVT::v4f16, MVT::v4bf16, MVT::v2i8, MVT::v4i8,
831+
MVT::v8i8, MVT::v8i16, MVT::v8f16, MVT::v8bf16,
832+
MVT::v16i16, MVT::v16f16, MVT::v16bf16, MVT::v32i16,
833+
MVT::v32f16, MVT::v32bf16},
820834
Custom);
821835

822836
setOperationAction({ISD::SMULO, ISD::UMULO}, MVT::i64, Custom);
@@ -5431,7 +5445,9 @@ SDValue SITargetLowering::splitTernaryVectorOp(SDValue Op,
54315445
assert(VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v8i16 ||
54325446
VT == MVT::v8f16 || VT == MVT::v4f32 || VT == MVT::v16i16 ||
54335447
VT == MVT::v16f16 || VT == MVT::v8f32 || VT == MVT::v16f32 ||
5434-
VT == MVT::v32f32 || VT == MVT::v32f16 || VT == MVT::v32i16);
5448+
VT == MVT::v32f32 || VT == MVT::v32f16 || VT == MVT::v32i16 ||
5449+
VT == MVT::v4bf16 || VT == MVT::v8bf16 || VT == MVT::v16bf16 ||
5450+
VT == MVT::v32bf16);
54355451

54365452
SDValue Lo0, Hi0;
54375453
SDValue Op0 = Op.getOperand(0);
@@ -6854,8 +6870,8 @@ SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
68546870
SDLoc SL(Op);
68556871
EVT VT = Op.getValueType();
68566872

6857-
if (VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v4bf16 ||
6858-
VT == MVT::v8i16 || VT == MVT::v8f16) {
6873+
if (VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v8i16 ||
6874+
VT == MVT::v8f16 || VT == MVT::v4bf16 || VT == MVT::v8bf16) {
68596875
EVT HalfVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
68606876
VT.getVectorNumElements() / 2);
68616877
MVT HalfIntVT = MVT::getIntegerVT(HalfVT.getSizeInBits());
@@ -6878,7 +6894,7 @@ SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
68786894
return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
68796895
}
68806896

6881-
if (VT == MVT::v16i16 || VT == MVT::v16f16) {
6897+
if (VT == MVT::v16i16 || VT == MVT::v16f16 || VT == MVT::v16bf16) {
68826898
EVT QuarterVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
68836899
VT.getVectorNumElements() / 4);
68846900
MVT QuarterIntVT = MVT::getIntegerVT(QuarterVT.getSizeInBits());
@@ -6899,7 +6915,7 @@ SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
68996915
return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
69006916
}
69016917

6902-
if (VT == MVT::v32i16 || VT == MVT::v32f16) {
6918+
if (VT == MVT::v32i16 || VT == MVT::v32f16 || VT == MVT::v32bf16) {
69036919
EVT QuarterVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
69046920
VT.getVectorNumElements() / 8);
69056921
MVT QuarterIntVT = MVT::getIntegerVT(QuarterVT.getSizeInBits());
@@ -14182,11 +14198,11 @@ SDValue SITargetLowering::PerformDAGCombine(SDNode *N,
1418214198
EVT VT = N->getValueType(0);
1418314199

1418414200
// v2i16 (scalar_to_vector i16:x) -> v2i16 (bitcast (any_extend i16:x))
14185-
if (VT == MVT::v2i16 || VT == MVT::v2f16) {
14201+
if (VT == MVT::v2i16 || VT == MVT::v2f16 || VT == MVT::v2f16) {
1418614202
SDLoc SL(N);
1418714203
SDValue Src = N->getOperand(0);
1418814204
EVT EltVT = Src.getValueType();
14189-
if (EltVT == MVT::f16)
14205+
if (EltVT != MVT::i16)
1419014206
Src = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Src);
1419114207

1419214208
SDValue Ext = DAG.getNode(ISD::ANY_EXTEND, SL, MVT::i32, Src);

llvm/lib/Target/AMDGPU/SIInstructions.td

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1633,6 +1633,37 @@ def : BitConvert <v2f64, v8i16, SReg_128>;
16331633
def : BitConvert <v2i64, v8f16, SReg_128>;
16341634
def : BitConvert <v2f64, v8f16, SReg_128>;
16351635

1636+
def : BitConvert <v4i32, v8bf16, SReg_128>;
1637+
def : BitConvert <v8bf16, v4i32, SReg_128>;
1638+
def : BitConvert <v4i32, v8bf16, VReg_128>;
1639+
def : BitConvert <v8bf16, v4i32, VReg_128>;
1640+
1641+
def : BitConvert <v4f32, v8bf16, SReg_128>;
1642+
def : BitConvert <v8bf16, v4f32, SReg_128>;
1643+
def : BitConvert <v4f32, v8bf16, VReg_128>;
1644+
def : BitConvert <v8bf16, v4f32, VReg_128>;
1645+
1646+
def : BitConvert <v8i16, v8bf16, SReg_128>;
1647+
def : BitConvert <v8bf16, v8i16, SReg_128>;
1648+
def : BitConvert <v8i16, v8bf16, VReg_128>;
1649+
def : BitConvert <v8bf16, v8i16, VReg_128>;
1650+
1651+
def : BitConvert <v8f16, v8bf16, SReg_128>;
1652+
def : BitConvert <v8bf16, v8f16, SReg_128>;
1653+
def : BitConvert <v8f16, v8bf16, VReg_128>;
1654+
def : BitConvert <v8bf16, v8f16, VReg_128>;
1655+
1656+
def : BitConvert <v2f64, v8bf16, SReg_128>;
1657+
def : BitConvert <v8bf16, v2f64, SReg_128>;
1658+
def : BitConvert <v2f64, v8bf16, VReg_128>;
1659+
def : BitConvert <v8bf16, v2f64, VReg_128>;
1660+
1661+
def : BitConvert <v2i64, v8bf16, SReg_128>;
1662+
def : BitConvert <v8bf16, v2i64, SReg_128>;
1663+
def : BitConvert <v2i64, v8bf16, VReg_128>;
1664+
def : BitConvert <v8bf16, v2i64, VReg_128>;
1665+
1666+
16361667
// 160-bit bitcast
16371668
def : BitConvert <v5i32, v5f32, SReg_160>;
16381669
def : BitConvert <v5f32, v5i32, SReg_160>;
@@ -1697,6 +1728,31 @@ def : BitConvert <v4i64, v16i16, VReg_256>;
16971728
def : BitConvert <v4f64, v16f16, VReg_256>;
16981729
def : BitConvert <v4f64, v16i16, VReg_256>;
16991730

1731+
1732+
def : BitConvert <v8i32, v16bf16, VReg_256>;
1733+
def : BitConvert <v16bf16, v8i32, VReg_256>;
1734+
def : BitConvert <v8f32, v16bf16, VReg_256>;
1735+
def : BitConvert <v16bf16, v8f32, VReg_256>;
1736+
def : BitConvert <v4i64, v16bf16, VReg_256>;
1737+
def : BitConvert <v16bf16, v4i64, VReg_256>;
1738+
def : BitConvert <v4f64, v16bf16, VReg_256>;
1739+
def : BitConvert <v16bf16, v4f64, VReg_256>;
1740+
1741+
1742+
1743+
def : BitConvert <v16i16, v16bf16, SReg_256>;
1744+
def : BitConvert <v16bf16, v16i16, SReg_256>;
1745+
def : BitConvert <v16i16, v16bf16, VReg_256>;
1746+
def : BitConvert <v16bf16, v16i16, VReg_256>;
1747+
1748+
def : BitConvert <v16f16, v16bf16, SReg_256>;
1749+
def : BitConvert <v16bf16, v16f16, SReg_256>;
1750+
def : BitConvert <v16f16, v16bf16, VReg_256>;
1751+
def : BitConvert <v16bf16, v16f16, VReg_256>;
1752+
1753+
1754+
1755+
17001756
// 288-bit bitcast
17011757
def : BitConvert <v9i32, v9f32, SReg_288>;
17021758
def : BitConvert <v9f32, v9i32, SReg_288>;

0 commit comments

Comments
 (0)