@@ -167,8 +167,10 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
167
167
addRegisterClass(MVT::v4bf16, &AMDGPU::SReg_64RegClass);
168
168
addRegisterClass(MVT::v8i16, &AMDGPU::SGPR_128RegClass);
169
169
addRegisterClass(MVT::v8f16, &AMDGPU::SGPR_128RegClass);
170
+ addRegisterClass(MVT::v8bf16, &AMDGPU::SGPR_128RegClass);
170
171
addRegisterClass(MVT::v16i16, &AMDGPU::SGPR_256RegClass);
171
172
addRegisterClass(MVT::v16f16, &AMDGPU::SGPR_256RegClass);
173
+ addRegisterClass(MVT::v16bf16, &AMDGPU::SGPR_256RegClass);
172
174
addRegisterClass(MVT::v32i16, &AMDGPU::SGPR_512RegClass);
173
175
addRegisterClass(MVT::v32f16, &AMDGPU::SGPR_512RegClass);
174
176
}
@@ -310,13 +312,14 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
310
312
// We only support LOAD/STORE and vector manipulation ops for vectors
311
313
// with > 4 elements.
312
314
for (MVT VT :
313
- {MVT::v8i32, MVT::v8f32, MVT::v9i32, MVT::v9f32, MVT::v10i32,
314
- MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32,
315
- MVT::v16i32, MVT::v16f32, MVT::v2i64, MVT::v2f64, MVT::v4i16,
316
- MVT::v4f16, MVT::v4bf16, MVT::v3i64, MVT::v3f64, MVT::v6i32,
317
- MVT::v6f32, MVT::v4i64, MVT::v4f64, MVT::v8i64, MVT::v8f64,
318
- MVT::v8i16, MVT::v8f16, MVT::v16i16, MVT::v16f16, MVT::v16i64,
319
- MVT::v16f64, MVT::v32i32, MVT::v32f32, MVT::v32i16, MVT::v32f16}) {
315
+ {MVT::v8i32, MVT::v8f32, MVT::v9i32, MVT::v9f32, MVT::v10i32,
316
+ MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32,
317
+ MVT::v16i32, MVT::v16f32, MVT::v2i64, MVT::v2f64, MVT::v4i16,
318
+ MVT::v4f16, MVT::v4bf16, MVT::v3i64, MVT::v3f64, MVT::v6i32,
319
+ MVT::v6f32, MVT::v4i64, MVT::v4f64, MVT::v8i64, MVT::v8f64,
320
+ MVT::v8i16, MVT::v8f16, MVT::v8bf16, MVT::v16i16, MVT::v16f16,
321
+ MVT::v16bf16, MVT::v16i64, MVT::v16f64, MVT::v32i32, MVT::v32f32,
322
+ MVT::v32i16, MVT::v32f16, MVT::v32bf16}) {
320
323
for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op) {
321
324
switch (Op) {
322
325
case ISD::LOAD:
@@ -683,6 +686,8 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
683
686
AddPromotedToType(ISD::LOAD, MVT::v8i16, MVT::v4i32);
684
687
setOperationAction(ISD::LOAD, MVT::v8f16, Promote);
685
688
AddPromotedToType(ISD::LOAD, MVT::v8f16, MVT::v4i32);
689
+ setOperationAction(ISD::LOAD, MVT::v8bf16, Promote);
690
+ AddPromotedToType(ISD::LOAD, MVT::v8bf16, MVT::v4i32);
686
691
687
692
setOperationAction(ISD::STORE, MVT::v4i16, Promote);
688
693
AddPromotedToType(ISD::STORE, MVT::v4i16, MVT::v2i32);
@@ -693,16 +698,22 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
693
698
AddPromotedToType(ISD::STORE, MVT::v8i16, MVT::v4i32);
694
699
setOperationAction(ISD::STORE, MVT::v8f16, Promote);
695
700
AddPromotedToType(ISD::STORE, MVT::v8f16, MVT::v4i32);
701
+ setOperationAction(ISD::STORE, MVT::v8bf16, Promote);
702
+ AddPromotedToType(ISD::STORE, MVT::v8bf16, MVT::v4i32);
696
703
697
704
setOperationAction(ISD::LOAD, MVT::v16i16, Promote);
698
705
AddPromotedToType(ISD::LOAD, MVT::v16i16, MVT::v8i32);
699
706
setOperationAction(ISD::LOAD, MVT::v16f16, Promote);
700
707
AddPromotedToType(ISD::LOAD, MVT::v16f16, MVT::v8i32);
708
+ setOperationAction(ISD::LOAD, MVT::v16bf16, Promote);
709
+ AddPromotedToType(ISD::LOAD, MVT::v16bf16, MVT::v8i32);
701
710
702
711
setOperationAction(ISD::STORE, MVT::v16i16, Promote);
703
712
AddPromotedToType(ISD::STORE, MVT::v16i16, MVT::v8i32);
704
713
setOperationAction(ISD::STORE, MVT::v16f16, Promote);
705
714
AddPromotedToType(ISD::STORE, MVT::v16f16, MVT::v8i32);
715
+ setOperationAction(ISD::STORE, MVT::v16bf16, Promote);
716
+ AddPromotedToType(ISD::STORE, MVT::v16bf16, MVT::v8i32);
706
717
707
718
setOperationAction(ISD::LOAD, MVT::v32i16, Promote);
708
719
AddPromotedToType(ISD::LOAD, MVT::v32i16, MVT::v16i32);
@@ -725,7 +736,8 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
725
736
MVT::v8i32, Expand);
726
737
727
738
if (!Subtarget->hasVOP3PInsts())
728
- setOperationAction(ISD::BUILD_VECTOR, {MVT::v2i16, MVT::v2f16}, Custom);
739
+ setOperationAction(ISD::BUILD_VECTOR,
740
+ {MVT::v2i16, MVT::v2f16, MVT::v2bf16}, Custom);
729
741
730
742
setOperationAction(ISD::FNEG, MVT::v2f16, Legal);
731
743
// This isn't really legal, but this avoids the legalizer unrolling it (and
@@ -743,8 +755,9 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
743
755
{MVT::v4f16, MVT::v8f16, MVT::v16f16, MVT::v32f16},
744
756
Expand);
745
757
746
- for (MVT Vec16 : {MVT::v8i16, MVT::v8f16, MVT::v16i16, MVT::v16f16,
747
- MVT::v32i16, MVT::v32f16}) {
758
+ for (MVT Vec16 :
759
+ {MVT::v8i16, MVT::v8f16, MVT::v8bf16, MVT::v16i16, MVT::v16f16,
760
+ MVT::v16bf16, MVT::v32i16, MVT::v32f16, MVT::v32bf16}) {
748
761
setOperationAction(
749
762
{ISD::BUILD_VECTOR, ISD::EXTRACT_VECTOR_ELT, ISD::SCALAR_TO_VECTOR},
750
763
Vec16, Custom);
@@ -814,9 +827,10 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
814
827
}
815
828
816
829
setOperationAction(ISD::SELECT,
817
- {MVT::v4i16, MVT::v4f16, MVT::v2i8, MVT::v4i8, MVT::v8i8,
818
- MVT::v8i16, MVT::v8f16, MVT::v16i16, MVT::v16f16,
819
- MVT::v32i16, MVT::v32f16},
830
+ {MVT::v4i16, MVT::v4f16, MVT::v4bf16, MVT::v2i8, MVT::v4i8,
831
+ MVT::v8i8, MVT::v8i16, MVT::v8f16, MVT::v8bf16,
832
+ MVT::v16i16, MVT::v16f16, MVT::v16bf16, MVT::v32i16,
833
+ MVT::v32f16, MVT::v32bf16},
820
834
Custom);
821
835
822
836
setOperationAction({ISD::SMULO, ISD::UMULO}, MVT::i64, Custom);
@@ -5431,7 +5445,9 @@ SDValue SITargetLowering::splitTernaryVectorOp(SDValue Op,
5431
5445
assert(VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v8i16 ||
5432
5446
VT == MVT::v8f16 || VT == MVT::v4f32 || VT == MVT::v16i16 ||
5433
5447
VT == MVT::v16f16 || VT == MVT::v8f32 || VT == MVT::v16f32 ||
5434
- VT == MVT::v32f32 || VT == MVT::v32f16 || VT == MVT::v32i16);
5448
+ VT == MVT::v32f32 || VT == MVT::v32f16 || VT == MVT::v32i16 ||
5449
+ VT == MVT::v4bf16 || VT == MVT::v8bf16 || VT == MVT::v16bf16 ||
5450
+ VT == MVT::v32bf16);
5435
5451
5436
5452
SDValue Lo0, Hi0;
5437
5453
SDValue Op0 = Op.getOperand(0);
@@ -6854,8 +6870,8 @@ SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
6854
6870
SDLoc SL(Op);
6855
6871
EVT VT = Op.getValueType();
6856
6872
6857
- if (VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v4bf16 ||
6858
- VT == MVT::v8i16 || VT == MVT::v8f16 ) {
6873
+ if (VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v8i16 ||
6874
+ VT == MVT::v8f16 || VT == MVT::v4bf16 || VT == MVT::v8bf16 ) {
6859
6875
EVT HalfVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
6860
6876
VT.getVectorNumElements() / 2);
6861
6877
MVT HalfIntVT = MVT::getIntegerVT(HalfVT.getSizeInBits());
@@ -6878,7 +6894,7 @@ SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
6878
6894
return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
6879
6895
}
6880
6896
6881
- if (VT == MVT::v16i16 || VT == MVT::v16f16) {
6897
+ if (VT == MVT::v16i16 || VT == MVT::v16f16 || VT == MVT::v16bf16 ) {
6882
6898
EVT QuarterVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
6883
6899
VT.getVectorNumElements() / 4);
6884
6900
MVT QuarterIntVT = MVT::getIntegerVT(QuarterVT.getSizeInBits());
@@ -6899,7 +6915,7 @@ SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
6899
6915
return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
6900
6916
}
6901
6917
6902
- if (VT == MVT::v32i16 || VT == MVT::v32f16) {
6918
+ if (VT == MVT::v32i16 || VT == MVT::v32f16 || VT == MVT::v32bf16 ) {
6903
6919
EVT QuarterVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
6904
6920
VT.getVectorNumElements() / 8);
6905
6921
MVT QuarterIntVT = MVT::getIntegerVT(QuarterVT.getSizeInBits());
@@ -14182,11 +14198,11 @@ SDValue SITargetLowering::PerformDAGCombine(SDNode *N,
14182
14198
EVT VT = N->getValueType(0);
14183
14199
14184
14200
// v2i16 (scalar_to_vector i16:x) -> v2i16 (bitcast (any_extend i16:x))
14185
- if (VT == MVT::v2i16 || VT == MVT::v2f16) {
14201
+ if (VT == MVT::v2i16 || VT == MVT::v2f16 || VT == MVT::v2f16 ) {
14186
14202
SDLoc SL(N);
14187
14203
SDValue Src = N->getOperand(0);
14188
14204
EVT EltVT = Src.getValueType();
14189
- if (EltVT == MVT::f16 )
14205
+ if (EltVT != MVT::i16 )
14190
14206
Src = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Src);
14191
14207
14192
14208
SDValue Ext = DAG.getNode(ISD::ANY_EXTEND, SL, MVT::i32, Src);
0 commit comments