Skip to content

[DirectX] Encapsulate DXILOpLowering's state into a class. NFC #104248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

bogner
Copy link
Contributor

@bogner bogner commented Aug 14, 2024

This introduces an anonymous class "OpLowerer" to help with lowering DXIL ops,
and moves the DXILOpBuilder there instead of creating a new one for every
operation. DXILOpBuilder is also changed to own its IRBuilder, since that makes
it simpler to ensure that it isn't misused.

bogner added 2 commits August 15, 2024 00:27
Created using spr 1.3.5-bogner

[skip ci]
Created using spr 1.3.5-bogner
@llvmbot
Copy link
Member

llvmbot commented Aug 14, 2024

@llvm/pr-subscribers-backend-directx

Author: Justin Bogner (bogner)

Changes

This introduces an anonymous class "OpLowerer" to help with lowering DXIL ops,
and moves the DXILOpBuilder there instead of creating a new one for every
operation. DXILOpBuilder is also changed to own its IRBuilder, since that makes
it simpler to ensure that it isn't misused.


Full diff: https://github.com/llvm/llvm-project/pull/104248.diff

3 Files Affected:

  • (modified) llvm/lib/Target/DirectX/DXILOpBuilder.cpp (+3-4)
  • (modified) llvm/lib/Target/DirectX/DXILOpBuilder.h (+6-3)
  • (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+64-45)
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 987437619f08e..7d2b40cc515cc 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -11,7 +11,6 @@
 
 #include "DXILOpBuilder.h"
 #include "DXILConstants.h"
-#include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/DXILABI.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -335,7 +334,7 @@ namespace dxil {
 // Triple is well-formed or that the target is supported since these checks
 // would have been done at the time the module M is constructed in the earlier
 // stages of compilation.
-DXILOpBuilder::DXILOpBuilder(Module &M, IRBuilderBase &B) : M(M), B(B) {
+DXILOpBuilder::DXILOpBuilder(Module &M) : M(M), IRB(M.getContext()) {
   Triple TT(Triple(M.getTargetTriple()));
   DXILVersion = TT.getDXILVersion();
   ShaderStage = TT.getEnvironment();
@@ -417,10 +416,10 @@ Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode,
 
   // We need to inject the opcode as the first argument.
   SmallVector<Value *> OpArgs;
-  OpArgs.push_back(B.getInt32(llvm::to_underlying(OpCode)));
+  OpArgs.push_back(IRB.getInt32(llvm::to_underlying(OpCode)));
   OpArgs.append(Args.begin(), Args.end());
 
-  return B.CreateCall(DXILFn, OpArgs);
+  return IRB.CreateCall(DXILFn, OpArgs);
 }
 
 CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef<Value *> Args,
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index 5d83357f7a2e9..483d5ddc8b619 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -14,8 +14,9 @@
 
 #include "DXILConstants.h"
 #include "llvm/ADT/SmallVector.h"
-#include "llvm/TargetParser/Triple.h"
+#include "llvm/IR/IRBuilder.h"
 #include "llvm/Support/Error.h"
+#include "llvm/TargetParser/Triple.h"
 
 namespace llvm {
 class Module;
@@ -29,7 +30,9 @@ namespace dxil {
 
 class DXILOpBuilder {
 public:
-  DXILOpBuilder(Module &M, IRBuilderBase &B);
+  DXILOpBuilder(Module &M);
+
+  IRBuilder<> &getIRB() { return IRB; }
 
   /// Create a call instruction for the given DXIL op. The arguments
   /// must be valid for an overload of the operation.
@@ -51,7 +54,7 @@ class DXILOpBuilder {
                                   Type *OverloadType = nullptr);
 
   Module &M;
-  IRBuilderBase &B;
+  IRBuilder<> IRB;
   VersionTuple DXILVersion;
   Triple::EnvironmentType ShaderStage;
 };
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 5f84cdcfda6de..e458720fcd6e9 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -73,67 +73,84 @@ static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
   return NewOperands;
 }
 
-static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
-  IRBuilder<> B(M.getContext());
-  DXILOpBuilder OpBuilder(M, B);
-  for (User *U : make_early_inc_range(F.users())) {
-    CallInst *CI = dyn_cast<CallInst>(U);
-    if (!CI)
-      continue;
-
-    SmallVector<Value *> Args;
-    B.SetInsertPoint(CI);
-    if (isVectorArgExpansion(F)) {
-      SmallVector<Value *> NewArgs = argVectorFlatten(CI, B);
-      Args.append(NewArgs.begin(), NewArgs.end());
-    } else
-      Args.append(CI->arg_begin(), CI->arg_end());
-
-    Expected<CallInst *> OpCallOrErr = OpBuilder.tryCreateOp(DXILOp, Args,
-                                                             F.getReturnType());
-    if (Error E = OpCallOrErr.takeError()) {
-      std::string Message(toString(std::move(E)));
-      DiagnosticInfoUnsupported Diag(*CI->getFunction(), Message,
-                                     CI->getDebugLoc());
-      M.getContext().diagnose(Diag);
-      continue;
+namespace {
+class OpLowerer {
+  Module &M;
+  DXILOpBuilder OpBuilder;
+
+public:
+  OpLowerer(Module &M) : M(M), OpBuilder(M) {}
+
+  void replaceFunction(Function &F,
+                       llvm::function_ref<Error(CallInst *CI)> ReplaceCall) {
+    for (User *U : make_early_inc_range(F.users())) {
+      CallInst *CI = dyn_cast<CallInst>(U);
+      if (!CI)
+        continue;
+
+      if (Error E = ReplaceCall(CI)) {
+        std::string Message(toString(std::move(E)));
+        DiagnosticInfoUnsupported Diag(*CI->getFunction(), Message,
+                                       CI->getDebugLoc());
+        M.getContext().diagnose(Diag);
+        continue;
+      }
     }
-    CallInst *OpCall = *OpCallOrErr;
+    if (F.user_empty())
+      F.eraseFromParent();
+  }
 
-    CI->replaceAllUsesWith(OpCall);
-    CI->eraseFromParent();
+  void replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp) {
+    bool IsVectorArgExpansion = isVectorArgExpansion(F);
+    replaceFunction(F, [&](CallInst *CI) -> Error {
+      SmallVector<Value *> Args;
+      OpBuilder.getIRB().SetInsertPoint(CI);
+      if (IsVectorArgExpansion) {
+        SmallVector<Value *> NewArgs = argVectorFlatten(CI, OpBuilder.getIRB());
+        Args.append(NewArgs.begin(), NewArgs.end());
+      } else
+        Args.append(CI->arg_begin(), CI->arg_end());
+
+      Expected<CallInst *> OpCall =
+          OpBuilder.tryCreateOp(DXILOp, Args, F.getReturnType());
+      if (Error E = OpCall.takeError())
+        return E;
+
+      CI->replaceAllUsesWith(*OpCall);
+      CI->eraseFromParent();
+      return Error::success();
+    });
   }
-  if (F.user_empty())
-    F.eraseFromParent();
-}
 
-static bool lowerIntrinsics(Module &M) {
-  bool Updated = false;
+  bool lowerIntrinsics() {
+    bool Updated = false;
 
-  for (Function &F : make_early_inc_range(M.functions())) {
-    if (!F.isDeclaration())
-      continue;
-    Intrinsic::ID ID = F.getIntrinsicID();
-    switch (ID) {
-    default:
-      continue;
+    for (Function &F : make_early_inc_range(M.functions())) {
+      if (!F.isDeclaration())
+        continue;
+      Intrinsic::ID ID = F.getIntrinsicID();
+      switch (ID) {
+      default:
+        continue;
 #define DXIL_OP_INTRINSIC(OpCode, Intrin)                                      \
   case Intrin:                                                                 \
-    lowerIntrinsic(OpCode, F, M);                                              \
+    replaceFunctionWithOp(F, OpCode);                                          \
     break;
 #include "DXILOperation.inc"
+      }
+      Updated = true;
     }
-    Updated = true;
+    return Updated;
   }
-  return Updated;
-}
+};
+} // namespace
 
 namespace {
 /// A pass that transforms external global definitions into declarations.
 class DXILOpLowering : public PassInfoMixin<DXILOpLowering> {
 public:
   PreservedAnalyses run(Module &M, ModuleAnalysisManager &) {
-    if (lowerIntrinsics(M))
+    if (OpLowerer(M).lowerIntrinsics())
       return PreservedAnalyses::none();
     return PreservedAnalyses::all();
   }
@@ -143,7 +160,9 @@ class DXILOpLowering : public PassInfoMixin<DXILOpLowering> {
 namespace {
 class DXILOpLoweringLegacy : public ModulePass {
 public:
-  bool runOnModule(Module &M) override { return lowerIntrinsics(M); }
+  bool runOnModule(Module &M) override {
+    return OpLowerer(M).lowerIntrinsics();
+  }
   StringRef getPassName() const override { return "DXIL Op Lowering"; }
   DXILOpLoweringLegacy() : ModulePass(ID) {}
 

bogner added a commit to bogner/llvm-project that referenced this pull request Aug 14, 2024
This introduces an anonymous class "OpLowerer" to help with lowering DXIL ops,
and moves the DXILOpBuilder there instead of creating a new one for every
operation. DXILOpBuilder is also changed to own its IRBuilder, since that makes
it simpler to ensure that it isn't misused.

Pull Request: llvm#104248
Michael137 and others added 2 commits August 20, 2024 10:43
Created using spr 1.3.5-bogner

[skip ci]
Created using spr 1.3.5-bogner
@bogner bogner changed the base branch from users/bogner/sprmain.directx-encapsulate-dxiloplowerings-state-into-a-class-nfc to main August 20, 2024 17:51
@bogner bogner merged commit e56ad22 into main Aug 20, 2024
5 of 9 checks passed
@bogner bogner deleted the users/bogner/sprdirectx-encapsulate-dxiloplowerings-state-into-a-class-nfc branch August 20, 2024 17:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

5 participants