Skip to content

Commit 3e0cd3c

Browse files
committed
implement select intrinsic
1 parent 15fa3ba commit 3e0cd3c

File tree

4 files changed

+155
-0
lines changed

4 files changed

+155
-0
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4751,6 +4751,12 @@ def HLSLSaturate : LangBuiltin<"HLSL_LANG"> {
47514751
let Prototype = "void(...)";
47524752
}
47534753

4754+
def HLSLSelect : LangBuiltin<"HLSL_LANG"> {
4755+
let Spellings = ["__builtin_hlsl_select"];
4756+
let Attributes = [NoThrow, Const];
4757+
let Prototype = "void(...)";
4758+
}
4759+
47544760
// Builtins for XRay.
47554761
def XRayCustomEvent : Builtin {
47564762
let Spellings = ["__xray_customevent"];

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18695,6 +18695,47 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1869518695
CGM.getHLSLRuntime().getSaturateIntrinsic(), ArrayRef<Value *>{Op0},
1869618696
nullptr, "hlsl.saturate");
1869718697
}
18698+
case Builtin::BI__builtin_hlsl_select: {
18699+
Value *OpCond = EmitScalarExpr(E->getArg(0));
18700+
Value *OpTrue = EmitScalarExpr(E->getArg(1));
18701+
Value *OpFalse = EmitScalarExpr(E->getArg(2));
18702+
llvm::Type *TCond = OpCond->getType();
18703+
18704+
// if cond is a bool emit a select instruction
18705+
if (TCond->isIntegerTy(1))
18706+
return Builder.CreateSelect(OpCond, OpTrue, OpFalse);
18707+
18708+
// if cond is a vector of bools lower to a shufflevector
18709+
// todo check if that true and false are vectors
18710+
// todo check that the size of true and false and cond are the same
18711+
if (TCond->isVectorTy() &&
18712+
E->getArg(0)->getType()->getAs<VectorType>()->isBooleanType()) {
18713+
assert(OpTrue->getType()->isVectorTy() && OpFalse->getType()->isVectorTy() &&
18714+
"Select's second and third operands must be vectors if first operand is a vector.");
18715+
18716+
auto *VecTyTrue = E->getArg(1)->getType()->getAs<VectorType>();
18717+
auto *VecTyFalse = E->getArg(2)->getType()->getAs<VectorType>();
18718+
18719+
assert(VecTyTrue->getElementType() == VecTyFalse->getElementType() &&
18720+
"Select's second and third vectors need the same element types.");
18721+
18722+
const unsigned N = VecTyTrue->getNumElements();
18723+
assert(N == VecTyFalse->getNumElements() &&
18724+
N == E->getArg(0)->getType()->getAs<VectorType>()->getNumElements() &&
18725+
"Select requires vectors to be of the same size.");
18726+
18727+
llvm::SmallVector<Value *> Mask;
18728+
for (unsigned I = 0; I < N; I++) {
18729+
Value *Index = ConstantInt::get(IntTy, I);
18730+
Value *IndexBool = Builder.CreateExtractElement(OpCond, Index);
18731+
Mask.push_back(Builder.CreateSelect(IndexBool, Index, ConstantInt::get(IntTy, I + N)));
18732+
}
18733+
18734+
return Builder.CreateShuffleVector(OpTrue, OpFalse, BuildVector(Mask));
18735+
}
18736+
18737+
llvm_unreachable("Select requires a bool or vector of bools as its first operand.");
18738+
}
1869818739
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
1869918740
return EmitRuntimeCall(CGM.CreateRuntimeFunction(
1870018741
llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1603,6 +1603,30 @@ double3 saturate(double3);
16031603
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
16041604
double4 saturate(double4);
16051605

1606+
//===----------------------------------------------------------------------===//
1607+
// select builtins
1608+
//===----------------------------------------------------------------------===//
1609+
1610+
/// \fn T select(bool Cond, T TrueVal, T FalseVal)
1611+
/// \brief ternary operator.
1612+
/// \param Cond The Condition input value.
1613+
/// \param TrueVal The Value returned if Cond is true.
1614+
/// \param FalseVal The Value returned if Cond is false.
1615+
1616+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
1617+
template<typename T>
1618+
T select(bool, T, T);
1619+
1620+
/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals, vector<T,Sz>, FalseVals)
1621+
/// \brief ternary operator for vectors. All vectors must be the same size.
1622+
/// \param Conds The Condition input values.
1623+
/// \param TrueVals The vector values are chosen from when conditions are true.
1624+
/// \param FalseVals The vector values are chosen from when conditions are false.
1625+
1626+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
1627+
template<typename T, int Sz>
1628+
vector<T,Sz> select(vector<bool,Sz>, vector<T,Sz>, vector<T,Sz>);
1629+
16061630
//===----------------------------------------------------------------------===//
16071631
// sin builtins
16081632
//===----------------------------------------------------------------------===//

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,6 +1512,66 @@ void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
15121512
TheCall->setType(ReturnType);
15131513
}
15141514

1515+
bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
1516+
assert(TheCall->getNumArgs() == 3);
1517+
Expr *Arg1 = TheCall->getArg(1);
1518+
Expr *Arg2 = TheCall->getArg(2);
1519+
if(!S->Context.hasSameUnqualifiedType(Arg1->getType(),
1520+
Arg2->getType())) {
1521+
S->Diag(TheCall->getBeginLoc(),
1522+
diag::err_typecheck_call_different_arg_types)
1523+
<< Arg1->getType() << Arg2->getType()
1524+
<< Arg1->getSourceRange() << Arg2->getSourceRange();
1525+
return true;
1526+
}
1527+
1528+
TheCall->setType(Arg1->getType());
1529+
return false;
1530+
}
1531+
1532+
bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
1533+
assert(TheCall->getNumArgs() == 3);
1534+
Expr *Arg1 = TheCall->getArg(1);
1535+
Expr *Arg2 = TheCall->getArg(2);
1536+
if(!Arg1->getType()->isVectorType()) {
1537+
S->Diag(Arg1->getBeginLoc(),
1538+
diag::err_builtin_non_vector_type)
1539+
<< "Second" << "__builtin_hlsl_select" << Arg1->getType()
1540+
<< Arg1->getSourceRange();
1541+
return true;
1542+
}
1543+
1544+
if(!Arg2->getType()->isVectorType()) {
1545+
S->Diag(Arg2->getBeginLoc(),
1546+
diag::err_builtin_non_vector_type)
1547+
<< "Third" << "__builtin_hlsl_select" << Arg2->getType()
1548+
<< Arg2->getSourceRange();
1549+
return true;
1550+
}
1551+
1552+
if(!S->Context.hasSameUnqualifiedType(Arg1->getType(),
1553+
Arg2->getType())) {
1554+
S->Diag(TheCall->getBeginLoc(),
1555+
diag::err_typecheck_call_different_arg_types)
1556+
<< Arg1->getType() << Arg2->getType()
1557+
<< Arg1->getSourceRange() << Arg2->getSourceRange();
1558+
return true;
1559+
}
1560+
1561+
// caller has checked that Arg0 is a vector.
1562+
// check all three args have the same length.
1563+
if(TheCall->getArg(0)->getType()->getAs<VectorType>()->getNumElements() !=
1564+
Arg1->getType()->getAs<VectorType>()->getNumElements()) {
1565+
S->Diag(TheCall->getBeginLoc(),
1566+
diag::err_typecheck_vector_lengths_not_equal)
1567+
<< TheCall->getArg(0)->getType() << Arg1->getType()
1568+
<< TheCall->getArg(0)->getSourceRange() << Arg1->getSourceRange();
1569+
return true;
1570+
}
1571+
1572+
return false;
1573+
}
1574+
15151575
// Note: returning true in this case results in CheckBuiltinFunctionCall
15161576
// returning an ExprError
15171577
bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
@@ -1545,6 +1605,30 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
15451605
break;
15461606
}
15471607
case Builtin::BI__builtin_hlsl_elementwise_saturate:
1608+
case Builtin::BI__builtin_hlsl_select: {
1609+
if (SemaRef.checkArgCount(TheCall, 3))
1610+
return true;
1611+
QualType ArgTy = TheCall->getArg(0)->getType();
1612+
if (ArgTy->isBooleanType()) {
1613+
if (CheckBoolSelect(&SemaRef, TheCall))
1614+
return true;
1615+
} else if (ArgTy->isVectorType() &&
1616+
ArgTy->getAs<VectorType>()->getElementType()->isBooleanType()) {
1617+
if (CheckVectorSelect(&SemaRef, TheCall))
1618+
return true;
1619+
} else { // first operand is not a bool or a vector of bools.
1620+
SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
1621+
diag::err_typecheck_convert_incompatible)
1622+
<< TheCall->getArg(0)->getType() << SemaRef.Context.getBOOLType()
1623+
<< 1 << 0 << 0;
1624+
SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
1625+
diag::err_builtin_non_vector_type)
1626+
<< "First" << "__builtin_hlsl_select" << TheCall->getArg(0)->getType()
1627+
<< TheCall->getArg(0)->getSourceRange();
1628+
return true;
1629+
}
1630+
break;
1631+
}
15481632
case Builtin::BI__builtin_hlsl_elementwise_rcp: {
15491633
if (CheckAllArgsHaveFloatRepresentation(&SemaRef, TheCall))
15501634
return true;

0 commit comments

Comments
 (0)