@@ -165,6 +165,8 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
165
165
addRegisterClass(MVT::v8f16, &AMDGPU::SGPR_128RegClass);
166
166
addRegisterClass(MVT::v16i16, &AMDGPU::SGPR_256RegClass);
167
167
addRegisterClass(MVT::v16f16, &AMDGPU::SGPR_256RegClass);
168
+ addRegisterClass(MVT::v32i16, &AMDGPU::SGPR_512RegClass);
169
+ addRegisterClass(MVT::v32f16, &AMDGPU::SGPR_512RegClass);
168
170
}
169
171
170
172
addRegisterClass(MVT::v32i32, &AMDGPU::VReg_1024RegClass);
@@ -269,13 +271,13 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
269
271
// We only support LOAD/STORE and vector manipulation ops for vectors
270
272
// with > 4 elements.
271
273
for (MVT VT :
272
- {MVT::v8i32, MVT::v8f32, MVT::v9i32, MVT::v9f32, MVT::v10i32,
273
- MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32,
274
- MVT::v16i32, MVT::v16f32, MVT::v2i64, MVT::v2f64, MVT::v4i16,
275
- MVT::v4f16, MVT::v3i64, MVT::v3f64, MVT::v6i32, MVT::v6f32,
276
- MVT::v4i64, MVT::v4f64, MVT::v8i64, MVT::v8f64, MVT::v8i16,
277
- MVT::v8f16, MVT::v16i16, MVT::v16f16, MVT::v16i64, MVT::v16f64,
278
- MVT::v32i32, MVT::v32f32}) {
274
+ {MVT::v8i32, MVT::v8f32, MVT::v9i32, MVT::v9f32, MVT::v10i32,
275
+ MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32,
276
+ MVT::v16i32, MVT::v16f32, MVT::v2i64, MVT::v2f64, MVT::v4i16,
277
+ MVT::v4f16, MVT::v3i64, MVT::v3f64, MVT::v6i32, MVT::v6f32,
278
+ MVT::v4i64, MVT::v4f64, MVT::v8i64, MVT::v8f64, MVT::v8i16,
279
+ MVT::v8f16, MVT::v16i16, MVT::v16f16, MVT::v16i64, MVT::v16f64,
280
+ MVT::v32i32, MVT::v32f32, MVT::v32i16, MVT::v32f16 }) {
279
281
for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op) {
280
282
switch (Op) {
281
283
case ISD::LOAD:
@@ -553,8 +555,9 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
553
555
if (STI.hasMadF16())
554
556
setOperationAction(ISD::FMAD, MVT::f16, Legal);
555
557
556
- for (MVT VT : {MVT::v2i16, MVT::v2f16, MVT::v4i16, MVT::v4f16, MVT::v8i16,
557
- MVT::v8f16, MVT::v16i16, MVT::v16f16}) {
558
+ for (MVT VT :
559
+ {MVT::v2i16, MVT::v2f16, MVT::v4i16, MVT::v4f16, MVT::v8i16,
560
+ MVT::v8f16, MVT::v16i16, MVT::v16f16, MVT::v32i16, MVT::v32f16}) {
558
561
for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op) {
559
562
switch (Op) {
560
563
case ISD::LOAD:
@@ -640,6 +643,16 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
640
643
setOperationAction(ISD::STORE, MVT::v16f16, Promote);
641
644
AddPromotedToType(ISD::STORE, MVT::v16f16, MVT::v8i32);
642
645
646
+ setOperationAction(ISD::LOAD, MVT::v32i16, Promote);
647
+ AddPromotedToType(ISD::LOAD, MVT::v32i16, MVT::v16i32);
648
+ setOperationAction(ISD::LOAD, MVT::v32f16, Promote);
649
+ AddPromotedToType(ISD::LOAD, MVT::v32f16, MVT::v16i32);
650
+
651
+ setOperationAction(ISD::STORE, MVT::v32i16, Promote);
652
+ AddPromotedToType(ISD::STORE, MVT::v32i16, MVT::v16i32);
653
+ setOperationAction(ISD::STORE, MVT::v32f16, Promote);
654
+ AddPromotedToType(ISD::STORE, MVT::v32f16, MVT::v16i32);
655
+
643
656
setOperationAction({ISD::ANY_EXTEND, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND},
644
657
MVT::v2i32, Expand);
645
658
setOperationAction(ISD::FP_EXTEND, MVT::v2f32, Expand);
@@ -662,12 +675,15 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
662
675
setOperationAction({ISD::FMAXNUM_IEEE, ISD::FMINNUM_IEEE}, MVT::f16, Legal);
663
676
664
677
setOperationAction({ISD::FMINNUM_IEEE, ISD::FMAXNUM_IEEE},
665
- {MVT::v4f16, MVT::v8f16, MVT::v16f16}, Custom);
678
+ {MVT::v4f16, MVT::v8f16, MVT::v16f16, MVT::v32f16},
679
+ Custom);
666
680
667
681
setOperationAction({ISD::FMINNUM, ISD::FMAXNUM},
668
- {MVT::v4f16, MVT::v8f16, MVT::v16f16}, Expand);
682
+ {MVT::v4f16, MVT::v8f16, MVT::v16f16, MVT::v32f16},
683
+ Expand);
669
684
670
- for (MVT Vec16 : {MVT::v8i16, MVT::v8f16, MVT::v16i16, MVT::v16f16}) {
685
+ for (MVT Vec16 : {MVT::v8i16, MVT::v8f16, MVT::v16i16, MVT::v16f16,
686
+ MVT::v32i16, MVT::v32f16}) {
671
687
setOperationAction(
672
688
{ISD::BUILD_VECTOR, ISD::EXTRACT_VECTOR_ELT, ISD::SCALAR_TO_VECTOR},
673
689
Vec16, Custom);
@@ -690,18 +706,18 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
690
706
691
707
setOperationAction(ISD::VECTOR_SHUFFLE,
692
708
{MVT::v4f16, MVT::v4i16, MVT::v8f16, MVT::v8i16,
693
- MVT::v16f16, MVT::v16i16},
709
+ MVT::v16f16, MVT::v16i16, MVT::v32f16, MVT::v32i16 },
694
710
Custom);
695
711
696
- for (MVT VT : {MVT::v4i16, MVT::v8i16, MVT::v16i16})
712
+ for (MVT VT : {MVT::v4i16, MVT::v8i16, MVT::v16i16, MVT::v32i16 })
697
713
// Split vector operations.
698
714
setOperationAction({ISD::SHL, ISD::SRA, ISD::SRL, ISD::ADD, ISD::SUB,
699
715
ISD::MUL, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX,
700
716
ISD::UADDSAT, ISD::SADDSAT, ISD::USUBSAT,
701
717
ISD::SSUBSAT},
702
718
VT, Custom);
703
719
704
- for (MVT VT : {MVT::v4f16, MVT::v8f16, MVT::v16f16})
720
+ for (MVT VT : {MVT::v4f16, MVT::v8f16, MVT::v16f16, MVT::v32f16 })
705
721
// Split vector operations.
706
722
setOperationAction({ISD::FADD, ISD::FMUL, ISD::FMA, ISD::FCANONICALIZE},
707
723
VT, Custom);
@@ -737,7 +753,8 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
737
753
738
754
setOperationAction(ISD::SELECT,
739
755
{MVT::v4i16, MVT::v4f16, MVT::v2i8, MVT::v4i8, MVT::v8i8,
740
- MVT::v8i16, MVT::v8f16, MVT::v16i16, MVT::v16f16},
756
+ MVT::v8i16, MVT::v8f16, MVT::v16i16, MVT::v16f16,
757
+ MVT::v32i16, MVT::v32f16},
741
758
Custom);
742
759
743
760
setOperationAction({ISD::SMULO, ISD::UMULO}, MVT::i64, Custom);
@@ -5107,7 +5124,7 @@ SDValue SITargetLowering::splitUnaryVectorOp(SDValue Op,
5107
5124
assert(VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v4f32 ||
5108
5125
VT == MVT::v8i16 || VT == MVT::v8f16 || VT == MVT::v16i16 ||
5109
5126
VT == MVT::v16f16 || VT == MVT::v8f32 || VT == MVT::v16f32 ||
5110
- VT == MVT::v32f32);
5127
+ VT == MVT::v32f32 || VT == MVT::v32i16 || VT == MVT::v32f16 );
5111
5128
5112
5129
SDValue Lo, Hi;
5113
5130
std::tie(Lo, Hi) = DAG.SplitVectorOperand(Op.getNode(), 0);
@@ -5130,7 +5147,7 @@ SDValue SITargetLowering::splitBinaryVectorOp(SDValue Op,
5130
5147
assert(VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v4f32 ||
5131
5148
VT == MVT::v8i16 || VT == MVT::v8f16 || VT == MVT::v16i16 ||
5132
5149
VT == MVT::v16f16 || VT == MVT::v8f32 || VT == MVT::v16f32 ||
5133
- VT == MVT::v32f32);
5150
+ VT == MVT::v32f32 || VT == MVT::v32i16 || VT == MVT::v32f16 );
5134
5151
5135
5152
SDValue Lo0, Hi0;
5136
5153
std::tie(Lo0, Hi0) = DAG.SplitVectorOperand(Op.getNode(), 0);
@@ -5897,7 +5914,8 @@ SDValue SITargetLowering::lowerFMINNUM_FMAXNUM(SDValue Op,
5897
5914
if (IsIEEEMode)
5898
5915
return expandFMINNUM_FMAXNUM(Op.getNode(), DAG);
5899
5916
5900
- if (VT == MVT::v4f16 || VT == MVT::v8f16 || VT == MVT::v16f16)
5917
+ if (VT == MVT::v4f16 || VT == MVT::v8f16 || VT == MVT::v16f16 ||
5918
+ VT == MVT::v16f16)
5901
5919
return splitBinaryVectorOp(Op, DAG);
5902
5920
return Op;
5903
5921
}
@@ -6415,7 +6433,7 @@ SDValue SITargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
6415
6433
if (SDValue Combined = performExtractVectorEltCombine(Op.getNode(), DCI))
6416
6434
return Combined;
6417
6435
6418
- if (VecSize == 128 || VecSize == 256) {
6436
+ if (VecSize == 128 || VecSize == 256 || VecSize == 512 ) {
6419
6437
SDValue Lo, Hi;
6420
6438
EVT LoVT, HiVT;
6421
6439
std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VecVT);
@@ -6428,9 +6446,7 @@ SDValue SITargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
6428
6446
Hi = DAG.getBitcast(HiVT,
6429
6447
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i64, V2,
6430
6448
DAG.getConstant(1, SL, MVT::i32)));
6431
- } else {
6432
- assert(VecSize == 256);
6433
-
6449
+ } else if (VecSize == 256) {
6434
6450
SDValue V2 = DAG.getBitcast(MVT::v4i64, Vec);
6435
6451
SDValue Parts[4];
6436
6452
for (unsigned P = 0; P < 4; ++P) {
@@ -6442,6 +6458,22 @@ SDValue SITargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
6442
6458
Parts[0], Parts[1]));
6443
6459
Hi = DAG.getBitcast(HiVT, DAG.getNode(ISD::BUILD_VECTOR, SL, MVT::v2i64,
6444
6460
Parts[2], Parts[3]));
6461
+ } else {
6462
+ assert(VecSize == 512);
6463
+
6464
+ SDValue V2 = DAG.getBitcast(MVT::v8i64, Vec);
6465
+ SDValue Parts[8];
6466
+ for (unsigned P = 0; P < 8; ++P) {
6467
+ Parts[P] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i64, V2,
6468
+ DAG.getConstant(P, SL, MVT::i32));
6469
+ }
6470
+
6471
+ Lo = DAG.getBitcast(LoVT,
6472
+ DAG.getNode(ISD::BUILD_VECTOR, SL, MVT::v4i64,
6473
+ Parts[0], Parts[1], Parts[2], Parts[3]));
6474
+ Hi = DAG.getBitcast(HiVT,
6475
+ DAG.getNode(ISD::BUILD_VECTOR, SL, MVT::v4i64,
6476
+ Parts[4], Parts[5],Parts[6], Parts[7]));
6445
6477
}
6446
6478
6447
6479
EVT IdxVT = Idx.getValueType();
@@ -6607,6 +6639,27 @@ SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
6607
6639
return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
6608
6640
}
6609
6641
6642
+ if (VT == MVT::v32i16 || VT == MVT::v32f16) {
6643
+ EVT QuarterVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
6644
+ VT.getVectorNumElements() / 8);
6645
+ MVT QuarterIntVT = MVT::getIntegerVT(QuarterVT.getSizeInBits());
6646
+
6647
+ SmallVector<SDValue, 8> Parts[8];
6648
+ for (unsigned I = 0, E = VT.getVectorNumElements() / 8; I != E; ++I) {
6649
+ for (unsigned P = 0; P < 8; ++P)
6650
+ Parts[P].push_back(Op.getOperand(I + P * E));
6651
+ }
6652
+ SDValue Casts[8];
6653
+ for (unsigned P = 0; P < 8; ++P) {
6654
+ SDValue Vec = DAG.getBuildVector(QuarterVT, SL, Parts[P]);
6655
+ Casts[P] = DAG.getNode(ISD::BITCAST, SL, QuarterIntVT, Vec);
6656
+ }
6657
+
6658
+ SDValue Blend =
6659
+ DAG.getBuildVector(MVT::getVectorVT(QuarterIntVT, 8), SL, Casts);
6660
+ return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
6661
+ }
6662
+
6610
6663
assert(VT == MVT::v2f16 || VT == MVT::v2i16);
6611
6664
assert(!Subtarget->hasVOP3PInsts() && "this should be legal");
6612
6665
@@ -9507,7 +9560,8 @@ SDValue SITargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
9507
9560
9508
9561
SDValue SITargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const {
9509
9562
EVT VT = Op.getValueType();
9510
- if (VT.getSizeInBits() == 128 || VT.getSizeInBits() == 256)
9563
+ if (VT.getSizeInBits() == 128 || VT.getSizeInBits() == 256 ||
9564
+ VT.getSizeInBits() == 512)
9511
9565
return splitTernaryVectorOp(Op, DAG);
9512
9566
9513
9567
assert(VT.getSizeInBits() == 64);
0 commit comments