-
Notifications
You must be signed in to change notification settings - Fork 14.3k
AMDGPU: Make bf16/v2bf16 legal types #76215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-selectiondag @llvm/pr-subscribers-backend-amdgpu Author: Matt Arsenault (arsenm) ChangesThere are some intrinsics are using i16 vectors in place of bfloat vectors. Depends #76213 #76214 Patch is 1.11 MiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76215.diff 28 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index a483b8028fda9e..296ed3a3c3dc11 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -3199,7 +3199,16 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
return true;
}
break;
- case ISD::FP_ROUND:
+ case ISD::FP_ROUND: {
+ EVT VT = Node->getValueType(0);
+ if (VT.getScalarType() == MVT::bf16) {
+ Results.push_back(
+ DAG.getNode(ISD::FP_TO_BF16, SDLoc(Node), VT, Node->getOperand(0)));
+ break;
+ }
+
+ LLVM_FALLTHROUGH;
+ }
case ISD::BITCAST:
if ((Tmp1 = EmitStackConvert(Node->getOperand(0), Node->getValueType(0),
Node->getValueType(0), dl)))
@@ -3226,12 +3235,19 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
return true;
}
break;
- case ISD::FP_EXTEND:
- if ((Tmp1 = EmitStackConvert(Node->getOperand(0),
- Node->getOperand(0).getValueType(),
- Node->getValueType(0), dl)))
+ case ISD::FP_EXTEND: {
+ SDValue Op = Node->getOperand(0);
+ EVT SrcVT = Op.getValueType();
+ EVT DstVT = Node->getValueType(0);
+ if (SrcVT.getScalarType() == MVT::bf16) {
+ Results.push_back(DAG.getNode(ISD::BF16_TO_FP, SDLoc(Node), DstVT, Op));
+ break;
+ }
+
+ if ((Tmp1 = EmitStackConvert(Op, SrcVT, DstVT, dl)))
Results.push_back(Tmp1);
break;
+ }
case ISD::BF16_TO_FP: {
// Always expand bf16 to f32 casts, they lower to ext + shift.
//
@@ -4908,7 +4924,9 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) {
static MVT getPromotedVectorElementType(const TargetLowering &TLI,
MVT EltVT, MVT NewEltVT) {
unsigned OldEltsPerNewElt = EltVT.getSizeInBits() / NewEltVT.getSizeInBits();
- MVT MidVT = MVT::getVectorVT(NewEltVT, OldEltsPerNewElt);
+ MVT MidVT = OldEltsPerNewElt == 1
+ ? NewEltVT
+ : MVT::getVectorVT(NewEltVT, OldEltsPerNewElt);
assert(TLI.isTypeLegal(MidVT) && "unexpected");
return MidVT;
}
@@ -5395,7 +5413,7 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
assert(NVT.isVector() && OVT.getSizeInBits() == NVT.getSizeInBits() &&
"Invalid promote type for build_vector");
- assert(NewEltVT.bitsLT(EltVT) && "not handled");
+ assert(NewEltVT.bitsLE(EltVT) && "not handled");
MVT MidVT = getPromotedVectorElementType(TLI, EltVT, NewEltVT);
@@ -5406,7 +5424,9 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
}
SDLoc SL(Node);
- SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, SL, NVT, NewOps);
+ SDValue Concat =
+ DAG.getNode(MidVT == NewEltVT ? ISD::BUILD_VECTOR : ISD::CONCAT_VECTORS,
+ SL, NVT, NewOps);
SDValue CvtVec = DAG.getNode(ISD::BITCAST, SL, OVT, Concat);
Results.push_back(CvtVec);
break;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td b/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td
index 9036b26a6f6bcb..c5207228dc913f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td
@@ -22,28 +22,28 @@ def CC_SI_Gfx : CallingConv<[
// 32 is reserved for the stack pointer
// 33 is reserved for the frame pointer
// 34 is reserved for the base pointer
- CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+ CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
SGPR4, SGPR5, SGPR6, SGPR7,
SGPR8, SGPR9, SGPR10, SGPR11, SGPR12, SGPR13, SGPR14, SGPR15,
SGPR16, SGPR17, SGPR18, SGPR19, SGPR20, SGPR21, SGPR22, SGPR23,
SGPR24, SGPR25, SGPR26, SGPR27, SGPR28, SGPR29
]>>>,
- CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+ CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
VGPR24, VGPR25, VGPR26, VGPR27, VGPR28, VGPR29, VGPR30, VGPR31
]>>>,
- CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1], CCAssignToStack<4, 4>>
+ CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1, bf16, v2bf16], CCAssignToStack<4, 4>>
]>;
def RetCC_SI_Gfx : CallingConv<[
CCIfType<[i1], CCPromoteToType<i32>>,
CCIfType<[i1, i16], CCIfExtend<CCPromoteToType<i32>>>,
- CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+ CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -66,7 +66,7 @@ def RetCC_SI_Gfx : CallingConv<[
def CC_SI_SHADER : CallingConv<[
- CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+ CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
SGPR0, SGPR1, SGPR2, SGPR3, SGPR4, SGPR5, SGPR6, SGPR7,
SGPR8, SGPR9, SGPR10, SGPR11, SGPR12, SGPR13, SGPR14, SGPR15,
SGPR16, SGPR17, SGPR18, SGPR19, SGPR20, SGPR21, SGPR22, SGPR23,
@@ -76,7 +76,7 @@ def CC_SI_SHADER : CallingConv<[
]>>>,
// 32*4 + 4 is the minimum for a fetch shader consumer with 32 inputs.
- CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+ CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -109,7 +109,7 @@ def RetCC_SI_Shader : CallingConv<[
]>>,
// 32*4 + 4 is the minimum for a fetch shader with 32 outputs.
- CCIfType<[f32, f16, v2f16] , CCAssignToReg<[
+ CCIfType<[f32, f16, v2f16, bf16, v2bf16] , CCAssignToReg<[
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -188,23 +188,23 @@ def CC_AMDGPU_Func : CallingConv<[
CCIfType<[i1], CCPromoteToType<i32>>,
CCIfType<[i8, i16], CCIfExtend<CCPromoteToType<i32>>>,
- CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
+ CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
!foreach(i, !range(0, 30), !cast<Register>("SGPR"#i)) // SGPR0-29
>>>,
- CCIfType<[i32, f32, i16, f16, v2i16, v2f16, i1], CCAssignToReg<[
+ CCIfType<[i32, f32, i16, f16, v2i16, v2f16, i1, bf16, v2bf16], CCAssignToReg<[
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
VGPR24, VGPR25, VGPR26, VGPR27, VGPR28, VGPR29, VGPR30, VGPR31]>>,
- CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1], CCAssignToStack<4, 4>>
+ CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1, bf16, v2bf16], CCAssignToStack<4, 4>>
]>;
// Calling convention for leaf functions
def RetCC_AMDGPU_Func : CallingConv<[
CCIfType<[i1], CCPromoteToType<i32>>,
CCIfType<[i1, i16], CCIfExtend<CCPromoteToType<i32>>>,
- CCIfType<[i32, f32, i16, f16, v2i16, v2f16], CCAssignToReg<[
+ CCIfType<[i32, f32, i16, f16, v2i16, v2f16, bf16, v2bf16], CCAssignToReg<[
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -223,11 +223,11 @@ def CC_AMDGPU : CallingConv<[
]>;
def CC_AMDGPU_CS_CHAIN : CallingConv<[
- CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
+ CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
!foreach(i, !range(105), !cast<Register>("SGPR"#i))
>>>,
- CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
+ CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
!foreach(i, !range(8, 255), !cast<Register>("VGPR"#i))
>>>
]>;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
index b0eac567ec9f18..40a49cbe3f518f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
@@ -303,6 +303,7 @@ void AMDGPUDAGToDAGISel::PreprocessISelDAG() {
switch (N->getOpcode()) {
case ISD::BUILD_VECTOR:
+ // TODO: Match load d16 from shl (extload:i16), 16
MadeChange |= matchLoadD16FromBuildVector(N);
break;
default:
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index 4bf4707553e5fe..131000830a73d5 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -3274,7 +3274,15 @@ SDValue AMDGPUTargetLowering::LowerUINT_TO_FP(SDValue Op,
return DAG.getNode(ISD::UINT_TO_FP, DL, DestVT, Ext);
}
- assert(SrcVT == MVT::i64 && "operation should be legal");
+ if (DestVT == MVT::bf16) {
+ SDLoc SL(Op);
+ SDValue ToF32 = DAG.getNode(ISD::UINT_TO_FP, SL, MVT::f32, Src);
+ SDValue FPRoundFlag = DAG.getIntPtrConstant(0, SL, /*isTarget=*/true);
+ return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ToF32, FPRoundFlag);
+ }
+
+ if (SrcVT != MVT::i64)
+ return Op;
if (Subtarget->has16BitInsts() && DestVT == MVT::f16) {
SDLoc DL(Op);
@@ -3312,7 +3320,15 @@ SDValue AMDGPUTargetLowering::LowerSINT_TO_FP(SDValue Op,
return DAG.getNode(ISD::SINT_TO_FP, DL, DestVT, Ext);
}
- assert(SrcVT == MVT::i64 && "operation should be legal");
+ if (DestVT == MVT::bf16) {
+ SDLoc SL(Op);
+ SDValue ToF32 = DAG.getNode(ISD::SINT_TO_FP, SL, MVT::f32, Src);
+ SDValue FPRoundFlag = DAG.getIntPtrConstant(0, SL, /*isTarget=*/true);
+ return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ToF32, FPRoundFlag);
+ }
+
+ if (SrcVT != MVT::i64)
+ return Op;
// TODO: Factor out code common with LowerUINT_TO_FP.
@@ -3510,7 +3526,7 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) con
return DAG.getZExtOrTrunc(V, DL, Op.getValueType());
}
-SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
+SDValue AMDGPUTargetLowering::LowerFP_TO_INT(const SDValue Op,
SelectionDAG &DAG) const {
SDValue Src = Op.getOperand(0);
unsigned OpOpcode = Op.getOpcode();
@@ -3521,6 +3537,12 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
if (SrcVT == MVT::f16 && DestVT == MVT::i16)
return Op;
+ if (SrcVT == MVT::bf16) {
+ SDLoc DL(Op);
+ SDValue PromotedSrc = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Src);
+ return DAG.getNode(Op.getOpcode(), DL, DestVT, PromotedSrc);
+ }
+
// Promote i16 to i32
if (DestVT == MVT::i16 && (SrcVT == MVT::f32 || SrcVT == MVT::f64)) {
SDLoc DL(Op);
@@ -3529,6 +3551,9 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToInt32);
}
+ if (DestVT != MVT::i64)
+ return Op;
+
if (SrcVT == MVT::f16 ||
(SrcVT == MVT::f32 && Src.getOpcode() == ISD::FP16_TO_FP)) {
SDLoc DL(Op);
@@ -3539,7 +3564,7 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
return DAG.getNode(Ext, DL, MVT::i64, FpToInt32);
}
- if (DestVT == MVT::i64 && (SrcVT == MVT::f32 || SrcVT == MVT::f64))
+ if (SrcVT == MVT::f32 || SrcVT == MVT::f64)
return LowerFP_TO_INT64(Op, DAG, OpOpcode == ISD::FP_TO_SINT);
return SDValue();
@@ -4940,7 +4965,9 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N,
// vnt1 = build_vector (t1 (bitcast t0:x)), (t1 (bitcast t0:y))
if (DestVT.isVector()) {
SDValue Src = N->getOperand(0);
- if (Src.getOpcode() == ISD::BUILD_VECTOR) {
+ if (Src.getOpcode() == ISD::BUILD_VECTOR &&
+ (DCI.getDAGCombineLevel() < AfterLegalizeDAG ||
+ isOperationLegal(ISD::BUILD_VECTOR, DestVT))) {
EVT SrcVT = Src.getValueType();
unsigned NElts = DestVT.getVectorNumElements();
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index fc119aa61d01a2..eaf32850a87149 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -151,14 +151,17 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
if (Subtarget->useRealTrue16Insts()) {
addRegisterClass(MVT::i16, &AMDGPU::VGPR_16RegClass);
addRegisterClass(MVT::f16, &AMDGPU::VGPR_16RegClass);
+ addRegisterClass(MVT::bf16, &AMDGPU::VGPR_16RegClass);
} else {
addRegisterClass(MVT::i16, &AMDGPU::SReg_32RegClass);
addRegisterClass(MVT::f16, &AMDGPU::SReg_32RegClass);
+ addRegisterClass(MVT::bf16, &AMDGPU::SReg_32RegClass);
}
// Unless there are also VOP3P operations, not operations are really legal.
addRegisterClass(MVT::v2i16, &AMDGPU::SReg_32RegClass);
addRegisterClass(MVT::v2f16, &AMDGPU::SReg_32RegClass);
+ addRegisterClass(MVT::v2bf16, &AMDGPU::SReg_32RegClass);
addRegisterClass(MVT::v4i16, &AMDGPU::SReg_64RegClass);
addRegisterClass(MVT::v4f16, &AMDGPU::SReg_64RegClass);
addRegisterClass(MVT::v8i16, &AMDGPU::SGPR_128RegClass);
@@ -196,6 +199,41 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
MVT::i1, MVT::v32i32},
Custom);
+ if (isTypeLegal(MVT::bf16)) {
+ for (unsigned Opc :
+ {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV,
+ ISD::FREM, ISD::FMA, ISD::FMINNUM, ISD::FMAXNUM,
+ ISD::FMINIMUM, ISD::FMAXIMUM, ISD::FSQRT, ISD::FCBRT,
+ ISD::FSIN, ISD::FCOS, ISD::FPOW, ISD::FPOWI,
+ ISD::FLDEXP, ISD::FFREXP, ISD::FLOG, ISD::FLOG2,
+ ISD::FLOG10, ISD::FEXP, ISD::FEXP2, ISD::FEXP10,
+ ISD::FCEIL, ISD::FTRUNC, ISD::FRINT, ISD::FNEARBYINT,
+ ISD::FROUND, ISD::FROUNDEVEN, ISD::FFLOOR, ISD::FCANONICALIZE,
+ ISD::SETCC}) {
+ // FIXME: The promoted to type shouldn't need to be explicit
+ setOperationAction(Opc, MVT::bf16, Promote);
+ AddPromotedToType(Opc, MVT::bf16, MVT::f32);
+ }
+
+ setOperationAction(ISD::FP_ROUND, MVT::bf16, Expand);
+
+ setOperationAction(ISD::SELECT, MVT::bf16, Promote);
+ AddPromotedToType(ISD::SELECT, MVT::bf16, MVT::i16);
+
+ // TODO: Could make these legal
+ setOperationAction(ISD::FABS, MVT::bf16, Expand);
+ setOperationAction(ISD::FNEG, MVT::bf16, Expand);
+ setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand);
+
+ // We only need to custom lower because we can't specify an action for bf16
+ // sources.
+ setOperationAction(ISD::FP_TO_SINT, MVT::i32, Custom);
+ setOperationAction(ISD::FP_TO_UINT, MVT::i32, Custom);
+
+ setOperationAction(ISD::BUILD_VECTOR, MVT::v2bf16, Promote);
+ AddPromotedToType(ISD::BUILD_VECTOR, MVT::v2bf16, MVT::v2i16);
+ }
+
setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand);
setTruncStoreAction(MVT::v3i32, MVT::v3i16, Expand);
setTruncStoreAction(MVT::v4i32, MVT::v4i16, Expand);
@@ -388,8 +426,8 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
// Avoid stack access for these.
// TODO: Generalize to more vector types.
setOperationAction({ISD::EXTRACT_VECTOR_ELT, ISD::INSERT_VECTOR_ELT},
- {MVT::v2i16, MVT::v2f16, MVT::v2i8, MVT::v4i8, MVT::v8i8,
- MVT::v4i16, MVT::v4f16},
+ {MVT::v2i16, MVT::v2f16, MVT::v2bf16, MVT::v2i8, MVT::v4i8,
+ MVT::v8i8, MVT::v4i16, MVT::v4f16},
Custom);
// Deal with vec3 vector operations when widened to vec4.
@@ -498,6 +536,11 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
setOperationAction(ISD::BF16_TO_FP, {MVT::i16, MVT::f32, MVT::f64}, Expand);
setOperationAction(ISD::FP_TO_BF16, {MVT::i16, MVT::f32, MVT::f64}, Expand);
+ // Custom lower these because we can't specify a rule based on an illegal
+ // source bf16.
+ setOperationAction({ISD::FP_EXTEND, ISD::STRICT_FP_EXTEND}, MVT::f32, Custom);
+ setOperationAction({ISD::FP_EXTEND, ISD::STRICT_FP_EXTEND}, MVT::f64, Custom);
+
if (Subtarget->has16BitInsts()) {
setOperationAction({ISD::Constant, ISD::SMIN, ISD::SMAX, ISD::UMIN,
ISD::UMAX, ISD::UADDSAT, ISD::USUBSAT},
@@ -524,9 +567,14 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
AddPromotedToType(ISD::FP_TO_FP16, MVT::i16, MVT::i32);
setOperationAction({ISD::FP_TO_SINT, ISD::FP_TO_UINT}, MVT::i16, Custom);
+ setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i16, Custom);
+ setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i16, Custom);
+
+ setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i32, Custom);
// F16 - Constant Actions.
setOperationAction(ISD::ConstantFP, MVT::f16, Legal);
+ setOperationAction(ISD::ConstantFP, MVT::bf16, Legal);
// F16 - Load/Store Actions.
setOperationAction(ISD::LOAD, MVT::f16, Promote);
@@ -534,16 +582,23 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
setOperationAction(ISD::STORE, MVT::f16, Promote);
AddPromotedToType(ISD::STORE, MVT::f16, MVT::i16);
+ // BF16 - Load/Store Actions.
+ setOperationAction(ISD::LOAD, MVT::bf16, Promote);
+ AddPromotedToType(ISD::LOAD, MVT::bf16, MVT::i16);
+ setOperationAction(ISD::STORE, MVT::bf16, Promote);
+ AddPromotedToType(ISD::STORE, MVT::bf16, MVT::i16);
+
// F16 - VOP1 Actions.
setOperationAction({ISD::FP_ROUND, ISD::STRICT_FP_ROUND, ISD::FCOS,
ISD::FSIN, ISD::FROUND, ISD::FPTRUNC_ROUND},
MVT::f16, Custom);
- setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i16, Custom);
setOperationAction({ISD::FP_TO_SINT, ISD::FP_TO_UINT}, MVT::f16, Promote);
+ setOperationAction({ISD::FP_TO_SINT, ISD::FP_TO_UINT}, MVT::bf16, Promote);
// F16 - VOP2 Actions.
- setOperationAction({ISD::BR_CC, ISD::SELECT_CC}, MVT::f16, Expand);
+ setOperationAction({ISD::BR_CC, ISD::SELECT_CC}, {MVT::f16, MVT::bf16},
+ Expand);
setOperationAction({ISD::FLDEXP, ISD::STRICT_FLDEXP}, MVT::f16, Custom);
setOperationAction(ISD::FFREXP, MVT::f16, Custom);
setOperationAction(ISD::FDIV, MVT::f16, Custom);
@@ -554,8 +609,9 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FMAD, MVT::f16, Legal);
for (MVT VT :
- {MVT::v2i16, MVT::v2f16, MVT::v4i16, MVT::v4f16, MVT::v8i16,
- MVT::v8f16, MVT::v16i16, MVT::v16f16, MVT::v32i16, MVT::v32f16}) {
+ {MVT::v2i16, MVT::v2f16, MVT::v2bf16, MVT::v4i16, MVT::v4f16,
+ MVT::v4bf16, MVT::v8i16, MVT::v8f16, MVT::v8bf16, MVT::v16i16,
+ MVT::v16f16, MVT::v16bf16, MVT::v32i16, MVT::v32f16}) {
for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op) {
switch (Op) {
case ISD::LOAD:
@@ -587,7 +643,8 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
// XXX - Do these do anything? Vector constants turn into build_vector.
setOperationAction(ISD::Constant, {MVT::v2i16, MVT::v2f16}, Legal);
- setOperationAction(ISD::UNDEF, {MVT::v2i16, MVT::v2f16}, Legal);
+ setOperationAction(ISD::UNDEF, {MVT::v2i16, MVT::v2f16, MVT::v2bf16},
+ Legal);
setOperationAction(ISD::STORE, MVT::v2i16, Promote);
AddPromotedToType(ISD::STORE, MVT::v2i16, MVT::i32);
@@ -3901,6 +3958,26 @@ SDValue SITargetLowering::lowerPREFETCH(SDValue Op, SelectionDAG &DAG) const {
return Op;
}
+// Work around DAG legality rules only based on the result type.
+SDValue SITargetLowering::lowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const {
+ bool IsStrict = Op.getOpcode() == ISD::STRICT_FP_EXTEND;
+ SDValue Src = Op.getOperand(IsStrict ? 1 : 0);
+ EVT S...
[truncated]
|
f533b99
to
676ef60
Compare
You can test this locally with the following command:git-clang-format --diff 25cd249355b0f3192ca5b0c69514ad68a1cb8897 da874220064ba6ba8fd02a50aaccd67ca726b23e -- llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp llvm/lib/Target/AMDGPU/SIISelLowering.cpp llvm/lib/Target/AMDGPU/SIISelLowering.h View the diff from clang-format here.diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 91f347be68..d917e84739 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -756,8 +756,8 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
ISD::FMAXNUM_IEEE, ISD::FCANONICALIZE},
MVT::v2f16, Legal);
- setOperationAction(ISD::EXTRACT_VECTOR_ELT, {MVT::v2i16, MVT::v2f16, MVT::v2bf16},
- Custom);
+ setOperationAction(ISD::EXTRACT_VECTOR_ELT,
+ {MVT::v2i16, MVT::v2f16, MVT::v2bf16}, Custom);
setOperationAction(ISD::VECTOR_SHUFFLE,
{MVT::v4f16, MVT::v4i16, MVT::v8f16, MVT::v8i16,
|
Assorted intrinsics are currently using i16 in place of a proper bfloat type, but they should really switch to bfloat. Note this only changes the type lists in tablegen, these are still not registered to be truly treated as a legal type yet.
There are some intrinsics are using i16 vectors in place of bfloat vectors. Move towards making bf16 vectors legal so these can migrate. Leave the larger vectors for a later change.
676ef60
to
da87422
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test changes are a bit hard to review so I didn't look at all of them
Do we need a GISel equivalent of this patch?
@@ -3901,6 +3958,26 @@ SDValue SITargetLowering::lowerPREFETCH(SDValue Op, SelectionDAG &DAG) const { | |||
return Op; | |||
} | |||
|
|||
// Work around DAG legality rules only based on the result type. | |||
SDValue SITargetLowering::lowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const { | |||
bool IsStrict = Op.getOpcode() == ISD::STRICT_FP_EXTEND; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tiny tiny nit: put parentheses around the initializer - =
and ==
on each side of an expression is just a bit confusing at first glance
@@ -1021,13 +1021,19 @@ define half @fmed3_f32_fpext_bf16(bfloat %arg0, bfloat %arg1, bfloat %arg2) #1 { | |||
; GFX8-LABEL: fmed3_f32_fpext_bf16: | |||
; GFX8: ; %bb.0: | |||
; GFX8-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) | |||
; GFX8-NEXT: v_lshlrev_b32_e32 v0, 16, v0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have these now?
No, legal types are more of a DAG concept |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, although I do not believe we can finally make bfloat16 legal in an FE now, we will still have a lot of emulation for many operations. I.e. if the final plan is to switch from i16 to bf16 in intrinsics and builtins, that is hardly possible. At best we will need to start with overloaded interfaces supporting both.
That's what the comment says but everything I tried works by promotion to float. I expanded the test to cover just about everything |
There are some intrinsics are using i16 vectors in place of bfloat vectors.
Move towards making bf16 vectors legal so these can migrate. Leave the
larger vectors for a later change.
Depends #76213 #76214