Skip to content

Commit 4768563

Browse files
authored
AMDGPU: Make v4bf16 a legal type (#76217)
Gets a few code quality improvements. A few cases are worse from losing load narrowing. Depends #76213 #76214 #76215
1 parent c1eef48 commit 4768563

File tree

8 files changed

+5678
-6430
lines changed

8 files changed

+5678
-6430
lines changed

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -389,15 +389,16 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
389389
Custom);
390390
setOperationAction(
391391
ISD::EXTRACT_SUBVECTOR,
392-
{MVT::v2f16, MVT::v2i16, MVT::v4f16, MVT::v4i16, MVT::v2f32,
393-
MVT::v2i32, MVT::v3f32, MVT::v3i32, MVT::v4f32, MVT::v4i32,
394-
MVT::v5f32, MVT::v5i32, MVT::v6f32, MVT::v6i32, MVT::v7f32,
395-
MVT::v7i32, MVT::v8f32, MVT::v8i32, MVT::v9f32, MVT::v9i32,
396-
MVT::v10i32, MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32,
397-
MVT::v12f32, MVT::v16f16, MVT::v16i16, MVT::v16f32, MVT::v16i32,
398-
MVT::v32f32, MVT::v32i32, MVT::v2f64, MVT::v2i64, MVT::v3f64,
399-
MVT::v3i64, MVT::v4f64, MVT::v4i64, MVT::v8f64, MVT::v8i64,
400-
MVT::v16f64, MVT::v16i64, MVT::v32i16, MVT::v32f16},
392+
{MVT::v2f16, MVT::v2i16, MVT::v2bf16, MVT::v4f16, MVT::v4i16,
393+
MVT::v4bf16, MVT::v2f32, MVT::v2i32, MVT::v3f32, MVT::v3i32,
394+
MVT::v4f32, MVT::v4i32, MVT::v5f32, MVT::v5i32, MVT::v6f32,
395+
MVT::v6i32, MVT::v7f32, MVT::v7i32, MVT::v8f32, MVT::v8i32,
396+
MVT::v9f32, MVT::v9i32, MVT::v10i32, MVT::v10f32, MVT::v11i32,
397+
MVT::v11f32, MVT::v12i32, MVT::v12f32, MVT::v16f16, MVT::v16i16,
398+
MVT::v16f32, MVT::v16i32, MVT::v32f32, MVT::v32i32, MVT::v2f64,
399+
MVT::v2i64, MVT::v3f64, MVT::v3i64, MVT::v4f64, MVT::v4i64,
400+
MVT::v8f64, MVT::v8i64, MVT::v16f64, MVT::v16i64, MVT::v32i16,
401+
MVT::v32f16},
401402
Custom);
402403

403404
setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
164164
addRegisterClass(MVT::v2bf16, &AMDGPU::SReg_32RegClass);
165165
addRegisterClass(MVT::v4i16, &AMDGPU::SReg_64RegClass);
166166
addRegisterClass(MVT::v4f16, &AMDGPU::SReg_64RegClass);
167+
addRegisterClass(MVT::v4bf16, &AMDGPU::SReg_64RegClass);
167168
addRegisterClass(MVT::v8i16, &AMDGPU::SGPR_128RegClass);
168169
addRegisterClass(MVT::v8f16, &AMDGPU::SGPR_128RegClass);
169170
addRegisterClass(MVT::v16i16, &AMDGPU::SGPR_256RegClass);
@@ -312,10 +313,10 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
312313
{MVT::v8i32, MVT::v8f32, MVT::v9i32, MVT::v9f32, MVT::v10i32,
313314
MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32,
314315
MVT::v16i32, MVT::v16f32, MVT::v2i64, MVT::v2f64, MVT::v4i16,
315-
MVT::v4f16, MVT::v3i64, MVT::v3f64, MVT::v6i32, MVT::v6f32,
316-
MVT::v4i64, MVT::v4f64, MVT::v8i64, MVT::v8f64, MVT::v8i16,
317-
MVT::v8f16, MVT::v16i16, MVT::v16f16, MVT::v16i64, MVT::v16f64,
318-
MVT::v32i32, MVT::v32f32, MVT::v32i16, MVT::v32f16}) {
316+
MVT::v4f16, MVT::v4bf16, MVT::v3i64, MVT::v3f64, MVT::v6i32,
317+
MVT::v6f32, MVT::v4i64, MVT::v4f64, MVT::v8i64, MVT::v8f64,
318+
MVT::v8i16, MVT::v8f16, MVT::v16i16, MVT::v16f16, MVT::v16i64,
319+
MVT::v16f64, MVT::v32i32, MVT::v32f32, MVT::v32i16, MVT::v32f16}) {
319320
for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op) {
320321
switch (Op) {
321322
case ISD::LOAD:
@@ -421,13 +422,14 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
421422
{MVT::v8i32, MVT::v8f32, MVT::v16i32, MVT::v16f32},
422423
Expand);
423424

424-
setOperationAction(ISD::BUILD_VECTOR, {MVT::v4f16, MVT::v4i16}, Custom);
425+
setOperationAction(ISD::BUILD_VECTOR, {MVT::v4f16, MVT::v4i16, MVT::v4bf16},
426+
Custom);
425427

426428
// Avoid stack access for these.
427429
// TODO: Generalize to more vector types.
428430
setOperationAction({ISD::EXTRACT_VECTOR_ELT, ISD::INSERT_VECTOR_ELT},
429431
{MVT::v2i16, MVT::v2f16, MVT::v2bf16, MVT::v2i8, MVT::v4i8,
430-
MVT::v8i8, MVT::v4i16, MVT::v4f16},
432+
MVT::v8i8, MVT::v4i16, MVT::v4f16, MVT::v4bf16},
431433
Custom);
432434

433435
// Deal with vec3 vector operations when widened to vec4.
@@ -667,11 +669,15 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
667669
AddPromotedToType(ISD::LOAD, MVT::v4i16, MVT::v2i32);
668670
setOperationAction(ISD::LOAD, MVT::v4f16, Promote);
669671
AddPromotedToType(ISD::LOAD, MVT::v4f16, MVT::v2i32);
672+
setOperationAction(ISD::LOAD, MVT::v4bf16, Promote);
673+
AddPromotedToType(ISD::LOAD, MVT::v4bf16, MVT::v2i32);
670674

671675
setOperationAction(ISD::STORE, MVT::v4i16, Promote);
672676
AddPromotedToType(ISD::STORE, MVT::v4i16, MVT::v2i32);
673677
setOperationAction(ISD::STORE, MVT::v4f16, Promote);
674678
AddPromotedToType(ISD::STORE, MVT::v4f16, MVT::v2i32);
679+
setOperationAction(ISD::STORE, MVT::v4bf16, Promote);
680+
AddPromotedToType(ISD::STORE, MVT::v4bf16, MVT::v2i32);
675681

676682
setOperationAction(ISD::LOAD, MVT::v8i16, Promote);
677683
AddPromotedToType(ISD::LOAD, MVT::v8i16, MVT::v4i32);
@@ -781,7 +787,8 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
781787
Custom);
782788

783789
setOperationAction(ISD::FEXP, MVT::v2f16, Custom);
784-
setOperationAction(ISD::SELECT, {MVT::v4i16, MVT::v4f16}, Custom);
790+
setOperationAction(ISD::SELECT, {MVT::v4i16, MVT::v4f16, MVT::v4bf16},
791+
Custom);
785792

786793
if (Subtarget->hasPackedFP32Ops()) {
787794
setOperationAction({ISD::FADD, ISD::FMUL, ISD::FMA, ISD::FNEG},
@@ -6805,7 +6812,7 @@ SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
68056812
SDLoc SL(Op);
68066813
EVT VT = Op.getValueType();
68076814

6808-
if (VT == MVT::v4i16 || VT == MVT::v4f16 ||
6815+
if (VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v4bf16 ||
68096816
VT == MVT::v8i16 || VT == MVT::v8f16) {
68106817
EVT HalfVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
68116818
VT.getVectorNumElements() / 2);
@@ -6871,7 +6878,7 @@ SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
68716878
return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
68726879
}
68736880

6874-
assert(VT == MVT::v2f16 || VT == MVT::v2i16);
6881+
assert(VT == MVT::v2f16 || VT == MVT::v2i16 || VT == MVT::v2bf16);
68756882
assert(!Subtarget->hasVOP3PInsts() && "this should be legal");
68766883

68776884
SDValue Lo = Op.getOperand(0);

llvm/lib/Target/AMDGPU/SIInstructions.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,6 +1548,19 @@ def : BitConvert <f64, v2i32, VReg_64>;
15481548
def : BitConvert <v2i32, f64, VReg_64>;
15491549
def : BitConvert <v4i16, v4f16, VReg_64>;
15501550
def : BitConvert <v4f16, v4i16, VReg_64>;
1551+
def : BitConvert <v4bf16, v2i32, VReg_64>;
1552+
def : BitConvert <v2i32, v4bf16, VReg_64>;
1553+
def : BitConvert <v4bf16, i64, VReg_64>;
1554+
def : BitConvert <i64, v4bf16, VReg_64>;
1555+
def : BitConvert <v4bf16, v4i16, VReg_64>;
1556+
def : BitConvert <v4i16, v4bf16, VReg_64>;
1557+
def : BitConvert <v4bf16, v4f16, VReg_64>;
1558+
def : BitConvert <v4f16, v4bf16, VReg_64>;
1559+
def : BitConvert <v4bf16, v2f32, VReg_64>;
1560+
def : BitConvert <v2f32, v4bf16, VReg_64>;
1561+
def : BitConvert <v4bf16, f64, VReg_64>;
1562+
def : BitConvert <f64, v4bf16, VReg_64>;
1563+
15511564

15521565
// FIXME: Make SGPR
15531566
def : BitConvert <v2i32, v4f16, VReg_64>;

0 commit comments

Comments
 (0)