Skip to content

Commit 835f1c0

Browse files
committed
Clean up.
1 parent 843b631 commit 835f1c0

File tree

2 files changed

+19
-14
lines changed

2 files changed

+19
-14
lines changed

lib/ParseSIL/ParseSIL.cpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6830,7 +6830,8 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
68306830
Scope S(&P, ScopeKind::TopLevel);
68316831
Scope Body(&P, ScopeKind::FunctionBody);
68326832

6833-
auto parseFunctionNameAndType = [&](SILFunction *&fn) -> bool {
6833+
// Parse a SIL function name.
6834+
auto parseFunctionName = [&](SILFunction *&fn) -> bool {
68346835
Identifier name;
68356836
SILType ty;
68366837
SourceLoc fnNameLoc = P.Tok.getLoc();
@@ -6852,13 +6853,13 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
68526853
State.TUState.PotentialZombieFns.insert(fn);
68536854
return false;
68546855
};
6855-
6856-
SourceLoc lastLoc = P.getEndOfPreviousLoc();
6857-
6856+
// Parse original function name.
68586857
SILFunction *originalFn;
6859-
if (parseFunctionNameAndType(originalFn))
6858+
if (parseFunctionName(originalFn))
68606859
return true;
68616860

6861+
SourceLoc lastLoc = P.getEndOfPreviousLoc();
6862+
// Parse an index subset, prefaced with the given label.
68626863
auto parseAutoDiffIndexSubset =
68636864
[&](StringRef label, AutoDiffIndexSubset *& paramIndexSubset) -> bool {
68646865
if (P.parseSpecificIdentifier(
@@ -6895,15 +6896,17 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
68956896
P.Context, maxIndexRef ? *maxIndexRef + 1 : 0, paramIndices);
68966897
return false;
68976898
};
6899+
// Parse parameter and result indices.
68986900
AutoDiffIndexSubset *parameterIndices = nullptr;
68996901
AutoDiffIndexSubset *resultIndices = nullptr;
69006902
if (parseAutoDiffIndexSubset("parameters", parameterIndices))
69016903
return true;
69026904
if (parseAutoDiffIndexSubset("results", resultIndices))
69036905
return true;
69046906

6907+
// Parse a trailing 'where' clause (optional).
6908+
// This represents derivative generic signature requirements.
69056909
GenericSignature *derivativeGenSig = nullptr;
6906-
// Parse a trailing 'where' clause if any.
69076910
if (P.Tok.is(tok::kw_where)) {
69086911
SourceLoc whereLoc;
69096912
SmallVector<RequirementRepr, 4> requirementReprs;
@@ -6925,34 +6928,36 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
69256928
nullptr);
69266929
}
69276930

6931+
// Parse differentiability witness body.
69286932
SILFunction *jvp = nullptr;
69296933
SILFunction *vjp = nullptr;
69306934
if (P.Tok.is(tok::l_brace)) {
6931-
SourceLoc LBraceLoc = P.Tok.getLoc();
6935+
// Parse '{'.
6936+
SourceLoc lBraceLoc = P.Tok.getLoc();
69326937
P.consumeToken(tok::l_brace);
6933-
6938+
// Parse JVP (optional).
69346939
if (P.Tok.is(tok::identifier) && P.Tok.getText() == "jvp") {
69356940
P.consumeToken(tok::identifier);
69366941
if (P.parseToken(tok::colon, diag::sil_diff_witness_expected_keyword,
69376942
":"))
69386943
return true;
69396944
Scope Body(&P, ScopeKind::FunctionBody);
6940-
if (parseFunctionNameAndType(jvp))
6945+
if (parseFunctionName(jvp))
69416946
return true;
69426947
}
6943-
6948+
// Parse VJP (optional).
69446949
if (P.Tok.is(tok::identifier) && P.Tok.getText() == "vjp") {
69456950
P.consumeToken(tok::identifier);
69466951
if (P.parseToken(tok::colon, diag::sil_diff_witness_expected_keyword,
69476952
":"))
69486953
return true;
69496954
Scope Body(&P, ScopeKind::FunctionBody);
6950-
if (parseFunctionNameAndType(vjp))
6955+
if (parseFunctionName(vjp))
69516956
return true;
69526957
}
6953-
6958+
// Parse '}'.
69546959
if (P.parseMatchingToken(tok::r_brace, lastLoc, diag::expected_sil_rbrace,
6955-
LBraceLoc))
6960+
lBraceLoc))
69566961
return true;
69576962
}
69586963

test/AutoDiff/sil_differentiability_witness_parse.sil

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ sil_differentiability_witness hidden @foo : $@convention(thin) <τ_0_0> (@in_gua
120120
vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float))
121121
}
122122

123-
// CHECK: sil_differentiability_witness hidden @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 parameters (0, 1) results (0) where τ_0_0 : _Differentiable {
123+
// CHECK-LABEL: sil_differentiability_witness hidden @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 parameters (0, 1) results (0) where τ_0_0 : _Differentiable {
124124
// CHECK: jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector)
125125
// CHECK: vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float))
126126
// CHECK: }

0 commit comments

Comments
 (0)