Skip to content

[HLSL] Adding Flatten and Branch if attributes with test fixes #122157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 13, 2025

Conversation

joaosaffran
Copy link
Contributor

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:codegen IR generation bugs: mangling, exceptions, etc. backend:DirectX HLSL HLSL Language Support backend:SPIR-V llvm:ir labels Jan 8, 2025
@llvmbot
Copy link
Member

llvmbot commented Jan 8, 2025

@llvm/pr-subscribers-clang
@llvm/pr-subscribers-backend-directx

@llvm/pr-subscribers-hlsl

Author: None (joaosaffran)

Changes
  • Adding the changes from PRs:
    • #116331
    • #121852
  • Fixes test tools/dxil-dis/debug-info.ll
  • Address some missed comments in the previous PR

Patch is 29.51 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122157.diff

14 Files Affected:

  • (modified) clang/include/clang/Basic/Attr.td (+10)
  • (modified) clang/lib/CodeGen/CGStmt.cpp (+6)
  • (modified) clang/lib/CodeGen/CodeGenFunction.cpp (+25-1)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+4)
  • (modified) clang/lib/Sema/SemaStmtAttr.cpp (+8)
  • (added) clang/test/AST/HLSL/HLSLControlFlowHint.hlsl (+43)
  • (added) clang/test/CodeGenHLSL/HLSLControlFlowHint.hlsl (+48)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1-1)
  • (modified) llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp (+36)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+23-6)
  • (modified) llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp (+33-11)
  • (added) llvm/test/CodeGen/DirectX/HLSLControlFlowHint.ll (+98)
  • (added) llvm/test/CodeGen/SPIRV/structurizer/HLSLControlFlowHint-pass-check.ll (+90)
  • (added) llvm/test/CodeGen/SPIRV/structurizer/HLSLControlFlowHint.ll (+91)
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 12faf06597008e..6d7f65ab2c6135 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4335,6 +4335,16 @@ def HLSLLoopHint: StmtAttr {
   let Documentation = [HLSLLoopHintDocs, HLSLUnrollHintDocs];
 }
 
+def HLSLControlFlowHint: StmtAttr {
+  /// [branch]
+  /// [flatten]
+  let Spellings = [Microsoft<"branch">, Microsoft<"flatten">];
+  let Subjects = SubjectList<[IfStmt],
+                              ErrorDiag, "'if' statements">;
+  let LangOpts = [HLSL];
+  let Documentation = [InternalOnly];
+}
+
 def CapturedRecord : InheritableAttr {
   // This attribute has no spellings as it is only ever created implicitly.
   let Spellings = [];
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index a87c50b8a1cbbf..c8ff48fc733125 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -757,6 +757,8 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
   bool noinline = false;
   bool alwaysinline = false;
   bool noconvergent = false;
+  HLSLControlFlowHintAttr::Spelling flattenOrBranch =
+      HLSLControlFlowHintAttr::SpellingNotCalculated;
   const CallExpr *musttail = nullptr;
 
   for (const auto *A : S.getAttrs()) {
@@ -788,6 +790,9 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
         Builder.CreateAssumption(AssumptionVal);
       }
     } break;
+    case attr::HLSLControlFlowHint: {
+      flattenOrBranch = cast<HLSLControlFlowHintAttr>(A)->getSemanticSpelling();
+    } break;
     }
   }
   SaveAndRestore save_nomerge(InNoMergeAttributedStmt, nomerge);
@@ -795,6 +800,7 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
   SaveAndRestore save_alwaysinline(InAlwaysInlineAttributedStmt, alwaysinline);
   SaveAndRestore save_noconvergent(InNoConvergentAttributedStmt, noconvergent);
   SaveAndRestore save_musttail(MustTailCall, musttail);
+  SaveAndRestore save_flattenOrBranch(HLSLControlFlowAttr, flattenOrBranch);
   EmitStmt(S.getSubStmt(), S.getAttrs());
 }
 
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index af58fa64f86585..176e8386d933d6 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -40,6 +40,7 @@
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/FPEnv.h"
+#include "llvm/IR/Instruction.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/MDBuilder.h"
@@ -2083,7 +2084,30 @@ void CodeGenFunction::EmitBranchOnBoolExpr(
     Weights = createProfileWeights(TrueCount, CurrentCount - TrueCount);
   }
 
-  Builder.CreateCondBr(CondV, TrueBlock, FalseBlock, Weights, Unpredictable);
+  llvm::Instruction *BrInst = Builder.CreateCondBr(CondV, TrueBlock, FalseBlock,
+                                                   Weights, Unpredictable);
+  switch (HLSLControlFlowAttr) {
+  case HLSLControlFlowHintAttr::Microsoft_branch:
+  case HLSLControlFlowHintAttr::Microsoft_flatten: {
+    llvm::MDBuilder MDHelper(CGM.getLLVMContext());
+
+    llvm::ConstantInt *BranchHintConstant =
+        HLSLControlFlowAttr ==
+                HLSLControlFlowHintAttr::Spelling::Microsoft_branch
+            ? llvm::ConstantInt::get(CGM.Int32Ty, 1)
+            : llvm::ConstantInt::get(CGM.Int32Ty, 2);
+
+    SmallVector<llvm::Metadata *, 2> Vals(
+        {MDHelper.createString("hlsl.controlflow.hint"),
+         MDHelper.createConstant(BranchHintConstant)});
+    BrInst->setMetadata("hlsl.controlflow.hint",
+                        llvm::MDNode::get(CGM.getLLVMContext(), Vals));
+    break;
+  }
+  // This is required to avoid warnings during compilation
+  case HLSLControlFlowHintAttr::SpellingNotCalculated:
+    break;
+  }
 }
 
 /// ErrorUnsupported - Print out an error that codegen doesn't support the
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index f2240f8308ce38..bc612a0bfb32ba 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -615,6 +615,10 @@ class CodeGenFunction : public CodeGenTypeCache {
   /// True if the current statement has noconvergent attribute.
   bool InNoConvergentAttributedStmt = false;
 
+  /// HLSL Branch attribute.
+  HLSLControlFlowHintAttr::Spelling HLSLControlFlowAttr =
+      HLSLControlFlowHintAttr::SpellingNotCalculated;
+
   // The CallExpr within the current statement that the musttail attribute
   // applies to.  nullptr if there is no 'musttail' on the current statement.
   const CallExpr *MustTailCall = nullptr;
diff --git a/clang/lib/Sema/SemaStmtAttr.cpp b/clang/lib/Sema/SemaStmtAttr.cpp
index 106e2430de901e..422d8abc1028aa 100644
--- a/clang/lib/Sema/SemaStmtAttr.cpp
+++ b/clang/lib/Sema/SemaStmtAttr.cpp
@@ -619,6 +619,12 @@ static Attr *handleHLSLLoopHintAttr(Sema &S, Stmt *St, const ParsedAttr &A,
   return ::new (S.Context) HLSLLoopHintAttr(S.Context, A, UnrollFactor);
 }
 
+static Attr *handleHLSLControlFlowHint(Sema &S, Stmt *St, const ParsedAttr &A,
+                                       SourceRange Range) {
+
+  return ::new (S.Context) HLSLControlFlowHintAttr(S.Context, A);
+}
+
 static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
                                   SourceRange Range) {
   if (A.isInvalid() || A.getKind() == ParsedAttr::IgnoredAttribute)
@@ -655,6 +661,8 @@ static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
     return handleLoopHintAttr(S, St, A, Range);
   case ParsedAttr::AT_HLSLLoopHint:
     return handleHLSLLoopHintAttr(S, St, A, Range);
+  case ParsedAttr::AT_HLSLControlFlowHint:
+    return handleHLSLControlFlowHint(S, St, A, Range);
   case ParsedAttr::AT_OpenCLUnrollHint:
     return handleOpenCLUnrollHint(S, St, A, Range);
   case ParsedAttr::AT_Suppress:
diff --git a/clang/test/AST/HLSL/HLSLControlFlowHint.hlsl b/clang/test/AST/HLSL/HLSLControlFlowHint.hlsl
new file mode 100644
index 00000000000000..a36779c05fbc93
--- /dev/null
+++ b/clang/test/AST/HLSL/HLSLControlFlowHint.hlsl
@@ -0,0 +1,43 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -ast-dump %s | FileCheck %s
+
+// CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used branch 'int (int)'
+// CHECK: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
+// CHECK-NEXT: -HLSLControlFlowHintAttr 0x{{[0-9A-Fa-f]+}} <{{.*}}> branch
+export int branch(int X){
+    int resp;
+    [branch] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used flatten 'int (int)'
+// CHECK: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
+// CHECK-NEXT: -HLSLControlFlowHintAttr 0x{{[0-9A-Fa-f]+}} <{{.*}}> flatten
+export int flatten(int X){
+    int resp;
+    [flatten] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used no_attr 'int (int)'
+// CHECK-NOT: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
+// CHECK-NOT: -HLSLControlFlowHintAttr
+export int no_attr(int X){
+    int resp;
+    if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
diff --git a/clang/test/CodeGenHLSL/HLSLControlFlowHint.hlsl b/clang/test/CodeGenHLSL/HLSLControlFlowHint.hlsl
new file mode 100644
index 00000000000000..aa13b275818502
--- /dev/null
+++ b/clang/test/CodeGenHLSL/HLSLControlFlowHint.hlsl
@@ -0,0 +1,48 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv-vulkan-library %s -fnative-half-type -emit-llvm -o - | FileCheck %s
+
+// CHECK: define {{.*}} i32 {{.*}}test_branch{{.*}}(i32 {{.*}} [[VALD:%.*]])
+// CHECK: [[PARAM:%.*]] = load i32, ptr [[VALD]].addr, align 4
+// CHECK: [[CMP:%.*]] = icmp sgt i32 [[PARAM]], 0
+// CHECK: br i1 [[CMP]], label %if.then, label %if.else, !hlsl.controlflow.hint [[HINT_BRANCH:![0-9]+]]
+export int test_branch(int X){
+    int resp;
+    [branch] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: define {{.*}} i32 {{.*}}test_flatten{{.*}}(i32 {{.*}} [[VALD:%.*]])
+// CHECK: [[PARAM:%.*]] = load i32, ptr [[VALD]].addr, align 4
+// CHECK: [[CMP:%.*]] = icmp sgt i32 [[PARAM]], 0
+// CHECK: br i1 [[CMP]], label %if.then, label %if.else, !hlsl.controlflow.hint [[HINT_FLATTEN:![0-9]+]]
+export int test_flatten(int X){
+    int resp;
+    [flatten] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: define {{.*}} i32 {{.*}}test_no_attr{{.*}}(i32 {{.*}} [[VALD:%.*]])
+// CHECK-NOT: !hlsl.controlflow.hint
+export int test_no_attr(int X){
+    int resp;
+    if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+//CHECK: [[HINT_BRANCH]] = !{!"hlsl.controlflow.hint", i32 1}
+//CHECK: [[HINT_FLATTEN]] = !{!"hlsl.controlflow.hint", i32 2}
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index b4d2dce66a6f0b..37057271b6c284 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -33,7 +33,7 @@ let TargetPrefix = "spv" in {
   def int_spv_ptrcast : Intrinsic<[llvm_any_ty], [llvm_any_ty, llvm_metadata_ty, llvm_i32_ty], [ImmArg<ArgIndex<2>>]>;
   def int_spv_switch : Intrinsic<[], [llvm_any_ty, llvm_vararg_ty]>;
   def int_spv_loop_merge : Intrinsic<[], [llvm_vararg_ty]>;
-  def int_spv_selection_merge : Intrinsic<[], [llvm_vararg_ty]>;
+  def int_spv_selection_merge : Intrinsic<[], [llvm_any_ty, llvm_i32_ty], [ImmArg<ArgIndex<1>>]>;
   def int_spv_cmpxchg : Intrinsic<[llvm_i32_ty], [llvm_any_ty, llvm_vararg_ty]>;
   def int_spv_unreachable : Intrinsic<[], []>;
   def int_spv_alloca : Intrinsic<[llvm_any_ty], [llvm_i8_ty], [ImmArg<ArgIndex<0>>]>;
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index 5afe6b2d2883db..5fd5c226eef894 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -15,12 +15,14 @@
 #include "llvm/ADT/Twine.h"
 #include "llvm/Analysis/DXILMetadataAnalysis.h"
 #include "llvm/Analysis/DXILResource.h"
+#include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/DiagnosticPrinter.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/InitializePasses.h"
@@ -300,6 +302,38 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
   return constructEntryMetadata(nullptr, nullptr, RMD, Properties, Ctx);
 }
 
+// TODO: We might need to refactor this to be more generic,
+// in case we need more metadata to be replaced.
+static void translateBranchMetadata(Module &M) {
+  for (Function &F : M) {
+    for (BasicBlock &BB : F) {
+      Instruction *BBTerminatorInst = BB.getTerminator();
+
+      MDNode *HlslControlFlowMD =
+          BBTerminatorInst->getMetadata("hlsl.controlflow.hint");
+
+      if (!HlslControlFlowMD)
+        continue;
+
+      assert(HlslControlFlowMD->getNumOperands() == 2 &&
+             "invalid operands for hlsl.controlflow.hint");
+
+      MDBuilder MDHelper(M.getContext());
+      ConstantInt *Op1 =
+          mdconst::extract<ConstantInt>(HlslControlFlowMD->getOperand(1));
+
+      SmallVector<llvm::Metadata *, 2> Vals(
+          ArrayRef<Metadata *>{MDHelper.createString("dx.controlflow.hints"),
+                               MDHelper.createConstant(Op1)});
+
+      MDNode *MDNode = llvm::MDNode::get(M.getContext(), Vals);
+
+      BBTerminatorInst->setMetadata("dx.controlflow.hints", MDNode);
+      BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr);
+    }
+  }
+}
+
 static void translateMetadata(Module &M, DXILBindingMap &DBM,
                               DXILResourceTypeMap &DRTM,
                               const Resources &MDResources,
@@ -372,6 +406,7 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M,
   const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M);
 
   translateMetadata(M, DBM, DRTM, MDResources, ShaderFlags, MMDI);
+  translateBranchMetadata(M);
 
   return PreservedAnalyses::all();
 }
@@ -409,6 +444,7 @@ class DXILTranslateMetadataLegacy : public ModulePass {
         getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
 
     translateMetadata(M, DBM, DRTM, MDResources, ShaderFlags, MMDI);
+    translateBranchMetadata(M);
     return true;
   }
 };
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 28c9b81db51f51..237f71a1b70e50 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -33,6 +33,7 @@
 #include "llvm/CodeGen/TargetOpcodes.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
 
 #define DEBUG_TYPE "spirv-isel"
 
@@ -45,6 +46,17 @@ using ExtInstList =
 
 namespace {
 
+llvm::SPIRV::SelectionControl::SelectionControl
+getSelectionOperandForImm(int Imm) {
+  if (Imm == 2)
+    return SPIRV::SelectionControl::Flatten;
+  if (Imm == 1)
+    return SPIRV::SelectionControl::DontFlatten;
+  if (Imm == 0)
+    return SPIRV::SelectionControl::None;
+  llvm_unreachable("Invalid immediate");
+}
+
 #define GET_GLOBALISEL_PREDICATE_BITSET
 #include "SPIRVGenGlobalISel.inc"
 #undef GET_GLOBALISEL_PREDICATE_BITSET
@@ -2818,12 +2830,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     }
     return MIB.constrainAllUses(TII, TRI, RBI);
   }
-  case Intrinsic::spv_loop_merge:
-  case Intrinsic::spv_selection_merge: {
-    const auto Opcode = IID == Intrinsic::spv_selection_merge
-                            ? SPIRV::OpSelectionMerge
-                            : SPIRV::OpLoopMerge;
-    auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode));
+  case Intrinsic::spv_loop_merge: {
+    auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLoopMerge));
     for (unsigned i = 1; i < I.getNumExplicitOperands(); ++i) {
       assert(I.getOperand(i).isMBB());
       MIB.addMBB(I.getOperand(i).getMBB());
@@ -2831,6 +2839,15 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     MIB.addImm(SPIRV::SelectionControl::None);
     return MIB.constrainAllUses(TII, TRI, RBI);
   }
+  case Intrinsic::spv_selection_merge: {
+    auto MIB =
+        BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSelectionMerge));
+    assert(I.getOperand(1).isMBB() &&
+           "operand 1 to spv_selection_merge must be a basic block");
+    MIB.addMBB(I.getOperand(1).getMBB());
+    MIB.addImm(getSelectionOperandForImm(I.getOperand(2).getImm()));
+    return MIB.constrainAllUses(TII, TRI, RBI);
+  }
   case Intrinsic::spv_cmpxchg:
     return selectAtomicCmpXchg(ResVReg, ResType, I);
   case Intrinsic::spv_unreachable:
diff --git a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
index 336cde4e782246..2e4343c7922f1c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
@@ -18,14 +18,16 @@
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/CodeGen/IntrinsicLowering.h"
-#include "llvm/IR/Analysis.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/IR/LegacyPassManager.h"
 #include "llvm/InitializePasses.h"
+#include "llvm/PassRegistry.h"
+#include "llvm/Transforms/Utils.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/LoopSimplify.h"
 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
@@ -646,8 +648,7 @@ class SPIRVStructurizer : public FunctionPass {
       Builder.SetInsertPoint(Header->getTerminator());
 
       auto MergeAddress = BlockAddress::get(BB.getParent(), &BB);
-      SmallVector<Value *, 1> Args = {MergeAddress};
-      Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+      createOpSelectMerge(&Builder, MergeAddress);
 
       Modified = true;
     }
@@ -769,10 +770,9 @@ class SPIRVStructurizer : public FunctionPass {
       BasicBlock *Merge = Candidates[0];
 
       auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
-      SmallVector<Value *, 1> Args = {MergeAddress};
       IRBuilder<> Builder(&BB);
       Builder.SetInsertPoint(BB.getTerminator());
-      Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+      createOpSelectMerge(&Builder, MergeAddress);
     }
 
     return Modified;
@@ -1105,8 +1105,7 @@ class SPIRVStructurizer : public FunctionPass {
         Builder.SetInsertPoint(Header->getTerminator());
 
         auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
-        SmallVector<Value *, 1> Args = {MergeAddress};
-        Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+        createOpSelectMerge(&Builder, MergeAddress);
         continue;
       }
 
@@ -1120,8 +1119,7 @@ class SPIRVStructurizer : public FunctionPass {
       Builder.SetInsertPoint(Header->getTerminator());
 
       auto MergeAddress = BlockAddress::get(NewMerge->getParent(), NewMerge);
-      SmallVector<Value *, 1> Args = {MergeAddress};
-      Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+      createOpSelectMerge(&Builder, MergeAddress);
     }
 
     return Modified;
@@ -1208,6 +1206,27 @@ class SPIRVStructurizer : public FunctionPass {
     AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
     FunctionPass::getAnalysisUsage(AU);
   }
+
+  void createOpSelectMerge(IRBuilder<> *Builder, BlockAddress *MergeAddress) {
+    Instruction *BBTerminatorInst = Builder->GetInsertBlock()->getTerminator();
+
+    MDNode *MDNode = BBTerminatorInst->getMetadata("hlsl.controlflow.hint");
+
+    ConstantInt *BranchHint = llvm::ConstantInt::get(Builder->getInt32Ty(), 0);
+
+    if (MDNode) {
+      assert(MDNode->getNumOperands() == 2 &&
+             "invalid metadata hlsl.controlflow.hint");
+      BranchHint = mdconst::extract<ConstantInt>(MDNode->getOperand(1));
+
+      assert(BranchHint && "invalid metadata value for hlsl.controlflow.hint");
+    }
+
+    llvm::SmallVector<llvm::Value *, 2> Args = {MergeAddress, BranchHint};
+
+    Builder->CreateIntrinsic(Intrinsic::spv_selection_merge,
+                             {MergeAddress->getType()}, {Args});
+  }
 };
 } // namespace llvm
 
@@ -1229,8 +1248,11 @@ FunctionPass *llvm::createSPIRVStructurizerPass() {
 
 PreservedAnalyses SPIRVStructurizerWrapper::run(Function &F,
                                                 FunctionAnalysisManager &AF) {
-  FunctionPass *StructurizerPass = createSPIRVStructurizerPass();
-  if (!StructurizerPass->runOnFunction(F))
+
+  auto FPM = legacy::FunctionPassManager(F.getParent());
+  FPM.add(createSPIRVStructurizerPass());
+
+  if (!FPM.run(F))
     return PreservedAnalyses::all();
   PreservedAnalyses PA;
   PA.preserveSet<CFGAnalyses>();
diff --git a/llvm/test/CodeGen/DirectX/HLSLControlFlowHint.ll b/llvm/test/CodeGen/DirectX/HLSLControlFlowHint.ll
new file mode 100644
index 00000000000000..6a5274429930ea
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/HLSLControlFlowHint.ll
@@ -0,0 +1,98 @@
+; RUN: opt -S -dxil-op-lower -dxil-translate-metadata -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+; This test make sure LLVM metadata is being translated into DXIL.
+
+
+; CHECK: define i32 @test_branch(i32 %X)
+; CHECK-NOT: hlsl.controlflow.hint
+; CHECK: br i1 %cmp, label %if.then, label %if.else, !dx.controlflow.hints [[HINT_BRANCH:![0-9]+]]
+define i32 @test_branch(i32 %X) {
+entry:
+  %X.addr = alloca i32, align 4
+  %resp = alloca i32, align 4
+  store i32 %X, ptr %X.addr, align 4
+  %0 = load i32, ptr %X.addr, align 4
+  %cmp = icmp sgt i32 %0, 0
+  br i1 ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jan 8, 2025

@llvm/pr-subscribers-backend-spir-v

Author: None (joaosaffran)

Changes
  • Adding the changes from PRs:
    • #116331
    • #121852
  • Fixes test tools/dxil-dis/debug-info.ll
  • Address some missed comments in the previous PR

Patch is 29.51 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122157.diff

14 Files Affected:

  • (modified) clang/include/clang/Basic/Attr.td (+10)
  • (modified) clang/lib/CodeGen/CGStmt.cpp (+6)
  • (modified) clang/lib/CodeGen/CodeGenFunction.cpp (+25-1)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+4)
  • (modified) clang/lib/Sema/SemaStmtAttr.cpp (+8)
  • (added) clang/test/AST/HLSL/HLSLControlFlowHint.hlsl (+43)
  • (added) clang/test/CodeGenHLSL/HLSLControlFlowHint.hlsl (+48)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1-1)
  • (modified) llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp (+36)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+23-6)
  • (modified) llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp (+33-11)
  • (added) llvm/test/CodeGen/DirectX/HLSLControlFlowHint.ll (+98)
  • (added) llvm/test/CodeGen/SPIRV/structurizer/HLSLControlFlowHint-pass-check.ll (+90)
  • (added) llvm/test/CodeGen/SPIRV/structurizer/HLSLControlFlowHint.ll (+91)
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 12faf06597008e..6d7f65ab2c6135 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4335,6 +4335,16 @@ def HLSLLoopHint: StmtAttr {
   let Documentation = [HLSLLoopHintDocs, HLSLUnrollHintDocs];
 }
 
+def HLSLControlFlowHint: StmtAttr {
+  /// [branch]
+  /// [flatten]
+  let Spellings = [Microsoft<"branch">, Microsoft<"flatten">];
+  let Subjects = SubjectList<[IfStmt],
+                              ErrorDiag, "'if' statements">;
+  let LangOpts = [HLSL];
+  let Documentation = [InternalOnly];
+}
+
 def CapturedRecord : InheritableAttr {
   // This attribute has no spellings as it is only ever created implicitly.
   let Spellings = [];
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index a87c50b8a1cbbf..c8ff48fc733125 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -757,6 +757,8 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
   bool noinline = false;
   bool alwaysinline = false;
   bool noconvergent = false;
+  HLSLControlFlowHintAttr::Spelling flattenOrBranch =
+      HLSLControlFlowHintAttr::SpellingNotCalculated;
   const CallExpr *musttail = nullptr;
 
   for (const auto *A : S.getAttrs()) {
@@ -788,6 +790,9 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
         Builder.CreateAssumption(AssumptionVal);
       }
     } break;
+    case attr::HLSLControlFlowHint: {
+      flattenOrBranch = cast<HLSLControlFlowHintAttr>(A)->getSemanticSpelling();
+    } break;
     }
   }
   SaveAndRestore save_nomerge(InNoMergeAttributedStmt, nomerge);
@@ -795,6 +800,7 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
   SaveAndRestore save_alwaysinline(InAlwaysInlineAttributedStmt, alwaysinline);
   SaveAndRestore save_noconvergent(InNoConvergentAttributedStmt, noconvergent);
   SaveAndRestore save_musttail(MustTailCall, musttail);
+  SaveAndRestore save_flattenOrBranch(HLSLControlFlowAttr, flattenOrBranch);
   EmitStmt(S.getSubStmt(), S.getAttrs());
 }
 
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index af58fa64f86585..176e8386d933d6 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -40,6 +40,7 @@
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/FPEnv.h"
+#include "llvm/IR/Instruction.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/MDBuilder.h"
@@ -2083,7 +2084,30 @@ void CodeGenFunction::EmitBranchOnBoolExpr(
     Weights = createProfileWeights(TrueCount, CurrentCount - TrueCount);
   }
 
-  Builder.CreateCondBr(CondV, TrueBlock, FalseBlock, Weights, Unpredictable);
+  llvm::Instruction *BrInst = Builder.CreateCondBr(CondV, TrueBlock, FalseBlock,
+                                                   Weights, Unpredictable);
+  switch (HLSLControlFlowAttr) {
+  case HLSLControlFlowHintAttr::Microsoft_branch:
+  case HLSLControlFlowHintAttr::Microsoft_flatten: {
+    llvm::MDBuilder MDHelper(CGM.getLLVMContext());
+
+    llvm::ConstantInt *BranchHintConstant =
+        HLSLControlFlowAttr ==
+                HLSLControlFlowHintAttr::Spelling::Microsoft_branch
+            ? llvm::ConstantInt::get(CGM.Int32Ty, 1)
+            : llvm::ConstantInt::get(CGM.Int32Ty, 2);
+
+    SmallVector<llvm::Metadata *, 2> Vals(
+        {MDHelper.createString("hlsl.controlflow.hint"),
+         MDHelper.createConstant(BranchHintConstant)});
+    BrInst->setMetadata("hlsl.controlflow.hint",
+                        llvm::MDNode::get(CGM.getLLVMContext(), Vals));
+    break;
+  }
+  // This is required to avoid warnings during compilation
+  case HLSLControlFlowHintAttr::SpellingNotCalculated:
+    break;
+  }
 }
 
 /// ErrorUnsupported - Print out an error that codegen doesn't support the
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index f2240f8308ce38..bc612a0bfb32ba 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -615,6 +615,10 @@ class CodeGenFunction : public CodeGenTypeCache {
   /// True if the current statement has noconvergent attribute.
   bool InNoConvergentAttributedStmt = false;
 
+  /// HLSL Branch attribute.
+  HLSLControlFlowHintAttr::Spelling HLSLControlFlowAttr =
+      HLSLControlFlowHintAttr::SpellingNotCalculated;
+
   // The CallExpr within the current statement that the musttail attribute
   // applies to.  nullptr if there is no 'musttail' on the current statement.
   const CallExpr *MustTailCall = nullptr;
diff --git a/clang/lib/Sema/SemaStmtAttr.cpp b/clang/lib/Sema/SemaStmtAttr.cpp
index 106e2430de901e..422d8abc1028aa 100644
--- a/clang/lib/Sema/SemaStmtAttr.cpp
+++ b/clang/lib/Sema/SemaStmtAttr.cpp
@@ -619,6 +619,12 @@ static Attr *handleHLSLLoopHintAttr(Sema &S, Stmt *St, const ParsedAttr &A,
   return ::new (S.Context) HLSLLoopHintAttr(S.Context, A, UnrollFactor);
 }
 
+static Attr *handleHLSLControlFlowHint(Sema &S, Stmt *St, const ParsedAttr &A,
+                                       SourceRange Range) {
+
+  return ::new (S.Context) HLSLControlFlowHintAttr(S.Context, A);
+}
+
 static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
                                   SourceRange Range) {
   if (A.isInvalid() || A.getKind() == ParsedAttr::IgnoredAttribute)
@@ -655,6 +661,8 @@ static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
     return handleLoopHintAttr(S, St, A, Range);
   case ParsedAttr::AT_HLSLLoopHint:
     return handleHLSLLoopHintAttr(S, St, A, Range);
+  case ParsedAttr::AT_HLSLControlFlowHint:
+    return handleHLSLControlFlowHint(S, St, A, Range);
   case ParsedAttr::AT_OpenCLUnrollHint:
     return handleOpenCLUnrollHint(S, St, A, Range);
   case ParsedAttr::AT_Suppress:
diff --git a/clang/test/AST/HLSL/HLSLControlFlowHint.hlsl b/clang/test/AST/HLSL/HLSLControlFlowHint.hlsl
new file mode 100644
index 00000000000000..a36779c05fbc93
--- /dev/null
+++ b/clang/test/AST/HLSL/HLSLControlFlowHint.hlsl
@@ -0,0 +1,43 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -ast-dump %s | FileCheck %s
+
+// CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used branch 'int (int)'
+// CHECK: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
+// CHECK-NEXT: -HLSLControlFlowHintAttr 0x{{[0-9A-Fa-f]+}} <{{.*}}> branch
+export int branch(int X){
+    int resp;
+    [branch] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used flatten 'int (int)'
+// CHECK: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
+// CHECK-NEXT: -HLSLControlFlowHintAttr 0x{{[0-9A-Fa-f]+}} <{{.*}}> flatten
+export int flatten(int X){
+    int resp;
+    [flatten] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used no_attr 'int (int)'
+// CHECK-NOT: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
+// CHECK-NOT: -HLSLControlFlowHintAttr
+export int no_attr(int X){
+    int resp;
+    if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
diff --git a/clang/test/CodeGenHLSL/HLSLControlFlowHint.hlsl b/clang/test/CodeGenHLSL/HLSLControlFlowHint.hlsl
new file mode 100644
index 00000000000000..aa13b275818502
--- /dev/null
+++ b/clang/test/CodeGenHLSL/HLSLControlFlowHint.hlsl
@@ -0,0 +1,48 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv-vulkan-library %s -fnative-half-type -emit-llvm -o - | FileCheck %s
+
+// CHECK: define {{.*}} i32 {{.*}}test_branch{{.*}}(i32 {{.*}} [[VALD:%.*]])
+// CHECK: [[PARAM:%.*]] = load i32, ptr [[VALD]].addr, align 4
+// CHECK: [[CMP:%.*]] = icmp sgt i32 [[PARAM]], 0
+// CHECK: br i1 [[CMP]], label %if.then, label %if.else, !hlsl.controlflow.hint [[HINT_BRANCH:![0-9]+]]
+export int test_branch(int X){
+    int resp;
+    [branch] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: define {{.*}} i32 {{.*}}test_flatten{{.*}}(i32 {{.*}} [[VALD:%.*]])
+// CHECK: [[PARAM:%.*]] = load i32, ptr [[VALD]].addr, align 4
+// CHECK: [[CMP:%.*]] = icmp sgt i32 [[PARAM]], 0
+// CHECK: br i1 [[CMP]], label %if.then, label %if.else, !hlsl.controlflow.hint [[HINT_FLATTEN:![0-9]+]]
+export int test_flatten(int X){
+    int resp;
+    [flatten] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: define {{.*}} i32 {{.*}}test_no_attr{{.*}}(i32 {{.*}} [[VALD:%.*]])
+// CHECK-NOT: !hlsl.controlflow.hint
+export int test_no_attr(int X){
+    int resp;
+    if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+//CHECK: [[HINT_BRANCH]] = !{!"hlsl.controlflow.hint", i32 1}
+//CHECK: [[HINT_FLATTEN]] = !{!"hlsl.controlflow.hint", i32 2}
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index b4d2dce66a6f0b..37057271b6c284 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -33,7 +33,7 @@ let TargetPrefix = "spv" in {
   def int_spv_ptrcast : Intrinsic<[llvm_any_ty], [llvm_any_ty, llvm_metadata_ty, llvm_i32_ty], [ImmArg<ArgIndex<2>>]>;
   def int_spv_switch : Intrinsic<[], [llvm_any_ty, llvm_vararg_ty]>;
   def int_spv_loop_merge : Intrinsic<[], [llvm_vararg_ty]>;
-  def int_spv_selection_merge : Intrinsic<[], [llvm_vararg_ty]>;
+  def int_spv_selection_merge : Intrinsic<[], [llvm_any_ty, llvm_i32_ty], [ImmArg<ArgIndex<1>>]>;
   def int_spv_cmpxchg : Intrinsic<[llvm_i32_ty], [llvm_any_ty, llvm_vararg_ty]>;
   def int_spv_unreachable : Intrinsic<[], []>;
   def int_spv_alloca : Intrinsic<[llvm_any_ty], [llvm_i8_ty], [ImmArg<ArgIndex<0>>]>;
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index 5afe6b2d2883db..5fd5c226eef894 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -15,12 +15,14 @@
 #include "llvm/ADT/Twine.h"
 #include "llvm/Analysis/DXILMetadataAnalysis.h"
 #include "llvm/Analysis/DXILResource.h"
+#include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/DiagnosticPrinter.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/InitializePasses.h"
@@ -300,6 +302,38 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
   return constructEntryMetadata(nullptr, nullptr, RMD, Properties, Ctx);
 }
 
+// TODO: We might need to refactor this to be more generic,
+// in case we need more metadata to be replaced.
+static void translateBranchMetadata(Module &M) {
+  for (Function &F : M) {
+    for (BasicBlock &BB : F) {
+      Instruction *BBTerminatorInst = BB.getTerminator();
+
+      MDNode *HlslControlFlowMD =
+          BBTerminatorInst->getMetadata("hlsl.controlflow.hint");
+
+      if (!HlslControlFlowMD)
+        continue;
+
+      assert(HlslControlFlowMD->getNumOperands() == 2 &&
+             "invalid operands for hlsl.controlflow.hint");
+
+      MDBuilder MDHelper(M.getContext());
+      ConstantInt *Op1 =
+          mdconst::extract<ConstantInt>(HlslControlFlowMD->getOperand(1));
+
+      SmallVector<llvm::Metadata *, 2> Vals(
+          ArrayRef<Metadata *>{MDHelper.createString("dx.controlflow.hints"),
+                               MDHelper.createConstant(Op1)});
+
+      MDNode *MDNode = llvm::MDNode::get(M.getContext(), Vals);
+
+      BBTerminatorInst->setMetadata("dx.controlflow.hints", MDNode);
+      BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr);
+    }
+  }
+}
+
 static void translateMetadata(Module &M, DXILBindingMap &DBM,
                               DXILResourceTypeMap &DRTM,
                               const Resources &MDResources,
@@ -372,6 +406,7 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M,
   const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M);
 
   translateMetadata(M, DBM, DRTM, MDResources, ShaderFlags, MMDI);
+  translateBranchMetadata(M);
 
   return PreservedAnalyses::all();
 }
@@ -409,6 +444,7 @@ class DXILTranslateMetadataLegacy : public ModulePass {
         getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
 
     translateMetadata(M, DBM, DRTM, MDResources, ShaderFlags, MMDI);
+    translateBranchMetadata(M);
     return true;
   }
 };
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 28c9b81db51f51..237f71a1b70e50 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -33,6 +33,7 @@
 #include "llvm/CodeGen/TargetOpcodes.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
 
 #define DEBUG_TYPE "spirv-isel"
 
@@ -45,6 +46,17 @@ using ExtInstList =
 
 namespace {
 
+llvm::SPIRV::SelectionControl::SelectionControl
+getSelectionOperandForImm(int Imm) {
+  if (Imm == 2)
+    return SPIRV::SelectionControl::Flatten;
+  if (Imm == 1)
+    return SPIRV::SelectionControl::DontFlatten;
+  if (Imm == 0)
+    return SPIRV::SelectionControl::None;
+  llvm_unreachable("Invalid immediate");
+}
+
 #define GET_GLOBALISEL_PREDICATE_BITSET
 #include "SPIRVGenGlobalISel.inc"
 #undef GET_GLOBALISEL_PREDICATE_BITSET
@@ -2818,12 +2830,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     }
     return MIB.constrainAllUses(TII, TRI, RBI);
   }
-  case Intrinsic::spv_loop_merge:
-  case Intrinsic::spv_selection_merge: {
-    const auto Opcode = IID == Intrinsic::spv_selection_merge
-                            ? SPIRV::OpSelectionMerge
-                            : SPIRV::OpLoopMerge;
-    auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode));
+  case Intrinsic::spv_loop_merge: {
+    auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLoopMerge));
     for (unsigned i = 1; i < I.getNumExplicitOperands(); ++i) {
       assert(I.getOperand(i).isMBB());
       MIB.addMBB(I.getOperand(i).getMBB());
@@ -2831,6 +2839,15 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     MIB.addImm(SPIRV::SelectionControl::None);
     return MIB.constrainAllUses(TII, TRI, RBI);
   }
+  case Intrinsic::spv_selection_merge: {
+    auto MIB =
+        BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSelectionMerge));
+    assert(I.getOperand(1).isMBB() &&
+           "operand 1 to spv_selection_merge must be a basic block");
+    MIB.addMBB(I.getOperand(1).getMBB());
+    MIB.addImm(getSelectionOperandForImm(I.getOperand(2).getImm()));
+    return MIB.constrainAllUses(TII, TRI, RBI);
+  }
   case Intrinsic::spv_cmpxchg:
     return selectAtomicCmpXchg(ResVReg, ResType, I);
   case Intrinsic::spv_unreachable:
diff --git a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
index 336cde4e782246..2e4343c7922f1c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
@@ -18,14 +18,16 @@
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/CodeGen/IntrinsicLowering.h"
-#include "llvm/IR/Analysis.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/IR/LegacyPassManager.h"
 #include "llvm/InitializePasses.h"
+#include "llvm/PassRegistry.h"
+#include "llvm/Transforms/Utils.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/LoopSimplify.h"
 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
@@ -646,8 +648,7 @@ class SPIRVStructurizer : public FunctionPass {
       Builder.SetInsertPoint(Header->getTerminator());
 
       auto MergeAddress = BlockAddress::get(BB.getParent(), &BB);
-      SmallVector<Value *, 1> Args = {MergeAddress};
-      Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+      createOpSelectMerge(&Builder, MergeAddress);
 
       Modified = true;
     }
@@ -769,10 +770,9 @@ class SPIRVStructurizer : public FunctionPass {
       BasicBlock *Merge = Candidates[0];
 
       auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
-      SmallVector<Value *, 1> Args = {MergeAddress};
       IRBuilder<> Builder(&BB);
       Builder.SetInsertPoint(BB.getTerminator());
-      Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+      createOpSelectMerge(&Builder, MergeAddress);
     }
 
     return Modified;
@@ -1105,8 +1105,7 @@ class SPIRVStructurizer : public FunctionPass {
         Builder.SetInsertPoint(Header->getTerminator());
 
         auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
-        SmallVector<Value *, 1> Args = {MergeAddress};
-        Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+        createOpSelectMerge(&Builder, MergeAddress);
         continue;
       }
 
@@ -1120,8 +1119,7 @@ class SPIRVStructurizer : public FunctionPass {
       Builder.SetInsertPoint(Header->getTerminator());
 
       auto MergeAddress = BlockAddress::get(NewMerge->getParent(), NewMerge);
-      SmallVector<Value *, 1> Args = {MergeAddress};
-      Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+      createOpSelectMerge(&Builder, MergeAddress);
     }
 
     return Modified;
@@ -1208,6 +1206,27 @@ class SPIRVStructurizer : public FunctionPass {
     AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
     FunctionPass::getAnalysisUsage(AU);
   }
+
+  void createOpSelectMerge(IRBuilder<> *Builder, BlockAddress *MergeAddress) {
+    Instruction *BBTerminatorInst = Builder->GetInsertBlock()->getTerminator();
+
+    MDNode *MDNode = BBTerminatorInst->getMetadata("hlsl.controlflow.hint");
+
+    ConstantInt *BranchHint = llvm::ConstantInt::get(Builder->getInt32Ty(), 0);
+
+    if (MDNode) {
+      assert(MDNode->getNumOperands() == 2 &&
+             "invalid metadata hlsl.controlflow.hint");
+      BranchHint = mdconst::extract<ConstantInt>(MDNode->getOperand(1));
+
+      assert(BranchHint && "invalid metadata value for hlsl.controlflow.hint");
+    }
+
+    llvm::SmallVector<llvm::Value *, 2> Args = {MergeAddress, BranchHint};
+
+    Builder->CreateIntrinsic(Intrinsic::spv_selection_merge,
+                             {MergeAddress->getType()}, {Args});
+  }
 };
 } // namespace llvm
 
@@ -1229,8 +1248,11 @@ FunctionPass *llvm::createSPIRVStructurizerPass() {
 
 PreservedAnalyses SPIRVStructurizerWrapper::run(Function &F,
                                                 FunctionAnalysisManager &AF) {
-  FunctionPass *StructurizerPass = createSPIRVStructurizerPass();
-  if (!StructurizerPass->runOnFunction(F))
+
+  auto FPM = legacy::FunctionPassManager(F.getParent());
+  FPM.add(createSPIRVStructurizerPass());
+
+  if (!FPM.run(F))
     return PreservedAnalyses::all();
   PreservedAnalyses PA;
   PA.preserveSet<CFGAnalyses>();
diff --git a/llvm/test/CodeGen/DirectX/HLSLControlFlowHint.ll b/llvm/test/CodeGen/DirectX/HLSLControlFlowHint.ll
new file mode 100644
index 00000000000000..6a5274429930ea
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/HLSLControlFlowHint.ll
@@ -0,0 +1,98 @@
+; RUN: opt -S -dxil-op-lower -dxil-translate-metadata -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+; This test make sure LLVM metadata is being translated into DXIL.
+
+
+; CHECK: define i32 @test_branch(i32 %X)
+; CHECK-NOT: hlsl.controlflow.hint
+; CHECK: br i1 %cmp, label %if.then, label %if.else, !dx.controlflow.hints [[HINT_BRANCH:![0-9]+]]
+define i32 @test_branch(i32 %X) {
+entry:
+  %X.addr = alloca i32, align 4
+  %resp = alloca i32, align 4
+  store i32 %X, ptr %X.addr, align 4
+  %0 = load i32, ptr %X.addr, align 4
+  %cmp = icmp sgt i32 %0, 0
+  br i1 ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jan 8, 2025

@llvm/pr-subscribers-clang-codegen

Author: None (joaosaffran)

Changes
  • Adding the changes from PRs:
    • #116331
    • #121852
  • Fixes test tools/dxil-dis/debug-info.ll
  • Address some missed comments in the previous PR

Patch is 29.51 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122157.diff

14 Files Affected:

  • (modified) clang/include/clang/Basic/Attr.td (+10)
  • (modified) clang/lib/CodeGen/CGStmt.cpp (+6)
  • (modified) clang/lib/CodeGen/CodeGenFunction.cpp (+25-1)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+4)
  • (modified) clang/lib/Sema/SemaStmtAttr.cpp (+8)
  • (added) clang/test/AST/HLSL/HLSLControlFlowHint.hlsl (+43)
  • (added) clang/test/CodeGenHLSL/HLSLControlFlowHint.hlsl (+48)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1-1)
  • (modified) llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp (+36)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+23-6)
  • (modified) llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp (+33-11)
  • (added) llvm/test/CodeGen/DirectX/HLSLControlFlowHint.ll (+98)
  • (added) llvm/test/CodeGen/SPIRV/structurizer/HLSLControlFlowHint-pass-check.ll (+90)
  • (added) llvm/test/CodeGen/SPIRV/structurizer/HLSLControlFlowHint.ll (+91)
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 12faf06597008e..6d7f65ab2c6135 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4335,6 +4335,16 @@ def HLSLLoopHint: StmtAttr {
   let Documentation = [HLSLLoopHintDocs, HLSLUnrollHintDocs];
 }
 
+def HLSLControlFlowHint: StmtAttr {
+  /// [branch]
+  /// [flatten]
+  let Spellings = [Microsoft<"branch">, Microsoft<"flatten">];
+  let Subjects = SubjectList<[IfStmt],
+                              ErrorDiag, "'if' statements">;
+  let LangOpts = [HLSL];
+  let Documentation = [InternalOnly];
+}
+
 def CapturedRecord : InheritableAttr {
   // This attribute has no spellings as it is only ever created implicitly.
   let Spellings = [];
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index a87c50b8a1cbbf..c8ff48fc733125 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -757,6 +757,8 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
   bool noinline = false;
   bool alwaysinline = false;
   bool noconvergent = false;
+  HLSLControlFlowHintAttr::Spelling flattenOrBranch =
+      HLSLControlFlowHintAttr::SpellingNotCalculated;
   const CallExpr *musttail = nullptr;
 
   for (const auto *A : S.getAttrs()) {
@@ -788,6 +790,9 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
         Builder.CreateAssumption(AssumptionVal);
       }
     } break;
+    case attr::HLSLControlFlowHint: {
+      flattenOrBranch = cast<HLSLControlFlowHintAttr>(A)->getSemanticSpelling();
+    } break;
     }
   }
   SaveAndRestore save_nomerge(InNoMergeAttributedStmt, nomerge);
@@ -795,6 +800,7 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
   SaveAndRestore save_alwaysinline(InAlwaysInlineAttributedStmt, alwaysinline);
   SaveAndRestore save_noconvergent(InNoConvergentAttributedStmt, noconvergent);
   SaveAndRestore save_musttail(MustTailCall, musttail);
+  SaveAndRestore save_flattenOrBranch(HLSLControlFlowAttr, flattenOrBranch);
   EmitStmt(S.getSubStmt(), S.getAttrs());
 }
 
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index af58fa64f86585..176e8386d933d6 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -40,6 +40,7 @@
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/FPEnv.h"
+#include "llvm/IR/Instruction.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/MDBuilder.h"
@@ -2083,7 +2084,30 @@ void CodeGenFunction::EmitBranchOnBoolExpr(
     Weights = createProfileWeights(TrueCount, CurrentCount - TrueCount);
   }
 
-  Builder.CreateCondBr(CondV, TrueBlock, FalseBlock, Weights, Unpredictable);
+  llvm::Instruction *BrInst = Builder.CreateCondBr(CondV, TrueBlock, FalseBlock,
+                                                   Weights, Unpredictable);
+  switch (HLSLControlFlowAttr) {
+  case HLSLControlFlowHintAttr::Microsoft_branch:
+  case HLSLControlFlowHintAttr::Microsoft_flatten: {
+    llvm::MDBuilder MDHelper(CGM.getLLVMContext());
+
+    llvm::ConstantInt *BranchHintConstant =
+        HLSLControlFlowAttr ==
+                HLSLControlFlowHintAttr::Spelling::Microsoft_branch
+            ? llvm::ConstantInt::get(CGM.Int32Ty, 1)
+            : llvm::ConstantInt::get(CGM.Int32Ty, 2);
+
+    SmallVector<llvm::Metadata *, 2> Vals(
+        {MDHelper.createString("hlsl.controlflow.hint"),
+         MDHelper.createConstant(BranchHintConstant)});
+    BrInst->setMetadata("hlsl.controlflow.hint",
+                        llvm::MDNode::get(CGM.getLLVMContext(), Vals));
+    break;
+  }
+  // This is required to avoid warnings during compilation
+  case HLSLControlFlowHintAttr::SpellingNotCalculated:
+    break;
+  }
 }
 
 /// ErrorUnsupported - Print out an error that codegen doesn't support the
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index f2240f8308ce38..bc612a0bfb32ba 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -615,6 +615,10 @@ class CodeGenFunction : public CodeGenTypeCache {
   /// True if the current statement has noconvergent attribute.
   bool InNoConvergentAttributedStmt = false;
 
+  /// HLSL Branch attribute.
+  HLSLControlFlowHintAttr::Spelling HLSLControlFlowAttr =
+      HLSLControlFlowHintAttr::SpellingNotCalculated;
+
   // The CallExpr within the current statement that the musttail attribute
   // applies to.  nullptr if there is no 'musttail' on the current statement.
   const CallExpr *MustTailCall = nullptr;
diff --git a/clang/lib/Sema/SemaStmtAttr.cpp b/clang/lib/Sema/SemaStmtAttr.cpp
index 106e2430de901e..422d8abc1028aa 100644
--- a/clang/lib/Sema/SemaStmtAttr.cpp
+++ b/clang/lib/Sema/SemaStmtAttr.cpp
@@ -619,6 +619,12 @@ static Attr *handleHLSLLoopHintAttr(Sema &S, Stmt *St, const ParsedAttr &A,
   return ::new (S.Context) HLSLLoopHintAttr(S.Context, A, UnrollFactor);
 }
 
+static Attr *handleHLSLControlFlowHint(Sema &S, Stmt *St, const ParsedAttr &A,
+                                       SourceRange Range) {
+
+  return ::new (S.Context) HLSLControlFlowHintAttr(S.Context, A);
+}
+
 static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
                                   SourceRange Range) {
   if (A.isInvalid() || A.getKind() == ParsedAttr::IgnoredAttribute)
@@ -655,6 +661,8 @@ static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
     return handleLoopHintAttr(S, St, A, Range);
   case ParsedAttr::AT_HLSLLoopHint:
     return handleHLSLLoopHintAttr(S, St, A, Range);
+  case ParsedAttr::AT_HLSLControlFlowHint:
+    return handleHLSLControlFlowHint(S, St, A, Range);
   case ParsedAttr::AT_OpenCLUnrollHint:
     return handleOpenCLUnrollHint(S, St, A, Range);
   case ParsedAttr::AT_Suppress:
diff --git a/clang/test/AST/HLSL/HLSLControlFlowHint.hlsl b/clang/test/AST/HLSL/HLSLControlFlowHint.hlsl
new file mode 100644
index 00000000000000..a36779c05fbc93
--- /dev/null
+++ b/clang/test/AST/HLSL/HLSLControlFlowHint.hlsl
@@ -0,0 +1,43 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -ast-dump %s | FileCheck %s
+
+// CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used branch 'int (int)'
+// CHECK: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
+// CHECK-NEXT: -HLSLControlFlowHintAttr 0x{{[0-9A-Fa-f]+}} <{{.*}}> branch
+export int branch(int X){
+    int resp;
+    [branch] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used flatten 'int (int)'
+// CHECK: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
+// CHECK-NEXT: -HLSLControlFlowHintAttr 0x{{[0-9A-Fa-f]+}} <{{.*}}> flatten
+export int flatten(int X){
+    int resp;
+    [flatten] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used no_attr 'int (int)'
+// CHECK-NOT: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
+// CHECK-NOT: -HLSLControlFlowHintAttr
+export int no_attr(int X){
+    int resp;
+    if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
diff --git a/clang/test/CodeGenHLSL/HLSLControlFlowHint.hlsl b/clang/test/CodeGenHLSL/HLSLControlFlowHint.hlsl
new file mode 100644
index 00000000000000..aa13b275818502
--- /dev/null
+++ b/clang/test/CodeGenHLSL/HLSLControlFlowHint.hlsl
@@ -0,0 +1,48 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv-vulkan-library %s -fnative-half-type -emit-llvm -o - | FileCheck %s
+
+// CHECK: define {{.*}} i32 {{.*}}test_branch{{.*}}(i32 {{.*}} [[VALD:%.*]])
+// CHECK: [[PARAM:%.*]] = load i32, ptr [[VALD]].addr, align 4
+// CHECK: [[CMP:%.*]] = icmp sgt i32 [[PARAM]], 0
+// CHECK: br i1 [[CMP]], label %if.then, label %if.else, !hlsl.controlflow.hint [[HINT_BRANCH:![0-9]+]]
+export int test_branch(int X){
+    int resp;
+    [branch] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: define {{.*}} i32 {{.*}}test_flatten{{.*}}(i32 {{.*}} [[VALD:%.*]])
+// CHECK: [[PARAM:%.*]] = load i32, ptr [[VALD]].addr, align 4
+// CHECK: [[CMP:%.*]] = icmp sgt i32 [[PARAM]], 0
+// CHECK: br i1 [[CMP]], label %if.then, label %if.else, !hlsl.controlflow.hint [[HINT_FLATTEN:![0-9]+]]
+export int test_flatten(int X){
+    int resp;
+    [flatten] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: define {{.*}} i32 {{.*}}test_no_attr{{.*}}(i32 {{.*}} [[VALD:%.*]])
+// CHECK-NOT: !hlsl.controlflow.hint
+export int test_no_attr(int X){
+    int resp;
+    if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+//CHECK: [[HINT_BRANCH]] = !{!"hlsl.controlflow.hint", i32 1}
+//CHECK: [[HINT_FLATTEN]] = !{!"hlsl.controlflow.hint", i32 2}
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index b4d2dce66a6f0b..37057271b6c284 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -33,7 +33,7 @@ let TargetPrefix = "spv" in {
   def int_spv_ptrcast : Intrinsic<[llvm_any_ty], [llvm_any_ty, llvm_metadata_ty, llvm_i32_ty], [ImmArg<ArgIndex<2>>]>;
   def int_spv_switch : Intrinsic<[], [llvm_any_ty, llvm_vararg_ty]>;
   def int_spv_loop_merge : Intrinsic<[], [llvm_vararg_ty]>;
-  def int_spv_selection_merge : Intrinsic<[], [llvm_vararg_ty]>;
+  def int_spv_selection_merge : Intrinsic<[], [llvm_any_ty, llvm_i32_ty], [ImmArg<ArgIndex<1>>]>;
   def int_spv_cmpxchg : Intrinsic<[llvm_i32_ty], [llvm_any_ty, llvm_vararg_ty]>;
   def int_spv_unreachable : Intrinsic<[], []>;
   def int_spv_alloca : Intrinsic<[llvm_any_ty], [llvm_i8_ty], [ImmArg<ArgIndex<0>>]>;
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index 5afe6b2d2883db..5fd5c226eef894 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -15,12 +15,14 @@
 #include "llvm/ADT/Twine.h"
 #include "llvm/Analysis/DXILMetadataAnalysis.h"
 #include "llvm/Analysis/DXILResource.h"
+#include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/DiagnosticPrinter.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/InitializePasses.h"
@@ -300,6 +302,38 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
   return constructEntryMetadata(nullptr, nullptr, RMD, Properties, Ctx);
 }
 
+// TODO: We might need to refactor this to be more generic,
+// in case we need more metadata to be replaced.
+static void translateBranchMetadata(Module &M) {
+  for (Function &F : M) {
+    for (BasicBlock &BB : F) {
+      Instruction *BBTerminatorInst = BB.getTerminator();
+
+      MDNode *HlslControlFlowMD =
+          BBTerminatorInst->getMetadata("hlsl.controlflow.hint");
+
+      if (!HlslControlFlowMD)
+        continue;
+
+      assert(HlslControlFlowMD->getNumOperands() == 2 &&
+             "invalid operands for hlsl.controlflow.hint");
+
+      MDBuilder MDHelper(M.getContext());
+      ConstantInt *Op1 =
+          mdconst::extract<ConstantInt>(HlslControlFlowMD->getOperand(1));
+
+      SmallVector<llvm::Metadata *, 2> Vals(
+          ArrayRef<Metadata *>{MDHelper.createString("dx.controlflow.hints"),
+                               MDHelper.createConstant(Op1)});
+
+      MDNode *MDNode = llvm::MDNode::get(M.getContext(), Vals);
+
+      BBTerminatorInst->setMetadata("dx.controlflow.hints", MDNode);
+      BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr);
+    }
+  }
+}
+
 static void translateMetadata(Module &M, DXILBindingMap &DBM,
                               DXILResourceTypeMap &DRTM,
                               const Resources &MDResources,
@@ -372,6 +406,7 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M,
   const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M);
 
   translateMetadata(M, DBM, DRTM, MDResources, ShaderFlags, MMDI);
+  translateBranchMetadata(M);
 
   return PreservedAnalyses::all();
 }
@@ -409,6 +444,7 @@ class DXILTranslateMetadataLegacy : public ModulePass {
         getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
 
     translateMetadata(M, DBM, DRTM, MDResources, ShaderFlags, MMDI);
+    translateBranchMetadata(M);
     return true;
   }
 };
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 28c9b81db51f51..237f71a1b70e50 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -33,6 +33,7 @@
 #include "llvm/CodeGen/TargetOpcodes.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
 
 #define DEBUG_TYPE "spirv-isel"
 
@@ -45,6 +46,17 @@ using ExtInstList =
 
 namespace {
 
+llvm::SPIRV::SelectionControl::SelectionControl
+getSelectionOperandForImm(int Imm) {
+  if (Imm == 2)
+    return SPIRV::SelectionControl::Flatten;
+  if (Imm == 1)
+    return SPIRV::SelectionControl::DontFlatten;
+  if (Imm == 0)
+    return SPIRV::SelectionControl::None;
+  llvm_unreachable("Invalid immediate");
+}
+
 #define GET_GLOBALISEL_PREDICATE_BITSET
 #include "SPIRVGenGlobalISel.inc"
 #undef GET_GLOBALISEL_PREDICATE_BITSET
@@ -2818,12 +2830,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     }
     return MIB.constrainAllUses(TII, TRI, RBI);
   }
-  case Intrinsic::spv_loop_merge:
-  case Intrinsic::spv_selection_merge: {
-    const auto Opcode = IID == Intrinsic::spv_selection_merge
-                            ? SPIRV::OpSelectionMerge
-                            : SPIRV::OpLoopMerge;
-    auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode));
+  case Intrinsic::spv_loop_merge: {
+    auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLoopMerge));
     for (unsigned i = 1; i < I.getNumExplicitOperands(); ++i) {
       assert(I.getOperand(i).isMBB());
       MIB.addMBB(I.getOperand(i).getMBB());
@@ -2831,6 +2839,15 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     MIB.addImm(SPIRV::SelectionControl::None);
     return MIB.constrainAllUses(TII, TRI, RBI);
   }
+  case Intrinsic::spv_selection_merge: {
+    auto MIB =
+        BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSelectionMerge));
+    assert(I.getOperand(1).isMBB() &&
+           "operand 1 to spv_selection_merge must be a basic block");
+    MIB.addMBB(I.getOperand(1).getMBB());
+    MIB.addImm(getSelectionOperandForImm(I.getOperand(2).getImm()));
+    return MIB.constrainAllUses(TII, TRI, RBI);
+  }
   case Intrinsic::spv_cmpxchg:
     return selectAtomicCmpXchg(ResVReg, ResType, I);
   case Intrinsic::spv_unreachable:
diff --git a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
index 336cde4e782246..2e4343c7922f1c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
@@ -18,14 +18,16 @@
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/CodeGen/IntrinsicLowering.h"
-#include "llvm/IR/Analysis.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/IR/LegacyPassManager.h"
 #include "llvm/InitializePasses.h"
+#include "llvm/PassRegistry.h"
+#include "llvm/Transforms/Utils.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/LoopSimplify.h"
 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
@@ -646,8 +648,7 @@ class SPIRVStructurizer : public FunctionPass {
       Builder.SetInsertPoint(Header->getTerminator());
 
       auto MergeAddress = BlockAddress::get(BB.getParent(), &BB);
-      SmallVector<Value *, 1> Args = {MergeAddress};
-      Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+      createOpSelectMerge(&Builder, MergeAddress);
 
       Modified = true;
     }
@@ -769,10 +770,9 @@ class SPIRVStructurizer : public FunctionPass {
       BasicBlock *Merge = Candidates[0];
 
       auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
-      SmallVector<Value *, 1> Args = {MergeAddress};
       IRBuilder<> Builder(&BB);
       Builder.SetInsertPoint(BB.getTerminator());
-      Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+      createOpSelectMerge(&Builder, MergeAddress);
     }
 
     return Modified;
@@ -1105,8 +1105,7 @@ class SPIRVStructurizer : public FunctionPass {
         Builder.SetInsertPoint(Header->getTerminator());
 
         auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
-        SmallVector<Value *, 1> Args = {MergeAddress};
-        Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+        createOpSelectMerge(&Builder, MergeAddress);
         continue;
       }
 
@@ -1120,8 +1119,7 @@ class SPIRVStructurizer : public FunctionPass {
       Builder.SetInsertPoint(Header->getTerminator());
 
       auto MergeAddress = BlockAddress::get(NewMerge->getParent(), NewMerge);
-      SmallVector<Value *, 1> Args = {MergeAddress};
-      Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+      createOpSelectMerge(&Builder, MergeAddress);
     }
 
     return Modified;
@@ -1208,6 +1206,27 @@ class SPIRVStructurizer : public FunctionPass {
     AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
     FunctionPass::getAnalysisUsage(AU);
   }
+
+  void createOpSelectMerge(IRBuilder<> *Builder, BlockAddress *MergeAddress) {
+    Instruction *BBTerminatorInst = Builder->GetInsertBlock()->getTerminator();
+
+    MDNode *MDNode = BBTerminatorInst->getMetadata("hlsl.controlflow.hint");
+
+    ConstantInt *BranchHint = llvm::ConstantInt::get(Builder->getInt32Ty(), 0);
+
+    if (MDNode) {
+      assert(MDNode->getNumOperands() == 2 &&
+             "invalid metadata hlsl.controlflow.hint");
+      BranchHint = mdconst::extract<ConstantInt>(MDNode->getOperand(1));
+
+      assert(BranchHint && "invalid metadata value for hlsl.controlflow.hint");
+    }
+
+    llvm::SmallVector<llvm::Value *, 2> Args = {MergeAddress, BranchHint};
+
+    Builder->CreateIntrinsic(Intrinsic::spv_selection_merge,
+                             {MergeAddress->getType()}, {Args});
+  }
 };
 } // namespace llvm
 
@@ -1229,8 +1248,11 @@ FunctionPass *llvm::createSPIRVStructurizerPass() {
 
 PreservedAnalyses SPIRVStructurizerWrapper::run(Function &F,
                                                 FunctionAnalysisManager &AF) {
-  FunctionPass *StructurizerPass = createSPIRVStructurizerPass();
-  if (!StructurizerPass->runOnFunction(F))
+
+  auto FPM = legacy::FunctionPassManager(F.getParent());
+  FPM.add(createSPIRVStructurizerPass());
+
+  if (!FPM.run(F))
     return PreservedAnalyses::all();
   PreservedAnalyses PA;
   PA.preserveSet<CFGAnalyses>();
diff --git a/llvm/test/CodeGen/DirectX/HLSLControlFlowHint.ll b/llvm/test/CodeGen/DirectX/HLSLControlFlowHint.ll
new file mode 100644
index 00000000000000..6a5274429930ea
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/HLSLControlFlowHint.ll
@@ -0,0 +1,98 @@
+; RUN: opt -S -dxil-op-lower -dxil-translate-metadata -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+; This test make sure LLVM metadata is being translated into DXIL.
+
+
+; CHECK: define i32 @test_branch(i32 %X)
+; CHECK-NOT: hlsl.controlflow.hint
+; CHECK: br i1 %cmp, label %if.then, label %if.else, !dx.controlflow.hints [[HINT_BRANCH:![0-9]+]]
+define i32 @test_branch(i32 %X) {
+entry:
+  %X.addr = alloca i32, align 4
+  %resp = alloca i32, align 4
+  store i32 %X, ptr %X.addr, align 4
+  %0 = load i32, ptr %X.addr, align 4
+  %cmp = icmp sgt i32 %0, 0
+  br i1 ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jan 8, 2025

@llvm/pr-subscribers-llvm-ir

Author: None (joaosaffran)

Changes
  • Adding the changes from PRs:
    • #116331
    • #121852
  • Fixes test tools/dxil-dis/debug-info.ll
  • Address some missed comments in the previous PR

Patch is 29.51 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122157.diff

14 Files Affected:

  • (modified) clang/include/clang/Basic/Attr.td (+10)
  • (modified) clang/lib/CodeGen/CGStmt.cpp (+6)
  • (modified) clang/lib/CodeGen/CodeGenFunction.cpp (+25-1)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+4)
  • (modified) clang/lib/Sema/SemaStmtAttr.cpp (+8)
  • (added) clang/test/AST/HLSL/HLSLControlFlowHint.hlsl (+43)
  • (added) clang/test/CodeGenHLSL/HLSLControlFlowHint.hlsl (+48)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1-1)
  • (modified) llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp (+36)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+23-6)
  • (modified) llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp (+33-11)
  • (added) llvm/test/CodeGen/DirectX/HLSLControlFlowHint.ll (+98)
  • (added) llvm/test/CodeGen/SPIRV/structurizer/HLSLControlFlowHint-pass-check.ll (+90)
  • (added) llvm/test/CodeGen/SPIRV/structurizer/HLSLControlFlowHint.ll (+91)
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 12faf06597008e..6d7f65ab2c6135 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4335,6 +4335,16 @@ def HLSLLoopHint: StmtAttr {
   let Documentation = [HLSLLoopHintDocs, HLSLUnrollHintDocs];
 }
 
+def HLSLControlFlowHint: StmtAttr {
+  /// [branch]
+  /// [flatten]
+  let Spellings = [Microsoft<"branch">, Microsoft<"flatten">];
+  let Subjects = SubjectList<[IfStmt],
+                              ErrorDiag, "'if' statements">;
+  let LangOpts = [HLSL];
+  let Documentation = [InternalOnly];
+}
+
 def CapturedRecord : InheritableAttr {
   // This attribute has no spellings as it is only ever created implicitly.
   let Spellings = [];
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index a87c50b8a1cbbf..c8ff48fc733125 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -757,6 +757,8 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
   bool noinline = false;
   bool alwaysinline = false;
   bool noconvergent = false;
+  HLSLControlFlowHintAttr::Spelling flattenOrBranch =
+      HLSLControlFlowHintAttr::SpellingNotCalculated;
   const CallExpr *musttail = nullptr;
 
   for (const auto *A : S.getAttrs()) {
@@ -788,6 +790,9 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
         Builder.CreateAssumption(AssumptionVal);
       }
     } break;
+    case attr::HLSLControlFlowHint: {
+      flattenOrBranch = cast<HLSLControlFlowHintAttr>(A)->getSemanticSpelling();
+    } break;
     }
   }
   SaveAndRestore save_nomerge(InNoMergeAttributedStmt, nomerge);
@@ -795,6 +800,7 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
   SaveAndRestore save_alwaysinline(InAlwaysInlineAttributedStmt, alwaysinline);
   SaveAndRestore save_noconvergent(InNoConvergentAttributedStmt, noconvergent);
   SaveAndRestore save_musttail(MustTailCall, musttail);
+  SaveAndRestore save_flattenOrBranch(HLSLControlFlowAttr, flattenOrBranch);
   EmitStmt(S.getSubStmt(), S.getAttrs());
 }
 
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index af58fa64f86585..176e8386d933d6 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -40,6 +40,7 @@
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/FPEnv.h"
+#include "llvm/IR/Instruction.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/MDBuilder.h"
@@ -2083,7 +2084,30 @@ void CodeGenFunction::EmitBranchOnBoolExpr(
     Weights = createProfileWeights(TrueCount, CurrentCount - TrueCount);
   }
 
-  Builder.CreateCondBr(CondV, TrueBlock, FalseBlock, Weights, Unpredictable);
+  llvm::Instruction *BrInst = Builder.CreateCondBr(CondV, TrueBlock, FalseBlock,
+                                                   Weights, Unpredictable);
+  switch (HLSLControlFlowAttr) {
+  case HLSLControlFlowHintAttr::Microsoft_branch:
+  case HLSLControlFlowHintAttr::Microsoft_flatten: {
+    llvm::MDBuilder MDHelper(CGM.getLLVMContext());
+
+    llvm::ConstantInt *BranchHintConstant =
+        HLSLControlFlowAttr ==
+                HLSLControlFlowHintAttr::Spelling::Microsoft_branch
+            ? llvm::ConstantInt::get(CGM.Int32Ty, 1)
+            : llvm::ConstantInt::get(CGM.Int32Ty, 2);
+
+    SmallVector<llvm::Metadata *, 2> Vals(
+        {MDHelper.createString("hlsl.controlflow.hint"),
+         MDHelper.createConstant(BranchHintConstant)});
+    BrInst->setMetadata("hlsl.controlflow.hint",
+                        llvm::MDNode::get(CGM.getLLVMContext(), Vals));
+    break;
+  }
+  // This is required to avoid warnings during compilation
+  case HLSLControlFlowHintAttr::SpellingNotCalculated:
+    break;
+  }
 }
 
 /// ErrorUnsupported - Print out an error that codegen doesn't support the
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index f2240f8308ce38..bc612a0bfb32ba 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -615,6 +615,10 @@ class CodeGenFunction : public CodeGenTypeCache {
   /// True if the current statement has noconvergent attribute.
   bool InNoConvergentAttributedStmt = false;
 
+  /// HLSL Branch attribute.
+  HLSLControlFlowHintAttr::Spelling HLSLControlFlowAttr =
+      HLSLControlFlowHintAttr::SpellingNotCalculated;
+
   // The CallExpr within the current statement that the musttail attribute
   // applies to.  nullptr if there is no 'musttail' on the current statement.
   const CallExpr *MustTailCall = nullptr;
diff --git a/clang/lib/Sema/SemaStmtAttr.cpp b/clang/lib/Sema/SemaStmtAttr.cpp
index 106e2430de901e..422d8abc1028aa 100644
--- a/clang/lib/Sema/SemaStmtAttr.cpp
+++ b/clang/lib/Sema/SemaStmtAttr.cpp
@@ -619,6 +619,12 @@ static Attr *handleHLSLLoopHintAttr(Sema &S, Stmt *St, const ParsedAttr &A,
   return ::new (S.Context) HLSLLoopHintAttr(S.Context, A, UnrollFactor);
 }
 
+static Attr *handleHLSLControlFlowHint(Sema &S, Stmt *St, const ParsedAttr &A,
+                                       SourceRange Range) {
+
+  return ::new (S.Context) HLSLControlFlowHintAttr(S.Context, A);
+}
+
 static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
                                   SourceRange Range) {
   if (A.isInvalid() || A.getKind() == ParsedAttr::IgnoredAttribute)
@@ -655,6 +661,8 @@ static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
     return handleLoopHintAttr(S, St, A, Range);
   case ParsedAttr::AT_HLSLLoopHint:
     return handleHLSLLoopHintAttr(S, St, A, Range);
+  case ParsedAttr::AT_HLSLControlFlowHint:
+    return handleHLSLControlFlowHint(S, St, A, Range);
   case ParsedAttr::AT_OpenCLUnrollHint:
     return handleOpenCLUnrollHint(S, St, A, Range);
   case ParsedAttr::AT_Suppress:
diff --git a/clang/test/AST/HLSL/HLSLControlFlowHint.hlsl b/clang/test/AST/HLSL/HLSLControlFlowHint.hlsl
new file mode 100644
index 00000000000000..a36779c05fbc93
--- /dev/null
+++ b/clang/test/AST/HLSL/HLSLControlFlowHint.hlsl
@@ -0,0 +1,43 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -ast-dump %s | FileCheck %s
+
+// CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used branch 'int (int)'
+// CHECK: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
+// CHECK-NEXT: -HLSLControlFlowHintAttr 0x{{[0-9A-Fa-f]+}} <{{.*}}> branch
+export int branch(int X){
+    int resp;
+    [branch] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used flatten 'int (int)'
+// CHECK: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
+// CHECK-NEXT: -HLSLControlFlowHintAttr 0x{{[0-9A-Fa-f]+}} <{{.*}}> flatten
+export int flatten(int X){
+    int resp;
+    [flatten] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used no_attr 'int (int)'
+// CHECK-NOT: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
+// CHECK-NOT: -HLSLControlFlowHintAttr
+export int no_attr(int X){
+    int resp;
+    if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
diff --git a/clang/test/CodeGenHLSL/HLSLControlFlowHint.hlsl b/clang/test/CodeGenHLSL/HLSLControlFlowHint.hlsl
new file mode 100644
index 00000000000000..aa13b275818502
--- /dev/null
+++ b/clang/test/CodeGenHLSL/HLSLControlFlowHint.hlsl
@@ -0,0 +1,48 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv-vulkan-library %s -fnative-half-type -emit-llvm -o - | FileCheck %s
+
+// CHECK: define {{.*}} i32 {{.*}}test_branch{{.*}}(i32 {{.*}} [[VALD:%.*]])
+// CHECK: [[PARAM:%.*]] = load i32, ptr [[VALD]].addr, align 4
+// CHECK: [[CMP:%.*]] = icmp sgt i32 [[PARAM]], 0
+// CHECK: br i1 [[CMP]], label %if.then, label %if.else, !hlsl.controlflow.hint [[HINT_BRANCH:![0-9]+]]
+export int test_branch(int X){
+    int resp;
+    [branch] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: define {{.*}} i32 {{.*}}test_flatten{{.*}}(i32 {{.*}} [[VALD:%.*]])
+// CHECK: [[PARAM:%.*]] = load i32, ptr [[VALD]].addr, align 4
+// CHECK: [[CMP:%.*]] = icmp sgt i32 [[PARAM]], 0
+// CHECK: br i1 [[CMP]], label %if.then, label %if.else, !hlsl.controlflow.hint [[HINT_FLATTEN:![0-9]+]]
+export int test_flatten(int X){
+    int resp;
+    [flatten] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: define {{.*}} i32 {{.*}}test_no_attr{{.*}}(i32 {{.*}} [[VALD:%.*]])
+// CHECK-NOT: !hlsl.controlflow.hint
+export int test_no_attr(int X){
+    int resp;
+    if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+//CHECK: [[HINT_BRANCH]] = !{!"hlsl.controlflow.hint", i32 1}
+//CHECK: [[HINT_FLATTEN]] = !{!"hlsl.controlflow.hint", i32 2}
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index b4d2dce66a6f0b..37057271b6c284 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -33,7 +33,7 @@ let TargetPrefix = "spv" in {
   def int_spv_ptrcast : Intrinsic<[llvm_any_ty], [llvm_any_ty, llvm_metadata_ty, llvm_i32_ty], [ImmArg<ArgIndex<2>>]>;
   def int_spv_switch : Intrinsic<[], [llvm_any_ty, llvm_vararg_ty]>;
   def int_spv_loop_merge : Intrinsic<[], [llvm_vararg_ty]>;
-  def int_spv_selection_merge : Intrinsic<[], [llvm_vararg_ty]>;
+  def int_spv_selection_merge : Intrinsic<[], [llvm_any_ty, llvm_i32_ty], [ImmArg<ArgIndex<1>>]>;
   def int_spv_cmpxchg : Intrinsic<[llvm_i32_ty], [llvm_any_ty, llvm_vararg_ty]>;
   def int_spv_unreachable : Intrinsic<[], []>;
   def int_spv_alloca : Intrinsic<[llvm_any_ty], [llvm_i8_ty], [ImmArg<ArgIndex<0>>]>;
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index 5afe6b2d2883db..5fd5c226eef894 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -15,12 +15,14 @@
 #include "llvm/ADT/Twine.h"
 #include "llvm/Analysis/DXILMetadataAnalysis.h"
 #include "llvm/Analysis/DXILResource.h"
+#include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/DiagnosticPrinter.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/InitializePasses.h"
@@ -300,6 +302,38 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
   return constructEntryMetadata(nullptr, nullptr, RMD, Properties, Ctx);
 }
 
+// TODO: We might need to refactor this to be more generic,
+// in case we need more metadata to be replaced.
+static void translateBranchMetadata(Module &M) {
+  for (Function &F : M) {
+    for (BasicBlock &BB : F) {
+      Instruction *BBTerminatorInst = BB.getTerminator();
+
+      MDNode *HlslControlFlowMD =
+          BBTerminatorInst->getMetadata("hlsl.controlflow.hint");
+
+      if (!HlslControlFlowMD)
+        continue;
+
+      assert(HlslControlFlowMD->getNumOperands() == 2 &&
+             "invalid operands for hlsl.controlflow.hint");
+
+      MDBuilder MDHelper(M.getContext());
+      ConstantInt *Op1 =
+          mdconst::extract<ConstantInt>(HlslControlFlowMD->getOperand(1));
+
+      SmallVector<llvm::Metadata *, 2> Vals(
+          ArrayRef<Metadata *>{MDHelper.createString("dx.controlflow.hints"),
+                               MDHelper.createConstant(Op1)});
+
+      MDNode *MDNode = llvm::MDNode::get(M.getContext(), Vals);
+
+      BBTerminatorInst->setMetadata("dx.controlflow.hints", MDNode);
+      BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr);
+    }
+  }
+}
+
 static void translateMetadata(Module &M, DXILBindingMap &DBM,
                               DXILResourceTypeMap &DRTM,
                               const Resources &MDResources,
@@ -372,6 +406,7 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M,
   const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M);
 
   translateMetadata(M, DBM, DRTM, MDResources, ShaderFlags, MMDI);
+  translateBranchMetadata(M);
 
   return PreservedAnalyses::all();
 }
@@ -409,6 +444,7 @@ class DXILTranslateMetadataLegacy : public ModulePass {
         getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
 
     translateMetadata(M, DBM, DRTM, MDResources, ShaderFlags, MMDI);
+    translateBranchMetadata(M);
     return true;
   }
 };
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 28c9b81db51f51..237f71a1b70e50 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -33,6 +33,7 @@
 #include "llvm/CodeGen/TargetOpcodes.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
 
 #define DEBUG_TYPE "spirv-isel"
 
@@ -45,6 +46,17 @@ using ExtInstList =
 
 namespace {
 
+llvm::SPIRV::SelectionControl::SelectionControl
+getSelectionOperandForImm(int Imm) {
+  if (Imm == 2)
+    return SPIRV::SelectionControl::Flatten;
+  if (Imm == 1)
+    return SPIRV::SelectionControl::DontFlatten;
+  if (Imm == 0)
+    return SPIRV::SelectionControl::None;
+  llvm_unreachable("Invalid immediate");
+}
+
 #define GET_GLOBALISEL_PREDICATE_BITSET
 #include "SPIRVGenGlobalISel.inc"
 #undef GET_GLOBALISEL_PREDICATE_BITSET
@@ -2818,12 +2830,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     }
     return MIB.constrainAllUses(TII, TRI, RBI);
   }
-  case Intrinsic::spv_loop_merge:
-  case Intrinsic::spv_selection_merge: {
-    const auto Opcode = IID == Intrinsic::spv_selection_merge
-                            ? SPIRV::OpSelectionMerge
-                            : SPIRV::OpLoopMerge;
-    auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode));
+  case Intrinsic::spv_loop_merge: {
+    auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLoopMerge));
     for (unsigned i = 1; i < I.getNumExplicitOperands(); ++i) {
       assert(I.getOperand(i).isMBB());
       MIB.addMBB(I.getOperand(i).getMBB());
@@ -2831,6 +2839,15 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     MIB.addImm(SPIRV::SelectionControl::None);
     return MIB.constrainAllUses(TII, TRI, RBI);
   }
+  case Intrinsic::spv_selection_merge: {
+    auto MIB =
+        BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSelectionMerge));
+    assert(I.getOperand(1).isMBB() &&
+           "operand 1 to spv_selection_merge must be a basic block");
+    MIB.addMBB(I.getOperand(1).getMBB());
+    MIB.addImm(getSelectionOperandForImm(I.getOperand(2).getImm()));
+    return MIB.constrainAllUses(TII, TRI, RBI);
+  }
   case Intrinsic::spv_cmpxchg:
     return selectAtomicCmpXchg(ResVReg, ResType, I);
   case Intrinsic::spv_unreachable:
diff --git a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
index 336cde4e782246..2e4343c7922f1c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
@@ -18,14 +18,16 @@
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/CodeGen/IntrinsicLowering.h"
-#include "llvm/IR/Analysis.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/IR/LegacyPassManager.h"
 #include "llvm/InitializePasses.h"
+#include "llvm/PassRegistry.h"
+#include "llvm/Transforms/Utils.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/LoopSimplify.h"
 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
@@ -646,8 +648,7 @@ class SPIRVStructurizer : public FunctionPass {
       Builder.SetInsertPoint(Header->getTerminator());
 
       auto MergeAddress = BlockAddress::get(BB.getParent(), &BB);
-      SmallVector<Value *, 1> Args = {MergeAddress};
-      Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+      createOpSelectMerge(&Builder, MergeAddress);
 
       Modified = true;
     }
@@ -769,10 +770,9 @@ class SPIRVStructurizer : public FunctionPass {
       BasicBlock *Merge = Candidates[0];
 
       auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
-      SmallVector<Value *, 1> Args = {MergeAddress};
       IRBuilder<> Builder(&BB);
       Builder.SetInsertPoint(BB.getTerminator());
-      Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+      createOpSelectMerge(&Builder, MergeAddress);
     }
 
     return Modified;
@@ -1105,8 +1105,7 @@ class SPIRVStructurizer : public FunctionPass {
         Builder.SetInsertPoint(Header->getTerminator());
 
         auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
-        SmallVector<Value *, 1> Args = {MergeAddress};
-        Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+        createOpSelectMerge(&Builder, MergeAddress);
         continue;
       }
 
@@ -1120,8 +1119,7 @@ class SPIRVStructurizer : public FunctionPass {
       Builder.SetInsertPoint(Header->getTerminator());
 
       auto MergeAddress = BlockAddress::get(NewMerge->getParent(), NewMerge);
-      SmallVector<Value *, 1> Args = {MergeAddress};
-      Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+      createOpSelectMerge(&Builder, MergeAddress);
     }
 
     return Modified;
@@ -1208,6 +1206,27 @@ class SPIRVStructurizer : public FunctionPass {
     AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
     FunctionPass::getAnalysisUsage(AU);
   }
+
+  void createOpSelectMerge(IRBuilder<> *Builder, BlockAddress *MergeAddress) {
+    Instruction *BBTerminatorInst = Builder->GetInsertBlock()->getTerminator();
+
+    MDNode *MDNode = BBTerminatorInst->getMetadata("hlsl.controlflow.hint");
+
+    ConstantInt *BranchHint = llvm::ConstantInt::get(Builder->getInt32Ty(), 0);
+
+    if (MDNode) {
+      assert(MDNode->getNumOperands() == 2 &&
+             "invalid metadata hlsl.controlflow.hint");
+      BranchHint = mdconst::extract<ConstantInt>(MDNode->getOperand(1));
+
+      assert(BranchHint && "invalid metadata value for hlsl.controlflow.hint");
+    }
+
+    llvm::SmallVector<llvm::Value *, 2> Args = {MergeAddress, BranchHint};
+
+    Builder->CreateIntrinsic(Intrinsic::spv_selection_merge,
+                             {MergeAddress->getType()}, {Args});
+  }
 };
 } // namespace llvm
 
@@ -1229,8 +1248,11 @@ FunctionPass *llvm::createSPIRVStructurizerPass() {
 
 PreservedAnalyses SPIRVStructurizerWrapper::run(Function &F,
                                                 FunctionAnalysisManager &AF) {
-  FunctionPass *StructurizerPass = createSPIRVStructurizerPass();
-  if (!StructurizerPass->runOnFunction(F))
+
+  auto FPM = legacy::FunctionPassManager(F.getParent());
+  FPM.add(createSPIRVStructurizerPass());
+
+  if (!FPM.run(F))
     return PreservedAnalyses::all();
   PreservedAnalyses PA;
   PA.preserveSet<CFGAnalyses>();
diff --git a/llvm/test/CodeGen/DirectX/HLSLControlFlowHint.ll b/llvm/test/CodeGen/DirectX/HLSLControlFlowHint.ll
new file mode 100644
index 00000000000000..6a5274429930ea
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/HLSLControlFlowHint.ll
@@ -0,0 +1,98 @@
+; RUN: opt -S -dxil-op-lower -dxil-translate-metadata -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+; This test make sure LLVM metadata is being translated into DXIL.
+
+
+; CHECK: define i32 @test_branch(i32 %X)
+; CHECK-NOT: hlsl.controlflow.hint
+; CHECK: br i1 %cmp, label %if.then, label %if.else, !dx.controlflow.hints [[HINT_BRANCH:![0-9]+]]
+define i32 @test_branch(i32 %X) {
+entry:
+  %X.addr = alloca i32, align 4
+  %resp = alloca i32, align 4
+  store i32 %X, ptr %X.addr, align 4
+  %0 = load i32, ptr %X.addr, align 4
+  %cmp = icmp sgt i32 %0, 0
+  br i1 ...
[truncated]

@joaosaffran joaosaffran merged commit 380bb51 into llvm:main Jan 13, 2025
16 of 17 checks passed
kazutakahirata pushed a commit to kazutakahirata/llvm-project that referenced this pull request Jan 13, 2025
…122157)

- Adding the changes from PRs: 
  - llvm#116331 
  - llvm#121852 
- Fixes test `tools/dxil-dis/debug-info.ll`
- Address some missed comments in the previous PR

---------

Co-authored-by: joaosaffran <[email protected]>
@joaosaffran joaosaffran deleted the hlsl/ms-if-test-fix branch February 26, 2025 22:47
@damyanp damyanp moved this to Closed in HLSL Support Apr 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:DirectX backend:SPIR-V clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support llvm:ir
Projects
Status: Closed
Development

Successfully merging this pull request may close these issues.

3 participants