Skip to content

Commit 8709bca

Browse files
committed
clang: Add __builtin_elementwise_fma
I didn't understand why the other builtins have promotion logic, or how it would apply for a ternary operation. Implicit conversions are evil to begin with, and even more so when the purpose is to get an exact IR intrinsic. This checks all the arguments have the same type.
1 parent 5cf549e commit 8709bca

File tree

7 files changed

+207
-11
lines changed

7 files changed

+207
-11
lines changed

clang/docs/LanguageExtensions.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,7 @@ Unless specified otherwise operation(±0) = ±0 and operation(±infinity) = ±in
631631
=========================================== ================================================================ =========================================
632632
T __builtin_elementwise_abs(T x) return the absolute value of a number x; the absolute value of signed integer and floating point types
633633
the most negative integer remains the most negative integer
634+
T __builtin_elementwise_fma(T x, T y, T z) fused multiply add, (x * y) + z. floating point types
634635
T __builtin_elementwise_ceil(T x) return the smallest integral value greater than or equal to x floating point types
635636
T __builtin_elementwise_sin(T x) return the sine of x interpreted as an angle in radians floating point types
636637
T __builtin_elementwise_cos(T x) return the cosine of x interpreted as an angle in radians floating point types

clang/include/clang/Basic/Builtins.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,7 @@ BUILTIN(__builtin_elementwise_sin, "v.", "nct")
671671
BUILTIN(__builtin_elementwise_trunc, "v.", "nct")
672672
BUILTIN(__builtin_elementwise_canonicalize, "v.", "nct")
673673
BUILTIN(__builtin_elementwise_copysign, "v.", "nct")
674+
BUILTIN(__builtin_elementwise_fma, "v.", "nct")
674675
BUILTIN(__builtin_elementwise_add_sat, "v.", "nct")
675676
BUILTIN(__builtin_elementwise_sub_sat, "v.", "nct")
676677
BUILTIN(__builtin_reduce_max, "v.", "nct")

clang/include/clang/Sema/Sema.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13531,6 +13531,7 @@ class Sema final {
1353113531
bool CheckPPCMMAType(QualType Type, SourceLocation TypeLoc);
1353213532

1353313533
bool SemaBuiltinElementwiseMath(CallExpr *TheCall);
13534+
bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall);
1353413535
bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall);
1353513536
bool PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall);
1353613537

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3118,6 +3118,8 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
31183118
emitUnaryBuiltin(*this, E, llvm::Intrinsic::canonicalize, "elt.trunc"));
31193119
case Builtin::BI__builtin_elementwise_copysign:
31203120
return RValue::get(emitBinaryBuiltin(*this, E, llvm::Intrinsic::copysign));
3121+
case Builtin::BI__builtin_elementwise_fma:
3122+
return RValue::get(emitTernaryBuiltin(*this, E, llvm::Intrinsic::fma));
31213123
case Builtin::BI__builtin_elementwise_add_sat:
31223124
case Builtin::BI__builtin_elementwise_sub_sat: {
31233125
Value *Op0 = EmitScalarExpr(E->getArg(0));

clang/lib/Sema/SemaChecking.cpp

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2626,20 +2626,16 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
26262626
return ExprError();
26272627

26282628
QualType ArgTy = TheCall->getArg(0)->getType();
2629-
QualType EltTy = ArgTy;
2630-
2631-
if (auto *VecTy = EltTy->getAs<VectorType>())
2632-
EltTy = VecTy->getElementType();
2633-
if (!EltTy->isFloatingType()) {
2634-
Diag(TheCall->getArg(0)->getBeginLoc(),
2635-
diag::err_builtin_invalid_arg_type)
2636-
<< 1 << /* float ty*/ 5 << ArgTy;
2637-
2629+
if (checkFPMathBuiltinElementType(*this, TheCall->getArg(0)->getBeginLoc(),
2630+
ArgTy, 1))
2631+
return ExprError();
2632+
break;
2633+
}
2634+
case Builtin::BI__builtin_elementwise_fma: {
2635+
if (SemaBuiltinElementwiseTernaryMath(TheCall))
26382636
return ExprError();
2639-
}
26402637
break;
26412638
}
2642-
26432639
// These builtins restrict the element type to integer
26442640
// types only.
26452641
case Builtin::BI__builtin_elementwise_add_sat:
@@ -17877,6 +17873,40 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
1787717873
return false;
1787817874
}
1787917875

17876+
bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall) {
17877+
if (checkArgCount(*this, TheCall, 3))
17878+
return true;
17879+
17880+
Expr *Args[3];
17881+
for (int I = 0; I < 3; ++I) {
17882+
ExprResult Converted = UsualUnaryConversions(TheCall->getArg(I));
17883+
if (Converted.isInvalid())
17884+
return true;
17885+
Args[I] = Converted.get();
17886+
}
17887+
17888+
int ArgOrdinal = 1;
17889+
for (Expr *Arg : Args) {
17890+
if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
17891+
ArgOrdinal++))
17892+
return true;
17893+
}
17894+
17895+
for (int I = 1; I < 3; ++I) {
17896+
if (Args[0]->getType().getCanonicalType() !=
17897+
Args[I]->getType().getCanonicalType()) {
17898+
return Diag(Args[0]->getBeginLoc(),
17899+
diag::err_typecheck_call_different_arg_types)
17900+
<< Args[0]->getType() << Args[I]->getType();
17901+
}
17902+
17903+
TheCall->setArg(I, Args[I]);
17904+
}
17905+
17906+
TheCall->setType(Args[0]->getType());
17907+
return false;
17908+
}
17909+
1788017910
bool Sema::PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall) {
1788117911
if (checkArgCount(*this, TheCall, 1))
1788217912
return true;

clang/test/CodeGen/builtins-elementwise-math.c

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
// RUN: %clang_cc1 -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
22

3+
typedef _Float16 half;
4+
5+
typedef half half2 __attribute__((ext_vector_type(2)));
6+
typedef float float2 __attribute__((ext_vector_type(2)));
37
typedef float float4 __attribute__((ext_vector_type(4)));
48
typedef short int si8 __attribute__((ext_vector_type(8)));
59
typedef unsigned int u4 __attribute__((ext_vector_type(4)));
@@ -525,3 +529,77 @@ void test_builtin_elementwise_copysign(float f1, float f2, double d1, double d2,
525529
// CHECK-NEXT: call <2 x double> @llvm.copysign.v2f64(<2 x double> <double 1.000000e+00, double 1.000000e+00>, <2 x double> [[V2F64]])
526530
v2f64 = __builtin_elementwise_copysign((double2)1.0, v2f64);
527531
}
532+
533+
void test_builtin_elementwise_fma(float f32, double f64,
534+
float2 v2f32, float4 v4f32,
535+
double2 v2f64, double3 v3f64,
536+
const float4 c_v4f32,
537+
half f16, half2 v2f16) {
538+
// CHECK-LABEL: define void @test_builtin_elementwise_fma(
539+
// CHECK: [[F32_0:%.+]] = load float, ptr %f32.addr
540+
// CHECK-NEXT: [[F32_1:%.+]] = load float, ptr %f32.addr
541+
// CHECK-NEXT: [[F32_2:%.+]] = load float, ptr %f32.addr
542+
// CHECK-NEXT: call float @llvm.fma.f32(float [[F32_0]], float [[F32_1]], float [[F32_2]])
543+
float f2 = __builtin_elementwise_fma(f32, f32, f32);
544+
545+
// CHECK: [[F64_0:%.+]] = load double, ptr %f64.addr
546+
// CHECK-NEXT: [[F64_1:%.+]] = load double, ptr %f64.addr
547+
// CHECK-NEXT: [[F64_2:%.+]] = load double, ptr %f64.addr
548+
// CHECK-NEXT: call double @llvm.fma.f64(double [[F64_0]], double [[F64_1]], double [[F64_2]])
549+
double d2 = __builtin_elementwise_fma(f64, f64, f64);
550+
551+
// CHECK: [[V4F32_0:%.+]] = load <4 x float>, ptr %v4f32.addr
552+
// CHECK-NEXT: [[V4F32_1:%.+]] = load <4 x float>, ptr %v4f32.addr
553+
// CHECK-NEXT: [[V4F32_2:%.+]] = load <4 x float>, ptr %v4f32.addr
554+
// CHECK-NEXT: call <4 x float> @llvm.fma.v4f32(<4 x float> [[V4F32_0]], <4 x float> [[V4F32_1]], <4 x float> [[V4F32_2]])
555+
float4 tmp_v4f32 = __builtin_elementwise_fma(v4f32, v4f32, v4f32);
556+
557+
558+
// FIXME: Are we really still doing the 3 vector load workaround
559+
// CHECK: [[V3F64_LOAD_0:%.+]] = load <4 x double>, ptr %v3f64.addr
560+
// CHECK-NEXT: [[V3F64_0:%.+]] = shufflevector
561+
// CHECK-NEXT: [[V3F64_LOAD_1:%.+]] = load <4 x double>, ptr %v3f64.addr
562+
// CHECK-NEXT: [[V3F64_1:%.+]] = shufflevector
563+
// CHECK-NEXT: [[V3F64_LOAD_2:%.+]] = load <4 x double>, ptr %v3f64.addr
564+
// CHECK-NEXT: [[V3F64_2:%.+]] = shufflevector
565+
// CHECK-NEXT: call <3 x double> @llvm.fma.v3f64(<3 x double> [[V3F64_0]], <3 x double> [[V3F64_1]], <3 x double> [[V3F64_2]])
566+
v3f64 = __builtin_elementwise_fma(v3f64, v3f64, v3f64);
567+
568+
// CHECK: [[F64_0:%.+]] = load double, ptr %f64.addr
569+
// CHECK-NEXT: [[F64_1:%.+]] = load double, ptr %f64.addr
570+
// CHECK-NEXT: [[F64_2:%.+]] = load double, ptr %f64.addr
571+
// CHECK-NEXT: call double @llvm.fma.f64(double [[F64_0]], double [[F64_1]], double [[F64_2]])
572+
v2f64 = __builtin_elementwise_fma(f64, f64, f64);
573+
574+
// CHECK: [[V4F32_0:%.+]] = load <4 x float>, ptr %c_v4f32.addr
575+
// CHECK-NEXT: [[V4F32_1:%.+]] = load <4 x float>, ptr %c_v4f32.addr
576+
// CHECK-NEXT: [[V4F32_2:%.+]] = load <4 x float>, ptr %c_v4f32.addr
577+
// CHECK-NEXT: call <4 x float> @llvm.fma.v4f32(<4 x float> [[V4F32_0]], <4 x float> [[V4F32_1]], <4 x float> [[V4F32_2]])
578+
v4f32 = __builtin_elementwise_fma(c_v4f32, c_v4f32, c_v4f32);
579+
580+
// CHECK: [[F16_0:%.+]] = load half, ptr %f16.addr
581+
// CHECK-NEXT: [[F16_1:%.+]] = load half, ptr %f16.addr
582+
// CHECK-NEXT: [[F16_2:%.+]] = load half, ptr %f16.addr
583+
// CHECK-NEXT: call half @llvm.fma.f16(half [[F16_0]], half [[F16_1]], half [[F16_2]])
584+
half tmp_f16 = __builtin_elementwise_fma(f16, f16, f16);
585+
586+
// CHECK: [[V2F16_0:%.+]] = load <2 x half>, ptr %v2f16.addr
587+
// CHECK-NEXT: [[V2F16_1:%.+]] = load <2 x half>, ptr %v2f16.addr
588+
// CHECK-NEXT: [[V2F16_2:%.+]] = load <2 x half>, ptr %v2f16.addr
589+
// CHECK-NEXT: call <2 x half> @llvm.fma.v2f16(<2 x half> [[V2F16_0]], <2 x half> [[V2F16_1]], <2 x half> [[V2F16_2]])
590+
half2 tmp0_v2f16 = __builtin_elementwise_fma(v2f16, v2f16, v2f16);
591+
592+
// CHECK: [[V2F16_0:%.+]] = load <2 x half>, ptr %v2f16.addr
593+
// CHECK-NEXT: [[V2F16_1:%.+]] = load <2 x half>, ptr %v2f16.addr
594+
// CHECK-NEXT: [[F16_2:%.+]] = load half, ptr %f16.addr
595+
// CHECK-NEXT: [[V2F16_2_INSERT:%.+]] = insertelement
596+
// CHECK-NEXT: [[V2F16_2:%.+]] = shufflevector <2 x half> [[V2F16_2_INSERT]], <2 x half> poison, <2 x i32> zeroinitializer
597+
// CHECK-NEXT: call <2 x half> @llvm.fma.v2f16(<2 x half> [[V2F16_0]], <2 x half> [[V2F16_1]], <2 x half> [[V2F16_2]])
598+
half2 tmp1_v2f16 = __builtin_elementwise_fma(v2f16, v2f16, (half2)f16);
599+
600+
// CHECK: [[V2F16_0:%.+]] = load <2 x half>, ptr %v2f16.addr
601+
// CHECK-NEXT: [[V2F16_1:%.+]] = load <2 x half>, ptr %v2f16.addr
602+
// CHECK-NEXT: call <2 x half> @llvm.fma.v2f16(<2 x half> [[V2F16_0]], <2 x half> [[V2F16_1]], <2 x half> <half 0xH4400, half 0xH4400>)
603+
half2 tmp2_v2f16 = __builtin_elementwise_fma(v2f16, v2f16, (half2)4.0);
604+
605+
}

clang/test/Sema/builtins-elementwise-math.c

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ typedef double double2 __attribute__((ext_vector_type(2)));
44
typedef double double4 __attribute__((ext_vector_type(4)));
55
typedef float float2 __attribute__((ext_vector_type(2)));
66
typedef float float4 __attribute__((ext_vector_type(4)));
7+
8+
typedef int int2 __attribute__((ext_vector_type(2)));
79
typedef int int3 __attribute__((ext_vector_type(3)));
810
typedef unsigned unsigned3 __attribute__((ext_vector_type(3)));
911
typedef unsigned unsigned4 __attribute__((ext_vector_type(4)));
@@ -572,3 +574,84 @@ void test_builtin_elementwise_copysign(int i, short s, double d, float f, float4
572574
float2 tmp9 = __builtin_elementwise_copysign(v4f32, v4f32);
573575
// expected-error@-1 {{initializing 'float2' (vector of 2 'float' values) with an expression of incompatible type 'float4' (vector of 4 'float' values)}}
574576
}
577+
578+
void test_builtin_elementwise_fma(int i32, int2 v2i32, short i16,
579+
double f64, double2 v2f64, double2 v3f64,
580+
float f32, float2 v2f32, float v3f32, float4 v4f32,
581+
const float4 c_v4f32,
582+
int3 v3i32, int *ptr) {
583+
584+
f32 = __builtin_elementwise_fma();
585+
// expected-error@-1 {{too few arguments to function call, expected 3, have 0}}
586+
587+
f32 = __builtin_elementwise_fma(f32);
588+
// expected-error@-1 {{too few arguments to function call, expected 3, have 1}}
589+
590+
f32 = __builtin_elementwise_fma(f32, f32);
591+
// expected-error@-1 {{too few arguments to function call, expected 3, have 2}}
592+
593+
f32 = __builtin_elementwise_fma(f32, f32, f32, f32);
594+
// expected-error@-1 {{too many arguments to function call, expected 3, have 4}}
595+
596+
f32 = __builtin_elementwise_fma(f64, f32, f32);
597+
// expected-error@-1 {{arguments are of different types ('double' vs 'float')}}
598+
599+
f32 = __builtin_elementwise_fma(f32, f64, f32);
600+
// expected-error@-1 {{arguments are of different types ('float' vs 'double')}}
601+
602+
f32 = __builtin_elementwise_fma(f32, f32, f64);
603+
// expected-error@-1 {{arguments are of different types ('float' vs 'double')}}
604+
605+
f32 = __builtin_elementwise_fma(f32, f32, f64);
606+
// expected-error@-1 {{arguments are of different types ('float' vs 'double')}}
607+
608+
f64 = __builtin_elementwise_fma(f64, f32, f32);
609+
// expected-error@-1 {{arguments are of different types ('double' vs 'float')}}
610+
611+
f64 = __builtin_elementwise_fma(f64, f64, f32);
612+
// expected-error@-1 {{arguments are of different types ('double' vs 'float')}}
613+
614+
f64 = __builtin_elementwise_fma(f64, f32, f64);
615+
// expected-error@-1 {{arguments are of different types ('double' vs 'float')}}
616+
617+
v2f64 = __builtin_elementwise_fma(v2f32, f64, f64);
618+
// expected-error@-1 {{arguments are of different types ('float2' (vector of 2 'float' values) vs 'double'}}
619+
620+
v2f64 = __builtin_elementwise_fma(v2f32, v2f64, f64);
621+
// expected-error@-1 {{arguments are of different types ('float2' (vector of 2 'float' values) vs 'double2' (vector of 2 'double' values)}}
622+
623+
v2f64 = __builtin_elementwise_fma(v2f32, f64, v2f64);
624+
// expected-error@-1 {{arguments are of different types ('float2' (vector of 2 'float' values) vs 'double'}}
625+
626+
v2f64 = __builtin_elementwise_fma(f64, v2f32, v2f64);
627+
// expected-error@-1 {{arguments are of different types ('double' vs 'float2' (vector of 2 'float' values)}}
628+
629+
v2f64 = __builtin_elementwise_fma(f64, v2f64, v2f64);
630+
// expected-error@-1 {{arguments are of different types ('double' vs 'double2' (vector of 2 'double' values)}}
631+
632+
i32 = __builtin_elementwise_fma(i32, i32, i32);
633+
// expected-error@-1 {{1st argument must be a floating point type (was 'int')}}
634+
635+
v2i32 = __builtin_elementwise_fma(v2i32, v2i32, v2i32);
636+
// expected-error@-1 {{1st argument must be a floating point type (was 'int2' (vector of 2 'int' values))}}
637+
638+
f32 = __builtin_elementwise_fma(f32, f32, i32);
639+
// expected-error@-1 {{3rd argument must be a floating point type (was 'int')}}
640+
641+
f32 = __builtin_elementwise_fma(f32, i32, f32);
642+
// expected-error@-1 {{2nd argument must be a floating point type (was 'int')}}
643+
644+
f32 = __builtin_elementwise_fma(f32, f32, i32);
645+
// expected-error@-1 {{3rd argument must be a floating point type (was 'int')}}
646+
647+
648+
_Complex float c1, c2, c3;
649+
c1 = __builtin_elementwise_fma(c1, f32, f32);
650+
// expected-error@-1 {{1st argument must be a floating point type (was '_Complex float')}}
651+
652+
c2 = __builtin_elementwise_fma(f32, c2, f32);
653+
// expected-error@-1 {{2nd argument must be a floating point type (was '_Complex float')}}
654+
655+
c3 = __builtin_elementwise_fma(f32, f32, c3);
656+
// expected-error@-1 {{3rd argument must be a floating point type (was '_Complex float')}}
657+
}

0 commit comments

Comments
 (0)