Skip to content

Commit 643b31d

Browse files
farzonlFarzon Lotfi
andauthored
[HLSL] implement mad intrinsic (#83826)
This change implements #83736 The dot product lowering needs a tertiary multipy add operation. DXIL has three mad opcodes for `fmad`(46), `imad`(48), and `umad`(49). Dot product in DXIL only uses `imad`\ `umad`, but for completeness and because the hlsl `mad` intrinsic requires it `fmad` was also included. Two new intrinsics were needed to be created to complete this change. the `fmad` case already supported by llvm via `fmuladd` intrinsic. - `hlsl_intrinsics.h` - exposed mad api call. - `Builtins.td` - exposed a `mad` builtin. - `Sema.h` - make `tertiary` calls check for float types optional. - `CGBuiltin.cpp` - pick the intrinsic for singed\unsigned & float also reuse `int_fmuladd`. - `SemaChecking.cpp` - type checks for `__builtin_hlsl_mad`. - `IntrinsicsDirectX.td` create the two new intrinsics for `imad`\`umad`/ - `DXIL.td` - create the llvm intrinsic to `DXIL` opcode mapping. --------- Co-authored-by: Farzon Lotfi <[email protected]>
1 parent a730ed7 commit 643b31d

File tree

12 files changed

+638
-6
lines changed

12 files changed

+638
-6
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4572,6 +4572,12 @@ def HLSLLerp : LangBuiltin<"HLSL_LANG"> {
45724572
let Prototype = "void(...)";
45734573
}
45744574

4575+
def HLSLMad : LangBuiltin<"HLSL_LANG"> {
4576+
let Spellings = ["__builtin_hlsl_mad"];
4577+
let Attributes = [NoThrow, Const];
4578+
let Prototype = "void(...)";
4579+
}
4580+
45754581
// Builtins for XRay.
45764582
def XRayCustomEvent : Builtin {
45774583
let Spellings = ["__xray_customevent"];

clang/include/clang/Sema/Sema.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14146,7 +14146,8 @@ class Sema final {
1414614146
bool SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res);
1414714147
bool SemaBuiltinVectorToScalarMath(CallExpr *TheCall);
1414814148
bool SemaBuiltinElementwiseMath(CallExpr *TheCall);
14149-
bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall);
14149+
bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall,
14150+
bool CheckForFloatArgs = true);
1415014151
bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall);
1415114152
bool PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall);
1415214153

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18044,6 +18044,25 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
1804418044
/*ReturnType*/ Op0->getType(), Intrinsic::dx_frac,
1804518045
ArrayRef<Value *>{Op0}, nullptr, "dx.frac");
1804618046
}
18047+
case Builtin::BI__builtin_hlsl_mad: {
18048+
Value *M = EmitScalarExpr(E->getArg(0));
18049+
Value *A = EmitScalarExpr(E->getArg(1));
18050+
Value *B = EmitScalarExpr(E->getArg(2));
18051+
if (E->getArg(0)->getType()->hasFloatingRepresentation()) {
18052+
return Builder.CreateIntrinsic(
18053+
/*ReturnType*/ M->getType(), Intrinsic::fmuladd,
18054+
ArrayRef<Value *>{M, A, B}, nullptr, "dx.fmad");
18055+
}
18056+
if (E->getArg(0)->getType()->hasSignedIntegerRepresentation()) {
18057+
return Builder.CreateIntrinsic(
18058+
/*ReturnType*/ M->getType(), Intrinsic::dx_imad,
18059+
ArrayRef<Value *>{M, A, B}, nullptr, "dx.imad");
18060+
}
18061+
assert(E->getArg(0)->getType()->hasUnsignedIntegerRepresentation());
18062+
return Builder.CreateIntrinsic(
18063+
/*ReturnType*/ M->getType(), Intrinsic::dx_umad,
18064+
ArrayRef<Value *>{M, A, B}, nullptr, "dx.umad");
18065+
}
1804718066
}
1804818067
return nullptr;
1804918068
}

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,111 @@ double3 log2(double3);
511511
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_log2)
512512
double4 log2(double4);
513513

514+
//===----------------------------------------------------------------------===//
515+
// mad builtins
516+
//===----------------------------------------------------------------------===//
517+
518+
/// \fn T mad(T M, T A, T B)
519+
/// \brief The result of \a M * \a A + \a B.
520+
/// \param M The multiplication value.
521+
/// \param A The first addition value.
522+
/// \param B The second addition value.
523+
524+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
525+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
526+
half mad(half, half, half);
527+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
528+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
529+
half2 mad(half2, half2, half2);
530+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
531+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
532+
half3 mad(half3, half3, half3);
533+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
534+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
535+
half4 mad(half4, half4, half4);
536+
537+
#ifdef __HLSL_ENABLE_16_BIT
538+
_HLSL_AVAILABILITY(shadermodel, 6.2)
539+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
540+
int16_t mad(int16_t, int16_t, int16_t);
541+
_HLSL_AVAILABILITY(shadermodel, 6.2)
542+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
543+
int16_t2 mad(int16_t2, int16_t2, int16_t2);
544+
_HLSL_AVAILABILITY(shadermodel, 6.2)
545+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
546+
int16_t3 mad(int16_t3, int16_t3, int16_t3);
547+
_HLSL_AVAILABILITY(shadermodel, 6.2)
548+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
549+
int16_t4 mad(int16_t4, int16_t4, int16_t4);
550+
551+
_HLSL_AVAILABILITY(shadermodel, 6.2)
552+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
553+
uint16_t mad(uint16_t, uint16_t, uint16_t);
554+
_HLSL_AVAILABILITY(shadermodel, 6.2)
555+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
556+
uint16_t2 mad(uint16_t2, uint16_t2, uint16_t2);
557+
_HLSL_AVAILABILITY(shadermodel, 6.2)
558+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
559+
uint16_t3 mad(uint16_t3, uint16_t3, uint16_t3);
560+
_HLSL_AVAILABILITY(shadermodel, 6.2)
561+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
562+
uint16_t4 mad(uint16_t4, uint16_t4, uint16_t4);
563+
#endif
564+
565+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
566+
int mad(int, int, int);
567+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
568+
int2 mad(int2, int2, int2);
569+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
570+
int3 mad(int3, int3, int3);
571+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
572+
int4 mad(int4, int4, int4);
573+
574+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
575+
uint mad(uint, uint, uint);
576+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
577+
uint2 mad(uint2, uint2, uint2);
578+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
579+
uint3 mad(uint3, uint3, uint3);
580+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
581+
uint4 mad(uint4, uint4, uint4);
582+
583+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
584+
int64_t mad(int64_t, int64_t, int64_t);
585+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
586+
int64_t2 mad(int64_t2, int64_t2, int64_t2);
587+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
588+
int64_t3 mad(int64_t3, int64_t3, int64_t3);
589+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
590+
int64_t4 mad(int64_t4, int64_t4, int64_t4);
591+
592+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
593+
uint64_t mad(uint64_t, uint64_t, uint64_t);
594+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
595+
uint64_t2 mad(uint64_t2, uint64_t2, uint64_t2);
596+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
597+
uint64_t3 mad(uint64_t3, uint64_t3, uint64_t3);
598+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
599+
uint64_t4 mad(uint64_t4, uint64_t4, uint64_t4);
600+
601+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
602+
float mad(float, float, float);
603+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
604+
float2 mad(float2, float2, float2);
605+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
606+
float3 mad(float3, float3, float3);
607+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
608+
float4 mad(float4, float4, float4);
609+
610+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
611+
double mad(double, double, double);
612+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
613+
double2 mad(double2, double2, double2);
614+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
615+
double3 mad(double3, double3, double3);
616+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
617+
double4 mad(double4, double4, double4);
618+
514619
//===----------------------------------------------------------------------===//
515620
// max builtins
516621
//===----------------------------------------------------------------------===//

clang/lib/Sema/SemaChecking.cpp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5298,6 +5298,14 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
52985298
return true;
52995299
break;
53005300
}
5301+
case Builtin::BI__builtin_hlsl_mad: {
5302+
if (checkArgCount(*this, TheCall, 3))
5303+
return true;
5304+
if (CheckVectorElementCallArgs(this, TheCall))
5305+
return true;
5306+
if (SemaBuiltinElementwiseTernaryMath(TheCall, /*CheckForFloatArgs*/ false))
5307+
return true;
5308+
}
53015309
}
53025310
return false;
53035311
}
@@ -19798,7 +19806,8 @@ bool Sema::SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res) {
1979819806
return false;
1979919807
}
1980019808

19801-
bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall) {
19809+
bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall,
19810+
bool CheckForFloatArgs) {
1980219811
if (checkArgCount(*this, TheCall, 3))
1980319812
return true;
1980419813

@@ -19810,11 +19819,20 @@ bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall) {
1981019819
Args[I] = Converted.get();
1981119820
}
1981219821

19813-
int ArgOrdinal = 1;
19814-
for (Expr *Arg : Args) {
19815-
if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
19822+
if (CheckForFloatArgs) {
19823+
int ArgOrdinal = 1;
19824+
for (Expr *Arg : Args) {
19825+
if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(),
19826+
Arg->getType(), ArgOrdinal++))
19827+
return true;
19828+
}
19829+
} else {
19830+
int ArgOrdinal = 1;
19831+
for (Expr *Arg : Args) {
19832+
if (checkMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
1981619833
ArgOrdinal++))
19817-
return true;
19834+
return true;
19835+
}
1981819836
}
1981919837

1982019838
for (int I = 1; I < 3; ++I) {

0 commit comments

Comments
 (0)