Skip to content

Commit eefe9a0

Browse files
author
marcrasi
authored
Merge pull request #30851 from apple/derivative-attr-serialization
[AutoDiff] SR-12526: cross-module @Derivative deserialization
2 parents 99993a9 + 7abf8ae commit eefe9a0

File tree

16 files changed

+116
-30
lines changed

16 files changed

+116
-30
lines changed

include/swift/AST/ASTTypeIDZone.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ SWIFT_TYPEID(Type)
2929
SWIFT_TYPEID(TypePair)
3030
SWIFT_TYPEID(TypeWitnessAndDecl)
3131
SWIFT_TYPEID(Witness)
32+
SWIFT_TYPEID_NAMED(AbstractFunctionDecl *, AbstractFunctionDecl)
3233
SWIFT_TYPEID_NAMED(ClosureExpr *, ClosureExpr)
3334
SWIFT_TYPEID_NAMED(CodeCompletionCallbacksFactory *,
3435
CodeCompletionCallbacksFactory)

include/swift/AST/Attr.h

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1865,6 +1865,7 @@ class DerivativeAttr final
18651865
: public DeclAttribute,
18661866
private llvm::TrailingObjects<DerivativeAttr, ParsedAutoDiffParameter> {
18671867
friend TrailingObjects;
1868+
friend class DerivativeAttrOriginalDeclRequest;
18681869

18691870
/// The base type for the referenced original declaration. This field is
18701871
/// non-null only for parsed attributes that reference a qualified original
@@ -1873,8 +1874,24 @@ class DerivativeAttr final
18731874
TypeRepr *BaseTypeRepr;
18741875
/// The original function name.
18751876
DeclNameRefWithLoc OriginalFunctionName;
1876-
/// The original function declaration, resolved by the type checker.
1877-
AbstractFunctionDecl *OriginalFunction = nullptr;
1877+
/// The original function.
1878+
///
1879+
/// The states are:
1880+
/// - nullptr:
1881+
/// The original function is unknown. The typechecker is responsible for
1882+
/// eventually resolving it.
1883+
/// - AbstractFunctionDecl:
1884+
/// The original function is known to be this `AbstractFunctionDecl`.
1885+
/// - LazyMemberLoader:
1886+
/// This `LazyMemberLoader` knows how to resolve the original function.
1887+
/// `ResolverContextData` is an additional piece of data that the
1888+
/// `LazyMemberLoader` needs.
1889+
// TODO(TF-1235): Making `DerivativeAttr` immutable will simplify this by
1890+
// removing the `AbstractFunctionDecl` state.
1891+
llvm::PointerUnion<AbstractFunctionDecl *, LazyMemberLoader *> OriginalFunction;
1892+
/// Data representing the original function declaration. See doc comment for
1893+
/// `OriginalFunction`.
1894+
uint64_t ResolverContextData = 0;
18781895
/// The number of parsed differentiability parameters specified in 'wrt:'.
18791896
unsigned NumParsedParameters = 0;
18801897
/// The differentiability parameter indices, resolved by the type checker.
@@ -1907,12 +1924,10 @@ class DerivativeAttr final
19071924
DeclNameRefWithLoc getOriginalFunctionName() const {
19081925
return OriginalFunctionName;
19091926
}
1910-
AbstractFunctionDecl *getOriginalFunction() const {
1911-
return OriginalFunction;
1912-
}
1913-
void setOriginalFunction(AbstractFunctionDecl *decl) {
1914-
OriginalFunction = decl;
1915-
}
1927+
AbstractFunctionDecl *getOriginalFunction(ASTContext &context) const;
1928+
void setOriginalFunction(AbstractFunctionDecl *decl);
1929+
void setOriginalFunctionResolver(LazyMemberLoader *resolver,
1930+
uint64_t resolverContextData);
19161931

19171932
AutoDiffDerivativeFunctionKind getDerivativeKind() const {
19181933
assert(Kind && "Derivative function kind has not yet been resolved");

include/swift/AST/LazyResolver.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ class alignas(void*) LazyMemberLoader {
106106
loadDynamicallyReplacedFunctionDecl(const DynamicReplacementAttr *DRA,
107107
uint64_t contextData) = 0;
108108

109+
/// Returns the referenced original declaration for a `@derivative(of:)`
110+
/// attribute.
111+
virtual AbstractFunctionDecl *
112+
loadReferencedFunctionDecl(const DerivativeAttr *DA,
113+
uint64_t contextData) = 0;
114+
109115
/// Returns the type for a given @_typeEraser() attribute.
110116
virtual Type loadTypeEraserType(const TypeEraserAttr *TRA,
111117
uint64_t contextData) = 0;

include/swift/AST/TypeCheckRequests.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2130,6 +2130,26 @@ class DifferentiableAttributeTypeCheckRequest
21302130
void cacheResult(IndexSubset *value) const;
21312131
};
21322132

2133+
/// Resolves the referenced original declaration for a `@derivative` attribute.
2134+
class DerivativeAttrOriginalDeclRequest
2135+
: public SimpleRequest<DerivativeAttrOriginalDeclRequest,
2136+
AbstractFunctionDecl *(DerivativeAttr *),
2137+
RequestFlags::Cached> {
2138+
public:
2139+
using SimpleRequest::SimpleRequest;
2140+
2141+
private:
2142+
friend SimpleRequest;
2143+
2144+
// Evaluation.
2145+
AbstractFunctionDecl *evaluate(Evaluator &evaluator,
2146+
DerivativeAttr *attr) const;
2147+
2148+
public:
2149+
// Caching.
2150+
bool isCached() const { return true; }
2151+
};
2152+
21332153
/// Checks whether a type eraser has a viable initializer.
21342154
class TypeEraserHasViableInitRequest
21352155
: public SimpleRequest<TypeEraserHasViableInitRequest,

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ SWIFT_REQUEST(TypeChecker, DefaultTypeRequest,
4949
SWIFT_REQUEST(TypeChecker, DifferentiableAttributeTypeCheckRequest,
5050
IndexSubset *(DifferentiableAttr *),
5151
SeparatelyCached, NoLocationInfo)
52+
SWIFT_REQUEST(TypeChecker, DerivativeAttrOriginalDeclRequest,
53+
AbstractFunctionDecl *(DerivativeAttr *),
54+
Cached, NoLocationInfo)
5255
SWIFT_REQUEST(TypeChecker, TypeEraserHasViableInitRequest,
5356
bool(TypeEraserAttr *, ProtocolDecl *),
5457
Cached, NoLocationInfo)

lib/AST/Attr.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "swift/AST/Expr.h"
2222
#include "swift/AST/GenericEnvironment.h"
2323
#include "swift/AST/IndexSubset.h"
24+
#include "swift/AST/LazyResolver.h"
2425
#include "swift/AST/Module.h"
2526
#include "swift/AST/ParameterList.h"
2627
#include "swift/AST/TypeCheckRequests.h"
@@ -1750,6 +1751,26 @@ DerivativeAttr *DerivativeAttr::create(ASTContext &context, bool implicit,
17501751
std::move(originalName), parameterIndices);
17511752
}
17521753

1754+
AbstractFunctionDecl *
1755+
DerivativeAttr::getOriginalFunction(ASTContext &context) const {
1756+
return evaluateOrDefault(
1757+
context.evaluator,
1758+
DerivativeAttrOriginalDeclRequest{const_cast<DerivativeAttr *>(this)},
1759+
nullptr);
1760+
}
1761+
1762+
void DerivativeAttr::setOriginalFunction(AbstractFunctionDecl *decl) {
1763+
assert(!OriginalFunction && "cannot overwrite original function");
1764+
OriginalFunction = decl;
1765+
}
1766+
1767+
void DerivativeAttr::setOriginalFunctionResolver(
1768+
LazyMemberLoader *resolver, uint64_t resolverContextData) {
1769+
assert(!OriginalFunction && "cannot overwrite original function");
1770+
OriginalFunction = resolver;
1771+
ResolverContextData = resolverContextData;
1772+
}
1773+
17531774
TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc,
17541775
SourceRange baseRange, TypeRepr *baseTypeRepr,
17551776
DeclNameRefWithLoc originalName,

lib/ClangImporter/ImporterImpl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,6 +1290,12 @@ class LLVM_LIBRARY_VISIBILITY ClangImporter::Implementation
12901290
llvm_unreachable("unimplemented for ClangImporter");
12911291
}
12921292

1293+
AbstractFunctionDecl *
1294+
loadReferencedFunctionDecl(const DerivativeAttr *DA,
1295+
uint64_t contextData) override {
1296+
llvm_unreachable("unimplemented for ClangImporter");
1297+
}
1298+
12931299
Type loadTypeEraserType(const TypeEraserAttr *TRA,
12941300
uint64_t contextData) override {
12951301
llvm_unreachable("unimplemented for ClangImporter");

lib/SILGen/SILGen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,7 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction(
791791
vjp = F;
792792
break;
793793
}
794-
auto *origAFD = derivAttr->getOriginalFunction();
794+
auto *origAFD = derivAttr->getOriginalFunction(getASTContext());
795795
auto origDeclRef =
796796
SILDeclRef(origAFD).asForeign(requiresForeignEntryPoint(origAFD));
797797
auto *origFn = getFunction(origDeclRef, NotForDefinition);

lib/Sema/TypeCheckAttr.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3705,8 +3705,6 @@ getTransposeOriginalFunctionType(AnyFunctionType *transposeFnType,
37053705
return originalType;
37063706
}
37073707

3708-
3709-
37103708
/// Given a `@differentiable` attribute, attempts to resolve the original
37113709
/// `AbstractFunctionDecl` for which it is registered, using the declaration
37123710
/// on which it is actually declared. On error, emits diagnostic and returns
@@ -4454,6 +4452,21 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
44544452
attr->setInvalid();
44554453
}
44564454

4455+
AbstractFunctionDecl *
4456+
DerivativeAttrOriginalDeclRequest::evaluate(Evaluator &evaluator,
4457+
DerivativeAttr *attr) const {
4458+
// If the typechecker has resolved the original function, return it.
4459+
if (auto *FD = attr->OriginalFunction.dyn_cast<AbstractFunctionDecl *>())
4460+
return FD;
4461+
4462+
// If the function can be lazily resolved, do so now.
4463+
if (auto *Resolver = attr->OriginalFunction.dyn_cast<LazyMemberLoader *>())
4464+
return Resolver->loadReferencedFunctionDecl(attr,
4465+
attr->ResolverContextData);
4466+
4467+
return nullptr;
4468+
}
4469+
44574470
/// Returns true if the given type's `TangentVector` is equal to itself in the
44584471
/// given module.
44594472
static bool tangentVectorEqualsSelf(Type type, DeclContext *DC) {

lib/Serialization/Deserialization.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4379,7 +4379,6 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
43794379

43804380
DeclNameRefWithLoc origName{
43814381
DeclNameRef(MF.getDeclBaseName(origNameId)), DeclNameLoc()};
4382-
auto *origDecl = cast<AbstractFunctionDecl>(MF.getDecl(origDeclId));
43834382
auto derivativeKind =
43844383
getActualAutoDiffDerivativeFunctionKind(rawDerivativeKind);
43854384
if (!derivativeKind)
@@ -4392,7 +4391,7 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
43924391
auto *derivativeAttr =
43934392
DerivativeAttr::create(ctx, isImplicit, SourceLoc(), SourceRange(),
43944393
/*baseType*/ nullptr, origName, indices);
4395-
derivativeAttr->setOriginalFunction(origDecl);
4394+
derivativeAttr->setOriginalFunctionResolver(&MF, origDeclId);
43964395
derivativeAttr->setDerivativeKind(*derivativeKind);
43974396
Attr = derivativeAttr;
43984397
break;
@@ -5941,6 +5940,12 @@ ValueDecl *ModuleFile::loadDynamicallyReplacedFunctionDecl(
59415940
return cast<ValueDecl>(getDecl(contextData));
59425941
}
59435942

5943+
AbstractFunctionDecl *
5944+
ModuleFile::loadReferencedFunctionDecl(const DerivativeAttr *DA,
5945+
uint64_t contextData) {
5946+
return cast<AbstractFunctionDecl>(getDecl(contextData));
5947+
}
5948+
59445949
Type ModuleFile::loadTypeEraserType(const TypeEraserAttr *TRA,
59455950
uint64_t contextData) {
59465951
return getType(contextData);

lib/Serialization/ModuleFile.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,10 @@ class ModuleFile
882882
loadDynamicallyReplacedFunctionDecl(const DynamicReplacementAttr *DRA,
883883
uint64_t contextData) override;
884884

885+
virtual AbstractFunctionDecl *
886+
loadReferencedFunctionDecl(const DerivativeAttr *DA,
887+
uint64_t contextData) override;
888+
885889
virtual Type loadTypeEraserType(const TypeEraserAttr *TRA,
886890
uint64_t contextData) override;
887891

lib/Serialization/Serialization.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2417,12 +2417,13 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
24172417
case DAK_Derivative: {
24182418
auto abbrCode = S.DeclTypeAbbrCodes[DerivativeDeclAttrLayout::Code];
24192419
auto *attr = cast<DerivativeAttr>(DA);
2420-
assert(attr->getOriginalFunction() &&
2420+
auto &ctx = S.getASTContext();
2421+
assert(attr->getOriginalFunction(ctx) &&
24212422
"`@derivative` attribute should have original declaration set "
24222423
"during construction or parsing");
24232424
auto origName = attr->getOriginalFunctionName().Name.getBaseName();
24242425
IdentifierID origNameId = S.addDeclBaseNameRef(origName);
2425-
DeclID origDeclID = S.addDeclRef(attr->getOriginalFunction());
2426+
DeclID origDeclID = S.addDeclRef(attr->getOriginalFunction(ctx));
24262427
auto derivativeKind =
24272428
getRawStableAutoDiffDerivativeFunctionKind(attr->getDerivativeKind());
24282429
auto *parameterIndices = attr->getParameterIndices();
@@ -4862,7 +4863,7 @@ static void recordDerivativeFunctionConfig(
48624863
attr->getDerivativeGenericSignature()});
48634864
}
48644865
for (auto *attr : AFD->getAttrs().getAttributes<DerivativeAttr>()) {
4865-
auto *origAFD = attr->getOriginalFunction();
4866+
auto *origAFD = attr->getOriginalFunction(ctx);
48664867
auto mangledName = ctx.getIdentifier(Mangler.mangleDeclAsUSR(origAFD, ""));
48674868
derivativeConfigs[mangledName].insert(
48684869
{ctx.getIdentifier(attr->getParameterIndices()->getString()),

lib/TBDGen/TBDGen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ void TBDGenVisitor::visitAbstractFunctionDecl(AbstractFunctionDecl *AFD) {
665665
for (const auto *derivativeAttr :
666666
AFD->getAttrs().getAttributes<DerivativeAttr>())
667667
addDerivativeConfiguration(
668-
derivativeAttr->getOriginalFunction(),
668+
derivativeAttr->getOriginalFunction(AFD->getASTContext()),
669669
AutoDiffConfig(derivativeAttr->getParameterIndices(),
670670
IndexSubset::get(AFD->getASTContext(), 1, {0}),
671671
AFD->getGenericSignature()));
Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
11
public struct Struct {
22
public func method(_ x: Float) -> Float { x }
3-
4-
public static func +(_ lhs: Self, rhs: Self) -> Self {
5-
lhs
6-
}
73
}

test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,4 @@ extension Struct: Differentiable {
1010
func vjpMethod(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
1111
(x, { $0 })
1212
}
13-
14-
@usableFromInline
15-
@derivative(of: +)
16-
static func vjpAdd(_ lhs: Self, rhs: Self) -> (
17-
value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
18-
) {
19-
(lhs + rhs, { v in (v, v) })
20-
}
2113
}
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
// RUN: %empty-directory(%t)
22
// RUN: %target-swift-frontend -emit-module -primary-file %S/Inputs/a.swift -emit-module-path %t/a.swiftmodule
33
// RUN: %target-swift-frontend -emit-module -primary-file %S/Inputs/b.swift -emit-module-path %t/b.swiftmodule -I %t
4-
// RUN: not --crash %target-swift-frontend-typecheck -verify -I %t %s
4+
// "-verify-ignore-unknown" is for "<unknown>:0: note: 'init()' declared here"
5+
// RUN: %target-swift-frontend-typecheck -verify -verify-ignore-unknown -I %t %s
56

67
// SR-12526: Fix cross-module deserialization crash involving `@derivative` attribute.
78

89
import a
910
import b
1011

1112
func foo(_ s: Struct) {
13+
// Without this error, SR-12526 does not trigger.
14+
// expected-error @+1 {{'Struct' initializer is inaccessible due to 'internal' protection level}}
1215
_ = Struct()
1316
_ = s.method(1)
1417
}

0 commit comments

Comments
 (0)