@@ -3000,6 +3000,134 @@ static QualType getNeonEltType(NeonTypeFlags Flags, ASTContext &Context,
3000
3000
3001
3001
enum ArmStreamingType { ArmNonStreaming, ArmStreaming, ArmStreamingCompatible };
3002
3002
3003
+ bool Sema::ParseSVEImmChecks(
3004
+ CallExpr *TheCall, SmallVector<std::tuple<int, int, int>, 3> &ImmChecks) {
3005
+ // Perform all the immediate checks for this builtin call.
3006
+ bool HasError = false;
3007
+ for (auto &I : ImmChecks) {
3008
+ int ArgNum, CheckTy, ElementSizeInBits;
3009
+ std::tie(ArgNum, CheckTy, ElementSizeInBits) = I;
3010
+
3011
+ typedef bool (*OptionSetCheckFnTy)(int64_t Value);
3012
+
3013
+ // Function that checks whether the operand (ArgNum) is an immediate
3014
+ // that is one of the predefined values.
3015
+ auto CheckImmediateInSet = [&](OptionSetCheckFnTy CheckImm,
3016
+ int ErrDiag) -> bool {
3017
+ // We can't check the value of a dependent argument.
3018
+ Expr *Arg = TheCall->getArg(ArgNum);
3019
+ if (Arg->isTypeDependent() || Arg->isValueDependent())
3020
+ return false;
3021
+
3022
+ // Check constant-ness first.
3023
+ llvm::APSInt Imm;
3024
+ if (SemaBuiltinConstantArg(TheCall, ArgNum, Imm))
3025
+ return true;
3026
+
3027
+ if (!CheckImm(Imm.getSExtValue()))
3028
+ return Diag(TheCall->getBeginLoc(), ErrDiag) << Arg->getSourceRange();
3029
+ return false;
3030
+ };
3031
+
3032
+ switch ((SVETypeFlags::ImmCheckType)CheckTy) {
3033
+ case SVETypeFlags::ImmCheck0_31:
3034
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 31))
3035
+ HasError = true;
3036
+ break;
3037
+ case SVETypeFlags::ImmCheck0_13:
3038
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 13))
3039
+ HasError = true;
3040
+ break;
3041
+ case SVETypeFlags::ImmCheck1_16:
3042
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 1, 16))
3043
+ HasError = true;
3044
+ break;
3045
+ case SVETypeFlags::ImmCheck0_7:
3046
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 7))
3047
+ HasError = true;
3048
+ break;
3049
+ case SVETypeFlags::ImmCheckExtract:
3050
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0,
3051
+ (2048 / ElementSizeInBits) - 1))
3052
+ HasError = true;
3053
+ break;
3054
+ case SVETypeFlags::ImmCheckShiftRight:
3055
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 1, ElementSizeInBits))
3056
+ HasError = true;
3057
+ break;
3058
+ case SVETypeFlags::ImmCheckShiftRightNarrow:
3059
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 1,
3060
+ ElementSizeInBits / 2))
3061
+ HasError = true;
3062
+ break;
3063
+ case SVETypeFlags::ImmCheckShiftLeft:
3064
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0,
3065
+ ElementSizeInBits - 1))
3066
+ HasError = true;
3067
+ break;
3068
+ case SVETypeFlags::ImmCheckLaneIndex:
3069
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0,
3070
+ (128 / (1 * ElementSizeInBits)) - 1))
3071
+ HasError = true;
3072
+ break;
3073
+ case SVETypeFlags::ImmCheckLaneIndexCompRotate:
3074
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0,
3075
+ (128 / (2 * ElementSizeInBits)) - 1))
3076
+ HasError = true;
3077
+ break;
3078
+ case SVETypeFlags::ImmCheckLaneIndexDot:
3079
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0,
3080
+ (128 / (4 * ElementSizeInBits)) - 1))
3081
+ HasError = true;
3082
+ break;
3083
+ case SVETypeFlags::ImmCheckComplexRot90_270:
3084
+ if (CheckImmediateInSet([](int64_t V) { return V == 90 || V == 270; },
3085
+ diag::err_rotation_argument_to_cadd))
3086
+ HasError = true;
3087
+ break;
3088
+ case SVETypeFlags::ImmCheckComplexRotAll90:
3089
+ if (CheckImmediateInSet(
3090
+ [](int64_t V) {
3091
+ return V == 0 || V == 90 || V == 180 || V == 270;
3092
+ },
3093
+ diag::err_rotation_argument_to_cmla))
3094
+ HasError = true;
3095
+ break;
3096
+ case SVETypeFlags::ImmCheck0_1:
3097
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 1))
3098
+ HasError = true;
3099
+ break;
3100
+ case SVETypeFlags::ImmCheck0_2:
3101
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 2))
3102
+ HasError = true;
3103
+ break;
3104
+ case SVETypeFlags::ImmCheck0_3:
3105
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 3))
3106
+ HasError = true;
3107
+ break;
3108
+ case SVETypeFlags::ImmCheck0_0:
3109
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 0))
3110
+ HasError = true;
3111
+ break;
3112
+ case SVETypeFlags::ImmCheck0_15:
3113
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 15))
3114
+ HasError = true;
3115
+ break;
3116
+ case SVETypeFlags::ImmCheck0_255:
3117
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 255))
3118
+ HasError = true;
3119
+ break;
3120
+ case SVETypeFlags::ImmCheck2_4_Mul2:
3121
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 2, 4) ||
3122
+ SemaBuiltinConstantArgMultiple(TheCall, ArgNum, 2))
3123
+ HasError = true;
3124
+ break;
3125
+ }
3126
+ }
3127
+
3128
+ return HasError;
3129
+ }
3130
+
3003
3131
static ArmStreamingType getArmStreamingFnType(const FunctionDecl *FD) {
3004
3132
if (FD->hasAttr<ArmLocallyStreamingAttr>())
3005
3133
return ArmStreaming;
@@ -3028,6 +3156,66 @@ static void checkArmStreamingBuiltin(Sema &S, CallExpr *TheCall,
3028
3156
<< TheCall->getSourceRange() << "streaming compatible";
3029
3157
return;
3030
3158
}
3159
+
3160
+ if (FnType == ArmNonStreaming && BuiltinType == ArmStreaming) {
3161
+ S.Diag(TheCall->getBeginLoc(), diag::warn_attribute_arm_sm_incompat_builtin)
3162
+ << TheCall->getSourceRange() << "non-streaming";
3163
+ }
3164
+ }
3165
+
3166
+ static bool hasSMEZAState(const FunctionDecl *FD) {
3167
+ if (FD->hasAttr<ArmNewZAAttr>())
3168
+ return true;
3169
+ if (const auto *T = FD->getType()->getAs<FunctionProtoType>())
3170
+ if (T->getAArch64SMEAttributes() & FunctionType::SME_PStateZASharedMask)
3171
+ return true;
3172
+ return false;
3173
+ }
3174
+
3175
+ static bool hasSMEZAState(unsigned BuiltinID) {
3176
+ switch (BuiltinID) {
3177
+ default:
3178
+ return false;
3179
+ #define GET_SME_BUILTIN_HAS_ZA_STATE
3180
+ #include "clang/Basic/arm_sme_builtins_za_state.inc"
3181
+ #undef GET_SME_BUILTIN_HAS_ZA_STATE
3182
+ }
3183
+ }
3184
+
3185
+ bool Sema::CheckSMEBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
3186
+ if (const FunctionDecl *FD = getCurFunctionDecl()) {
3187
+ bool debug = FD->getDeclName().getAsString() == "incompat_sve_sm";
3188
+ std::optional<ArmStreamingType> BuiltinType;
3189
+
3190
+ switch (BuiltinID) {
3191
+ default:
3192
+ break;
3193
+ #define GET_SME_STREAMING_ATTRS
3194
+ #include "clang/Basic/arm_sme_streaming_attrs.inc"
3195
+ #undef GET_SME_STREAMING_ATTRS
3196
+ }
3197
+
3198
+ if (BuiltinType)
3199
+ checkArmStreamingBuiltin(*this, TheCall, FD, *BuiltinType);
3200
+
3201
+ if (hasSMEZAState(BuiltinID) && !hasSMEZAState(FD))
3202
+ Diag(TheCall->getBeginLoc(),
3203
+ diag::warn_attribute_arm_za_builtin_no_za_state)
3204
+ << TheCall->getSourceRange();
3205
+ }
3206
+
3207
+ // Range check SME intrinsics that take immediate values.
3208
+ SmallVector<std::tuple<int, int, int>, 3> ImmChecks;
3209
+
3210
+ switch (BuiltinID) {
3211
+ default:
3212
+ return false;
3213
+ #define GET_SME_IMMEDIATE_CHECK
3214
+ #include "clang/Basic/arm_sme_sema_rangechecks.inc"
3215
+ #undef GET_SME_IMMEDIATE_CHECK
3216
+ }
3217
+
3218
+ return ParseSVEImmChecks(TheCall, ImmChecks);
3031
3219
}
3032
3220
3033
3221
bool Sema::CheckSVEBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
@@ -3564,6 +3752,9 @@ bool Sema::CheckAArch64BuiltinFunctionCall(const TargetInfo &TI,
3564
3752
if (CheckSVEBuiltinFunctionCall(BuiltinID, TheCall))
3565
3753
return true;
3566
3754
3755
+ if (CheckSMEBuiltinFunctionCall(BuiltinID, TheCall))
3756
+ return true;
3757
+
3567
3758
// For intrinsics which take an immediate value as part of the instruction,
3568
3759
// range check them here.
3569
3760
unsigned i = 0, l = 0, u = 0;
0 commit comments