Skip to content

Commit 84873fb

Browse files
committed
IDE: implement an IDE action to collect types of all expressions in a source file.
This is libIDE side implementation for collecting all type information in a source file. When several expression share the same source range, we always report the type of the outermost expression. rdar://35199889
1 parent 0e91b49 commit 84873fb

File tree

5 files changed

+194
-0
lines changed

5 files changed

+194
-0
lines changed

include/swift/Sema/IDETypeChecking.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,28 @@ namespace swift {
161161
bool shouldPrintRequirement(ExtensionDecl *ED, StringRef Req);
162162
bool hasMergeGroup(MergeGroupKind Kind);
163163
};
164+
165+
/// Reported type for an expression. This expression is represented by offset
166+
/// length in the source buffer;
167+
struct ExpressionTypeInfo {
168+
169+
/// The start of the expression;
170+
uint32_t offset;
171+
172+
/// The length of the expression;
173+
uint32_t length;
174+
175+
/// The start of the printed type in a separately given string buffer.
176+
uint32_t typeOffset;
177+
178+
/// The length of the printed type
179+
uint32_t typeLength;
180+
};
181+
182+
/// Collect type information for every expression in \c SF; all types will
183+
/// be printed to \c OS.
184+
ArrayRef<ExpressionTypeInfo> collectExpressionType(SourceFile &SF,
185+
std::vector<ExpressionTypeInfo> &scratch, llvm::raw_ostream &OS);
164186
}
165187

166188
#endif

lib/IDE/IDETypeChecking.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
#include "swift/AST/NameLookup.h"
2424
#include "swift/AST/ProtocolConformance.h"
2525
#include "swift/Sema/IDETypeChecking.h"
26+
#include "swift/IDE/SourceEntityWalker.h"
27+
#include "swift/Parse/Lexer.h"
2628

2729
using namespace swift;
2830

@@ -577,3 +579,82 @@ collectDefaultImplementationForProtocolMembers(ProtocolDecl *PD,
577579
for (auto *IP : PD->getInheritedProtocols())
578580
HandleMembers(IP->getMembers());
579581
}
582+
583+
/// This walker will traverse the AST and report types for every expression.
584+
class ExpressionTypeCollector: public SourceEntityWalker {
585+
SourceManager &SM;
586+
unsigned int BufferId;
587+
std::vector<ExpressionTypeInfo> &Results;
588+
589+
// This is to where we print all types.
590+
llvm::raw_ostream &OS;
591+
592+
// Map from a printed type to the offset in OS where the type starts.
593+
llvm::StringMap<uint32_t> TypeOffsets;
594+
595+
// This keeps track of whether we have a type reported for a given
596+
// [offset, length].
597+
llvm::DenseMap<unsigned, llvm::DenseSet<unsigned>> AllPrintedTypes;
598+
599+
bool shouldReport(unsigned Offset, unsigned Length, Expr *E) {
600+
// We shouldn't report null types.
601+
if (E->getType().isNull())
602+
return false;
603+
604+
// If we have already reported types for this source range, we shouldn't
605+
// report again. This makes sure we always report the outtermost type of
606+
// several overlapping expressions.
607+
auto &Bucket = AllPrintedTypes[Offset];
608+
return Bucket.find(Length) == Bucket.end();
609+
}
610+
611+
// Find an existing offset in the type buffer otherwise print the type to
612+
// the buffer.
613+
uint32_t getTypeOffsets(StringRef PrintedType) {
614+
auto It = TypeOffsets.find(PrintedType);
615+
if (It == TypeOffsets.end()) {
616+
TypeOffsets[PrintedType] = OS.tell();
617+
OS << PrintedType;
618+
}
619+
return TypeOffsets[PrintedType];
620+
}
621+
622+
public:
623+
ExpressionTypeCollector(SourceFile &SF, std::vector<ExpressionTypeInfo> &Results,
624+
llvm::raw_ostream &OS): SM(SF.getASTContext().SourceMgr),
625+
BufferId(*SF.getBufferID()),
626+
Results(Results), OS(OS) {}
627+
bool walkToExprPre(Expr *E) override {
628+
if (E->getSourceRange().isInvalid())
629+
return true;
630+
CharSourceRange Range =
631+
Lexer::getCharSourceRangeFromSourceRange(SM, E->getSourceRange());
632+
unsigned Offset = SM.getLocOffsetInBuffer(Range.getStart(), BufferId);
633+
unsigned Length = Range.getByteLength();
634+
if (!shouldReport(Offset, Length, E))
635+
return true;
636+
// Print the type to a temporary buffer.
637+
SmallString<64> Buffer;
638+
{
639+
llvm::raw_svector_ostream OS(Buffer);
640+
E->getType()->getRValueType()->reconstituteSugar(true)->print(OS);
641+
}
642+
643+
// Add the type information to the result list.
644+
Results.push_back({Offset, Length, getTypeOffsets(Buffer.str()),
645+
static_cast<uint32_t>(Buffer.size())});
646+
647+
// Keep track of that we have a type reported for this range.
648+
AllPrintedTypes[Offset].insert(Length);
649+
return true;
650+
}
651+
};
652+
653+
ArrayRef<ExpressionTypeInfo>
654+
swift::collectExpressionType(SourceFile &SF,
655+
std::vector<ExpressionTypeInfo> &Scratch,
656+
llvm::raw_ostream &OS) {
657+
ExpressionTypeCollector Walker(SF, Scratch, OS);
658+
Walker.walk(SF);
659+
return Scratch;
660+
}

test/IDE/Inputs/ExprType.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
2+
func foo() -> Int { return 1 }
3+
4+
func bar(f: Float) -> Float { return bar(f: 1) }
5+
6+
protocol P {}
7+
8+
func fooP(_ p: P) { fooP(p) }
9+
10+
class C {}
11+
12+
func ArrayC(_ a: [C]) {
13+
_ = a.count
14+
_ = a.description.count.advanced(by: 1).description
15+
}
16+
17+
struct S {
18+
let val = 4
19+
}
20+
func DictS(_ a: [Int: S]) {
21+
_ = a[2]?.val.advanced(by: 1).byteSwapped
22+
}

test/IDE/expr_type.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// RUN: %target-swift-ide-test -print-expr-type -source-filename %S/Inputs/ExprType.swift -swift-version 5 | %FileCheck %s
2+
3+
// CHECK: func foo() -> Int { return <expr type:"Int">1</expr> }
4+
// CHECK: func bar(f: Float) -> Float { return <expr type:"Float"><expr type:"(Float) -> Float">bar</expr><expr type:"(f: Float)">(f: <expr type:"Float">1</expr>)</expr></expr> }
5+
// CHECK: func fooP(_ p: P) { <expr type:"()"><expr type:"(P) -> ()">fooP</expr><expr type:"(P)">(<expr type:"P">p</expr>)</expr></expr> }
6+
// CHECK: <expr type:"()"><expr type:"Int">_</expr> = <expr type:"Int"><expr type:"[C]">a</expr>.count</expr></expr>
7+
// CHECK: <expr type:"()"><expr type:"String">_</expr> = <expr type:"String"><expr type:"Int"><expr type:"(Int) -> Int"><expr type:"Int"><expr type:"String"><expr type:"[C]">a</expr>.description</expr>.count</expr>.<expr type:"(Int) -> (Int) -> Int">advanced</expr></expr><expr type:"(by: Int)">(by: <expr type:"Int">1</expr>)</expr></expr>.description</expr></expr>
8+
// CHECK: <expr type:"()"><expr type:"Int?">_</expr> = <expr type:"Int?"><expr type:"Int"><expr type:"(Int) -> Int"><expr type:"Int"><expr type:"S"><expr type:"S?"><expr type:"[Int : S]">a</expr><expr type:"(Int)">[<expr type:"Int">2</expr>]</expr></expr>?</expr>.val</expr>.<expr type:"(Int) -> (Int) -> Int">advanced</expr></expr><expr type:"(by: Int)">(by: <expr type:"Int">1</expr>)</expr></expr>.byteSwapped</expr></expr>

tools/swift-ide-test/swift-ide-test.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ enum class ActionType {
9595
PrintLocalTypes,
9696
PrintTypeInterface,
9797
PrintIndexedSymbols,
98+
PrintExpressionTypes,
9899
TestCreateCompilerInvocation,
99100
CompilerInvocationFromModule,
100101
GenerateModuleAPIDescription,
@@ -228,6 +229,9 @@ Action(llvm::cl::desc("Mode:"), llvm::cl::init(ActionType::None),
228229
clEnumValN(ActionType::TypeContextInfo,
229230
"type-context-info",
230231
"Perform expression context info analysis"),
232+
clEnumValN(ActionType::PrintExpressionTypes,
233+
"print-expr-type",
234+
"Print types for all expressions in the file"),
231235
clEnumValN(ActionType::ConformingMethodList,
232236
"conforming-methods",
233237
"Perform conforming method analysis for expression")));
@@ -1709,6 +1713,57 @@ static int doPrintAST(const CompilerInvocation &InitInvok,
17091713
return EXIT_SUCCESS;
17101714
}
17111715

1716+
static int doPrintExpressionTypes(const CompilerInvocation &InitInvok,
1717+
StringRef SourceFilename) {
1718+
CompilerInvocation Invocation(InitInvok);
1719+
Invocation.getFrontendOptions().InputsAndOutputs.addPrimaryInputFile(SourceFilename);
1720+
CompilerInstance CI;
1721+
1722+
// Display diagnostics to stderr.
1723+
PrintingDiagnosticConsumer PrintDiags;
1724+
CI.addDiagnosticConsumer(&PrintDiags);
1725+
if (CI.setup(Invocation))
1726+
return EXIT_FAILURE;
1727+
CI.performSema();
1728+
std::vector<ExpressionTypeInfo> Scratch;
1729+
1730+
// Buffer for where types will be printed.
1731+
llvm::SmallString<256> TypeBuffer;
1732+
llvm::raw_svector_ostream OS(TypeBuffer);
1733+
SourceFile &SF = *CI.getPrimarySourceFile();
1734+
auto Source = SF.getASTContext().SourceMgr.getRangeForBuffer(*SF.getBufferID()).str();
1735+
std::vector<std::pair<unsigned, std::string>> SortedTags;
1736+
1737+
// Collect all tags of expressions.
1738+
for (auto R: collectExpressionType(*CI.getPrimarySourceFile(), Scratch, OS)) {
1739+
SortedTags.push_back({R.offset,
1740+
(llvm::Twine("<expr type:\"") + TypeBuffer.str().substr(R.typeOffset,
1741+
R.typeLength) + "\">").str()});
1742+
SortedTags.push_back({R.offset + R.length, "</expr>"});
1743+
}
1744+
// Sort tags by offset.
1745+
std::stable_sort(SortedTags.begin(), SortedTags.end(),
1746+
[](std::pair<unsigned, std::string> T1, std::pair<unsigned, std::string> T2) {
1747+
return T1.first < T2.first;
1748+
});
1749+
1750+
ArrayRef<std::pair<unsigned, std::string>> SortedTagsRef = SortedTags;
1751+
unsigned Cur = 0;
1752+
do {
1753+
// Print tags that are due at current offset.
1754+
while(!SortedTagsRef.empty() && SortedTagsRef.front().first == Cur) {
1755+
llvm::outs() << SortedTagsRef.front().second;
1756+
SortedTagsRef = SortedTagsRef.drop_front();
1757+
}
1758+
auto Start = Cur;
1759+
// Change current offset to the start offset of next tag.
1760+
Cur = SortedTagsRef.empty() ? Source.size() : SortedTagsRef.front().first;
1761+
// Print the source before next tag.
1762+
llvm::outs() << Source.substr(Start, Cur - Start);
1763+
} while(!SortedTagsRef.empty());
1764+
return EXIT_SUCCESS;
1765+
}
1766+
17121767
static int doPrintLocalTypes(const CompilerInvocation &InitInvok,
17131768
const std::vector<std::string> ModulesToPrint) {
17141769
using NodeKind = Demangle::Node::Kind;
@@ -3353,6 +3408,12 @@ int main(int argc, char *argv[]) {
33533408
options::CodeCompletionDiagnostics);
33543409
break;
33553410

3411+
case ActionType::PrintExpressionTypes:
3412+
ExitCode = doPrintExpressionTypes(InitInvok,
3413+
options::SourceFilename);
3414+
break;
3415+
3416+
33563417
case ActionType::ConformingMethodList:
33573418
if (options::CodeCompletionToken.empty()) {
33583419
llvm::errs() << "token name required\n";

0 commit comments

Comments
 (0)