Skip to content

Commit aee3f5b

Browse files
committed
Incorporate PR review feedback
- Use VersionTriple to deal with Shader Model version. - Undo sin test reorganization.
1 parent 24b497b commit aee3f5b

File tree

13 files changed

+124
-115
lines changed

13 files changed

+124
-115
lines changed

llvm/include/llvm/Support/DXILABI.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ enum class ElementType : uint32_t {
9090
} // namespace dxil
9191
} // namespace llvm
9292

93-
// Generate a unique value for given Major, Minor pair of Shader Model
94-
// version. Allows for 100 minor versions for a given major version number.
95-
// To be used uniformly by DXILEmitter backend as well as DXIL Lowering pass.
96-
#define COMPUTE_SM_VERSION_VALUE(MAJ, MIN) ((MAJ * 100) + MIN)
93+
struct DXILShaderModel {
94+
unsigned Major = 0;
95+
unsigned Minor = 0;
96+
};
9797

9898
#endif // LLVM_SUPPORT_DXILABI_H

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,10 @@ class DXILShaderModel<int major, int minor> {
2222

2323
// Valid minimum Shader model version records
2424

25-
// Shader Mode 6.x
26-
foreach i = 0...9 in {
25+
// Shader Model 6.0 - 6.8
26+
foreach i = 0...8 in {
2727
def SM6_#i : DXILShaderModel<6, i>;
2828
}
29-
// Shader Model 7.x - for now 7.0 is defined. Extend as needed
30-
foreach i = 0 in {
31-
def SM7_#i : DXILShaderModel<7, i>;
32-
}
3329

3430
// Abstraction of class mapping valid DXIL Op overloads the minimum
3531
// version of Shader Model they are supported
@@ -109,14 +105,13 @@ let OpClass = isSpecialFloat in {
109105
"Determines if the specified value is infinite.">;
110106
}
111107

112-
// Unary Class
113108
let OpClass = unary in {
114109
def Abs : DXILOpMapping<6, int_fabs, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty, llvm_double_ty]>],
115110
"Returns the absolute value of the input.">;
116111

117112
def Cos : DXILOpMapping<12, int_cos, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
118113
"Returns cosine(theta) for theta in radians.">;
119-
def Sin : DXILOpMapping<13, int_sin, [DXILOpOverload<SM6_3, [llvm_half_ty, llvm_float_ty]>,
114+
def Sin : DXILOpMapping<13, int_sin, [DXILOpOverload<SM6_2, [llvm_half_ty, llvm_float_ty]>,
120115
DXILOpOverload<SM6_0, [llvm_float_ty]>],
121116
"Returns sine(theta) for theta in radians.">;
122117
def Exp2 : DXILOpMapping<21, int_exp2, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
@@ -146,7 +141,6 @@ let OpClass = unary in {
146141
"Returns the specified value with its bits reversed.">;
147142
}
148143

149-
// Binary Class
150144
let OpClass = binary in {
151145
// Float overloads
152146
def FMax : DXILOpMapping<35, int_maxnum, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty, llvm_double_ty]>],
@@ -164,10 +158,7 @@ let OpClass = binary in {
164158
"Unsigned integer minimum. UMin(a,b) = a < b ? a : b">;
165159
}
166160

167-
// Tertiary Class
168161
let OpClass = tertiary in {
169-
// Float overloads
170-
// let OpOverloadTypes = [llvm_half_ty, llvm_float_ty, llvm_double_ty] in {
171162
def FMad : DXILOpMapping<46, int_fmuladd, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty, llvm_double_ty]>],
172163
"Floating point arithmetic multiply/add operation."
173164
" fmad(m,a,b) = m * a + b.">;
@@ -181,7 +172,6 @@ def UMad : DXILOpMapping<49, int_dx_umad, [DXILOpOverload<SM6_0, [llvm_i16_ty, l
181172
}
182173

183174
// Dot Operations
184-
// let OpOverloadTypes = [llvm_half_ty, llvm_float_ty] in
185175
let OpClass = dot2 in
186176
def Dot2 : DXILOpMapping<54, int_dx_dot2, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
187177
"dot product of two float vectors Dot(a,b) = a[0]*b[0] +"

llvm/lib/Target/DirectX/DXILOpBuilder.cpp

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
#include "llvm/IR/Module.h"
1616
#include "llvm/Support/DXILABI.h"
1717
#include "llvm/Support/ErrorHandling.h"
18+
#include "llvm/Support/VersionTuple.h"
19+
#include <algorithm>
20+
#include <cassert>
1821
#include <string>
1922

2023
using namespace llvm;
@@ -125,7 +128,7 @@ static std::string getTypeName(OverloadKind Kind, Type *Ty) {
125128
}
126129

127130
struct OpSMOverloadProp {
128-
uint16_t ShaderModelVer;
131+
DXILShaderModel ShaderModelVer;
129132
uint16_t ValidTys;
130133
};
131134

@@ -256,35 +259,35 @@ static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
256259
}
257260

258261
static uint16_t getValidOverloadMask(const OpCodeProperty *Prop,
259-
uint32_t SMVer) {
262+
VersionTuple SMVer) {
260263
uint16_t ValidTyMask = 0;
261264
// std::vector Prop->OverloadProp is in ascending order of SM Version
262265
// Overloads of highest SM version that is not greater than SMVer
263266
// are the ones that are valid for SMVer.
264-
for (auto OL : Prop->OverloadProp) {
265-
if (OL.ShaderModelVer <= SMVer) {
266-
ValidTyMask = OL.ValidTys;
267-
} else {
268-
break;
269-
}
270-
}
267+
268+
// Get the lower bound value iterator of SMVer
269+
auto LaterSM = std::lower_bound(
270+
Prop->OverloadProp.begin(), Prop->OverloadProp.end(), SMVer,
271+
[](const OpSMOverloadProp OL, VersionTuple VerTup) {
272+
return (VersionTuple(OL.ShaderModelVer.Major,
273+
OL.ShaderModelVer.Minor) <= VerTup);
274+
});
275+
// Valid overloads are of the version prior to the lower bound
276+
ValidTyMask = (--LaterSM)->ValidTys;
277+
assert(ValidTyMask != 0 && "No valid overload types found");
271278
return ValidTyMask;
272279
}
273280

274281
namespace llvm {
275282
namespace dxil {
276283

277-
CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, uint32_t SMVer,
278-
Type *ReturnTy, Type *OverloadTy,
284+
CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode,
285+
VersionTuple &SMVer, Type *ReturnTy,
286+
Type *OverloadTy,
279287
SmallVector<Value *> Args) {
280288
const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
281289
uint16_t ValidTyMask = getValidOverloadMask(Prop, SMVer);
282290

283-
if (ValidTyMask == 0) {
284-
report_fatal_error(StringRef(std::to_string(SMVer).append(
285-
": Unhandled Shader Model Version")),
286-
/*gen_crash_diag*/ false);
287-
}
288291
OverloadKind Kind = getOverloadKind(OverloadTy);
289292
if ((ValidTyMask & (uint16_t)Kind) == 0) {
290293
report_fatal_error("Invalid Overload Type", /* gen_crash_diag=*/false);
@@ -304,7 +307,7 @@ CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, uint32_t SMVer,
304307
return B.CreateCall(DXILFn, Args);
305308
}
306309

307-
Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, uint32_t SMVer,
310+
Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, VersionTuple &SMVer,
308311
FunctionType *FT) {
309312

310313
const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
@@ -313,11 +316,6 @@ Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, uint32_t SMVer,
313316
if (Prop->OverloadParamIndex < 0) {
314317
auto &Ctx = FT->getContext();
315318
uint16_t ValidTyMask = getValidOverloadMask(Prop, SMVer);
316-
if (ValidTyMask == 0) {
317-
report_fatal_error(StringRef(std::to_string(SMVer).append(
318-
": Unhandled Shader Model Version")),
319-
/*gen_crash_diag*/ false);
320-
}
321319

322320
switch (ValidTyMask) {
323321
case OverloadKind::VOID:
@@ -344,14 +342,15 @@ Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, uint32_t SMVer,
344342
}
345343
}
346344

347-
// Prop->OverloadParamIndex is 0, overload type is FT->getReturnType().
345+
// Consider FT->getReturnType() as default overload type, unless
346+
// Prop->OverloadParamIndex != 0.
348347
Type *OverloadType = FT->getReturnType();
349348
if (Prop->OverloadParamIndex != 0) {
350349
// Skip Return Type.
351350
OverloadType = FT->getParamType(Prop->OverloadParamIndex - 1);
352351
}
353352

354-
auto ParamKinds = getOpCodeParameterKind(*Prop);
353+
const auto *ParamKinds = getOpCodeParameterKind(*Prop);
355354
auto Kind = ParamKinds[Prop->OverloadParamIndex];
356355
// For ResRet and CBufferRet, OverloadTy is in field of StructType.
357356
if (Kind == ParameterKind::CBufferRet ||

llvm/lib/Target/DirectX/DXILOpBuilder.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
#include "DXILConstants.h"
1616
#include "llvm/ADT/SmallVector.h"
17-
#include <cstdint>
17+
#include "llvm/Support/VersionTuple.h"
1818

1919
namespace llvm {
2020
class Module;
@@ -38,10 +38,11 @@ class DXILOpBuilder {
3838
/// \param ReturnTy Return type of the DXIL Op call constructed
3939
/// \param OverloadTy Overload type of the DXIL Op call constructed
4040
/// \return DXIL Op call constructed
41-
CallInst *createDXILOpCall(dxil::OpCode OpCode, uint32_t SMVer,
41+
CallInst *createDXILOpCall(dxil::OpCode OpCode, VersionTuple &SMVer,
4242
Type *ReturnTy, Type *OverloadTy,
4343
SmallVector<Value *> Args);
44-
Type *getOverloadTy(dxil::OpCode OpCode, uint32_t SMVer, FunctionType *FT);
44+
Type *getOverloadTy(dxil::OpCode OpCode, VersionTuple &SMVer,
45+
FunctionType *FT);
4546
static const char *getOpCodeName(dxil::OpCode DXILOp);
4647

4748
private:

llvm/lib/Target/DirectX/DXILOpLowering.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "llvm/MC/TargetRegistry.h"
2626
#include "llvm/Pass.h"
2727
#include "llvm/Support/ErrorHandling.h"
28+
#include "llvm/Support/VersionTuple.h"
2829

2930
#define DEBUG_TYPE "dxil-op-lower"
3031

@@ -73,7 +74,7 @@ static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
7374
return NewOperands;
7475
}
7576

76-
static uint32_t getModuleShaderModelVersion(Module &M) {
77+
static VersionTuple getModuleShaderModelVersion(Module &M) {
7778
std::string TTStr = M.getTargetTriple();
7879
std::string Error;
7980
auto Target = TargetRegistry::lookupTarget(TTStr, Error);
@@ -82,16 +83,13 @@ static uint32_t getModuleShaderModelVersion(Module &M) {
8283
report_fatal_error(StringRef(Error), /*gen_crash_diag*/ false);
8384
}
8485
}
85-
auto Major = Triple(TTStr).getOSVersion().getMajor();
86-
auto MinorOrErr = Triple(TTStr).getOSVersion().getMinor();
87-
uint32_t Minor = MinorOrErr.has_value() ? *MinorOrErr : 0;
88-
return COMPUTE_SM_VERSION_VALUE(Major, Minor);
86+
return Triple(TTStr).getOSVersion();
8987
}
9088

9189
static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
9290
IRBuilder<> B(M.getContext());
9391
DXILOpBuilder DXILB(M, B);
94-
uint32_t SMVer = getModuleShaderModelVersion(M);
92+
VersionTuple SMVer = getModuleShaderModelVersion(M);
9593
Type *OverloadTy = DXILB.getOverloadTy(DXILOp, SMVer, F.getFunctionType());
9694
for (User *U : make_early_inc_range(F.users())) {
9795
CallInst *CI = dyn_cast<CallInst>(U);

llvm/test/CodeGen/DirectX/Inputs/sin/double.ll

Lines changed: 0 additions & 10 deletions
This file was deleted.

llvm/test/CodeGen/DirectX/sin.ll

Lines changed: 0 additions & 23 deletions
This file was deleted.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library %s 2>&1 | FileCheck %s --check-prefix=SM6_0_DOUBLE
2+
; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s --check-prefix=SM6_3_DOUBLE
3+
4+
; Double is not valid in any Shader Model version
5+
; SM6_0_DOUBLE: LLVM ERROR: Invalid Overload
6+
; SM6_3_DOUBLE: LLVM ERROR: Invalid Overload
7+
8+
define noundef double @sin_double(double noundef %a) #0 {
9+
entry:
10+
%a.addr = alloca double, align 8
11+
store double %a, ptr %a.addr, align 8
12+
%0 = load double, ptr %a.addr, align 8
13+
%1 = call double @llvm.sin.f64(double %0)
14+
ret double %1
15+
}
16+

llvm/test/CodeGen/DirectX/Inputs/sin/float.ll renamed to llvm/test/CodeGen/DirectX/sin_sm_60.ll

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library %s | FileCheck %s -check-prefix=SM6_0_FLOAT
2+
3+
; Float is valid for SM6.0
4+
; SM6_0_FLOAT: call float @dx.op.unary.f32(i32 13, float %{{.*}})
5+
16
; Function Attrs: noinline nounwind optnone
27
define noundef float @sin_float(float noundef %a) #0 {
38
entry:

llvm/test/CodeGen/DirectX/Inputs/sin/half.ll renamed to llvm/test/CodeGen/DirectX/sin_sm_60_error.ll

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library %s 2>&1 | FileCheck %s -check-prefix=SM6_0_HALF
2+
3+
; Half is not valid for SM6.0
4+
; SM6_0_HALF: LLVM ERROR: Invalid Overload
5+
16
; Function Attrs: noinline nounwind optnone
27
define noundef half @sin_half(half noundef %a) #0 {
38
entry:
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s -check-prefix=SM6_3
2+
; Half and float are valid for SM6.2 and later
3+
; SM6_3: call half @dx.op.unary.f16(i32 13, half %{{.*}})
4+
; SM6_3: call float @dx.op.unary.f32(i32 13, float %{{.*}})
5+
6+
; Function Attrs: noinline nounwind optnone
7+
define noundef half @sin_half(half noundef %a) #0 {
8+
entry:
9+
%a.addr = alloca half, align 2
10+
store half %a, ptr %a.addr, align 2
11+
%0 = load half, ptr %a.addr, align 2
12+
%1 = call half @llvm.sin.f16(half %0)
13+
ret half %1
14+
}
15+
16+
; Function Attrs: noinline nounwind optnone
17+
define noundef float @sin_float(float noundef %a) #0 {
18+
entry:
19+
%a.addr = alloca float, align 4
20+
store float %a, ptr %a.addr, align 4
21+
%0 = load float, ptr %a.addr, align 4
22+
%1 = call float @llvm.sin.f32(float %0)
23+
ret float %1
24+
}
25+
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library %s 2>&1 | FileCheck %s --check-prefix=SM6_0_DOUBLE
2+
; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s --check-prefix=SM6_3_DOUBLE
3+
4+
; Double is not valid in any Shader Model version
5+
; SM6_0_DOUBLE: LLVM ERROR: Invalid Overload
6+
; SM6_3_DOUBLE: LLVM ERROR: Invalid Overload
7+
8+
define noundef double @sin_double(double noundef %a) #0 {
9+
entry:
10+
%a.addr = alloca double, align 8
11+
store double %a, ptr %a.addr, align 8
12+
%0 = load double, ptr %a.addr, align 8
13+
%1 = call double @llvm.sin.f64(double %0)
14+
ret double %1
15+
}
16+

0 commit comments

Comments
 (0)