Skip to content

[DirectX] Implement typedBufferLoad_checkbit #108087

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 11, 2024

Conversation

bogner
Copy link
Contributor

@bogner bogner commented Sep 10, 2024

This represents a typedBufferLoad that's followed by "CheckAccessFullyMapped". It returns an extra i1 representing that value.

Fixes #108085

This represents a typedBufferLoad that's followed by "CheckAccessFullyMapped".
It returns an extra `i1` representing that value.

Fixes llvm#108085
@llvmbot
Copy link
Member

llvmbot commented Sep 10, 2024

@llvm/pr-subscribers-backend-directx

@llvm/pr-subscribers-llvm-ir

Author: Justin Bogner (bogner)

Changes

This represents a typedBufferLoad that's followed by "CheckAccessFullyMapped". It returns an extra i1 representing that value.

Fixes #108085


Full diff: https://github.com/llvm/llvm-project/pull/108087.diff

5 Files Affected:

  • (modified) llvm/docs/DirectX/DXILResources.rst (+6)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+3)
  • (modified) llvm/lib/Target/DirectX/DXIL.td (+9)
  • (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+57-11)
  • (modified) llvm/test/CodeGen/DirectX/BufferLoad.ll (+22)
diff --git a/llvm/docs/DirectX/DXILResources.rst b/llvm/docs/DirectX/DXILResources.rst
index a982c3a29fcc3b..ad8ede9c59fbfa 100644
--- a/llvm/docs/DirectX/DXILResources.rst
+++ b/llvm/docs/DirectX/DXILResources.rst
@@ -361,6 +361,12 @@ Examples:
      - ``i32``
      - Index into the buffer
 
+.. code-block:: llvm
+
+   %ret = call {<4 x float>, i1}
+       @llvm.dx.typedBufferLoad.checkbit.v4f32.tdx.TypedBuffer_v4f32_0_0_0t(
+           target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 %index)
+
 Texture and Typed Buffer Stores
 -------------------------------
 
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index c36e98f040ab81..f1017bdd512496 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -32,6 +32,9 @@ def int_dx_handle_fromBinding
 
 def int_dx_typedBufferLoad
     : DefaultAttrsIntrinsic<[llvm_any_ty], [llvm_any_ty, llvm_i32_ty]>;
+def int_dx_typedBufferLoad_checkbit
+    : DefaultAttrsIntrinsic<[llvm_any_ty, llvm_i1_ty],
+                            [llvm_any_ty, llvm_i32_ty]>;
 def int_dx_typedBufferStore
     : DefaultAttrsIntrinsic<[], [llvm_any_ty, llvm_i32_ty, llvm_anyvector_ty]>;
 
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 759a58ed3930e3..902ab37bf741ed 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -719,6 +719,15 @@ def BufferStore : DXILOp<69, bufferStore> {
   let stages = [Stages<DXIL1_0, [all_stages]>];
 }
 
+def CheckAccessFullyMapped : DXILOp<71, checkAccessFullyMapped> {
+  let Doc = "checks whether a Sample, Gather, or Load operation "
+            "accessed mapped tiles in a tiled resource";
+  let arguments = [OverloadTy];
+  let result = Int1Ty;
+  let overloads = [Overloads<DXIL1_0, [Int32Ty]>];
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+}
+
 def ThreadId :  DXILOp<93, threadId> {
   let Doc = "Reads the thread ID";
   let LLVMIntrinsic = int_dx_thread_id;
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index f968cab1dccf1e..572766eb087724 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -265,16 +265,51 @@ class OpLowerer {
 
   /// Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op.
   /// Since we expect to be post-scalarization, make an effort to avoid vectors.
-  Error replaceResRetUses(CallInst *Intrin, CallInst *Op) {
+  Error replaceResRetUses(CallInst *Intrin, CallInst *Op, bool HasCheckBit) {
     IRBuilder<> &IRB = OpBuilder.getIRB();
 
+    Instruction *OldResult = Intrin;
     Type *OldTy = Intrin->getType();
 
+    if (HasCheckBit) {
+      auto *ST = cast<StructType>(OldTy);
+
+      Value *CheckOp = nullptr;
+      Type *Int32Ty = IRB.getInt32Ty();
+      for (Use &U : make_early_inc_range(OldResult->uses())) {
+        if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
+          ArrayRef<unsigned> Indices = EVI->getIndices();
+          assert(Indices.size() == 1);
+          // We're only interested in uses of the check bit for now.
+          if (Indices[0] != 1)
+            continue;
+          if (!CheckOp) {
+            Value *NewEVI = IRB.CreateExtractValue(Op, 4);
+            Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
+                OpCode::CheckAccessFullyMapped, {NewEVI}, Int32Ty);
+            if (Error E = OpCall.takeError())
+              return E;
+            CheckOp = *OpCall;
+          }
+          EVI->replaceAllUsesWith(CheckOp);
+          EVI->eraseFromParent();
+        }
+      }
+
+      OldResult =
+          cast<Instruction>(IRB.CreateExtractValue(Op, 0));
+      OldTy = ST->getElementType(0);
+    }
+
     // For scalars, we just extract the first element.
     if (!isa<FixedVectorType>(OldTy)) {
       Value *EVI = IRB.CreateExtractValue(Op, 0);
-      Intrin->replaceAllUsesWith(EVI);
-      Intrin->eraseFromParent();
+      OldResult->replaceAllUsesWith(EVI);
+      OldResult->eraseFromParent();
+      if (OldResult != Intrin) {
+        assert(Intrin->use_empty() && "Intrinsic still has uses?");
+        Intrin->eraseFromParent();
+      }
       return Error::success();
     }
 
@@ -283,7 +318,7 @@ class OpLowerer {
 
     // The users of the operation should all be scalarized, so we attempt to
     // replace the extractelements with extractvalues directly.
-    for (Use &U : make_early_inc_range(Intrin->uses())) {
+    for (Use &U : make_early_inc_range(OldResult->uses())) {
       if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser())) {
         if (auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
           size_t IndexVal = IndexOp->getZExtValue();
@@ -331,7 +366,7 @@ class OpLowerer {
     // If we still have uses, then we're not fully scalarized and need to
     // recreate the vector. This should only happen for things like exported
     // functions from libraries.
-    if (!Intrin->use_empty()) {
+    if (!OldResult->use_empty()) {
       for (int I = 0, E = N; I != E; ++I)
         if (!Extracts[I])
           Extracts[I] = IRB.CreateExtractValue(Op, I);
@@ -339,14 +374,19 @@ class OpLowerer {
       Value *Vec = UndefValue::get(OldTy);
       for (int I = 0, E = N; I != E; ++I)
         Vec = IRB.CreateInsertElement(Vec, Extracts[I], I);
-      Intrin->replaceAllUsesWith(Vec);
+      OldResult->replaceAllUsesWith(Vec);
+    }
+
+    OldResult->eraseFromParent();
+    if (OldResult != Intrin) {
+      assert(Intrin->use_empty() && "Intrinsic still has uses?");
+      Intrin->eraseFromParent();
     }
 
-    Intrin->eraseFromParent();
     return Error::success();
   }
 
-  [[nodiscard]] bool lowerTypedBufferLoad(Function &F) {
+  [[nodiscard]] bool lowerTypedBufferLoad(Function &F, bool HasCheckBit) {
     IRBuilder<> &IRB = OpBuilder.getIRB();
     Type *Int32Ty = IRB.getInt32Ty();
 
@@ -358,14 +398,17 @@ class OpLowerer {
       Value *Index0 = CI->getArgOperand(1);
       Value *Index1 = UndefValue::get(Int32Ty);
 
-      Type *NewRetTy = OpBuilder.getResRetType(CI->getType()->getScalarType());
+      Type *OldTy = CI->getType();
+      if (HasCheckBit)
+        OldTy = cast<StructType>(OldTy)->getElementType(0);
+      Type *NewRetTy = OpBuilder.getResRetType(OldTy->getScalarType());
 
       std::array<Value *, 3> Args{Handle, Index0, Index1};
       Expected<CallInst *> OpCall =
           OpBuilder.tryCreateOp(OpCode::BufferLoad, Args, NewRetTy);
       if (Error E = OpCall.takeError())
         return E;
-      if (Error E = replaceResRetUses(CI, *OpCall))
+      if (Error E = replaceResRetUses(CI, *OpCall, HasCheckBit))
         return E;
 
       return Error::success();
@@ -434,7 +477,10 @@ class OpLowerer {
         HasErrors |= lowerHandleFromBinding(F);
         break;
       case Intrinsic::dx_typedBufferLoad:
-        HasErrors |= lowerTypedBufferLoad(F);
+        HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/false);
+        break;
+      case Intrinsic::dx_typedBufferLoad_checkbit:
+        HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/true);
         break;
       case Intrinsic::dx_typedBufferStore:
         HasErrors |= lowerTypedBufferStore(F);
diff --git a/llvm/test/CodeGen/DirectX/BufferLoad.ll b/llvm/test/CodeGen/DirectX/BufferLoad.ll
index 4b9fb52f0b5299..e3a4441ad6e833 100644
--- a/llvm/test/CodeGen/DirectX/BufferLoad.ll
+++ b/llvm/test/CodeGen/DirectX/BufferLoad.ll
@@ -4,6 +4,7 @@ target triple = "dxil-pc-shadermodel6.6-compute"
 
 declare void @scalar_user(float)
 declare void @vector_user(<4 x float>)
+declare void @check_user(i1)
 
 define void @loadv4f32() {
   ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
@@ -128,6 +129,27 @@ define void @loadv2f32() {
   ret void
 }
 
+define void @loadv4f32_checkbit() {
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.TypedBuffer", <4 x float>, 0, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_0_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+  %data0 = call {<4 x float>, i1} @llvm.dx.typedBufferLoad.checkbit.f32(
+      target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 0)
+
+  ; CHECK: [[STATUS:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA0]], 4
+  ; CHECK: [[MAPPED:%.*]] = call i1 @dx.op.checkAccessFullyMapped.i32(i32 71, i32 [[STATUS]]
+  %check = extractvalue {<4 x float>, i1} %data0, 1
+
+  ; CHECK: call void @check_user(i1 [[MAPPED]])
+  call void @check_user(i1 %check)
+
+  ret void
+}
+
 define void @loadv4i32() {
   ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
   ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]

Copy link

github-actions bot commented Sep 10, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@hekota hekota left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Type *OldTy = Intrin->getType();

if (HasCheckBit) {
auto *ST = cast<StructType>(OldTy);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the success of cast need to be ensured by checking ST != nullptr, here and similarly on line 299?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cast asserts if it can't do the cast (unlike dyn_cast, which returns null). It's meant to be used only when it's known that it can't fail.

@bogner bogner merged commit 34e20f1 into llvm:main Sep 11, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

Implement the dx.typedBufferLoad.checkbit intrinsic
4 participants