Skip to content

[AutoDiff] SR-12526: cross-module @derivative deserialization #30851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/swift/AST/ASTTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ SWIFT_TYPEID(Type)
SWIFT_TYPEID(TypePair)
SWIFT_TYPEID(TypeWitnessAndDecl)
SWIFT_TYPEID(Witness)
SWIFT_TYPEID_NAMED(AbstractFunctionDecl *, AbstractFunctionDecl)
SWIFT_TYPEID_NAMED(ClosureExpr *, ClosureExpr)
SWIFT_TYPEID_NAMED(CodeCompletionCallbacksFactory *,
CodeCompletionCallbacksFactory)
Expand Down
31 changes: 23 additions & 8 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1865,6 +1865,7 @@ class DerivativeAttr final
: public DeclAttribute,
private llvm::TrailingObjects<DerivativeAttr, ParsedAutoDiffParameter> {
friend TrailingObjects;
friend class DerivativeAttrOriginalDeclRequest;

/// The base type for the referenced original declaration. This field is
/// non-null only for parsed attributes that reference a qualified original
Expand All @@ -1873,8 +1874,24 @@ class DerivativeAttr final
TypeRepr *BaseTypeRepr;
/// The original function name.
DeclNameRefWithLoc OriginalFunctionName;
/// The original function declaration, resolved by the type checker.
AbstractFunctionDecl *OriginalFunction = nullptr;
/// The original function.
///
/// The states are:
/// - nullptr:
/// The original function is unknown. The typechecker is responsible for
/// eventually resolving it.
/// - AbstractFunctionDecl:
/// The original function is known to be this `AbstractFunctionDecl`.
/// - LazyMemberLoader:
/// This `LazyMemberLoader` knows how to resolve the original function.
/// `ResolverContextData` is an additional piece of data that the
/// `LazyMemberLoader` needs.
// TODO(TF-1235): Making `DerivativeAttr` immutable will simplify this by
// removing the `AbstractFunctionDecl` state.
llvm::PointerUnion<AbstractFunctionDecl *, LazyMemberLoader *> OriginalFunction;
/// Data representing the original function declaration. See doc comment for
/// `OriginalFunction`.
uint64_t ResolverContextData = 0;
/// The number of parsed differentiability parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The differentiability parameter indices, resolved by the type checker.
Expand Down Expand Up @@ -1907,12 +1924,10 @@ class DerivativeAttr final
DeclNameRefWithLoc getOriginalFunctionName() const {
return OriginalFunctionName;
}
AbstractFunctionDecl *getOriginalFunction() const {
return OriginalFunction;
}
void setOriginalFunction(AbstractFunctionDecl *decl) {
OriginalFunction = decl;
}
AbstractFunctionDecl *getOriginalFunction(ASTContext &context) const;
void setOriginalFunction(AbstractFunctionDecl *decl);
void setOriginalFunctionResolver(LazyMemberLoader *resolver,
uint64_t resolverContextData);

AutoDiffDerivativeFunctionKind getDerivativeKind() const {
assert(Kind && "Derivative function kind has not yet been resolved");
Expand Down
6 changes: 6 additions & 0 deletions include/swift/AST/LazyResolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ class alignas(void*) LazyMemberLoader {
loadDynamicallyReplacedFunctionDecl(const DynamicReplacementAttr *DRA,
uint64_t contextData) = 0;

/// Returns the referenced original declaration for a `@derivative(of:)`
/// attribute.
virtual AbstractFunctionDecl *
loadReferencedFunctionDecl(const DerivativeAttr *DA,
uint64_t contextData) = 0;

/// Returns the type for a given @_typeEraser() attribute.
virtual Type loadTypeEraserType(const TypeEraserAttr *TRA,
uint64_t contextData) = 0;
Expand Down
20 changes: 20 additions & 0 deletions include/swift/AST/TypeCheckRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -2130,6 +2130,26 @@ class DifferentiableAttributeTypeCheckRequest
void cacheResult(IndexSubset *value) const;
};

/// Resolves the referenced original declaration for a `@derivative` attribute.
class DerivativeAttrOriginalDeclRequest
: public SimpleRequest<DerivativeAttrOriginalDeclRequest,
AbstractFunctionDecl *(DerivativeAttr *),
RequestFlags::Cached> {
public:
using SimpleRequest::SimpleRequest;

private:
friend SimpleRequest;

// Evaluation.
AbstractFunctionDecl *evaluate(Evaluator &evaluator,
DerivativeAttr *attr) const;

public:
// Caching.
bool isCached() const { return true; }
};

/// Checks whether a type eraser has a viable initializer.
class TypeEraserHasViableInitRequest
: public SimpleRequest<TypeEraserHasViableInitRequest,
Expand Down
3 changes: 3 additions & 0 deletions include/swift/AST/TypeCheckerTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ SWIFT_REQUEST(TypeChecker, DefaultTypeRequest,
SWIFT_REQUEST(TypeChecker, DifferentiableAttributeTypeCheckRequest,
IndexSubset *(DifferentiableAttr *),
SeparatelyCached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, DerivativeAttrOriginalDeclRequest,
AbstractFunctionDecl *(DerivativeAttr *),
Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, TypeEraserHasViableInitRequest,
bool(TypeEraserAttr *, ProtocolDecl *),
Cached, NoLocationInfo)
Expand Down
21 changes: 21 additions & 0 deletions lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "swift/AST/Expr.h"
#include "swift/AST/GenericEnvironment.h"
#include "swift/AST/IndexSubset.h"
#include "swift/AST/LazyResolver.h"
#include "swift/AST/Module.h"
#include "swift/AST/ParameterList.h"
#include "swift/AST/TypeCheckRequests.h"
Expand Down Expand Up @@ -1750,6 +1751,26 @@ DerivativeAttr *DerivativeAttr::create(ASTContext &context, bool implicit,
std::move(originalName), parameterIndices);
}

AbstractFunctionDecl *
DerivativeAttr::getOriginalFunction(ASTContext &context) const {
return evaluateOrDefault(
context.evaluator,
DerivativeAttrOriginalDeclRequest{const_cast<DerivativeAttr *>(this)},
nullptr);
}

void DerivativeAttr::setOriginalFunction(AbstractFunctionDecl *decl) {
assert(!OriginalFunction && "cannot overwrite original function");
OriginalFunction = decl;
}

void DerivativeAttr::setOriginalFunctionResolver(
LazyMemberLoader *resolver, uint64_t resolverContextData) {
assert(!OriginalFunction && "cannot overwrite original function");
OriginalFunction = resolver;
ResolverContextData = resolverContextData;
}

TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, TypeRepr *baseTypeRepr,
DeclNameRefWithLoc originalName,
Expand Down
6 changes: 6 additions & 0 deletions lib/ClangImporter/ImporterImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,12 @@ class LLVM_LIBRARY_VISIBILITY ClangImporter::Implementation
llvm_unreachable("unimplemented for ClangImporter");
}

AbstractFunctionDecl *
loadReferencedFunctionDecl(const DerivativeAttr *DA,
uint64_t contextData) override {
llvm_unreachable("unimplemented for ClangImporter");
}

Type loadTypeEraserType(const TypeEraserAttr *TRA,
uint64_t contextData) override {
llvm_unreachable("unimplemented for ClangImporter");
Expand Down
2 changes: 1 addition & 1 deletion lib/SILGen/SILGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction(
vjp = F;
break;
}
auto *origAFD = derivAttr->getOriginalFunction();
auto *origAFD = derivAttr->getOriginalFunction(getASTContext());
auto origDeclRef =
SILDeclRef(origAFD).asForeign(requiresForeignEntryPoint(origAFD));
auto *origFn = getFunction(origDeclRef, NotForDefinition);
Expand Down
17 changes: 15 additions & 2 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3705,8 +3705,6 @@ getTransposeOriginalFunctionType(AnyFunctionType *transposeFnType,
return originalType;
}



/// Given a `@differentiable` attribute, attempts to resolve the original
/// `AbstractFunctionDecl` for which it is registered, using the declaration
/// on which it is actually declared. On error, emits diagnostic and returns
Expand Down Expand Up @@ -4454,6 +4452,21 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
attr->setInvalid();
}

AbstractFunctionDecl *
DerivativeAttrOriginalDeclRequest::evaluate(Evaluator &evaluator,
DerivativeAttr *attr) const {
// If the typechecker has resolved the original function, return it.
if (auto *FD = attr->OriginalFunction.dyn_cast<AbstractFunctionDecl *>())
return FD;

// If the function can be lazily resolved, do so now.
if (auto *Resolver = attr->OriginalFunction.dyn_cast<LazyMemberLoader *>())
return Resolver->loadReferencedFunctionDecl(attr,
attr->ResolverContextData);

return nullptr;
}

/// Returns true if the given type's `TangentVector` is equal to itself in the
/// given module.
static bool tangentVectorEqualsSelf(Type type, DeclContext *DC) {
Expand Down
9 changes: 7 additions & 2 deletions lib/Serialization/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4379,7 +4379,6 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {

DeclNameRefWithLoc origName{
DeclNameRef(MF.getDeclBaseName(origNameId)), DeclNameLoc()};
auto *origDecl = cast<AbstractFunctionDecl>(MF.getDecl(origDeclId));
auto derivativeKind =
getActualAutoDiffDerivativeFunctionKind(rawDerivativeKind);
if (!derivativeKind)
Expand All @@ -4392,7 +4391,7 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
auto *derivativeAttr =
DerivativeAttr::create(ctx, isImplicit, SourceLoc(), SourceRange(),
/*baseType*/ nullptr, origName, indices);
derivativeAttr->setOriginalFunction(origDecl);
derivativeAttr->setOriginalFunctionResolver(&MF, origDeclId);
derivativeAttr->setDerivativeKind(*derivativeKind);
Attr = derivativeAttr;
break;
Expand Down Expand Up @@ -5941,6 +5940,12 @@ ValueDecl *ModuleFile::loadDynamicallyReplacedFunctionDecl(
return cast<ValueDecl>(getDecl(contextData));
}

AbstractFunctionDecl *
ModuleFile::loadReferencedFunctionDecl(const DerivativeAttr *DA,
uint64_t contextData) {
return cast<AbstractFunctionDecl>(getDecl(contextData));
}

Type ModuleFile::loadTypeEraserType(const TypeEraserAttr *TRA,
uint64_t contextData) {
return getType(contextData);
Expand Down
4 changes: 4 additions & 0 deletions lib/Serialization/ModuleFile.h
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,10 @@ class ModuleFile
loadDynamicallyReplacedFunctionDecl(const DynamicReplacementAttr *DRA,
uint64_t contextData) override;

virtual AbstractFunctionDecl *
loadReferencedFunctionDecl(const DerivativeAttr *DA,
uint64_t contextData) override;

virtual Type loadTypeEraserType(const TypeEraserAttr *TRA,
uint64_t contextData) override;

Expand Down
7 changes: 4 additions & 3 deletions lib/Serialization/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2417,12 +2417,13 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
case DAK_Derivative: {
auto abbrCode = S.DeclTypeAbbrCodes[DerivativeDeclAttrLayout::Code];
auto *attr = cast<DerivativeAttr>(DA);
assert(attr->getOriginalFunction() &&
auto &ctx = S.getASTContext();
assert(attr->getOriginalFunction(ctx) &&
"`@derivative` attribute should have original declaration set "
"during construction or parsing");
auto origName = attr->getOriginalFunctionName().Name.getBaseName();
IdentifierID origNameId = S.addDeclBaseNameRef(origName);
DeclID origDeclID = S.addDeclRef(attr->getOriginalFunction());
DeclID origDeclID = S.addDeclRef(attr->getOriginalFunction(ctx));
auto derivativeKind =
getRawStableAutoDiffDerivativeFunctionKind(attr->getDerivativeKind());
auto *parameterIndices = attr->getParameterIndices();
Expand Down Expand Up @@ -4862,7 +4863,7 @@ static void recordDerivativeFunctionConfig(
attr->getDerivativeGenericSignature()});
}
for (auto *attr : AFD->getAttrs().getAttributes<DerivativeAttr>()) {
auto *origAFD = attr->getOriginalFunction();
auto *origAFD = attr->getOriginalFunction(ctx);
auto mangledName = ctx.getIdentifier(Mangler.mangleDeclAsUSR(origAFD, ""));
derivativeConfigs[mangledName].insert(
{ctx.getIdentifier(attr->getParameterIndices()->getString()),
Expand Down
2 changes: 1 addition & 1 deletion lib/TBDGen/TBDGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ void TBDGenVisitor::visitAbstractFunctionDecl(AbstractFunctionDecl *AFD) {
for (const auto *derivativeAttr :
AFD->getAttrs().getAttributes<DerivativeAttr>())
addDerivativeConfiguration(
derivativeAttr->getOriginalFunction(),
derivativeAttr->getOriginalFunction(AFD->getASTContext()),
AutoDiffConfig(derivativeAttr->getParameterIndices(),
IndexSubset::get(AFD->getASTContext(), 1, {0}),
AFD->getGenericSignature()));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
public struct Struct {
public func method(_ x: Float) -> Float { x }

public static func +(_ lhs: Self, rhs: Self) -> Self {
lhs
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,4 @@ extension Struct: Differentiable {
func vjpMethod(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
(x, { $0 })
}

@usableFromInline
@derivative(of: +)
static func vjpAdd(_ lhs: Self, rhs: Self) -> (
value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
) {
(lhs + rhs, { v in (v, v) })
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
// RUN: %empty-directory(%t)
// RUN: %target-swift-frontend -emit-module -primary-file %S/Inputs/a.swift -emit-module-path %t/a.swiftmodule
// RUN: %target-swift-frontend -emit-module -primary-file %S/Inputs/b.swift -emit-module-path %t/b.swiftmodule -I %t
// RUN: not --crash %target-swift-frontend-typecheck -verify -I %t %s
// "-verify-ignore-unknown" is for "<unknown>:0: note: 'init()' declared here"
// RUN: %target-swift-frontend-typecheck -verify -verify-ignore-unknown -I %t %s

// SR-12526: Fix cross-module deserialization crash involving `@derivative` attribute.

import a
import b

func foo(_ s: Struct) {
// Without this error, SR-12526 does not trigger.
// expected-error @+1 {{'Struct' initializer is inaccessible due to 'internal' protection level}}
_ = Struct()
_ = s.method(1)
}