Skip to content

Commit 8ceb72f

Browse files
authored
[AMDGPU] make v32i16/v32f16 legal (#70484)
Some upcoming intrinsics will be using these new types
1 parent 6397ea7 commit 8ceb72f

File tree

17 files changed

+643
-545
lines changed

17 files changed

+643
-545
lines changed

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
384384
MVT::v12f32, MVT::v16f16, MVT::v16i16, MVT::v16f32, MVT::v16i32,
385385
MVT::v32f32, MVT::v32i32, MVT::v2f64, MVT::v2i64, MVT::v3f64,
386386
MVT::v3i64, MVT::v4f64, MVT::v4i64, MVT::v8f64, MVT::v8i64,
387-
MVT::v16f64, MVT::v16i64},
387+
MVT::v16f64, MVT::v16i64, MVT::v32i16, MVT::v32f16},
388388
Custom);
389389

390390
setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 78 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
165165
addRegisterClass(MVT::v8f16, &AMDGPU::SGPR_128RegClass);
166166
addRegisterClass(MVT::v16i16, &AMDGPU::SGPR_256RegClass);
167167
addRegisterClass(MVT::v16f16, &AMDGPU::SGPR_256RegClass);
168+
addRegisterClass(MVT::v32i16, &AMDGPU::SGPR_512RegClass);
169+
addRegisterClass(MVT::v32f16, &AMDGPU::SGPR_512RegClass);
168170
}
169171

170172
addRegisterClass(MVT::v32i32, &AMDGPU::VReg_1024RegClass);
@@ -269,13 +271,13 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
269271
// We only support LOAD/STORE and vector manipulation ops for vectors
270272
// with > 4 elements.
271273
for (MVT VT :
272-
{MVT::v8i32, MVT::v8f32, MVT::v9i32, MVT::v9f32, MVT::v10i32,
273-
MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32,
274-
MVT::v16i32, MVT::v16f32, MVT::v2i64, MVT::v2f64, MVT::v4i16,
275-
MVT::v4f16, MVT::v3i64, MVT::v3f64, MVT::v6i32, MVT::v6f32,
276-
MVT::v4i64, MVT::v4f64, MVT::v8i64, MVT::v8f64, MVT::v8i16,
277-
MVT::v8f16, MVT::v16i16, MVT::v16f16, MVT::v16i64, MVT::v16f64,
278-
MVT::v32i32, MVT::v32f32}) {
274+
{MVT::v8i32, MVT::v8f32, MVT::v9i32, MVT::v9f32, MVT::v10i32,
275+
MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32,
276+
MVT::v16i32, MVT::v16f32, MVT::v2i64, MVT::v2f64, MVT::v4i16,
277+
MVT::v4f16, MVT::v3i64, MVT::v3f64, MVT::v6i32, MVT::v6f32,
278+
MVT::v4i64, MVT::v4f64, MVT::v8i64, MVT::v8f64, MVT::v8i16,
279+
MVT::v8f16, MVT::v16i16, MVT::v16f16, MVT::v16i64, MVT::v16f64,
280+
MVT::v32i32, MVT::v32f32, MVT::v32i16, MVT::v32f16}) {
279281
for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op) {
280282
switch (Op) {
281283
case ISD::LOAD:
@@ -553,8 +555,9 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
553555
if (STI.hasMadF16())
554556
setOperationAction(ISD::FMAD, MVT::f16, Legal);
555557

556-
for (MVT VT : {MVT::v2i16, MVT::v2f16, MVT::v4i16, MVT::v4f16, MVT::v8i16,
557-
MVT::v8f16, MVT::v16i16, MVT::v16f16}) {
558+
for (MVT VT :
559+
{MVT::v2i16, MVT::v2f16, MVT::v4i16, MVT::v4f16, MVT::v8i16,
560+
MVT::v8f16, MVT::v16i16, MVT::v16f16, MVT::v32i16, MVT::v32f16}) {
558561
for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op) {
559562
switch (Op) {
560563
case ISD::LOAD:
@@ -640,6 +643,16 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
640643
setOperationAction(ISD::STORE, MVT::v16f16, Promote);
641644
AddPromotedToType(ISD::STORE, MVT::v16f16, MVT::v8i32);
642645

646+
setOperationAction(ISD::LOAD, MVT::v32i16, Promote);
647+
AddPromotedToType(ISD::LOAD, MVT::v32i16, MVT::v16i32);
648+
setOperationAction(ISD::LOAD, MVT::v32f16, Promote);
649+
AddPromotedToType(ISD::LOAD, MVT::v32f16, MVT::v16i32);
650+
651+
setOperationAction(ISD::STORE, MVT::v32i16, Promote);
652+
AddPromotedToType(ISD::STORE, MVT::v32i16, MVT::v16i32);
653+
setOperationAction(ISD::STORE, MVT::v32f16, Promote);
654+
AddPromotedToType(ISD::STORE, MVT::v32f16, MVT::v16i32);
655+
643656
setOperationAction({ISD::ANY_EXTEND, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND},
644657
MVT::v2i32, Expand);
645658
setOperationAction(ISD::FP_EXTEND, MVT::v2f32, Expand);
@@ -662,12 +675,15 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
662675
setOperationAction({ISD::FMAXNUM_IEEE, ISD::FMINNUM_IEEE}, MVT::f16, Legal);
663676

664677
setOperationAction({ISD::FMINNUM_IEEE, ISD::FMAXNUM_IEEE},
665-
{MVT::v4f16, MVT::v8f16, MVT::v16f16}, Custom);
678+
{MVT::v4f16, MVT::v8f16, MVT::v16f16, MVT::v32f16},
679+
Custom);
666680

667681
setOperationAction({ISD::FMINNUM, ISD::FMAXNUM},
668-
{MVT::v4f16, MVT::v8f16, MVT::v16f16}, Expand);
682+
{MVT::v4f16, MVT::v8f16, MVT::v16f16, MVT::v32f16},
683+
Expand);
669684

670-
for (MVT Vec16 : {MVT::v8i16, MVT::v8f16, MVT::v16i16, MVT::v16f16}) {
685+
for (MVT Vec16 : {MVT::v8i16, MVT::v8f16, MVT::v16i16, MVT::v16f16,
686+
MVT::v32i16, MVT::v32f16}) {
671687
setOperationAction(
672688
{ISD::BUILD_VECTOR, ISD::EXTRACT_VECTOR_ELT, ISD::SCALAR_TO_VECTOR},
673689
Vec16, Custom);
@@ -690,18 +706,18 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
690706

691707
setOperationAction(ISD::VECTOR_SHUFFLE,
692708
{MVT::v4f16, MVT::v4i16, MVT::v8f16, MVT::v8i16,
693-
MVT::v16f16, MVT::v16i16},
709+
MVT::v16f16, MVT::v16i16, MVT::v32f16, MVT::v32i16},
694710
Custom);
695711

696-
for (MVT VT : {MVT::v4i16, MVT::v8i16, MVT::v16i16})
712+
for (MVT VT : {MVT::v4i16, MVT::v8i16, MVT::v16i16, MVT::v32i16})
697713
// Split vector operations.
698714
setOperationAction({ISD::SHL, ISD::SRA, ISD::SRL, ISD::ADD, ISD::SUB,
699715
ISD::MUL, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX,
700716
ISD::UADDSAT, ISD::SADDSAT, ISD::USUBSAT,
701717
ISD::SSUBSAT},
702718
VT, Custom);
703719

704-
for (MVT VT : {MVT::v4f16, MVT::v8f16, MVT::v16f16})
720+
for (MVT VT : {MVT::v4f16, MVT::v8f16, MVT::v16f16, MVT::v32f16})
705721
// Split vector operations.
706722
setOperationAction({ISD::FADD, ISD::FMUL, ISD::FMA, ISD::FCANONICALIZE},
707723
VT, Custom);
@@ -737,7 +753,8 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
737753

738754
setOperationAction(ISD::SELECT,
739755
{MVT::v4i16, MVT::v4f16, MVT::v2i8, MVT::v4i8, MVT::v8i8,
740-
MVT::v8i16, MVT::v8f16, MVT::v16i16, MVT::v16f16},
756+
MVT::v8i16, MVT::v8f16, MVT::v16i16, MVT::v16f16,
757+
MVT::v32i16, MVT::v32f16},
741758
Custom);
742759

743760
setOperationAction({ISD::SMULO, ISD::UMULO}, MVT::i64, Custom);
@@ -5107,7 +5124,7 @@ SDValue SITargetLowering::splitUnaryVectorOp(SDValue Op,
51075124
assert(VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v4f32 ||
51085125
VT == MVT::v8i16 || VT == MVT::v8f16 || VT == MVT::v16i16 ||
51095126
VT == MVT::v16f16 || VT == MVT::v8f32 || VT == MVT::v16f32 ||
5110-
VT == MVT::v32f32);
5127+
VT == MVT::v32f32 || VT == MVT::v32i16 || VT == MVT::v32f16);
51115128

51125129
SDValue Lo, Hi;
51135130
std::tie(Lo, Hi) = DAG.SplitVectorOperand(Op.getNode(), 0);
@@ -5130,7 +5147,7 @@ SDValue SITargetLowering::splitBinaryVectorOp(SDValue Op,
51305147
assert(VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v4f32 ||
51315148
VT == MVT::v8i16 || VT == MVT::v8f16 || VT == MVT::v16i16 ||
51325149
VT == MVT::v16f16 || VT == MVT::v8f32 || VT == MVT::v16f32 ||
5133-
VT == MVT::v32f32);
5150+
VT == MVT::v32f32 || VT == MVT::v32i16 || VT == MVT::v32f16);
51345151

51355152
SDValue Lo0, Hi0;
51365153
std::tie(Lo0, Hi0) = DAG.SplitVectorOperand(Op.getNode(), 0);
@@ -5897,7 +5914,8 @@ SDValue SITargetLowering::lowerFMINNUM_FMAXNUM(SDValue Op,
58975914
if (IsIEEEMode)
58985915
return expandFMINNUM_FMAXNUM(Op.getNode(), DAG);
58995916

5900-
if (VT == MVT::v4f16 || VT == MVT::v8f16 || VT == MVT::v16f16)
5917+
if (VT == MVT::v4f16 || VT == MVT::v8f16 || VT == MVT::v16f16 ||
5918+
VT == MVT::v16f16)
59015919
return splitBinaryVectorOp(Op, DAG);
59025920
return Op;
59035921
}
@@ -6415,7 +6433,7 @@ SDValue SITargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
64156433
if (SDValue Combined = performExtractVectorEltCombine(Op.getNode(), DCI))
64166434
return Combined;
64176435

6418-
if (VecSize == 128 || VecSize == 256) {
6436+
if (VecSize == 128 || VecSize == 256 || VecSize == 512) {
64196437
SDValue Lo, Hi;
64206438
EVT LoVT, HiVT;
64216439
std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VecVT);
@@ -6428,9 +6446,7 @@ SDValue SITargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
64286446
Hi = DAG.getBitcast(HiVT,
64296447
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i64, V2,
64306448
DAG.getConstant(1, SL, MVT::i32)));
6431-
} else {
6432-
assert(VecSize == 256);
6433-
6449+
} else if (VecSize == 256) {
64346450
SDValue V2 = DAG.getBitcast(MVT::v4i64, Vec);
64356451
SDValue Parts[4];
64366452
for (unsigned P = 0; P < 4; ++P) {
@@ -6442,6 +6458,22 @@ SDValue SITargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
64426458
Parts[0], Parts[1]));
64436459
Hi = DAG.getBitcast(HiVT, DAG.getNode(ISD::BUILD_VECTOR, SL, MVT::v2i64,
64446460
Parts[2], Parts[3]));
6461+
} else {
6462+
assert(VecSize == 512);
6463+
6464+
SDValue V2 = DAG.getBitcast(MVT::v8i64, Vec);
6465+
SDValue Parts[8];
6466+
for (unsigned P = 0; P < 8; ++P) {
6467+
Parts[P] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i64, V2,
6468+
DAG.getConstant(P, SL, MVT::i32));
6469+
}
6470+
6471+
Lo = DAG.getBitcast(LoVT,
6472+
DAG.getNode(ISD::BUILD_VECTOR, SL, MVT::v4i64,
6473+
Parts[0], Parts[1], Parts[2], Parts[3]));
6474+
Hi = DAG.getBitcast(HiVT,
6475+
DAG.getNode(ISD::BUILD_VECTOR, SL, MVT::v4i64,
6476+
Parts[4], Parts[5],Parts[6], Parts[7]));
64456477
}
64466478

64476479
EVT IdxVT = Idx.getValueType();
@@ -6607,6 +6639,27 @@ SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
66076639
return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
66086640
}
66096641

6642+
if (VT == MVT::v32i16 || VT == MVT::v32f16) {
6643+
EVT QuarterVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
6644+
VT.getVectorNumElements() / 8);
6645+
MVT QuarterIntVT = MVT::getIntegerVT(QuarterVT.getSizeInBits());
6646+
6647+
SmallVector<SDValue, 8> Parts[8];
6648+
for (unsigned I = 0, E = VT.getVectorNumElements() / 8; I != E; ++I) {
6649+
for (unsigned P = 0; P < 8; ++P)
6650+
Parts[P].push_back(Op.getOperand(I + P * E));
6651+
}
6652+
SDValue Casts[8];
6653+
for (unsigned P = 0; P < 8; ++P) {
6654+
SDValue Vec = DAG.getBuildVector(QuarterVT, SL, Parts[P]);
6655+
Casts[P] = DAG.getNode(ISD::BITCAST, SL, QuarterIntVT, Vec);
6656+
}
6657+
6658+
SDValue Blend =
6659+
DAG.getBuildVector(MVT::getVectorVT(QuarterIntVT, 8), SL, Casts);
6660+
return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
6661+
}
6662+
66106663
assert(VT == MVT::v2f16 || VT == MVT::v2i16);
66116664
assert(!Subtarget->hasVOP3PInsts() && "this should be legal");
66126665

@@ -9507,7 +9560,8 @@ SDValue SITargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
95079560

95089561
SDValue SITargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const {
95099562
EVT VT = Op.getValueType();
9510-
if (VT.getSizeInBits() == 128 || VT.getSizeInBits() == 256)
9563+
if (VT.getSizeInBits() == 128 || VT.getSizeInBits() == 256 ||
9564+
VT.getSizeInBits() == 512)
95119565
return splitTernaryVectorOp(Op, DAG);
95129566

95139567
assert(VT.getSizeInBits() == 64);

llvm/lib/Target/AMDGPU/SIInstructions.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1619,6 +1619,16 @@ def : BitConvert <v12i32, v12f32, VReg_384>;
16191619
def : BitConvert <v12f32, v12i32, VReg_384>;
16201620

16211621
// 512-bit bitcast
1622+
def : BitConvert <v32f16, v32i16, VReg_512>;
1623+
def : BitConvert <v32i16, v32f16, VReg_512>;
1624+
def : BitConvert <v32f16, v16i32, VReg_512>;
1625+
def : BitConvert <v32f16, v16f32, VReg_512>;
1626+
def : BitConvert <v16f32, v32f16, VReg_512>;
1627+
def : BitConvert <v16i32, v32f16, VReg_512>;
1628+
def : BitConvert <v32i16, v16i32, VReg_512>;
1629+
def : BitConvert <v32i16, v16f32, VReg_512>;
1630+
def : BitConvert <v16f32, v32i16, VReg_512>;
1631+
def : BitConvert <v16i32, v32i16, VReg_512>;
16221632
def : BitConvert <v16i32, v16f32, VReg_512>;
16231633
def : BitConvert <v16f32, v16i32, VReg_512>;
16241634
def : BitConvert <v8i64, v8f64, VReg_512>;

llvm/lib/Target/AMDGPU/SIRegisterInfo.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -930,7 +930,7 @@ defm "" : SRegClass<11, [v11i32, v11f32], SGPR_352Regs, TTMP_352Regs>;
930930
defm "" : SRegClass<12, [v12i32, v12f32], SGPR_384Regs, TTMP_384Regs>;
931931

932932
let GlobalPriority = true in {
933-
defm "" : SRegClass<16, [v16i32, v16f32, v8i64, v8f64], SGPR_512Regs, TTMP_512Regs>;
933+
defm "" : SRegClass<16, [v16i32, v16f32, v8i64, v8f64, v32i16, v32f16], SGPR_512Regs, TTMP_512Regs>;
934934
defm "" : SRegClass<32, [v32i32, v32f32, v16i64, v16f64], SGPR_1024Regs>;
935935
}
936936

@@ -984,7 +984,7 @@ defm VReg_352 : VRegClass<11, [v11i32, v11f32], (add VGPR_352)>;
984984
defm VReg_384 : VRegClass<12, [v12i32, v12f32], (add VGPR_384)>;
985985

986986
let GlobalPriority = true in {
987-
defm VReg_512 : VRegClass<16, [v16i32, v16f32, v8i64, v8f64], (add VGPR_512)>;
987+
defm VReg_512 : VRegClass<16, [v16i32, v16f32, v8i64, v8f64, v32i16, v32f16], (add VGPR_512)>;
988988
defm VReg_1024 : VRegClass<32, [v32i32, v32f32, v16i64, v16f64], (add VGPR_1024)>;
989989
}
990990

llvm/test/Analysis/CostModel/AMDGPU/add-sub.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ define amdgpu_kernel void @add_i16() #0 {
7676
; FAST16-NEXT: Cost Model: Found an estimated cost of 4 for instruction: %v5i16 = add <5 x i16> undef, undef
7777
; FAST16-NEXT: Cost Model: Found an estimated cost of 4 for instruction: %v6i16 = add <6 x i16> undef, undef
7878
; FAST16-NEXT: Cost Model: Found an estimated cost of 8 for instruction: %v16i16 = add <16 x i16> undef, undef
79-
; FAST16-NEXT: Cost Model: Found an estimated cost of 32 for instruction: %v17i16 = add <17 x i16> undef, undef
79+
; FAST16-NEXT: Cost Model: Found an estimated cost of 48 for instruction: %v17i16 = add <17 x i16> undef, undef
8080
; FAST16-NEXT: Cost Model: Found an estimated cost of 10 for instruction: ret void
8181
;
8282
; SLOW16-LABEL: 'add_i16'
@@ -98,7 +98,7 @@ define amdgpu_kernel void @add_i16() #0 {
9898
; FAST16-SIZE-NEXT: Cost Model: Found an estimated cost of 4 for instruction: %v5i16 = add <5 x i16> undef, undef
9999
; FAST16-SIZE-NEXT: Cost Model: Found an estimated cost of 4 for instruction: %v6i16 = add <6 x i16> undef, undef
100100
; FAST16-SIZE-NEXT: Cost Model: Found an estimated cost of 8 for instruction: %v16i16 = add <16 x i16> undef, undef
101-
; FAST16-SIZE-NEXT: Cost Model: Found an estimated cost of 32 for instruction: %v17i16 = add <17 x i16> undef, undef
101+
; FAST16-SIZE-NEXT: Cost Model: Found an estimated cost of 48 for instruction: %v17i16 = add <17 x i16> undef, undef
102102
; FAST16-SIZE-NEXT: Cost Model: Found an estimated cost of 1 for instruction: ret void
103103
;
104104
; SLOW16-SIZE-LABEL: 'add_i16'

0 commit comments

Comments
 (0)