@@ -164,6 +164,7 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
164
164
addRegisterClass(MVT::v2bf16, &AMDGPU::SReg_32RegClass);
165
165
addRegisterClass(MVT::v4i16, &AMDGPU::SReg_64RegClass);
166
166
addRegisterClass(MVT::v4f16, &AMDGPU::SReg_64RegClass);
167
+ addRegisterClass(MVT::v4bf16, &AMDGPU::SReg_64RegClass);
167
168
addRegisterClass(MVT::v8i16, &AMDGPU::SGPR_128RegClass);
168
169
addRegisterClass(MVT::v8f16, &AMDGPU::SGPR_128RegClass);
169
170
addRegisterClass(MVT::v16i16, &AMDGPU::SGPR_256RegClass);
@@ -312,10 +313,10 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
312
313
{MVT::v8i32, MVT::v8f32, MVT::v9i32, MVT::v9f32, MVT::v10i32,
313
314
MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32,
314
315
MVT::v16i32, MVT::v16f32, MVT::v2i64, MVT::v2f64, MVT::v4i16,
315
- MVT::v4f16, MVT::v3i64, MVT::v3f64 , MVT::v6i32 , MVT::v6f32 ,
316
- MVT::v4i64 , MVT::v4f64 , MVT::v8i64 , MVT::v8f64 , MVT::v8i16 ,
317
- MVT::v8f16 , MVT::v16i16, MVT::v16f16 , MVT::v16i64 , MVT::v16f64 ,
318
- MVT::v32i32, MVT::v32f32, MVT::v32i16, MVT::v32f16}) {
316
+ MVT::v4f16, MVT::v4bf16, MVT::v3i64 , MVT::v3f64 , MVT::v6i32 ,
317
+ MVT::v6f32 , MVT::v4i64 , MVT::v4f64 , MVT::v8i64 , MVT::v8f64 ,
318
+ MVT::v8i16 , MVT::v8f16, MVT::v16i16 , MVT::v16f16 , MVT::v16i64 ,
319
+ MVT::v16f64, MVT:: v32i32, MVT::v32f32, MVT::v32i16, MVT::v32f16}) {
319
320
for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op) {
320
321
switch (Op) {
321
322
case ISD::LOAD:
@@ -421,13 +422,14 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
421
422
{MVT::v8i32, MVT::v8f32, MVT::v16i32, MVT::v16f32},
422
423
Expand);
423
424
424
- setOperationAction(ISD::BUILD_VECTOR, {MVT::v4f16, MVT::v4i16}, Custom);
425
+ setOperationAction(ISD::BUILD_VECTOR, {MVT::v4f16, MVT::v4i16, MVT::v4bf16},
426
+ Custom);
425
427
426
428
// Avoid stack access for these.
427
429
// TODO: Generalize to more vector types.
428
430
setOperationAction({ISD::EXTRACT_VECTOR_ELT, ISD::INSERT_VECTOR_ELT},
429
431
{MVT::v2i16, MVT::v2f16, MVT::v2bf16, MVT::v2i8, MVT::v4i8,
430
- MVT::v8i8, MVT::v4i16, MVT::v4f16},
432
+ MVT::v8i8, MVT::v4i16, MVT::v4f16, MVT::v4bf16 },
431
433
Custom);
432
434
433
435
// Deal with vec3 vector operations when widened to vec4.
@@ -667,11 +669,15 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
667
669
AddPromotedToType(ISD::LOAD, MVT::v4i16, MVT::v2i32);
668
670
setOperationAction(ISD::LOAD, MVT::v4f16, Promote);
669
671
AddPromotedToType(ISD::LOAD, MVT::v4f16, MVT::v2i32);
672
+ setOperationAction(ISD::LOAD, MVT::v4bf16, Promote);
673
+ AddPromotedToType(ISD::LOAD, MVT::v4bf16, MVT::v2i32);
670
674
671
675
setOperationAction(ISD::STORE, MVT::v4i16, Promote);
672
676
AddPromotedToType(ISD::STORE, MVT::v4i16, MVT::v2i32);
673
677
setOperationAction(ISD::STORE, MVT::v4f16, Promote);
674
678
AddPromotedToType(ISD::STORE, MVT::v4f16, MVT::v2i32);
679
+ setOperationAction(ISD::STORE, MVT::v4bf16, Promote);
680
+ AddPromotedToType(ISD::STORE, MVT::v4bf16, MVT::v2i32);
675
681
676
682
setOperationAction(ISD::LOAD, MVT::v8i16, Promote);
677
683
AddPromotedToType(ISD::LOAD, MVT::v8i16, MVT::v4i32);
@@ -781,7 +787,8 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
781
787
Custom);
782
788
783
789
setOperationAction(ISD::FEXP, MVT::v2f16, Custom);
784
- setOperationAction(ISD::SELECT, {MVT::v4i16, MVT::v4f16}, Custom);
790
+ setOperationAction(ISD::SELECT, {MVT::v4i16, MVT::v4f16, MVT::v4bf16},
791
+ Custom);
785
792
786
793
if (Subtarget->hasPackedFP32Ops()) {
787
794
setOperationAction({ISD::FADD, ISD::FMUL, ISD::FMA, ISD::FNEG},
@@ -6804,7 +6811,7 @@ SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
6804
6811
SDLoc SL(Op);
6805
6812
EVT VT = Op.getValueType();
6806
6813
6807
- if (VT == MVT::v4i16 || VT == MVT::v4f16 ||
6814
+ if (VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v4bf16 ||
6808
6815
VT == MVT::v8i16 || VT == MVT::v8f16) {
6809
6816
EVT HalfVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
6810
6817
VT.getVectorNumElements() / 2);
@@ -6870,7 +6877,7 @@ SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
6870
6877
return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
6871
6878
}
6872
6879
6873
- assert(VT == MVT::v2f16 || VT == MVT::v2i16);
6880
+ assert(VT == MVT::v2f16 || VT == MVT::v2i16 || VT == MVT::v2bf16 );
6874
6881
assert(!Subtarget->hasVOP3PInsts() && "this should be legal");
6875
6882
6876
6883
SDValue Lo = Op.getOperand(0);
0 commit comments