Skip to content

Commit f36e856

Browse files
committed
[ownership] Use a new ADT SwitchEnumBranch instead of SwitchEnumInstBase for generic operations on SwitchEnum{,Addr}Inst.
I have a need to have SwitchEnum{,Addr}Inst have different base classes (TermInst, OwnershipForwardingTermInst). To do this I need to add a template to SwitchEnumInstBase so I can switch that BaseTy. Sadly since we are using SwitchEnumInstBase as an ADT type as well as an actual base type for Instructions, this is impossible to do without introducing a template in a ton of places. Rather than doing that, I changed the code that was using SwitchEnumInstBase as an ADT to instead use a proper ADT SwitchEnumBranch. I am happy to change the name as possible see fit (maybe SwitchEnumTerm?).
1 parent 5c88120 commit f36e856

File tree

10 files changed

+279
-115
lines changed

10 files changed

+279
-115
lines changed

include/swift/SIL/TerminatorUtils.h

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
//===--- TerminatorUtils.h ------------------------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2020 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
///
13+
/// \file
14+
///
15+
/// ADTs for working with various forms of terminators.
16+
///
17+
//===----------------------------------------------------------------------===//
18+
19+
#ifndef SWIFT_SIL_TERMINATORUTILS_H
20+
#define SWIFT_SIL_TERMINATORUTILS_H
21+
22+
#include "swift/Basic/LLVM.h"
23+
#include "swift/SIL/SILInstruction.h"
24+
25+
#include "llvm/ADT/PointerUnion.h"
26+
27+
namespace swift {
28+
29+
/// An ADT for writing generic code against SwitchEnumAddrInst and
30+
/// SwitchEnumInst.
31+
///
32+
/// We use this instead of SwitchEnumInstBase for this purpose in order to avoid
33+
/// the need for templating SwitchEnumInstBase from causing this ADT type of
34+
/// usage to require templates.
35+
class SwitchEnumTermInst {
36+
PointerUnion<SwitchEnumAddrInst *, SwitchEnumInst *> value;
37+
38+
public:
39+
SwitchEnumTermInst(SwitchEnumAddrInst *seai) : value(seai) {}
40+
SwitchEnumTermInst(SwitchEnumInst *seai) : value(seai) {}
41+
SwitchEnumTermInst(SILInstruction *i) : value(nullptr) {
42+
if (auto *seai = dyn_cast<SwitchEnumAddrInst>(i)) {
43+
value = seai;
44+
return;
45+
}
46+
47+
if (auto *sei = dyn_cast<SwitchEnumInst>(i)) {
48+
value = sei;
49+
return;
50+
}
51+
}
52+
53+
SwitchEnumTermInst(const SILInstruction *i)
54+
: SwitchEnumTermInst(const_cast<SILInstruction *>(i)) {}
55+
56+
operator TermInst *() const {
57+
if (auto *seai = value.dyn_cast<SwitchEnumAddrInst *>())
58+
return seai;
59+
return value.get<SwitchEnumInst *>();
60+
}
61+
62+
TermInst *operator*() const {
63+
if (auto *seai = value.dyn_cast<SwitchEnumAddrInst *>())
64+
return seai;
65+
return value.get<SwitchEnumInst *>();
66+
}
67+
68+
TermInst *operator->() const {
69+
if (auto *seai = value.dyn_cast<SwitchEnumAddrInst *>())
70+
return seai;
71+
return value.get<SwitchEnumInst *>();
72+
}
73+
74+
operator bool() const { return bool(value); }
75+
76+
SILValue getOperand() {
77+
if (auto *sei = value.dyn_cast<SwitchEnumInst *>())
78+
return sei->getOperand();
79+
return value.get<SwitchEnumAddrInst *>()->getOperand();
80+
}
81+
82+
unsigned getNumCases() {
83+
if (auto *sei = value.dyn_cast<SwitchEnumInst *>())
84+
return sei->getNumCases();
85+
return value.get<SwitchEnumAddrInst *>()->getNumCases();
86+
}
87+
88+
std::pair<EnumElementDecl *, SILBasicBlock *> getCase(unsigned i) const {
89+
if (auto *sei = value.dyn_cast<SwitchEnumInst *>())
90+
return sei->getCase(i);
91+
return value.get<SwitchEnumAddrInst *>()->getCase(i);
92+
}
93+
94+
SILBasicBlock *getCaseDestination(EnumElementDecl *decl) const {
95+
if (auto *sei = value.dyn_cast<SwitchEnumInst *>())
96+
return sei->getCaseDestination(decl);
97+
return value.get<SwitchEnumAddrInst *>()->getCaseDestination(decl);
98+
}
99+
100+
ProfileCounter getCaseCount(unsigned i) const {
101+
if (auto *sei = value.dyn_cast<SwitchEnumInst *>())
102+
return sei->getCaseCount(i);
103+
return value.get<SwitchEnumAddrInst *>()->getCaseCount(i);
104+
}
105+
106+
ProfileCounter getDefaultCount() const {
107+
if (auto *sei = value.dyn_cast<SwitchEnumInst *>())
108+
return sei->getDefaultCount();
109+
return value.get<SwitchEnumAddrInst *>()->getDefaultCount();
110+
}
111+
112+
bool hasDefault() const {
113+
if (auto *sei = value.dyn_cast<SwitchEnumInst *>())
114+
return sei->hasDefault();
115+
return value.get<SwitchEnumAddrInst *>()->hasDefault();
116+
}
117+
118+
SILBasicBlock *getDefaultBB() const {
119+
if (auto *sei = value.dyn_cast<SwitchEnumInst *>())
120+
return sei->getDefaultBB();
121+
return value.get<SwitchEnumAddrInst *>()->getDefaultBB();
122+
}
123+
124+
NullablePtr<SILBasicBlock> getDefaultBBOrNull() const {
125+
if (auto *sei = value.dyn_cast<SwitchEnumInst *>())
126+
return sei->getDefaultBBOrNull();
127+
return value.get<SwitchEnumAddrInst *>()->getDefaultBBOrNull();
128+
}
129+
130+
/// If the default refers to exactly one case decl, return it.
131+
NullablePtr<EnumElementDecl> getUniqueCaseForDefault() const {
132+
if (auto *sei = value.dyn_cast<SwitchEnumInst *>())
133+
return sei->getUniqueCaseForDefault();
134+
return value.get<SwitchEnumAddrInst *>()->getUniqueCaseForDefault();
135+
}
136+
137+
/// If the given block only has one enum element decl matched to it,
138+
/// return it.
139+
NullablePtr<EnumElementDecl>
140+
getUniqueCaseForDestination(SILBasicBlock *BB) const {
141+
if (auto *sei = value.dyn_cast<SwitchEnumInst *>())
142+
return sei->getUniqueCaseForDestination(BB);
143+
return value.get<SwitchEnumAddrInst *>()->getUniqueCaseForDestination(BB);
144+
}
145+
};
146+
147+
} // namespace swift
148+
149+
#endif

lib/IRGen/IRGenSIL.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "swift/SIL/SILModule.h"
3737
#include "swift/SIL/SILType.h"
3838
#include "swift/SIL/SILVisitor.h"
39+
#include "swift/SIL/TerminatorUtils.h"
3940
#include "clang/AST/ASTContext.h"
4041
#include "clang/AST/DeclCXX.h"
4142
#include "clang/Basic/TargetInfo.h"
@@ -3516,12 +3517,12 @@ static void addIncomingSILArgumentsToPHINodes(IRGenSILFunction &IGF,
35163517
}
35173518

35183519
static llvm::BasicBlock *emitBBMapForSwitchEnum(
3519-
IRGenSILFunction &IGF,
3520-
SmallVectorImpl<std::pair<EnumElementDecl*, llvm::BasicBlock*>> &dests,
3521-
SwitchEnumInstBase *inst) {
3522-
for (unsigned i = 0, e = inst->getNumCases(); i < e; ++i) {
3523-
auto casePair = inst->getCase(i);
3524-
3520+
IRGenSILFunction &IGF,
3521+
SmallVectorImpl<std::pair<EnumElementDecl *, llvm::BasicBlock *>> &dests,
3522+
SwitchEnumTermInst inst) {
3523+
for (unsigned i = 0, e = inst.getNumCases(); i < e; ++i) {
3524+
auto casePair = inst.getCase(i);
3525+
35253526
// If the destination BB accepts the case argument, set up a waypoint BB so
35263527
// we can feed the values into the argument's PHI node(s).
35273528
//
@@ -3533,10 +3534,10 @@ static llvm::BasicBlock *emitBBMapForSwitchEnum(
35333534
else
35343535
dests.push_back({casePair.first, IGF.getLoweredBB(casePair.second).bb});
35353536
}
3536-
3537+
35373538
llvm::BasicBlock *defaultDest = nullptr;
3538-
if (inst->hasDefault())
3539-
defaultDest = IGF.getLoweredBB(inst->getDefaultBB()).bb;
3539+
if (inst.hasDefault())
3540+
defaultDest = IGF.getLoweredBB(inst.getDefaultBB()).bb;
35403541
return defaultDest;
35413542
}
35423543

lib/SIL/IR/SILPrinter.cpp

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,27 @@
1616
///
1717
//===----------------------------------------------------------------------===//
1818

19-
#include "swift/Strings.h"
20-
#include "swift/Demangling/Demangle.h"
19+
#include "swift/AST/Decl.h"
20+
#include "swift/AST/GenericEnvironment.h"
21+
#include "swift/AST/Module.h"
22+
#include "swift/AST/PrintOptions.h"
23+
#include "swift/AST/ProtocolConformance.h"
24+
#include "swift/AST/Types.h"
2125
#include "swift/Basic/QuotedString.h"
22-
#include "swift/SIL/SILPrintContext.h"
26+
#include "swift/Basic/STLExtras.h"
27+
#include "swift/Demangling/Demangle.h"
2328
#include "swift/SIL/ApplySite.h"
2429
#include "swift/SIL/CFG.h"
25-
#include "swift/SIL/SILFunction.h"
2630
#include "swift/SIL/SILCoverageMap.h"
2731
#include "swift/SIL/SILDebugScope.h"
2832
#include "swift/SIL/SILDeclRef.h"
33+
#include "swift/SIL/SILFunction.h"
2934
#include "swift/SIL/SILModule.h"
30-
#include "swift/SIL/SILVisitor.h"
35+
#include "swift/SIL/SILPrintContext.h"
3136
#include "swift/SIL/SILVTable.h"
32-
#include "swift/AST/Decl.h"
33-
#include "swift/AST/GenericEnvironment.h"
34-
#include "swift/AST/Module.h"
35-
#include "swift/AST/PrintOptions.h"
36-
#include "swift/AST/ProtocolConformance.h"
37-
#include "swift/AST/Types.h"
38-
#include "swift/Basic/STLExtras.h"
37+
#include "swift/SIL/SILVisitor.h"
38+
#include "swift/SIL/TerminatorUtils.h"
39+
#include "swift/Strings.h"
3940
#include "clang/AST/ASTContext.h"
4041
#include "clang/AST/Decl.h"
4142
#include "llvm/ADT/APFloat.h"
@@ -47,11 +48,10 @@
4748
#include "llvm/ADT/StringRef.h"
4849
#include "llvm/ADT/StringSwitch.h"
4950
#include "llvm/Support/CommandLine.h"
50-
#include "llvm/Support/FormattedStream.h"
5151
#include "llvm/Support/FileSystem.h"
52+
#include "llvm/Support/FormattedStream.h"
5253
#include <set>
5354

54-
5555
using namespace swift;
5656
using ID = SILPrintContext::ID;
5757

@@ -2111,27 +2111,27 @@ class SILPrinter : public SILInstructionVisitor<SILPrinter> {
21112111
if (SII->hasDefault())
21122112
*this << ", default " << Ctx.getID(SII->getDefaultBB());
21132113
}
2114-
2115-
void printSwitchEnumInst(SwitchEnumInstBase *SOI) {
2116-
*this << getIDAndType(SOI->getOperand());
2117-
for (unsigned i = 0, e = SOI->getNumCases(); i < e; ++i) {
2114+
2115+
void printSwitchEnumInst(SwitchEnumTermInst SOI) {
2116+
*this << getIDAndType(SOI.getOperand());
2117+
for (unsigned i = 0, e = SOI.getNumCases(); i < e; ++i) {
21182118
EnumElementDecl *elt;
21192119
SILBasicBlock *dest;
2120-
std::tie(elt, dest) = SOI->getCase(i);
2120+
std::tie(elt, dest) = SOI.getCase(i);
21212121
*this << ", case " << SILDeclRef(elt, SILDeclRef::Kind::EnumElement)
21222122
<< ": " << Ctx.getID(dest);
2123-
if (SOI->getCaseCount(i)) {
2124-
*this << " !case_count(" << SOI->getCaseCount(i).getValue() << ")";
2123+
if (SOI.getCaseCount(i)) {
2124+
*this << " !case_count(" << SOI.getCaseCount(i).getValue() << ")";
21252125
}
21262126
}
2127-
if (SOI->hasDefault()) {
2128-
*this << ", default " << Ctx.getID(SOI->getDefaultBB());
2129-
if (SOI->getDefaultCount()) {
2130-
*this << " !default_count(" << SOI->getDefaultCount().getValue() << ")";
2127+
if (SOI.hasDefault()) {
2128+
*this << ", default " << Ctx.getID(SOI.getDefaultBB());
2129+
if (SOI.getDefaultCount()) {
2130+
*this << " !default_count(" << SOI.getDefaultCount().getValue() << ")";
21312131
}
21322132
}
21332133
}
2134-
2134+
21352135
void visitSwitchEnumInst(SwitchEnumInst *SOI) {
21362136
printSwitchEnumInst(SOI);
21372137
}

lib/SIL/Utils/BasicBlockUtils.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "swift/SIL/SILBasicBlock.h"
1818
#include "swift/SIL/SILBuilder.h"
1919
#include "swift/SIL/SILFunction.h"
20+
#include "swift/SIL/TerminatorUtils.h"
2021

2122
using namespace swift;
2223

@@ -97,6 +98,12 @@ static SILBasicBlock *getNthEdgeBlock(SwitchInstTy *S, unsigned edgeIdx) {
9798
return S->getCase(edgeIdx).second;
9899
}
99100

101+
static SILBasicBlock *getNthEdgeBlock(SwitchEnumTermInst S, unsigned edgeIdx) {
102+
if (S.getNumCases() == edgeIdx)
103+
return S.getDefaultBB();
104+
return S.getCase(edgeIdx).second;
105+
}
106+
100107
void swift::getEdgeArgs(TermInst *T, unsigned edgeIdx, SILBasicBlock *newEdgeBB,
101108
llvm::SmallVectorImpl<SILValue> &args) {
102109
switch (T->getKind()) {
@@ -159,8 +166,8 @@ void swift::getEdgeArgs(TermInst *T, unsigned edgeIdx, SILBasicBlock *newEdgeBB,
159166
// destination block to figure this out.
160167
case SILInstructionKind::SwitchEnumInst:
161168
case SILInstructionKind::SwitchEnumAddrInst: {
162-
auto SEI = cast<SwitchEnumInstBase>(T);
163-
auto *succBB = getNthEdgeBlock(SEI, edgeIdx);
169+
SwitchEnumTermInst branch(T);
170+
auto *succBB = getNthEdgeBlock(branch, edgeIdx);
164171
assert(succBB->getNumArguments() < 2 && "Can take at most one argument");
165172
if (!succBB->getNumArguments())
166173
return;

lib/SILOptimizer/Differentiation/VJPCloner.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "swift/SILOptimizer/Differentiation/PullbackCloner.h"
2626
#include "swift/SILOptimizer/Differentiation/Thunk.h"
2727

28+
#include "swift/SIL/TerminatorUtils.h"
2829
#include "swift/SIL/TypeSubstCloner.h"
2930
#include "swift/SILOptimizer/PassManager/PrettyStackTrace.h"
3031
#include "swift/SILOptimizer/Utils/CFGOptUtils.h"
@@ -256,46 +257,44 @@ class VJPCloner::Implementation final
256257
createTrampolineBasicBlock(cbi, pbStructVal, cbi->getFalseBB()));
257258
}
258259

259-
void visitSwitchEnumInstBase(SwitchEnumInstBase *inst) {
260+
void visitSwitchEnumTermInst(SwitchEnumTermInst inst) {
260261
// Build pullback struct value for original block.
261-
auto *pbStructVal = buildPullbackValueStructValue(inst);
262+
auto *pbStructVal = buildPullbackValueStructValue(*inst);
262263

263264
// Create trampoline successor basic blocks.
264265
SmallVector<std::pair<EnumElementDecl *, SILBasicBlock *>, 4> caseBBs;
265-
for (unsigned i : range(inst->getNumCases())) {
266-
auto caseBB = inst->getCase(i);
266+
for (unsigned i : range(inst.getNumCases())) {
267+
auto caseBB = inst.getCase(i);
267268
auto *trampolineBB =
268269
createTrampolineBasicBlock(inst, pbStructVal, caseBB.second);
269270
caseBBs.push_back({caseBB.first, trampolineBB});
270271
}
271272
// Create trampoline default basic block.
272273
SILBasicBlock *newDefaultBB = nullptr;
273-
if (auto *defaultBB = inst->getDefaultBBOrNull().getPtrOrNull())
274+
if (auto *defaultBB = inst.getDefaultBBOrNull().getPtrOrNull())
274275
newDefaultBB = createTrampolineBasicBlock(inst, pbStructVal, defaultBB);
275276

276277
// Create a new `switch_enum` instruction.
277278
switch (inst->getKind()) {
278279
case SILInstructionKind::SwitchEnumInst:
279-
getBuilder().createSwitchEnum(inst->getLoc(),
280-
getOpValue(inst->getOperand()),
281-
newDefaultBB, caseBBs);
280+
getBuilder().createSwitchEnum(
281+
inst->getLoc(), getOpValue(inst.getOperand()), newDefaultBB, caseBBs);
282282
break;
283283
case SILInstructionKind::SwitchEnumAddrInst:
284-
getBuilder().createSwitchEnumAddr(inst->getLoc(),
285-
getOpValue(inst->getOperand()),
286-
newDefaultBB, caseBBs);
284+
getBuilder().createSwitchEnumAddr(
285+
inst->getLoc(), getOpValue(inst.getOperand()), newDefaultBB, caseBBs);
287286
break;
288287
default:
289288
llvm_unreachable("Expected `switch_enum` or `switch_enum_addr`");
290289
}
291290
}
292291

293292
void visitSwitchEnumInst(SwitchEnumInst *sei) {
294-
visitSwitchEnumInstBase(sei);
293+
visitSwitchEnumTermInst(sei);
295294
}
296295

297296
void visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai) {
298-
visitSwitchEnumInstBase(seai);
297+
visitSwitchEnumTermInst(seai);
299298
}
300299

301300
void visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi) {

0 commit comments

Comments
 (0)