@@ -6805,17 +6805,17 @@ static void convertRequirements(Parser &P, SILFunction *F,
6805
6805
6806
6806
// / decl-sil-differentiability-witness ::=
6807
6807
// / 'sil_differentiability_witness'
6808
+ // / '[' 'parameters' index-subset ']'
6809
+ // / '[' 'results' index-subset ']'
6810
+ // / ('[' 'where' derivatve-generic-signature-requirements ']')?
6808
6811
// / sil-function-name ':' sil-type
6809
- // / 'parameters' autodiff-index-subset
6810
- // / 'results' autodiff-index-subset
6811
- // / ('where' generic-signature)?
6812
6812
// / '{'
6813
6813
// / ('jvp' sil-function-name ':' sil-type)?
6814
6814
// / ('vjp' sil-function-name ':' sil-type)?
6815
6815
// / '}'
6816
6816
// /
6817
- // / autodiff- index-subset ::=
6818
- // / '(' [0-9]+ (',', [0-9]+)* ')'
6817
+ // / index-subset ::=
6818
+ // / [0-9]+ (' ' [0-9]+)*
6819
6819
bool SILParserTUState::parseSILDifferentiabilityWitness (Parser &P) {
6820
6820
P.consumeToken (tok::kw_sil_differentiability_witness);
6821
6821
SILParser State (P);
@@ -6853,21 +6853,17 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
6853
6853
State.TUState .PotentialZombieFns .insert (fn);
6854
6854
return false ;
6855
6855
};
6856
- // Parse original function name.
6857
- SILFunction *originalFn;
6858
- if (parseFunctionName (originalFn))
6859
- return true ;
6860
6856
6861
6857
SourceLoc lastLoc = P.getEndOfPreviousLoc ();
6862
6858
// Parse an index subset, prefaced with the given label.
6863
6859
auto parseIndexSubset =
6864
6860
[&](StringRef label, IndexSubset *& indexSubset) -> bool {
6861
+ if (P.parseToken (tok::l_square, diag::sil_diff_witness_expected_keyword,
6862
+ " [" ))
6863
+ return true ;
6865
6864
if (P.parseSpecificIdentifier (
6866
6865
label, diag::sil_diff_witness_expected_keyword, label))
6867
6866
return true ;
6868
- if (P.parseToken (tok::l_paren, diag::sil_diff_witness_expected_keyword,
6869
- " (" ))
6870
- return true ;
6871
6867
// Parse parameter index list.
6872
6868
SmallVector<unsigned , 8 > paramIndices;
6873
6869
// Function that parses an index into `paramIndices`. Returns true on error.
@@ -6884,11 +6880,11 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
6884
6880
if (parseParam ())
6885
6881
return true ;
6886
6882
// Parse rest.
6887
- while (P.consumeIf (tok::comma ))
6883
+ while (P.Tok . isNot (tok::r_square ))
6888
6884
if (parseParam ())
6889
6885
return true ;
6890
- if (P.parseToken (tok::r_paren , diag::sil_diff_witness_expected_keyword,
6891
- " ( " ))
6886
+ if (P.parseToken (tok::r_square , diag::sil_diff_witness_expected_keyword,
6887
+ " ] " ))
6892
6888
return true ;
6893
6889
auto maxIndexRef =
6894
6890
std::max_element (paramIndices.begin (), paramIndices.end ());
@@ -6907,18 +6903,33 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
6907
6903
// Parse a trailing 'where' clause (optional).
6908
6904
// This represents derivative generic signature requirements.
6909
6905
GenericSignature *derivativeGenSig = nullptr ;
6910
- if (P.Tok .is (tok::kw_where)) {
6911
- SourceLoc whereLoc;
6912
- SmallVector<RequirementRepr, 4 > requirementReprs;
6906
+ SourceLoc whereLoc;
6907
+ SmallVector<RequirementRepr, 4 > derivativeRequirementReprs;
6908
+ if (P.Tok .is (tok::l_square) && P.peekToken ().is (tok::kw_where)) {
6909
+ P.consumeToken (tok::l_square);
6913
6910
bool firstTypeInComplete;
6914
- P.parseGenericWhereClause (whereLoc, requirementReprs, firstTypeInComplete,
6911
+ P.parseGenericWhereClause (whereLoc, derivativeRequirementReprs,
6912
+ firstTypeInComplete,
6915
6913
/* AllowLayoutConstraints*/ false );
6916
- auto *whereClause = TrailingWhereClause::create (
6917
- originalFn->getModule ().getASTContext (), whereLoc, requirementReprs);
6914
+ if (P.parseToken (tok::r_square, diag::sil_diff_witness_expected_keyword,
6915
+ " ]" ))
6916
+ return true ;
6917
+ }
6918
+
6919
+ // Parse original function name.
6920
+ SILFunction *originalFn;
6921
+ if (parseFunctionName (originalFn))
6922
+ return true ;
6923
+
6924
+ // Resolve derivative requirements.
6925
+ if (!derivativeRequirementReprs.empty ()) {
6918
6926
SmallVector<Requirement, 4 > requirements;
6927
+ auto *whereClause = TrailingWhereClause::create (
6928
+ originalFn->getModule ().getASTContext (), whereLoc,
6929
+ derivativeRequirementReprs);
6919
6930
convertRequirements (P, originalFn, whereClause->getRequirements (),
6920
6931
requirements);
6921
- assert (requirements.size () == requirementReprs .size ());
6932
+ assert (requirements.size () == derivativeRequirementReprs .size ());
6922
6933
derivativeGenSig = evaluateOrDefault (
6923
6934
P.Context .evaluator ,
6924
6935
AbstractGenericSignatureRequest{
0 commit comments