Skip to content

[ubsan][NFCI] Use SanitizerOrdinal instead of SanitizerMask for EmitCheck (exactly one sanitizer is required) #122511

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 10, 2025

Conversation

thurstond
Copy link
Contributor

@thurstond thurstond commented Jan 10, 2025

The Checked parameter of CodeGenFunction::EmitCheck is of type ArrayRef<std::pair<llvm::Value *, SanitizerMask>>, which is overly generalized: SanitizerMask can denote that zero or more sanitizers are enabled, but EmitCheck requires that exactly one sanitizer is specified in the SanitizerMask (e.g., SanitizeTrap.has(Checked[i].second) enforces that).

This patch replaces SanitizerMask with SanitizerOrdinal in the Checked parameter of EmitCheck and code that transitively relies on it. This should not affect the behavior of UBSan, but it has the advantages that:

  • the code is clearer: it avoids ambiguity in EmitCheck about what to do if multiple bits are set
  • specifying the wrong number of sanitizers in Checked[i].second will be detected as a compile-time error, rather than a runtime assertion failure

Suggested by Vitaly in #122392 as an alternative to adding an explicit runtime assertion that the SanitizerMask contains exactly one sanitizer.

The `Checked` parameter of `CodeGenFunction::EmitCheck` is of type `ArrayRef<std::pair<llvm::Value *, SanitizerMask>>`, which is overly generalized: SanitizerMask can denote that zero or more sanitizers are enabled, but `EmitCheck` requires that exactly one sanitizer is specified in the SanitizerMask (e.g., `SanitizeTrap.has(Checked[i].second)` enforces that).

This patch replaces SanitizerMask with SanitizerOrdinal for `EmitCheck`
and code that calls it transitively. This should not affect the behavior
of UBSan, but it has the advantage that specifying the wrong number of
sanitizers in `Checked[i].second` will be detected as a compile-time
error, rather than a runtime assertion failure.

Suggested by Vitaly in llvm#122392
as an alternative to adding an explicit runtime assertion that the SanitizerMask
contains exactly one sanitizer.
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:codegen IR generation bugs: mangling, exceptions, etc. labels Jan 10, 2025
@llvmbot
Copy link
Member

llvmbot commented Jan 10, 2025

@llvm/pr-subscribers-clang

Author: Thurston Dang (thurstond)

Changes

The Checked parameter of CodeGenFunction::EmitCheck is of type ArrayRef&lt;std::pair&lt;llvm::Value *, SanitizerMask&gt;&gt;, which is overly generalized: SanitizerMask can denote that zero or more sanitizers are enabled, but EmitCheck requires that exactly one sanitizer is specified in the SanitizerMask (e.g., SanitizeTrap.has(Checked[i].second) enforces that).

This patch replaces SanitizerMask with SanitizerOrdinal in EmitCheck and code that calls it transitively. This should not affect the behavior of UBSan, but it has the advantage that specifying the wrong number of sanitizers in Checked[i].second will be detected as a compile-time error, rather than a runtime assertion failure.

Suggested by Vitaly in #122392 as an alternative to adding an explicit runtime assertion that the SanitizerMask contains exactly one sanitizer.


Patch is 35.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122511.diff

11 Files Affected:

  • (modified) clang/include/clang/Basic/Sanitizers.h (+4)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+3-3)
  • (modified) clang/lib/CodeGen/CGCall.cpp (+6-6)
  • (modified) clang/lib/CodeGen/CGClass.cpp (+8-7)
  • (modified) clang/lib/CodeGen/CGDecl.cpp (+1-1)
  • (modified) clang/lib/CodeGen/CGExpr.cpp (+39-34)
  • (modified) clang/lib/CodeGen/CGExprScalar.cpp (+47-37)
  • (modified) clang/lib/CodeGen/CGObjC.cpp (+1-1)
  • (modified) clang/lib/CodeGen/CodeGenFunction.cpp (+5-4)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+8-5)
  • (modified) clang/lib/CodeGen/ItaniumCXXABI.cpp (+2-2)
diff --git a/clang/include/clang/Basic/Sanitizers.h b/clang/include/clang/Basic/Sanitizers.h
index c890242269b334..a12e779215bcf0 100644
--- a/clang/include/clang/Basic/Sanitizers.h
+++ b/clang/include/clang/Basic/Sanitizers.h
@@ -161,6 +161,10 @@ struct SanitizerSet {
     return static_cast<bool>(Mask & K);
   }
 
+  bool has(SanitizerKind::SanitizerOrdinal O) const {
+    return has(SanitizerMask::bitPosToMask(O));
+  }
+
   /// Check if one or more sanitizers are enabled.
   bool hasOneOf(SanitizerMask K) const { return static_cast<bool>(Mask & K); }
 
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index ca03fb665d423d..8941e66b92a2df 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -2240,7 +2240,7 @@ Value *CodeGenFunction::EmitCheckedArgForBuiltin(const Expr *E,
   SanitizerScope SanScope(this);
   Value *Cond = Builder.CreateICmpNE(
       ArgValue, llvm::Constant::getNullValue(ArgValue->getType()));
-  EmitCheck(std::make_pair(Cond, SanitizerKind::Builtin),
+  EmitCheck(std::make_pair(Cond, SanitizerKind::SO_Builtin),
             SanitizerHandler::InvalidBuiltin,
             {EmitCheckSourceLocation(E->getExprLoc()),
              llvm::ConstantInt::get(Builder.getInt8Ty(), Kind)},
@@ -2255,7 +2255,7 @@ Value *CodeGenFunction::EmitCheckedArgForAssume(const Expr *E) {
 
   SanitizerScope SanScope(this);
   EmitCheck(
-      std::make_pair(ArgValue, SanitizerKind::Builtin),
+      std::make_pair(ArgValue, SanitizerKind::SO_Builtin),
       SanitizerHandler::InvalidBuiltin,
       {EmitCheckSourceLocation(E->getExprLoc()),
        llvm::ConstantInt::get(Builder.getInt8Ty(), BCK_AssumePassedFalse)},
@@ -2290,7 +2290,7 @@ static Value *EmitOverflowCheckedAbs(CodeGenFunction &CGF, const CallExpr *E,
 
   // TODO: support -ftrapv-handler.
   if (SanitizeOverflow) {
-    CGF.EmitCheck({{NotOverflow, SanitizerKind::SignedIntegerOverflow}},
+    CGF.EmitCheck({{NotOverflow, SanitizerKind::SO_SignedIntegerOverflow}},
                   SanitizerHandler::NegateOverflow,
                   {CGF.EmitCheckSourceLocation(E->getArg(0)->getExprLoc()),
                    CGF.EmitCheckTypeDescriptor(E->getType())},
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 7b0ef4be986193..379ce00b739aed 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -4024,20 +4024,20 @@ void CodeGenFunction::EmitReturnValueCheck(llvm::Value *RV) {
 
   // Prefer the returns_nonnull attribute if it's present.
   SourceLocation AttrLoc;
-  SanitizerMask CheckKind;
+  SanitizerKind::SanitizerOrdinal CheckKind;
   SanitizerHandler Handler;
   if (RetNNAttr) {
     assert(!requiresReturnValueNullabilityCheck() &&
            "Cannot check nullability and the nonnull attribute");
     AttrLoc = RetNNAttr->getLocation();
-    CheckKind = SanitizerKind::ReturnsNonnullAttribute;
+    CheckKind = SanitizerKind::SO_ReturnsNonnullAttribute;
     Handler = SanitizerHandler::NonnullReturn;
   } else {
     if (auto *DD = dyn_cast<DeclaratorDecl>(CurCodeDecl))
       if (auto *TSI = DD->getTypeSourceInfo())
         if (auto FTL = TSI->getTypeLoc().getAsAdjusted<FunctionTypeLoc>())
           AttrLoc = FTL.getReturnLoc().findNullabilityLoc();
-    CheckKind = SanitizerKind::NullabilityReturn;
+    CheckKind = SanitizerKind::SO_NullabilityReturn;
     Handler = SanitizerHandler::NullabilityReturn;
   }
 
@@ -4419,15 +4419,15 @@ void CodeGenFunction::EmitNonNullArgCheck(RValue RV, QualType ArgType,
     return;
 
   SourceLocation AttrLoc;
-  SanitizerMask CheckKind;
+  SanitizerKind::SanitizerOrdinal CheckKind;
   SanitizerHandler Handler;
   if (NNAttr) {
     AttrLoc = NNAttr->getLocation();
-    CheckKind = SanitizerKind::NonnullAttribute;
+    CheckKind = SanitizerKind::SO_NonnullAttribute;
     Handler = SanitizerHandler::NonnullArg;
   } else {
     AttrLoc = PVD->getTypeSourceInfo()->getTypeLoc().findNullabilityLoc();
-    CheckKind = SanitizerKind::NullabilityArg;
+    CheckKind = SanitizerKind::SO_NullabilityArg;
     Handler = SanitizerHandler::NullabilityArg;
   }
 
diff --git a/clang/lib/CodeGen/CGClass.cpp b/clang/lib/CodeGen/CGClass.cpp
index c45688bd1ed3ce..16a1c1cebdaa44 100644
--- a/clang/lib/CodeGen/CGClass.cpp
+++ b/clang/lib/CodeGen/CGClass.cpp
@@ -2843,23 +2843,23 @@ void CodeGenFunction::EmitVTablePtrCheck(const CXXRecordDecl *RD,
       !CGM.HasHiddenLTOVisibility(RD))
     return;
 
-  SanitizerMask M;
+  SanitizerKind::SanitizerOrdinal M;
   llvm::SanitizerStatKind SSK;
   switch (TCK) {
   case CFITCK_VCall:
-    M = SanitizerKind::CFIVCall;
+    M = SanitizerKind::SO_CFIVCall;
     SSK = llvm::SanStat_CFI_VCall;
     break;
   case CFITCK_NVCall:
-    M = SanitizerKind::CFINVCall;
+    M = SanitizerKind::SO_CFINVCall;
     SSK = llvm::SanStat_CFI_NVCall;
     break;
   case CFITCK_DerivedCast:
-    M = SanitizerKind::CFIDerivedCast;
+    M = SanitizerKind::SO_CFIDerivedCast;
     SSK = llvm::SanStat_CFI_DerivedCast;
     break;
   case CFITCK_UnrelatedCast:
-    M = SanitizerKind::CFIUnrelatedCast;
+    M = SanitizerKind::SO_CFIUnrelatedCast;
     SSK = llvm::SanStat_CFI_UnrelatedCast;
     break;
   case CFITCK_ICall:
@@ -2869,7 +2869,8 @@ void CodeGenFunction::EmitVTablePtrCheck(const CXXRecordDecl *RD,
   }
 
   std::string TypeName = RD->getQualifiedNameAsString();
-  if (getContext().getNoSanitizeList().containsType(M, TypeName))
+  if (getContext().getNoSanitizeList().containsType(
+          SanitizerMask::bitPosToMask(M), TypeName))
     return;
 
   SanitizerScope SanScope(this);
@@ -2945,7 +2946,7 @@ llvm::Value *CodeGenFunction::EmitVTableTypeCheckedLoad(
   if (SanOpts.has(SanitizerKind::CFIVCall) &&
       !getContext().getNoSanitizeList().containsType(SanitizerKind::CFIVCall,
                                                      TypeName)) {
-    EmitCheck(std::make_pair(CheckResult, SanitizerKind::CFIVCall),
+    EmitCheck(std::make_pair(CheckResult, SanitizerKind::SO_CFIVCall),
               SanitizerHandler::CFICheckFail, {}, {});
   }
 
diff --git a/clang/lib/CodeGen/CGDecl.cpp b/clang/lib/CodeGen/CGDecl.cpp
index 6f3ff050cb6978..f85e0b2c404f92 100644
--- a/clang/lib/CodeGen/CGDecl.cpp
+++ b/clang/lib/CodeGen/CGDecl.cpp
@@ -762,7 +762,7 @@ void CodeGenFunction::EmitNullabilityCheck(LValue LHS, llvm::Value *RHS,
       EmitCheckSourceLocation(Loc), EmitCheckTypeDescriptor(LHS.getType()),
       llvm::ConstantInt::get(Int8Ty, 0), // The LogAlignment info is unused.
       llvm::ConstantInt::get(Int8Ty, TCK_NonnullAssign)};
-  EmitCheck({{IsNotNull, SanitizerKind::NullabilityAssign}},
+  EmitCheck({{IsNotNull, SanitizerKind::SO_NullabilityAssign}},
             SanitizerHandler::TypeMismatch, StaticData, RHS);
 }
 
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 1bad7a722da07a..060d02b7f14873 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -735,7 +735,8 @@ void CodeGenFunction::EmitTypeCheck(TypeCheckKind TCK, SourceLocation Loc,
 
   SanitizerScope SanScope(this);
 
-  SmallVector<std::pair<llvm::Value *, SanitizerMask>, 3> Checks;
+  SmallVector<std::pair<llvm::Value *, SanitizerKind::SanitizerOrdinal>, 3>
+      Checks;
   llvm::BasicBlock *Done = nullptr;
 
   // Quickly determine whether we have a pointer to an alloca. It's possible
@@ -767,7 +768,7 @@ void CodeGenFunction::EmitTypeCheck(TypeCheckKind TCK, SourceLocation Loc,
         Builder.CreateCondBr(IsNonNull, Rest, Done);
         EmitBlock(Rest);
       } else {
-        Checks.push_back(std::make_pair(IsNonNull, SanitizerKind::Null));
+        Checks.push_back(std::make_pair(IsNonNull, SanitizerKind::SO_Null));
       }
     }
   }
@@ -794,7 +795,8 @@ void CodeGenFunction::EmitTypeCheck(TypeCheckKind TCK, SourceLocation Loc,
       llvm::Value *Dynamic = Builder.getFalse();
       llvm::Value *LargeEnough = Builder.CreateICmpUGE(
           Builder.CreateCall(F, {Ptr, Min, NullIsUnknown, Dynamic}), Size);
-      Checks.push_back(std::make_pair(LargeEnough, SanitizerKind::ObjectSize));
+      Checks.push_back(
+          std::make_pair(LargeEnough, SanitizerKind::SO_ObjectSize));
     }
   }
 
@@ -818,7 +820,7 @@ void CodeGenFunction::EmitTypeCheck(TypeCheckKind TCK, SourceLocation Loc,
       llvm::Value *Aligned =
           Builder.CreateICmpEQ(Align, llvm::ConstantInt::get(IntPtrTy, 0));
       if (Aligned != True)
-        Checks.push_back(std::make_pair(Aligned, SanitizerKind::Alignment));
+        Checks.push_back(std::make_pair(Aligned, SanitizerKind::SO_Alignment));
     }
   }
 
@@ -902,7 +904,7 @@ void CodeGenFunction::EmitTypeCheck(TypeCheckKind TCK, SourceLocation Loc,
         llvm::ConstantInt::get(Int8Ty, TCK)
       };
       llvm::Value *DynamicData[] = { Ptr, Hash };
-      EmitCheck(std::make_pair(EqualHash, SanitizerKind::Vptr),
+      EmitCheck(std::make_pair(EqualHash, SanitizerKind::SO_Vptr),
                 SanitizerHandler::DynamicTypeCacheMiss, StaticData,
                 DynamicData);
     }
@@ -1220,7 +1222,7 @@ void CodeGenFunction::EmitBoundsCheckImpl(const Expr *E, llvm::Value *Bound,
   };
   llvm::Value *Check = Accessed ? Builder.CreateICmpULT(IndexVal, BoundVal)
                                 : Builder.CreateICmpULE(IndexVal, BoundVal);
-  EmitCheck(std::make_pair(Check, SanitizerKind::ArrayBounds),
+  EmitCheck(std::make_pair(Check, SanitizerKind::SO_ArrayBounds),
             SanitizerHandler::OutOfBounds, StaticData, Index);
 }
 
@@ -1960,8 +1962,8 @@ bool CodeGenFunction::EmitScalarRangeCheck(llvm::Value *Value, QualType Ty,
   }
   llvm::Constant *StaticArgs[] = {EmitCheckSourceLocation(Loc),
                                   EmitCheckTypeDescriptor(Ty)};
-  SanitizerMask Kind =
-      NeedsEnumCheck ? SanitizerKind::Enum : SanitizerKind::Bool;
+  SanitizerKind::SanitizerOrdinal Kind =
+      NeedsEnumCheck ? SanitizerKind::SO_Enum : SanitizerKind::SO_Bool;
   EmitCheck(std::make_pair(Check, Kind), SanitizerHandler::LoadInvalidValue,
             StaticArgs, EmitCheckValue(Value));
   return true;
@@ -3513,11 +3515,12 @@ enum class CheckRecoverableKind {
 };
 }
 
-static CheckRecoverableKind getRecoverableKind(SanitizerMask Kind) {
-  assert(Kind.countPopulation() == 1);
-  if (Kind == SanitizerKind::Vptr)
+static CheckRecoverableKind
+getRecoverableKind(SanitizerKind::SanitizerOrdinal Ordinal) {
+  if (Ordinal == SanitizerKind::SO_Vptr)
     return CheckRecoverableKind::AlwaysRecoverable;
-  else if (Kind == SanitizerKind::Return || Kind == SanitizerKind::Unreachable)
+  else if (Ordinal == SanitizerKind::SO_Return ||
+           Ordinal == SanitizerKind::SO_Unreachable)
     return CheckRecoverableKind::Unrecoverable;
   else
     return CheckRecoverableKind::Recoverable;
@@ -3589,7 +3592,7 @@ static void emitCheckHandlerCall(CodeGenFunction &CGF,
 }
 
 void CodeGenFunction::EmitCheck(
-    ArrayRef<std::pair<llvm::Value *, SanitizerMask>> Checked,
+    ArrayRef<std::pair<llvm::Value *, SanitizerKind::SanitizerOrdinal>> Checked,
     SanitizerHandler CheckHandler, ArrayRef<llvm::Constant *> StaticArgs,
     ArrayRef<llvm::Value *> DynamicArgs) {
   assert(IsSanitizerScope);
@@ -3715,8 +3718,9 @@ void CodeGenFunction::EmitCheck(
 }
 
 void CodeGenFunction::EmitCfiSlowPathCheck(
-    SanitizerMask Kind, llvm::Value *Cond, llvm::ConstantInt *TypeId,
-    llvm::Value *Ptr, ArrayRef<llvm::Constant *> StaticArgs) {
+    SanitizerKind::SanitizerOrdinal Ordinal, llvm::Value *Cond,
+    llvm::ConstantInt *TypeId, llvm::Value *Ptr,
+    ArrayRef<llvm::Constant *> StaticArgs) {
   llvm::BasicBlock *Cont = createBasicBlock("cfi.cont");
 
   llvm::BasicBlock *CheckBB = createBasicBlock("cfi.slowpath");
@@ -3728,7 +3732,7 @@ void CodeGenFunction::EmitCfiSlowPathCheck(
 
   EmitBlock(CheckBB);
 
-  bool WithDiag = !CGM.getCodeGenOpts().SanitizeTrap.has(Kind);
+  bool WithDiag = !CGM.getCodeGenOpts().SanitizeTrap.has(Ordinal);
 
   llvm::CallInst *CheckCall;
   llvm::FunctionCallee SlowPathFn;
@@ -3860,22 +3864,23 @@ void CodeGenFunction::EmitCfiCheckFail() {
                          {Addr, AllVtables}),
       IntPtrTy);
 
-  const std::pair<int, SanitizerMask> CheckKinds[] = {
-      {CFITCK_VCall, SanitizerKind::CFIVCall},
-      {CFITCK_NVCall, SanitizerKind::CFINVCall},
-      {CFITCK_DerivedCast, SanitizerKind::CFIDerivedCast},
-      {CFITCK_UnrelatedCast, SanitizerKind::CFIUnrelatedCast},
-      {CFITCK_ICall, SanitizerKind::CFIICall}};
-
-  SmallVector<std::pair<llvm::Value *, SanitizerMask>, 5> Checks;
-  for (auto CheckKindMaskPair : CheckKinds) {
-    int Kind = CheckKindMaskPair.first;
-    SanitizerMask Mask = CheckKindMaskPair.second;
+  const std::pair<int, SanitizerKind::SanitizerOrdinal> CheckKinds[] = {
+      {CFITCK_VCall, SanitizerKind::SO_CFIVCall},
+      {CFITCK_NVCall, SanitizerKind::SO_CFINVCall},
+      {CFITCK_DerivedCast, SanitizerKind::SO_CFIDerivedCast},
+      {CFITCK_UnrelatedCast, SanitizerKind::SO_CFIUnrelatedCast},
+      {CFITCK_ICall, SanitizerKind::SO_CFIICall}};
+
+  SmallVector<std::pair<llvm::Value *, SanitizerKind::SanitizerOrdinal>, 5>
+      Checks;
+  for (auto CheckKindOrdinalPair : CheckKinds) {
+    int Kind = CheckKindOrdinalPair.first;
+    SanitizerKind::SanitizerOrdinal Ordinal = CheckKindOrdinalPair.second;
     llvm::Value *Cond =
         Builder.CreateICmpNE(CheckKind, llvm::ConstantInt::get(Int8Ty, Kind));
-    if (CGM.getLangOpts().Sanitize.has(Mask))
-      EmitCheck(std::make_pair(Cond, Mask), SanitizerHandler::CFICheckFail, {},
-                {Data, Addr, ValidVtable});
+    if (CGM.getLangOpts().Sanitize.has(Ordinal))
+      EmitCheck(std::make_pair(Cond, Ordinal), SanitizerHandler::CFICheckFail,
+                {}, {Data, Addr, ValidVtable});
     else
       EmitTrapCheck(Cond, SanitizerHandler::CFICheckFail);
   }
@@ -3890,7 +3895,7 @@ void CodeGenFunction::EmitUnreachable(SourceLocation Loc) {
   if (SanOpts.has(SanitizerKind::Unreachable)) {
     SanitizerScope SanScope(this);
     EmitCheck(std::make_pair(static_cast<llvm::Value *>(Builder.getFalse()),
-                             SanitizerKind::Unreachable),
+                             SanitizerKind::SO_Unreachable),
               SanitizerHandler::BuiltinUnreachable,
               EmitCheckSourceLocation(Loc), {});
   }
@@ -6054,7 +6059,7 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType,
           Builder.CreateICmpEQ(CalleeTypeHash, TypeHash);
       llvm::Constant *StaticData[] = {EmitCheckSourceLocation(E->getBeginLoc()),
                                       EmitCheckTypeDescriptor(CalleeType)};
-      EmitCheck(std::make_pair(CalleeTypeHashMatch, SanitizerKind::Function),
+      EmitCheck(std::make_pair(CalleeTypeHashMatch, SanitizerKind::SO_Function),
                 SanitizerHandler::FunctionTypeMismatch, StaticData,
                 {CalleePtr});
 
@@ -6091,10 +6096,10 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType,
         EmitCheckTypeDescriptor(QualType(FnType, 0)),
     };
     if (CGM.getCodeGenOpts().SanitizeCfiCrossDso && CrossDsoTypeId) {
-      EmitCfiSlowPathCheck(SanitizerKind::CFIICall, TypeTest, CrossDsoTypeId,
+      EmitCfiSlowPathCheck(SanitizerKind::SO_CFIICall, TypeTest, CrossDsoTypeId,
                            CalleePtr, StaticData);
     } else {
-      EmitCheck(std::make_pair(TypeTest, SanitizerKind::CFIICall),
+      EmitCheck(std::make_pair(TypeTest, SanitizerKind::SO_CFIICall),
                 SanitizerHandler::CFICheckFail, StaticData,
                 {CalleePtr, llvm::UndefValue::get(IntPtrTy)});
     }
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 090c4fb3ea39ee..ac499e490ee87e 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -280,8 +280,9 @@ class ScalarExprEmitter
     return CGF.EmitCheckedLValue(E, TCK);
   }
 
-  void EmitBinOpCheck(ArrayRef<std::pair<Value *, SanitizerMask>> Checks,
-                      const BinOpInfo &Info);
+  void EmitBinOpCheck(
+      ArrayRef<std::pair<Value *, SanitizerKind::SanitizerOrdinal>> Checks,
+      const BinOpInfo &Info);
 
   Value *EmitLoadOfLValue(LValue LV, SourceLocation Loc) {
     return CGF.EmitLoadOfLValue(LV, Loc).getScalarVal();
@@ -1047,14 +1048,14 @@ void ScalarExprEmitter::EmitFloatConversionCheck(
   llvm::Constant *StaticArgs[] = {CGF.EmitCheckSourceLocation(Loc),
                                   CGF.EmitCheckTypeDescriptor(OrigSrcType),
                                   CGF.EmitCheckTypeDescriptor(DstType)};
-  CGF.EmitCheck(std::make_pair(Check, SanitizerKind::FloatCastOverflow),
+  CGF.EmitCheck(std::make_pair(Check, SanitizerKind::SO_FloatCastOverflow),
                 SanitizerHandler::FloatCastOverflow, StaticArgs, OrigSrc);
 }
 
 // Should be called within CodeGenFunction::SanitizerScope RAII scope.
 // Returns 'i1 false' when the truncation Src -> Dst was lossy.
 static std::pair<ScalarExprEmitter::ImplicitConversionCheckKind,
-                 std::pair<llvm::Value *, SanitizerMask>>
+                 std::pair<llvm::Value *, SanitizerKind::SanitizerOrdinal>>
 EmitIntegerTruncationCheckHelper(Value *Src, QualType SrcType, Value *Dst,
                                  QualType DstType, CGBuilderTy &Builder) {
   llvm::Type *SrcTy = Src->getType();
@@ -1073,13 +1074,13 @@ EmitIntegerTruncationCheckHelper(Value *Src, QualType SrcType, Value *Dst,
   // If both (src and dst) types are unsigned, then it's an unsigned truncation.
   // Else, it is a signed truncation.
   ScalarExprEmitter::ImplicitConversionCheckKind Kind;
-  SanitizerMask Mask;
+  SanitizerKind::SanitizerOrdinal Ordinal;
   if (!SrcSigned && !DstSigned) {
     Kind = ScalarExprEmitter::ICCK_UnsignedIntegerTruncation;
-    Mask = SanitizerKind::ImplicitUnsignedIntegerTruncation;
+    Ordinal = SanitizerKind::SO_ImplicitUnsignedIntegerTruncation;
   } else {
     Kind = ScalarExprEmitter::ICCK_SignedIntegerTruncation;
-    Mask = SanitizerKind::ImplicitSignedIntegerTruncation;
+    Ordinal = SanitizerKind::SO_ImplicitSignedIntegerTruncation;
   }
 
   llvm::Value *Check = nullptr;
@@ -1088,7 +1089,7 @@ EmitIntegerTruncationCheckHelper(Value *Src, QualType SrcType, Value *Dst,
   // 2. Equality-compare with the original source value
   Check = Builder.CreateICmpEQ(Check, Src, "truncheck");
   // If the comparison result is 'i1 false', then the truncation was lossy.
-  return std::make_pair(Kind, std::make_pair(Check, Mask));
+  return std::make_pair(Kind, std::make_pair(Check, Ordinal));
 }
 
 static bool PromotionIsPotentiallyEligibleForImplicitIntegerConversionCheck(
@@ -1128,7 +1129,7 @@ void ScalarExprEmitter::EmitIntegerTruncationCheck(Value *Src, QualType SrcType,
   CodeGenFunction::SanitizerScope SanScope(&CGF);
 
   std::pair<ScalarExprEmitter::ImplicitConversionCheckKind,
-            std::pair<llvm::Value *, SanitizerMask>>
+            std::pair<llvm::Value *, SanitizerKind::SanitizerOrdinal>>
       Check =
           EmitIntegerTruncationCheckHelper(Src, SrcType, Dst, DstType, Builder);
   // If the comparison result is 'i1 false', then the truncation was lossy.
@@ -1138,7 +1139,8 @@ void ScalarExprEmitter::EmitIntegerTruncationCheck(Value *Src, QualType SrcType,
     return;
 
   // Does some SSCL ignore this type?
-  if (CGF.getContext().isTypeIgnoredBySanitizer(Check.second.second, DstType))
+  if (CGF.getContext().isTypeIgnoredBySanitizer(
+          SanitizerMask::bitPosToMask(Check.second.second), DstType))
     return;
 
   llvm::Constant *StaticArgs[] = {
@@ -1169,7 +1171,7 @@ static llvm::Value *EmitIsNegativeTestHelper(Value *V, QualType VType,
 // Should be called within CodeGenFunction::SanitizerScope RAII scope.
 // Returns 'i1 false' when the conversion Src -> Dst changed the sign.
 static std::pair<ScalarExprEmitter::ImplicitConversionCheckKind,
-                 std::pair<llvm::Value *, SanitizerMask>>
+                 std::pair<llvm::Value *, SanitizerKind::SanitizerOrdinal>>
 EmitIntegerSignChangeCheckHelper(Value *Src, QualType SrcType, Value *Dst,
                                  QualType DstType, CGBuilderTy &Builder) {
   llvm::Type *SrcTy = Src->getType();
@@ -1205,13 +1207,13 @@ EmitIntegerSignChangeCheckHelper(Value *Src, QualType SrcType, Value *Dst,
   // If the comparison result is 'false', then the conversion changed the sign.
   return std::make_pair(
       ScalarExprEmitter::ICCK_IntegerSignChange,
-      std::make_pair(Check, SanitizerKind::ImplicitInteg...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jan 10, 2025

@llvm/pr-subscribers-clang-codegen

Author: Thurston Dang (thurstond)

Changes

The Checked parameter of CodeGenFunction::EmitCheck is of type ArrayRef&lt;std::pair&lt;llvm::Value *, SanitizerMask&gt;&gt;, which is overly generalized: SanitizerMask can denote that zero or more sanitizers are enabled, but EmitCheck requires that exactly one sanitizer is specified in the SanitizerMask (e.g., SanitizeTrap.has(Checked[i].second) enforces that).

This patch replaces SanitizerMask with SanitizerOrdinal in EmitCheck and code that calls it transitively. This should not affect the behavior of UBSan, but it has the advantage that specifying the wrong number of sanitizers in Checked[i].second will be detected as a compile-time error, rather than a runtime assertion failure.

Suggested by Vitaly in #122392 as an alternative to adding an explicit runtime assertion that the SanitizerMask contains exactly one sanitizer.


Patch is 35.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122511.diff

11 Files Affected:

  • (modified) clang/include/clang/Basic/Sanitizers.h (+4)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+3-3)
  • (modified) clang/lib/CodeGen/CGCall.cpp (+6-6)
  • (modified) clang/lib/CodeGen/CGClass.cpp (+8-7)
  • (modified) clang/lib/CodeGen/CGDecl.cpp (+1-1)
  • (modified) clang/lib/CodeGen/CGExpr.cpp (+39-34)
  • (modified) clang/lib/CodeGen/CGExprScalar.cpp (+47-37)
  • (modified) clang/lib/CodeGen/CGObjC.cpp (+1-1)
  • (modified) clang/lib/CodeGen/CodeGenFunction.cpp (+5-4)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+8-5)
  • (modified) clang/lib/CodeGen/ItaniumCXXABI.cpp (+2-2)
diff --git a/clang/include/clang/Basic/Sanitizers.h b/clang/include/clang/Basic/Sanitizers.h
index c890242269b334..a12e779215bcf0 100644
--- a/clang/include/clang/Basic/Sanitizers.h
+++ b/clang/include/clang/Basic/Sanitizers.h
@@ -161,6 +161,10 @@ struct SanitizerSet {
     return static_cast<bool>(Mask & K);
   }
 
+  bool has(SanitizerKind::SanitizerOrdinal O) const {
+    return has(SanitizerMask::bitPosToMask(O));
+  }
+
   /// Check if one or more sanitizers are enabled.
   bool hasOneOf(SanitizerMask K) const { return static_cast<bool>(Mask & K); }
 
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index ca03fb665d423d..8941e66b92a2df 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -2240,7 +2240,7 @@ Value *CodeGenFunction::EmitCheckedArgForBuiltin(const Expr *E,
   SanitizerScope SanScope(this);
   Value *Cond = Builder.CreateICmpNE(
       ArgValue, llvm::Constant::getNullValue(ArgValue->getType()));
-  EmitCheck(std::make_pair(Cond, SanitizerKind::Builtin),
+  EmitCheck(std::make_pair(Cond, SanitizerKind::SO_Builtin),
             SanitizerHandler::InvalidBuiltin,
             {EmitCheckSourceLocation(E->getExprLoc()),
              llvm::ConstantInt::get(Builder.getInt8Ty(), Kind)},
@@ -2255,7 +2255,7 @@ Value *CodeGenFunction::EmitCheckedArgForAssume(const Expr *E) {
 
   SanitizerScope SanScope(this);
   EmitCheck(
-      std::make_pair(ArgValue, SanitizerKind::Builtin),
+      std::make_pair(ArgValue, SanitizerKind::SO_Builtin),
       SanitizerHandler::InvalidBuiltin,
       {EmitCheckSourceLocation(E->getExprLoc()),
        llvm::ConstantInt::get(Builder.getInt8Ty(), BCK_AssumePassedFalse)},
@@ -2290,7 +2290,7 @@ static Value *EmitOverflowCheckedAbs(CodeGenFunction &CGF, const CallExpr *E,
 
   // TODO: support -ftrapv-handler.
   if (SanitizeOverflow) {
-    CGF.EmitCheck({{NotOverflow, SanitizerKind::SignedIntegerOverflow}},
+    CGF.EmitCheck({{NotOverflow, SanitizerKind::SO_SignedIntegerOverflow}},
                   SanitizerHandler::NegateOverflow,
                   {CGF.EmitCheckSourceLocation(E->getArg(0)->getExprLoc()),
                    CGF.EmitCheckTypeDescriptor(E->getType())},
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 7b0ef4be986193..379ce00b739aed 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -4024,20 +4024,20 @@ void CodeGenFunction::EmitReturnValueCheck(llvm::Value *RV) {
 
   // Prefer the returns_nonnull attribute if it's present.
   SourceLocation AttrLoc;
-  SanitizerMask CheckKind;
+  SanitizerKind::SanitizerOrdinal CheckKind;
   SanitizerHandler Handler;
   if (RetNNAttr) {
     assert(!requiresReturnValueNullabilityCheck() &&
            "Cannot check nullability and the nonnull attribute");
     AttrLoc = RetNNAttr->getLocation();
-    CheckKind = SanitizerKind::ReturnsNonnullAttribute;
+    CheckKind = SanitizerKind::SO_ReturnsNonnullAttribute;
     Handler = SanitizerHandler::NonnullReturn;
   } else {
     if (auto *DD = dyn_cast<DeclaratorDecl>(CurCodeDecl))
       if (auto *TSI = DD->getTypeSourceInfo())
         if (auto FTL = TSI->getTypeLoc().getAsAdjusted<FunctionTypeLoc>())
           AttrLoc = FTL.getReturnLoc().findNullabilityLoc();
-    CheckKind = SanitizerKind::NullabilityReturn;
+    CheckKind = SanitizerKind::SO_NullabilityReturn;
     Handler = SanitizerHandler::NullabilityReturn;
   }
 
@@ -4419,15 +4419,15 @@ void CodeGenFunction::EmitNonNullArgCheck(RValue RV, QualType ArgType,
     return;
 
   SourceLocation AttrLoc;
-  SanitizerMask CheckKind;
+  SanitizerKind::SanitizerOrdinal CheckKind;
   SanitizerHandler Handler;
   if (NNAttr) {
     AttrLoc = NNAttr->getLocation();
-    CheckKind = SanitizerKind::NonnullAttribute;
+    CheckKind = SanitizerKind::SO_NonnullAttribute;
     Handler = SanitizerHandler::NonnullArg;
   } else {
     AttrLoc = PVD->getTypeSourceInfo()->getTypeLoc().findNullabilityLoc();
-    CheckKind = SanitizerKind::NullabilityArg;
+    CheckKind = SanitizerKind::SO_NullabilityArg;
     Handler = SanitizerHandler::NullabilityArg;
   }
 
diff --git a/clang/lib/CodeGen/CGClass.cpp b/clang/lib/CodeGen/CGClass.cpp
index c45688bd1ed3ce..16a1c1cebdaa44 100644
--- a/clang/lib/CodeGen/CGClass.cpp
+++ b/clang/lib/CodeGen/CGClass.cpp
@@ -2843,23 +2843,23 @@ void CodeGenFunction::EmitVTablePtrCheck(const CXXRecordDecl *RD,
       !CGM.HasHiddenLTOVisibility(RD))
     return;
 
-  SanitizerMask M;
+  SanitizerKind::SanitizerOrdinal M;
   llvm::SanitizerStatKind SSK;
   switch (TCK) {
   case CFITCK_VCall:
-    M = SanitizerKind::CFIVCall;
+    M = SanitizerKind::SO_CFIVCall;
     SSK = llvm::SanStat_CFI_VCall;
     break;
   case CFITCK_NVCall:
-    M = SanitizerKind::CFINVCall;
+    M = SanitizerKind::SO_CFINVCall;
     SSK = llvm::SanStat_CFI_NVCall;
     break;
   case CFITCK_DerivedCast:
-    M = SanitizerKind::CFIDerivedCast;
+    M = SanitizerKind::SO_CFIDerivedCast;
     SSK = llvm::SanStat_CFI_DerivedCast;
     break;
   case CFITCK_UnrelatedCast:
-    M = SanitizerKind::CFIUnrelatedCast;
+    M = SanitizerKind::SO_CFIUnrelatedCast;
     SSK = llvm::SanStat_CFI_UnrelatedCast;
     break;
   case CFITCK_ICall:
@@ -2869,7 +2869,8 @@ void CodeGenFunction::EmitVTablePtrCheck(const CXXRecordDecl *RD,
   }
 
   std::string TypeName = RD->getQualifiedNameAsString();
-  if (getContext().getNoSanitizeList().containsType(M, TypeName))
+  if (getContext().getNoSanitizeList().containsType(
+          SanitizerMask::bitPosToMask(M), TypeName))
     return;
 
   SanitizerScope SanScope(this);
@@ -2945,7 +2946,7 @@ llvm::Value *CodeGenFunction::EmitVTableTypeCheckedLoad(
   if (SanOpts.has(SanitizerKind::CFIVCall) &&
       !getContext().getNoSanitizeList().containsType(SanitizerKind::CFIVCall,
                                                      TypeName)) {
-    EmitCheck(std::make_pair(CheckResult, SanitizerKind::CFIVCall),
+    EmitCheck(std::make_pair(CheckResult, SanitizerKind::SO_CFIVCall),
               SanitizerHandler::CFICheckFail, {}, {});
   }
 
diff --git a/clang/lib/CodeGen/CGDecl.cpp b/clang/lib/CodeGen/CGDecl.cpp
index 6f3ff050cb6978..f85e0b2c404f92 100644
--- a/clang/lib/CodeGen/CGDecl.cpp
+++ b/clang/lib/CodeGen/CGDecl.cpp
@@ -762,7 +762,7 @@ void CodeGenFunction::EmitNullabilityCheck(LValue LHS, llvm::Value *RHS,
       EmitCheckSourceLocation(Loc), EmitCheckTypeDescriptor(LHS.getType()),
       llvm::ConstantInt::get(Int8Ty, 0), // The LogAlignment info is unused.
       llvm::ConstantInt::get(Int8Ty, TCK_NonnullAssign)};
-  EmitCheck({{IsNotNull, SanitizerKind::NullabilityAssign}},
+  EmitCheck({{IsNotNull, SanitizerKind::SO_NullabilityAssign}},
             SanitizerHandler::TypeMismatch, StaticData, RHS);
 }
 
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 1bad7a722da07a..060d02b7f14873 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -735,7 +735,8 @@ void CodeGenFunction::EmitTypeCheck(TypeCheckKind TCK, SourceLocation Loc,
 
   SanitizerScope SanScope(this);
 
-  SmallVector<std::pair<llvm::Value *, SanitizerMask>, 3> Checks;
+  SmallVector<std::pair<llvm::Value *, SanitizerKind::SanitizerOrdinal>, 3>
+      Checks;
   llvm::BasicBlock *Done = nullptr;
 
   // Quickly determine whether we have a pointer to an alloca. It's possible
@@ -767,7 +768,7 @@ void CodeGenFunction::EmitTypeCheck(TypeCheckKind TCK, SourceLocation Loc,
         Builder.CreateCondBr(IsNonNull, Rest, Done);
         EmitBlock(Rest);
       } else {
-        Checks.push_back(std::make_pair(IsNonNull, SanitizerKind::Null));
+        Checks.push_back(std::make_pair(IsNonNull, SanitizerKind::SO_Null));
       }
     }
   }
@@ -794,7 +795,8 @@ void CodeGenFunction::EmitTypeCheck(TypeCheckKind TCK, SourceLocation Loc,
       llvm::Value *Dynamic = Builder.getFalse();
       llvm::Value *LargeEnough = Builder.CreateICmpUGE(
           Builder.CreateCall(F, {Ptr, Min, NullIsUnknown, Dynamic}), Size);
-      Checks.push_back(std::make_pair(LargeEnough, SanitizerKind::ObjectSize));
+      Checks.push_back(
+          std::make_pair(LargeEnough, SanitizerKind::SO_ObjectSize));
     }
   }
 
@@ -818,7 +820,7 @@ void CodeGenFunction::EmitTypeCheck(TypeCheckKind TCK, SourceLocation Loc,
       llvm::Value *Aligned =
           Builder.CreateICmpEQ(Align, llvm::ConstantInt::get(IntPtrTy, 0));
       if (Aligned != True)
-        Checks.push_back(std::make_pair(Aligned, SanitizerKind::Alignment));
+        Checks.push_back(std::make_pair(Aligned, SanitizerKind::SO_Alignment));
     }
   }
 
@@ -902,7 +904,7 @@ void CodeGenFunction::EmitTypeCheck(TypeCheckKind TCK, SourceLocation Loc,
         llvm::ConstantInt::get(Int8Ty, TCK)
       };
       llvm::Value *DynamicData[] = { Ptr, Hash };
-      EmitCheck(std::make_pair(EqualHash, SanitizerKind::Vptr),
+      EmitCheck(std::make_pair(EqualHash, SanitizerKind::SO_Vptr),
                 SanitizerHandler::DynamicTypeCacheMiss, StaticData,
                 DynamicData);
     }
@@ -1220,7 +1222,7 @@ void CodeGenFunction::EmitBoundsCheckImpl(const Expr *E, llvm::Value *Bound,
   };
   llvm::Value *Check = Accessed ? Builder.CreateICmpULT(IndexVal, BoundVal)
                                 : Builder.CreateICmpULE(IndexVal, BoundVal);
-  EmitCheck(std::make_pair(Check, SanitizerKind::ArrayBounds),
+  EmitCheck(std::make_pair(Check, SanitizerKind::SO_ArrayBounds),
             SanitizerHandler::OutOfBounds, StaticData, Index);
 }
 
@@ -1960,8 +1962,8 @@ bool CodeGenFunction::EmitScalarRangeCheck(llvm::Value *Value, QualType Ty,
   }
   llvm::Constant *StaticArgs[] = {EmitCheckSourceLocation(Loc),
                                   EmitCheckTypeDescriptor(Ty)};
-  SanitizerMask Kind =
-      NeedsEnumCheck ? SanitizerKind::Enum : SanitizerKind::Bool;
+  SanitizerKind::SanitizerOrdinal Kind =
+      NeedsEnumCheck ? SanitizerKind::SO_Enum : SanitizerKind::SO_Bool;
   EmitCheck(std::make_pair(Check, Kind), SanitizerHandler::LoadInvalidValue,
             StaticArgs, EmitCheckValue(Value));
   return true;
@@ -3513,11 +3515,12 @@ enum class CheckRecoverableKind {
 };
 }
 
-static CheckRecoverableKind getRecoverableKind(SanitizerMask Kind) {
-  assert(Kind.countPopulation() == 1);
-  if (Kind == SanitizerKind::Vptr)
+static CheckRecoverableKind
+getRecoverableKind(SanitizerKind::SanitizerOrdinal Ordinal) {
+  if (Ordinal == SanitizerKind::SO_Vptr)
     return CheckRecoverableKind::AlwaysRecoverable;
-  else if (Kind == SanitizerKind::Return || Kind == SanitizerKind::Unreachable)
+  else if (Ordinal == SanitizerKind::SO_Return ||
+           Ordinal == SanitizerKind::SO_Unreachable)
     return CheckRecoverableKind::Unrecoverable;
   else
     return CheckRecoverableKind::Recoverable;
@@ -3589,7 +3592,7 @@ static void emitCheckHandlerCall(CodeGenFunction &CGF,
 }
 
 void CodeGenFunction::EmitCheck(
-    ArrayRef<std::pair<llvm::Value *, SanitizerMask>> Checked,
+    ArrayRef<std::pair<llvm::Value *, SanitizerKind::SanitizerOrdinal>> Checked,
     SanitizerHandler CheckHandler, ArrayRef<llvm::Constant *> StaticArgs,
     ArrayRef<llvm::Value *> DynamicArgs) {
   assert(IsSanitizerScope);
@@ -3715,8 +3718,9 @@ void CodeGenFunction::EmitCheck(
 }
 
 void CodeGenFunction::EmitCfiSlowPathCheck(
-    SanitizerMask Kind, llvm::Value *Cond, llvm::ConstantInt *TypeId,
-    llvm::Value *Ptr, ArrayRef<llvm::Constant *> StaticArgs) {
+    SanitizerKind::SanitizerOrdinal Ordinal, llvm::Value *Cond,
+    llvm::ConstantInt *TypeId, llvm::Value *Ptr,
+    ArrayRef<llvm::Constant *> StaticArgs) {
   llvm::BasicBlock *Cont = createBasicBlock("cfi.cont");
 
   llvm::BasicBlock *CheckBB = createBasicBlock("cfi.slowpath");
@@ -3728,7 +3732,7 @@ void CodeGenFunction::EmitCfiSlowPathCheck(
 
   EmitBlock(CheckBB);
 
-  bool WithDiag = !CGM.getCodeGenOpts().SanitizeTrap.has(Kind);
+  bool WithDiag = !CGM.getCodeGenOpts().SanitizeTrap.has(Ordinal);
 
   llvm::CallInst *CheckCall;
   llvm::FunctionCallee SlowPathFn;
@@ -3860,22 +3864,23 @@ void CodeGenFunction::EmitCfiCheckFail() {
                          {Addr, AllVtables}),
       IntPtrTy);
 
-  const std::pair<int, SanitizerMask> CheckKinds[] = {
-      {CFITCK_VCall, SanitizerKind::CFIVCall},
-      {CFITCK_NVCall, SanitizerKind::CFINVCall},
-      {CFITCK_DerivedCast, SanitizerKind::CFIDerivedCast},
-      {CFITCK_UnrelatedCast, SanitizerKind::CFIUnrelatedCast},
-      {CFITCK_ICall, SanitizerKind::CFIICall}};
-
-  SmallVector<std::pair<llvm::Value *, SanitizerMask>, 5> Checks;
-  for (auto CheckKindMaskPair : CheckKinds) {
-    int Kind = CheckKindMaskPair.first;
-    SanitizerMask Mask = CheckKindMaskPair.second;
+  const std::pair<int, SanitizerKind::SanitizerOrdinal> CheckKinds[] = {
+      {CFITCK_VCall, SanitizerKind::SO_CFIVCall},
+      {CFITCK_NVCall, SanitizerKind::SO_CFINVCall},
+      {CFITCK_DerivedCast, SanitizerKind::SO_CFIDerivedCast},
+      {CFITCK_UnrelatedCast, SanitizerKind::SO_CFIUnrelatedCast},
+      {CFITCK_ICall, SanitizerKind::SO_CFIICall}};
+
+  SmallVector<std::pair<llvm::Value *, SanitizerKind::SanitizerOrdinal>, 5>
+      Checks;
+  for (auto CheckKindOrdinalPair : CheckKinds) {
+    int Kind = CheckKindOrdinalPair.first;
+    SanitizerKind::SanitizerOrdinal Ordinal = CheckKindOrdinalPair.second;
     llvm::Value *Cond =
         Builder.CreateICmpNE(CheckKind, llvm::ConstantInt::get(Int8Ty, Kind));
-    if (CGM.getLangOpts().Sanitize.has(Mask))
-      EmitCheck(std::make_pair(Cond, Mask), SanitizerHandler::CFICheckFail, {},
-                {Data, Addr, ValidVtable});
+    if (CGM.getLangOpts().Sanitize.has(Ordinal))
+      EmitCheck(std::make_pair(Cond, Ordinal), SanitizerHandler::CFICheckFail,
+                {}, {Data, Addr, ValidVtable});
     else
       EmitTrapCheck(Cond, SanitizerHandler::CFICheckFail);
   }
@@ -3890,7 +3895,7 @@ void CodeGenFunction::EmitUnreachable(SourceLocation Loc) {
   if (SanOpts.has(SanitizerKind::Unreachable)) {
     SanitizerScope SanScope(this);
     EmitCheck(std::make_pair(static_cast<llvm::Value *>(Builder.getFalse()),
-                             SanitizerKind::Unreachable),
+                             SanitizerKind::SO_Unreachable),
               SanitizerHandler::BuiltinUnreachable,
               EmitCheckSourceLocation(Loc), {});
   }
@@ -6054,7 +6059,7 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType,
           Builder.CreateICmpEQ(CalleeTypeHash, TypeHash);
       llvm::Constant *StaticData[] = {EmitCheckSourceLocation(E->getBeginLoc()),
                                       EmitCheckTypeDescriptor(CalleeType)};
-      EmitCheck(std::make_pair(CalleeTypeHashMatch, SanitizerKind::Function),
+      EmitCheck(std::make_pair(CalleeTypeHashMatch, SanitizerKind::SO_Function),
                 SanitizerHandler::FunctionTypeMismatch, StaticData,
                 {CalleePtr});
 
@@ -6091,10 +6096,10 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType,
         EmitCheckTypeDescriptor(QualType(FnType, 0)),
     };
     if (CGM.getCodeGenOpts().SanitizeCfiCrossDso && CrossDsoTypeId) {
-      EmitCfiSlowPathCheck(SanitizerKind::CFIICall, TypeTest, CrossDsoTypeId,
+      EmitCfiSlowPathCheck(SanitizerKind::SO_CFIICall, TypeTest, CrossDsoTypeId,
                            CalleePtr, StaticData);
     } else {
-      EmitCheck(std::make_pair(TypeTest, SanitizerKind::CFIICall),
+      EmitCheck(std::make_pair(TypeTest, SanitizerKind::SO_CFIICall),
                 SanitizerHandler::CFICheckFail, StaticData,
                 {CalleePtr, llvm::UndefValue::get(IntPtrTy)});
     }
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 090c4fb3ea39ee..ac499e490ee87e 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -280,8 +280,9 @@ class ScalarExprEmitter
     return CGF.EmitCheckedLValue(E, TCK);
   }
 
-  void EmitBinOpCheck(ArrayRef<std::pair<Value *, SanitizerMask>> Checks,
-                      const BinOpInfo &Info);
+  void EmitBinOpCheck(
+      ArrayRef<std::pair<Value *, SanitizerKind::SanitizerOrdinal>> Checks,
+      const BinOpInfo &Info);
 
   Value *EmitLoadOfLValue(LValue LV, SourceLocation Loc) {
     return CGF.EmitLoadOfLValue(LV, Loc).getScalarVal();
@@ -1047,14 +1048,14 @@ void ScalarExprEmitter::EmitFloatConversionCheck(
   llvm::Constant *StaticArgs[] = {CGF.EmitCheckSourceLocation(Loc),
                                   CGF.EmitCheckTypeDescriptor(OrigSrcType),
                                   CGF.EmitCheckTypeDescriptor(DstType)};
-  CGF.EmitCheck(std::make_pair(Check, SanitizerKind::FloatCastOverflow),
+  CGF.EmitCheck(std::make_pair(Check, SanitizerKind::SO_FloatCastOverflow),
                 SanitizerHandler::FloatCastOverflow, StaticArgs, OrigSrc);
 }
 
 // Should be called within CodeGenFunction::SanitizerScope RAII scope.
 // Returns 'i1 false' when the truncation Src -> Dst was lossy.
 static std::pair<ScalarExprEmitter::ImplicitConversionCheckKind,
-                 std::pair<llvm::Value *, SanitizerMask>>
+                 std::pair<llvm::Value *, SanitizerKind::SanitizerOrdinal>>
 EmitIntegerTruncationCheckHelper(Value *Src, QualType SrcType, Value *Dst,
                                  QualType DstType, CGBuilderTy &Builder) {
   llvm::Type *SrcTy = Src->getType();
@@ -1073,13 +1074,13 @@ EmitIntegerTruncationCheckHelper(Value *Src, QualType SrcType, Value *Dst,
   // If both (src and dst) types are unsigned, then it's an unsigned truncation.
   // Else, it is a signed truncation.
   ScalarExprEmitter::ImplicitConversionCheckKind Kind;
-  SanitizerMask Mask;
+  SanitizerKind::SanitizerOrdinal Ordinal;
   if (!SrcSigned && !DstSigned) {
     Kind = ScalarExprEmitter::ICCK_UnsignedIntegerTruncation;
-    Mask = SanitizerKind::ImplicitUnsignedIntegerTruncation;
+    Ordinal = SanitizerKind::SO_ImplicitUnsignedIntegerTruncation;
   } else {
     Kind = ScalarExprEmitter::ICCK_SignedIntegerTruncation;
-    Mask = SanitizerKind::ImplicitSignedIntegerTruncation;
+    Ordinal = SanitizerKind::SO_ImplicitSignedIntegerTruncation;
   }
 
   llvm::Value *Check = nullptr;
@@ -1088,7 +1089,7 @@ EmitIntegerTruncationCheckHelper(Value *Src, QualType SrcType, Value *Dst,
   // 2. Equality-compare with the original source value
   Check = Builder.CreateICmpEQ(Check, Src, "truncheck");
   // If the comparison result is 'i1 false', then the truncation was lossy.
-  return std::make_pair(Kind, std::make_pair(Check, Mask));
+  return std::make_pair(Kind, std::make_pair(Check, Ordinal));
 }
 
 static bool PromotionIsPotentiallyEligibleForImplicitIntegerConversionCheck(
@@ -1128,7 +1129,7 @@ void ScalarExprEmitter::EmitIntegerTruncationCheck(Value *Src, QualType SrcType,
   CodeGenFunction::SanitizerScope SanScope(&CGF);
 
   std::pair<ScalarExprEmitter::ImplicitConversionCheckKind,
-            std::pair<llvm::Value *, SanitizerMask>>
+            std::pair<llvm::Value *, SanitizerKind::SanitizerOrdinal>>
       Check =
           EmitIntegerTruncationCheckHelper(Src, SrcType, Dst, DstType, Builder);
   // If the comparison result is 'i1 false', then the truncation was lossy.
@@ -1138,7 +1139,8 @@ void ScalarExprEmitter::EmitIntegerTruncationCheck(Value *Src, QualType SrcType,
     return;
 
   // Does some SSCL ignore this type?
-  if (CGF.getContext().isTypeIgnoredBySanitizer(Check.second.second, DstType))
+  if (CGF.getContext().isTypeIgnoredBySanitizer(
+          SanitizerMask::bitPosToMask(Check.second.second), DstType))
     return;
 
   llvm::Constant *StaticArgs[] = {
@@ -1169,7 +1171,7 @@ static llvm::Value *EmitIsNegativeTestHelper(Value *V, QualType VType,
 // Should be called within CodeGenFunction::SanitizerScope RAII scope.
 // Returns 'i1 false' when the conversion Src -> Dst changed the sign.
 static std::pair<ScalarExprEmitter::ImplicitConversionCheckKind,
-                 std::pair<llvm::Value *, SanitizerMask>>
+                 std::pair<llvm::Value *, SanitizerKind::SanitizerOrdinal>>
 EmitIntegerSignChangeCheckHelper(Value *Src, QualType SrcType, Value *Dst,
                                  QualType DstType, CGBuilderTy &Builder) {
   llvm::Type *SrcTy = Src->getType();
@@ -1205,13 +1207,13 @@ EmitIntegerSignChangeCheckHelper(Value *Src, QualType SrcType, Value *Dst,
   // If the comparison result is 'false', then the conversion changed the sign.
   return std::make_pair(
       ScalarExprEmitter::ICCK_IntegerSignChange,
-      std::make_pair(Check, SanitizerKind::ImplicitInteg...
[truncated]

@vitalybuka
Copy link
Collaborator

I assume this is NFC[I], please make it in title
and there was not a single instance where >1 bits are set?

@vitalybuka
Copy link
Collaborator

I would describe reason as "to avoid ambiguity in EmitCheck about what to do if multiple bits are set".
But your description is OK as well.

@thurstond
Copy link
Contributor Author

I assume this is NFC[I], please make it in title and there was not a single instance where >1 bits are set?

I'll add NFCI to the title.

There's never any case where there's more than one (or zero) sanitizers in the SanitizerMask. The EmitCheck function would already fail a runtime assertion if there wasn't exactly one sanitizer set:

void CodeGenFunction::EmitCheck(
  ...
  for (int i = 0, n = Checked.size(); i < n; ++i) {
    ...
    llvm::Value *&Cond =
        CGM.getCodeGenOpts().SanitizeTrap.has(Checked[i].second)
  bool has(SanitizerMask K) const {
    assert(K.isPowerOf2() && "Has to be a single sanitizer.");
    return static_cast<bool>(Mask & K);
  }

@thurstond thurstond changed the title [ubsan] Use SanitizerOrdinal instead of SanitizerMask for EmitCheck (exactly one sanitizer is required) [ubsan][NFCI] Use SanitizerOrdinal instead of SanitizerMask for EmitCheck (exactly one sanitizer is required) Jan 10, 2025
@thurstond
Copy link
Contributor Author

I would describe reason as "to avoid ambiguity in EmitCheck about what to do if multiple bits are set". But your description is OK as well.

I've added that to the description.

@thurstond thurstond merged commit 55b5875 into llvm:main Jan 10, 2025
10 of 11 checks passed
BaiXilin pushed a commit to BaiXilin/llvm-fix-vnni-instr-types that referenced this pull request Jan 12, 2025
…heck (exactly one sanitizer is required) (llvm#122511)

The `Checked` parameter of `CodeGenFunction::EmitCheck` is of type
`ArrayRef<std::pair<llvm::Value *, SanitizerMask>>`, which is overly
generalized: SanitizerMask can denote that zero or more sanitizers are
enabled, but `EmitCheck` requires that exactly one sanitizer is
specified in the SanitizerMask (e.g.,
`SanitizeTrap.has(Checked[i].second)` enforces that).

This patch replaces SanitizerMask with SanitizerOrdinal in the `Checked`
parameter of `EmitCheck` and code that transitively relies on it. This
should not affect the behavior of UBSan, but it has the advantages that:
- the code is clearer: it avoids ambiguity in EmitCheck about what to do
if multiple bits are set
- specifying the wrong number of sanitizers in `Checked[i].second` will
be detected as a compile-time error, rather than a runtime assertion
failure

Suggested by Vitaly in llvm#122392
as an alternative to adding an explicit runtime assertion that the
SanitizerMask contains exactly one sanitizer.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants