-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[DXIL] Add DXIL version-specific TableGen specification and implementation of DXIL Ops #97593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ef071d2
749c069
bc96291
e8bb0c6
5d2d02d
4f3be13
9b0f40e
ab19613
ecb8d1e
1c49dd0
17a9fe1
a1155f4
962d8a8
a43b1c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,15 +15,16 @@ | |
#include "llvm/IR/Module.h" | ||
#include "llvm/Support/DXILABI.h" | ||
#include "llvm/Support/ErrorHandling.h" | ||
#include <optional> | ||
|
||
using namespace llvm; | ||
using namespace llvm::dxil; | ||
|
||
constexpr StringLiteral DXILOpNamePrefix = "dx.op."; | ||
|
||
namespace { | ||
|
||
enum OverloadKind : uint16_t { | ||
UNDEFINED = 0, | ||
VOID = 1, | ||
HALF = 1 << 1, | ||
FLOAT = 1 << 2, | ||
|
@@ -36,9 +37,27 @@ enum OverloadKind : uint16_t { | |
UserDefineType = 1 << 9, | ||
ObjectType = 1 << 10, | ||
}; | ||
struct Version { | ||
unsigned Major = 0; | ||
unsigned Minor = 0; | ||
}; | ||
|
||
struct OpOverload { | ||
Version DXILVersion; | ||
uint16_t ValidTys; | ||
}; | ||
} // namespace | ||
|
||
struct OpStage { | ||
Version DXILVersion; | ||
uint32_t ValidStages; | ||
}; | ||
|
||
struct OpAttribute { | ||
Version DXILVersion; | ||
uint32_t ValidAttrs; | ||
}; | ||
|
||
static const char *getOverloadTypeName(OverloadKind Kind) { | ||
switch (Kind) { | ||
case OverloadKind::HALF: | ||
|
@@ -58,12 +77,13 @@ static const char *getOverloadTypeName(OverloadKind Kind) { | |
case OverloadKind::I64: | ||
return "i64"; | ||
case OverloadKind::VOID: | ||
case OverloadKind::UNDEFINED: | ||
return "void"; | ||
case OverloadKind::ObjectType: | ||
case OverloadKind::UserDefineType: | ||
break; | ||
} | ||
llvm_unreachable("invalid overload type for name"); | ||
return "void"; | ||
} | ||
|
||
static OverloadKind getOverloadKind(Type *Ty) { | ||
|
@@ -131,8 +151,9 @@ struct OpCodeProperty { | |
dxil::OpCodeClass OpCodeClass; | ||
// Offset in DXILOpCodeClassNameTable. | ||
unsigned OpCodeClassNameOffset; | ||
uint16_t OverloadTys; | ||
llvm::Attribute::AttrKind FuncAttr; | ||
llvm::SmallVector<OpOverload> Overloads; | ||
llvm::SmallVector<OpStage> Stages; | ||
llvm::SmallVector<OpAttribute> Attributes; | ||
int OverloadParamIndex; // parameter index which control the overload. | ||
// When < 0, should be only 1 overload type. | ||
unsigned NumOfParameters; // Number of parameters include return value. | ||
|
@@ -221,6 +242,45 @@ static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) { | |
return nullptr; | ||
} | ||
|
||
static ShaderKind getShaderKindEnum(Triple::EnvironmentType EnvType) { | ||
switch (EnvType) { | ||
case Triple::Pixel: | ||
return ShaderKind::pixel; | ||
case Triple::Vertex: | ||
return ShaderKind::vertex; | ||
case Triple::Geometry: | ||
return ShaderKind::geometry; | ||
case Triple::Hull: | ||
return ShaderKind::hull; | ||
case Triple::Domain: | ||
return ShaderKind::domain; | ||
case Triple::Compute: | ||
return ShaderKind::compute; | ||
case Triple::Library: | ||
return ShaderKind::library; | ||
case Triple::RayGeneration: | ||
return ShaderKind::raygeneration; | ||
case Triple::Intersection: | ||
return ShaderKind::intersection; | ||
case Triple::AnyHit: | ||
return ShaderKind::anyhit; | ||
case Triple::ClosestHit: | ||
return ShaderKind::closesthit; | ||
case Triple::Miss: | ||
return ShaderKind::miss; | ||
case Triple::Callable: | ||
return ShaderKind::callable; | ||
case Triple::Mesh: | ||
return ShaderKind::mesh; | ||
case Triple::Amplification: | ||
return ShaderKind::amplification; | ||
default: | ||
break; | ||
} | ||
llvm_unreachable( | ||
"Shader Kind Not Found - Invalid DXIL Environment Specified"); | ||
} | ||
|
||
/// Construct DXIL function type. This is the type of a function with | ||
/// the following prototype | ||
/// OverloadType dx.op.<opclass>.<return-type>(int opcode, <param types>) | ||
|
@@ -232,7 +292,7 @@ static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop, | |
Type *ReturnTy, Type *OverloadTy) { | ||
SmallVector<Type *> ArgTys; | ||
|
||
auto ParamKinds = getOpCodeParameterKind(*Prop); | ||
const ParameterKind *ParamKinds = getOpCodeParameterKind(*Prop); | ||
|
||
// Add ReturnTy as return type of the function | ||
ArgTys.emplace_back(ReturnTy); | ||
|
@@ -249,17 +309,103 @@ static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop, | |
ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false); | ||
} | ||
|
||
/// Get index of the property from PropList valid for the most recent | ||
/// DXIL version not greater than DXILVer. | ||
/// PropList is expected to be sorted in ascending order of DXIL version. | ||
template <typename T> | ||
static std::optional<size_t> getPropIndex(ArrayRef<T> PropList, | ||
const VersionTuple DXILVer) { | ||
size_t Index = PropList.size() - 1; | ||
for (auto Iter = PropList.rbegin(); Iter != PropList.rend(); | ||
Iter++, Index--) { | ||
const T &Prop = *Iter; | ||
if (VersionTuple(Prop.DXILVersion.Major, Prop.DXILVersion.Minor) <= | ||
DXILVer) { | ||
return Index; | ||
} | ||
} | ||
return std::nullopt; | ||
} | ||
|
||
namespace llvm { | ||
namespace dxil { | ||
|
||
// No extra checks on TargetTriple need be performed to verify that the | ||
// Triple is well-formed or that the target is supported since these checks | ||
// would have been done at the time the module M is constructed in the earlier | ||
// stages of compilation. | ||
DXILOpBuilder::DXILOpBuilder(Module &M, IRBuilderBase &B) : M(M), B(B) { | ||
Triple TT(Triple(M.getTargetTriple())); | ||
DXILVersion = TT.getDXILVersion(); | ||
ShaderStage = TT.getEnvironment(); | ||
// Ensure Environment type is known | ||
if (ShaderStage == Triple::UnknownEnvironment) { | ||
report_fatal_error( | ||
Twine(DXILVersion.getAsString()) + | ||
": Unknown Compilation Target Shader Stage specified ", | ||
/*gen_crash_diag*/ false); | ||
} | ||
} | ||
|
||
CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy, | ||
Type *OverloadTy, | ||
SmallVector<Value *> Args) { | ||
|
||
const OpCodeProperty *Prop = getOpCodeProperty(OpCode); | ||
std::optional<size_t> OlIndexOrErr = | ||
getPropIndex(ArrayRef(Prop->Overloads), DXILVersion); | ||
if (!OlIndexOrErr.has_value()) { | ||
report_fatal_error(Twine(getOpCodeName(OpCode)) + | ||
": No valid overloads found for DXIL Version - " + | ||
DXILVersion.getAsString(), | ||
/*gen_crash_diag*/ false); | ||
} | ||
uint16_t ValidTyMask = Prop->Overloads[*OlIndexOrErr].ValidTys; | ||
|
||
OverloadKind Kind = getOverloadKind(OverloadTy); | ||
if ((Prop->OverloadTys & (uint16_t)Kind) == 0) { | ||
report_fatal_error("Invalid Overload Type", /* gen_crash_diag=*/false); | ||
|
||
// Check if the operation supports overload types and OverloadTy is valid | ||
// per the specified types for the operation | ||
if ((ValidTyMask != OverloadKind::UNDEFINED) && | ||
(ValidTyMask & (uint16_t)Kind) == 0) { | ||
report_fatal_error(Twine("Invalid Overload Type for DXIL operation - ") + | ||
getOpCodeName(OpCode), | ||
/* gen_crash_diag=*/false); | ||
} | ||
|
||
// Perform necessary checks to ensure Opcode is valid in the targeted shader | ||
// kind | ||
std::optional<size_t> StIndexOrErr = | ||
getPropIndex(ArrayRef(Prop->Stages), DXILVersion); | ||
if (!StIndexOrErr.has_value()) { | ||
report_fatal_error(Twine(getOpCodeName(OpCode)) + | ||
": No valid stages found for DXIL Version - " + | ||
DXILVersion.getAsString(), | ||
/*gen_crash_diag*/ false); | ||
} | ||
uint16_t ValidShaderKindMask = Prop->Stages[*StIndexOrErr].ValidStages; | ||
|
||
// Ensure valid shader stage properties are specified | ||
if (ValidShaderKindMask == ShaderKind::removed) { | ||
report_fatal_error( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The clang code seems to have a diagnostic td file that these messages go into. Is it normal practice not to do that in llvm? Or is this more about the types of errors that mean they don't go into a centralized table? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There isn't really an equivalent to Clang's diagnostic tablegen in LLVM. Generally LLVM's errors are all fatal so they report this way (basically exiting the compiler). We may want to consider a larger design proposal for LLVM to allow passes to propagate errors up. This has been discussed in the past, but never really worked on. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are so-called "Backend Diagnostics" in
For now I think the assert/fatal error approach is probably fine while we get things propped up a little better, but we will want to improve this over time. |
||
Twine(DXILVersion.getAsString()) + | ||
": Unsupported Target Shader Stage for DXIL operation - " + | ||
getOpCodeName(OpCode), | ||
/*gen_crash_diag*/ false); | ||
} | ||
|
||
// Shader stage need not be validated since getShaderKindEnum() fails | ||
// for unknown shader stage. | ||
|
||
// Verify the target shader stage is valid for the DXIL operation | ||
ShaderKind ModuleStagekind = getShaderKindEnum(ShaderStage); | ||
if (!(ValidShaderKindMask & ModuleStagekind)) { | ||
auto ShaderEnvStr = Triple::getEnvironmentTypeName(ShaderStage); | ||
report_fatal_error(Twine(ShaderEnvStr) + | ||
" : Invalid Shader Stage for DXIL operation - " + | ||
getOpCodeName(OpCode) + " for DXIL Version " + | ||
DXILVersion.getAsString(), | ||
/*gen_crash_diag*/ false); | ||
} | ||
|
||
std::string DXILFnName = constructOverloadName(Kind, OverloadTy, *Prop); | ||
|
@@ -282,40 +428,18 @@ Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) { | |
// If DXIL Op has no overload parameter, just return the | ||
// precise return type specified. | ||
if (Prop->OverloadParamIndex < 0) { | ||
auto &Ctx = FT->getContext(); | ||
switch (Prop->OverloadTys) { | ||
case OverloadKind::VOID: | ||
return Type::getVoidTy(Ctx); | ||
case OverloadKind::HALF: | ||
return Type::getHalfTy(Ctx); | ||
case OverloadKind::FLOAT: | ||
return Type::getFloatTy(Ctx); | ||
case OverloadKind::DOUBLE: | ||
return Type::getDoubleTy(Ctx); | ||
case OverloadKind::I1: | ||
return Type::getInt1Ty(Ctx); | ||
case OverloadKind::I8: | ||
return Type::getInt8Ty(Ctx); | ||
case OverloadKind::I16: | ||
return Type::getInt16Ty(Ctx); | ||
case OverloadKind::I32: | ||
return Type::getInt32Ty(Ctx); | ||
case OverloadKind::I64: | ||
return Type::getInt64Ty(Ctx); | ||
default: | ||
llvm_unreachable("invalid overload type"); | ||
return nullptr; | ||
} | ||
return FT->getReturnType(); | ||
} | ||
|
||
// Prop->OverloadParamIndex is 0, overload type is FT->getReturnType(). | ||
// Consider FT->getReturnType() as default overload type, unless | ||
// Prop->OverloadParamIndex != 0. | ||
Type *OverloadType = FT->getReturnType(); | ||
if (Prop->OverloadParamIndex != 0) { | ||
// Skip Return Type. | ||
OverloadType = FT->getParamType(Prop->OverloadParamIndex - 1); | ||
} | ||
|
||
auto ParamKinds = getOpCodeParameterKind(*Prop); | ||
const ParameterKind *ParamKinds = getOpCodeParameterKind(*Prop); | ||
auto Kind = ParamKinds[Prop->OverloadParamIndex]; | ||
// For ResRet and CBufferRet, OverloadTy is in field of StructType. | ||
if (Kind == ParameterKind::CBufferRet || | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be a diagnostic? I notice in the place that this function is called, UnknownEnvironment will cause a diagnostic.