Skip to content

Commit 2073f86

Browse files
committed
Address review feedback.
1 parent c3959ad commit 2073f86

File tree

7 files changed

+38
-46
lines changed

7 files changed

+38
-46
lines changed

include/swift/AST/DiagnosticsParse.def

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -687,12 +687,12 @@ ERROR(sil_witness_protocol_conformance_not_found,none,
687687
"sil protocol conformance not found", ())
688688

689689
// SIL differentiability witnesses
690-
ERROR(sil_diff_witness_expected_keyword,PointsToFirstBadToken,
690+
ERROR(sil_diff_witness_expected_token,PointsToFirstBadToken,
691691
"expected '%0' in differentiability witness", (StringRef))
692692
ERROR(sil_diff_witness_expected_index_list,PointsToFirstBadToken,
693693
"expected a space-separated list of indices, e.g. '0 1'", ())
694694
ERROR(sil_diff_witness_expected_parameter_index,PointsToFirstBadToken,
695-
"expected a parameter index to differentiate with respect to.", ())
695+
"expected a parameter index to differentiate with respect to", ())
696696
ERROR(sil_diff_witness_expected_result_index,PointsToFirstBadToken,
697697
"expected a result index to differentiate with respect to", ())
698698

include/swift/SIL/SILDifferentiabilityWitness.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,6 @@ class SILDifferentiabilityWitness
6767
/// deserialized.
6868
DeclAttribute *attribute = nullptr;
6969

70-
static AutoDiffConfig *
71-
getAutoDiffConfig(SILModule &module, IndexSubset *parameterIndices,
72-
IndexSubset *resultIndices,
73-
GenericSignature *derivativeGenSig);
74-
7570
SILDifferentiabilityWitness(SILModule &module, SILLinkage linkage,
7671
SILFunction *originalFunction,
7772
IndexSubset *parameterIndices,
@@ -112,6 +107,8 @@ class SILDifferentiabilityWitness
112107
case AutoDiffDerivativeFunctionKind::VJP: return vjp;
113108
}
114109
}
110+
void setJVP(SILFunction *jvp) { this->jvp = jvp; }
111+
void setVJP(SILFunction *vjp) { this->vjp = vjp; }
115112
void setDerivative(AutoDiffDerivativeFunctionKind kind,
116113
SILFunction *derivative) {
117114
switch (kind) {
@@ -123,9 +120,9 @@ class SILDifferentiabilityWitness
123120
DeclAttribute *getAttribute() const { return attribute; }
124121

125122
/// Verify that the differentiability witness is well-formed.
126-
void verify(const SILModule &M) const;
123+
void verify(const SILModule &module) const;
127124

128-
void print(llvm::raw_ostream &OS, bool verbose = false) const;
125+
void print(llvm::raw_ostream &os, bool verbose = false) const;
129126
void dump() const;
130127
};
131128

lib/ParseSIL/ParseSIL.cpp

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6747,7 +6747,10 @@ bool SILParserTUState::parseSILDefaultWitnessTable(Parser &P) {
67476747

67486748
// SWIFT_ENABLE_TENSORFLOW
67496749
// TODO(TF-893): Dedupe with `SILParser::convertRequirements` upstream.
6750-
// Consider defining this as `Parser::convertRequirements`.
6750+
// Currently, this utility is defined on `SILParser`, but SIL differentiability
6751+
// witness is defined on `SILParserTUState` and only has access to `Parser`.
6752+
// Consider redefining `SILParser::convertRequirements`as
6753+
// `Parser::convertRequirements`.
67516754
static void convertRequirements(Parser &P, SILFunction *F,
67526755
ArrayRef<RequirementRepr> From,
67536756
SmallVectorImpl<Requirement> &To) {
@@ -6827,8 +6830,8 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
68276830
if (!linkage)
68286831
linkage = SILLinkage::PublicExternal;
68296832

6830-
Scope S(&P, ScopeKind::TopLevel);
6831-
Scope Body(&P, ScopeKind::FunctionBody);
6833+
Scope scope(&P, ScopeKind::TopLevel);
6834+
Scope body(&P, ScopeKind::FunctionBody);
68326835

68336836
// Parse a SIL function name.
68346837
auto parseFunctionName = [&](SILFunction *&fn) -> bool {
@@ -6858,11 +6861,10 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
68586861
// Parse an index subset, prefaced with the given label.
68596862
auto parseIndexSubset =
68606863
[&](StringRef label, IndexSubset *& indexSubset) -> bool {
6861-
if (P.parseToken(tok::l_square, diag::sil_diff_witness_expected_keyword,
6862-
"["))
6864+
if (P.parseToken(tok::l_square, diag::sil_diff_witness_expected_token, "["))
68636865
return true;
68646866
if (P.parseSpecificIdentifier(
6865-
label, diag::sil_diff_witness_expected_keyword, label))
6867+
label, diag::sil_diff_witness_expected_token, label))
68666868
return true;
68676869
// Parse parameter index list.
68686870
SmallVector<unsigned, 8> paramIndices;
@@ -6883,8 +6885,7 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
68836885
while (P.Tok.isNot(tok::r_square))
68846886
if (parseParam())
68856887
return true;
6886-
if (P.parseToken(tok::r_square, diag::sil_diff_witness_expected_keyword,
6887-
"]"))
6888+
if (P.parseToken(tok::r_square, diag::sil_diff_witness_expected_token, "]"))
68886889
return true;
68896890
auto maxIndexRef =
68906891
std::max_element(paramIndices.begin(), paramIndices.end());
@@ -6911,8 +6912,7 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
69116912
P.parseGenericWhereClause(whereLoc, derivativeRequirementReprs,
69126913
firstTypeInComplete,
69136914
/*AllowLayoutConstraints*/ false);
6914-
if (P.parseToken(tok::r_square, diag::sil_diff_witness_expected_keyword,
6915-
"]"))
6915+
if (P.parseToken(tok::r_square, diag::sil_diff_witness_expected_token, "]"))
69166916
return true;
69176917
}
69186918

@@ -6944,25 +6944,23 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
69446944
SILFunction *vjp = nullptr;
69456945
if (P.Tok.is(tok::l_brace)) {
69466946
// Parse '{'.
6947-
SourceLoc lBraceLoc = P.Tok.getLoc();
6948-
P.consumeToken(tok::l_brace);
6947+
SourceLoc lBraceLoc;
6948+
P.consumeIf(tok::l_brace, lBraceLoc);
69496949
// Parse JVP (optional).
69506950
if (P.Tok.is(tok::identifier) && P.Tok.getText() == "jvp") {
69516951
P.consumeToken(tok::identifier);
6952-
if (P.parseToken(tok::colon, diag::sil_diff_witness_expected_keyword,
6953-
":"))
6952+
if (P.parseToken(tok::colon, diag::sil_diff_witness_expected_token, ":"))
69546953
return true;
6955-
Scope Body(&P, ScopeKind::FunctionBody);
6954+
Scope body(&P, ScopeKind::FunctionBody);
69566955
if (parseFunctionName(jvp))
69576956
return true;
69586957
}
69596958
// Parse VJP (optional).
69606959
if (P.Tok.is(tok::identifier) && P.Tok.getText() == "vjp") {
69616960
P.consumeToken(tok::identifier);
6962-
if (P.parseToken(tok::colon, diag::sil_diff_witness_expected_keyword,
6963-
":"))
6961+
if (P.parseToken(tok::colon, diag::sil_diff_witness_expected_token, ":"))
69646962
return true;
6965-
Scope Body(&P, ScopeKind::FunctionBody);
6963+
Scope body(&P, ScopeKind::FunctionBody);
69666964
if (parseFunctionName(vjp))
69676965
return true;
69686966
}

lib/SIL/SILDifferentiabilityWitness.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@ SILDifferentiabilityWitness *SILDifferentiabilityWitness::create(
2222
IndexSubset *parameterIndices, IndexSubset *resultIndices,
2323
GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
2424
bool isSerialized, DeclAttribute *attribute) {
25-
void *buf = module.allocate(sizeof(SILDifferentiabilityWitness),
26-
alignof(SILDifferentiabilityWitness));
27-
auto *diffWitness = ::new (buf) SILDifferentiabilityWitness(
25+
auto *diffWitness = new (module) SILDifferentiabilityWitness(
2826
module, linkage, originalFunction, parameterIndices, resultIndices,
2927
derivativeGenSig, jvp, vjp, isSerialized, attribute);
3028
// Register the differentiability witness in the module.

lib/SIL/SILPrinter.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3061,23 +3061,23 @@ void SILDefaultWitnessTable::dump() const {
30613061
void SILDifferentiabilityWitness::print(
30623062
llvm::raw_ostream &OS, bool verbose) const {
30633063
OS << "// differentiability witness for "
3064-
<< demangleSymbol(originalFunction->getName()) << "\n";
3065-
// sil_differentiability_witness @original-function-name : $original-sil-type
3064+
<< demangleSymbol(originalFunction->getName()) << '\n';
30663065
PrintOptions qualifiedSILTypeOptions = PrintOptions::printQualifiedSILType();
3066+
// sil_differentiability_witness (linkage)?
30673067
OS << "sil_differentiability_witness ";
30683068
printLinkage(OS, linkage, ForDefinition);
3069-
// [parameters 0 1 ...]
3069+
// [parameters ...]
30703070
OS << "[parameters ";
30713071
interleave(getParameterIndices()->getIndices(),
30723072
[&](unsigned index) { OS << index; },
3073-
[&] { OS << " "; });
3074-
// [results 0 1 ...]
3073+
[&] { OS << ' '; });
3074+
// [results ...]
30753075
OS << "] [results ";
30763076
interleave(getResultIndices()->getIndices(),
30773077
[&](unsigned index) { OS << index; },
3078-
[&] { OS << " "; });
3078+
[&] { OS << ' '; });
30793079
OS << ']';
3080-
// [where ...]
3080+
// ([where ...])?
30813081
if (auto *derivativeGenSig = getDerivativeGenericSignature()) {
30823082
ArrayRef<Requirement> requirements;
30833083
SmallVector<Requirement, 4> requirementsScratch;
@@ -3093,17 +3093,17 @@ void SILDifferentiabilityWitness::print(
30933093
}
30943094
if (!requirements.empty()) {
30953095
OS << " [where ";
3096-
auto SubPrinter = PrintOptions::printSIL();
3096+
auto subPrinter = PrintOptions::printSIL();
30973097
interleave(requirements,
30983098
[&](Requirement req) {
3099-
req.print(OS, SubPrinter);
3099+
req.print(OS, subPrinter);
31003100
return;
31013101
},
31023102
[&] { OS << ", "; });
31033103
OS << ']';
31043104
}
31053105
}
3106-
// original: @original-function-name : $original-sil-type
3106+
// @original-function-name : $original-sil-type
31073107
OS << " @" << originalFunction->getName() << " : "
31083108
<< originalFunction->getLoweredType();
31093109
// {
@@ -3112,9 +3112,9 @@ void SILDifferentiabilityWitness::print(
31123112
// }
31133113
OS << " {\n";
31143114
if (jvp)
3115-
OS << " jvp: @" << jvp->getName() << " : " << jvp->getLoweredType() << "\n";
3115+
OS << " jvp: @" << jvp->getName() << " : " << jvp->getLoweredType() << '\n';
31163116
if (vjp)
3117-
OS << " vjp: @" << vjp->getName() << " : " << vjp->getLoweredType() << "\n";
3117+
OS << " vjp: @" << vjp->getName() << " : " << vjp->getLoweredType() << '\n';
31183118
OS << "}\n\n";
31193119
}
31203120

lib/Serialization/DeserializeSIL.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3369,9 +3369,7 @@ SILDifferentiabilityWitness *SILDeserializer::lookupDifferentiabilityWitness(
33693369
auto iter = DifferentiabilityWitnessList->find(mangledDiffWitnessKey);
33703370
if (iter == DifferentiabilityWitnessList->end())
33713371
return nullptr;
3372-
3373-
auto *diffWitness = readDifferentiabilityWitness(*iter);
3374-
return diffWitness;
3372+
return readDifferentiabilityWitness(*iter);
33753373
}
33763374

33773375
void SILDeserializer::getAllDifferentiabilityWitnesses() {

lib/Serialization/SerializeSIL.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2300,6 +2300,7 @@ void SILSerializer::writeIndexTables() {
23002300
Offset.emit(ScratchRecord, sil_index_block::SIL_PROPERTY_OFFSETS,
23012301
PropertyOffset);
23022302
}
2303+
23032304
}
23042305

23052306
void SILSerializer::writeSILGlobalVar(const SILGlobalVariable &g) {
@@ -2517,7 +2518,7 @@ writeSILDifferentiabilityWitness(const SILDifferentiabilityWitness &dw) {
25172518
vjpID = S.addUniquedStringRef(vjp->getName());
25182519
}
25192520
SmallVector<unsigned, 8> parameterAndResultIndices(
2520-
dw.getParameterIndices()->begin(), dw.getParameterIndices()->end());
2521+
dw.getParameterIndices()->begin(), dw.getParameterIndices()->end());
25212522
parameterAndResultIndices.append(dw.getResultIndices()->begin(),
25222523
dw.getResultIndices()->end());
25232524
auto originalFnType = original->getLoweredFunctionType();

0 commit comments

Comments
 (0)