Skip to content

Commit a2ae0f2

Browse files
committed
Finish parsing/printing/serialization.
1 parent ab304f1 commit a2ae0f2

14 files changed

+509
-38
lines changed

include/swift/AST/DiagnosticsParse.def

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,20 @@ ERROR(sil_witness_assoc_conf_not_found,none,
686686
ERROR(sil_witness_protocol_conformance_not_found,none,
687687
"sil protocol conformance not found", ())
688688

689+
// [differentiable ...] (sil-decl attr)
690+
ERROR(sil_diff_witness_expected_keyword,PointsToFirstBadToken,
691+
"expected '%0' in differentiability witness", (StringRef))
692+
ERROR(sil_diff_witness_expected_parameter_list,PointsToFirstBadToken,
693+
"expected an comma-separated list of parameter indices, e.g. (0, 1)", ())
694+
ERROR(sil_diff_witness_expected_rsquare,PointsToFirstBadToken,
695+
"expected ']' to end 'differentiable' attribute", ())
696+
ERROR(sil_diff_witness_expected_parameter_index,PointsToFirstBadToken,
697+
"expected the index of a parameter to differentiate w.r.t.", ())
698+
ERROR(sil_diff_witness_expected_source_index,PointsToFirstBadToken,
699+
"expected the index of a result to differentiate from", ())
700+
701+
// SIL differentiability witnesses
702+
689703
// SIL Coverage Map
690704
ERROR(sil_coverage_func_not_found, none,
691705
"sil function not found %0", (Identifier))

include/swift/SIL/SILDifferentiabilityWitness.h

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//
33
// This source file is part of the Swift.org open source project
44
//
5-
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
5+
// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
66
// Licensed under Apache License v2.0 with Runtime Library Exception
77
//
88
// See https://swift.org/LICENSE.txt for license information
@@ -42,8 +42,10 @@ class SILDifferentiabilityWitness
4242
public SILAllocated<SILDifferentiabilityWitness>
4343
{
4444
private:
45-
/// The module which contains the SIL differentiability witness.
45+
/// The module which contains the differentiability witness.
4646
SILModule &module;
47+
/// The linkage of the differentiability witness.
48+
SILLinkage linkage;
4749
/// The original function.
4850
SILFunction *originalFunction;
4951
/// The parameter indices.
@@ -60,19 +62,31 @@ class SILDifferentiabilityWitness
6062
/// devirtualization from another module.
6163
bool serialized;
6264

63-
SILDifferentiabilityWitness(SILModule &module, SILFunction *originalFunction,
65+
SILDifferentiabilityWitness(SILModule &module, SILLinkage linkage,
66+
SILFunction *originalFunction,
6467
AutoDiffIndexSubset *parameterIndices,
6568
AutoDiffIndexSubset *resultIndices,
6669
GenericSignature *derivativeGenSig,
6770
SILFunction *jvp, SILFunction *vjp,
6871
bool isSerialized)
69-
: module(module), originalFunction(originalFunction),
72+
: module(module), linkage(linkage), originalFunction(originalFunction),
7073
parameterIndices(parameterIndices), resultIndices(resultIndices),
7174
derivativeGenericSignature(derivativeGenSig), jvp(jvp), vjp(vjp),
7275
serialized(isSerialized) {}
7376

7477
public:
78+
/// The key type, used for uniquing `SILDifferentiabilityWitness` in
79+
/// `SILModule`, original function, parameter indices, result indices, and
80+
/// derivative generic signature.
81+
using Key = std::tuple<const SILFunction *, AutoDiffIndexSubset *,
82+
AutoDiffIndexSubset *, GenericSignature *>;
83+
Key getKey() {
84+
return std::make_tuple(originalFunction, parameterIndices, resultIndices,
85+
derivativeGenericSignature);
86+
}
87+
7588
SILModule &getModule() const { return module; }
89+
SILLinkage getLinkage() const { return linkage; }
7690
SILFunction *getOriginalFunction() const { return originalFunction; }
7791
AutoDiffIndexSubset *getParameterIndices() const {
7892
return parameterIndices;
@@ -88,7 +102,7 @@ class SILDifferentiabilityWitness
88102
bool isSerialized() const { return serialized; }
89103

90104
static SILDifferentiabilityWitness *create(
91-
SILModule &module, SILFunction *originalFunction,
105+
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
92106
AutoDiffIndexSubset *parameterIndices, AutoDiffIndexSubset *resultIndices,
93107
GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
94108
bool isSerialized);

include/swift/SIL/SILModule.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ class SILModule {
145145
friend SILProperty;
146146
friend SILUndef;
147147
friend SILWitnessTable;
148+
// SWIFT_ENABLE_TENSORFLOW
149+
friend SILDifferentiabilityWitness;
150+
// SWIFT_ENABLE_TENSORFLOW END
148151
friend Lowering::SILGenModule;
149152
friend Lowering::TypeConverter;
150153
class SerializationCallback;
@@ -202,10 +205,9 @@ class SILModule {
202205

203206
// SWIFT_ENABLE_TENSORFLOW
204207
/// Lookup table for SIL differentiability witnesses from original functions.
205-
/// Indexed by original function, parameter indices, result indices, and
206-
/// derivative generic signature.
207-
llvm::DenseMap<std::tuple<const SILFunction *, AutoDiffIndexSubset *,
208-
AutoDiffIndexSubset *, GenericSignature *>,
208+
/// Indexed by key type: original function, parameter indices, result indices,
209+
/// and derivative generic signature.
210+
llvm::DenseMap<SILDifferentiabilityWitness::Key,
209211
SILDifferentiabilityWitness *>
210212
DifferentiabilityWitnessMap;
211213

include/swift/Serialization/SerializedSILLoader.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ class SerializedSILLoader {
9999
/// Deserialize all Properties in all SILModules.
100100
void getAllProperties();
101101

102+
// SWIFT_ENABLE_TENSORFLOW
103+
/// Deserialize all DifferentiabilityWitnesses in all SILModules.
104+
void getAllDifferentiabilityWitnesses();
105+
// SWIFT_ENABLE_TENSORFLOW END
106+
102107
SerializedSILLoader(const SerializedSILLoader &) = delete;
103108
SerializedSILLoader(SerializedSILLoader &&) = delete;
104109
SerializedSILLoader &operator=(const SerializedSILLoader &) = delete;

lib/Parse/ParseDecl.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ bool Parser::parseTopLevel() {
9494
CASE_SIL(sil_global, SILGlobal)
9595
CASE_SIL(sil_witness_table, SILWitnessTable)
9696
CASE_SIL(sil_default_witness_table, SILDefaultWitnessTable)
97+
// SWIFT_ENABLE_TENSORFLOW
98+
CASE_SIL(sil_differentiability_witness, SILDifferentiabilityWitness)
99+
// SWIFT_ENABLE_TENSORFLOW END
97100
CASE_SIL(sil_coverage_map, SILCoverageMap)
98101
CASE_SIL(sil_property, SILProperty)
99102
CASE_SIL(sil_scope, SILScope)

lib/ParseSIL/ParseSIL.cpp

Lines changed: 211 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6745,21 +6745,225 @@ bool SILParserTUState::parseSILDefaultWitnessTable(Parser &P) {
67456745
return false;
67466746
}
67476747

6748+
// SWIFT_ENABLE_TENSORFLOW
6749+
// TODO: Dedupe with `SILParser::convertRequirements` upstream.
6750+
// Consider defining this as `Parser::convertRequirements`.
6751+
static void convertRequirements(Parser &P, SILFunction *F,
6752+
ArrayRef<RequirementRepr> From,
6753+
SmallVectorImpl<Requirement> &To) {
6754+
if (From.empty()) {
6755+
To.clear();
6756+
return;
6757+
}
6758+
6759+
auto *GenericEnv = F->getLoweredFunctionType()
6760+
->getGenericSignature()
6761+
->getGenericEnvironment();
6762+
assert(GenericEnv);
6763+
(void)GenericEnv;
6764+
6765+
IdentTypeReprLookup PerformLookup(P);
6766+
// Use parser lexical scopes to resolve references
6767+
// to the generic parameters.
6768+
auto ResolveToInterfaceType = [&](TypeLoc Ty) -> Type {
6769+
Ty.getTypeRepr()->walk(PerformLookup);
6770+
swift::performTypeLocChecking(P.Context, Ty, /*isSILMode*/ true,
6771+
/*isSILType*/ true, GenericEnv, &P.SF);
6772+
assert(Ty.getType());
6773+
return Ty.getType()->mapTypeOutOfContext();
6774+
};
6775+
6776+
for (auto &Req : From) {
6777+
if (Req.getKind() == RequirementReprKind::SameType) {
6778+
auto FirstType = ResolveToInterfaceType(Req.getFirstTypeLoc());
6779+
auto SecondType = ResolveToInterfaceType(Req.getSecondTypeLoc());
6780+
Requirement ConvertedRequirement(RequirementKind::SameType, FirstType,
6781+
SecondType);
6782+
To.push_back(ConvertedRequirement);
6783+
continue;
6784+
}
6785+
6786+
if (Req.getKind() == RequirementReprKind::TypeConstraint) {
6787+
auto Subject = ResolveToInterfaceType(Req.getSubjectLoc());
6788+
auto Constraint = ResolveToInterfaceType(Req.getConstraintLoc());
6789+
Requirement ConvertedRequirement(RequirementKind::Conformance, Subject,
6790+
Constraint);
6791+
To.push_back(ConvertedRequirement);
6792+
continue;
6793+
}
6794+
6795+
if (Req.getKind() == RequirementReprKind::LayoutConstraint) {
6796+
auto Subject = ResolveToInterfaceType(Req.getSubjectLoc());
6797+
Requirement ConvertedRequirement(RequirementKind::Layout, Subject,
6798+
Req.getLayoutConstraint());
6799+
To.push_back(ConvertedRequirement);
6800+
continue;
6801+
}
6802+
llvm_unreachable("Unsupported requirement kind");
6803+
}
6804+
}
6805+
67486806
/// decl-sil-differentiability-witness ::=
67496807
/// 'sil_differentiability_witness'
6750-
/// sil-function-name
6751-
/// 'wrt' autodiff-index-subset
6752-
/// 'sources' autodiff-index-subset
6753-
/// ('derivative_generic_signature' generic-signature)?
6754-
/// '{' ('jvp' sil-function-name)? ('vjp' sil-function-name)? '}'
6808+
/// sil-function-name ':' sil-type
6809+
/// 'parameters' autodiff-index-subset
6810+
/// 'results' autodiff-index-subset
6811+
/// ('where' generic-signature)?
6812+
/// '{'
6813+
/// ('jvp' sil-function-name ':' sil-type)?
6814+
/// ('vjp' sil-function-name ':' sil-type)?
6815+
/// '}'
67556816
///
67566817
/// autodiff-index-subset ::=
6757-
/// [0-9]+ (',', [0-9]+)*
6818+
/// '(' [0-9]+ (',', [0-9]+)* ')'
67586819
bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
67596820
P.consumeToken(tok::kw_sil_differentiability_witness);
6760-
// TODO(TF-867): Implement parsing. Test round-tripping with printing.
6821+
SILParser State(P);
6822+
6823+
// Parse the linkage.
6824+
Optional<SILLinkage> linkage;
6825+
if (parseSILLinkage(linkage, P))
6826+
return true;
6827+
if (!linkage)
6828+
linkage = SILLinkage::PublicExternal;
6829+
6830+
Scope S(&P, ScopeKind::TopLevel);
6831+
Scope Body(&P, ScopeKind::FunctionBody);
6832+
6833+
auto parseFunctionNameAndType = [&](SILFunction *&fn) -> bool {
6834+
Identifier name;
6835+
SILType ty;
6836+
SourceLoc fnNameLoc = P.Tok.getLoc();
6837+
// We need to turn on InSILBody to parse the function reference.
6838+
Lexer::SILBodyRAII tmp(*P.L);
6839+
GenericEnvironment *ignoredEnv;
6840+
if ((State.parseGlobalName(name)) ||
6841+
P.parseToken(tok::colon, diag::expected_sil_colon_value_ref) ||
6842+
State.parseSILType(ty, ignoredEnv, /*IsFuncDecl*/ true))
6843+
return true;
6844+
6845+
// The function doesn't exist yet. Create a zombie forward declaration.
6846+
auto fnType = ty.getAs<SILFunctionType>();
6847+
if (!fnType || !ty.isObject()) {
6848+
P.diagnose(fnNameLoc, diag::expected_sil_function_type);
6849+
return true;
6850+
}
6851+
fn = State.getGlobalNameForReference(name, fnType, fnNameLoc, true);
6852+
State.TUState.PotentialZombieFns.insert(fn);
6853+
return false;
6854+
};
6855+
6856+
SourceLoc lastLoc = P.getEndOfPreviousLoc();
6857+
6858+
SILFunction *originalFn;
6859+
if (parseFunctionNameAndType(originalFn))
6860+
return true;
6861+
6862+
auto parseAutoDiffIndexSubset =
6863+
[&](StringRef label, AutoDiffIndexSubset *& paramIndexSubset) -> bool {
6864+
if (P.parseSpecificIdentifier(
6865+
label, diag::sil_diff_witness_expected_keyword, label))
6866+
return true;
6867+
if (P.parseToken(tok::l_paren, diag::sil_diff_witness_expected_keyword,
6868+
"("))
6869+
return true;
6870+
// Parse parameter index list.
6871+
SmallVector<unsigned, 8> paramIndices;
6872+
// Function that parses an index into `paramIndices`. Returns true on error.
6873+
auto parseParam = [&]() -> bool {
6874+
unsigned index;
6875+
// TODO: Reject non-ascending parameter index lists.
6876+
if (P.parseUnsignedInteger(index, lastLoc,
6877+
diag::sil_diff_witness_expected_parameter_list))
6878+
return true;
6879+
paramIndices.push_back(index);
6880+
return false;
6881+
};
6882+
// Parse first.
6883+
if (parseParam())
6884+
return true;
6885+
// Parse rest.
6886+
while (P.consumeIf(tok::comma))
6887+
if (parseParam())
6888+
return true;
6889+
if (P.parseToken(tok::r_paren, diag::sil_diff_witness_expected_keyword,
6890+
"("))
6891+
return true;
6892+
auto maxIndexRef =
6893+
std::max_element(paramIndices.begin(), paramIndices.end());
6894+
paramIndexSubset = AutoDiffIndexSubset::get(
6895+
P.Context, maxIndexRef ? *maxIndexRef + 1 : 0, paramIndices);
6896+
return false;
6897+
};
6898+
AutoDiffIndexSubset *parameterIndices = nullptr;
6899+
AutoDiffIndexSubset *resultIndices = nullptr;
6900+
if (parseAutoDiffIndexSubset("parameters", parameterIndices))
6901+
return true;
6902+
if (parseAutoDiffIndexSubset("results", resultIndices))
6903+
return true;
6904+
6905+
GenericSignature *derivativeGenSig = nullptr;
6906+
// Parse a trailing 'where' clause if any.
6907+
if (P.Tok.is(tok::kw_where)) {
6908+
SourceLoc whereLoc;
6909+
SmallVector<RequirementRepr, 4> requirementReprs;
6910+
bool firstTypeInComplete;
6911+
P.parseGenericWhereClause(whereLoc, requirementReprs, firstTypeInComplete,
6912+
/*AllowLayoutConstraints*/ false);
6913+
auto *whereClause = TrailingWhereClause::create(
6914+
originalFn->getModule().getASTContext(), whereLoc, requirementReprs);
6915+
SmallVector<Requirement, 4> requirements;
6916+
convertRequirements(P, originalFn, whereClause->getRequirements(),
6917+
requirements);
6918+
assert(requirements.size() == requirementReprs.size());
6919+
derivativeGenSig = evaluateOrDefault(
6920+
P.Context.evaluator,
6921+
AbstractGenericSignatureRequest{
6922+
originalFn->getLoweredFunctionType()->getGenericSignature(),
6923+
/*addedGenericParams=*/{},
6924+
std::move(requirements)},
6925+
nullptr);
6926+
}
6927+
6928+
SILFunction *jvp = nullptr;
6929+
SILFunction *vjp = nullptr;
6930+
if (P.Tok.is(tok::l_brace)) {
6931+
SourceLoc LBraceLoc = P.Tok.getLoc();
6932+
P.consumeToken(tok::l_brace);
6933+
6934+
if (P.Tok.is(tok::identifier) && P.Tok.getText() == "jvp") {
6935+
P.consumeToken(tok::identifier);
6936+
if (P.parseToken(tok::colon, diag::sil_diff_witness_expected_keyword,
6937+
":"))
6938+
return true;
6939+
Scope Body(&P, ScopeKind::FunctionBody);
6940+
if (parseFunctionNameAndType(jvp))
6941+
return true;
6942+
}
6943+
6944+
if (P.Tok.is(tok::identifier) && P.Tok.getText() == "vjp") {
6945+
P.consumeToken(tok::identifier);
6946+
if (P.parseToken(tok::colon, diag::sil_diff_witness_expected_keyword,
6947+
":"))
6948+
return true;
6949+
Scope Body(&P, ScopeKind::FunctionBody);
6950+
if (parseFunctionNameAndType(vjp))
6951+
return true;
6952+
}
6953+
6954+
if (P.parseMatchingToken(tok::r_brace, lastLoc, diag::expected_sil_rbrace,
6955+
LBraceLoc))
6956+
return true;
6957+
}
6958+
6959+
// TODO: Parse `isSerialized` flag.
6960+
bool isSerialized = false;
6961+
SILDifferentiabilityWitness::create(
6962+
M, *linkage, originalFn, parameterIndices, resultIndices,
6963+
derivativeGenSig, jvp, vjp, isSerialized);
67616964
return false;
67626965
}
6966+
// SWIFT_ENABLE_TENSORFLOW END
67636967

67646968
llvm::Optional<llvm::coverage::Counter> SILParser::parseSILCoverageExpr(
67656969
llvm::coverage::CounterExpressionBuilder &Builder) {

lib/SIL/SILDifferentiabilityWitness.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//
33
// This source file is part of the Swift.org open source project
44
//
5-
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
5+
// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
66
// Licensed under Apache License v2.0 with Runtime Library Exception
77
//
88
// See https://swift.org/LICENSE.txt for license information
@@ -18,15 +18,19 @@
1818
using namespace swift;
1919

2020
SILDifferentiabilityWitness *SILDifferentiabilityWitness::create(
21-
SILModule &module, SILFunction *originalFunction,
21+
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
2222
AutoDiffIndexSubset *parameterIndices, AutoDiffIndexSubset *resultIndices,
2323
GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
2424
bool isSerialized) {
2525
void *buf = module.allocate(sizeof(SILDifferentiabilityWitness),
2626
alignof(SILDifferentiabilityWitness));
27-
SILDifferentiabilityWitness *dw = ::new (buf)
28-
SILDifferentiabilityWitness(module, originalFunction, parameterIndices,
29-
resultIndices, derivativeGenSig, jvp, vjp,
30-
isSerialized);
31-
return dw;
27+
auto *diffWitness = ::new (buf) SILDifferentiabilityWitness(
28+
module, linkage, originalFunction, parameterIndices, resultIndices,
29+
derivativeGenSig, jvp, vjp, isSerialized);
30+
// Register the differentiability witness in the module.
31+
assert(!module.DifferentiabilityWitnessMap.count(diffWitness->getKey()) &&
32+
"Cannot create duplicate differentiability witness in a module");
33+
module.DifferentiabilityWitnessMap[diffWitness->getKey()] = diffWitness;
34+
module.getDifferentiabilityWitnessList().push_back(diffWitness);
35+
return diffWitness;
3236
}

0 commit comments

Comments
 (0)