Skip to content

Commit d1e4a2d

Browse files
authored
[flang] Fix spurious error with separate module procedures (#106768)
When the implementation of one SMP apparently references another in what might be a specification expression, semantics may need to resolve it as a forward reference, and to allow for the replacement of a SubprogramNameDetails place-holding symbol with the final SubprogramDetails symbol. Otherwise, as in the bug report below, confusing error messages may result. (The reference in question isn't really in the specification part of a subprogram, but due to the syntactic ambiguity between the array element assignment statement and a statement function definition, it appears to be so at the time that the reference is processed.) I needed to make DumpSymbols() available via SemanticsContext to analyze this bug, and left that new API in place to make things easier next time. Fixes #106705.
1 parent 840da2e commit d1e4a2d

File tree

5 files changed

+53
-18
lines changed

5 files changed

+53
-18
lines changed

flang/include/flang/Semantics/expression.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ class ExpressionAnalyzer {
354354
parser::CharBlock, const ProcedureDesignator &, ActualArguments &);
355355
using AdjustActuals =
356356
std::optional<std::function<bool(const Symbol &, ActualArguments &)>>;
357-
bool ResolveForward(const Symbol &);
357+
const Symbol *ResolveForward(const Symbol &);
358358
std::pair<const Symbol *, bool /* failure due ambiguity */> ResolveGeneric(
359359
const Symbol &, const ActualArguments &, const AdjustActuals &,
360360
bool isSubroutine, bool mightBeStructureConstructor = false);

flang/include/flang/Semantics/semantics.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@ class SemanticsContext {
257257
void NoteDefinedSymbol(const Symbol &);
258258
bool IsSymbolDefined(const Symbol &) const;
259259

260+
void DumpSymbols(llvm::raw_ostream &);
261+
260262
private:
261263
struct ScopeIndexComparator {
262264
bool operator()(parser::CharBlock, parser::CharBlock) const;

flang/lib/Semantics/expression.cpp

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2650,9 +2650,9 @@ static int ComputeCudaMatchingDistance(
26502650
// Handles a forward reference to a module function from what must
26512651
// be a specification expression. Return false if the symbol is
26522652
// an invalid forward reference.
2653-
bool ExpressionAnalyzer::ResolveForward(const Symbol &symbol) {
2653+
const Symbol *ExpressionAnalyzer::ResolveForward(const Symbol &symbol) {
26542654
if (context_.HasError(symbol)) {
2655-
return false;
2655+
return nullptr;
26562656
}
26572657
if (const auto *details{
26582658
symbol.detailsIf<semantics::SubprogramNameDetails>()}) {
@@ -2661,28 +2661,33 @@ bool ExpressionAnalyzer::ResolveForward(const Symbol &symbol) {
26612661
// checking a specification expression in a sibling module
26622662
// procedure. Resolve its names now so that its interface
26632663
// is known.
2664+
const semantics::Scope &scope{symbol.owner()};
26642665
semantics::ResolveSpecificationParts(context_, symbol);
2665-
if (symbol.has<semantics::SubprogramNameDetails>()) {
2666+
const Symbol *resolved{nullptr};
2667+
if (auto iter{scope.find(symbol.name())}; iter != scope.cend()) {
2668+
resolved = &*iter->second;
2669+
}
2670+
if (!resolved || resolved->has<semantics::SubprogramNameDetails>()) {
26662671
// When the symbol hasn't had its details updated, we must have
26672672
// already been in the process of resolving the function's
26682673
// specification part; but recursive function calls are not
26692674
// allowed in specification parts (10.1.11 para 5).
26702675
Say("The module function '%s' may not be referenced recursively in a specification expression"_err_en_US,
26712676
symbol.name());
26722677
context_.SetError(symbol);
2673-
return false;
26742678
}
2679+
return resolved;
26752680
} else if (inStmtFunctionDefinition_) {
26762681
semantics::ResolveSpecificationParts(context_, symbol);
26772682
CHECK(symbol.has<semantics::SubprogramDetails>());
26782683
} else { // 10.1.11 para 4
26792684
Say("The internal function '%s' may not be referenced in a specification expression"_err_en_US,
26802685
symbol.name());
26812686
context_.SetError(symbol);
2682-
return false;
2687+
return nullptr;
26832688
}
26842689
}
2685-
return true;
2690+
return &symbol;
26862691
}
26872692

26882693
// Resolve a call to a generic procedure with given actual arguments.
@@ -2709,20 +2714,21 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
27092714
}
27102715
if (const auto *details{ultimate.detailsIf<semantics::GenericDetails>()}) {
27112716
for (const Symbol &specific0 : details->specificProcs()) {
2712-
const Symbol &specific{BypassGeneric(specific0)};
2713-
if (isSubroutine != !IsFunction(specific)) {
2717+
const Symbol &specific1{BypassGeneric(specific0)};
2718+
if (isSubroutine != !IsFunction(specific1)) {
27142719
continue;
27152720
}
2716-
if (!ResolveForward(specific)) {
2721+
const Symbol *specific{ResolveForward(specific1)};
2722+
if (!specific) {
27172723
continue;
27182724
}
27192725
if (std::optional<characteristics::Procedure> procedure{
27202726
characteristics::Procedure::Characterize(
2721-
ProcedureDesignator{specific}, context_.foldingContext(),
2727+
ProcedureDesignator{*specific}, context_.foldingContext(),
27222728
/*emitError=*/false)}) {
27232729
ActualArguments localActuals{actuals};
2724-
if (specific.has<semantics::ProcBindingDetails>()) {
2725-
if (!adjustActuals.value()(specific, localActuals)) {
2730+
if (specific->has<semantics::ProcBindingDetails>()) {
2731+
if (!adjustActuals.value()(*specific, localActuals)) {
27262732
continue;
27272733
}
27282734
}
@@ -2751,9 +2757,9 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
27512757
}
27522758
if (!procedure->IsElemental()) {
27532759
// takes priority over elemental match
2754-
nonElemental = &specific;
2760+
nonElemental = specific;
27552761
} else {
2756-
elemental = &specific;
2762+
elemental = specific;
27572763
}
27582764
crtMatchingDistance = ComputeCudaMatchingDistance(
27592765
context_.languageFeatures(), *procedure, localActuals);
@@ -2866,7 +2872,12 @@ auto ExpressionAnalyzer::GetCalleeAndArguments(const parser::Name &name,
28662872
if (context_.HasError(symbol)) {
28672873
return std::nullopt; // also handles null symbol
28682874
}
2869-
const Symbol &ultimate{DEREF(symbol).GetUltimate()};
2875+
symbol = ResolveForward(*symbol);
2876+
if (!symbol) {
2877+
return std::nullopt;
2878+
}
2879+
name.symbol = const_cast<Symbol *>(symbol);
2880+
const Symbol &ultimate{symbol->GetUltimate()};
28702881
CheckForBadRecursion(name.source, ultimate);
28712882
bool dueToAmbiguity{false};
28722883
bool isGenericInterface{ultimate.has<semantics::GenericDetails>()};

flang/lib/Semantics/semantics.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -655,10 +655,12 @@ void Semantics::EmitMessages(llvm::raw_ostream &os) {
655655
context_.messages().Emit(os, context_.allCookedSources());
656656
}
657657

658-
void Semantics::DumpSymbols(llvm::raw_ostream &os) {
659-
DoDumpSymbols(os, context_.globalScope());
658+
void SemanticsContext::DumpSymbols(llvm::raw_ostream &os) {
659+
DoDumpSymbols(os, globalScope());
660660
}
661661

662+
void Semantics::DumpSymbols(llvm::raw_ostream &os) { context_.DumpSymbols(os); }
663+
662664
void Semantics::DumpSymbolsSources(llvm::raw_ostream &os) const {
663665
NameToSymbolMap symbols;
664666
GetSymbolNames(context_.globalScope(), symbols);

flang/test/Semantics/smp-proc-ref.f90

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
!RUN: %flang_fc1 -fsyntax-only %s
2+
module m
3+
real :: qux(10)
4+
interface
5+
module subroutine bar(i)
6+
end
7+
module function baz()
8+
end
9+
end interface
10+
end
11+
12+
submodule(m) sm
13+
contains
14+
module procedure bar
15+
qux(i) = baz() ! ensure no bogus error here
16+
end
17+
module procedure baz
18+
baz = 1.
19+
end
20+
end

0 commit comments

Comments
 (0)