Skip to content

Commit 20a1723

Browse files
committed
[HLSL] Overloads for lerp with a scalar weight
This adds overloads for the `lerp` function that accept a scalar for the weight parameter by splatting it into the appropriate vector. Fixes #137827
1 parent 2c9a706 commit 20a1723

File tree

4 files changed

+47
-14
lines changed

4 files changed

+47
-14
lines changed

clang/lib/Headers/hlsl/hlsl_compat_overloads.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,12 @@ constexpr bool4 isinf(double4 V) { return isinf((float4)V); }
277277
// lerp builtins overloads
278278
//===----------------------------------------------------------------------===//
279279

280+
template <typename T, uint N>
281+
constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
282+
lerp(vector<T, N> x, vector<T, N> y, T s) {
283+
return lerp(x, y, (vector<T, N>)s);
284+
}
285+
280286
_DXC_COMPAT_TERNARY_DOUBLE_OVERLOADS(lerp)
281287
_DXC_COMPAT_TERNARY_INTEGER_OVERLOADS(lerp)
282288

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2555,7 +2555,8 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
25552555
case Builtin::BI__builtin_hlsl_lerp: {
25562556
if (SemaRef.checkArgCount(TheCall, 3))
25572557
return true;
2558-
if (CheckVectorElementCallArgs(&SemaRef, TheCall))
2558+
if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0) ||
2559+
CheckAllArgsHaveSameType(&SemaRef, TheCall))
25592560
return true;
25602561
if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
25612562
return true;

clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -emit-llvm -O1 -o - | FileCheck %s --check-prefixes=CHECK -DFNATTRS="noundef nofpclass(nan inf)" -DTARGET=dx
2-
// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -x hlsl -triple spirv-unknown-vulkan-compute %s -emit-llvm -O1 -o - | FileCheck %s --check-prefixes=CHECK -DFNATTRS="spir_func noundef nofpclass(nan inf)" -DTARGET=spv
1+
// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s --check-prefixes=CHECK,NATIVE_HALF -DFNATTRS="noundef nofpclass(nan inf)" -DTARGET=dx
2+
// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -emit-llvm -O1 -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF -DFNATTRS="noundef nofpclass(nan inf)" -DTARGET=dx
3+
// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -x hlsl -triple spirv-unknown-vulkan-compute %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s --check-prefixes=CHECK,NATIVE_HALF -DFNATTRS="spir_func noundef nofpclass(nan inf)" -DTARGET=spv
4+
// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -x hlsl -triple spirv-unknown-vulkan-compute %s -emit-llvm -O1 -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF -DFNATTRS="spir_func noundef nofpclass(nan inf)" -DTARGET=spv
35

46
// CHECK: define [[FNATTRS]] float @_Z16test_lerp_doubled(
57
// CHECK-NEXT: [[ENTRY:.*:]]
@@ -158,3 +160,27 @@ float3 test_lerp_uint64_t3(uint64_t3 p0) { return lerp(p0, p0, p0); }
158160
// CHECK-NEXT: ret <4 x float> [[LERP]]
159161
//
160162
float4 test_lerp_uint64_t4(uint64_t4 p0) { return lerp(p0, p0, p0); }
163+
164+
// NATIVE_HALF: define [[FNATTRS]] <3 x half> @_Z21test_lerp_half_scalarDv3_DhS_Dh{{.*}}(
165+
// NO_HALF: define [[FNATTRS]] <3 x float> @_Z21test_lerp_half_scalarDv3_DhS_Dh(
166+
// CHECK-NEXT: [[ENTRY:.*:]]
167+
// NATIVE_HALF-NEXT: [[SPLATINSERT:%.*]] = insertelement <3 x half> poison, half %{{.*}}, i64 0
168+
// NATIVE_HALF-NEXT: [[SPLAT:%.*]] = shufflevector <3 x half> [[SPLATINSERT]], <3 x half> poison, <3 x i32> zeroinitializer
169+
// NATIVE_HALF-NEXT: [[LERP:%.*]] = tail call {{.*}} <3 x half> @llvm.[[TARGET]].lerp.v3f16(<3 x half> {{.*}}, <3 x half> {{.*}}, <3 x half> [[SPLAT]])
170+
// NATIVE_HALF-NEXT: ret <3 x half> [[LERP]]
171+
// NO_HALF-NEXT: [[SPLATINSERT:%.*]] = insertelement <3 x float> poison, float %{{.*}}, i64 0
172+
// NO_HALF-NEXT: [[SPLAT:%.*]] = shufflevector <3 x float> [[SPLATINSERT]], <3 x float> poison, <3 x i32> zeroinitializer
173+
// NO_HALF-NEXT: [[LERP:%.*]] = tail call {{.*}} <3 x float> @llvm.[[TARGET]].lerp.v3f32(<3 x float> {{.*}}, <3 x float> {{.*}}, <3 x float> [[SPLAT]])
174+
// NO_HALF-NEXT: ret <3 x float> [[LERP]]
175+
half3 test_lerp_half_scalar(half3 x, half3 y, half s) { return lerp(x, y, s); }
176+
177+
// CHECK: define [[FNATTRS]] <3 x float> @_Z22test_lerp_float_scalarDv3_fS_f(
178+
// CHECK-NEXT: [[ENTRY:.*:]]
179+
// CHECK-NEXT: [[SPLATINSERT:%.*]] = insertelement <3 x float> poison, float %{{.*}}, i64 0
180+
// CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <3 x float> [[SPLATINSERT]], <3 x float> poison, <3 x i32> zeroinitializer
181+
// CHECK-NEXT: [[LERP:%.*]] = tail call {{.*}} <3 x float> @llvm.[[TARGET]].lerp.v3f32(<3 x float> {{.*}}, <3 x float> {{.*}}, <3 x float> [[SPLAT]])
182+
// CHECK-NEXT: ret <3 x float> [[LERP]]
183+
//
184+
float3 test_lerp_float_scalar(float3 x, float3 y, float s) {
185+
return lerp(x, y, s);
186+
}

clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,42 +62,42 @@ float2 test_lerp_element_type_mismatch(half2 p0, float2 p1) {
6262

6363
float2 test_builtin_lerp_float2_splat(float p0, float2 p1) {
6464
return __builtin_hlsl_lerp(p0, p1, p1);
65-
// expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be vectors}}
65+
// expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the same type}}
6666
}
6767

6868
float2 test_builtin_lerp_float2_splat2(double p0, double2 p1) {
6969
return __builtin_hlsl_lerp(p1, p0, p1);
70-
// expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be vectors}}
70+
// expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the same type}}
7171
}
7272

7373
float2 test_builtin_lerp_float2_splat3(double p0, double2 p1) {
7474
return __builtin_hlsl_lerp(p1, p1, p0);
75-
// expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be vectors}}
75+
// expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the same type}}
7676
}
7777

7878
float3 test_builtin_lerp_float3_splat(float p0, float3 p1) {
7979
return __builtin_hlsl_lerp(p0, p1, p1);
80-
// expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be vectors}}
80+
// expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the same type}}
8181
}
8282

8383
float4 test_builtin_lerp_float4_splat(float p0, float4 p1) {
8484
return __builtin_hlsl_lerp(p0, p1, p1);
85-
// expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be vectors}}
85+
// expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the same type}}
8686
}
8787

8888
float2 test_lerp_float2_int_splat(float2 p0, int p1) {
8989
return __builtin_hlsl_lerp(p0, p1, p1);
90-
// expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be vectors}}
90+
// expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the same type}}
9191
}
9292

9393
float3 test_lerp_float3_int_splat(float3 p0, int p1) {
9494
return __builtin_hlsl_lerp(p0, p1, p1);
95-
// expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be vectors}}
95+
// expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the same type}}
9696
}
9797

9898
float2 test_builtin_lerp_int_vect_to_float_vec_promotion(int2 p0, float p1) {
9999
return __builtin_hlsl_lerp(p0, p1, p1);
100-
// expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be vectors}}
100+
// expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the same type}}
101101
}
102102

103103
float test_builtin_lerp_bool_type_promotion(bool p0) {
@@ -107,17 +107,17 @@ float test_builtin_lerp_bool_type_promotion(bool p0) {
107107

108108
float builtin_bool_to_float_type_promotion(float p0, bool p1) {
109109
return __builtin_hlsl_lerp(p0, p0, p1);
110-
// expected-error@-1 {{3rd argument must be a scalar or vector of floating-point types (was 'bool')}}
110+
// expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the same type}}
111111
}
112112

113113
float builtin_bool_to_float_type_promotion2(bool p0, float p1) {
114114
return __builtin_hlsl_lerp(p1, p0, p1);
115-
// expected-error@-1 {{2nd argument must be a scalar or vector of floating-point types (was 'bool')}}
115+
// expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the same type}}
116116
}
117117

118118
float builtin_lerp_int_to_float_promotion(float p0, int p1) {
119119
return __builtin_hlsl_lerp(p0, p0, p1);
120-
// expected-error@-1 {{3rd argument must be a scalar or vector of floating-point types (was 'int')}}
120+
// expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the same type}}
121121
}
122122

123123
float4 test_lerp_int4(int4 p0, int4 p1, int4 p2) {

0 commit comments

Comments
 (0)