-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[DirectX] Lower @llvm.dx.typedBufferLoad
to DXIL ops
#104252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DirectX] Lower @llvm.dx.typedBufferLoad
to DXIL ops
#104252
Conversation
Created using spr 1.3.5-bogner [skip ci]
Created using spr 1.3.5-bogner
@llvm/pr-subscribers-llvm-ir @llvm/pr-subscribers-backend-directx Author: Justin Bogner (bogner) ChangesThe Note that this change includes a bit of a hack in how it deals with Full diff: https://github.com/llvm/llvm-project/pull/104252.diff 7 Files Affected:
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index ca3682fa47767..d817b610fa71a 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -30,6 +30,10 @@ def int_dx_handle_fromBinding
[llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty],
[IntrNoMem]>;
+def int_dx_typedBufferLoad
+ : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
+ [llvm_any_ty, llvm_i32_ty]>;
+
// Cast between target extension handle types and dxil-style opaque handles
def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 31fee04d82158..b114148f84e84 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -40,7 +40,10 @@ def Int64Ty : DXILOpParamType;
def HalfTy : DXILOpParamType;
def FloatTy : DXILOpParamType;
def DoubleTy : DXILOpParamType;
-def ResRetTy : DXILOpParamType;
+def ResRetHalfTy : DXILOpParamType;
+def ResRetFloatTy : DXILOpParamType;
+def ResRetInt16Ty : DXILOpParamType;
+def ResRetInt32Ty : DXILOpParamType;
def HandleTy : DXILOpParamType;
def ResBindTy : DXILOpParamType;
def ResPropsTy : DXILOpParamType;
@@ -683,6 +686,17 @@ def CreateHandle : DXILOp<57, createHandle> {
let stages = [Stages<DXIL1_0, [all_stages]>];
}
+def BufferLoad : DXILOp<68, bufferLoad> {
+ let Doc = "reads from a TypedBuffer";
+ // Handle, Coord0, Coord1
+ let arguments = [HandleTy, Int32Ty, Int32Ty];
+ let result = OverloadTy;
+ let overloads =
+ [Overloads<DXIL1_0,
+ [ResRetHalfTy, ResRetFloatTy, ResRetInt16Ty, ResRetInt32Ty]>];
+ let stages = [Stages<DXIL1_0, [all_stages]>];
+}
+
def ThreadId : DXILOp<93, threadId> {
let Doc = "Reads the thread ID";
let LLVMIntrinsic = int_dx_thread_id;
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 692af1b359ced..246e32c264dc9 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -120,8 +120,15 @@ static OverloadKind getOverloadKind(Type *Ty) {
}
case Type::PointerTyID:
return OverloadKind::UserDefineType;
- case Type::StructTyID:
+ case Type::StructTyID: {
+ // TODO: This is a hack. As described in DXILEmitter.cpp, we need to rework
+ // how we're handling overloads and remove the `OverloadKind` proxy enum.
+ StructType *ST = cast<StructType>(Ty);
+ if (ST->hasName() && ST->getName().starts_with("dx.types.ResRet"))
+ return getOverloadKind(ST->getElementType(0));
+
return OverloadKind::ObjectType;
+ }
default:
llvm_unreachable("invalid overload type");
return OverloadKind::VOID;
@@ -195,10 +202,11 @@ static StructType *getOrCreateStructType(StringRef Name,
return StructType::create(Ctx, EltTys, Name);
}
-static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) {
- OverloadKind Kind = getOverloadKind(OverloadTy);
+static StructType *getResRetType(Type *ElementTy) {
+ LLVMContext &Ctx = ElementTy->getContext();
+ OverloadKind Kind = getOverloadKind(ElementTy);
std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet.");
- Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy,
+ Type *FieldTypes[5] = {ElementTy, ElementTy, ElementTy, ElementTy,
Type::getInt32Ty(Ctx)};
return getOrCreateStructType(TypeName, FieldTypes, Ctx);
}
@@ -248,8 +256,14 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
return Type::getInt64Ty(Ctx);
case OpParamType::OverloadTy:
return OverloadTy;
- case OpParamType::ResRetTy:
- return getResRetType(OverloadTy, Ctx);
+ case OpParamType::ResRetHalfTy:
+ return getResRetType(Type::getHalfTy(Ctx));
+ case OpParamType::ResRetFloatTy:
+ return getResRetType(Type::getFloatTy(Ctx));
+ case OpParamType::ResRetInt16Ty:
+ return getResRetType(Type::getInt16Ty(Ctx));
+ case OpParamType::ResRetInt32Ty:
+ return getResRetType(Type::getInt32Ty(Ctx));
case OpParamType::HandleTy:
return getHandleType(Ctx);
case OpParamType::ResBindTy:
@@ -391,6 +405,7 @@ Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode,
return makeOpError(OpCode, "Wrong number of arguments");
OverloadTy = Args[ArgIndex]->getType();
}
+
FunctionType *DXILOpFT =
getDXILOpFunctionType(OpCode, M.getContext(), OverloadTy);
@@ -451,6 +466,10 @@ CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef<Value *> Args,
return *Result;
}
+StructType *DXILOpBuilder::getResRetType(Type *ElementTy) {
+ return ::getResRetType(ElementTy);
+}
+
StructType *DXILOpBuilder::getHandleType() {
return ::getHandleType(IRB.getContext());
}
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index 4a55a8ac9eadb..a68f0c43f67af 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -46,6 +46,8 @@ class DXILOpBuilder {
Expected<CallInst *> tryCreateOp(dxil::OpCode Op, ArrayRef<Value *> Args,
Type *RetTy = nullptr);
+ /// Get a `%dx.types.ResRet` type with the given element type.
+ StructType *getResRetType(Type *ElementTy);
/// Get the `%dx.types.Handle` type.
StructType *getHandleType();
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index ab18c57efa307..46dfc905b5875 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -236,6 +236,59 @@ class OpLowerer {
lowerToBindAndAnnotateHandle(F);
}
+ void lowerTypedBufferLoad(Function &F) {
+ IRBuilder<> &IRB = OpBuilder.getIRB();
+ Type *Int32Ty = IRB.getInt32Ty();
+
+ replaceFunction(F, [&](CallInst *CI) -> Error {
+ IRB.SetInsertPoint(CI);
+
+ Value *Handle =
+ createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
+ Value *Index0 = CI->getArgOperand(1);
+ Value *Index1 = UndefValue::get(Int32Ty);
+ Type *RetTy = OpBuilder.getResRetType(CI->getType()->getScalarType());
+
+ std::array<Value *, 3> Args{Handle, Index0, Index1};
+ Expected<CallInst *> OpCall =
+ OpBuilder.tryCreateOp(OpCode::BufferLoad, Args, RetTy);
+ if (Error E = OpCall.takeError())
+ return E;
+
+ std::array<Value *, 4> Extracts = {};
+
+ // We've switched the return type from a vector to a struct, but at this
+ // point most vectors have probably already been scalarized. Try to
+ // forward arguments directly rather than inserting into and immediately
+ // extracting from a vector.
+ for (Use &U : make_early_inc_range(CI->uses()))
+ if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser()))
+ if (auto *Index = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
+ size_t IndexVal = Index->getZExtValue();
+ assert(IndexVal < 4 && "Index into buffer load out of range");
+ if (!Extracts[IndexVal])
+ Extracts[IndexVal] = IRB.CreateExtractValue(*OpCall, IndexVal);
+ EEI->replaceAllUsesWith(Extracts[IndexVal]);
+ EEI->eraseFromParent();
+ }
+
+ // If there are still uses then we need to create a vector.
+ if (!CI->use_empty()) {
+ for (int I = 0, E = 4; I != E; ++I)
+ if (!Extracts[I])
+ Extracts[I] = IRB.CreateExtractValue(*OpCall, I);
+
+ Value *Vec = UndefValue::get(CI->getType());
+ for (int I = 0, E = 4; I != E; ++I)
+ Vec = IRB.CreateInsertElement(Vec, Extracts[I], I);
+ CI->replaceAllUsesWith(Vec);
+ }
+
+ CI->eraseFromParent();
+ return Error::success();
+ });
+ }
+
bool lowerIntrinsics() {
bool Updated = false;
@@ -253,6 +306,10 @@ class OpLowerer {
#include "DXILOperation.inc"
case Intrinsic::dx_handle_fromBinding:
lowerHandleFromBinding(F);
+ break;
+ case Intrinsic::dx_typedBufferLoad:
+ lowerTypedBufferLoad(F);
+ break;
}
Updated = true;
}
diff --git a/llvm/test/CodeGen/DirectX/BufferLoad.ll b/llvm/test/CodeGen/DirectX/BufferLoad.ll
new file mode 100644
index 0000000000000..c3bb96dbdf909
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/BufferLoad.ll
@@ -0,0 +1,102 @@
+; RUN: opt -S -dxil-op-lower %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.6-compute"
+
+declare void @scalar_user(float)
+declare void @vector_user(<4 x float>)
+
+define void @loadfloats() {
+ ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+ ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+ %buffer = call target("dx.TypedBuffer", <4 x float>, 0, 0, 0)
+ @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_0_0_0(
+ i32 0, i32 0, i32 1, i32 0, i1 false)
+
+ ; The temporary casts should all have been cleaned up
+ ; CHECK-NOT: %dx.cast_handle
+
+ ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+ %data0 = call <4 x float> @llvm.dx.typedBufferLoad(
+ target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 0)
+
+ ; The extract order depends on the users, so don't enforce that here.
+ ; CHECK-DAG: extractvalue %dx.types.ResRet.f32 [[DATA0]], 0
+ %data0_0 = extractelement <4 x float> %data0, i32 0
+ ; CHECK-DAG: extractvalue %dx.types.ResRet.f32 [[DATA0]], 2
+ %data0_2 = extractelement <4 x float> %data0, i32 2
+
+ ; If all of the uses are extracts, we skip creating a vector
+ ; CHECK-NOT: insertelement
+ call void @scalar_user(float %data0_0)
+ call void @scalar_user(float %data0_2)
+
+ ; CHECK: [[DATA4:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 4, i32 undef)
+ %data4 = call <4 x float> @llvm.dx.typedBufferLoad(
+ target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 4)
+
+ ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 0
+ ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 1
+ ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 2
+ ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 3
+ ; CHECK: insertelement <4 x float> undef
+ ; CHECK: insertelement <4 x float>
+ ; CHECK: insertelement <4 x float>
+ ; CHECK: insertelement <4 x float>
+ call void @vector_user(<4 x float> %data4)
+
+ ; CHECK: [[DATA12:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 12, i32 undef)
+ %data12 = call <4 x float> @llvm.dx.typedBufferLoad(
+ target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 12)
+
+ ; CHECK: [[DATA12_3:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA12]], 3
+ %data12_3 = extractelement <4 x float> %data12, i32 3
+
+ ; If there are a mix of users we need the vector, but extracts are direct
+ ; CHECK: call void @scalar_user(float [[DATA12_3]])
+ call void @scalar_user(float %data12_3)
+ call void @vector_user(<4 x float> %data12)
+
+ ret void
+}
+
+define void @loadint() {
+ ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+ ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+ %buffer = call target("dx.TypedBuffer", <4 x i32>, 0, 0, 0)
+ @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i32_0_0_0(
+ i32 0, i32 0, i32 1, i32 0, i1 false)
+
+ ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+ %data0 = call <4 x i32> @llvm.dx.typedBufferLoad(
+ target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) %buffer, i32 0)
+
+ ret void
+}
+
+define void @loadhalf() {
+ ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+ ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+ %buffer = call target("dx.TypedBuffer", <4 x half>, 0, 0, 0)
+ @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f16_0_0_0(
+ i32 0, i32 0, i32 1, i32 0, i1 false)
+
+ ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f16 @dx.op.bufferLoad.f16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+ %data0 = call <4 x half> @llvm.dx.typedBufferLoad(
+ target("dx.TypedBuffer", <4 x half>, 0, 0, 0) %buffer, i32 0)
+
+ ret void
+}
+
+define void @loadi16() {
+ ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+ ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+ %buffer = call target("dx.TypedBuffer", <4 x i16>, 0, 0, 0)
+ @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i16_0_0_0(
+ i32 0, i32 0, i32 1, i32 0, i1 false)
+
+ ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i16 @dx.op.bufferLoad.i16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+ %data0 = call <4 x i16> @llvm.dx.typedBufferLoad(
+ target("dx.TypedBuffer", <4 x i16>, 0, 0, 0) %buffer, i32 0)
+
+ ret void
+}
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index 9cc1b5ccb8acb..332706f7e3e57 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -187,7 +187,11 @@ static StringRef getOverloadKindStr(const Record *R) {
.Case("Int8Ty", "OverloadKind::I8")
.Case("Int16Ty", "OverloadKind::I16")
.Case("Int32Ty", "OverloadKind::I32")
- .Case("Int64Ty", "OverloadKind::I64");
+ .Case("Int64Ty", "OverloadKind::I64")
+ .Case("ResRetHalfTy", "OverloadKind::HALF")
+ .Case("ResRetFloatTy", "OverloadKind::FLOAT")
+ .Case("ResRetInt16Ty", "OverloadKind::I16")
+ .Case("ResRetInt32Ty", "OverloadKind::I32");
}
/// Return a string representation of valid overload information denoted
|
Depends on #104251 |
The `@llvm.dx.typedBufferLoad` intrinsic is lowered to `@dx.op.bufferLoad`. There's some complexity here due to translating from a vector return type to a named struct and trying to avoid excessive IR coming out of that. Note that this change includes a bit of a hack in how it deals with `getOverloadKind` for the `dx.ResRet` types - we need to adjust how we deal with operation overloads to generate a table directly rather than proxy through the OverloadKind enum, but that's left for a later change here. Pull Request: llvm#104252
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, LGTM
let result = OverloadTy; | ||
let overloads = | ||
[Overloads<DXIL1_0, | ||
[ResRetHalfTy, ResRetFloatTy, ResRetInt16Ty, ResRetInt32Ty]>]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the 16-bit overloads valid in dxil 1.0? I suppose maybe they are used to represent ther the minprec types, but true 16-bit types only came in with dxil 1.2 I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding from a discussion about this quite a while ago was that 16 bit types are valid DXIL retroactively to DXIL 1.0 but it isn't actually possible to enable 16 bit types until circa SM6.2 / DXIL 1.2. I'm not 100% about that though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The 16-bit overloads were always valid in DXIL 1.0, but they didn't actually mean 16-bit types, they meant the min16{float|int|uint} types. This is one of the things that's really wonky about DXIL defining interpretations of LLVM IR that conflicted with LLVM's core definition.
I think the code here is accurate to what we need for that.
The `@llvm.dx.typedBufferLoad` intrinsic is lowered to `@dx.op.bufferLoad`. There's some complexity here due to translating from a vector return type to a named struct and trying to avoid excessive IR coming out of that. Note that this change includes a bit of a hack in how it deals with `getOverloadKind` for the `dx.ResRet` types - we need to adjust how we deal with operation overloads to generate a table directly rather than proxy through the OverloadKind enum, but that's left for a later change here. Pull Request: llvm#104252
Created using spr 1.3.5-bogner [skip ci]
Created using spr 1.3.5-bogner
All looks good to me. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, some nits.
Created using spr 1.3.5-bogner
Created using spr 1.3.5-bogner
The
@llvm.dx.typedBufferLoad
intrinsic is lowered to@dx.op.bufferLoad
.There's some complexity here in translating to scalarized IR, which I've
abstracted out into a function that should be useful for samples, gathers, and
CBuffer loads.
I've also updated the DXILResources.rst docs to match what I'm doing here and
the proposal in llvm/wg-hlsl#59. I've removed the content about stores and raw
buffers for now with the expectation that it will be added along with the work.
Note that this change includes a bit of a hack in how it deals with
getOverloadKind
for thedx.ResRet
types - we need to adjust how we dealwith operation overloads to generate a table directly rather than proxy through
the OverloadKind enum, but that's left for a later change here.
Part of #91367