Skip to content

AMDGPU: Make bf16/v2bf16 legal types #76215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3199,7 +3199,16 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
return true;
}
break;
case ISD::FP_ROUND:
case ISD::FP_ROUND: {
EVT VT = Node->getValueType(0);
if (VT.getScalarType() == MVT::bf16) {
Results.push_back(
DAG.getNode(ISD::FP_TO_BF16, SDLoc(Node), VT, Node->getOperand(0)));
break;
}

LLVM_FALLTHROUGH;
}
case ISD::BITCAST:
if ((Tmp1 = EmitStackConvert(Node->getOperand(0), Node->getValueType(0),
Node->getValueType(0), dl)))
Expand All @@ -3226,12 +3235,19 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
return true;
}
break;
case ISD::FP_EXTEND:
if ((Tmp1 = EmitStackConvert(Node->getOperand(0),
Node->getOperand(0).getValueType(),
Node->getValueType(0), dl)))
case ISD::FP_EXTEND: {
SDValue Op = Node->getOperand(0);
EVT SrcVT = Op.getValueType();
EVT DstVT = Node->getValueType(0);
if (SrcVT.getScalarType() == MVT::bf16) {
Results.push_back(DAG.getNode(ISD::BF16_TO_FP, SDLoc(Node), DstVT, Op));
break;
}

if ((Tmp1 = EmitStackConvert(Op, SrcVT, DstVT, dl)))
Results.push_back(Tmp1);
break;
}
case ISD::BF16_TO_FP: {
// Always expand bf16 to f32 casts, they lower to ext + shift.
//
Expand Down
26 changes: 13 additions & 13 deletions llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,28 @@ def CC_SI_Gfx : CallingConv<[
// 32 is reserved for the stack pointer
// 33 is reserved for the frame pointer
// 34 is reserved for the base pointer
CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
SGPR4, SGPR5, SGPR6, SGPR7,
SGPR8, SGPR9, SGPR10, SGPR11, SGPR12, SGPR13, SGPR14, SGPR15,
SGPR16, SGPR17, SGPR18, SGPR19, SGPR20, SGPR21, SGPR22, SGPR23,
SGPR24, SGPR25, SGPR26, SGPR27, SGPR28, SGPR29
]>>>,

CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
VGPR24, VGPR25, VGPR26, VGPR27, VGPR28, VGPR29, VGPR30, VGPR31
]>>>,

CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1], CCAssignToStack<4, 4>>
CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1, bf16, v2bf16], CCAssignToStack<4, 4>>
]>;

def RetCC_SI_Gfx : CallingConv<[
CCIfType<[i1], CCPromoteToType<i32>>,
CCIfType<[i1, i16], CCIfExtend<CCPromoteToType<i32>>>,

CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
Expand All @@ -66,7 +66,7 @@ def RetCC_SI_Gfx : CallingConv<[

def CC_SI_SHADER : CallingConv<[

CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
SGPR0, SGPR1, SGPR2, SGPR3, SGPR4, SGPR5, SGPR6, SGPR7,
SGPR8, SGPR9, SGPR10, SGPR11, SGPR12, SGPR13, SGPR14, SGPR15,
SGPR16, SGPR17, SGPR18, SGPR19, SGPR20, SGPR21, SGPR22, SGPR23,
Expand All @@ -76,7 +76,7 @@ def CC_SI_SHADER : CallingConv<[
]>>>,

// 32*4 + 4 is the minimum for a fetch shader consumer with 32 inputs.
CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
Expand Down Expand Up @@ -109,7 +109,7 @@ def RetCC_SI_Shader : CallingConv<[
]>>,

// 32*4 + 4 is the minimum for a fetch shader with 32 outputs.
CCIfType<[f32, f16, v2f16] , CCAssignToReg<[
CCIfType<[f32, f16, v2f16, bf16, v2bf16] , CCAssignToReg<[
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
Expand Down Expand Up @@ -188,23 +188,23 @@ def CC_AMDGPU_Func : CallingConv<[
CCIfType<[i1], CCPromoteToType<i32>>,
CCIfType<[i8, i16], CCIfExtend<CCPromoteToType<i32>>>,

CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
!foreach(i, !range(0, 30), !cast<Register>("SGPR"#i)) // SGPR0-29
>>>,

CCIfType<[i32, f32, i16, f16, v2i16, v2f16, i1], CCAssignToReg<[
CCIfType<[i32, f32, i16, f16, v2i16, v2f16, i1, bf16, v2bf16], CCAssignToReg<[
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
VGPR24, VGPR25, VGPR26, VGPR27, VGPR28, VGPR29, VGPR30, VGPR31]>>,
CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1], CCAssignToStack<4, 4>>
CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1, bf16, v2bf16], CCAssignToStack<4, 4>>
]>;

// Calling convention for leaf functions
def RetCC_AMDGPU_Func : CallingConv<[
CCIfType<[i1], CCPromoteToType<i32>>,
CCIfType<[i1, i16], CCIfExtend<CCPromoteToType<i32>>>,
CCIfType<[i32, f32, i16, f16, v2i16, v2f16], CCAssignToReg<[
CCIfType<[i32, f32, i16, f16, v2i16, v2f16, bf16, v2bf16], CCAssignToReg<[
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
Expand All @@ -223,11 +223,11 @@ def CC_AMDGPU : CallingConv<[
]>;

def CC_AMDGPU_CS_CHAIN : CallingConv<[
CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
!foreach(i, !range(105), !cast<Register>("SGPR"#i))
>>>,

CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
!foreach(i, !range(8, 255), !cast<Register>("VGPR"#i))
>>>
]>;
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ void AMDGPUDAGToDAGISel::PreprocessISelDAG() {

switch (N->getOpcode()) {
case ISD::BUILD_VECTOR:
// TODO: Match load d16 from shl (extload:i16), 16
MadeChange |= matchLoadD16FromBuildVector(N);
break;
default:
Expand Down
37 changes: 32 additions & 5 deletions llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3282,7 +3282,15 @@ SDValue AMDGPUTargetLowering::LowerUINT_TO_FP(SDValue Op,
return DAG.getNode(ISD::UINT_TO_FP, DL, DestVT, Ext);
}

assert(SrcVT == MVT::i64 && "operation should be legal");
if (DestVT == MVT::bf16) {
SDLoc SL(Op);
SDValue ToF32 = DAG.getNode(ISD::UINT_TO_FP, SL, MVT::f32, Src);
SDValue FPRoundFlag = DAG.getIntPtrConstant(0, SL, /*isTarget=*/true);
return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ToF32, FPRoundFlag);
}

if (SrcVT != MVT::i64)
return Op;

if (Subtarget->has16BitInsts() && DestVT == MVT::f16) {
SDLoc DL(Op);
Expand Down Expand Up @@ -3320,7 +3328,15 @@ SDValue AMDGPUTargetLowering::LowerSINT_TO_FP(SDValue Op,
return DAG.getNode(ISD::SINT_TO_FP, DL, DestVT, Ext);
}

assert(SrcVT == MVT::i64 && "operation should be legal");
if (DestVT == MVT::bf16) {
SDLoc SL(Op);
SDValue ToF32 = DAG.getNode(ISD::SINT_TO_FP, SL, MVT::f32, Src);
SDValue FPRoundFlag = DAG.getIntPtrConstant(0, SL, /*isTarget=*/true);
return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ToF32, FPRoundFlag);
}

if (SrcVT != MVT::i64)
return Op;

// TODO: Factor out code common with LowerUINT_TO_FP.

Expand Down Expand Up @@ -3518,7 +3534,7 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) con
return DAG.getZExtOrTrunc(V, DL, Op.getValueType());
}

SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
SDValue AMDGPUTargetLowering::LowerFP_TO_INT(const SDValue Op,
SelectionDAG &DAG) const {
SDValue Src = Op.getOperand(0);
unsigned OpOpcode = Op.getOpcode();
Expand All @@ -3529,6 +3545,12 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
if (SrcVT == MVT::f16 && DestVT == MVT::i16)
return Op;

if (SrcVT == MVT::bf16) {
SDLoc DL(Op);
SDValue PromotedSrc = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Src);
return DAG.getNode(Op.getOpcode(), DL, DestVT, PromotedSrc);
}

// Promote i16 to i32
if (DestVT == MVT::i16 && (SrcVT == MVT::f32 || SrcVT == MVT::f64)) {
SDLoc DL(Op);
Expand All @@ -3537,6 +3559,9 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToInt32);
}

if (DestVT != MVT::i64)
return Op;

if (SrcVT == MVT::f16 ||
(SrcVT == MVT::f32 && Src.getOpcode() == ISD::FP16_TO_FP)) {
SDLoc DL(Op);
Expand All @@ -3547,7 +3572,7 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
return DAG.getNode(Ext, DL, MVT::i64, FpToInt32);
}

if (DestVT == MVT::i64 && (SrcVT == MVT::f32 || SrcVT == MVT::f64))
if (SrcVT == MVT::f32 || SrcVT == MVT::f64)
return LowerFP_TO_INT64(Op, DAG, OpOpcode == ISD::FP_TO_SINT);

return SDValue();
Expand Down Expand Up @@ -4948,7 +4973,9 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N,
// vnt1 = build_vector (t1 (bitcast t0:x)), (t1 (bitcast t0:y))
if (DestVT.isVector()) {
SDValue Src = N->getOperand(0);
if (Src.getOpcode() == ISD::BUILD_VECTOR) {
if (Src.getOpcode() == ISD::BUILD_VECTOR &&
(DCI.getDAGCombineLevel() < AfterLegalizeDAG ||
isOperationLegal(ISD::BUILD_VECTOR, DestVT))) {
EVT SrcVT = Src.getValueType();
unsigned NElts = DestVT.getVectorNumElements();

Expand Down
Loading