Skip to content

[HLSL] select scalar overloads for vector conditions #129396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 9, 2025

Conversation

llvm-beanz
Copy link
Collaborator

This PR adds scalar/vector overloads for vector conditions to the select builtin, and updates the sema checking and codegen to allow scalars to extend to vectors.

Fixes #126570

@llvm-beanz llvm-beanz requested a review from spall March 1, 2025 19:24
@llvmbot llvmbot added clang Clang issues not falling into any other category backend:X86 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang:codegen IR generation bugs: mangling, exceptions, etc. HLSL HLSL Language Support labels Mar 1, 2025
@llvmbot
Copy link
Member

llvmbot commented Mar 1, 2025

@llvm/pr-subscribers-backend-x86

@llvm/pr-subscribers-clang

Author: Chris B (llvm-beanz)

Changes

This PR adds scalar/vector overloads for vector conditions to the select builtin, and updates the sema checking and codegen to allow scalars to extend to vectors.

Fixes #126570


Full diff: https://github.com/llvm/llvm-project/pull/129396.diff

7 Files Affected:

  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+3)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+8)
  • (modified) clang/lib/Headers/hlsl/hlsl_detail.h (+5)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+36)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+32-24)
  • (modified) clang/test/CodeGenHLSL/builtins/select.hlsl (+29)
  • (modified) clang/test/SemaHLSL/BuiltIns/select-errors.hlsl (+22-76)
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index d094c075ecee2..be649f0bce320 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12682,6 +12682,9 @@ def err_hlsl_param_qualifier_mismatch :
 def err_hlsl_vector_compound_assignment_truncation : Error<
   "left hand operand of type %0 to compound assignment cannot be truncated "
   "when used with right hand operand of type %1">;
+def err_hlsl_builtin_scalar_vector_mismatch : Error<
+  "%select{all|second and third}0 arguments to %1 must be of scalar or "
+  "vector type with matching scalar element type%diff{: $ vs $|}2,3">;
 
 def warn_hlsl_impcast_vector_truncation : Warning<
   "implicit conversion truncates vector: %0 to %1">, InGroup<Conversion>;
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 03b8d16b76e0d..a84e5e4b59c89 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -19741,6 +19741,14 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         RValFalse.isScalar()
             ? RValFalse.getScalarVal()
             : RValFalse.getAggregatePointer(E->getArg(2)->getType(), *this);
+    if (auto *VTy = E->getType()->getAs<VectorType>()) {
+      if (!OpTrue->getType()->isVectorTy())
+        OpTrue =
+            Builder.CreateVectorSplat(VTy->getNumElements(), OpTrue, "splat");
+      if (!OpFalse->getType()->isVectorTy())
+        OpFalse =
+            Builder.CreateVectorSplat(VTy->getNumElements(), OpFalse, "splat");
+    }
 
     Value *SelectVal =
         Builder.CreateSelect(OpCond, OpTrue, OpFalse, "hlsl.select");
diff --git a/clang/lib/Headers/hlsl/hlsl_detail.h b/clang/lib/Headers/hlsl/hlsl_detail.h
index 0d568539cd66a..daccd2d793aa8 100644
--- a/clang/lib/Headers/hlsl/hlsl_detail.h
+++ b/clang/lib/Headers/hlsl/hlsl_detail.h
@@ -95,6 +95,11 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
 #endif
 }
 
+template<typename T>
+struct is_arithmetic {
+  static const bool Value = __is_arithmetic(T);
+};
+
 } // namespace __detail
 } // namespace hlsl
 #endif //_HLSL_HLSL_DETAILS_H_
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index ed008eeb04ba8..77a7f773b85b2 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -2246,6 +2246,42 @@ template <typename T, int Sz>
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
 vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, vector<T, Sz>);
 
+
+/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, T TrueVal,
+///                         vector<T,Sz> FalseVals)
+/// \brief ternary operator for vectors. All vectors must be the same size.
+/// \param Conds The Condition input values.
+/// \param TrueVal The scalar value to splat from when conditions are true.
+/// \param FalseVals The vector values are chosen from when conditions are
+/// false.
+
+template <typename T, int Sz>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
+vector<T, Sz> select(vector<bool, Sz>, T, vector<T, Sz>);
+
+/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals,
+///                         T FalseVal)
+/// \brief ternary operator for vectors. All vectors must be the same size.
+/// \param Conds The Condition input values.
+/// \param TrueVals The vector values are chosen from when conditions are true.
+/// \param FalseVal The scalar value to splat from when conditions are false.
+
+template <typename T, int Sz>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
+vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, T);
+
+/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals,
+///                         T FalseVal)
+/// \brief ternary operator for vectors. All vectors must be the same size.
+/// \param Conds The Condition input values.
+/// \param TrueVal The scalar value to splat from when conditions are true.
+/// \param FalseVal The scalar value to splat from when conditions are false.
+
+template <typename T, int Sz>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
+__detail::enable_if_t<__detail::is_arithmetic<T>::Value, vector<T, Sz>> select(
+    vector<bool, Sz>, T, T);
+
 //===----------------------------------------------------------------------===//
 // sin builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index bfe84b16218b7..4ec31cd39eb60 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2213,40 +2213,48 @@ static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
 static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
   assert(TheCall->getNumArgs() == 3);
   Expr *Arg1 = TheCall->getArg(1);
+  QualType Arg1Ty = Arg1->getType();
   Expr *Arg2 = TheCall->getArg(2);
-  if (!Arg1->getType()->isVectorType()) {
-    S->Diag(Arg1->getBeginLoc(), diag::err_builtin_non_vector_type)
-        << "Second" << TheCall->getDirectCallee() << Arg1->getType()
+  QualType Arg2Ty = Arg2->getType();
+
+  QualType Arg1ScalarTy = Arg1Ty;
+  if (auto VTy = Arg1ScalarTy->getAs<VectorType>())
+    Arg1ScalarTy = VTy->getElementType();
+
+  QualType Arg2ScalarTy = Arg2Ty;
+  if (auto VTy = Arg2ScalarTy->getAs<VectorType>())
+    Arg2ScalarTy = VTy->getElementType();
+
+  if (!S->Context.hasSameUnqualifiedType(Arg1ScalarTy, Arg2ScalarTy))
+    S->Diag(Arg1->getBeginLoc(), diag::err_hlsl_builtin_scalar_vector_mismatch)
+        << /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty;
+
+  QualType Arg0Ty = TheCall->getArg(0)->getType();
+  unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements();
+  unsigned Arg1Length = Arg1Ty->isVectorType()
+                            ? Arg1Ty->getAs<VectorType>()->getNumElements()
+                            : 0;
+  unsigned Arg2Length = Arg2Ty->isVectorType()
+                            ? Arg2Ty->getAs<VectorType>()->getNumElements()
+                            : 0;
+  if (Arg1Length > 0 && Arg0Length != Arg1Length) {
+    S->Diag(TheCall->getBeginLoc(),
+            diag::err_typecheck_vector_lengths_not_equal)
+        << Arg0Ty << Arg1Ty << TheCall->getArg(0)->getSourceRange()
         << Arg1->getSourceRange();
     return true;
   }
 
-  if (!Arg2->getType()->isVectorType()) {
-    S->Diag(Arg2->getBeginLoc(), diag::err_builtin_non_vector_type)
-        << "Third" << TheCall->getDirectCallee() << Arg2->getType()
-        << Arg2->getSourceRange();
-    return true;
-  }
-
-  if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
+  if (Arg2Length > 0 && Arg0Length != Arg2Length) {
     S->Diag(TheCall->getBeginLoc(),
-            diag::err_typecheck_call_different_arg_types)
-        << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
+            diag::err_typecheck_vector_lengths_not_equal)
+        << Arg0Ty << Arg2Ty << TheCall->getArg(0)->getSourceRange()
         << Arg2->getSourceRange();
     return true;
   }
 
-  // caller has checked that Arg0 is a vector.
-  // check all three args have the same length.
-  if (TheCall->getArg(0)->getType()->getAs<VectorType>()->getNumElements() !=
-      Arg1->getType()->getAs<VectorType>()->getNumElements()) {
-    S->Diag(TheCall->getBeginLoc(),
-            diag::err_typecheck_vector_lengths_not_equal)
-        << TheCall->getArg(0)->getType() << Arg1->getType()
-        << TheCall->getArg(0)->getSourceRange() << Arg1->getSourceRange();
-    return true;
-  }
-  TheCall->setType(Arg1->getType());
+  TheCall->setType(
+      S->getASTContext().getExtVectorType(Arg1ScalarTy, Arg0Length));
   return false;
 }
 
diff --git a/clang/test/CodeGenHLSL/builtins/select.hlsl b/clang/test/CodeGenHLSL/builtins/select.hlsl
index cade938b71a2b..196b8a90cd877 100644
--- a/clang/test/CodeGenHLSL/builtins/select.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/select.hlsl
@@ -52,3 +52,32 @@ int3 test_select_vector_3(bool3 cond0, int3 tVals, int3 fVals) {
 int4 test_select_vector_4(bool4 cond0, int4 tVals, int4 fVals) {
   return select(cond0, tVals, fVals);
 }
+
+// CHECK-LABEL: test_select_vector_scalar_vector
+// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
+// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
+// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> [[SPLAT1]], <4 x i32> {{%.*}}
+// CHECK: ret <4 x i32> [[SELECT]]
+int4 test_select_vector_scalar_vector(bool4 cond0, int tVal, int4 fVals) {
+  return select(cond0, tVal, fVals);
+}
+
+// CHECK-LABEL: test_select_vector_vector_scalar
+// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
+// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
+// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> {{%.*}}, <4 x i32> [[SPLAT1]]
+// CHECK: ret <4 x i32> [[SELECT]]
+int4 test_select_vector_vector_scalar(bool4 cond0, int4 tVals, int fVal) {
+  return select(cond0, tVals, fVal);
+}
+
+// CHECK-LABEL: test_select_vector_scalar_scalar
+// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
+// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
+// CHECK: [[SPLAT_SRC2:%.*]] = insertelement <4 x i32> poison, i32 %3, i64 0
+// CHECK: [[SPLAT2:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC2]], <4 x i32> poison, <4 x i32> zeroinitializer
+// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> [[SPLAT1]], <4 x i32> [[SPLAT2]]
+// CHECK: ret <4 x i32> [[SELECT]]
+int4 test_select_vector_scalar_scalar(bool4 cond0, int tVal, int fVal) {
+  return select(cond0, tVal, fVal);
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
index 34b5fb6d54cd5..b445cedcba074 100644
--- a/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
@@ -1,119 +1,65 @@
-// RUN: %clang_cc1 -finclude-default-header
-// -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only
-// -disable-llvm-passes -verify -verify-ignore-unexpected
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
 
-int test_no_arg() {
-  return select();
-  // expected-error@-1 {{no matching function for call to 'select'}}
-  // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template
-  // not viable: requires 3 arguments, but 0 were provided}}
-  // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
-  // viable: requires 3 arguments, but 0 were provided}}
-}
-
-int test_too_few_args(bool p0) {
-  return select(p0);
-  // expected-error@-1 {{no matching function for call to 'select'}}
-  // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
-  // viable: requires 3 arguments, but 1 was provided}}
-  // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
-  // viable: requires 3 arguments, but 1 was provided}}
-}
-
-int test_too_many_args(bool p0, int t0, int f0, int g0) {
-  return select<int>(p0, t0, f0, g0);
-  // expected-error@-1 {{no matching function for call to 'select'}}
-  // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
-  // viable: requires 3 arguments, but 4 were provided}}
-  // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
-  // viable: requires 3 arguments, but 4 were provided}}
-}
 
 int test_select_first_arg_wrong_type(int1 p0, int t0, int f0) {
   return select(p0, t0, f0);
-  // expected-error@-1 {{no matching function for call to 'select'}}
-  // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
-  // viable: no known conversion from 'vector<int, 1>' (vector of 1 'int' value)
-  // to 'bool' for 1st argument}}
-  // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate template ignored: could
-  // not match 'vector<T, Sz>' against 'int'}}
 }
 
 int1 test_select_bool_vals_diff_vecs(bool p0, int1 t0, int1 f0) {
   return select<int1>(p0, t0, f0);
-  // expected-warning@-1 {{implicit conversion truncates vector:
-  // 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>'
-  // (vector of 1 'int' value)}}
 }
 
 int2 test_select_vector_vals_not_vecs(bool2 p0, int t0,
                                                int f0) {
   return select(p0, t0, f0);
-  // expected-error@-1 {{no matching function for call to 'select'}}
-  // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate template ignored:
-  // could not match 'vector<T, Sz>' against 'int'}}
-  // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
-  // viable: no known conversion from 'vector<bool, 2>'
-  // (vector of 2 'bool' values) to 'bool' for 1st argument}}
 }
 
 int1 test_select_vector_vals_wrong_size(bool2 p0, int1 t0, int1 f0) {
-  return select<int,1>(p0, t0, f0); // produce warnings
-  // expected-warning@-1 {{implicit conversion truncates vector:
-  // 'vector<bool, 2>' (vector of 2 'bool' values) to 'vector<bool, 1>'
-  // (vector of 1 'bool' value)}}
-  // expected-warning@-2 {{implicit conversion truncates vector:
-  // 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>'
-  // (vector of 1 'int' value)}}
+  return select<int,1>(p0, t0, f0); // expected-warning{{implicit conversion truncates vector: 'bool2' (aka 'vector<bool, 2>') to 'vector<bool, 1>' (vector of 1 'bool' value)}}
+}
+
+int test_select_no_args() {
+  return __builtin_hlsl_select(); // expected-error{{too few arguments to function call, expected 3, have 0}}
+}
+
+int test_select_builtin_wrong_arg_count(bool p0) {
+  return __builtin_hlsl_select(p0); // expected-error{{too few arguments to function call, expected 3, have 1}}
 }
 
 // __builtin_hlsl_select tests
-int test_select_builtin_wrong_arg_count(bool p0, int t0) {
-  return __builtin_hlsl_select(p0, t0);
-  // expected-error@-1 {{too few arguments to function call, expected 3,
-  // have 2}}
+int test_select_builtin_wrong_arg_count2(bool p0, int t0) {
+  return __builtin_hlsl_select(p0, t0); // expected-error{{too few arguments to function call, expected 3, have 2}}
+}
+
+int test_too_many_args(bool p0, int t0, int f0, int g0) {
+  return __builtin_hlsl_select(p0, t0, f0, g0); // expected-error{{too many arguments to function call, expected 3, have 4}}
 }
 
 // not a bool or a vector of bool. should be 2 errors.
 int test_select_builtin_first_arg_wrong_type1(int p0, int t0, int f0) {
-  return __builtin_hlsl_select(p0, t0, f0);
-  // expected-error@-1 {{passing 'int' to parameter of incompatible type
-  // 'bool'}}
-  // expected-error@-2 {{First argument to __builtin_hlsl_select must be of
-  // vector type}}
-  }
+  return __builtin_hlsl_select(p0, t0, f0); // expected-error{{invalid operand of type 'int' where 'bool' or a vector of such type is required}}
+}
 
 int test_select_builtin_first_arg_wrong_type2(int1 p0, int t0, int f0) {
-  return __builtin_hlsl_select(p0, t0, f0);
-  // expected-error@-1 {{passing 'vector<int, 1>' (vector of 1 'int' value) to
-  // parameter of incompatible type 'bool'}}
-  // expected-error@-2 {{First argument to __builtin_hlsl_select must be of
-  // vector type}}
+  return __builtin_hlsl_select(p0, t0, f0); // expected-error{{invalid operand of type 'int1' (aka 'vector<int, 1>') where 'bool' or a vector of such type is required}}
 }
 
 // if a bool last 2 args are of same type
 int test_select_builtin_bool_incompatible_args(bool p0, int t0, double f0) {
-  return __builtin_hlsl_select(p0, t0, f0);
-  // expected-error@-1 {{arguments are of different types ('int' vs 'double')}}
+  return __builtin_hlsl_select(p0, t0, f0); // expected-error{{arguments are of different types ('int' vs 'double')}}
 }
 
 // if a vector second arg isnt a vector
 int2 test_select_builtin_second_arg_not_vector(bool2 p0, int t0, int2 f0) {
   return __builtin_hlsl_select(p0, t0, f0);
-  // expected-error@-1 {{Second argument to __builtin_hlsl_select must be of
-  // vector type}}
 }
 
 // if a vector third arg isn't a vector
 int2 test_select_builtin_second_arg_not_vector(bool2 p0, int2 t0, int f0) {
   return __builtin_hlsl_select(p0, t0, f0);
-  // expected-error@-1 {{Third argument to __builtin_hlsl_select must be of
-  // vector type}}
 }
 
 // if vector last 2 aren't same type (so both are vectors but wrong type)
-int2 test_select_builtin_diff_types(bool1 p0, int1 t0, float1 f0) {
-  return __builtin_hlsl_select(p0, t0, f0);
-  // expected-error@-1 {{arguments are of different types ('vector<int, [...]>'
-  // vs 'vector<float, [...]>')}}
+int1 test_select_builtin_diff_types(bool1 p0, int1 t0, float1 f0) {
+  return __builtin_hlsl_select(p0, t0, f0); // expected-error{{second and third arguments to __builtin_hlsl_select must be of scalar or vector type with matching scalar element type: 'vector<int, [...]>' vs 'vector<float, [...]>'}}
 }

@llvmbot
Copy link
Member

llvmbot commented Mar 1, 2025

@llvm/pr-subscribers-clang-codegen

Author: Chris B (llvm-beanz)

Changes

This PR adds scalar/vector overloads for vector conditions to the select builtin, and updates the sema checking and codegen to allow scalars to extend to vectors.

Fixes #126570


Full diff: https://github.com/llvm/llvm-project/pull/129396.diff

7 Files Affected:

  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+3)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+8)
  • (modified) clang/lib/Headers/hlsl/hlsl_detail.h (+5)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+36)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+32-24)
  • (modified) clang/test/CodeGenHLSL/builtins/select.hlsl (+29)
  • (modified) clang/test/SemaHLSL/BuiltIns/select-errors.hlsl (+22-76)
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index d094c075ecee2..be649f0bce320 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12682,6 +12682,9 @@ def err_hlsl_param_qualifier_mismatch :
 def err_hlsl_vector_compound_assignment_truncation : Error<
   "left hand operand of type %0 to compound assignment cannot be truncated "
   "when used with right hand operand of type %1">;
+def err_hlsl_builtin_scalar_vector_mismatch : Error<
+  "%select{all|second and third}0 arguments to %1 must be of scalar or "
+  "vector type with matching scalar element type%diff{: $ vs $|}2,3">;
 
 def warn_hlsl_impcast_vector_truncation : Warning<
   "implicit conversion truncates vector: %0 to %1">, InGroup<Conversion>;
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 03b8d16b76e0d..a84e5e4b59c89 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -19741,6 +19741,14 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         RValFalse.isScalar()
             ? RValFalse.getScalarVal()
             : RValFalse.getAggregatePointer(E->getArg(2)->getType(), *this);
+    if (auto *VTy = E->getType()->getAs<VectorType>()) {
+      if (!OpTrue->getType()->isVectorTy())
+        OpTrue =
+            Builder.CreateVectorSplat(VTy->getNumElements(), OpTrue, "splat");
+      if (!OpFalse->getType()->isVectorTy())
+        OpFalse =
+            Builder.CreateVectorSplat(VTy->getNumElements(), OpFalse, "splat");
+    }
 
     Value *SelectVal =
         Builder.CreateSelect(OpCond, OpTrue, OpFalse, "hlsl.select");
diff --git a/clang/lib/Headers/hlsl/hlsl_detail.h b/clang/lib/Headers/hlsl/hlsl_detail.h
index 0d568539cd66a..daccd2d793aa8 100644
--- a/clang/lib/Headers/hlsl/hlsl_detail.h
+++ b/clang/lib/Headers/hlsl/hlsl_detail.h
@@ -95,6 +95,11 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
 #endif
 }
 
+template<typename T>
+struct is_arithmetic {
+  static const bool Value = __is_arithmetic(T);
+};
+
 } // namespace __detail
 } // namespace hlsl
 #endif //_HLSL_HLSL_DETAILS_H_
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index ed008eeb04ba8..77a7f773b85b2 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -2246,6 +2246,42 @@ template <typename T, int Sz>
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
 vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, vector<T, Sz>);
 
+
+/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, T TrueVal,
+///                         vector<T,Sz> FalseVals)
+/// \brief ternary operator for vectors. All vectors must be the same size.
+/// \param Conds The Condition input values.
+/// \param TrueVal The scalar value to splat from when conditions are true.
+/// \param FalseVals The vector values are chosen from when conditions are
+/// false.
+
+template <typename T, int Sz>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
+vector<T, Sz> select(vector<bool, Sz>, T, vector<T, Sz>);
+
+/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals,
+///                         T FalseVal)
+/// \brief ternary operator for vectors. All vectors must be the same size.
+/// \param Conds The Condition input values.
+/// \param TrueVals The vector values are chosen from when conditions are true.
+/// \param FalseVal The scalar value to splat from when conditions are false.
+
+template <typename T, int Sz>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
+vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, T);
+
+/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals,
+///                         T FalseVal)
+/// \brief ternary operator for vectors. All vectors must be the same size.
+/// \param Conds The Condition input values.
+/// \param TrueVal The scalar value to splat from when conditions are true.
+/// \param FalseVal The scalar value to splat from when conditions are false.
+
+template <typename T, int Sz>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
+__detail::enable_if_t<__detail::is_arithmetic<T>::Value, vector<T, Sz>> select(
+    vector<bool, Sz>, T, T);
+
 //===----------------------------------------------------------------------===//
 // sin builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index bfe84b16218b7..4ec31cd39eb60 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2213,40 +2213,48 @@ static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
 static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
   assert(TheCall->getNumArgs() == 3);
   Expr *Arg1 = TheCall->getArg(1);
+  QualType Arg1Ty = Arg1->getType();
   Expr *Arg2 = TheCall->getArg(2);
-  if (!Arg1->getType()->isVectorType()) {
-    S->Diag(Arg1->getBeginLoc(), diag::err_builtin_non_vector_type)
-        << "Second" << TheCall->getDirectCallee() << Arg1->getType()
+  QualType Arg2Ty = Arg2->getType();
+
+  QualType Arg1ScalarTy = Arg1Ty;
+  if (auto VTy = Arg1ScalarTy->getAs<VectorType>())
+    Arg1ScalarTy = VTy->getElementType();
+
+  QualType Arg2ScalarTy = Arg2Ty;
+  if (auto VTy = Arg2ScalarTy->getAs<VectorType>())
+    Arg2ScalarTy = VTy->getElementType();
+
+  if (!S->Context.hasSameUnqualifiedType(Arg1ScalarTy, Arg2ScalarTy))
+    S->Diag(Arg1->getBeginLoc(), diag::err_hlsl_builtin_scalar_vector_mismatch)
+        << /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty;
+
+  QualType Arg0Ty = TheCall->getArg(0)->getType();
+  unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements();
+  unsigned Arg1Length = Arg1Ty->isVectorType()
+                            ? Arg1Ty->getAs<VectorType>()->getNumElements()
+                            : 0;
+  unsigned Arg2Length = Arg2Ty->isVectorType()
+                            ? Arg2Ty->getAs<VectorType>()->getNumElements()
+                            : 0;
+  if (Arg1Length > 0 && Arg0Length != Arg1Length) {
+    S->Diag(TheCall->getBeginLoc(),
+            diag::err_typecheck_vector_lengths_not_equal)
+        << Arg0Ty << Arg1Ty << TheCall->getArg(0)->getSourceRange()
         << Arg1->getSourceRange();
     return true;
   }
 
-  if (!Arg2->getType()->isVectorType()) {
-    S->Diag(Arg2->getBeginLoc(), diag::err_builtin_non_vector_type)
-        << "Third" << TheCall->getDirectCallee() << Arg2->getType()
-        << Arg2->getSourceRange();
-    return true;
-  }
-
-  if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
+  if (Arg2Length > 0 && Arg0Length != Arg2Length) {
     S->Diag(TheCall->getBeginLoc(),
-            diag::err_typecheck_call_different_arg_types)
-        << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
+            diag::err_typecheck_vector_lengths_not_equal)
+        << Arg0Ty << Arg2Ty << TheCall->getArg(0)->getSourceRange()
         << Arg2->getSourceRange();
     return true;
   }
 
-  // caller has checked that Arg0 is a vector.
-  // check all three args have the same length.
-  if (TheCall->getArg(0)->getType()->getAs<VectorType>()->getNumElements() !=
-      Arg1->getType()->getAs<VectorType>()->getNumElements()) {
-    S->Diag(TheCall->getBeginLoc(),
-            diag::err_typecheck_vector_lengths_not_equal)
-        << TheCall->getArg(0)->getType() << Arg1->getType()
-        << TheCall->getArg(0)->getSourceRange() << Arg1->getSourceRange();
-    return true;
-  }
-  TheCall->setType(Arg1->getType());
+  TheCall->setType(
+      S->getASTContext().getExtVectorType(Arg1ScalarTy, Arg0Length));
   return false;
 }
 
diff --git a/clang/test/CodeGenHLSL/builtins/select.hlsl b/clang/test/CodeGenHLSL/builtins/select.hlsl
index cade938b71a2b..196b8a90cd877 100644
--- a/clang/test/CodeGenHLSL/builtins/select.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/select.hlsl
@@ -52,3 +52,32 @@ int3 test_select_vector_3(bool3 cond0, int3 tVals, int3 fVals) {
 int4 test_select_vector_4(bool4 cond0, int4 tVals, int4 fVals) {
   return select(cond0, tVals, fVals);
 }
+
+// CHECK-LABEL: test_select_vector_scalar_vector
+// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
+// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
+// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> [[SPLAT1]], <4 x i32> {{%.*}}
+// CHECK: ret <4 x i32> [[SELECT]]
+int4 test_select_vector_scalar_vector(bool4 cond0, int tVal, int4 fVals) {
+  return select(cond0, tVal, fVals);
+}
+
+// CHECK-LABEL: test_select_vector_vector_scalar
+// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
+// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
+// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> {{%.*}}, <4 x i32> [[SPLAT1]]
+// CHECK: ret <4 x i32> [[SELECT]]
+int4 test_select_vector_vector_scalar(bool4 cond0, int4 tVals, int fVal) {
+  return select(cond0, tVals, fVal);
+}
+
+// CHECK-LABEL: test_select_vector_scalar_scalar
+// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
+// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
+// CHECK: [[SPLAT_SRC2:%.*]] = insertelement <4 x i32> poison, i32 %3, i64 0
+// CHECK: [[SPLAT2:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC2]], <4 x i32> poison, <4 x i32> zeroinitializer
+// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> [[SPLAT1]], <4 x i32> [[SPLAT2]]
+// CHECK: ret <4 x i32> [[SELECT]]
+int4 test_select_vector_scalar_scalar(bool4 cond0, int tVal, int fVal) {
+  return select(cond0, tVal, fVal);
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
index 34b5fb6d54cd5..b445cedcba074 100644
--- a/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
@@ -1,119 +1,65 @@
-// RUN: %clang_cc1 -finclude-default-header
-// -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only
-// -disable-llvm-passes -verify -verify-ignore-unexpected
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
 
-int test_no_arg() {
-  return select();
-  // expected-error@-1 {{no matching function for call to 'select'}}
-  // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template
-  // not viable: requires 3 arguments, but 0 were provided}}
-  // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
-  // viable: requires 3 arguments, but 0 were provided}}
-}
-
-int test_too_few_args(bool p0) {
-  return select(p0);
-  // expected-error@-1 {{no matching function for call to 'select'}}
-  // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
-  // viable: requires 3 arguments, but 1 was provided}}
-  // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
-  // viable: requires 3 arguments, but 1 was provided}}
-}
-
-int test_too_many_args(bool p0, int t0, int f0, int g0) {
-  return select<int>(p0, t0, f0, g0);
-  // expected-error@-1 {{no matching function for call to 'select'}}
-  // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
-  // viable: requires 3 arguments, but 4 were provided}}
-  // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
-  // viable: requires 3 arguments, but 4 were provided}}
-}
 
 int test_select_first_arg_wrong_type(int1 p0, int t0, int f0) {
   return select(p0, t0, f0);
-  // expected-error@-1 {{no matching function for call to 'select'}}
-  // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
-  // viable: no known conversion from 'vector<int, 1>' (vector of 1 'int' value)
-  // to 'bool' for 1st argument}}
-  // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate template ignored: could
-  // not match 'vector<T, Sz>' against 'int'}}
 }
 
 int1 test_select_bool_vals_diff_vecs(bool p0, int1 t0, int1 f0) {
   return select<int1>(p0, t0, f0);
-  // expected-warning@-1 {{implicit conversion truncates vector:
-  // 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>'
-  // (vector of 1 'int' value)}}
 }
 
 int2 test_select_vector_vals_not_vecs(bool2 p0, int t0,
                                                int f0) {
   return select(p0, t0, f0);
-  // expected-error@-1 {{no matching function for call to 'select'}}
-  // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate template ignored:
-  // could not match 'vector<T, Sz>' against 'int'}}
-  // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
-  // viable: no known conversion from 'vector<bool, 2>'
-  // (vector of 2 'bool' values) to 'bool' for 1st argument}}
 }
 
 int1 test_select_vector_vals_wrong_size(bool2 p0, int1 t0, int1 f0) {
-  return select<int,1>(p0, t0, f0); // produce warnings
-  // expected-warning@-1 {{implicit conversion truncates vector:
-  // 'vector<bool, 2>' (vector of 2 'bool' values) to 'vector<bool, 1>'
-  // (vector of 1 'bool' value)}}
-  // expected-warning@-2 {{implicit conversion truncates vector:
-  // 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>'
-  // (vector of 1 'int' value)}}
+  return select<int,1>(p0, t0, f0); // expected-warning{{implicit conversion truncates vector: 'bool2' (aka 'vector<bool, 2>') to 'vector<bool, 1>' (vector of 1 'bool' value)}}
+}
+
+int test_select_no_args() {
+  return __builtin_hlsl_select(); // expected-error{{too few arguments to function call, expected 3, have 0}}
+}
+
+int test_select_builtin_wrong_arg_count(bool p0) {
+  return __builtin_hlsl_select(p0); // expected-error{{too few arguments to function call, expected 3, have 1}}
 }
 
 // __builtin_hlsl_select tests
-int test_select_builtin_wrong_arg_count(bool p0, int t0) {
-  return __builtin_hlsl_select(p0, t0);
-  // expected-error@-1 {{too few arguments to function call, expected 3,
-  // have 2}}
+int test_select_builtin_wrong_arg_count2(bool p0, int t0) {
+  return __builtin_hlsl_select(p0, t0); // expected-error{{too few arguments to function call, expected 3, have 2}}
+}
+
+int test_too_many_args(bool p0, int t0, int f0, int g0) {
+  return __builtin_hlsl_select(p0, t0, f0, g0); // expected-error{{too many arguments to function call, expected 3, have 4}}
 }
 
 // not a bool or a vector of bool. should be 2 errors.
 int test_select_builtin_first_arg_wrong_type1(int p0, int t0, int f0) {
-  return __builtin_hlsl_select(p0, t0, f0);
-  // expected-error@-1 {{passing 'int' to parameter of incompatible type
-  // 'bool'}}
-  // expected-error@-2 {{First argument to __builtin_hlsl_select must be of
-  // vector type}}
-  }
+  return __builtin_hlsl_select(p0, t0, f0); // expected-error{{invalid operand of type 'int' where 'bool' or a vector of such type is required}}
+}
 
 int test_select_builtin_first_arg_wrong_type2(int1 p0, int t0, int f0) {
-  return __builtin_hlsl_select(p0, t0, f0);
-  // expected-error@-1 {{passing 'vector<int, 1>' (vector of 1 'int' value) to
-  // parameter of incompatible type 'bool'}}
-  // expected-error@-2 {{First argument to __builtin_hlsl_select must be of
-  // vector type}}
+  return __builtin_hlsl_select(p0, t0, f0); // expected-error{{invalid operand of type 'int1' (aka 'vector<int, 1>') where 'bool' or a vector of such type is required}}
 }
 
 // if a bool last 2 args are of same type
 int test_select_builtin_bool_incompatible_args(bool p0, int t0, double f0) {
-  return __builtin_hlsl_select(p0, t0, f0);
-  // expected-error@-1 {{arguments are of different types ('int' vs 'double')}}
+  return __builtin_hlsl_select(p0, t0, f0); // expected-error{{arguments are of different types ('int' vs 'double')}}
 }
 
 // if a vector second arg isnt a vector
 int2 test_select_builtin_second_arg_not_vector(bool2 p0, int t0, int2 f0) {
   return __builtin_hlsl_select(p0, t0, f0);
-  // expected-error@-1 {{Second argument to __builtin_hlsl_select must be of
-  // vector type}}
 }
 
 // if a vector third arg isn't a vector
 int2 test_select_builtin_second_arg_not_vector(bool2 p0, int2 t0, int f0) {
   return __builtin_hlsl_select(p0, t0, f0);
-  // expected-error@-1 {{Third argument to __builtin_hlsl_select must be of
-  // vector type}}
 }
 
 // if vector last 2 aren't same type (so both are vectors but wrong type)
-int2 test_select_builtin_diff_types(bool1 p0, int1 t0, float1 f0) {
-  return __builtin_hlsl_select(p0, t0, f0);
-  // expected-error@-1 {{arguments are of different types ('vector<int, [...]>'
-  // vs 'vector<float, [...]>')}}
+int1 test_select_builtin_diff_types(bool1 p0, int1 t0, float1 f0) {
+  return __builtin_hlsl_select(p0, t0, f0); // expected-error{{second and third arguments to __builtin_hlsl_select must be of scalar or vector type with matching scalar element type: 'vector<int, [...]>' vs 'vector<float, [...]>'}}
 }

Copy link

github-actions bot commented Mar 1, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

// viable: requires 3 arguments, but 4 were provided}}
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
// viable: requires 3 arguments, but 4 were provided}}
}

int test_select_first_arg_wrong_type(int1 p0, int t0, int f0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and the 2 tests below don't error? Do you just want to delete them?

}

// if a bool last 2 args are of same type
int test_select_builtin_bool_incompatible_args(bool p0, int t0, double f0) {
return __builtin_hlsl_select(p0, t0, f0);
// expected-error@-1 {{arguments are of different types ('int' vs 'double')}}
return __builtin_hlsl_select(p0, t0, f0); // expected-error{{arguments are of different types ('int' vs 'double')}}
}

// if a vector second arg isnt a vector
int2 test_select_builtin_second_arg_not_vector(bool2 p0, int t0, int2 f0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here and the one below.

This PR adds scalar/vector overloads for vector conditions to the
`select` builtin, and updates the sema checking and codegen to allow
scalars to extend to vectors.

Fixes llvm#126570

clang-format
clang-format
'cbieneman/select' on '44f0fe9a2806'.
This fixes issues that were causing the new select overloads to fail to
compile.
//
//===----------------------------------------------------------------------===//

#ifndef _HLSL_HLSL_INTRINSIC_HELPERS_H_
Copy link
Member

@farzonl farzonl Mar 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does the enable_if_t template in hlsl_detail.hget exposed here? Does order in clang/lib/Headers/hlsl.h matter I see you have hlsl_detail.h before hlsl_alias_intrinsics.h.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hlsl_detail.h is included first in hlsl.h, which makes it included before the other headers. These headers are all implementation details, and aren't expected to be exposed to user code since hlsl.h is implicitly included in all HLSL source files.

@llvm-beanz llvm-beanz merged commit e85e29c into llvm:main Mar 9, 2025
10 of 12 checks passed
@damyanp damyanp moved this to Closed in HLSL Support Apr 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:X86 clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
Status: Closed
Development

Successfully merging this pull request may close these issues.

[HLSL] select not resolving correctly for complex cases
4 participants