Skip to content

Commit ecc2080

Browse files
committed
[DirectX][NFC] Model precise overload type specification of DXIL Ops
Implements an abstraction to specify precise overload types supported by DXIL ops. These overload types are typically a subset of LLVM intrinsics. Implements the corresponding changes in DXILEmitter backend. Adds tests to verify expected errors for unsupported overload types at code generation time. Add tests to check for correct overrload error output.
1 parent 8fccf6b commit ecc2080

File tree

8 files changed

+194
-69
lines changed

8 files changed

+194
-69
lines changed

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -205,25 +205,63 @@ defset list<DXILOpClass> OpClasses = {
205205
def writeSamplerFeedbackBias : DXILOpClass;
206206
def writeSamplerFeedbackGrad : DXILOpClass;
207207
def writeSamplerFeedbackLevel: DXILOpClass;
208+
209+
// This is a sentinel definition. Hence placed at the end of the list
210+
// and not as part of the above alphabetically sorted valid definitions.
211+
// Additionally it is capitalized unlike all the others.
212+
def UnknownOpClass: DXILOpClass;
213+
}
214+
215+
// Several of the overloaded DXIL Operations support for data types
216+
// that are a subset of the overloaded LLVM intrinsics that they map to.
217+
// For e.g., llvm.sin.* intrinsic operates on any floating-point type and
218+
// maps for lowering to DXIL Op Sin. However, valid overloads of DXIL Sin
219+
// operation overloads are half (f16) and float (f32) only.
220+
//
221+
// The following abstracts overload types specific to DXIL operations.
222+
223+
class DXILType : LLVMType<OtherVT> {
224+
let isAny = 1;
208225
}
209226

227+
// Concrete records for various overload types supported specifically by
228+
// DXIL Operations.
229+
230+
def llvm_i16ori32_ty : DXILType;
231+
def llvm_halforfloat_ty : DXILType;
232+
210233
// Abstraction DXIL Operation to LLVM intrinsic
211-
class DXILOpMapping<int opCode, DXILOpClass opClass, Intrinsic intrinsic, string doc> {
234+
class DXILOpMappingBase {
235+
int OpCode = 0; // Opcode of DXIL Operation
236+
DXILOpClass OpClass = UnknownOpClass;// Class of DXIL Operation.
237+
Intrinsic LLVMIntrinsic = ?; // LLVM Intrinsic DXIL Operation maps to
238+
string Doc = ""; // A short description of the operation
239+
list<LLVMType> OpTypes = ?; // Valid types of DXIL Operation in the
240+
// format [returnTy, param1ty, ...]
241+
}
242+
243+
class DXILOpMapping<int opCode, DXILOpClass opClass,
244+
Intrinsic intrinsic, string doc,
245+
list<LLVMType> opTys = []> : DXILOpMappingBase {
212246
int OpCode = opCode; // Opcode corresponding to DXIL Operation
213-
DXILOpClass OpClass = opClass; // Class of DXIL Operation.
247+
DXILOpClass OpClass = opClass; // Class of DXIL Operation.
214248
Intrinsic LLVMIntrinsic = intrinsic; // LLVM Intrinsic the DXIL Operation maps
215249
string Doc = doc; // to a short description of the operation
250+
list<LLVMType> OpTypes = !if(!eq(!size(opTys), 0), LLVMIntrinsic.Types, opTys);
216251
}
217252

218253
// Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic
219254
def Sin : DXILOpMapping<13, unary, int_sin,
220-
"Returns sine(theta) for theta in radians.">;
255+
"Returns sine(theta) for theta in radians.",
256+
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
221257
def Frac : DXILOpMapping<22, unary, int_dx_frac,
222258
"Returns a fraction from 0 to 1 that represents the "
223-
"decimal part of the input.">;
259+
"decimal part of the input.",
260+
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
224261
def Round : DXILOpMapping<26, unary, int_round,
225262
"Returns the input rounded to the nearest integer"
226-
"within a floating-point type.">;
263+
"within a floating-point type.",
264+
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
227265
def UMax : DXILOpMapping<39, binary, int_umax,
228266
"Unsigned integer maximum. UMax(a,b) = a > b ? a : b">;
229267
def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id,

llvm/lib/Target/DirectX/DXILOpBuilder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ static FunctionCallee getOrCreateDXILOpFunction(dxil::OpCode DXILOp,
257257
// FIXME: find the issue and report error in clang instead of check it in
258258
// backend.
259259
if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
260-
llvm_unreachable("invalid overload");
260+
report_fatal_error("Invalid Overload Type", false);
261261
}
262262

263263
std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);

llvm/test/CodeGen/DirectX/frac.ll

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,3 @@ entry:
2929
%dx.frac = call half @llvm.dx.frac.f16(half %0)
3030
ret half %dx.frac
3131
}
32-
33-
; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
34-
declare half @llvm.dx.frac.f16(half) #1
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
2+
3+
; This test is expected to fail with the following error
4+
; CHECK: LLVM ERROR: Invalid Overload Type
5+
6+
; Function Attrs: noinline nounwind optnone
7+
define noundef double @frac_double(double noundef %a) #0 {
8+
entry:
9+
%a.addr = alloca double, align 8
10+
store double %a, ptr %a.addr, align 8
11+
%0 = load double, ptr %a.addr, align 8
12+
%dx.frac = call double @llvm.dx.frac.f64(double %0)
13+
ret double %dx.frac
14+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
2+
3+
; This test is expected to fail with the following error
4+
; CHECK: LLVM ERROR: Invalid Overload Type
5+
6+
define noundef double @round_double(double noundef %a) #0 {
7+
entry:
8+
%a.addr = alloca double, align 8
9+
store double %a, ptr %a.addr, align 8
10+
%0 = load double, ptr %a.addr, align 8
11+
%elt.round = call double @llvm.round.f64(double %0)
12+
ret double %elt.round
13+
}

llvm/test/CodeGen/DirectX/sin.ll

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,8 @@
44
; CHECK:call float @dx.op.unary.f32(i32 13, float %{{.*}})
55
; CHECK:call half @dx.op.unary.f16(i32 13, half %{{.*}})
66

7-
target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
8-
target triple = "dxil-pc-shadermodel6.7-library"
9-
107
; Function Attrs: noinline nounwind optnone
11-
define noundef float @_Z3foof(float noundef %a) #0 {
8+
define noundef float @sin_float(float noundef %a) #0 {
129
entry:
1310
%a.addr = alloca float, align 4
1411
store float %a, ptr %a.addr, align 4
@@ -21,23 +18,11 @@ entry:
2118
declare float @llvm.sin.f32(float) #1
2219

2320
; Function Attrs: noinline nounwind optnone
24-
define noundef half @_Z3barDh(half noundef %a) #0 {
21+
define noundef half @sin_half(half noundef %a) #0 {
2522
entry:
2623
%a.addr = alloca half, align 2
2724
store half %a, ptr %a.addr, align 2
2825
%0 = load half, ptr %a.addr, align 2
2926
%1 = call half @llvm.sin.f16(half %0)
3027
ret half %1
3128
}
32-
33-
; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
34-
declare half @llvm.sin.f16(half) #1
35-
36-
attributes #0 = { noinline nounwind optnone "frame-pointer"="none" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
37-
attributes #1 = { nocallback nofree nosync nounwind readnone speculatable willreturn }
38-
39-
!llvm.module.flags = !{!0}
40-
!llvm.ident = !{!1}
41-
42-
!0 = !{i32 1, !"wchar_size", i32 4}
43-
!1 = !{!"clang version 15.0.0 (https://github.com/llvm/llvm-project.git 73417c517644db5c419c85c0b3cb6750172fcab5)"}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
2+
3+
; This test is expected to fail with the following error
4+
; CHECK: LLVM ERROR: Invalid Overload
5+
6+
define noundef double @sin_double(double noundef %a) #0 {
7+
entry:
8+
%a.addr = alloca double, align 8
9+
store double %a, ptr %a.addr, align 8
10+
%0 = load double, ptr %a.addr, align 8
11+
%1 = call double @llvm.sin.f64(double %0)
12+
ret double %1
13+
}
14+

llvm/utils/TableGen/DXILEmitter.cpp

Lines changed: 107 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "llvm/Support/DXILABI.h"
2323
#include "llvm/TableGen/Record.h"
2424
#include "llvm/TableGen/TableGenBackend.h"
25+
#include <string>
2526

2627
using namespace llvm;
2728
using namespace llvm::dxil;
@@ -38,8 +39,8 @@ struct DXILOperationDesc {
3839
int OpCode; // ID of DXIL operation
3940
StringRef OpClass; // name of the opcode class
4041
StringRef Doc; // the documentation description of this instruction
41-
SmallVector<MVT::SimpleValueType> OpTypes; // Vector of operand types -
42-
// return type is at index 0
42+
SmallVector<Record *> OpTypes; // Vector of operand type records -
43+
// return type is at index 0
4344
SmallVector<std::string>
4445
OpAttributes; // operation attribute represented as strings
4546
StringRef Intrinsic; // The llvm intrinsic map to OpName. Default is "" which
@@ -57,20 +58,21 @@ struct DXILOperationDesc {
5758
DXILShaderModel ShaderModel; // minimum shader model required
5859
DXILShaderModel ShaderModelTranslated; // minimum shader model required with
5960
// translation by linker
60-
int OverloadParamIndex; // parameter index which control the overload.
61-
// When < 0, should be only 1 overload type.
61+
int OverloadParamIndex; // Index of parameter with overload type.
62+
// -1 : no overload types
6263
SmallVector<StringRef, 4> counters; // counters for this inst.
6364
DXILOperationDesc(const Record *);
6465
};
6566
} // end anonymous namespace
6667

67-
/// Convert DXIL type name string to dxil::ParameterKind
68+
/// Return dxil::ParameterKind corresponding to input LLVMType record
6869
///
69-
/// \param VT Simple Value Type
70+
/// \param R TableGen def record of class LLVMType
7071
/// \return ParameterKind As defined in llvm/Support/DXILABI.h
7172

72-
static ParameterKind getParameterKind(MVT::SimpleValueType VT) {
73-
switch (VT) {
73+
static ParameterKind getParameterKind(const Record *R) {
74+
auto VTRec = R->getValueAsDef("VT");
75+
switch (getValueType(VTRec)) {
7476
case MVT::isVoid:
7577
return ParameterKind::VOID;
7678
case MVT::f16:
@@ -90,6 +92,18 @@ static ParameterKind getParameterKind(MVT::SimpleValueType VT) {
9092
case MVT::fAny:
9193
case MVT::iAny:
9294
return ParameterKind::OVERLOAD;
95+
case MVT::Other:
96+
// Handle DXIL-specific overload types
97+
{
98+
auto RetKind = StringSwitch<ParameterKind>(R->getNameInitAsString())
99+
.Cases("llvm_i16ori32_ty", "llvm_halforfloat_ty",
100+
ParameterKind::OVERLOAD)
101+
.Default(ParameterKind::INVALID);
102+
if (RetKind != ParameterKind::INVALID) {
103+
return RetKind;
104+
}
105+
}
106+
LLVM_FALLTHROUGH;
93107
default:
94108
llvm_unreachable("Support for specified DXIL Type not yet implemented");
95109
}
@@ -106,45 +120,80 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
106120

107121
Doc = R->getValueAsString("Doc");
108122

123+
auto TypeRecs = R->getValueAsListOfDefs("OpTypes");
124+
unsigned TypeRecsSize = TypeRecs.size();
125+
// Populate OpTypes with return type and parameter types
126+
127+
// Parameter indices of overloaded parameters.
128+
// This vector contains overload parameters in the order order used to
129+
// resolve an LLVMMatchType in accordance with convention outlined in
130+
// the comment before the definition of class LLVMMatchType in
131+
// llvm/IR/Intrinsics.td
132+
SmallVector<int> OverloadParamIndices;
133+
for (unsigned i = 0; i < TypeRecsSize; i++) {
134+
auto TR = TypeRecs[i];
135+
// Track operation parameter indices of any overload types
136+
auto isAny = TR->getValueAsInt("isAny");
137+
if (isAny == 1) {
138+
// TODO: At present it is expected that all overload types in a DXIL Op
139+
// are of the same type. Hence, OverloadParamIndices will have only one
140+
// element. This implies we do not need a vector. However, until more
141+
// (all?) DXIL Ops are added in DXIL.td, a vector is being used to flag
142+
// cases this assumption would not hold.
143+
if (!OverloadParamIndices.empty()) {
144+
bool knownType = true;
145+
// Ensure that the same overload type registered earlier is being used
146+
for (auto Idx : OverloadParamIndices) {
147+
if (TR != TypeRecs[Idx]) {
148+
knownType = false;
149+
break;
150+
}
151+
}
152+
if (!knownType) {
153+
report_fatal_error("Specification of multiple differing overload "
154+
"parameter types not yet supported",
155+
false);
156+
}
157+
} else {
158+
OverloadParamIndices.push_back(i);
159+
}
160+
}
161+
// Populate OpTypes array according to the type specification
162+
if (TR->isAnonymous()) {
163+
// Check prior overload types exist
164+
assert(!OverloadParamIndices.empty() &&
165+
"No prior overloaded parameter found to match.");
166+
// Get the parameter index of anonymous type, TR, references
167+
auto OLParamIndex = TR->getValueAsInt("Number");
168+
// Resolve and insert the type to that at OLParamIndex
169+
OpTypes.emplace_back(TypeRecs[OLParamIndex]);
170+
} else {
171+
// A non-anonymous type. Just record it in OpTypes
172+
OpTypes.emplace_back(TR);
173+
}
174+
}
175+
176+
// Set the index of the overload parameter, if any.
177+
OverloadParamIndex = -1; // default; indicating none
178+
if (!OverloadParamIndices.empty()) {
179+
if (OverloadParamIndices.size() > 1)
180+
report_fatal_error("Multiple overload type specification not supported",
181+
false);
182+
OverloadParamIndex = OverloadParamIndices[0];
183+
}
184+
// Get the operation class
185+
OpClass = R->getValueAsDef("OpClass")->getName();
186+
109187
if (R->getValue("LLVMIntrinsic")) {
110188
auto *IntrinsicDef = R->getValueAsDef("LLVMIntrinsic");
111189
auto DefName = IntrinsicDef->getName();
112190
assert(DefName.starts_with("int_") && "invalid intrinsic name");
113191
// Remove the int_ from intrinsic name.
114192
Intrinsic = DefName.substr(4);
115-
// TODO: It is expected that return type and parameter types of
116-
// DXIL Operation are the same as that of the intrinsic. Deviations
117-
// are expected to be encoded in TableGen record specification and
118-
// handled accordingly here. Support to be added later, as needed.
119-
// Get parameter type list of the intrinsic. Types attribute contains
120-
// the list of as [returnType, param1Type,, param2Type, ...]
121-
122-
OverloadParamIndex = -1;
123-
auto TypeRecs = IntrinsicDef->getValueAsListOfDefs("Types");
124-
unsigned TypeRecsSize = TypeRecs.size();
125-
// Populate return type and parameter type names
126-
for (unsigned i = 0; i < TypeRecsSize; i++) {
127-
auto TR = TypeRecs[i];
128-
OpTypes.emplace_back(getValueType(TR->getValueAsDef("VT")));
129-
// Get the overload parameter index.
130-
// TODO : Seems hacky. Is it possible that more than one parameter can
131-
// be of overload kind??
132-
// TODO: Check for any additional constraints specified for DXIL operation
133-
// restricting return type.
134-
if (i > 0) {
135-
auto &CurParam = OpTypes.back();
136-
if (getParameterKind(CurParam) >= ParameterKind::OVERLOAD) {
137-
OverloadParamIndex = i;
138-
}
139-
}
140-
}
141-
// Get the operation class
142-
OpClass = R->getValueAsDef("OpClass")->getName();
143-
144-
// NOTE: For now, assume that attributes of DXIL Operation are the same as
193+
// TODO: For now, assume that attributes of DXIL Operation are the same as
145194
// that of the intrinsic. Deviations are expected to be encoded in TableGen
146195
// record specification and handled accordingly here. Support to be added
147-
// later.
196+
// as needed.
148197
auto IntrPropList = IntrinsicDef->getValueAsListInit("IntrProperties");
149198
auto IntrPropListSize = IntrPropList->size();
150199
for (unsigned i = 0; i < IntrPropListSize; i++) {
@@ -191,12 +240,13 @@ static std::string getParameterKindStr(ParameterKind Kind) {
191240
}
192241

193242
/// Return a string representation of OverloadKind enum that maps to
194-
/// input Simple Value Type enum
195-
/// \param VT Simple Value Type enum
243+
/// input LLVMType record
244+
/// \param R TableGen def record of class LLVMType
196245
/// \return std::string string representation of OverloadKind
197246

198-
static std::string getOverloadKindStr(MVT::SimpleValueType VT) {
199-
switch (VT) {
247+
static std::string getOverloadKindStr(const Record *R) {
248+
auto VTRec = R->getValueAsDef("VT");
249+
switch (getValueType(VTRec)) {
200250
case MVT::isVoid:
201251
return "OverloadKind::VOID";
202252
case MVT::f16:
@@ -219,6 +269,20 @@ static std::string getOverloadKindStr(MVT::SimpleValueType VT) {
219269
return "OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64";
220270
case MVT::fAny:
221271
return "OverloadKind::HALF | OverloadKind::FLOAT | OverloadKind::DOUBLE";
272+
case MVT::Other:
273+
// Handle DXIL-specific overload types
274+
{
275+
auto RetStr =
276+
StringSwitch<std::string>(R->getNameInitAsString())
277+
.Case("llvm_i16ori32_ty", "OverloadKind::I16 | OverloadKind::I32")
278+
.Case("llvm_halforfloat_ty",
279+
"OverloadKind::HALF | OverloadKind::FLOAT")
280+
.Default("");
281+
if (RetStr != "") {
282+
return RetStr;
283+
}
284+
}
285+
LLVM_FALLTHROUGH;
222286
default:
223287
llvm_unreachable(
224288
"Support for specified parameter OverloadKind not yet implemented");

0 commit comments

Comments
 (0)