Skip to content

Commit 205471c

Browse files
committed
[AArch64][SME] Warn when using a streaming builtin from a non-streaming function
This PR adds a warning that's emitted when a non-streaming or non-streaming-compatible builtin is called in an unsuitable function. Uses work by Kerry McLaughlin.
1 parent 0626ced commit 205471c

File tree

8 files changed

+302
-0
lines changed

8 files changed

+302
-0
lines changed

clang/include/clang/Basic/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,12 @@ clang_tablegen(arm_sme_builtin_cg.inc -gen-arm-sme-builtin-codegen
9797
clang_tablegen(arm_sme_sema_rangechecks.inc -gen-arm-sme-sema-rangechecks
9898
SOURCE arm_sme.td
9999
TARGET ClangARMSmeSemaRangeChecks)
100+
clang_tablegen(arm_sme_streaming_attrs.inc -gen-arm-sme-streaming-attrs
101+
SOURCE arm_sme.td
102+
TARGET ClangARMSmeStreamingAttrs)
103+
clang_tablegen(arm_sme_builtins_za_state.inc -gen-arm-sme-builtin-za-state
104+
SOURCE arm_sme.td
105+
TARGET ClangARMSmeBuiltinsZAState)
100106
clang_tablegen(arm_cde_builtins.inc -gen-arm-cde-builtin-def
101107
SOURCE arm_cde.td
102108
TARGET ClangARMCdeBuiltinsDef)

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3151,6 +3151,9 @@ def err_attribute_arm_feature_sve_bits_unsupported : Error<
31513151
def warn_attribute_arm_sm_incompat_builtin : Warning<
31523152
"builtin call has undefined behaviour when called from a %0 function">,
31533153
InGroup<DiagGroup<"undefined-arm-streaming">>;
3154+
def warn_attribute_arm_za_builtin_no_za_state : Warning<
3155+
"builtin call is not valid when calling from a function without active ZA state">,
3156+
InGroup<DiagGroup<"undefined-arm-za">>;
31543157
def err_sve_vector_in_non_sve_target : Error<
31553158
"SVE vector type %0 cannot be used in a target without sve">;
31563159
def err_attribute_riscv_rvv_bits_unsupported : Error<

clang/include/clang/Sema/Sema.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13840,7 +13840,10 @@ class Sema final {
1384013840
bool CheckNeonBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
1384113841
CallExpr *TheCall);
1384213842
bool CheckMVEBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
13843+
bool ParseSVEImmChecks(CallExpr *TheCall,
13844+
SmallVector<std::tuple<int, int, int>, 3> &ImmChecks);
1384313845
bool CheckSVEBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
13846+
bool CheckSMEBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
1384413847
bool CheckCDEBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
1384513848
CallExpr *TheCall);
1384613849
bool CheckARMCoprocessorImmediate(const TargetInfo &TI, const Expr *CoprocArg,

clang/lib/Sema/SemaChecking.cpp

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3000,6 +3000,134 @@ static QualType getNeonEltType(NeonTypeFlags Flags, ASTContext &Context,
30003000

30013001
enum ArmStreamingType { ArmNonStreaming, ArmStreaming, ArmStreamingCompatible };
30023002

3003+
bool Sema::ParseSVEImmChecks(
3004+
CallExpr *TheCall, SmallVector<std::tuple<int, int, int>, 3> &ImmChecks) {
3005+
// Perform all the immediate checks for this builtin call.
3006+
bool HasError = false;
3007+
for (auto &I : ImmChecks) {
3008+
int ArgNum, CheckTy, ElementSizeInBits;
3009+
std::tie(ArgNum, CheckTy, ElementSizeInBits) = I;
3010+
3011+
typedef bool (*OptionSetCheckFnTy)(int64_t Value);
3012+
3013+
// Function that checks whether the operand (ArgNum) is an immediate
3014+
// that is one of the predefined values.
3015+
auto CheckImmediateInSet = [&](OptionSetCheckFnTy CheckImm,
3016+
int ErrDiag) -> bool {
3017+
// We can't check the value of a dependent argument.
3018+
Expr *Arg = TheCall->getArg(ArgNum);
3019+
if (Arg->isTypeDependent() || Arg->isValueDependent())
3020+
return false;
3021+
3022+
// Check constant-ness first.
3023+
llvm::APSInt Imm;
3024+
if (SemaBuiltinConstantArg(TheCall, ArgNum, Imm))
3025+
return true;
3026+
3027+
if (!CheckImm(Imm.getSExtValue()))
3028+
return Diag(TheCall->getBeginLoc(), ErrDiag) << Arg->getSourceRange();
3029+
return false;
3030+
};
3031+
3032+
switch ((SVETypeFlags::ImmCheckType)CheckTy) {
3033+
case SVETypeFlags::ImmCheck0_31:
3034+
if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 31))
3035+
HasError = true;
3036+
break;
3037+
case SVETypeFlags::ImmCheck0_13:
3038+
if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 13))
3039+
HasError = true;
3040+
break;
3041+
case SVETypeFlags::ImmCheck1_16:
3042+
if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 1, 16))
3043+
HasError = true;
3044+
break;
3045+
case SVETypeFlags::ImmCheck0_7:
3046+
if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 7))
3047+
HasError = true;
3048+
break;
3049+
case SVETypeFlags::ImmCheckExtract:
3050+
if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0,
3051+
(2048 / ElementSizeInBits) - 1))
3052+
HasError = true;
3053+
break;
3054+
case SVETypeFlags::ImmCheckShiftRight:
3055+
if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 1, ElementSizeInBits))
3056+
HasError = true;
3057+
break;
3058+
case SVETypeFlags::ImmCheckShiftRightNarrow:
3059+
if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 1,
3060+
ElementSizeInBits / 2))
3061+
HasError = true;
3062+
break;
3063+
case SVETypeFlags::ImmCheckShiftLeft:
3064+
if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0,
3065+
ElementSizeInBits - 1))
3066+
HasError = true;
3067+
break;
3068+
case SVETypeFlags::ImmCheckLaneIndex:
3069+
if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0,
3070+
(128 / (1 * ElementSizeInBits)) - 1))
3071+
HasError = true;
3072+
break;
3073+
case SVETypeFlags::ImmCheckLaneIndexCompRotate:
3074+
if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0,
3075+
(128 / (2 * ElementSizeInBits)) - 1))
3076+
HasError = true;
3077+
break;
3078+
case SVETypeFlags::ImmCheckLaneIndexDot:
3079+
if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0,
3080+
(128 / (4 * ElementSizeInBits)) - 1))
3081+
HasError = true;
3082+
break;
3083+
case SVETypeFlags::ImmCheckComplexRot90_270:
3084+
if (CheckImmediateInSet([](int64_t V) { return V == 90 || V == 270; },
3085+
diag::err_rotation_argument_to_cadd))
3086+
HasError = true;
3087+
break;
3088+
case SVETypeFlags::ImmCheckComplexRotAll90:
3089+
if (CheckImmediateInSet(
3090+
[](int64_t V) {
3091+
return V == 0 || V == 90 || V == 180 || V == 270;
3092+
},
3093+
diag::err_rotation_argument_to_cmla))
3094+
HasError = true;
3095+
break;
3096+
case SVETypeFlags::ImmCheck0_1:
3097+
if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 1))
3098+
HasError = true;
3099+
break;
3100+
case SVETypeFlags::ImmCheck0_2:
3101+
if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 2))
3102+
HasError = true;
3103+
break;
3104+
case SVETypeFlags::ImmCheck0_3:
3105+
if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 3))
3106+
HasError = true;
3107+
break;
3108+
case SVETypeFlags::ImmCheck0_0:
3109+
if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 0))
3110+
HasError = true;
3111+
break;
3112+
case SVETypeFlags::ImmCheck0_15:
3113+
if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 15))
3114+
HasError = true;
3115+
break;
3116+
case SVETypeFlags::ImmCheck0_255:
3117+
if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 255))
3118+
HasError = true;
3119+
break;
3120+
case SVETypeFlags::ImmCheck2_4_Mul2:
3121+
if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 2, 4) ||
3122+
SemaBuiltinConstantArgMultiple(TheCall, ArgNum, 2))
3123+
HasError = true;
3124+
break;
3125+
}
3126+
}
3127+
3128+
return HasError;
3129+
}
3130+
30033131
static ArmStreamingType getArmStreamingFnType(const FunctionDecl *FD) {
30043132
if (FD->hasAttr<ArmLocallyStreamingAttr>())
30053133
return ArmStreaming;
@@ -3028,6 +3156,66 @@ static void checkArmStreamingBuiltin(Sema &S, CallExpr *TheCall,
30283156
<< TheCall->getSourceRange() << "streaming compatible";
30293157
return;
30303158
}
3159+
3160+
if (FnType == ArmNonStreaming && BuiltinType == ArmStreaming) {
3161+
S.Diag(TheCall->getBeginLoc(), diag::warn_attribute_arm_sm_incompat_builtin)
3162+
<< TheCall->getSourceRange() << "non-streaming";
3163+
}
3164+
}
3165+
3166+
static bool hasSMEZAState(const FunctionDecl *FD) {
3167+
if (FD->hasAttr<ArmNewZAAttr>())
3168+
return true;
3169+
if (const auto *T = FD->getType()->getAs<FunctionProtoType>())
3170+
if (T->getAArch64SMEAttributes() & FunctionType::SME_PStateZASharedMask)
3171+
return true;
3172+
return false;
3173+
}
3174+
3175+
static bool hasSMEZAState(unsigned BuiltinID) {
3176+
switch (BuiltinID) {
3177+
default:
3178+
return false;
3179+
#define GET_SME_BUILTIN_HAS_ZA_STATE
3180+
#include "clang/Basic/arm_sme_builtins_za_state.inc"
3181+
#undef GET_SME_BUILTIN_HAS_ZA_STATE
3182+
}
3183+
}
3184+
3185+
bool Sema::CheckSMEBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
3186+
if (const FunctionDecl *FD = getCurFunctionDecl()) {
3187+
bool debug = FD->getDeclName().getAsString() == "incompat_sve_sm";
3188+
std::optional<ArmStreamingType> BuiltinType;
3189+
3190+
switch (BuiltinID) {
3191+
default:
3192+
break;
3193+
#define GET_SME_STREAMING_ATTRS
3194+
#include "clang/Basic/arm_sme_streaming_attrs.inc"
3195+
#undef GET_SME_STREAMING_ATTRS
3196+
}
3197+
3198+
if (BuiltinType)
3199+
checkArmStreamingBuiltin(*this, TheCall, FD, *BuiltinType);
3200+
3201+
if (hasSMEZAState(BuiltinID) && !hasSMEZAState(FD))
3202+
Diag(TheCall->getBeginLoc(),
3203+
diag::warn_attribute_arm_za_builtin_no_za_state)
3204+
<< TheCall->getSourceRange();
3205+
}
3206+
3207+
// Range check SME intrinsics that take immediate values.
3208+
SmallVector<std::tuple<int, int, int>, 3> ImmChecks;
3209+
3210+
switch (BuiltinID) {
3211+
default:
3212+
return false;
3213+
#define GET_SME_IMMEDIATE_CHECK
3214+
#include "clang/Basic/arm_sme_sema_rangechecks.inc"
3215+
#undef GET_SME_IMMEDIATE_CHECK
3216+
}
3217+
3218+
return ParseSVEImmChecks(TheCall, ImmChecks);
30313219
}
30323220

30333221
bool Sema::CheckSVEBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
@@ -3564,6 +3752,9 @@ bool Sema::CheckAArch64BuiltinFunctionCall(const TargetInfo &TI,
35643752
if (CheckSVEBuiltinFunctionCall(BuiltinID, TheCall))
35653753
return true;
35663754

3755+
if (CheckSMEBuiltinFunctionCall(BuiltinID, TheCall))
3756+
return true;
3757+
35673758
// For intrinsics which take an immediate value as part of the instruction,
35683759
// range check them here.
35693760
unsigned i = 0, l = 0, u = 0;

clang/test/Sema/aarch64-incompat-sm-builtin-calls.c

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
// REQUIRES: aarch64-registered-target
66

77
#include "arm_neon.h"
8+
#include "arm_sme_draft_spec_subject_to_change.h"
89

910
int16x8_t incompat_neon_sm(int16x8_t splat) __arm_streaming {
1011
// expected-warning@+1 {{builtin call has undefined behaviour when called from a streaming function}}
@@ -20,3 +21,23 @@ int16x8_t incompat_neon_smc(int16x8_t splat) __arm_streaming_compatible {
2021
// expected-warning@+1 {{builtin call has undefined behaviour when called from a streaming compatible function}}
2122
return (int16x8_t)__builtin_neon_vqaddq_v((int8x16_t)splat, (int8x16_t)splat, 33);
2223
}
24+
25+
void incompat_sme_norm(svbool_t pg, void const *ptr) __arm_shared_za {
26+
// expected-warning@+1 {{builtin call has undefined behaviour when called from a non-streaming function}}
27+
return __builtin_sme_svld1_hor_za128(0, 0, pg, ptr);
28+
}
29+
30+
void incompat_sme_smc(svbool_t pg, void const *ptr) __arm_streaming_compatible __arm_shared_za {
31+
// expected-warning@+1 {{builtin call has undefined behaviour when called from a streaming compatible function}}
32+
return __builtin_sme_svld1_hor_za128(0, 0, pg, ptr);
33+
}
34+
35+
void incompat_sme_sm(svbool_t pn, svbool_t pm, svfloat32_t zn, svfloat32_t zm) __arm_shared_za {
36+
// expected-warning@+1 {{builtin call has undefined behaviour when called from a non-streaming function}}
37+
svmops_za32_f32_m(0, pn, pm, zn, zm);
38+
}
39+
40+
svbool_t streaming_caller_ptrue(void) __arm_streaming {
41+
// expected-no-warning
42+
return svand_z(svptrue_b16(), svptrue_pat_b16(SV_ALL), svptrue_pat_b16(SV_VL4));
43+
}

clang/utils/TableGen/SveEmitter.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,9 @@ class SVEEmitter {
378378
/// Emit all the information needed to map builtin -> LLVM IR intrinsic.
379379
void createSMECodeGenMap(raw_ostream &o);
380380

381+
/// Create a table for a builtin's requirement for PSTATE.SM.
382+
void createStreamingAttrs(raw_ostream &o, ACLEKind Kind);
383+
381384
/// Emit all the range checks for the immediates.
382385
void createSMERangeChecks(raw_ostream &o);
383386

@@ -1369,6 +1372,12 @@ void SVEEmitter::createHeader(raw_ostream &OS) {
13691372
OS << "#define __aio static __inline__ __attribute__((__always_inline__, "
13701373
"__nodebug__, __overloadable__))\n\n";
13711374

1375+
OS << "#ifdef __ARM_FEATURE_SME\n";
1376+
OS << "#define __asc __attribute__((arm_streaming_compatible))\n";
1377+
OS << "#else\n";
1378+
OS << "#define __asc\n";
1379+
OS << "#endif\n\n";
1380+
13721381
// Add reinterpret functions.
13731382
for (auto [N, Suffix] :
13741383
std::initializer_list<std::pair<unsigned, const char *>>{
@@ -1688,6 +1697,61 @@ void SVEEmitter::createSMERangeChecks(raw_ostream &OS) {
16881697
OS << "#endif\n\n";
16891698
}
16901699

1700+
void SVEEmitter::createStreamingAttrs(raw_ostream &OS, ACLEKind Kind) {
1701+
std::vector<Record *> RV = Records.getAllDerivedDefinitions("Inst");
1702+
SmallVector<std::unique_ptr<Intrinsic>, 128> Defs;
1703+
for (auto *R : RV)
1704+
createIntrinsic(R, Defs);
1705+
1706+
// The mappings must be sorted based on BuiltinID.
1707+
llvm::sort(Defs, [](const std::unique_ptr<Intrinsic> &A,
1708+
const std::unique_ptr<Intrinsic> &B) {
1709+
return A->getMangledName() < B->getMangledName();
1710+
});
1711+
1712+
switch (Kind) {
1713+
case ACLEKind::SME:
1714+
OS << "#ifdef GET_SME_STREAMING_ATTRS\n";
1715+
break;
1716+
case ACLEKind::SVE:
1717+
OS << "#ifdef GET_SVE_STREAMING_ATTRS\n";
1718+
break;
1719+
}
1720+
1721+
// Ensure these are only emitted once.
1722+
std::set<std::string> Emitted;
1723+
1724+
uint64_t IsStreamingFlag = getEnumValueForFlag("IsStreaming");
1725+
uint64_t IsStreamingCompatibleFlag =
1726+
getEnumValueForFlag("IsStreamingCompatible");
1727+
for (auto &Def : Defs) {
1728+
if (Emitted.find(Def->getMangledName()) != Emitted.end())
1729+
continue;
1730+
1731+
switch (Kind) {
1732+
case ACLEKind::SME:
1733+
OS << "case SME::BI__builtin_sme_";
1734+
break;
1735+
case ACLEKind::SVE:
1736+
OS << "case SVE::BI__builtin_sve_";
1737+
break;
1738+
}
1739+
OS << Def->getMangledName() << ":\n";
1740+
1741+
if (Def->isFlagSet(IsStreamingFlag))
1742+
OS << " BuiltinType = ArmStreaming;\n";
1743+
else if (Def->isFlagSet(IsStreamingCompatibleFlag))
1744+
OS << " BuiltinType = ArmStreamingCompatible;\n";
1745+
else
1746+
OS << " BuiltinType = ArmNonStreaming;\n";
1747+
OS << " break;\n";
1748+
1749+
Emitted.insert(Def->getMangledName());
1750+
}
1751+
1752+
OS << "#endif\n\n";
1753+
}
1754+
16911755
namespace clang {
16921756
void EmitSveHeader(RecordKeeper &Records, raw_ostream &OS) {
16931757
SVEEmitter(Records).createHeader(OS);
@@ -1724,4 +1788,8 @@ void EmitSmeBuiltinCG(RecordKeeper &Records, raw_ostream &OS) {
17241788
void EmitSmeRangeChecks(RecordKeeper &Records, raw_ostream &OS) {
17251789
SVEEmitter(Records).createSMERangeChecks(OS);
17261790
}
1791+
1792+
void EmitSmeStreamingAttrs(RecordKeeper &Records, raw_ostream &OS) {
1793+
SVEEmitter(Records).createStreamingAttrs(OS, ACLEKind::SME);
1794+
}
17271795
} // End namespace clang

0 commit comments

Comments
 (0)