Skip to content

Commit 9724728

Browse files
committed
Align specification of return type and parameter type fields of
DXIL Op mapping with those of TableGan class Intrinsic. A void return type of LLVM Intrinsic is represented as [] in its TableGen description record. Currently, a void return type of DXIL Operation is represented as [llvm_void_ty]. In addition, return and parameter types are recorded as a single list with an understanding that element at index `0` is the return type. These changes leverage and align DXIL Op type specification with the type specification of the LLVM Intrinsic. As a result, return and parameter types are now specified as two separate lists no longer requiring a different representation for void return type. Additionally, type specification would be more succinct yet equally informative for DXIL Op records for which the same LLVM Intrinsics types are also valid. Added a test to verify lowering of LLVM intrinsic with void return. Barrier intrinsic has a void return type. Specification of its DXIL Op can inherit the types of this intrinsic. The test verifies the changes. Move OverloadKind to DXILABI.h. Update definition names of enum Overload to follow naming conventions.
1 parent 6d3ec56 commit 9724728

File tree

6 files changed

+260
-176
lines changed

6 files changed

+260
-176
lines changed

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def int_dx_thread_id : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrW
1616
def int_dx_group_id : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>;
1717
def int_dx_thread_id_in_group : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>;
1818
def int_dx_flattened_thread_id_in_group : Intrinsic<[llvm_i32_ty], [], [IntrNoMem, IntrWillReturn]>;
19+
def int_dx_barrier : Intrinsic<[], [llvm_i32_ty], [IntrNoDuplicate, IntrWillReturn]>;
1920

2021
def int_dx_create_handle : ClangBuiltin<"__builtin_hlsl_create_handle">,
2122
Intrinsic<[ llvm_ptr_ty ], [llvm_i8_ty], [IntrWillReturn]>;

llvm/include/llvm/Support/DXILABI.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,21 @@ enum class ParameterKind : uint8_t {
3939
DXILHandle,
4040
};
4141

42+
enum OverloadKind : uint16_t {
43+
Invalid = 0,
44+
Void = 1,
45+
Half = 1 << 1,
46+
Float = 1 << 2,
47+
Double = 1 << 3,
48+
I1 = 1 << 4,
49+
I8 = 1 << 5,
50+
I16 = 1 << 6,
51+
I32 = 1 << 7,
52+
I64 = 1 << 8,
53+
UserDefineType = 1 << 9,
54+
ObjectType = 1 << 10,
55+
};
56+
4257
/// The kind of resource for an SRV or UAV resource. Sometimes referred to as
4358
/// "Shape" in the DXIL docs.
4459
enum class ResourceKind : uint32_t {

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -240,58 +240,63 @@ class DXILOpMappingBase {
240240
DXILOpClass OpClass = UnknownOpClass;// Class of DXIL Operation.
241241
Intrinsic LLVMIntrinsic = ?; // LLVM Intrinsic DXIL Operation maps to
242242
string Doc = ""; // A short description of the operation
243-
list<LLVMType> OpTypes = ?; // Valid types of DXIL Operation in the
244-
// format [returnTy, param1ty, ...]
243+
// The following fields denote the same semantics as those of Intrinsic class
244+
// and are initialized with the same values as those of LLVMIntrinsic unless
245+
// overridden in the definition of a record.
246+
list<LLVMType> OpRetTypes = ?; // Valid return types of DXIL Operation
247+
list<LLVMType> OpParamTypes = ?; // Valid parameter types of DXIL Operation
245248
}
246249

247250
class DXILOpMapping<int opCode, DXILOpClass opClass,
248251
Intrinsic intrinsic, string doc,
249-
list<LLVMType> opTys = []> : DXILOpMappingBase {
252+
list<LLVMType> retTys = [],
253+
list<LLVMType> paramTys = []> : DXILOpMappingBase {
250254
int OpCode = opCode; // Opcode corresponding to DXIL Operation
251255
DXILOpClass OpClass = opClass; // Class of DXIL Operation.
252256
Intrinsic LLVMIntrinsic = intrinsic; // LLVM Intrinsic the DXIL Operation maps
253257
string Doc = doc; // to a short description of the operation
254-
list<LLVMType> OpTypes = !if(!eq(!size(opTys), 0), LLVMIntrinsic.Types, opTys);
258+
list<LLVMType> OpRetTypes = !if(!eq(!size(retTys), 0), LLVMIntrinsic.RetTypes, retTys);
259+
list<LLVMType> OpParamTypes = !if(!eq(!size(paramTys), 0), LLVMIntrinsic.ParamTypes, paramTys);
255260
}
256261

257262
// Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic
258263
def Abs : DXILOpMapping<6, unary, int_fabs,
259264
"Returns the absolute value of the input.">;
260265
def IsInf : DXILOpMapping<9, isSpecialFloat, int_dx_isinf,
261266
"Determines if the specified value is infinite.",
262-
[llvm_i1_ty, llvm_halforfloat_ty]>;
267+
[llvm_i1_ty], [llvm_halforfloat_ty]>;
263268
def Cos : DXILOpMapping<12, unary, int_cos,
264269
"Returns cosine(theta) for theta in radians.",
265-
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
270+
[llvm_halforfloat_ty], [LLVMMatchType<0>]>;
266271
def Sin : DXILOpMapping<13, unary, int_sin,
267272
"Returns sine(theta) for theta in radians.",
268-
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
273+
[llvm_halforfloat_ty], [LLVMMatchType<0>]>;
269274
def Exp2 : DXILOpMapping<21, unary, int_exp2,
270275
"Returns the base 2 exponential, or 2**x, of the specified value."
271276
"exp2(x) = 2**x.",
272-
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
277+
[llvm_halforfloat_ty], [LLVMMatchType<0>]>;
273278
def Frac : DXILOpMapping<22, unary, int_dx_frac,
274279
"Returns a fraction from 0 to 1 that represents the "
275280
"decimal part of the input.",
276-
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
281+
[llvm_halforfloat_ty], [LLVMMatchType<0>]>;
277282
def Log2 : DXILOpMapping<23, unary, int_log2,
278283
"Returns the base-2 logarithm of the specified value.",
279-
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
284+
[llvm_halforfloat_ty], [LLVMMatchType<0>]>;
280285
def Sqrt : DXILOpMapping<24, unary, int_sqrt,
281286
"Returns the square root of the specified floating-point"
282287
"value, per component.",
283-
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
288+
[llvm_halforfloat_ty], [LLVMMatchType<0>]>;
284289
def RSqrt : DXILOpMapping<25, unary, int_dx_rsqrt,
285290
"Returns the reciprocal of the square root of the specified value."
286291
"rsqrt(x) = 1 / sqrt(x).",
287-
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
292+
[llvm_halforfloat_ty], [LLVMMatchType<0>]>;
288293
def Round : DXILOpMapping<26, unary, int_round,
289294
"Returns the input rounded to the nearest integer"
290295
"within a floating-point type.",
291-
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
296+
[llvm_halforfloat_ty], [LLVMMatchType<0>]>;
292297
def Floor : DXILOpMapping<27, unary, int_floor,
293298
"Returns the largest integer that is less than or equal to the input.",
294-
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
299+
[llvm_halforfloat_ty], [LLVMMatchType<0>]>;
295300
def FMax : DXILOpMapping<35, binary, int_maxnum,
296301
"Float maximum. FMax(a,b) = a > b ? a : b">;
297302
def FMin : DXILOpMapping<36, binary, int_minnum,
@@ -305,20 +310,28 @@ def UMax : DXILOpMapping<39, binary, int_umax,
305310
def UMin : DXILOpMapping<40, binary, int_umin,
306311
"Unsigned integer minimum. UMin(a,b) = a < b ? a : b">;
307312
def FMad : DXILOpMapping<46, tertiary, int_fmuladd,
308-
"Floating point arithmetic multiply/add operation. fmad(m,a,b) = m * a + b.">;
313+
"Floating point arithmetic multiply/add operation. "
314+
"fmad(m,a,b) = m * a + b.">;
309315
def IMad : DXILOpMapping<48, tertiary, int_dx_imad,
310-
"Signed integer arithmetic multiply/add operation. imad(m,a,b) = m * a + b.">;
316+
"Signed integer arithmetic multiply/add operation. "
317+
"imad(m,a,b) = m * a + b.">;
311318
def UMad : DXILOpMapping<49, tertiary, int_dx_umad,
312-
"Unsigned integer arithmetic multiply/add operation. umad(m,a,b) = m * a + b.">;
313-
let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 4)) in
314-
def Dot2 : DXILOpMapping<54, dot2, int_dx_dot2,
315-
"dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 1">;
316-
let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 6)) in
317-
def Dot3 : DXILOpMapping<55, dot3, int_dx_dot3,
318-
"dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 2">;
319-
let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 8)) in
320-
def Dot4 : DXILOpMapping<56, dot4, int_dx_dot4,
321-
"dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 3">;
319+
"Unsigned integer arithmetic multiply/add operation. "
320+
"umad(m,a,b) = m * a + b.">;
321+
def Dot2 : DXILOpMapping<54, dot2, int_dx_dot2,
322+
"dot product of two float vectors Dot(a,b) = a[0]*b[0]"
323+
" + ... + a[n]*b[n] where n is between 0 and 1",
324+
[llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 4)>;
325+
def Dot3 : DXILOpMapping<55, dot3, int_dx_dot3,
326+
"dot product of two float vectors Dot(a,b) = a[0]*b[0]"
327+
" + ... + a[n]*b[n] where n is between 0 and 2",
328+
[llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 6)>;
329+
def Dot4 : DXILOpMapping<56, dot4, int_dx_dot4,
330+
"dot product of two float vectors Dot(a,b) = a[0]*b[0]"
331+
" + ... + a[n]*b[n] where n is between 0 and 3",
332+
[llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 8)>;
333+
def Barrier : DXILOpMapping<80, barrier, int_dx_barrier,
334+
"Inserts a memory barrier in the shader">;
322335
def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id,
323336
"Reads the thread ID">;
324337
def GroupId : DXILOpMapping<94, groupId, int_dx_group_id,

llvm/lib/Target/DirectX/DXILOpBuilder.cpp

Lines changed: 20 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,31 +21,13 @@ using namespace llvm::dxil;
2121

2222
constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
2323

24-
namespace {
25-
26-
enum OverloadKind : uint16_t {
27-
VOID = 1,
28-
HALF = 1 << 1,
29-
FLOAT = 1 << 2,
30-
DOUBLE = 1 << 3,
31-
I1 = 1 << 4,
32-
I8 = 1 << 5,
33-
I16 = 1 << 6,
34-
I32 = 1 << 7,
35-
I64 = 1 << 8,
36-
UserDefineType = 1 << 9,
37-
ObjectType = 1 << 10,
38-
};
39-
40-
} // namespace
41-
4224
static const char *getOverloadTypeName(OverloadKind Kind) {
4325
switch (Kind) {
44-
case OverloadKind::HALF:
26+
case OverloadKind::Half:
4527
return "f16";
46-
case OverloadKind::FLOAT:
28+
case OverloadKind::Float:
4729
return "f32";
48-
case OverloadKind::DOUBLE:
30+
case OverloadKind::Double:
4931
return "f64";
5032
case OverloadKind::I1:
5133
return "i1";
@@ -57,26 +39,29 @@ static const char *getOverloadTypeName(OverloadKind Kind) {
5739
return "i32";
5840
case OverloadKind::I64:
5941
return "i64";
60-
case OverloadKind::VOID:
42+
case OverloadKind::Void:
6143
case OverloadKind::ObjectType:
6244
case OverloadKind::UserDefineType:
6345
break;
46+
case OverloadKind::Invalid:
47+
report_fatal_error("Invalid Overload Type for type name lookup",
48+
/* gen_crash_diag=*/false);
6449
}
65-
llvm_unreachable("invalid overload type for name");
50+
llvm_unreachable("Unhandled Overload Type specified for type name lookup");
6651
return "void";
6752
}
6853

6954
static OverloadKind getOverloadKind(Type *Ty) {
7055
Type::TypeID T = Ty->getTypeID();
7156
switch (T) {
7257
case Type::VoidTyID:
73-
return OverloadKind::VOID;
58+
return OverloadKind::Void;
7459
case Type::HalfTyID:
75-
return OverloadKind::HALF;
60+
return OverloadKind::Half;
7661
case Type::FloatTyID:
77-
return OverloadKind::FLOAT;
62+
return OverloadKind::Float;
7863
case Type::DoubleTyID:
79-
return OverloadKind::DOUBLE;
64+
return OverloadKind::Double;
8065
case Type::IntegerTyID: {
8166
IntegerType *ITy = cast<IntegerType>(Ty);
8267
unsigned Bits = ITy->getBitWidth();
@@ -93,7 +78,7 @@ static OverloadKind getOverloadKind(Type *Ty) {
9378
return OverloadKind::I64;
9479
default:
9580
llvm_unreachable("invalid overload type");
96-
return OverloadKind::VOID;
81+
return OverloadKind::Void;
9782
}
9883
}
9984
case Type::PointerTyID:
@@ -102,7 +87,7 @@ static OverloadKind getOverloadKind(Type *Ty) {
10287
return OverloadKind::ObjectType;
10388
default:
10489
llvm_unreachable("invalid overload type");
105-
return OverloadKind::VOID;
90+
return OverloadKind::Void;
10691
}
10792
}
10893

@@ -147,7 +132,7 @@ struct OpCodeProperty {
147132

148133
static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
149134
const OpCodeProperty &Prop) {
150-
if (Kind == OverloadKind::VOID) {
135+
if (Kind == OverloadKind::Void) {
151136
return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
152137
}
153138
return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
@@ -157,7 +142,7 @@ static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
157142

158143
static std::string constructOverloadTypeName(OverloadKind Kind,
159144
StringRef TypeName) {
160-
if (Kind == OverloadKind::VOID)
145+
if (Kind == OverloadKind::Void)
161146
return TypeName.str();
162147

163148
assert(Kind < OverloadKind::UserDefineType && "invalid overload kind");
@@ -284,13 +269,13 @@ Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {
284269
if (Prop->OverloadParamIndex < 0) {
285270
auto &Ctx = FT->getContext();
286271
switch (Prop->OverloadTys) {
287-
case OverloadKind::VOID:
272+
case OverloadKind::Void:
288273
return Type::getVoidTy(Ctx);
289-
case OverloadKind::HALF:
274+
case OverloadKind::Half:
290275
return Type::getHalfTy(Ctx);
291-
case OverloadKind::FLOAT:
276+
case OverloadKind::Float:
292277
return Type::getFloatTy(Ctx);
293-
case OverloadKind::DOUBLE:
278+
case OverloadKind::Double:
294279
return Type::getDoubleTy(Ctx);
295280
case OverloadKind::I1:
296281
return Type::getInt1Ty(Ctx);

llvm/test/CodeGen/DirectX/barrier.ll

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
2+
3+
; Argument of llvm.dx.barrier is expected to be a mask of
4+
; DXIL::BarrierMode values. Chose an int value for testing.
5+
6+
define void @test_barrier() #0 {
7+
entry:
8+
; CHECK: call void @dx.op.barrier.i32(i32 80, i32 9)
9+
call void @llvm.dx.barrier(i32 noundef 9)
10+
ret void
11+
}

0 commit comments

Comments
 (0)