Skip to content

Commit 3f51563

Browse files
committed
[DXIL] exp, any, lerp, & rcp Intrinsic Lowering
This change implements lowering for #70076, #70100, #70072, & #70102 `CGBuiltin.cpp` - - simplify `lerp` intrinsic `IntrinsicsDirectX.td` - simplify `lerp` intrinsic `SemaChecking.cpp` - remove unnecessary check `DXILIntrinsicExpansion.*` - add intrinsic to instruction expansion cases `DXILOpLowering.cpp` - make sure `DXILIntrinsicExpansion` happens first `DirectX.h` - changes to support new pass `DirectXTargetMachine.cpp` - changes to support new pass Why `any`, and `lerp` as instruction expansion just for DXIL? - SPIR-V there is an [OpAny](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpAny) - SPIR-V has a GLSL lerp extension via [Fmix](https://registry.khronos.org/SPIR-V/specs/1.0/GLSL.std.450.html#FMix) Why `exp` instruction expansion? - We have an `exp2` opcode and `exp` reuses that opcode. So instruction expansion is a convenient way to do preprocessing. - Further SPIR-V has a GLSL exp extension via [Exp](https://registry.khronos.org/SPIR-V/specs/1.0/GLSL.std.450.html#Exp) and [Exp2](https://registry.khronos.org/SPIR-V/specs/1.0/GLSL.std.450.html#Exp2) Why `rcp` as instruction expansion? This one is a bit of the odd man out and might have to move to `cgbuiltins` when we better understand SPIRV requirements. However I included it because it seems like [fast math mode has an AllowRecip flag](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_fp_fast_math_mode) which lets you compute the reciprocal without performing the division. We don't have that in DXIL so thought to include it.
1 parent af2bf86 commit 3f51563

File tree

18 files changed

+557
-89
lines changed

18 files changed

+557
-89
lines changed

clang/include/clang/AST/Type.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2244,6 +2244,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
22442244
bool isFloatingType() const; // C99 6.2.5p11 (real floating + complex)
22452245
bool isHalfType() const; // OpenCL 6.1.1.1, NEON (IEEE 754-2008 half)
22462246
bool isFloat16Type() const; // C11 extension ISO/IEC TS 18661
2247+
bool isFloat32Type() const;
22472248
bool isBFloat16Type() const;
22482249
bool isFloat128Type() const;
22492250
bool isIbm128Type() const;
@@ -7452,6 +7453,10 @@ inline bool Type::isFloat16Type() const {
74527453
return isSpecificBuiltinType(BuiltinType::Float16);
74537454
}
74547455

7456+
inline bool Type::isFloat32Type() const {
7457+
return isSpecificBuiltinType(BuiltinType::Float);
7458+
}
7459+
74557460
inline bool Type::isBFloat16Type() const {
74567461
return isSpecificBuiltinType(BuiltinType::BFloat16);
74577462
}

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18021,38 +18021,11 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
1802118021
Value *X = EmitScalarExpr(E->getArg(0));
1802218022
Value *Y = EmitScalarExpr(E->getArg(1));
1802318023
Value *S = EmitScalarExpr(E->getArg(2));
18024-
llvm::Type *Xty = X->getType();
18025-
llvm::Type *Yty = Y->getType();
18026-
llvm::Type *Sty = S->getType();
18027-
if (!Xty->isVectorTy() && !Yty->isVectorTy() && !Sty->isVectorTy()) {
18028-
if (Xty->isFloatingPointTy()) {
18029-
auto V = Builder.CreateFSub(Y, X);
18030-
V = Builder.CreateFMul(S, V);
18031-
return Builder.CreateFAdd(X, V, "dx.lerp");
18032-
}
18033-
llvm_unreachable("Scalar Lerp is only supported on floats.");
18034-
}
18035-
// A VectorSplat should have happened
18036-
assert(Xty->isVectorTy() && Yty->isVectorTy() && Sty->isVectorTy() &&
18037-
"Lerp of vector and scalar is not supported.");
18038-
18039-
[[maybe_unused]] auto *XVecTy =
18040-
E->getArg(0)->getType()->getAs<VectorType>();
18041-
[[maybe_unused]] auto *YVecTy =
18042-
E->getArg(1)->getType()->getAs<VectorType>();
18043-
[[maybe_unused]] auto *SVecTy =
18044-
E->getArg(2)->getType()->getAs<VectorType>();
18045-
// A HLSLVectorTruncation should have happend
18046-
assert(XVecTy->getNumElements() == YVecTy->getNumElements() &&
18047-
XVecTy->getNumElements() == SVecTy->getNumElements() &&
18048-
"Lerp requires vectors to be of the same size.");
18049-
assert(XVecTy->getElementType()->isRealFloatingType() &&
18050-
XVecTy->getElementType() == YVecTy->getElementType() &&
18051-
XVecTy->getElementType() == SVecTy->getElementType() &&
18052-
"Lerp requires float vectors to be of the same type.");
18024+
if (!E->getArg(0)->getType()->hasFloatingRepresentation())
18025+
llvm_unreachable("lerp operand must have a float representation");
1805318026
return Builder.CreateIntrinsic(
18054-
/*ReturnType=*/Xty, Intrinsic::dx_lerp, ArrayRef<Value *>{X, Y, S},
18055-
nullptr, "dx.lerp");
18027+
/*ReturnType=*/X->getType(), Intrinsic::dx_lerp,
18028+
ArrayRef<Value *>{X, Y, S}, nullptr, "dx.lerp");
1805618029
}
1805718030
case Builtin::BI__builtin_hlsl_elementwise_frac: {
1805818031
Value *Op0 = EmitScalarExpr(E->getArg(0));

clang/lib/Sema/SemaChecking.cpp

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5234,10 +5234,6 @@ bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
52345234
TheCall->getArg(1)->getEndLoc());
52355235
retValue = true;
52365236
}
5237-
5238-
if (!retValue)
5239-
TheCall->setType(VecTyA->getElementType());
5240-
52415237
return retValue;
52425238
}
52435239
}
@@ -5251,11 +5247,12 @@ bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
52515247
return true;
52525248
}
52535249

5254-
bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
5255-
QualType ExpectedType = S->Context.FloatTy;
5250+
bool CheckArgsTypesAreCorrect(
5251+
Sema *S, CallExpr *TheCall, QualType ExpectedType,
5252+
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
52565253
for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
52575254
QualType PassedType = TheCall->getArg(i)->getType();
5258-
if (!PassedType->hasFloatingRepresentation()) {
5255+
if (Check(PassedType)) {
52595256
if (auto *VecTyA = PassedType->getAs<VectorType>())
52605257
ExpectedType = S->Context.getVectorType(
52615258
ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
@@ -5268,6 +5265,26 @@ bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
52685265
return false;
52695266
}
52705267

5268+
bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
5269+
auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
5270+
return !PassedType->hasFloatingRepresentation();
5271+
};
5272+
return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
5273+
checkAllFloatTypes);
5274+
}
5275+
5276+
bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
5277+
auto checkFloatorHalf = [](clang::QualType PassedType) -> bool {
5278+
clang::QualType BaseType =
5279+
PassedType->isVectorType()
5280+
? PassedType->getAs<clang::VectorType>()->getElementType()
5281+
: PassedType;
5282+
return !BaseType->isHalfType() && !BaseType->isFloat32Type();
5283+
};
5284+
return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
5285+
checkFloatorHalf);
5286+
}
5287+
52715288
void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
52725289
QualType ReturnType) {
52735290
auto *VecTyA = TheCall->getArg(0)->getType()->getAs<VectorType>();
@@ -5295,21 +5312,27 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
52955312
return true;
52965313
break;
52975314
}
5298-
case Builtin::BI__builtin_hlsl_elementwise_isinf: {
5299-
if (checkArgCount(*this, TheCall, 1))
5300-
return true;
5315+
case Builtin::BI__builtin_hlsl_elementwise_rcp: {
53015316
if (CheckAllArgsHaveFloatRepresentation(this, TheCall))
53025317
return true;
5303-
SetElementTypeAsReturnType(this, TheCall, this->Context.BoolTy);
5318+
if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
5319+
return true;
53045320
break;
53055321
}
53065322
case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
5307-
case Builtin::BI__builtin_hlsl_elementwise_rcp:
53085323
case Builtin::BI__builtin_hlsl_elementwise_frac: {
5309-
if (CheckAllArgsHaveFloatRepresentation(this, TheCall))
5324+
if (CheckFloatOrHalfRepresentations(this, TheCall))
5325+
return true;
5326+
if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
5327+
return true;
5328+
break;
5329+
}
5330+
case Builtin::BI__builtin_hlsl_elementwise_isinf: {
5331+
if (CheckFloatOrHalfRepresentations(this, TheCall))
53105332
return true;
53115333
if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
53125334
return true;
5335+
SetElementTypeAsReturnType(this, TheCall, this->Context.BoolTy);
53135336
break;
53145337
}
53155338
case Builtin::BI__builtin_hlsl_lerp: {
@@ -5319,7 +5342,7 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
53195342
return true;
53205343
if (SemaBuiltinElementwiseTernaryMath(TheCall))
53215344
return true;
5322-
if (CheckAllArgsHaveFloatRepresentation(this, TheCall))
5345+
if (CheckFloatOrHalfRepresentations(this, TheCall))
53235346
return true;
53245347
break;
53255348
}

clang/test/CodeGenHLSL/builtins/lerp-builtin.hlsl

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,5 @@
11
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s
22

3-
4-
5-
// CHECK-LABEL: builtin_lerp_half_scalar
6-
// CHECK: %3 = fsub double %conv1, %conv
7-
// CHECK: %4 = fmul double %conv2, %3
8-
// CHECK: %dx.lerp = fadd double %conv, %4
9-
// CHECK: %conv3 = fptrunc double %dx.lerp to half
10-
// CHECK: ret half %conv3
11-
half builtin_lerp_half_scalar (half p0) {
12-
return __builtin_hlsl_lerp ( p0, p0, p0 );
13-
}
14-
15-
// CHECK-LABEL: builtin_lerp_float_scalar
16-
// CHECK: %3 = fsub double %conv1, %conv
17-
// CHECK: %4 = fmul double %conv2, %3
18-
// CHECK: %dx.lerp = fadd double %conv, %4
19-
// CHECK: %conv3 = fptrunc double %dx.lerp to float
20-
// CHECK: ret float %conv3
21-
float builtin_lerp_float_scalar ( float p0) {
22-
return __builtin_hlsl_lerp ( p0, p0, p0 );
23-
}
24-
253
// CHECK-LABEL: builtin_lerp_half_vector
264
// CHECK: %dx.lerp = call <3 x half> @llvm.dx.lerp.v3f16(<3 x half> %0, <3 x half> %1, <3 x half> %2)
275
// CHECK: ret <3 x half> %dx.lerp

clang/test/CodeGenHLSL/builtins/lerp.hlsl

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,51 +6,46 @@
66
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
77
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
88

9-
// NATIVE_HALF: %3 = fsub half %1, %0
10-
// NATIVE_HALF: %4 = fmul half %2, %3
11-
// NATIVE_HALF: %dx.lerp = fadd half %0, %4
9+
10+
// NATIVE_HALF: %dx.lerp = call half @llvm.dx.lerp.f16(half %0, half %1, half %2)
1211
// NATIVE_HALF: ret half %dx.lerp
13-
// NO_HALF: %3 = fsub float %1, %0
14-
// NO_HALF: %4 = fmul float %2, %3
15-
// NO_HALF: %dx.lerp = fadd float %0, %4
12+
// NO_HALF: %dx.lerp = call float @llvm.dx.lerp.f32(float %0, float %1, float %2)
1613
// NO_HALF: ret float %dx.lerp
1714
half test_lerp_half(half p0) { return lerp(p0, p0, p0); }
1815

1916
// NATIVE_HALF: %dx.lerp = call <2 x half> @llvm.dx.lerp.v2f16(<2 x half> %0, <2 x half> %1, <2 x half> %2)
2017
// NATIVE_HALF: ret <2 x half> %dx.lerp
2118
// NO_HALF: %dx.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
2219
// NO_HALF: ret <2 x float> %dx.lerp
23-
half2 test_lerp_half2(half2 p0, half2 p1) { return lerp(p0, p0, p0); }
20+
half2 test_lerp_half2(half2 p0) { return lerp(p0, p0, p0); }
2421

2522
// NATIVE_HALF: %dx.lerp = call <3 x half> @llvm.dx.lerp.v3f16(<3 x half> %0, <3 x half> %1, <3 x half> %2)
2623
// NATIVE_HALF: ret <3 x half> %dx.lerp
2724
// NO_HALF: %dx.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
2825
// NO_HALF: ret <3 x float> %dx.lerp
29-
half3 test_lerp_half3(half3 p0, half3 p1) { return lerp(p0, p0, p0); }
26+
half3 test_lerp_half3(half3 p0) { return lerp(p0, p0, p0); }
3027

3128
// NATIVE_HALF: %dx.lerp = call <4 x half> @llvm.dx.lerp.v4f16(<4 x half> %0, <4 x half> %1, <4 x half> %2)
3229
// NATIVE_HALF: ret <4 x half> %dx.lerp
3330
// NO_HALF: %dx.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
3431
// NO_HALF: ret <4 x float> %dx.lerp
35-
half4 test_lerp_half4(half4 p0, half4 p1) { return lerp(p0, p0, p0); }
32+
half4 test_lerp_half4(half4 p0) { return lerp(p0, p0, p0); }
3633

37-
// CHECK: %3 = fsub float %1, %0
38-
// CHECK: %4 = fmul float %2, %3
39-
// CHECK: %dx.lerp = fadd float %0, %4
34+
// CHECK: %dx.lerp = call float @llvm.dx.lerp.f32(float %0, float %1, float %2)
4035
// CHECK: ret float %dx.lerp
41-
float test_lerp_float(float p0, float p1) { return lerp(p0, p0, p0); }
36+
float test_lerp_float(float p0) { return lerp(p0, p0, p0); }
4237

4338
// CHECK: %dx.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
4439
// CHECK: ret <2 x float> %dx.lerp
45-
float2 test_lerp_float2(float2 p0, float2 p1) { return lerp(p0, p0, p0); }
40+
float2 test_lerp_float2(float2 p0) { return lerp(p0, p0, p0); }
4641

4742
// CHECK: %dx.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
4843
// CHECK: ret <3 x float> %dx.lerp
49-
float3 test_lerp_float3(float3 p0, float3 p1) { return lerp(p0, p0, p0); }
44+
float3 test_lerp_float3(float3 p0) { return lerp(p0, p0, p0); }
5045

5146
// CHECK: %dx.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
5247
// CHECK: ret <4 x float> %dx.lerp
53-
float4 test_lerp_float4(float4 p0, float4 p1) { return lerp(p0, p0, p0); }
48+
float4 test_lerp_float4(float4 p0) { return lerp(p0, p0, p0); }
5449

5550
// CHECK: %dx.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %splat.splat, <2 x float> %1, <2 x float> %2)
5651
// CHECK: ret <2 x float> %dx.lerp

clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,5 +92,18 @@ float builtin_lerp_int_to_float_promotion(float p0, int p1) {
9292

9393
float4 test_lerp_int4(int4 p0, int4 p1, int4 p2) {
9494
return __builtin_hlsl_lerp(p0, p1, p2);
95-
// expected-error@-1 {{1st argument must be a floating point type (was 'int4' (aka 'vector<int, 4>'))}}
96-
}
95+
// expected-error@-1 {{1st argument must be a floating point type (was 'int4' (aka 'vector<int, 4>'))}}
96+
}
97+
98+
// note: DefaultVariadicArgumentPromotion --> DefaultArgumentPromotion has already promoted to double
99+
// we don't know anymore that the input was half when __builtin_hlsl_lerp is called so we default to float
100+
// for expected type
101+
half builtin_lerp_half_scalar (half p0) {
102+
return __builtin_hlsl_lerp ( p0, p0, p0 );
103+
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'float'}}
104+
}
105+
106+
float builtin_lerp_float_scalar ( float p0) {
107+
return __builtin_hlsl_lerp ( p0, p0, p0 );
108+
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'float'}}
109+
}

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ def int_dx_isinf :
3333
DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>],
3434
[llvm_anyfloat_ty]>;
3535

36-
def int_dx_lerp :
37-
Intrinsic<[LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
38-
[llvm_anyvector_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>,LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
36+
def int_dx_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
3937
[IntrNoMem, IntrWillReturn] >;
4038

4139
def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;

llvm/lib/Target/DirectX/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ add_llvm_target(DirectXCodeGen
1919
DirectXSubtarget.cpp
2020
DirectXTargetMachine.cpp
2121
DXContainerGlobals.cpp
22+
DXILIntrinsicExpansion.cpp
2223
DXILMetadata.cpp
2324
DXILOpBuilder.cpp
2425
DXILOpLowering.cpp

0 commit comments

Comments
 (0)