@@ -38,6 +38,8 @@ using namespace clang;
38
38
using namespace ento ;
39
39
using namespace taint ;
40
40
41
+ using llvm::ImmutableSet;
42
+
41
43
namespace {
42
44
43
45
class GenericTaintChecker ;
@@ -434,7 +436,9 @@ template <> struct ScalarEnumerationTraits<TaintConfiguration::VariadicType> {
434
436
// / to the call post-visit. The values are signed integers, which are either
435
437
// / ReturnValueIndex, or indexes of the pointer/reference argument, which
436
438
// / points to data, which should be tainted on return.
437
- REGISTER_SET_WITH_PROGRAMSTATE (TaintArgsOnPostVisit, ArgIdxTy)
439
+ REGISTER_MAP_WITH_PROGRAMSTATE (TaintArgsOnPostVisit, const LocationContext *,
440
+ ImmutableSet<ArgIdxTy>)
441
+ REGISTER_SET_FACTORY_WITH_PROGRAMSTATE(ArgIdxFactory, ArgIdxTy)
438
442
439
443
void GenericTaintRuleParser::validateArgVector(const std::string &Option,
440
444
const ArgVecTy &Args) const {
@@ -685,22 +689,26 @@ void GenericTaintChecker::checkPostCall(const CallEvent &Call,
685
689
// Set the marked values as tainted. The return value only accessible from
686
690
// checkPostStmt.
687
691
ProgramStateRef State = C.getState ();
692
+ const StackFrameContext *CurrentFrame = C.getStackFrame ();
688
693
689
694
// Depending on what was tainted at pre-visit, we determined a set of
690
695
// arguments which should be tainted after the function returns. These are
691
696
// stored in the state as TaintArgsOnPostVisit set.
692
- TaintArgsOnPostVisitTy TaintArgs = State->get <TaintArgsOnPostVisit>();
693
- if (TaintArgs.isEmpty ())
697
+ TaintArgsOnPostVisitTy TaintArgsMap = State->get <TaintArgsOnPostVisit>();
698
+
699
+ const ImmutableSet<ArgIdxTy> *TaintArgs = TaintArgsMap.lookup (CurrentFrame);
700
+ if (!TaintArgs)
694
701
return ;
702
+ assert (!TaintArgs->isEmpty ());
695
703
696
704
LLVM_DEBUG (for (ArgIdxTy I
697
- : TaintArgs) {
705
+ : * TaintArgs) {
698
706
llvm::dbgs () << " PostCall<" ;
699
707
Call.dump (llvm::dbgs ());
700
708
llvm::dbgs () << " > actually wants to taint arg index: " << I << ' \n ' ;
701
709
});
702
710
703
- for (ArgIdxTy ArgNum : TaintArgs) {
711
+ for (ArgIdxTy ArgNum : * TaintArgs) {
704
712
// Special handling for the tainted return value.
705
713
if (ArgNum == ReturnValueIndex) {
706
714
State = addTaint (State, Call.getReturnValue ());
@@ -714,7 +722,7 @@ void GenericTaintChecker::checkPostCall(const CallEvent &Call,
714
722
}
715
723
716
724
// Clear up the taint info from the state.
717
- State = State->remove <TaintArgsOnPostVisit>();
725
+ State = State->remove <TaintArgsOnPostVisit>(CurrentFrame );
718
726
C.addTransition (State);
719
727
}
720
728
@@ -776,28 +784,33 @@ void GenericTaintRule::process(const GenericTaintChecker &Checker,
776
784
};
777
785
778
786
// / Propagate taint where it is necessary.
787
+ auto &F = State->getStateManager ().get_context <ArgIdxFactory>();
788
+ ImmutableSet<ArgIdxTy> Result = F.getEmptySet ();
779
789
ForEachCallArg (
780
- [this , &State, WouldEscape, &Call](ArgIdxTy I, const Expr *E, SVal V) {
790
+ [this , WouldEscape, &Call, &Result, &F](ArgIdxTy I, const Expr *E,
791
+ SVal V) {
781
792
if (PropDstArgs.contains (I)) {
782
793
LLVM_DEBUG (llvm::dbgs () << " PreCall<" ; Call.dump (llvm::dbgs ());
783
794
llvm::dbgs ()
784
795
<< " > prepares tainting arg index: " << I << ' \n ' ;);
785
- State = State-> add <TaintArgsOnPostVisit>( I);
796
+ Result = F. add (Result, I);
786
797
}
787
798
788
799
// TODO: We should traverse all reachable memory regions via the
789
800
// escaping parameter. Instead of doing that we simply mark only the
790
801
// referred memory region as tainted.
791
802
if (WouldEscape (V, E->getType ())) {
792
- LLVM_DEBUG (if (!State-> contains <TaintArgsOnPostVisit> (I)) {
803
+ LLVM_DEBUG (if (!Result. contains (I)) {
793
804
llvm::dbgs () << " PreCall<" ;
794
805
Call.dump (llvm::dbgs ());
795
806
llvm::dbgs () << " > prepares tainting arg index: " << I << ' \n ' ;
796
807
});
797
- State = State-> add <TaintArgsOnPostVisit>( I);
808
+ Result = F. add (Result, I);
798
809
}
799
810
});
800
811
812
+ if (!Result.isEmpty ())
813
+ State = State->set <TaintArgsOnPostVisit>(C.getStackFrame (), Result);
801
814
C.addTransition (State);
802
815
}
803
816
@@ -888,7 +901,11 @@ void GenericTaintChecker::taintUnsafeSocketProtocol(const CallEvent &Call,
888
901
if (SafeProtocol)
889
902
return ;
890
903
891
- C.addTransition (C.getState ()->add <TaintArgsOnPostVisit>(ReturnValueIndex));
904
+ ProgramStateRef State = C.getState ();
905
+ auto &F = State->getStateManager ().get_context <ArgIdxFactory>();
906
+ ImmutableSet<ArgIdxTy> Result = F.add (F.getEmptySet (), ReturnValueIndex);
907
+ State = State->set <TaintArgsOnPostVisit>(C.getStackFrame (), Result);
908
+ C.addTransition (State);
892
909
}
893
910
894
911
// / Checker registration
0 commit comments