Skip to content

[HLSL] implement mad intrinsic #83826

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4548,6 +4548,12 @@ def HLSLLerp : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}

def HLSLMad : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_mad"];
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}

// Builtins for XRay.
def XRayCustomEvent : Builtin {
let Spellings = ["__xray_customevent"];
Expand Down
3 changes: 2 additions & 1 deletion clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -14146,7 +14146,8 @@ class Sema final {
bool SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res);
bool SemaBuiltinVectorToScalarMath(CallExpr *TheCall);
bool SemaBuiltinElementwiseMath(CallExpr *TheCall);
bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall);
bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall,
bool CheckForFloatArgs = true);
bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall);
bool PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall);

Expand Down
19 changes: 19 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18044,6 +18044,25 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
/*ReturnType*/ Op0->getType(), Intrinsic::dx_frac,
ArrayRef<Value *>{Op0}, nullptr, "dx.frac");
}
case Builtin::BI__builtin_hlsl_mad: {
Value *M = EmitScalarExpr(E->getArg(0));
Value *A = EmitScalarExpr(E->getArg(1));
Value *B = EmitScalarExpr(E->getArg(2));
if (E->getArg(0)->getType()->hasFloatingRepresentation()) {
return Builder.CreateIntrinsic(
/*ReturnType*/ M->getType(), Intrinsic::fmuladd,
ArrayRef<Value *>{M, A, B}, nullptr, "dx.fmad");
}
if (E->getArg(0)->getType()->hasSignedIntegerRepresentation()) {
return Builder.CreateIntrinsic(
/*ReturnType*/ M->getType(), Intrinsic::dx_imad,
ArrayRef<Value *>{M, A, B}, nullptr, "dx.imad");
}
assert(E->getArg(0)->getType()->hasUnsignedIntegerRepresentation());
return Builder.CreateIntrinsic(
/*ReturnType*/ M->getType(), Intrinsic::dx_umad,
ArrayRef<Value *>{M, A, B}, nullptr, "dx.umad");
}
}
return nullptr;
}
Expand Down
105 changes: 105 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,111 @@ double3 log2(double3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_log2)
double4 log2(double4);

//===----------------------------------------------------------------------===//
// mad builtins
//===----------------------------------------------------------------------===//

/// \fn T mad(T M, T A, T B)
/// \brief The result of \a M * \a A + \a B.
/// \param M The multiplication value.
/// \param A The first addition value.
/// \param B The second addition value.

_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
half mad(half, half, half);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
half2 mad(half2, half2, half2);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
half3 mad(half3, half3, half3);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
half4 mad(half4, half4, half4);

#ifdef __HLSL_ENABLE_16_BIT
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int16_t mad(int16_t, int16_t, int16_t);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int16_t2 mad(int16_t2, int16_t2, int16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int16_t3 mad(int16_t3, int16_t3, int16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int16_t4 mad(int16_t4, int16_t4, int16_t4);

_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint16_t mad(uint16_t, uint16_t, uint16_t);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint16_t2 mad(uint16_t2, uint16_t2, uint16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint16_t3 mad(uint16_t3, uint16_t3, uint16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint16_t4 mad(uint16_t4, uint16_t4, uint16_t4);
#endif

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int mad(int, int, int);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int2 mad(int2, int2, int2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int3 mad(int3, int3, int3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int4 mad(int4, int4, int4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint mad(uint, uint, uint);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint2 mad(uint2, uint2, uint2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint3 mad(uint3, uint3, uint3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint4 mad(uint4, uint4, uint4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int64_t mad(int64_t, int64_t, int64_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int64_t2 mad(int64_t2, int64_t2, int64_t2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int64_t3 mad(int64_t3, int64_t3, int64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int64_t4 mad(int64_t4, int64_t4, int64_t4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint64_t mad(uint64_t, uint64_t, uint64_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint64_t2 mad(uint64_t2, uint64_t2, uint64_t2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint64_t3 mad(uint64_t3, uint64_t3, uint64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint64_t4 mad(uint64_t4, uint64_t4, uint64_t4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
float mad(float, float, float);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
float2 mad(float2, float2, float2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
float3 mad(float3, float3, float3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
float4 mad(float4, float4, float4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
double mad(double, double, double);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
double2 mad(double2, double2, double2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
double3 mad(double3, double3, double3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
double4 mad(double4, double4, double4);

//===----------------------------------------------------------------------===//
// max builtins
//===----------------------------------------------------------------------===//
Expand Down
28 changes: 23 additions & 5 deletions clang/lib/Sema/SemaChecking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5298,6 +5298,14 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
case Builtin::BI__builtin_hlsl_mad: {
if (checkArgCount(*this, TheCall, 3))
return true;
if (CheckVectorElementCallArgs(this, TheCall))
return true;
if (SemaBuiltinElementwiseTernaryMath(TheCall, /*CheckForFloatArgs*/ false))
return true;
}
}
return false;
}
Expand Down Expand Up @@ -19800,7 +19808,8 @@ bool Sema::SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res) {
return false;
}

bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall) {
bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall,
bool CheckForFloatArgs) {
if (checkArgCount(*this, TheCall, 3))
return true;

Expand All @@ -19812,11 +19821,20 @@ bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall) {
Args[I] = Converted.get();
}

int ArgOrdinal = 1;
for (Expr *Arg : Args) {
if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
if (CheckForFloatArgs) {
int ArgOrdinal = 1;
for (Expr *Arg : Args) {
if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(),
Arg->getType(), ArgOrdinal++))
return true;
}
} else {
int ArgOrdinal = 1;
for (Expr *Arg : Args) {
if (checkMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
ArgOrdinal++))
return true;
return true;
}
}

for (int I = 1; I < 3; ++I) {
Expand Down
Loading