Skip to content

AMDGPU: Make bf16/v2bf16 legal types #76215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 4, 2024
Merged

Conversation

arsenm
Copy link
Contributor

@arsenm arsenm commented Dec 22, 2023

There are some intrinsics are using i16 vectors in place of bfloat vectors.
Move towards making bf16 vectors legal so these can migrate. Leave the
larger vectors for a later change.

Depends #76213 #76214

@llvmbot llvmbot added the llvm:SelectionDAG SelectionDAGISel as well label Dec 22, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 22, 2023

@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-amdgpu

Author: Matt Arsenault (arsenm)

Changes

There are some intrinsics are using i16 vectors in place of bfloat vectors.
Move towards making bf16 vectors legal so these can migrate. Leave the
larger vectors for a later change.

Depends #76213 #76214


Patch is 1.11 MiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76215.diff

28 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+28-8)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td (+13-13)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp (+1)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp (+32-5)
  • (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+88-8)
  • (modified) llvm/lib/Target/AMDGPU/SIISelLowering.h (+1)
  • (modified) llvm/lib/Target/AMDGPU/SIInstructions.td (+46-6)
  • (modified) llvm/lib/Target/AMDGPU/SIRegisterInfo.td (+31-31)
  • (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+11-11)
  • (modified) llvm/test/CodeGen/AMDGPU/bf16.ll (+6030-7324)
  • (modified) llvm/test/CodeGen/AMDGPU/fcopysign.f32.ll (+5-7)
  • (modified) llvm/test/CodeGen/AMDGPU/fmed3-cast-combine.ll (+14-2)
  • (modified) llvm/test/CodeGen/AMDGPU/fneg-modifier-casting.ll (+5-5)
  • (modified) llvm/test/CodeGen/AMDGPU/function-args-inreg.ll (+2-2)
  • (modified) llvm/test/CodeGen/AMDGPU/function-args.ll (+78-89)
  • (modified) llvm/test/CodeGen/AMDGPU/function-returns.ll (+19-46)
  • (modified) llvm/test/CodeGen/AMDGPU/gfx-callable-argument-types.ll (+48-80)
  • (modified) llvm/test/CodeGen/AMDGPU/isel-amdgpu-cs-chain-preserve-cc.ll (+6-10)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.exp.ll (+13-12)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.exp10.ll (+13-12)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.exp2.ll (+43-15)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.bf16.ll (+113-289)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.log.ll (+7-4)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.log10.ll (+7-4)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.log2.ll (+46-16)
  • (modified) llvm/test/CodeGen/AMDGPU/local-atomics-fp.ll (+2-2)
  • (modified) llvm/test/CodeGen/AMDGPU/select-undef.ll (+2-1)
  • (modified) llvm/test/CodeGen/AMDGPU/vector_shuffle.packed.ll (+258-592)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index a483b8028fda9e..296ed3a3c3dc11 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -3199,7 +3199,16 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
       return true;
     }
     break;
-  case ISD::FP_ROUND:
+  case ISD::FP_ROUND: {
+    EVT VT = Node->getValueType(0);
+    if (VT.getScalarType() == MVT::bf16) {
+      Results.push_back(
+          DAG.getNode(ISD::FP_TO_BF16, SDLoc(Node), VT, Node->getOperand(0)));
+      break;
+    }
+
+    LLVM_FALLTHROUGH;
+  }
   case ISD::BITCAST:
     if ((Tmp1 = EmitStackConvert(Node->getOperand(0), Node->getValueType(0),
                                  Node->getValueType(0), dl)))
@@ -3226,12 +3235,19 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
       return true;
     }
     break;
-  case ISD::FP_EXTEND:
-    if ((Tmp1 = EmitStackConvert(Node->getOperand(0),
-                                 Node->getOperand(0).getValueType(),
-                                 Node->getValueType(0), dl)))
+  case ISD::FP_EXTEND: {
+    SDValue Op = Node->getOperand(0);
+    EVT SrcVT = Op.getValueType();
+    EVT DstVT = Node->getValueType(0);
+    if (SrcVT.getScalarType() == MVT::bf16) {
+      Results.push_back(DAG.getNode(ISD::BF16_TO_FP, SDLoc(Node), DstVT, Op));
+      break;
+    }
+
+    if ((Tmp1 = EmitStackConvert(Op, SrcVT, DstVT, dl)))
       Results.push_back(Tmp1);
     break;
+  }
   case ISD::BF16_TO_FP: {
     // Always expand bf16 to f32 casts, they lower to ext + shift.
     //
@@ -4908,7 +4924,9 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) {
 static MVT getPromotedVectorElementType(const TargetLowering &TLI,
                                         MVT EltVT, MVT NewEltVT) {
   unsigned OldEltsPerNewElt = EltVT.getSizeInBits() / NewEltVT.getSizeInBits();
-  MVT MidVT = MVT::getVectorVT(NewEltVT, OldEltsPerNewElt);
+  MVT MidVT = OldEltsPerNewElt == 1
+                  ? NewEltVT
+                  : MVT::getVectorVT(NewEltVT, OldEltsPerNewElt);
   assert(TLI.isTypeLegal(MidVT) && "unexpected");
   return MidVT;
 }
@@ -5395,7 +5413,7 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
 
     assert(NVT.isVector() && OVT.getSizeInBits() == NVT.getSizeInBits() &&
            "Invalid promote type for build_vector");
-    assert(NewEltVT.bitsLT(EltVT) && "not handled");
+    assert(NewEltVT.bitsLE(EltVT) && "not handled");
 
     MVT MidVT = getPromotedVectorElementType(TLI, EltVT, NewEltVT);
 
@@ -5406,7 +5424,9 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
     }
 
     SDLoc SL(Node);
-    SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, SL, NVT, NewOps);
+    SDValue Concat =
+        DAG.getNode(MidVT == NewEltVT ? ISD::BUILD_VECTOR : ISD::CONCAT_VECTORS,
+                    SL, NVT, NewOps);
     SDValue CvtVec = DAG.getNode(ISD::BITCAST, SL, OVT, Concat);
     Results.push_back(CvtVec);
     break;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td b/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td
index 9036b26a6f6bcb..c5207228dc913f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td
@@ -22,28 +22,28 @@ def CC_SI_Gfx : CallingConv<[
   // 32 is reserved for the stack pointer
   // 33 is reserved for the frame pointer
   // 34 is reserved for the base pointer
-  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     SGPR4, SGPR5, SGPR6, SGPR7,
     SGPR8, SGPR9, SGPR10, SGPR11, SGPR12, SGPR13, SGPR14, SGPR15,
     SGPR16, SGPR17, SGPR18, SGPR19, SGPR20, SGPR21, SGPR22, SGPR23,
     SGPR24, SGPR25, SGPR26, SGPR27, SGPR28, SGPR29
   ]>>>,
 
-  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
     VGPR24, VGPR25, VGPR26, VGPR27, VGPR28, VGPR29, VGPR30, VGPR31
   ]>>>,
 
-  CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1], CCAssignToStack<4, 4>>
+  CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1, bf16, v2bf16], CCAssignToStack<4, 4>>
 ]>;
 
 def RetCC_SI_Gfx : CallingConv<[
   CCIfType<[i1], CCPromoteToType<i32>>,
   CCIfType<[i1, i16], CCIfExtend<CCPromoteToType<i32>>>,
 
-  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -66,7 +66,7 @@ def RetCC_SI_Gfx : CallingConv<[
 
 def CC_SI_SHADER : CallingConv<[
 
-  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     SGPR0, SGPR1, SGPR2, SGPR3, SGPR4, SGPR5, SGPR6, SGPR7,
     SGPR8, SGPR9, SGPR10, SGPR11, SGPR12, SGPR13, SGPR14, SGPR15,
     SGPR16, SGPR17, SGPR18, SGPR19, SGPR20, SGPR21, SGPR22, SGPR23,
@@ -76,7 +76,7 @@ def CC_SI_SHADER : CallingConv<[
   ]>>>,
 
   // 32*4 + 4 is the minimum for a fetch shader consumer with 32 inputs.
-  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -109,7 +109,7 @@ def RetCC_SI_Shader : CallingConv<[
   ]>>,
 
   // 32*4 + 4 is the minimum for a fetch shader with 32 outputs.
-  CCIfType<[f32, f16, v2f16] , CCAssignToReg<[
+  CCIfType<[f32, f16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -188,23 +188,23 @@ def CC_AMDGPU_Func : CallingConv<[
   CCIfType<[i1], CCPromoteToType<i32>>,
   CCIfType<[i8, i16], CCIfExtend<CCPromoteToType<i32>>>,
 
-  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
+  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
     !foreach(i, !range(0, 30), !cast<Register>("SGPR"#i))  // SGPR0-29
   >>>,
 
-  CCIfType<[i32, f32, i16, f16, v2i16, v2f16, i1], CCAssignToReg<[
+  CCIfType<[i32, f32, i16, f16, v2i16, v2f16, i1, bf16, v2bf16], CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
     VGPR24, VGPR25, VGPR26, VGPR27, VGPR28, VGPR29, VGPR30, VGPR31]>>,
-  CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1], CCAssignToStack<4, 4>>
+  CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1, bf16, v2bf16], CCAssignToStack<4, 4>>
 ]>;
 
 // Calling convention for leaf functions
 def RetCC_AMDGPU_Func : CallingConv<[
   CCIfType<[i1], CCPromoteToType<i32>>,
   CCIfType<[i1, i16], CCIfExtend<CCPromoteToType<i32>>>,
-  CCIfType<[i32, f32, i16, f16, v2i16, v2f16], CCAssignToReg<[
+  CCIfType<[i32, f32, i16, f16, v2i16, v2f16, bf16, v2bf16], CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -223,11 +223,11 @@ def CC_AMDGPU : CallingConv<[
 ]>;
 
 def CC_AMDGPU_CS_CHAIN : CallingConv<[
-  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
+  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
     !foreach(i, !range(105), !cast<Register>("SGPR"#i))
   >>>,
 
-  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
+  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
     !foreach(i, !range(8, 255), !cast<Register>("VGPR"#i))
   >>>
 ]>;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
index b0eac567ec9f18..40a49cbe3f518f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
@@ -303,6 +303,7 @@ void AMDGPUDAGToDAGISel::PreprocessISelDAG() {
 
     switch (N->getOpcode()) {
     case ISD::BUILD_VECTOR:
+      // TODO: Match load d16 from shl (extload:i16), 16
       MadeChange |= matchLoadD16FromBuildVector(N);
       break;
     default:
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index 4bf4707553e5fe..131000830a73d5 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -3274,7 +3274,15 @@ SDValue AMDGPUTargetLowering::LowerUINT_TO_FP(SDValue Op,
     return DAG.getNode(ISD::UINT_TO_FP, DL, DestVT, Ext);
   }
 
-  assert(SrcVT == MVT::i64 && "operation should be legal");
+  if (DestVT == MVT::bf16) {
+    SDLoc SL(Op);
+    SDValue ToF32 = DAG.getNode(ISD::UINT_TO_FP, SL, MVT::f32, Src);
+    SDValue FPRoundFlag = DAG.getIntPtrConstant(0, SL, /*isTarget=*/true);
+    return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ToF32, FPRoundFlag);
+  }
+
+  if (SrcVT != MVT::i64)
+    return Op;
 
   if (Subtarget->has16BitInsts() && DestVT == MVT::f16) {
     SDLoc DL(Op);
@@ -3312,7 +3320,15 @@ SDValue AMDGPUTargetLowering::LowerSINT_TO_FP(SDValue Op,
     return DAG.getNode(ISD::SINT_TO_FP, DL, DestVT, Ext);
   }
 
-  assert(SrcVT == MVT::i64 && "operation should be legal");
+  if (DestVT == MVT::bf16) {
+    SDLoc SL(Op);
+    SDValue ToF32 = DAG.getNode(ISD::SINT_TO_FP, SL, MVT::f32, Src);
+    SDValue FPRoundFlag = DAG.getIntPtrConstant(0, SL, /*isTarget=*/true);
+    return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ToF32, FPRoundFlag);
+  }
+
+  if (SrcVT != MVT::i64)
+    return Op;
 
   // TODO: Factor out code common with LowerUINT_TO_FP.
 
@@ -3510,7 +3526,7 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) con
   return DAG.getZExtOrTrunc(V, DL, Op.getValueType());
 }
 
-SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
+SDValue AMDGPUTargetLowering::LowerFP_TO_INT(const SDValue Op,
                                              SelectionDAG &DAG) const {
   SDValue Src = Op.getOperand(0);
   unsigned OpOpcode = Op.getOpcode();
@@ -3521,6 +3537,12 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
   if (SrcVT == MVT::f16 && DestVT == MVT::i16)
     return Op;
 
+  if (SrcVT == MVT::bf16) {
+    SDLoc DL(Op);
+    SDValue PromotedSrc = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Src);
+    return DAG.getNode(Op.getOpcode(), DL, DestVT, PromotedSrc);
+  }
+
   // Promote i16 to i32
   if (DestVT == MVT::i16 && (SrcVT == MVT::f32 || SrcVT == MVT::f64)) {
     SDLoc DL(Op);
@@ -3529,6 +3551,9 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
     return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToInt32);
   }
 
+  if (DestVT != MVT::i64)
+    return Op;
+
   if (SrcVT == MVT::f16 ||
       (SrcVT == MVT::f32 && Src.getOpcode() == ISD::FP16_TO_FP)) {
     SDLoc DL(Op);
@@ -3539,7 +3564,7 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
     return DAG.getNode(Ext, DL, MVT::i64, FpToInt32);
   }
 
-  if (DestVT == MVT::i64 && (SrcVT == MVT::f32 || SrcVT == MVT::f64))
+  if (SrcVT == MVT::f32 || SrcVT == MVT::f64)
     return LowerFP_TO_INT64(Op, DAG, OpOpcode == ISD::FP_TO_SINT);
 
   return SDValue();
@@ -4940,7 +4965,9 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N,
     //   vnt1 = build_vector (t1 (bitcast t0:x)), (t1 (bitcast t0:y))
     if (DestVT.isVector()) {
       SDValue Src = N->getOperand(0);
-      if (Src.getOpcode() == ISD::BUILD_VECTOR) {
+      if (Src.getOpcode() == ISD::BUILD_VECTOR &&
+          (DCI.getDAGCombineLevel() < AfterLegalizeDAG ||
+           isOperationLegal(ISD::BUILD_VECTOR, DestVT))) {
         EVT SrcVT = Src.getValueType();
         unsigned NElts = DestVT.getVectorNumElements();
 
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index fc119aa61d01a2..eaf32850a87149 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -151,14 +151,17 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
     if (Subtarget->useRealTrue16Insts()) {
       addRegisterClass(MVT::i16, &AMDGPU::VGPR_16RegClass);
       addRegisterClass(MVT::f16, &AMDGPU::VGPR_16RegClass);
+      addRegisterClass(MVT::bf16, &AMDGPU::VGPR_16RegClass);
     } else {
       addRegisterClass(MVT::i16, &AMDGPU::SReg_32RegClass);
       addRegisterClass(MVT::f16, &AMDGPU::SReg_32RegClass);
+      addRegisterClass(MVT::bf16, &AMDGPU::SReg_32RegClass);
     }
 
     // Unless there are also VOP3P operations, not operations are really legal.
     addRegisterClass(MVT::v2i16, &AMDGPU::SReg_32RegClass);
     addRegisterClass(MVT::v2f16, &AMDGPU::SReg_32RegClass);
+    addRegisterClass(MVT::v2bf16, &AMDGPU::SReg_32RegClass);
     addRegisterClass(MVT::v4i16, &AMDGPU::SReg_64RegClass);
     addRegisterClass(MVT::v4f16, &AMDGPU::SReg_64RegClass);
     addRegisterClass(MVT::v8i16, &AMDGPU::SGPR_128RegClass);
@@ -196,6 +199,41 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
                       MVT::i1,     MVT::v32i32},
                      Custom);
 
+  if (isTypeLegal(MVT::bf16)) {
+    for (unsigned Opc :
+         {ISD::FADD,     ISD::FSUB,       ISD::FMUL,    ISD::FDIV,
+          ISD::FREM,     ISD::FMA,        ISD::FMINNUM, ISD::FMAXNUM,
+          ISD::FMINIMUM, ISD::FMAXIMUM,   ISD::FSQRT,   ISD::FCBRT,
+          ISD::FSIN,     ISD::FCOS,       ISD::FPOW,    ISD::FPOWI,
+          ISD::FLDEXP,   ISD::FFREXP,     ISD::FLOG,    ISD::FLOG2,
+          ISD::FLOG10,   ISD::FEXP,       ISD::FEXP2,   ISD::FEXP10,
+          ISD::FCEIL,    ISD::FTRUNC,     ISD::FRINT,   ISD::FNEARBYINT,
+          ISD::FROUND,   ISD::FROUNDEVEN, ISD::FFLOOR,  ISD::FCANONICALIZE,
+          ISD::SETCC}) {
+      // FIXME: The promoted to type shouldn't need to be explicit
+      setOperationAction(Opc, MVT::bf16, Promote);
+      AddPromotedToType(Opc, MVT::bf16, MVT::f32);
+    }
+
+    setOperationAction(ISD::FP_ROUND, MVT::bf16, Expand);
+
+    setOperationAction(ISD::SELECT, MVT::bf16, Promote);
+    AddPromotedToType(ISD::SELECT, MVT::bf16, MVT::i16);
+
+    // TODO: Could make these legal
+    setOperationAction(ISD::FABS, MVT::bf16, Expand);
+    setOperationAction(ISD::FNEG, MVT::bf16, Expand);
+    setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand);
+
+    // We only need to custom lower because we can't specify an action for bf16
+    // sources.
+    setOperationAction(ISD::FP_TO_SINT, MVT::i32, Custom);
+    setOperationAction(ISD::FP_TO_UINT, MVT::i32, Custom);
+
+    setOperationAction(ISD::BUILD_VECTOR, MVT::v2bf16, Promote);
+    AddPromotedToType(ISD::BUILD_VECTOR, MVT::v2bf16, MVT::v2i16);
+  }
+
   setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand);
   setTruncStoreAction(MVT::v3i32, MVT::v3i16, Expand);
   setTruncStoreAction(MVT::v4i32, MVT::v4i16, Expand);
@@ -388,8 +426,8 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
   // Avoid stack access for these.
   // TODO: Generalize to more vector types.
   setOperationAction({ISD::EXTRACT_VECTOR_ELT, ISD::INSERT_VECTOR_ELT},
-                     {MVT::v2i16, MVT::v2f16, MVT::v2i8, MVT::v4i8, MVT::v8i8,
-                      MVT::v4i16, MVT::v4f16},
+                     {MVT::v2i16, MVT::v2f16, MVT::v2bf16, MVT::v2i8, MVT::v4i8,
+                      MVT::v8i8, MVT::v4i16, MVT::v4f16},
                      Custom);
 
   // Deal with vec3 vector operations when widened to vec4.
@@ -498,6 +536,11 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
   setOperationAction(ISD::BF16_TO_FP, {MVT::i16, MVT::f32, MVT::f64}, Expand);
   setOperationAction(ISD::FP_TO_BF16, {MVT::i16, MVT::f32, MVT::f64}, Expand);
 
+  // Custom lower these because we can't specify a rule based on an illegal
+  // source bf16.
+  setOperationAction({ISD::FP_EXTEND, ISD::STRICT_FP_EXTEND}, MVT::f32, Custom);
+  setOperationAction({ISD::FP_EXTEND, ISD::STRICT_FP_EXTEND}, MVT::f64, Custom);
+
   if (Subtarget->has16BitInsts()) {
     setOperationAction({ISD::Constant, ISD::SMIN, ISD::SMAX, ISD::UMIN,
                         ISD::UMAX, ISD::UADDSAT, ISD::USUBSAT},
@@ -524,9 +567,14 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
     AddPromotedToType(ISD::FP_TO_FP16, MVT::i16, MVT::i32);
 
     setOperationAction({ISD::FP_TO_SINT, ISD::FP_TO_UINT}, MVT::i16, Custom);
+    setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i16, Custom);
+    setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i16, Custom);
+
+    setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i32, Custom);
 
     // F16 - Constant Actions.
     setOperationAction(ISD::ConstantFP, MVT::f16, Legal);
+    setOperationAction(ISD::ConstantFP, MVT::bf16, Legal);
 
     // F16 - Load/Store Actions.
     setOperationAction(ISD::LOAD, MVT::f16, Promote);
@@ -534,16 +582,23 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
     setOperationAction(ISD::STORE, MVT::f16, Promote);
     AddPromotedToType(ISD::STORE, MVT::f16, MVT::i16);
 
+    // BF16 - Load/Store Actions.
+    setOperationAction(ISD::LOAD, MVT::bf16, Promote);
+    AddPromotedToType(ISD::LOAD, MVT::bf16, MVT::i16);
+    setOperationAction(ISD::STORE, MVT::bf16, Promote);
+    AddPromotedToType(ISD::STORE, MVT::bf16, MVT::i16);
+
     // F16 - VOP1 Actions.
     setOperationAction({ISD::FP_ROUND, ISD::STRICT_FP_ROUND, ISD::FCOS,
                         ISD::FSIN, ISD::FROUND, ISD::FPTRUNC_ROUND},
                        MVT::f16, Custom);
 
-    setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i16, Custom);
     setOperationAction({ISD::FP_TO_SINT, ISD::FP_TO_UINT}, MVT::f16, Promote);
+    setOperationAction({ISD::FP_TO_SINT, ISD::FP_TO_UINT}, MVT::bf16, Promote);
 
     // F16 - VOP2 Actions.
-    setOperationAction({ISD::BR_CC, ISD::SELECT_CC}, MVT::f16, Expand);
+    setOperationAction({ISD::BR_CC, ISD::SELECT_CC}, {MVT::f16, MVT::bf16},
+                       Expand);
     setOperationAction({ISD::FLDEXP, ISD::STRICT_FLDEXP}, MVT::f16, Custom);
     setOperationAction(ISD::FFREXP, MVT::f16, Custom);
     setOperationAction(ISD::FDIV, MVT::f16, Custom);
@@ -554,8 +609,9 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::FMAD, MVT::f16, Legal);
 
     for (MVT VT :
-         {MVT::v2i16, MVT::v2f16, MVT::v4i16, MVT::v4f16, MVT::v8i16,
-          MVT::v8f16, MVT::v16i16, MVT::v16f16, MVT::v32i16, MVT::v32f16}) {
+         {MVT::v2i16, MVT::v2f16, MVT::v2bf16, MVT::v4i16, MVT::v4f16,
+          MVT::v4bf16, MVT::v8i16, MVT::v8f16, MVT::v8bf16, MVT::v16i16,
+          MVT::v16f16, MVT::v16bf16, MVT::v32i16, MVT::v32f16}) {
       for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op) {
         switch (Op) {
         case ISD::LOAD:
@@ -587,7 +643,8 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
     // XXX - Do these do anything? Vector constants turn into build_vector.
     setOperationAction(ISD::Constant, {MVT::v2i16, MVT::v2f16}, Legal);
 
-    setOperationAction(ISD::UNDEF, {MVT::v2i16, MVT::v2f16}, Legal);
+    setOperationAction(ISD::UNDEF, {MVT::v2i16, MVT::v2f16, MVT::v2bf16},
+                       Legal);
 
     setOperationAction(ISD::STORE, MVT::v2i16, Promote);
     AddPromotedToType(ISD::STORE, MVT::v2i16, MVT::i32);
@@ -3901,6 +3958,26 @@ SDValue SITargetLowering::lowerPREFETCH(SDValue Op, SelectionDAG &DAG) const {
   return Op;
 }
 
+// Work around DAG legality rules only based on the result type.
+SDValue SITargetLowering::lowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const {
+  bool IsStrict = Op.getOpcode() == ISD::STRICT_FP_EXTEND;
+  SDValue Src = Op.getOperand(IsStrict ? 1 : 0);
+  EVT S...
[truncated]

Copy link

github-actions bot commented Dec 27, 2023

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 25cd249355b0f3192ca5b0c69514ad68a1cb8897 da874220064ba6ba8fd02a50aaccd67ca726b23e -- llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp llvm/lib/Target/AMDGPU/SIISelLowering.cpp llvm/lib/Target/AMDGPU/SIISelLowering.h
View the diff from clang-format here.
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 91f347be68..d917e84739 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -756,8 +756,8 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
                         ISD::FMAXNUM_IEEE, ISD::FCANONICALIZE},
                        MVT::v2f16, Legal);
 
-    setOperationAction(ISD::EXTRACT_VECTOR_ELT, {MVT::v2i16, MVT::v2f16, MVT::v2bf16},
-                       Custom);
+    setOperationAction(ISD::EXTRACT_VECTOR_ELT,
+                       {MVT::v2i16, MVT::v2f16, MVT::v2bf16}, Custom);
 
     setOperationAction(ISD::VECTOR_SHUFFLE,
                        {MVT::v4f16, MVT::v4i16, MVT::v8f16, MVT::v8i16,

arsenm added 2 commits January 1, 2024 21:24
Assorted intrinsics are currently using i16 in place of a proper
bfloat type, but they should really switch to bfloat.

Note this only changes the type lists in tablegen, these are still
not registered to be truly treated as a legal type yet.
There are some intrinsics are using i16 vectors in place of bfloat vectors.
Move towards making bf16 vectors legal so these can migrate. Leave the
larger vectors for a later change.
@arsenm arsenm force-pushed the bf16/legal-bf16-v2bf16 branch from 676ef60 to da87422 Compare January 1, 2024 14:26
Copy link
Contributor

@Pierre-vh Pierre-vh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test changes are a bit hard to review so I didn't look at all of them

Do we need a GISel equivalent of this patch?

@@ -3901,6 +3958,26 @@ SDValue SITargetLowering::lowerPREFETCH(SDValue Op, SelectionDAG &DAG) const {
return Op;
}

// Work around DAG legality rules only based on the result type.
SDValue SITargetLowering::lowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const {
bool IsStrict = Op.getOpcode() == ISD::STRICT_FP_EXTEND;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny tiny nit: put parentheses around the initializer - = and == on each side of an expression is just a bit confusing at first glance

@@ -1021,13 +1021,19 @@ define half @fmed3_f32_fpext_bf16(bfloat %arg0, bfloat %arg1, bfloat %arg2) #1 {
; GFX8-LABEL: fmed3_f32_fpext_bf16:
; GFX8: ; %bb.0:
; GFX8-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; GFX8-NEXT: v_lshlrev_b32_e32 v0, 16, v0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have these now?

@arsenm
Copy link
Contributor Author

arsenm commented Jan 2, 2024

No, legal types are more of a DAG concept

Copy link
Collaborator

@rampitec rampitec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, although I do not believe we can finally make bfloat16 legal in an FE now, we will still have a lot of emulation for many operations. I.e. if the final plan is to switch from i16 to bf16 in intrinsics and builtins, that is hardly possible. At best we will need to start with overloaded interfaces supporting both.

@arsenm
Copy link
Contributor Author

arsenm commented Jan 2, 2024

That's what the comment says but everything I tried works by promotion to float. I expanded the test to cover just about everything

@arsenm arsenm merged commit 460ffcd into llvm:main Jan 4, 2024
@arsenm arsenm deleted the bf16/legal-bf16-v2bf16 branch January 4, 2024 15:31
arsenm added a commit that referenced this pull request Jan 5, 2024
Gets a few code quality improvements. A few cases are worse
from losing load narrowing.
Depends #76213 #76214 #76215
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AMDGPU llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants