Skip to content

[HLSL] select scalar overloads for vector conditions #129396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -12707,6 +12707,10 @@ def err_hlsl_param_qualifier_mismatch :
def err_hlsl_vector_compound_assignment_truncation : Error<
"left hand operand of type %0 to compound assignment cannot be truncated "
"when used with right hand operand of type %1">;
def err_hlsl_builtin_scalar_vector_mismatch
: Error<
"%select{all|second and third}0 arguments to %1 must be of scalar or "
"vector type with matching scalar element type%diff{: $ vs $|}2,3">;

def warn_hlsl_impcast_vector_truncation : Warning<
"implicit conversion truncates vector: %0 to %1">, InGroup<Conversion>;
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19836,6 +19836,14 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
RValFalse.isScalar()
? RValFalse.getScalarVal()
: RValFalse.getAggregatePointer(E->getArg(2)->getType(), *this);
if (auto *VTy = E->getType()->getAs<VectorType>()) {
if (!OpTrue->getType()->isVectorTy())
OpTrue =
Builder.CreateVectorSplat(VTy->getNumElements(), OpTrue, "splat");
if (!OpFalse->getType()->isVectorTy())
OpFalse =
Builder.CreateVectorSplat(VTy->getNumElements(), OpFalse, "splat");
}

Value *SelectVal =
Builder.CreateSelect(OpCond, OpTrue, OpFalse, "hlsl.select");
Expand Down
1 change: 1 addition & 0 deletions clang/lib/Headers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ set(hlsl_h
set(hlsl_subdir_files
hlsl/hlsl_basic_types.h
hlsl/hlsl_alias_intrinsics.h
hlsl/hlsl_intrinsic_helpers.h
hlsl/hlsl_intrinsics.h
hlsl/hlsl_detail.h
)
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/Headers/hlsl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
#pragma clang diagnostic ignored "-Whlsl-dxc-compatability"
#endif

// Basic types, type traits and type-independent templates.
#include "hlsl/hlsl_basic_types.h"
#include "hlsl/hlsl_detail.h"

// HLSL standard library function declarations/definitions.
#include "hlsl/hlsl_alias_intrinsics.h"
#include "hlsl/hlsl_intrinsics.h"

#if defined(__clang__)
Expand Down
35 changes: 35 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -2123,6 +2123,41 @@ template <typename T, int Sz>
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, vector<T, Sz>);

/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, T TrueVal,
/// vector<T,Sz> FalseVals)
/// \brief ternary operator for vectors. All vectors must be the same size.
/// \param Conds The Condition input values.
/// \param TrueVal The scalar value to splat from when conditions are true.
/// \param FalseVals The vector values are chosen from when conditions are
/// false.

template <typename T, int Sz>
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
vector<T, Sz> select(vector<bool, Sz>, T, vector<T, Sz>);

/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals,
/// T FalseVal)
/// \brief ternary operator for vectors. All vectors must be the same size.
/// \param Conds The Condition input values.
/// \param TrueVals The vector values are chosen from when conditions are true.
/// \param FalseVal The scalar value to splat from when conditions are false.

template <typename T, int Sz>
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, T);

/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals,
/// T FalseVal)
/// \brief ternary operator for vectors. All vectors must be the same size.
/// \param Conds The Condition input values.
/// \param TrueVal The scalar value to splat from when conditions are true.
/// \param FalseVal The scalar value to splat from when conditions are false.

template <typename T, int Sz>
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
__detail::enable_if_t<__detail::is_arithmetic<T>::Value, vector<T, Sz>> select(
vector<bool, Sz>, T, T);

//===----------------------------------------------------------------------===//
// sin builtins
//===----------------------------------------------------------------------===//
Expand Down
60 changes: 4 additions & 56 deletions clang/lib/Headers/hlsl/hlsl_detail.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===----- detail.h - HLSL definitions for intrinsics ----------===//
//===----- hlsl_detail.h - HLSL definitions for intrinsics ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -9,8 +9,6 @@
#ifndef _HLSL_HLSL_DETAILS_H_
#define _HLSL_HLSL_DETAILS_H_

#include "hlsl_alias_intrinsics.h"

namespace hlsl {

namespace __detail {
Expand Down Expand Up @@ -43,59 +41,9 @@ constexpr enable_if_t<sizeof(U) == sizeof(T), U> bit_cast(T F) {
return __builtin_bit_cast(U, F);
}

constexpr vector<uint, 4> d3d_color_to_ubyte4_impl(vector<float, 4> V) {
// Use the same scaling factor used by FXC, and DXC for DXIL
// (i.e., 255.001953)
// https://github.com/microsoft/DirectXShaderCompiler/blob/070d0d5a2beacef9eeb51037a9b04665716fd6f3/lib/HLSL/HLOperationLower.cpp#L666C1-L697C2
// The DXC implementation refers to a comment on the following stackoverflow
// discussion to justify the scaling factor: "Built-in rounding, necessary
// because of truncation. 0.001953 * 256 = 0.5"
// https://stackoverflow.com/questions/52103720/why-does-d3dcolortoubyte4-multiplies-components-by-255-001953f
return V.zyxw * 255.001953f;
}

template <typename T>
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
length_impl(T X) {
return abs(X);
}

template <typename T, int N>
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
length_vec_impl(vector<T, N> X) {
#if (__has_builtin(__builtin_spirv_length))
return __builtin_spirv_length(X);
#else
return sqrt(dot(X, X));
#endif
}

template <typename T>
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
distance_impl(T X, T Y) {
return length_impl(X - Y);
}

template <typename T, int N>
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
distance_vec_impl(vector<T, N> X, vector<T, N> Y) {
return length_vec_impl(X - Y);
}

template <typename T>
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
reflect_impl(T I, T N) {
return I - 2 * N * I * N;
}

template <typename T, int L>
constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
#if (__has_builtin(__builtin_spirv_reflect))
return __builtin_spirv_reflect(I, N);
#else
return I - 2 * N * dot(I, N);
#endif
}
template <typename T> struct is_arithmetic {
static const bool Value = __is_arithmetic(T);
};

} // namespace __detail
} // namespace hlsl
Expand Down
71 changes: 71 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
//===----- hlsl_intrinsic_helpers.h - HLSL helpers intrinsics -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef _HLSL_HLSL_INTRINSIC_HELPERS_H_
Copy link
Member

@farzonl farzonl Mar 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does the enable_if_t template in hlsl_detail.hget exposed here? Does order in clang/lib/Headers/hlsl.h matter I see you have hlsl_detail.h before hlsl_alias_intrinsics.h.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hlsl_detail.h is included first in hlsl.h, which makes it included before the other headers. These headers are all implementation details, and aren't expected to be exposed to user code since hlsl.h is implicitly included in all HLSL source files.

#define _HLSL_HLSL_INTRINSIC_HELPERS_H_

namespace hlsl {
namespace __detail {

constexpr vector<uint, 4> d3d_color_to_ubyte4_impl(vector<float, 4> V) {
// Use the same scaling factor used by FXC, and DXC for DXIL
// (i.e., 255.001953)
// https://github.com/microsoft/DirectXShaderCompiler/blob/070d0d5a2beacef9eeb51037a9b04665716fd6f3/lib/HLSL/HLOperationLower.cpp#L666C1-L697C2
// The DXC implementation refers to a comment on the following stackoverflow
// discussion to justify the scaling factor: "Built-in rounding, necessary
// because of truncation. 0.001953 * 256 = 0.5"
// https://stackoverflow.com/questions/52103720/why-does-d3dcolortoubyte4-multiplies-components-by-255-001953f
return V.zyxw * 255.001953f;
}

template <typename T>
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
length_impl(T X) {
return abs(X);
}

template <typename T, int N>
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
length_vec_impl(vector<T, N> X) {
#if (__has_builtin(__builtin_spirv_length))
return __builtin_spirv_length(X);
#else
return sqrt(dot(X, X));
#endif
}

template <typename T>
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
distance_impl(T X, T Y) {
return length_impl(X - Y);
}

template <typename T, int N>
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
distance_vec_impl(vector<T, N> X, vector<T, N> Y) {
return length_vec_impl(X - Y);
}

template <typename T>
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
reflect_impl(T I, T N) {
return I - 2 * N * I * N;
}

template <typename T, int L>
constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
#if (__has_builtin(__builtin_spirv_reflect))
return __builtin_spirv_reflect(I, N);
#else
return I - 2 * N * dot(I, N);
#endif
}
} // namespace __detail
} // namespace hlsl

#endif // _HLSL_HLSL_INTRINSIC_HELPERS_H_
2 changes: 1 addition & 1 deletion clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#ifndef _HLSL_HLSL_INTRINSICS_H_
#define _HLSL_HLSL_INTRINSICS_H_

#include "hlsl_detail.h"
#include "hlsl/hlsl_intrinsic_helpers.h"

namespace hlsl {

Expand Down
56 changes: 32 additions & 24 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2225,40 +2225,48 @@ static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
assert(TheCall->getNumArgs() == 3);
Expr *Arg1 = TheCall->getArg(1);
QualType Arg1Ty = Arg1->getType();
Expr *Arg2 = TheCall->getArg(2);
if (!Arg1->getType()->isVectorType()) {
S->Diag(Arg1->getBeginLoc(), diag::err_builtin_non_vector_type)
<< "Second" << TheCall->getDirectCallee() << Arg1->getType()
QualType Arg2Ty = Arg2->getType();

QualType Arg1ScalarTy = Arg1Ty;
if (auto VTy = Arg1ScalarTy->getAs<VectorType>())
Arg1ScalarTy = VTy->getElementType();

QualType Arg2ScalarTy = Arg2Ty;
if (auto VTy = Arg2ScalarTy->getAs<VectorType>())
Arg2ScalarTy = VTy->getElementType();

if (!S->Context.hasSameUnqualifiedType(Arg1ScalarTy, Arg2ScalarTy))
S->Diag(Arg1->getBeginLoc(), diag::err_hlsl_builtin_scalar_vector_mismatch)
<< /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty;

QualType Arg0Ty = TheCall->getArg(0)->getType();
unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements();
unsigned Arg1Length = Arg1Ty->isVectorType()
? Arg1Ty->getAs<VectorType>()->getNumElements()
: 0;
unsigned Arg2Length = Arg2Ty->isVectorType()
? Arg2Ty->getAs<VectorType>()->getNumElements()
: 0;
if (Arg1Length > 0 && Arg0Length != Arg1Length) {
S->Diag(TheCall->getBeginLoc(),
diag::err_typecheck_vector_lengths_not_equal)
<< Arg0Ty << Arg1Ty << TheCall->getArg(0)->getSourceRange()
<< Arg1->getSourceRange();
return true;
}

if (!Arg2->getType()->isVectorType()) {
S->Diag(Arg2->getBeginLoc(), diag::err_builtin_non_vector_type)
<< "Third" << TheCall->getDirectCallee() << Arg2->getType()
<< Arg2->getSourceRange();
return true;
}

if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
if (Arg2Length > 0 && Arg0Length != Arg2Length) {
S->Diag(TheCall->getBeginLoc(),
diag::err_typecheck_call_different_arg_types)
<< Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
diag::err_typecheck_vector_lengths_not_equal)
<< Arg0Ty << Arg2Ty << TheCall->getArg(0)->getSourceRange()
<< Arg2->getSourceRange();
return true;
}

// caller has checked that Arg0 is a vector.
// check all three args have the same length.
if (TheCall->getArg(0)->getType()->getAs<VectorType>()->getNumElements() !=
Arg1->getType()->getAs<VectorType>()->getNumElements()) {
S->Diag(TheCall->getBeginLoc(),
diag::err_typecheck_vector_lengths_not_equal)
<< TheCall->getArg(0)->getType() << Arg1->getType()
<< TheCall->getArg(0)->getSourceRange() << Arg1->getSourceRange();
return true;
}
TheCall->setType(Arg1->getType());
TheCall->setType(
S->getASTContext().getExtVectorType(Arg1ScalarTy, Arg0Length));
return false;
}

Expand Down
29 changes: 29 additions & 0 deletions clang/test/CodeGenHLSL/builtins/select.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,32 @@ int3 test_select_vector_3(bool3 cond0, int3 tVals, int3 fVals) {
int4 test_select_vector_4(bool4 cond0, int4 tVals, int4 fVals) {
return select(cond0, tVals, fVals);
}

// CHECK-LABEL: test_select_vector_scalar_vector
// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> [[SPLAT1]], <4 x i32> {{%.*}}
// CHECK: ret <4 x i32> [[SELECT]]
int4 test_select_vector_scalar_vector(bool4 cond0, int tVal, int4 fVals) {
return select(cond0, tVal, fVals);
}

// CHECK-LABEL: test_select_vector_vector_scalar
// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> {{%.*}}, <4 x i32> [[SPLAT1]]
// CHECK: ret <4 x i32> [[SELECT]]
int4 test_select_vector_vector_scalar(bool4 cond0, int4 tVals, int fVal) {
return select(cond0, tVals, fVal);
}

// CHECK-LABEL: test_select_vector_scalar_scalar
// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
// CHECK: [[SPLAT_SRC2:%.*]] = insertelement <4 x i32> poison, i32 %3, i64 0
// CHECK: [[SPLAT2:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC2]], <4 x i32> poison, <4 x i32> zeroinitializer
// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> [[SPLAT1]], <4 x i32> [[SPLAT2]]
// CHECK: ret <4 x i32> [[SELECT]]
int4 test_select_vector_scalar_scalar(bool4 cond0, int tVal, int fVal) {
return select(cond0, tVal, fVal);
}
Loading