Skip to content

Commit 0d8331c

Browse files
committed
[flang] Refine symbol sorting
Replace semantics::SymbolSet with alternatives that clarify whether the set should order its contents by source position or not. This matters because positionally-ordered sets must not be used for Symbols that might be subjected to name replacement during name resolution, and address-ordered sets must not be used (without sorting) in circumstances where the order of their contents affects the output of the compiler. All set<> and map<> instances in the compiler that are keyed by Symbols now have explicit Compare types in their template instantiations. Symbol::operator< is no more. Differential Revision: https://reviews.llvm.org/D98878
1 parent 4c782a2 commit 0d8331c

File tree

17 files changed

+124
-85
lines changed

17 files changed

+124
-85
lines changed

flang/include/flang/Evaluate/constant.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,11 @@ class Constant<Type<TypeCategory::Character, KIND>> : public ConstantBounds {
195195
};
196196

197197
class StructureConstructor;
198-
using StructureConstructorValues =
199-
std::map<SymbolRef, common::CopyableIndirection<Expr<SomeType>>>;
198+
struct ComponentCompare {
199+
bool operator()(SymbolRef x, SymbolRef y) const;
200+
};
201+
using StructureConstructorValues = std::map<SymbolRef,
202+
common::CopyableIndirection<Expr<SomeType>>, ComponentCompare>;
200203

201204
template <>
202205
class Constant<SomeDerived>

flang/include/flang/Evaluate/tools.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -839,10 +839,12 @@ template <typename A> SymbolVector GetSymbolVector(const A &x) {
839839
const Symbol *GetLastTarget(const SymbolVector &);
840840

841841
// Collects all of the Symbols in an expression
842-
template <typename A> semantics::SymbolSet CollectSymbols(const A &);
843-
extern template semantics::SymbolSet CollectSymbols(const Expr<SomeType> &);
844-
extern template semantics::SymbolSet CollectSymbols(const Expr<SomeInteger> &);
845-
extern template semantics::SymbolSet CollectSymbols(
842+
template <typename A> semantics::UnorderedSymbolSet CollectSymbols(const A &);
843+
extern template semantics::UnorderedSymbolSet CollectSymbols(
844+
const Expr<SomeType> &);
845+
extern template semantics::UnorderedSymbolSet CollectSymbols(
846+
const Expr<SomeInteger> &);
847+
extern template semantics::UnorderedSymbolSet CollectSymbols(
846848
const Expr<SubscriptInteger> &);
847849

848850
// Predicate: does a variable contain a vector-valued subscript (not a triplet)?

flang/include/flang/Semantics/semantics.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,9 @@ class SemanticsContext {
198198
parser::CharBlock location;
199199
IndexVarKind kind;
200200
};
201-
std::map<SymbolRef, const IndexVarInfo> activeIndexVars_;
202-
SymbolSet errorSymbols_;
201+
std::map<SymbolRef, const IndexVarInfo, SymbolAddressCompare>
202+
activeIndexVars_;
203+
UnorderedSymbolSet errorSymbols_;
203204
std::set<std::string> tempNames_;
204205
};
205206

flang/include/flang/Semantics/symbol.h

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -596,13 +596,6 @@ class Symbol {
596596
bool operator==(const Symbol &that) const { return this == &that; }
597597
bool operator!=(const Symbol &that) const { return !(*this == that); }
598598

599-
// Symbol comparison is based on the order of cooked source
600-
// stream creation and, when both are from the same cooked source,
601-
// their positions in that cooked source stream.
602-
// (This function is implemented in Evaluate/tools.cpp to
603-
// satisfy complicated shared library interdependency.)
604-
bool operator<(const Symbol &) const;
605-
606599
int Rank() const {
607600
return std::visit(
608601
common::visitors{
@@ -767,13 +760,40 @@ inline const DeclTypeSpec *Symbol::GetType() const {
767760
details_);
768761
}
769762

770-
inline bool operator<(SymbolRef x, SymbolRef y) {
771-
return *x < *y; // name source position ordering
772-
}
773-
inline bool operator<(MutableSymbolRef x, MutableSymbolRef y) {
774-
return *x < *y; // name source position ordering
763+
// Sets and maps keyed by Symbols
764+
765+
struct SymbolAddressCompare {
766+
bool operator()(const SymbolRef &x, const SymbolRef &y) const {
767+
return &*x < &*y;
768+
}
769+
bool operator()(const MutableSymbolRef &x, const MutableSymbolRef &y) const {
770+
return &*x < &*y;
771+
}
772+
};
773+
774+
// Symbol comparison is based on the order of cooked source
775+
// stream creation and, when both are from the same cooked source,
776+
// their positions in that cooked source stream.
777+
// Don't use this comparator or OrderedSymbolSet to hold
778+
// Symbols that might be subject to ReplaceName().
779+
struct SymbolSourcePositionCompare {
780+
// These functions are implemented in Evaluate/tools.cpp to
781+
// satisfy complicated shared library interdependency.
782+
bool operator()(const SymbolRef &, const SymbolRef &) const;
783+
bool operator()(const MutableSymbolRef &, const MutableSymbolRef &) const;
784+
};
785+
786+
using UnorderedSymbolSet = std::set<SymbolRef, SymbolAddressCompare>;
787+
using OrderedSymbolSet = std::set<SymbolRef, SymbolSourcePositionCompare>;
788+
789+
template <typename A>
790+
OrderedSymbolSet OrderBySourcePosition(const A &container) {
791+
OrderedSymbolSet result;
792+
for (SymbolRef x : container) {
793+
result.emplace(x);
794+
}
795+
return result;
775796
}
776-
using SymbolSet = std::set<SymbolRef>;
777797

778798
} // namespace Fortran::semantics
779799

flang/lib/Evaluate/characteristics.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -343,30 +343,29 @@ bool DummyProcedure::operator==(const DummyProcedure &that) const {
343343
procedure.value() == that.procedure.value();
344344
}
345345

346-
static std::string GetSeenProcs(const semantics::SymbolSet &seenProcs) {
346+
static std::string GetSeenProcs(
347+
const semantics::UnorderedSymbolSet &seenProcs) {
347348
// Sort the symbols so that they appear in the same order on all platforms
348-
std::vector<SymbolRef> sorter{seenProcs.begin(), seenProcs.end()};
349-
std::sort(sorter.begin(), sorter.end());
350-
349+
auto ordered{semantics::OrderBySourcePosition(seenProcs)};
351350
std::string result;
352351
llvm::interleave(
353-
sorter,
352+
ordered,
354353
[&](const SymbolRef p) { result += '\'' + p->name().ToString() + '\''; },
355354
[&]() { result += ", "; });
356355
return result;
357356
}
358357

359-
// These functions with arguments of type SymbolSet are used with mutually
360-
// recursive calls when characterizing a Procedure, a DummyArgument, or a
361-
// DummyProcedure to detect circularly defined procedures as required by
358+
// These functions with arguments of type UnorderedSymbolSet are used with
359+
// mutually recursive calls when characterizing a Procedure, a DummyArgument,
360+
// or a DummyProcedure to detect circularly defined procedures as required by
362361
// 15.4.3.6, paragraph 2.
363362
static std::optional<DummyArgument> CharacterizeDummyArgument(
364363
const semantics::Symbol &symbol, FoldingContext &context,
365-
semantics::SymbolSet &seenProcs);
364+
semantics::UnorderedSymbolSet &seenProcs);
366365

367366
static std::optional<Procedure> CharacterizeProcedure(
368367
const semantics::Symbol &original, FoldingContext &context,
369-
semantics::SymbolSet &seenProcs) {
368+
semantics::UnorderedSymbolSet &seenProcs) {
370369
Procedure result;
371370
const auto &symbol{original.GetUltimate()};
372371
if (seenProcs.find(symbol) != seenProcs.end()) {
@@ -475,7 +474,7 @@ static std::optional<Procedure> CharacterizeProcedure(
475474

476475
static std::optional<DummyProcedure> CharacterizeDummyProcedure(
477476
const semantics::Symbol &symbol, FoldingContext &context,
478-
semantics::SymbolSet &seenProcs) {
477+
semantics::UnorderedSymbolSet &seenProcs) {
479478
if (auto procedure{CharacterizeProcedure(symbol, context, seenProcs)}) {
480479
// Dummy procedures may not be elemental. Elemental dummy procedure
481480
// interfaces are errors when the interface is not intrinsic, and that
@@ -516,7 +515,7 @@ bool DummyArgument::operator==(const DummyArgument &that) const {
516515

517516
static std::optional<DummyArgument> CharacterizeDummyArgument(
518517
const semantics::Symbol &symbol, FoldingContext &context,
519-
semantics::SymbolSet &seenProcs) {
518+
semantics::UnorderedSymbolSet &seenProcs) {
520519
auto name{symbol.name().ToString()};
521520
if (symbol.has<semantics::ObjectEntityDetails>()) {
522521
if (auto obj{DummyDataObject::Characterize(symbol, context)}) {
@@ -779,7 +778,7 @@ bool Procedure::CanOverride(
779778

780779
std::optional<Procedure> Procedure::Characterize(
781780
const semantics::Symbol &original, FoldingContext &context) {
782-
semantics::SymbolSet seenProcs;
781+
semantics::UnorderedSymbolSet seenProcs;
783782
return CharacterizeProcedure(original, context, seenProcs);
784783
}
785784

flang/lib/Evaluate/constant.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,5 +315,9 @@ std::size_t Constant<SomeDerived>::CopyFrom(const Constant<SomeDerived> &source,
315315
return Base::CopyFrom(source, count, resultSubscripts, dimOrder);
316316
}
317317

318+
bool ComponentCompare::operator()(SymbolRef x, SymbolRef y) const {
319+
return semantics::SymbolSourcePositionCompare{}(x, y);
320+
}
321+
318322
INSTANTIATE_CONSTANT_TEMPLATES
319323
} // namespace Fortran::evaluate

flang/lib/Evaluate/tools.cpp

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -782,20 +782,22 @@ const Symbol *GetLastTarget(const SymbolVector &symbols) {
782782
}
783783

784784
struct CollectSymbolsHelper
785-
: public SetTraverse<CollectSymbolsHelper, semantics::SymbolSet> {
786-
using Base = SetTraverse<CollectSymbolsHelper, semantics::SymbolSet>;
785+
: public SetTraverse<CollectSymbolsHelper, semantics::UnorderedSymbolSet> {
786+
using Base = SetTraverse<CollectSymbolsHelper, semantics::UnorderedSymbolSet>;
787787
CollectSymbolsHelper() : Base{*this} {}
788788
using Base::operator();
789-
semantics::SymbolSet operator()(const Symbol &symbol) const {
789+
semantics::UnorderedSymbolSet operator()(const Symbol &symbol) const {
790790
return {symbol};
791791
}
792792
};
793-
template <typename A> semantics::SymbolSet CollectSymbols(const A &x) {
793+
template <typename A> semantics::UnorderedSymbolSet CollectSymbols(const A &x) {
794794
return CollectSymbolsHelper{}(x);
795795
}
796-
template semantics::SymbolSet CollectSymbols(const Expr<SomeType> &);
797-
template semantics::SymbolSet CollectSymbols(const Expr<SomeInteger> &);
798-
template semantics::SymbolSet CollectSymbols(const Expr<SubscriptInteger> &);
796+
template semantics::UnorderedSymbolSet CollectSymbols(const Expr<SomeType> &);
797+
template semantics::UnorderedSymbolSet CollectSymbols(
798+
const Expr<SomeInteger> &);
799+
template semantics::UnorderedSymbolSet CollectSymbols(
800+
const Expr<SubscriptInteger> &);
799801

800802
// HasVectorSubscript()
801803
struct HasVectorSubscriptHelper : public AnyTraverse<HasVectorSubscriptHelper> {
@@ -1177,7 +1179,7 @@ const Symbol &GetUsedModule(const UseDetails &details) {
11771179
}
11781180

11791181
static const Symbol *FindFunctionResult(
1180-
const Symbol &original, SymbolSet &seen) {
1182+
const Symbol &original, UnorderedSymbolSet &seen) {
11811183
const Symbol &root{GetAssociationRoot(original)};
11821184
;
11831185
if (!seen.insert(root).second) {
@@ -1199,16 +1201,23 @@ static const Symbol *FindFunctionResult(
11991201
}
12001202

12011203
const Symbol *FindFunctionResult(const Symbol &symbol) {
1202-
SymbolSet seen;
1204+
UnorderedSymbolSet seen;
12031205
return FindFunctionResult(symbol, seen);
12041206
}
12051207

12061208
// These are here in Evaluate/tools.cpp so that Evaluate can use
12071209
// them; they cannot be defined in symbol.h due to the dependence
12081210
// on Scope.
12091211

1210-
bool Symbol::operator<(const Symbol &that) const {
1211-
return GetSemanticsContext().allCookedSources().Precedes(name_, that.name_);
1212+
bool SymbolSourcePositionCompare::operator()(
1213+
const SymbolRef &x, const SymbolRef &y) const {
1214+
return x->GetSemanticsContext().allCookedSources().Precedes(
1215+
x->name(), y->name());
1216+
}
1217+
bool SymbolSourcePositionCompare::operator()(
1218+
const MutableSymbolRef &x, const MutableSymbolRef &y) const {
1219+
return x->GetSemanticsContext().allCookedSources().Precedes(
1220+
x->name(), y->name());
12121221
}
12131222

12141223
SemanticsContext &Symbol::GetSemanticsContext() const {

flang/lib/Parser/provenance.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -602,16 +602,15 @@ void AllCookedSources::Dump(llvm::raw_ostream &o) const {
602602
}
603603

604604
bool AllCookedSources::Precedes(CharBlock x, CharBlock y) const {
605-
const CookedSource *ySource{Find(y)};
606605
if (const CookedSource * xSource{Find(x)}) {
607-
if (ySource) {
608-
int xNum{xSource->number()};
609-
int yNum{ySource->number()};
610-
return xNum < yNum || (xNum == yNum && x.begin() < y.begin());
606+
if (xSource->AsCharBlock().Contains(y)) {
607+
return x.begin() < y.begin();
608+
} else if (const CookedSource * ySource{Find(y)}) {
609+
return xSource->number() < ySource->number();
611610
} else {
612611
return true; // by fiat, all cooked source < anything outside
613612
}
614-
} else if (ySource) {
613+
} else if (Find(y)) {
615614
return false;
616615
} else {
617616
// Both names are compiler-created (SaveTempName).

flang/lib/Semantics/check-declarations.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ class CheckHelper {
110110
// that has a symbol.
111111
const Symbol *innermostSymbol_{nullptr};
112112
// Cache of calls to Procedure::Characterize(Symbol)
113-
std::map<SymbolRef, std::optional<Procedure>> characterizeCache_;
113+
std::map<SymbolRef, std::optional<Procedure>, SymbolAddressCompare>
114+
characterizeCache_;
114115
};
115116

116117
class DistinguishabilityHelper {

flang/lib/Semantics/check-do-forall.cpp

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -548,9 +548,9 @@ class DoContext {
548548
// the names up in the scope that encloses the DO construct to avoid getting
549549
// the local versions of them. Then follow the host-, use-, and
550550
// construct-associations to get the root symbols
551-
SymbolSet GatherLocals(
551+
UnorderedSymbolSet GatherLocals(
552552
const std::list<parser::LocalitySpec> &localitySpecs) const {
553-
SymbolSet symbols;
553+
UnorderedSymbolSet symbols;
554554
const Scope &parentScope{
555555
context_.FindScope(currentStatementSourcePosition_).parent()};
556556
// Loop through the LocalitySpec::Local locality-specs
@@ -568,8 +568,9 @@ class DoContext {
568568
return symbols;
569569
}
570570

571-
static SymbolSet GatherSymbolsFromExpression(const parser::Expr &expression) {
572-
SymbolSet result;
571+
static UnorderedSymbolSet GatherSymbolsFromExpression(
572+
const parser::Expr &expression) {
573+
UnorderedSymbolSet result;
573574
if (const auto *expr{GetExpr(expression)}) {
574575
for (const Symbol &symbol : evaluate::CollectSymbols(*expr)) {
575576
result.insert(ResolveAssociations(symbol));
@@ -580,8 +581,9 @@ class DoContext {
580581

581582
// C1121 - procedures in mask must be pure
582583
void CheckMaskIsPure(const parser::ScalarLogicalExpr &mask) const {
583-
SymbolSet references{GatherSymbolsFromExpression(mask.thing.thing.value())};
584-
for (const Symbol &ref : references) {
584+
UnorderedSymbolSet references{
585+
GatherSymbolsFromExpression(mask.thing.thing.value())};
586+
for (const Symbol &ref : OrderBySourcePosition(references)) {
585587
if (IsProcedure(ref) && !IsPureProcedure(ref)) {
586588
context_.SayWithDecl(ref, parser::Unwrap<parser::Expr>(mask)->source,
587589
"%s mask expression may not reference impure procedure '%s'"_err_en_US,
@@ -591,10 +593,10 @@ class DoContext {
591593
}
592594
}
593595

594-
void CheckNoCollisions(const SymbolSet &refs, const SymbolSet &uses,
595-
parser::MessageFixedText &&errorMessage,
596+
void CheckNoCollisions(const UnorderedSymbolSet &refs,
597+
const UnorderedSymbolSet &uses, parser::MessageFixedText &&errorMessage,
596598
const parser::CharBlock &refPosition) const {
597-
for (const Symbol &ref : refs) {
599+
for (const Symbol &ref : OrderBySourcePosition(refs)) {
598600
if (uses.find(ref) != uses.end()) {
599601
context_.SayWithDecl(ref, refPosition, std::move(errorMessage),
600602
LoopKindName(), ref.name());
@@ -603,17 +605,17 @@ class DoContext {
603605
}
604606
}
605607

606-
void HasNoReferences(
607-
const SymbolSet &indexNames, const parser::ScalarIntExpr &expr) const {
608+
void HasNoReferences(const UnorderedSymbolSet &indexNames,
609+
const parser::ScalarIntExpr &expr) const {
608610
CheckNoCollisions(GatherSymbolsFromExpression(expr.thing.thing.value()),
609611
indexNames,
610612
"%s limit expression may not reference index variable '%s'"_err_en_US,
611613
expr.thing.thing.value().source);
612614
}
613615

614616
// C1129, names in local locality-specs can't be in mask expressions
615-
void CheckMaskDoesNotReferenceLocal(
616-
const parser::ScalarLogicalExpr &mask, const SymbolSet &localVars) const {
617+
void CheckMaskDoesNotReferenceLocal(const parser::ScalarLogicalExpr &mask,
618+
const UnorderedSymbolSet &localVars) const {
617619
CheckNoCollisions(GatherSymbolsFromExpression(mask.thing.thing.value()),
618620
localVars,
619621
"%s mask expression references variable '%s'"
@@ -623,8 +625,8 @@ class DoContext {
623625

624626
// C1129, names in local locality-specs can't be in limit or step
625627
// expressions
626-
void CheckExprDoesNotReferenceLocal(
627-
const parser::ScalarIntExpr &expr, const SymbolSet &localVars) const {
628+
void CheckExprDoesNotReferenceLocal(const parser::ScalarIntExpr &expr,
629+
const UnorderedSymbolSet &localVars) const {
628630
CheckNoCollisions(GatherSymbolsFromExpression(expr.thing.thing.value()),
629631
localVars,
630632
"%s expression references variable '%s'"
@@ -663,7 +665,7 @@ class DoContext {
663665
CheckMaskIsPure(*mask);
664666
}
665667
auto &controls{std::get<std::list<parser::ConcurrentControl>>(header.t)};
666-
SymbolSet indexNames;
668+
UnorderedSymbolSet indexNames;
667669
for (const parser::ConcurrentControl &control : controls) {
668670
const auto &indexName{std::get<parser::Name>(control.t)};
669671
if (indexName.symbol) {
@@ -697,7 +699,7 @@ class DoContext {
697699
const auto &localitySpecs{
698700
std::get<std::list<parser::LocalitySpec>>(concurrent.t)};
699701
if (!localitySpecs.empty()) {
700-
const SymbolSet &localVars{GatherLocals(localitySpecs)};
702+
const UnorderedSymbolSet &localVars{GatherLocals(localitySpecs)};
701703
for (const auto &c : GetControls(control)) {
702704
CheckExprDoesNotReferenceLocal(std::get<1>(c.t), localVars);
703705
CheckExprDoesNotReferenceLocal(std::get<2>(c.t), localVars);
@@ -733,7 +735,7 @@ class DoContext {
733735
void CheckForallIndexesUsed(const evaluate::Assignment &assignment) {
734736
SymbolVector indexVars{context_.GetIndexVars(IndexVarKind::FORALL)};
735737
if (!indexVars.empty()) {
736-
SymbolSet symbols{evaluate::CollectSymbols(assignment.lhs)};
738+
UnorderedSymbolSet symbols{evaluate::CollectSymbols(assignment.lhs)};
737739
std::visit(
738740
common::visitors{
739741
[&](const evaluate::Assignment::BoundsSpec &spec) {

flang/lib/Semantics/check-omp-structure.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ void OmpStructureChecker::Leave(const parser::OmpClauseList &) {
630630
}
631631
}
632632
// A list-item cannot appear in more than one aligned clause
633-
semantics::SymbolSet alignedVars;
633+
semantics::UnorderedSymbolSet alignedVars;
634634
auto clauseAll = FindClauses(llvm::omp::Clause::OMPC_aligned);
635635
for (auto itr = clauseAll.first; itr != clauseAll.second; ++itr) {
636636
const auto &alignedClause{

0 commit comments

Comments
 (0)