Skip to content

AMDGPU: Make v4bf16 a legal type #76217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 5, 2024
Merged

AMDGPU: Make v4bf16 a legal type #76217

merged 5 commits into from
Jan 5, 2024

Conversation

arsenm
Copy link
Contributor

@arsenm arsenm commented Dec 22, 2023

Gets a few code quality improvements. A few cases are worse
from losing load narrowing.
Depends #76213 #76214 #76215

@llvmbot llvmbot added the llvm:SelectionDAG SelectionDAGISel as well label Dec 22, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 22, 2023

@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-amdgpu

Author: Matt Arsenault (arsenm)

Changes

Gets a few code quality improvements. A few cases are worse
from losing load narrowing.
Depends #76213 #76214 #76215


Patch is 1.08 MiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76217.diff

27 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+28-8)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td (+13-13)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp (+1)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp (+42-14)
  • (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+101-15)
  • (modified) llvm/lib/Target/AMDGPU/SIISelLowering.h (+1)
  • (modified) llvm/lib/Target/AMDGPU/SIInstructions.td (+59-6)
  • (modified) llvm/lib/Target/AMDGPU/SIRegisterInfo.td (+31-31)
  • (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+11-11)
  • (modified) llvm/test/CodeGen/AMDGPU/bf16.ll (+5121-6738)
  • (modified) llvm/test/CodeGen/AMDGPU/fcopysign.f32.ll (+5-7)
  • (modified) llvm/test/CodeGen/AMDGPU/fmed3-cast-combine.ll (+14-2)
  • (modified) llvm/test/CodeGen/AMDGPU/fneg-modifier-casting.ll (+5-5)
  • (modified) llvm/test/CodeGen/AMDGPU/function-args-inreg.ll (+2-2)
  • (modified) llvm/test/CodeGen/AMDGPU/function-args.ll (+70-88)
  • (modified) llvm/test/CodeGen/AMDGPU/function-returns.ll (+19-46)
  • (modified) llvm/test/CodeGen/AMDGPU/gfx-callable-argument-types.ll (+48-80)
  • (modified) llvm/test/CodeGen/AMDGPU/isel-amdgpu-cs-chain-preserve-cc.ll (+6-10)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.exp.ll (+13-12)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.exp10.ll (+13-12)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.exp2.ll (+43-15)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.bf16.ll (+157-346)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.log.ll (+7-4)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.log10.ll (+7-4)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.log2.ll (+46-16)
  • (modified) llvm/test/CodeGen/AMDGPU/local-atomics-fp.ll (+2-2)
  • (modified) llvm/test/CodeGen/AMDGPU/vector_shuffle.packed.ll (+362-629)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index a483b8028fda9e..296ed3a3c3dc11 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -3199,7 +3199,16 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
       return true;
     }
     break;
-  case ISD::FP_ROUND:
+  case ISD::FP_ROUND: {
+    EVT VT = Node->getValueType(0);
+    if (VT.getScalarType() == MVT::bf16) {
+      Results.push_back(
+          DAG.getNode(ISD::FP_TO_BF16, SDLoc(Node), VT, Node->getOperand(0)));
+      break;
+    }
+
+    LLVM_FALLTHROUGH;
+  }
   case ISD::BITCAST:
     if ((Tmp1 = EmitStackConvert(Node->getOperand(0), Node->getValueType(0),
                                  Node->getValueType(0), dl)))
@@ -3226,12 +3235,19 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
       return true;
     }
     break;
-  case ISD::FP_EXTEND:
-    if ((Tmp1 = EmitStackConvert(Node->getOperand(0),
-                                 Node->getOperand(0).getValueType(),
-                                 Node->getValueType(0), dl)))
+  case ISD::FP_EXTEND: {
+    SDValue Op = Node->getOperand(0);
+    EVT SrcVT = Op.getValueType();
+    EVT DstVT = Node->getValueType(0);
+    if (SrcVT.getScalarType() == MVT::bf16) {
+      Results.push_back(DAG.getNode(ISD::BF16_TO_FP, SDLoc(Node), DstVT, Op));
+      break;
+    }
+
+    if ((Tmp1 = EmitStackConvert(Op, SrcVT, DstVT, dl)))
       Results.push_back(Tmp1);
     break;
+  }
   case ISD::BF16_TO_FP: {
     // Always expand bf16 to f32 casts, they lower to ext + shift.
     //
@@ -4908,7 +4924,9 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) {
 static MVT getPromotedVectorElementType(const TargetLowering &TLI,
                                         MVT EltVT, MVT NewEltVT) {
   unsigned OldEltsPerNewElt = EltVT.getSizeInBits() / NewEltVT.getSizeInBits();
-  MVT MidVT = MVT::getVectorVT(NewEltVT, OldEltsPerNewElt);
+  MVT MidVT = OldEltsPerNewElt == 1
+                  ? NewEltVT
+                  : MVT::getVectorVT(NewEltVT, OldEltsPerNewElt);
   assert(TLI.isTypeLegal(MidVT) && "unexpected");
   return MidVT;
 }
@@ -5395,7 +5413,7 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
 
     assert(NVT.isVector() && OVT.getSizeInBits() == NVT.getSizeInBits() &&
            "Invalid promote type for build_vector");
-    assert(NewEltVT.bitsLT(EltVT) && "not handled");
+    assert(NewEltVT.bitsLE(EltVT) && "not handled");
 
     MVT MidVT = getPromotedVectorElementType(TLI, EltVT, NewEltVT);
 
@@ -5406,7 +5424,9 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
     }
 
     SDLoc SL(Node);
-    SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, SL, NVT, NewOps);
+    SDValue Concat =
+        DAG.getNode(MidVT == NewEltVT ? ISD::BUILD_VECTOR : ISD::CONCAT_VECTORS,
+                    SL, NVT, NewOps);
     SDValue CvtVec = DAG.getNode(ISD::BITCAST, SL, OVT, Concat);
     Results.push_back(CvtVec);
     break;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td b/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td
index 9036b26a6f6bcb..c5207228dc913f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td
@@ -22,28 +22,28 @@ def CC_SI_Gfx : CallingConv<[
   // 32 is reserved for the stack pointer
   // 33 is reserved for the frame pointer
   // 34 is reserved for the base pointer
-  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     SGPR4, SGPR5, SGPR6, SGPR7,
     SGPR8, SGPR9, SGPR10, SGPR11, SGPR12, SGPR13, SGPR14, SGPR15,
     SGPR16, SGPR17, SGPR18, SGPR19, SGPR20, SGPR21, SGPR22, SGPR23,
     SGPR24, SGPR25, SGPR26, SGPR27, SGPR28, SGPR29
   ]>>>,
 
-  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
     VGPR24, VGPR25, VGPR26, VGPR27, VGPR28, VGPR29, VGPR30, VGPR31
   ]>>>,
 
-  CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1], CCAssignToStack<4, 4>>
+  CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1, bf16, v2bf16], CCAssignToStack<4, 4>>
 ]>;
 
 def RetCC_SI_Gfx : CallingConv<[
   CCIfType<[i1], CCPromoteToType<i32>>,
   CCIfType<[i1, i16], CCIfExtend<CCPromoteToType<i32>>>,
 
-  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -66,7 +66,7 @@ def RetCC_SI_Gfx : CallingConv<[
 
 def CC_SI_SHADER : CallingConv<[
 
-  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     SGPR0, SGPR1, SGPR2, SGPR3, SGPR4, SGPR5, SGPR6, SGPR7,
     SGPR8, SGPR9, SGPR10, SGPR11, SGPR12, SGPR13, SGPR14, SGPR15,
     SGPR16, SGPR17, SGPR18, SGPR19, SGPR20, SGPR21, SGPR22, SGPR23,
@@ -76,7 +76,7 @@ def CC_SI_SHADER : CallingConv<[
   ]>>>,
 
   // 32*4 + 4 is the minimum for a fetch shader consumer with 32 inputs.
-  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -109,7 +109,7 @@ def RetCC_SI_Shader : CallingConv<[
   ]>>,
 
   // 32*4 + 4 is the minimum for a fetch shader with 32 outputs.
-  CCIfType<[f32, f16, v2f16] , CCAssignToReg<[
+  CCIfType<[f32, f16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -188,23 +188,23 @@ def CC_AMDGPU_Func : CallingConv<[
   CCIfType<[i1], CCPromoteToType<i32>>,
   CCIfType<[i8, i16], CCIfExtend<CCPromoteToType<i32>>>,
 
-  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
+  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
     !foreach(i, !range(0, 30), !cast<Register>("SGPR"#i))  // SGPR0-29
   >>>,
 
-  CCIfType<[i32, f32, i16, f16, v2i16, v2f16, i1], CCAssignToReg<[
+  CCIfType<[i32, f32, i16, f16, v2i16, v2f16, i1, bf16, v2bf16], CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
     VGPR24, VGPR25, VGPR26, VGPR27, VGPR28, VGPR29, VGPR30, VGPR31]>>,
-  CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1], CCAssignToStack<4, 4>>
+  CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1, bf16, v2bf16], CCAssignToStack<4, 4>>
 ]>;
 
 // Calling convention for leaf functions
 def RetCC_AMDGPU_Func : CallingConv<[
   CCIfType<[i1], CCPromoteToType<i32>>,
   CCIfType<[i1, i16], CCIfExtend<CCPromoteToType<i32>>>,
-  CCIfType<[i32, f32, i16, f16, v2i16, v2f16], CCAssignToReg<[
+  CCIfType<[i32, f32, i16, f16, v2i16, v2f16, bf16, v2bf16], CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -223,11 +223,11 @@ def CC_AMDGPU : CallingConv<[
 ]>;
 
 def CC_AMDGPU_CS_CHAIN : CallingConv<[
-  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
+  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
     !foreach(i, !range(105), !cast<Register>("SGPR"#i))
   >>>,
 
-  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
+  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
     !foreach(i, !range(8, 255), !cast<Register>("VGPR"#i))
   >>>
 ]>;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
index b0eac567ec9f18..40a49cbe3f518f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
@@ -303,6 +303,7 @@ void AMDGPUDAGToDAGISel::PreprocessISelDAG() {
 
     switch (N->getOpcode()) {
     case ISD::BUILD_VECTOR:
+      // TODO: Match load d16 from shl (extload:i16), 16
       MadeChange |= matchLoadD16FromBuildVector(N);
       break;
     default:
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index 4bf4707553e5fe..055684281d5955 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -389,15 +389,16 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
                      Custom);
   setOperationAction(
       ISD::EXTRACT_SUBVECTOR,
-      {MVT::v2f16,  MVT::v2i16,  MVT::v4f16,  MVT::v4i16,  MVT::v2f32,
-       MVT::v2i32,  MVT::v3f32,  MVT::v3i32,  MVT::v4f32,  MVT::v4i32,
-       MVT::v5f32,  MVT::v5i32,  MVT::v6f32,  MVT::v6i32,  MVT::v7f32,
-       MVT::v7i32,  MVT::v8f32,  MVT::v8i32,  MVT::v9f32,  MVT::v9i32,
-       MVT::v10i32, MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32,
-       MVT::v12f32, MVT::v16f16, MVT::v16i16, MVT::v16f32, MVT::v16i32,
-       MVT::v32f32, MVT::v32i32, MVT::v2f64,  MVT::v2i64,  MVT::v3f64,
-       MVT::v3i64,  MVT::v4f64,  MVT::v4i64,  MVT::v8f64,  MVT::v8i64,
-       MVT::v16f64, MVT::v16i64, MVT::v32i16, MVT::v32f16},
+      {MVT::v2f16,  MVT::v2i16,  MVT::v2bf16, MVT::v4f16,  MVT::v4i16,
+       MVT::v4bf16, MVT::v2f32,  MVT::v2i32,  MVT::v3f32,  MVT::v3i32,
+       MVT::v4f32,  MVT::v4i32,  MVT::v5f32,  MVT::v5i32,  MVT::v6f32,
+       MVT::v6i32,  MVT::v7f32,  MVT::v7i32,  MVT::v8f32,  MVT::v8i32,
+       MVT::v9f32,  MVT::v9i32,  MVT::v10i32, MVT::v10f32, MVT::v11i32,
+       MVT::v11f32, MVT::v12i32, MVT::v12f32, MVT::v16f16, MVT::v16i16,
+       MVT::v16f32, MVT::v16i32, MVT::v32f32, MVT::v32i32, MVT::v2f64,
+       MVT::v2i64,  MVT::v3f64,  MVT::v3i64,  MVT::v4f64,  MVT::v4i64,
+       MVT::v8f64,  MVT::v8i64,  MVT::v16f64, MVT::v16i64, MVT::v32i16,
+       MVT::v32f16},
       Custom);
 
   setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);
@@ -3274,7 +3275,15 @@ SDValue AMDGPUTargetLowering::LowerUINT_TO_FP(SDValue Op,
     return DAG.getNode(ISD::UINT_TO_FP, DL, DestVT, Ext);
   }
 
-  assert(SrcVT == MVT::i64 && "operation should be legal");
+  if (DestVT == MVT::bf16) {
+    SDLoc SL(Op);
+    SDValue ToF32 = DAG.getNode(ISD::UINT_TO_FP, SL, MVT::f32, Src);
+    SDValue FPRoundFlag = DAG.getIntPtrConstant(0, SL, /*isTarget=*/true);
+    return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ToF32, FPRoundFlag);
+  }
+
+  if (SrcVT != MVT::i64)
+    return Op;
 
   if (Subtarget->has16BitInsts() && DestVT == MVT::f16) {
     SDLoc DL(Op);
@@ -3312,7 +3321,15 @@ SDValue AMDGPUTargetLowering::LowerSINT_TO_FP(SDValue Op,
     return DAG.getNode(ISD::SINT_TO_FP, DL, DestVT, Ext);
   }
 
-  assert(SrcVT == MVT::i64 && "operation should be legal");
+  if (DestVT == MVT::bf16) {
+    SDLoc SL(Op);
+    SDValue ToF32 = DAG.getNode(ISD::SINT_TO_FP, SL, MVT::f32, Src);
+    SDValue FPRoundFlag = DAG.getIntPtrConstant(0, SL, /*isTarget=*/true);
+    return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ToF32, FPRoundFlag);
+  }
+
+  if (SrcVT != MVT::i64)
+    return Op;
 
   // TODO: Factor out code common with LowerUINT_TO_FP.
 
@@ -3510,7 +3527,7 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) con
   return DAG.getZExtOrTrunc(V, DL, Op.getValueType());
 }
 
-SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
+SDValue AMDGPUTargetLowering::LowerFP_TO_INT(const SDValue Op,
                                              SelectionDAG &DAG) const {
   SDValue Src = Op.getOperand(0);
   unsigned OpOpcode = Op.getOpcode();
@@ -3521,6 +3538,12 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
   if (SrcVT == MVT::f16 && DestVT == MVT::i16)
     return Op;
 
+  if (SrcVT == MVT::bf16) {
+    SDLoc DL(Op);
+    SDValue PromotedSrc = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Src);
+    return DAG.getNode(Op.getOpcode(), DL, DestVT, PromotedSrc);
+  }
+
   // Promote i16 to i32
   if (DestVT == MVT::i16 && (SrcVT == MVT::f32 || SrcVT == MVT::f64)) {
     SDLoc DL(Op);
@@ -3529,6 +3552,9 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
     return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToInt32);
   }
 
+  if (DestVT != MVT::i64)
+    return Op;
+
   if (SrcVT == MVT::f16 ||
       (SrcVT == MVT::f32 && Src.getOpcode() == ISD::FP16_TO_FP)) {
     SDLoc DL(Op);
@@ -3539,7 +3565,7 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
     return DAG.getNode(Ext, DL, MVT::i64, FpToInt32);
   }
 
-  if (DestVT == MVT::i64 && (SrcVT == MVT::f32 || SrcVT == MVT::f64))
+  if (SrcVT == MVT::f32 || SrcVT == MVT::f64)
     return LowerFP_TO_INT64(Op, DAG, OpOpcode == ISD::FP_TO_SINT);
 
   return SDValue();
@@ -4940,7 +4966,9 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N,
     //   vnt1 = build_vector (t1 (bitcast t0:x)), (t1 (bitcast t0:y))
     if (DestVT.isVector()) {
       SDValue Src = N->getOperand(0);
-      if (Src.getOpcode() == ISD::BUILD_VECTOR) {
+      if (Src.getOpcode() == ISD::BUILD_VECTOR &&
+          (DCI.getDAGCombineLevel() < AfterLegalizeDAG ||
+           isOperationLegal(ISD::BUILD_VECTOR, DestVT))) {
         EVT SrcVT = Src.getValueType();
         unsigned NElts = DestVT.getVectorNumElements();
 
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index fc119aa61d01a2..333f0cdd6d6541 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -151,16 +151,20 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
     if (Subtarget->useRealTrue16Insts()) {
       addRegisterClass(MVT::i16, &AMDGPU::VGPR_16RegClass);
       addRegisterClass(MVT::f16, &AMDGPU::VGPR_16RegClass);
+      addRegisterClass(MVT::bf16, &AMDGPU::VGPR_16RegClass);
     } else {
       addRegisterClass(MVT::i16, &AMDGPU::SReg_32RegClass);
       addRegisterClass(MVT::f16, &AMDGPU::SReg_32RegClass);
+      addRegisterClass(MVT::bf16, &AMDGPU::SReg_32RegClass);
     }
 
     // Unless there are also VOP3P operations, not operations are really legal.
     addRegisterClass(MVT::v2i16, &AMDGPU::SReg_32RegClass);
     addRegisterClass(MVT::v2f16, &AMDGPU::SReg_32RegClass);
+    addRegisterClass(MVT::v2bf16, &AMDGPU::SReg_32RegClass);
     addRegisterClass(MVT::v4i16, &AMDGPU::SReg_64RegClass);
     addRegisterClass(MVT::v4f16, &AMDGPU::SReg_64RegClass);
+    addRegisterClass(MVT::v4bf16, &AMDGPU::SReg_64RegClass);
     addRegisterClass(MVT::v8i16, &AMDGPU::SGPR_128RegClass);
     addRegisterClass(MVT::v8f16, &AMDGPU::SGPR_128RegClass);
     addRegisterClass(MVT::v16i16, &AMDGPU::SGPR_256RegClass);
@@ -196,6 +200,41 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
                       MVT::i1,     MVT::v32i32},
                      Custom);
 
+  if (isTypeLegal(MVT::bf16)) {
+    for (unsigned Opc :
+         {ISD::FADD,     ISD::FSUB,       ISD::FMUL,    ISD::FDIV,
+          ISD::FREM,     ISD::FMA,        ISD::FMINNUM, ISD::FMAXNUM,
+          ISD::FMINIMUM, ISD::FMAXIMUM,   ISD::FSQRT,   ISD::FCBRT,
+          ISD::FSIN,     ISD::FCOS,       ISD::FPOW,    ISD::FPOWI,
+          ISD::FLDEXP,   ISD::FFREXP,     ISD::FLOG,    ISD::FLOG2,
+          ISD::FLOG10,   ISD::FEXP,       ISD::FEXP2,   ISD::FEXP10,
+          ISD::FCEIL,    ISD::FTRUNC,     ISD::FRINT,   ISD::FNEARBYINT,
+          ISD::FROUND,   ISD::FROUNDEVEN, ISD::FFLOOR,  ISD::FCANONICALIZE,
+          ISD::SETCC}) {
+      // FIXME: The promoted to type shouldn't need to be explicit
+      setOperationAction(Opc, MVT::bf16, Promote);
+      AddPromotedToType(Opc, MVT::bf16, MVT::f32);
+    }
+
+    setOperationAction(ISD::FP_ROUND, MVT::bf16, Expand);
+
+    setOperationAction(ISD::SELECT, MVT::bf16, Promote);
+    AddPromotedToType(ISD::SELECT, MVT::bf16, MVT::i16);
+
+    // TODO: Could make these legal
+    setOperationAction(ISD::FABS, MVT::bf16, Expand);
+    setOperationAction(ISD::FNEG, MVT::bf16, Expand);
+    setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand);
+
+    // We only need to custom lower because we can't specify an action for bf16
+    // sources.
+    setOperationAction(ISD::FP_TO_SINT, MVT::i32, Custom);
+    setOperationAction(ISD::FP_TO_UINT, MVT::i32, Custom);
+
+    setOperationAction(ISD::BUILD_VECTOR, MVT::v2bf16, Promote);
+    AddPromotedToType(ISD::BUILD_VECTOR, MVT::v2bf16, MVT::v2i16);
+  }
+
   setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand);
   setTruncStoreAction(MVT::v3i32, MVT::v3i16, Expand);
   setTruncStoreAction(MVT::v4i32, MVT::v4i16, Expand);
@@ -274,10 +313,10 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
        {MVT::v8i32,  MVT::v8f32,  MVT::v9i32,  MVT::v9f32,  MVT::v10i32,
         MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32,
         MVT::v16i32, MVT::v16f32, MVT::v2i64,  MVT::v2f64,  MVT::v4i16,
-        MVT::v4f16,  MVT::v3i64,  MVT::v3f64,  MVT::v6i32,  MVT::v6f32,
-        MVT::v4i64,  MVT::v4f64,  MVT::v8i64,  MVT::v8f64,  MVT::v8i16,
-        MVT::v8f16,  MVT::v16i16, MVT::v16f16, MVT::v16i64, MVT::v16f64,
-        MVT::v32i32, MVT::v32f32, MVT::v32i16, MVT::v32f16}) {
+        MVT::v4f16,  MVT::v4bf16, MVT::v3i64,  MVT::v3f64,  MVT::v6i32,
+        MVT::v6f32,  MVT::v4i64,  MVT::v4f64,  MVT::v8i64,  MVT::v8f64,
+        MVT::v8i16,  MVT::v8f16,  MVT::v16i16, MVT::v16f16, MVT::v16i64,
+        MVT::v16f64, MVT::v32i32, MVT::v32f32, MVT::v32i16, MVT::v32f16}) {
     for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op) {
       switch (Op) {
       case ISD::LOAD:
@@ -383,13 +422,14 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
                      {MVT::v8i32, MVT::v8f32, MVT::v16i32, MVT::v16f32},
                      Expand);
 
-  setOperationAction(ISD::BUILD_VECTOR, {MVT::v4f16, MVT::v4i16}, Custom);
+  setOperationAction(ISD::BUILD_VECTOR, {MVT::v4f16, MVT::v4i16, MVT::v4bf16},
+                     Custom);
 
   // Avoid stack access for these.
   // TODO: Generalize to more vector types.
   setOperationAction({ISD::EXTRACT_VECTOR_ELT, ISD::INSERT_VECTOR_ELT},
-                     {MVT::v2i16, MVT::v2f16, MVT::v2i8, MVT::v4i8, MVT::v8i8,
-                      MVT::v4i16, MVT::v4f16},
+                     {MVT::v2i16, MVT::v2f16, MVT::v2bf16, MVT::v2i8, MVT::v4i8,
+                      MVT::v8i8, MVT::v4i16, MVT::v4f16, MVT::v4bf16},
                      Custom);
 
   // Deal with vec3 vector operations when widened to vec4.
@@ -498,6 +538,11 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
   setOperationAction(ISD::BF16_TO_FP, {MVT::i16, MVT::f32, MVT::f64}, Expand);
   setOperationAction(ISD::FP_TO_BF16, {MVT::i16, MVT::f32, MVT::f64}, Expand);
 
+  // Custom lower these because we can't specify a rule based on an illegal
+  // source bf16.
+  setOperationAction({ISD::FP_EXTEND, ISD::STRICT_FP_EXTEND}, MVT::f32, Custom);
+  setOperationAction({ISD::FP_EXTEND, ISD::STRICT_FP_EXTEND}, MVT::f64, Custom);
+
   if (Subtarget->has16BitInsts()) {
     setOperationAction({ISD::Constant, ISD::SMIN, ISD::SMAX, ISD::UMIN,
                         ISD::UMAX, ISD::UADDSAT, ISD::USUBSAT},
@@ -524,9 +569,14 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
     AddPromotedToType(ISD::FP_TO_FP16, MVT::i16, MVT::i32);
 
     setOperationAction({ISD::FP_TO_SINT, ISD::FP_TO_UINT}, MVT::i16, Custom);
+    setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i16, Custom);
+    setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i16, Custom);
+
+    setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i32, Custom);
 
     // F16 - Constant Actions.
     setOperationA...
[truncated]

Copy link

github-actions bot commented Dec 27, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

arsenm added 3 commits January 1, 2024 21:24
Assorted intrinsics are currently using i16 in place of a proper
bfloat type, but they should really switch to bfloat.

Note this only changes the type lists in tablegen, these are still
not registered to be truly treated as a legal type yet.
There are some intrinsics are using i16 vectors in place of bfloat vectors.
Move towards making bf16 vectors legal so these can migrate. Leave the
larger vectors for a later change.
Gets a few code quality improvements. A few cases are worse
from losing load narrowing.
@arsenm arsenm merged commit 4768563 into llvm:main Jan 5, 2024
@arsenm arsenm deleted the bf16/legal-v4bf16 branch January 5, 2024 01:35
arsenm added a commit that referenced this pull request Jan 8, 2024
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AMDGPU llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants