Skip to content

Commit 0f349b7

Browse files
authored
[HLSL] Implement support for HLSL intrinsic - select (#107129)
Implement support for HLSL intrinsic select. This would close issue #75377
1 parent 34e3007 commit 0f349b7

File tree

8 files changed

+334
-4
lines changed

8 files changed

+334
-4
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4763,6 +4763,12 @@ def HLSLSaturate : LangBuiltin<"HLSL_LANG"> {
47634763
let Prototype = "void(...)";
47644764
}
47654765

4766+
def HLSLSelect : LangBuiltin<"HLSL_LANG"> {
4767+
let Spellings = ["__builtin_hlsl_select"];
4768+
let Attributes = [NoThrow, Const];
4769+
let Prototype = "void(...)";
4770+
}
4771+
47664772
// Builtins for XRay.
47674773
def XRayCustomEvent : Builtin {
47684774
let Spellings = ["__xray_customevent"];

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9206,6 +9206,9 @@ def err_typecheck_expect_scalar_operand : Error<
92069206
"operand of type %0 where arithmetic or pointer type is required">;
92079207
def err_typecheck_cond_incompatible_operands : Error<
92089208
"incompatible operand types%diff{ ($ and $)|}0,1">;
9209+
def err_typecheck_expect_scalar_or_vector : Error<
9210+
"invalid operand of type %0 where %1 or "
9211+
"a vector of such type is required">;
92099212
def err_typecheck_expect_flt_or_vector : Error<
92109213
"invalid operand of type %0 where floating, complex or "
92119214
"a vector of such types is required">;

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6244,8 +6244,20 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
62446244
}
62456245

62466246
// EmitHLSLBuiltinExpr will check getLangOpts().HLSL
6247-
if (Value *V = EmitHLSLBuiltinExpr(BuiltinID, E))
6248-
return RValue::get(V);
6247+
if (Value *V = EmitHLSLBuiltinExpr(BuiltinID, E, ReturnValue)) {
6248+
switch (EvalKind) {
6249+
case TEK_Scalar:
6250+
if (V->getType()->isVoidTy())
6251+
return RValue::get(nullptr);
6252+
return RValue::get(V);
6253+
case TEK_Aggregate:
6254+
return RValue::getAggregate(ReturnValue.getAddress(),
6255+
ReturnValue.isVolatile());
6256+
case TEK_Complex:
6257+
llvm_unreachable("No current hlsl builtin returns complex");
6258+
}
6259+
llvm_unreachable("Bad evaluation kind in EmitBuiltinExpr");
6260+
}
62496261

62506262
if (getLangOpts().HIPStdPar && getLangOpts().CUDAIsDevice)
62516263
return EmitHipStdParUnsupportedBuiltin(this, FD);
@@ -18640,7 +18652,8 @@ Intrinsic::ID getDotProductIntrinsic(CGHLSLRuntime &RT, QualType QT) {
1864018652
}
1864118653

1864218654
Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
18643-
const CallExpr *E) {
18655+
const CallExpr *E,
18656+
ReturnValueSlot ReturnValue) {
1864418657
if (!getLangOpts().HLSL)
1864518658
return nullptr;
1864618659

@@ -18827,6 +18840,27 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1882718840
CGM.getHLSLRuntime().getSaturateIntrinsic(), ArrayRef<Value *>{Op0},
1882818841
nullptr, "hlsl.saturate");
1882918842
}
18843+
case Builtin::BI__builtin_hlsl_select: {
18844+
Value *OpCond = EmitScalarExpr(E->getArg(0));
18845+
RValue RValTrue = EmitAnyExpr(E->getArg(1));
18846+
Value *OpTrue =
18847+
RValTrue.isScalar()
18848+
? RValTrue.getScalarVal()
18849+
: RValTrue.getAggregatePointer(E->getArg(1)->getType(), *this);
18850+
RValue RValFalse = EmitAnyExpr(E->getArg(2));
18851+
Value *OpFalse =
18852+
RValFalse.isScalar()
18853+
? RValFalse.getScalarVal()
18854+
: RValFalse.getAggregatePointer(E->getArg(2)->getType(), *this);
18855+
18856+
Value *SelectVal =
18857+
Builder.CreateSelect(OpCond, OpTrue, OpFalse, "hlsl.select");
18858+
if (!RValTrue.isScalar())
18859+
Builder.CreateStore(SelectVal, ReturnValue.getAddress(),
18860+
ReturnValue.isVolatile());
18861+
18862+
return SelectVal;
18863+
}
1883018864
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
1883118865
return EmitRuntimeCall(CGM.CreateRuntimeFunction(
1883218866
llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4704,7 +4704,8 @@ class CodeGenFunction : public CodeGenTypeCache {
47044704
llvm::Value *EmitX86BuiltinExpr(unsigned BuiltinID, const CallExpr *E);
47054705
llvm::Value *EmitPPCBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
47064706
llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
4707-
llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
4707+
llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E,
4708+
ReturnValueSlot ReturnValue);
47084709
llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
47094710
const CallExpr *E);
47104711
llvm::Value *EmitSystemZBuiltinExpr(unsigned BuiltinID, const CallExpr *E);

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1603,6 +1603,32 @@ double3 saturate(double3);
16031603
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
16041604
double4 saturate(double4);
16051605

1606+
//===----------------------------------------------------------------------===//
1607+
// select builtins
1608+
//===----------------------------------------------------------------------===//
1609+
1610+
/// \fn T select(bool Cond, T TrueVal, T FalseVal)
1611+
/// \brief ternary operator.
1612+
/// \param Cond The Condition input value.
1613+
/// \param TrueVal The Value returned if Cond is true.
1614+
/// \param FalseVal The Value returned if Cond is false.
1615+
1616+
template <typename T>
1617+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
1618+
T select(bool, T, T);
1619+
1620+
/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals,
1621+
/// vector<T,Sz> FalseVals)
1622+
/// \brief ternary operator for vectors. All vectors must be the same size.
1623+
/// \param Conds The Condition input values.
1624+
/// \param TrueVals The vector values are chosen from when conditions are true.
1625+
/// \param FalseVals The vector values are chosen from when conditions are
1626+
/// false.
1627+
1628+
template <typename T, int Sz>
1629+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
1630+
vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, vector<T, Sz>);
1631+
16061632
//===----------------------------------------------------------------------===//
16071633
// sin builtins
16081634
//===----------------------------------------------------------------------===//

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,6 +1531,79 @@ void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
15311531
TheCall->setType(ReturnType);
15321532
}
15331533

1534+
static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
1535+
unsigned ArgIndex) {
1536+
assert(TheCall->getNumArgs() >= ArgIndex);
1537+
QualType ArgType = TheCall->getArg(ArgIndex)->getType();
1538+
auto *VTy = ArgType->getAs<VectorType>();
1539+
// not the scalar or vector<scalar>
1540+
if (!(S->Context.hasSameUnqualifiedType(ArgType, Scalar) ||
1541+
(VTy &&
1542+
S->Context.hasSameUnqualifiedType(VTy->getElementType(), Scalar)))) {
1543+
S->Diag(TheCall->getArg(0)->getBeginLoc(),
1544+
diag::err_typecheck_expect_scalar_or_vector)
1545+
<< ArgType << Scalar;
1546+
return true;
1547+
}
1548+
return false;
1549+
}
1550+
1551+
static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
1552+
assert(TheCall->getNumArgs() == 3);
1553+
Expr *Arg1 = TheCall->getArg(1);
1554+
Expr *Arg2 = TheCall->getArg(2);
1555+
if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
1556+
S->Diag(TheCall->getBeginLoc(),
1557+
diag::err_typecheck_call_different_arg_types)
1558+
<< Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
1559+
<< Arg2->getSourceRange();
1560+
return true;
1561+
}
1562+
1563+
TheCall->setType(Arg1->getType());
1564+
return false;
1565+
}
1566+
1567+
static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
1568+
assert(TheCall->getNumArgs() == 3);
1569+
Expr *Arg1 = TheCall->getArg(1);
1570+
Expr *Arg2 = TheCall->getArg(2);
1571+
if (!Arg1->getType()->isVectorType()) {
1572+
S->Diag(Arg1->getBeginLoc(), diag::err_builtin_non_vector_type)
1573+
<< "Second" << TheCall->getDirectCallee() << Arg1->getType()
1574+
<< Arg1->getSourceRange();
1575+
return true;
1576+
}
1577+
1578+
if (!Arg2->getType()->isVectorType()) {
1579+
S->Diag(Arg2->getBeginLoc(), diag::err_builtin_non_vector_type)
1580+
<< "Third" << TheCall->getDirectCallee() << Arg2->getType()
1581+
<< Arg2->getSourceRange();
1582+
return true;
1583+
}
1584+
1585+
if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
1586+
S->Diag(TheCall->getBeginLoc(),
1587+
diag::err_typecheck_call_different_arg_types)
1588+
<< Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
1589+
<< Arg2->getSourceRange();
1590+
return true;
1591+
}
1592+
1593+
// caller has checked that Arg0 is a vector.
1594+
// check all three args have the same length.
1595+
if (TheCall->getArg(0)->getType()->getAs<VectorType>()->getNumElements() !=
1596+
Arg1->getType()->getAs<VectorType>()->getNumElements()) {
1597+
S->Diag(TheCall->getBeginLoc(),
1598+
diag::err_typecheck_vector_lengths_not_equal)
1599+
<< TheCall->getArg(0)->getType() << Arg1->getType()
1600+
<< TheCall->getArg(0)->getSourceRange() << Arg1->getSourceRange();
1601+
return true;
1602+
}
1603+
TheCall->setType(Arg1->getType());
1604+
return false;
1605+
}
1606+
15341607
// Note: returning true in this case results in CheckBuiltinFunctionCall
15351608
// returning an ExprError
15361609
bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
@@ -1563,6 +1636,20 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
15631636
return true;
15641637
break;
15651638
}
1639+
case Builtin::BI__builtin_hlsl_select: {
1640+
if (SemaRef.checkArgCount(TheCall, 3))
1641+
return true;
1642+
if (CheckScalarOrVector(&SemaRef, TheCall, getASTContext().BoolTy, 0))
1643+
return true;
1644+
QualType ArgTy = TheCall->getArg(0)->getType();
1645+
if (ArgTy->isBooleanType() && CheckBoolSelect(&SemaRef, TheCall))
1646+
return true;
1647+
auto *VTy = ArgTy->getAs<VectorType>();
1648+
if (VTy && VTy->getElementType()->isBooleanType() &&
1649+
CheckVectorSelect(&SemaRef, TheCall))
1650+
return true;
1651+
break;
1652+
}
15661653
case Builtin::BI__builtin_hlsl_elementwise_saturate:
15671654
case Builtin::BI__builtin_hlsl_elementwise_rcp: {
15681655
if (CheckAllArgsHaveFloatRepresentation(&SemaRef, TheCall))
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
2+
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
3+
// RUN: -o - | FileCheck %s --check-prefixes=CHECK
4+
5+
// CHECK-LABEL: test_select_bool_int
6+
// CHECK: [[SELECT:%.*]] = select i1 {{%.*}}, i32 {{%.*}}, i32 {{%.*}}
7+
// CHECK: ret i32 [[SELECT]]
8+
int test_select_bool_int(bool cond0, int tVal, int fVal) {
9+
return select<int>(cond0, tVal, fVal);
10+
}
11+
12+
struct S { int a; };
13+
// CHECK-LABEL: test_select_infer
14+
// CHECK: [[SELECT:%.*]] = select i1 {{%.*}}, ptr {{%.*}}, ptr {{%.*}}
15+
// CHECK: store ptr [[SELECT]]
16+
// CHECK: ret void
17+
struct S test_select_infer(bool cond0, struct S tVal, struct S fVal) {
18+
return select(cond0, tVal, fVal);
19+
}
20+
21+
// CHECK-LABEL: test_select_bool_vector
22+
// CHECK: [[SELECT:%.*]] = select i1 {{%.*}}, <2 x i32> {{%.*}}, <2 x i32> {{%.*}}
23+
// CHECK: ret <2 x i32> [[SELECT]]
24+
int2 test_select_bool_vector(bool cond0, int2 tVal, int2 fVal) {
25+
return select<int2>(cond0, tVal, fVal);
26+
}
27+
28+
// CHECK-LABEL: test_select_vector_1
29+
// CHECK: [[SELECT:%.*]] = select <1 x i1> {{%.*}}, <1 x i32> {{%.*}}, <1 x i32> {{%.*}}
30+
// CHECK: ret <1 x i32> [[SELECT]]
31+
int1 test_select_vector_1(bool1 cond0, int1 tVals, int1 fVals) {
32+
return select<int,1>(cond0, tVals, fVals);
33+
}
34+
35+
// CHECK-LABEL: test_select_vector_2
36+
// CHECK: [[SELECT:%.*]] = select <2 x i1> {{%.*}}, <2 x i32> {{%.*}}, <2 x i32> {{%.*}}
37+
// CHECK: ret <2 x i32> [[SELECT]]
38+
int2 test_select_vector_2(bool2 cond0, int2 tVals, int2 fVals) {
39+
return select<int,2>(cond0, tVals, fVals);
40+
}
41+
42+
// CHECK-LABEL: test_select_vector_3
43+
// CHECK: [[SELECT:%.*]] = select <3 x i1> {{%.*}}, <3 x i32> {{%.*}}, <3 x i32> {{%.*}}
44+
// CHECK: ret <3 x i32> [[SELECT]]
45+
int3 test_select_vector_3(bool3 cond0, int3 tVals, int3 fVals) {
46+
return select<int,3>(cond0, tVals, fVals);
47+
}
48+
49+
// CHECK-LABEL: test_select_vector_4
50+
// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> {{%.*}}, <4 x i32> {{%.*}}
51+
// CHECK: ret <4 x i32> [[SELECT]]
52+
int4 test_select_vector_4(bool4 cond0, int4 tVals, int4 fVals) {
53+
return select(cond0, tVals, fVals);
54+
}

0 commit comments

Comments
 (0)