Skip to content

[DirectX] Implement typedBufferLoad_checkbit #108087

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions llvm/docs/DirectX/DXILResources.rst
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,12 @@ Examples:
- ``i32``
- Index into the buffer

.. code-block:: llvm

%ret = call {<4 x float>, i1}
@llvm.dx.typedBufferLoad.checkbit.v4f32.tdx.TypedBuffer_v4f32_0_0_0t(
target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 %index)

Texture and Typed Buffer Stores
-------------------------------

Expand Down
3 changes: 3 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def int_dx_handle_fromBinding

def int_dx_typedBufferLoad
: DefaultAttrsIntrinsic<[llvm_any_ty], [llvm_any_ty, llvm_i32_ty]>;
def int_dx_typedBufferLoad_checkbit
: DefaultAttrsIntrinsic<[llvm_any_ty, llvm_i1_ty],
[llvm_any_ty, llvm_i32_ty]>;
def int_dx_typedBufferStore
: DefaultAttrsIntrinsic<[], [llvm_any_ty, llvm_i32_ty, llvm_anyvector_ty]>;

Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,15 @@ def BufferStore : DXILOp<69, bufferStore> {
let stages = [Stages<DXIL1_0, [all_stages]>];
}

def CheckAccessFullyMapped : DXILOp<71, checkAccessFullyMapped> {
let Doc = "checks whether a Sample, Gather, or Load operation "
"accessed mapped tiles in a tiled resource";
let arguments = [OverloadTy];
let result = Int1Ty;
let overloads = [Overloads<DXIL1_0, [Int32Ty]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
}

def ThreadId : DXILOp<93, threadId> {
let Doc = "Reads the thread ID";
let LLVMIntrinsic = int_dx_thread_id;
Expand Down
67 changes: 56 additions & 11 deletions llvm/lib/Target/DirectX/DXILOpLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,16 +265,50 @@ class OpLowerer {

/// Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op.
/// Since we expect to be post-scalarization, make an effort to avoid vectors.
Error replaceResRetUses(CallInst *Intrin, CallInst *Op) {
Error replaceResRetUses(CallInst *Intrin, CallInst *Op, bool HasCheckBit) {
IRBuilder<> &IRB = OpBuilder.getIRB();

Instruction *OldResult = Intrin;
Type *OldTy = Intrin->getType();

if (HasCheckBit) {
auto *ST = cast<StructType>(OldTy);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the success of cast need to be ensured by checking ST != nullptr, here and similarly on line 299?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cast asserts if it can't do the cast (unlike dyn_cast, which returns null). It's meant to be used only when it's known that it can't fail.


Value *CheckOp = nullptr;
Type *Int32Ty = IRB.getInt32Ty();
for (Use &U : make_early_inc_range(OldResult->uses())) {
if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
ArrayRef<unsigned> Indices = EVI->getIndices();
assert(Indices.size() == 1);
// We're only interested in uses of the check bit for now.
if (Indices[0] != 1)
continue;
if (!CheckOp) {
Value *NewEVI = IRB.CreateExtractValue(Op, 4);
Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
OpCode::CheckAccessFullyMapped, {NewEVI}, Int32Ty);
if (Error E = OpCall.takeError())
return E;
CheckOp = *OpCall;
}
EVI->replaceAllUsesWith(CheckOp);
EVI->eraseFromParent();
}
}

OldResult = cast<Instruction>(IRB.CreateExtractValue(Op, 0));
OldTy = ST->getElementType(0);
}

// For scalars, we just extract the first element.
if (!isa<FixedVectorType>(OldTy)) {
Value *EVI = IRB.CreateExtractValue(Op, 0);
Intrin->replaceAllUsesWith(EVI);
Intrin->eraseFromParent();
OldResult->replaceAllUsesWith(EVI);
OldResult->eraseFromParent();
if (OldResult != Intrin) {
assert(Intrin->use_empty() && "Intrinsic still has uses?");
Intrin->eraseFromParent();
}
return Error::success();
}

Expand All @@ -283,7 +317,7 @@ class OpLowerer {

// The users of the operation should all be scalarized, so we attempt to
// replace the extractelements with extractvalues directly.
for (Use &U : make_early_inc_range(Intrin->uses())) {
for (Use &U : make_early_inc_range(OldResult->uses())) {
if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser())) {
if (auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
size_t IndexVal = IndexOp->getZExtValue();
Expand Down Expand Up @@ -331,22 +365,27 @@ class OpLowerer {
// If we still have uses, then we're not fully scalarized and need to
// recreate the vector. This should only happen for things like exported
// functions from libraries.
if (!Intrin->use_empty()) {
if (!OldResult->use_empty()) {
for (int I = 0, E = N; I != E; ++I)
if (!Extracts[I])
Extracts[I] = IRB.CreateExtractValue(Op, I);

Value *Vec = UndefValue::get(OldTy);
for (int I = 0, E = N; I != E; ++I)
Vec = IRB.CreateInsertElement(Vec, Extracts[I], I);
Intrin->replaceAllUsesWith(Vec);
OldResult->replaceAllUsesWith(Vec);
}

OldResult->eraseFromParent();
if (OldResult != Intrin) {
assert(Intrin->use_empty() && "Intrinsic still has uses?");
Intrin->eraseFromParent();
}

Intrin->eraseFromParent();
return Error::success();
}

[[nodiscard]] bool lowerTypedBufferLoad(Function &F) {
[[nodiscard]] bool lowerTypedBufferLoad(Function &F, bool HasCheckBit) {
IRBuilder<> &IRB = OpBuilder.getIRB();
Type *Int32Ty = IRB.getInt32Ty();

Expand All @@ -358,14 +397,17 @@ class OpLowerer {
Value *Index0 = CI->getArgOperand(1);
Value *Index1 = UndefValue::get(Int32Ty);

Type *NewRetTy = OpBuilder.getResRetType(CI->getType()->getScalarType());
Type *OldTy = CI->getType();
if (HasCheckBit)
OldTy = cast<StructType>(OldTy)->getElementType(0);
Type *NewRetTy = OpBuilder.getResRetType(OldTy->getScalarType());

std::array<Value *, 3> Args{Handle, Index0, Index1};
Expected<CallInst *> OpCall =
OpBuilder.tryCreateOp(OpCode::BufferLoad, Args, NewRetTy);
if (Error E = OpCall.takeError())
return E;
if (Error E = replaceResRetUses(CI, *OpCall))
if (Error E = replaceResRetUses(CI, *OpCall, HasCheckBit))
return E;

return Error::success();
Expand Down Expand Up @@ -434,7 +476,10 @@ class OpLowerer {
HasErrors |= lowerHandleFromBinding(F);
break;
case Intrinsic::dx_typedBufferLoad:
HasErrors |= lowerTypedBufferLoad(F);
HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/false);
break;
case Intrinsic::dx_typedBufferLoad_checkbit:
HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/true);
break;
case Intrinsic::dx_typedBufferStore:
HasErrors |= lowerTypedBufferStore(F);
Expand Down
22 changes: 22 additions & 0 deletions llvm/test/CodeGen/DirectX/BufferLoad.ll
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ target triple = "dxil-pc-shadermodel6.6-compute"

declare void @scalar_user(float)
declare void @vector_user(<4 x float>)
declare void @check_user(i1)

define void @loadv4f32() {
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
Expand Down Expand Up @@ -128,6 +129,27 @@ define void @loadv2f32() {
ret void
}

define void @loadv4f32_checkbit() {
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
%buffer = call target("dx.TypedBuffer", <4 x float>, 0, 0, 0)
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_0_0_0(
i32 0, i32 0, i32 1, i32 0, i1 false)

; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
%data0 = call {<4 x float>, i1} @llvm.dx.typedBufferLoad.checkbit.f32(
target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 0)

; CHECK: [[STATUS:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA0]], 4
; CHECK: [[MAPPED:%.*]] = call i1 @dx.op.checkAccessFullyMapped.i32(i32 71, i32 [[STATUS]]
%check = extractvalue {<4 x float>, i1} %data0, 1

; CHECK: call void @check_user(i1 [[MAPPED]])
call void @check_user(i1 %check)

ret void
}

define void @loadv4i32() {
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
Expand Down
Loading