@@ -992,7 +992,7 @@ static bool parseDifferentiableAttr(
992
992
if (P.parseSpecificIdentifier (
993
993
" source" , diag::sil_attr_differentiable_expected_keyword, " source" ) ||
994
994
P.parseUnsignedInteger (SourceIndex, LastLoc,
995
- diag::sil_attr_differentiable_expected_source_index ))
995
+ diag::sil_autodiff_expected_result_index ))
996
996
return true ;
997
997
// Parse 'wrt'.
998
998
if (P.parseSpecificIdentifier (
@@ -1005,7 +1005,7 @@ static bool parseDifferentiableAttr(
1005
1005
unsigned Index;
1006
1006
// TODO: Reject non-ascending parameter index lists.
1007
1007
if (P.parseUnsignedInteger (Index, LastLoc,
1008
- diag::sil_attr_differentiable_expected_parameter_index ))
1008
+ diag::sil_autodiff_expected_parameter_index ))
1009
1009
return true ;
1010
1010
ParamIndices.push_back (Index);
1011
1011
return false ;
@@ -2939,12 +2939,11 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
2939
2939
while (P.Tok .is (tok::integer_literal)) {
2940
2940
unsigned index;
2941
2941
if (P.parseUnsignedInteger (index, lastLoc,
2942
- diag::sil_inst_autodiff_expected_parameter_index ))
2942
+ diag::sil_autodiff_expected_parameter_index ))
2943
2943
return true ;
2944
2944
parameterIndices.push_back (index);
2945
2945
}
2946
- if (P.parseToken (tok::r_square,
2947
- diag::sil_inst_autodiff_attr_expected_rsquare,
2946
+ if (P.parseToken (tok::r_square, diag::sil_autodiff_expected_rsquare,
2948
2947
" parameter index list" ))
2949
2948
return true ;
2950
2949
}
@@ -3002,12 +3001,11 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
3002
3001
while (P.Tok .is (tok::integer_literal)) {
3003
3002
unsigned index;
3004
3003
if (P.parseUnsignedInteger (index, lastLoc,
3005
- diag::sil_inst_autodiff_expected_parameter_index ))
3004
+ diag::sil_autodiff_expected_parameter_index ))
3006
3005
return true ;
3007
3006
parameterIndices.push_back (index);
3008
3007
}
3009
- if (P.parseToken (tok::r_square,
3010
- diag::sil_inst_autodiff_attr_expected_rsquare,
3008
+ if (P.parseToken (tok::r_square, diag::sil_autodiff_expected_rsquare,
3011
3009
" parameter index list" ))
3012
3010
return true ;
3013
3011
}
@@ -3050,8 +3048,7 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
3050
3048
diag::sil_inst_autodiff_expected_differentiable_extractee_kind) ||
3051
3049
parseSILIdentifierSwitch (extractee, extracteeNames,
3052
3050
diag::sil_inst_autodiff_expected_differentiable_extractee_kind) ||
3053
- P.parseToken (tok::r_square,
3054
- diag::sil_inst_autodiff_attr_expected_rsquare,
3051
+ P.parseToken (tok::r_square, diag::sil_autodiff_expected_rsquare,
3055
3052
" extractee kind" ))
3056
3053
return true ;
3057
3054
if (parseTypedValueRef (functionOperand, B) ||
@@ -3073,8 +3070,7 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
3073
3070
diag::sil_inst_autodiff_expected_linear_extractee_kind) ||
3074
3071
parseSILIdentifierSwitch (extractee, extracteeNames,
3075
3072
diag::sil_inst_autodiff_expected_linear_extractee_kind) ||
3076
- P.parseToken (tok::r_square,
3077
- diag::sil_inst_autodiff_attr_expected_rsquare,
3073
+ P.parseToken (tok::r_square, diag::sil_autodiff_expected_rsquare,
3078
3074
" extractee kind" ))
3079
3075
return true ;
3080
3076
if (parseTypedValueRef (functionOperand, B) ||
@@ -6941,47 +6937,35 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
6941
6937
};
6942
6938
6943
6939
SourceLoc lastLoc = P.getEndOfPreviousLoc ();
6944
- // Parse an index subset, prefaced with the given label.
6945
- auto parseIndexSubset =
6946
- [&](StringRef label, IndexSubset *& indexSubset) -> bool {
6947
- if (P.parseToken (tok::l_square, diag::sil_diff_witness_expected_token, " [" ))
6948
- return true ;
6949
- if (P.parseSpecificIdentifier (
6950
- label, diag::sil_diff_witness_expected_token, label))
6951
- return true ;
6952
- // Parse parameter index list.
6953
- SmallVector<unsigned , 8 > paramIndices;
6954
- // Function that parses an index into `paramIndices`. Returns true on error.
6955
- auto parseParam = [&]() -> bool {
6940
+ // Parse an index set, prefaced with the given label.
6941
+ auto parseIndexSet = [&](StringRef label, SmallVectorImpl<unsigned > &indices,
6942
+ const Diagnostic &parseIndexDiag) -> bool {
6943
+ // Parse `[<label> <integer_literal>...]`.
6944
+ if (P.parseToken (tok::l_square, diag::sil_autodiff_expected_lsquare,
6945
+ " index list" ) ||
6946
+ P.parseSpecificIdentifier (
6947
+ label, diag::sil_autodiff_expected_index_list_label, label))
6948
+ return true ;
6949
+ while (P.Tok .is (tok::integer_literal)) {
6956
6950
unsigned index;
6957
- // TODO: Reject non-ascending index lists.
6958
- if (P.parseUnsignedInteger (index, lastLoc,
6959
- diag::sil_diff_witness_expected_index_list))
6951
+ if (P.parseUnsignedInteger (index, lastLoc, parseIndexDiag))
6960
6952
return true ;
6961
- paramIndices.push_back (index);
6962
- return false ;
6963
- };
6964
- // Parse first.
6965
- if (parseParam ())
6966
- return true ;
6967
- // Parse rest.
6968
- while (P.Tok .isNot (tok::r_square))
6969
- if (parseParam ())
6970
- return true ;
6971
- if (P.parseToken (tok::r_square, diag::sil_diff_witness_expected_token, " ]" ))
6953
+ indices.push_back (index);
6954
+ }
6955
+ if (P.parseToken (tok::r_square, diag::sil_autodiff_expected_rsquare,
6956
+ " index list" ))
6972
6957
return true ;
6973
- auto maxIndexRef =
6974
- std::max_element (paramIndices.begin (), paramIndices.end ());
6975
- indexSubset = IndexSubset::get (
6976
- P.Context , maxIndexRef ? *maxIndexRef + 1 : 0 , paramIndices);
6977
6958
return false ;
6978
6959
};
6979
6960
// Parse parameter and result indices.
6980
- IndexSubset *parameterIndices = nullptr ;
6981
- IndexSubset *resultIndices = nullptr ;
6982
- if (parseIndexSubset (" parameters" , parameterIndices))
6961
+ SmallVector<unsigned , 8 > parameterIndices;
6962
+ SmallVector<unsigned , 8 > resultIndices;
6963
+ // Parse parameter and result indices.
6964
+ if (parseIndexSet (" parameters" , parameterIndices,
6965
+ diag::sil_autodiff_expected_parameter_index))
6983
6966
return true ;
6984
- if (parseIndexSubset (" results" , resultIndices))
6967
+ if (parseIndexSet (" results" , resultIndices,
6968
+ diag::sil_autodiff_expected_result_index))
6985
6969
return true ;
6986
6970
6987
6971
// Parse a trailing 'where' clause (optional).
@@ -6995,7 +6979,8 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
6995
6979
P.parseGenericWhereClause (whereLoc, derivativeRequirementReprs,
6996
6980
firstTypeInComplete,
6997
6981
/* AllowLayoutConstraints*/ false );
6998
- if (P.parseToken (tok::r_square, diag::sil_diff_witness_expected_token, " ]" ))
6982
+ if (P.parseToken (tok::r_square, diag::sil_autodiff_expected_rsquare,
6983
+ " where clause" ))
6999
6984
return true ;
7000
6985
}
7001
6986
@@ -7053,8 +7038,13 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
7053
7038
return true ;
7054
7039
}
7055
7040
7041
+ auto origFnType = originalFn->getLoweredFunctionType ();
7042
+ auto *parameterIndexSet = IndexSubset::get (
7043
+ P.Context , origFnType->getNumParameters (), parameterIndices);
7044
+ auto *resultIndexSet = IndexSubset::get (
7045
+ P.Context , origFnType->getNumResults (), resultIndices);
7056
7046
SILDifferentiabilityWitness::create (
7057
- M, *linkage, originalFn, parameterIndices, resultIndices ,
7047
+ M, *linkage, originalFn, parameterIndexSet, resultIndexSet ,
7058
7048
derivativeGenSig, jvp, vjp, isSerialized);
7059
7049
return false ;
7060
7050
}
0 commit comments