Skip to content

Commit 14ee6c5

Browse files
authored
[AutoDiff] Enable @derivative attribute qualified declaration names. (#28892)
Enable qualified declaration names in `@derivative` attribute, just like `@transpose` attribute. `DerivativeAttr` now stores a base type `TypeRepr *`, which is non-null for parsed attributes that reference a qualified original declaration. Add `TypeResolutionFlags::AllowModule` flag to enable module lookup via `TypeChecker::lookupMember` given a `ModuleType`. Add tests for type-qualified and module-qualified declaration names. Resolves TF-1058.
1 parent e47543c commit 14ee6c5

File tree

10 files changed

+85
-45
lines changed

10 files changed

+85
-45
lines changed

include/swift/AST/Attr.h

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,6 +1791,11 @@ class DerivativeAttr final
17911791
private llvm::TrailingObjects<DerivativeAttr, ParsedAutoDiffParameter> {
17921792
friend TrailingObjects;
17931793

1794+
/// The base type repr for the referenced original function. This field is
1795+
/// non-null only for parsed attributes that reference a qualified original
1796+
/// declaration. This field is not serialized; type-checking uses it to
1797+
/// resolve the original declaration, which is serialized.
1798+
TypeRepr *BaseTypeRepr;
17941799
/// The original function name.
17951800
DeclNameRefWithLoc OriginalFunctionName;
17961801
/// The original function declaration, resolved by the type checker.
@@ -1803,23 +1808,27 @@ class DerivativeAttr final
18031808
Optional<AutoDiffDerivativeFunctionKind> Kind = None;
18041809

18051810
explicit DerivativeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
1806-
DeclNameRefWithLoc original,
1811+
TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
18071812
ArrayRef<ParsedAutoDiffParameter> params);
18081813

18091814
explicit DerivativeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
1810-
DeclNameRefWithLoc original, IndexSubset *indices);
1815+
TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
1816+
IndexSubset *parameterIndices);
18111817

18121818
public:
18131819
static DerivativeAttr *create(ASTContext &context, bool implicit,
18141820
SourceLoc atLoc, SourceRange baseRange,
1821+
TypeRepr *baseTypeRepr,
18151822
DeclNameRefWithLoc original,
18161823
ArrayRef<ParsedAutoDiffParameter> params);
18171824

18181825
static DerivativeAttr *create(ASTContext &context, bool implicit,
18191826
SourceLoc atLoc, SourceRange baseRange,
1827+
TypeRepr *baseTypeRepr,
18201828
DeclNameRefWithLoc original,
1821-
IndexSubset *indices);
1829+
IndexSubset *parameterIndices);
18221830

1831+
TypeRepr *getBaseTypeRepr() const { return BaseTypeRepr; }
18231832
DeclNameRefWithLoc getOriginalFunctionName() const {
18241833
return OriginalFunctionName;
18251834
}
@@ -1876,9 +1885,10 @@ class TransposeAttr final
18761885
private llvm::TrailingObjects<TransposeAttr, ParsedAutoDiffParameter> {
18771886
friend TrailingObjects;
18781887

1879-
/// The base type of the original function.
1880-
/// This is non-null only when the original function is not top-level (i.e. it
1881-
/// is an instance/static method).
1888+
/// The base type repr for the referenced original function. This field is
1889+
/// non-null only for parsed attributes that reference a qualified original
1890+
/// declaration. This field is not serialized; type-checking uses it to
1891+
/// resolve the original declaration, which is serialized.
18821892
TypeRepr *BaseTypeRepr;
18831893
/// The original function name.
18841894
DeclNameRefWithLoc OriginalFunctionName;
@@ -1895,7 +1905,7 @@ class TransposeAttr final
18951905

18961906
explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
18971907
TypeRepr *baseType, DeclNameRefWithLoc original,
1898-
IndexSubset *indices);
1908+
IndexSubset *parameterIndices);
18991909

19001910
public:
19011911
static TransposeAttr *create(ASTContext &context, bool implicit,
@@ -1906,7 +1916,7 @@ class TransposeAttr final
19061916
static TransposeAttr *create(ASTContext &context, bool implicit,
19071917
SourceLoc atLoc, SourceRange baseRange,
19081918
TypeRepr *baseType, DeclNameRefWithLoc original,
1909-
IndexSubset *indices);
1919+
IndexSubset *parameterIndices);
19101920

19111921
TypeRepr *getBaseTypeRepr() const { return BaseTypeRepr; }
19121922
DeclNameRefWithLoc getOriginalFunctionName() const {

lib/AST/Attr.cpp

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,41 +1518,43 @@ void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D,
15181518
}
15191519

15201520
DerivativeAttr::DerivativeAttr(bool implicit, SourceLoc atLoc,
1521-
SourceRange baseRange,
1521+
SourceRange baseRange, TypeRepr *baseTypeRepr,
15221522
DeclNameRefWithLoc originalName,
15231523
ArrayRef<ParsedAutoDiffParameter> params)
15241524
: DeclAttribute(DAK_Derivative, atLoc, baseRange, implicit),
1525-
OriginalFunctionName(std::move(originalName)),
1525+
BaseTypeRepr(baseTypeRepr), OriginalFunctionName(std::move(originalName)),
15261526
NumParsedParameters(params.size()) {
15271527
std::copy(params.begin(), params.end(),
15281528
getTrailingObjects<ParsedAutoDiffParameter>());
15291529
}
15301530

15311531
DerivativeAttr::DerivativeAttr(bool implicit, SourceLoc atLoc,
1532-
SourceRange baseRange,
1532+
SourceRange baseRange, TypeRepr *baseTypeRepr,
15331533
DeclNameRefWithLoc originalName,
1534-
IndexSubset *indices)
1534+
IndexSubset *parameterIndices)
15351535
: DeclAttribute(DAK_Derivative, atLoc, baseRange, implicit),
1536-
OriginalFunctionName(std::move(originalName)), ParameterIndices(indices) {
1537-
}
1536+
BaseTypeRepr(baseTypeRepr), OriginalFunctionName(std::move(originalName)),
1537+
ParameterIndices(parameterIndices) {}
15381538

15391539
DerivativeAttr *
15401540
DerivativeAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
1541-
SourceRange baseRange, DeclNameRefWithLoc originalName,
1541+
SourceRange baseRange, TypeRepr *baseTypeRepr,
1542+
DeclNameRefWithLoc originalName,
15421543
ArrayRef<ParsedAutoDiffParameter> params) {
15431544
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(params.size());
15441545
void *mem = context.Allocate(size, alignof(DerivativeAttr));
1545-
return new (mem) DerivativeAttr(implicit, atLoc, baseRange,
1546+
return new (mem) DerivativeAttr(implicit, atLoc, baseRange, baseTypeRepr,
15461547
std::move(originalName), params);
15471548
}
15481549

15491550
DerivativeAttr *DerivativeAttr::create(ASTContext &context, bool implicit,
15501551
SourceLoc atLoc, SourceRange baseRange,
1552+
TypeRepr *baseTypeRepr,
15511553
DeclNameRefWithLoc originalName,
1552-
IndexSubset *indices) {
1554+
IndexSubset *parameterIndices) {
15531555
void *mem = context.Allocate(sizeof(DerivativeAttr), alignof(DerivativeAttr));
1554-
return new (mem) DerivativeAttr(implicit, atLoc, baseRange,
1555-
std::move(originalName), indices);
1556+
return new (mem) DerivativeAttr(implicit, atLoc, baseRange, baseTypeRepr,
1557+
std::move(originalName), parameterIndices);
15561558
}
15571559

15581560
TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc,
@@ -1568,10 +1570,11 @@ TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc,
15681570

15691571
TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc,
15701572
SourceRange baseRange, TypeRepr *baseTypeRepr,
1571-
DeclNameRefWithLoc originalName, IndexSubset *indices)
1573+
DeclNameRefWithLoc originalName,
1574+
IndexSubset *parameterIndices)
15721575
: DeclAttribute(DAK_Transpose, atLoc, baseRange, implicit),
15731576
BaseTypeRepr(baseTypeRepr), OriginalFunctionName(std::move(originalName)),
1574-
ParameterIndices(indices) {}
1577+
ParameterIndices(parameterIndices) {}
15751578

15761579
TransposeAttr *TransposeAttr::create(ASTContext &context, bool implicit,
15771580
SourceLoc atLoc, SourceRange baseRange,
@@ -1588,10 +1591,10 @@ TransposeAttr *TransposeAttr::create(ASTContext &context, bool implicit,
15881591
SourceLoc atLoc, SourceRange baseRange,
15891592
TypeRepr *baseType,
15901593
DeclNameRefWithLoc originalName,
1591-
IndexSubset *indices) {
1594+
IndexSubset *parameterIndices) {
15921595
void *mem = context.Allocate(sizeof(TransposeAttr), alignof(TransposeAttr));
15931596
return new (mem) TransposeAttr(implicit, atLoc, baseRange, baseType,
1594-
std::move(originalName), indices);
1597+
std::move(originalName), parameterIndices);
15951598
}
15961599

15971600
ImplementsAttr::ImplementsAttr(SourceLoc atLoc, SourceRange range,

lib/Parse/ParseDecl.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,7 +1111,8 @@ static bool parseBaseTypeForQualifiedDeclName(Parser &P, TypeRepr *&baseType) {
11111111
return false;
11121112
}
11131113

1114-
/// parseQualifiedDeclName
1114+
/// Parses an optional base type, followed by a declaration name.
1115+
/// Returns true on error (if declaration name could not be parsed).
11151116
///
11161117
/// \verbatim
11171118
/// qualified-decl-name:
@@ -1120,8 +1121,6 @@ static bool parseBaseTypeForQualifiedDeclName(Parser &P, TypeRepr *&baseType) {
11201121
/// identifier generic-args? ('.' identifier generic-args?)*
11211122
/// \endverbatim
11221123
///
1123-
/// Parses an optional base type, followed by a declaration name.
1124-
/// Returns true on error (if declaration name could not be parsed).
11251124
// TODO(TF-1066): Use module qualified name syntax/parsing instead of custom
11261125
// qualified name syntax/parsing.
11271126
static bool parseQualifiedDeclName(Parser &P, Diag<> nameParseError,
@@ -1147,7 +1146,8 @@ static bool parseQualifiedDeclName(Parser &P, Diag<> nameParseError,
11471146
///
11481147
/// \verbatim
11491148
/// derivative-attribute-arguments:
1150-
/// '(' 'of' ':' decl-name (',' differentiation-params-clause)? ')'
1149+
/// '(' 'of' ':' qualified-decl-name (',' differentiation-params-clause)?
1150+
/// ')'
11511151
/// \endverbatim
11521152
ParserResult<DerivativeAttr> Parser::parseDerivativeAttribute(SourceLoc atLoc,
11531153
SourceLoc loc) {
@@ -1206,16 +1206,16 @@ ParserResult<DerivativeAttr> Parser::parseDerivativeAttribute(SourceLoc atLoc,
12061206
/*DeclModifier*/ false);
12071207
return makeParserError();
12081208
}
1209-
return ParserResult<DerivativeAttr>(
1210-
DerivativeAttr::create(Context, /*implicit*/ false, atLoc,
1211-
SourceRange(loc, rParenLoc), original, params));
1209+
return ParserResult<DerivativeAttr>(DerivativeAttr::create(
1210+
Context, /*implicit*/ false, atLoc, SourceRange(loc, rParenLoc), baseType,
1211+
original, params));
12121212
}
12131213

12141214
/// Parse a `@transpose(of:)` attribute, returning true on error.
12151215
///
12161216
/// \verbatim
12171217
/// transpose-attribute-arguments:
1218-
/// '(' 'of' ':' decl-name (',' transposed-params-clause)? ')'
1218+
/// '(' 'of' ':' qualified-decl-name (',' transposed-params-clause)? ')'
12191219
/// \endverbatim
12201220
ParserResult<TransposeAttr> Parser::parseTransposeAttribute(SourceLoc atLoc,
12211221
SourceLoc loc) {

lib/Sema/TypeCheckAttr.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3547,16 +3547,26 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
35473547
return derivative->getParent() == func->getParent();
35483548
};
35493549

3550-
auto lookupOptions =
3551-
defaultMemberLookupOptions | NameLookupFlags::IgnoreAccessControl;
3550+
auto resolution = TypeResolution::forContextual(derivative->getDeclContext());
3551+
Type baseType;
3552+
if (auto *baseTypeRepr = attr->getBaseTypeRepr()) {
3553+
TypeResolutionOptions options = None;
3554+
options |= TypeResolutionFlags::AllowModule;
3555+
baseType = resolution.resolveType(baseTypeRepr, options);
3556+
}
3557+
if (baseType && baseType->hasError())
3558+
return;
3559+
auto lookupOptions = attr->getBaseTypeRepr()
3560+
? defaultMemberLookupOptions
3561+
: defaultUnqualifiedLookupOptions;
35523562
auto derivativeTypeCtx = derivative->getInnermostTypeContext();
35533563
if (!derivativeTypeCtx)
35543564
derivativeTypeCtx = derivative->getParent();
35553565
assert(derivativeTypeCtx);
35563566

35573567
// Look up original function.
35583568
auto *originalAFD = findAbstractFunctionDecl(
3559-
originalName.Name, originalName.Loc.getBaseNameLoc(), /*baseType*/ Type(),
3569+
originalName.Name, originalName.Loc.getBaseNameLoc(), baseType,
35603570
derivativeTypeCtx, isValidOriginal, noneValidDiagnostic,
35613571
ambiguousDiagnostic, notFunctionDiagnostic, lookupOptions,
35623572
hasValidTypeContext, invalidTypeContextDiagnostic);
@@ -3678,7 +3688,7 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
36783688
}
36793689

36803690
// Reject different-file derivative registration.
3681-
// TODO(TF-1021): Lift this restriction.
3691+
// TODO(TF-1021): Lift same-file derivative registration restriction.
36823692
if (originalAFD->getParentSourceFile() != derivative->getParentSourceFile()) {
36833693
diags.diagnose(attr->getLocation(),
36843694
diag::derivative_attr_not_in_same_file_as_original);

lib/Sema/TypeCheckType.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1644,6 +1644,10 @@ Type TypeChecker::resolveIdentifierType(
16441644
if (!result) return nullptr;
16451645

16461646
if (auto moduleTy = result->getAs<ModuleType>()) {
1647+
// Allow module types only if flag is specified.
1648+
if (options.contains(TypeResolutionFlags::AllowModule))
1649+
return moduleTy;
1650+
// Otherwise, emit an error.
16471651
if (!options.contains(TypeResolutionFlags::SilenceErrors)) {
16481652
auto moduleName = moduleTy->getModule()->getName();
16491653
diags.diagnose(Components.back()->getNameLoc(),

lib/Sema/TypeCheckType.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ enum class TypeResolutionFlags : uint16_t {
6666

6767
/// Whether we should not produce diagnostics if the type is invalid.
6868
SilenceErrors = 1 << 10,
69+
70+
/// Whether to allow module declaration types.
71+
AllowModule = 1 << 11
6972
};
7073

7174
/// Type resolution contexts that require special handling.

lib/Serialization/Deserialization.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4202,11 +4202,12 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
42024202
parametersBitVector[i] = parameters[i];
42034203
auto *indices = IndexSubset::get(ctx, parametersBitVector);
42044204

4205-
auto *derivAttr = DerivativeAttr::create(
4206-
ctx, isImplicit, SourceLoc(), SourceRange(), origName, indices);
4207-
derivAttr->setOriginalFunction(origDecl);
4208-
derivAttr->setDerivativeKind(*derivativeKind);
4209-
Attr = derivAttr;
4205+
auto *derivativeAttr =
4206+
DerivativeAttr::create(ctx, isImplicit, SourceLoc(), SourceRange(),
4207+
/*baseType*/ nullptr, origName, indices);
4208+
derivativeAttr->setOriginalFunction(origDecl);
4209+
derivativeAttr->setDerivativeKind(*derivativeKind);
4210+
Attr = derivativeAttr;
42104211
break;
42114212
}
42124213

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,8 @@ extension StaticMethod {
195195
return (x, { $0 })
196196
}
197197

198-
@derivative(of: foo)
198+
// Test qualified declaration name.
199+
@derivative(of: StaticMethod.foo)
199200
static func vjpFoo(x: Float) -> (value: Float, pullback: (Float) -> Float) {
200201
return (x, { $0 })
201202
}
@@ -232,6 +233,14 @@ extension InstanceMethod {
232233
return (x, { $0 + $1 })
233234
}
234235

236+
// Test qualified declaration name.
237+
@derivative(of: InstanceMethod.foo, wrt: x)
238+
func jvpFooWrtX(x: Self) -> (
239+
value: Self, differential: (TangentVector) -> (TangentVector)
240+
) {
241+
return (x, { $0 })
242+
}
243+
235244
@derivative(of: generic)
236245
func vjpGeneric<T: Differentiable>(_ x: T) -> (
237246
value: Self, pullback: (TangentVector) -> (TangentVector, T.TangentVector)

0 commit comments

Comments
 (0)