15
15
#include " llvm/IR/Module.h"
16
16
#include " llvm/Support/DXILABI.h"
17
17
#include " llvm/Support/ErrorHandling.h"
18
+ #include < optional>
18
19
19
20
using namespace llvm ;
20
21
using namespace llvm ::dxil;
21
22
22
23
constexpr StringLiteral DXILOpNamePrefix = " dx.op." ;
23
24
24
25
namespace {
25
-
26
26
enum OverloadKind : uint16_t {
27
+ UNDEFINED = 0 ,
27
28
VOID = 1 ,
28
29
HALF = 1 << 1 ,
29
30
FLOAT = 1 << 2 ,
@@ -36,9 +37,27 @@ enum OverloadKind : uint16_t {
36
37
UserDefineType = 1 << 9 ,
37
38
ObjectType = 1 << 10 ,
38
39
};
40
+ struct Version {
41
+ unsigned Major = 0 ;
42
+ unsigned Minor = 0 ;
43
+ };
39
44
45
+ struct OpOverload {
46
+ Version DXILVersion;
47
+ uint16_t ValidTys;
48
+ };
40
49
} // namespace
41
50
51
+ struct OpStage {
52
+ Version DXILVersion;
53
+ uint32_t ValidStages;
54
+ };
55
+
56
+ struct OpAttribute {
57
+ Version DXILVersion;
58
+ uint32_t ValidAttrs;
59
+ };
60
+
42
61
static const char *getOverloadTypeName (OverloadKind Kind) {
43
62
switch (Kind) {
44
63
case OverloadKind::HALF:
@@ -58,12 +77,13 @@ static const char *getOverloadTypeName(OverloadKind Kind) {
58
77
case OverloadKind::I64:
59
78
return " i64" ;
60
79
case OverloadKind::VOID:
80
+ case OverloadKind::UNDEFINED:
81
+ return " void" ;
61
82
case OverloadKind::ObjectType:
62
83
case OverloadKind::UserDefineType:
63
84
break ;
64
85
}
65
86
llvm_unreachable (" invalid overload type for name" );
66
- return " void" ;
67
87
}
68
88
69
89
static OverloadKind getOverloadKind (Type *Ty) {
@@ -131,8 +151,9 @@ struct OpCodeProperty {
131
151
dxil::OpCodeClass OpCodeClass;
132
152
// Offset in DXILOpCodeClassNameTable.
133
153
unsigned OpCodeClassNameOffset;
134
- uint16_t OverloadTys;
135
- llvm::Attribute::AttrKind FuncAttr;
154
+ llvm::SmallVector<OpOverload> Overloads;
155
+ llvm::SmallVector<OpStage> Stages;
156
+ llvm::SmallVector<OpAttribute> Attributes;
136
157
int OverloadParamIndex; // parameter index which control the overload.
137
158
// When < 0, should be only 1 overload type.
138
159
unsigned NumOfParameters; // Number of parameters include return value.
@@ -221,6 +242,45 @@ static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) {
221
242
return nullptr ;
222
243
}
223
244
245
+ static ShaderKind getShaderKindEnum (Triple::EnvironmentType EnvType) {
246
+ switch (EnvType) {
247
+ case Triple::Pixel:
248
+ return ShaderKind::pixel;
249
+ case Triple::Vertex:
250
+ return ShaderKind::vertex;
251
+ case Triple::Geometry:
252
+ return ShaderKind::geometry;
253
+ case Triple::Hull:
254
+ return ShaderKind::hull;
255
+ case Triple::Domain:
256
+ return ShaderKind::domain;
257
+ case Triple::Compute:
258
+ return ShaderKind::compute;
259
+ case Triple::Library:
260
+ return ShaderKind::library;
261
+ case Triple::RayGeneration:
262
+ return ShaderKind::raygeneration;
263
+ case Triple::Intersection:
264
+ return ShaderKind::intersection;
265
+ case Triple::AnyHit:
266
+ return ShaderKind::anyhit;
267
+ case Triple::ClosestHit:
268
+ return ShaderKind::closesthit;
269
+ case Triple::Miss:
270
+ return ShaderKind::miss;
271
+ case Triple::Callable:
272
+ return ShaderKind::callable;
273
+ case Triple::Mesh:
274
+ return ShaderKind::mesh;
275
+ case Triple::Amplification:
276
+ return ShaderKind::amplification;
277
+ default :
278
+ break ;
279
+ }
280
+ llvm_unreachable (
281
+ " Shader Kind Not Found - Invalid DXIL Environment Specified" );
282
+ }
283
+
224
284
// / Construct DXIL function type. This is the type of a function with
225
285
// / the following prototype
226
286
// / OverloadType dx.op.<opclass>.<return-type>(int opcode, <param types>)
@@ -232,7 +292,7 @@ static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
232
292
Type *ReturnTy, Type *OverloadTy) {
233
293
SmallVector<Type *> ArgTys;
234
294
235
- auto ParamKinds = getOpCodeParameterKind (*Prop);
295
+ const ParameterKind * ParamKinds = getOpCodeParameterKind (*Prop);
236
296
237
297
// Add ReturnTy as return type of the function
238
298
ArgTys.emplace_back (ReturnTy);
@@ -249,17 +309,103 @@ static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
249
309
ArgTys[0 ], ArrayRef<Type *>(&ArgTys[1 ], ArgTys.size () - 1 ), false );
250
310
}
251
311
312
+ // / Get index of the property from PropList valid for the most recent
313
+ // / DXIL version not greater than DXILVer.
314
+ // / PropList is expected to be sorted in ascending order of DXIL version.
315
+ template <typename T>
316
+ static std::optional<size_t > getPropIndex (ArrayRef<T> PropList,
317
+ const VersionTuple DXILVer) {
318
+ size_t Index = PropList.size () - 1 ;
319
+ for (auto Iter = PropList.rbegin (); Iter != PropList.rend ();
320
+ Iter++, Index--) {
321
+ const T &Prop = *Iter;
322
+ if (VersionTuple (Prop.DXILVersion .Major , Prop.DXILVersion .Minor ) <=
323
+ DXILVer) {
324
+ return Index;
325
+ }
326
+ }
327
+ return std::nullopt;
328
+ }
329
+
252
330
namespace llvm {
253
331
namespace dxil {
254
332
333
+ // No extra checks on TargetTriple need be performed to verify that the
334
+ // Triple is well-formed or that the target is supported since these checks
335
+ // would have been done at the time the module M is constructed in the earlier
336
+ // stages of compilation.
337
+ DXILOpBuilder::DXILOpBuilder (Module &M, IRBuilderBase &B) : M(M), B(B) {
338
+ Triple TT (Triple (M.getTargetTriple ()));
339
+ DXILVersion = TT.getDXILVersion ();
340
+ ShaderStage = TT.getEnvironment ();
341
+ // Ensure Environment type is known
342
+ if (ShaderStage == Triple::UnknownEnvironment) {
343
+ report_fatal_error (
344
+ Twine (DXILVersion.getAsString ()) +
345
+ " : Unknown Compilation Target Shader Stage specified " ,
346
+ /* gen_crash_diag*/ false );
347
+ }
348
+ }
349
+
255
350
CallInst *DXILOpBuilder::createDXILOpCall (dxil::OpCode OpCode, Type *ReturnTy,
256
351
Type *OverloadTy,
257
352
SmallVector<Value *> Args) {
353
+
258
354
const OpCodeProperty *Prop = getOpCodeProperty (OpCode);
355
+ std::optional<size_t > OlIndexOrErr =
356
+ getPropIndex (ArrayRef (Prop->Overloads ), DXILVersion);
357
+ if (!OlIndexOrErr.has_value ()) {
358
+ report_fatal_error (Twine (getOpCodeName (OpCode)) +
359
+ " : No valid overloads found for DXIL Version - " +
360
+ DXILVersion.getAsString (),
361
+ /* gen_crash_diag*/ false );
362
+ }
363
+ uint16_t ValidTyMask = Prop->Overloads [*OlIndexOrErr].ValidTys ;
259
364
260
365
OverloadKind Kind = getOverloadKind (OverloadTy);
261
- if ((Prop->OverloadTys & (uint16_t )Kind) == 0 ) {
262
- report_fatal_error (" Invalid Overload Type" , /* gen_crash_diag=*/ false );
366
+
367
+ // Check if the operation supports overload types and OverloadTy is valid
368
+ // per the specified types for the operation
369
+ if ((ValidTyMask != OverloadKind::UNDEFINED) &&
370
+ (ValidTyMask & (uint16_t )Kind) == 0 ) {
371
+ report_fatal_error (Twine (" Invalid Overload Type for DXIL operation - " ) +
372
+ getOpCodeName (OpCode),
373
+ /* gen_crash_diag=*/ false );
374
+ }
375
+
376
+ // Perform necessary checks to ensure Opcode is valid in the targeted shader
377
+ // kind
378
+ std::optional<size_t > StIndexOrErr =
379
+ getPropIndex (ArrayRef (Prop->Stages ), DXILVersion);
380
+ if (!StIndexOrErr.has_value ()) {
381
+ report_fatal_error (Twine (getOpCodeName (OpCode)) +
382
+ " : No valid stages found for DXIL Version - " +
383
+ DXILVersion.getAsString (),
384
+ /* gen_crash_diag*/ false );
385
+ }
386
+ uint16_t ValidShaderKindMask = Prop->Stages [*StIndexOrErr].ValidStages ;
387
+
388
+ // Ensure valid shader stage properties are specified
389
+ if (ValidShaderKindMask == ShaderKind::removed) {
390
+ report_fatal_error (
391
+ Twine (DXILVersion.getAsString ()) +
392
+ " : Unsupported Target Shader Stage for DXIL operation - " +
393
+ getOpCodeName (OpCode),
394
+ /* gen_crash_diag*/ false );
395
+ }
396
+
397
+ // Shader stage need not be validated since getShaderKindEnum() fails
398
+ // for unknown shader stage.
399
+
400
+ // Verify the target shader stage is valid for the DXIL operation
401
+ ShaderKind ModuleStagekind = getShaderKindEnum (ShaderStage);
402
+ if (!(ValidShaderKindMask & ModuleStagekind)) {
403
+ auto ShaderEnvStr = Triple::getEnvironmentTypeName (ShaderStage);
404
+ report_fatal_error (Twine (ShaderEnvStr) +
405
+ " : Invalid Shader Stage for DXIL operation - " +
406
+ getOpCodeName (OpCode) + " for DXIL Version " +
407
+ DXILVersion.getAsString (),
408
+ /* gen_crash_diag*/ false );
263
409
}
264
410
265
411
std::string DXILFnName = constructOverloadName (Kind, OverloadTy, *Prop);
@@ -282,40 +428,18 @@ Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {
282
428
// If DXIL Op has no overload parameter, just return the
283
429
// precise return type specified.
284
430
if (Prop->OverloadParamIndex < 0 ) {
285
- auto &Ctx = FT->getContext ();
286
- switch (Prop->OverloadTys ) {
287
- case OverloadKind::VOID:
288
- return Type::getVoidTy (Ctx);
289
- case OverloadKind::HALF:
290
- return Type::getHalfTy (Ctx);
291
- case OverloadKind::FLOAT:
292
- return Type::getFloatTy (Ctx);
293
- case OverloadKind::DOUBLE:
294
- return Type::getDoubleTy (Ctx);
295
- case OverloadKind::I1:
296
- return Type::getInt1Ty (Ctx);
297
- case OverloadKind::I8:
298
- return Type::getInt8Ty (Ctx);
299
- case OverloadKind::I16:
300
- return Type::getInt16Ty (Ctx);
301
- case OverloadKind::I32:
302
- return Type::getInt32Ty (Ctx);
303
- case OverloadKind::I64:
304
- return Type::getInt64Ty (Ctx);
305
- default :
306
- llvm_unreachable (" invalid overload type" );
307
- return nullptr ;
308
- }
431
+ return FT->getReturnType ();
309
432
}
310
433
311
- // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType().
434
+ // Consider FT->getReturnType() as default overload type, unless
435
+ // Prop->OverloadParamIndex != 0.
312
436
Type *OverloadType = FT->getReturnType ();
313
437
if (Prop->OverloadParamIndex != 0 ) {
314
438
// Skip Return Type.
315
439
OverloadType = FT->getParamType (Prop->OverloadParamIndex - 1 );
316
440
}
317
441
318
- auto ParamKinds = getOpCodeParameterKind (*Prop);
442
+ const ParameterKind * ParamKinds = getOpCodeParameterKind (*Prop);
319
443
auto Kind = ParamKinds[Prop->OverloadParamIndex ];
320
444
// For ResRet and CBufferRet, OverloadTy is in field of StructType.
321
445
if (Kind == ParameterKind::CBufferRet ||
0 commit comments