Skip to content

[AutoDiff] Support @differentiable class initializers. #29754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -3004,10 +3004,6 @@ ERROR(differentiable_attr_stored_property_variable_unsupported,none,
ERROR(differentiable_attr_class_member_no_dynamic_self,none,
"'@differentiable' attribute cannot be declared on class methods "
"returning 'Self'", ())
// TODO(TF-654): Remove when differentiation supports class initializers.
ERROR(differentiable_attr_class_init_not_yet_supported,none,
"'@differentiable' attribute does not yet support class initializers",
())
ERROR(differentiable_attr_empty_where_clause,none,
"empty 'where' clause in '@differentiable' attribute", ())
// SWIFT_ENABLE_TENSORFLOW
Expand Down
10 changes: 10 additions & 0 deletions include/swift/SILOptimizer/Utils/Differentiation/PullbackEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,16 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
void visitUnconditionalCheckedCastAddrInst(
UnconditionalCheckedCastAddrInst *uccai);

/// Handle `unchecked_ref_cast` instruction.
/// Original: y = unchecked_ref_cast x
/// Adjoint: adj[x] += adj[y] (assuming x' and y' have the same type)
void visitUncheckedRefCastInst(UncheckedRefCastInst *urci);

/// Handle `upcast` instruction.
/// Original: y = upcast x
/// Adjoint: adj[x] += adj[y] (assuming x' and y' have the same type)
void visitUpcastInst(UpcastInst *ui);

#define NOT_DIFFERENTIABLE(INST, DIAG) void visit##INST##Inst(INST##Inst *inst);
#undef NOT_DIFFERENTIABLE

Expand Down
14 changes: 14 additions & 0 deletions lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,20 @@ Type TypeBase::replaceCovariantResultType(Type newResultType,
return OptionalType::get(
objectType->replaceCovariantResultType(newResultType, uncurryLevel));
}
// SWIFT_ENABLE_TENSORFLOW
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this added logic in TypeBase::replaceCovariantResultType is hacky. It's necessary for TypeConverter::getConstantOverrideInfo (derivative vtable entry SILGen) to work properly.

I can't think of a more robust alternative that doesn't involve ~60 lines of code dupe.


One alternative is to drop support for @differentiable initializers in non-final classes.

Initializers return a dynamic Self type, so perhaps non-final initializer derivatives shouldn't be supported (since derivatives would return a tuple (Self, ...) and type-checking accepts covariant Self only as a top-level method result type).

If there are no use cases for @differentiable initializers in non-final classes, I think dropping support is fine.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we should ban @differentiable on everything that produces a covariant result.

// Special logic to handle JVP/VJP derivative functions of `Self`-returning
// methods. This logic is hacky and not robust at all.
// Consider adding a `bool isAutoDiffDerivative` flag for triggering this
// logic, or creating a separate dedicated helper
// `TypeBase::replaceCovariantResultTypeForAutoDiffDerivative`.
if (auto tupleType = dyn_cast<TupleType>(this)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is in TypeBase. getAs() would be more idiomatic.

Suggested change
if (auto tupleType = dyn_cast<TupleType>(this)) {
if (auto tupleType = getAs<TupleType>()) {

assert(tupleType->getNumElements() == 2 &&
"Tuple result is expected only for derivative functions, which "
"return a two-element tuple");
return TupleType::get({newResultType, tupleType->getElement(1)},
getASTContext());
}
// SWIFT_ENABLE_TENSORFLOW END

return newResultType;
}
Expand Down
43 changes: 33 additions & 10 deletions lib/SIL/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2705,17 +2705,15 @@ TypeConverter::getConstantInfo(TypeExpansionContext expansion,
loweredInterfaceType);

// SWIFT_ENABLE_TENSORFLOW
// In the case of autodiff derivative functions, the above computations
// determine `silFnType` by first computing the derivative function type at
// the AST level and then lowering that. Unfortunately, the actual
// SILFunctionType for the function is determined by first lowering the
// function's AST type, and then computing the derivative function type at the
// SIL level. "Lowering" does not commute with "getting the autodiff
// associated type", so these two computations produce different results.
// Therefore `silFnType` is not the actual type of the function that
// `constant` refers to.
// For derivative functions, the above computations determine `silFnType`
// by first computing the derivative AST function type and then lowering it to
// SIL. Unfortunately, the expected derivative SIL function type is determined
// by first lowering the original function's AST type, and then computing its
// SIL derivative function type. "Lowering" does not commute with "getting the
// derivative type", so these two computations produce different results.
// Therefore, `silFnType` is not the expected SIL derivative function type.
//
// We hackily fix this problem by redoing the computation in the right order.
// We fix this problem by performing the computation in the right order.
if (auto *autoDiffFuncId = constant.autoDiffDerivativeFunctionIdentifier) {
auto origFnConstantInfo = getConstantInfo(
TypeExpansionContext::minimal(), constant.asAutoDiffOriginalFunction());
Expand All @@ -2725,6 +2723,7 @@ TypeConverter::getConstantInfo(TypeExpansionContext expansion,
loweredIndices, /*resultIndex*/ 0, autoDiffFuncId->getKind(),
*this, LookUpConformanceInModule(&M));
}
// SWIFT_ENABLE_TENSORFLOW END

LLVM_DEBUG(llvm::dbgs() << "lowering type for constant ";
constant.print(llvm::dbgs());
Expand Down Expand Up @@ -2938,6 +2937,30 @@ TypeConverter::getConstantOverrideInfo(TypeExpansionContext context,
*this, context, basePattern, bridgedTypes.Uncurried, base, derived,
/*reqt subs*/ None, ProtocolConformanceRef());

// SWIFT_ENABLE_TENSORFLOW
// For derivative functions, the above computations determine `fnTy`
// by first computing the derivative AST function type and then lowering it to
// SIL. Unfortunately, the expected derivative SIL function type is determined
// by first lowering the original function's AST type, and then computing its
// SIL derivative function type. "Lowering" does not commute with "getting the
// derivative type", so these two computations produce different results.
// Therefore, `fnTy` is not the expected SIL derivative function type.
//
// We fix this problem by performing the computation in the right order.
if (auto *derivedDerivID = derived.autoDiffDerivativeFunctionIdentifier) {
auto *baseDerivID = base.autoDiffDerivativeFunctionIdentifier;
assert(baseDerivID);
auto &overrideInfo =
getConstantOverrideInfo(context, derived.asAutoDiffOriginalFunction(),
base.asAutoDiffOriginalFunction());
auto loweredIndices = autodiff::getLoweredParameterIndices(
derivedDerivID->getParameterIndices(), overrideInterfaceTy);
fnTy = overrideInfo.SILFnType->getAutoDiffDerivativeFunctionType(
loweredIndices, /*resultIndex*/ 0, derivedDerivID->getKind(), *this,
LookUpConformanceInModule(&M));
}
// SWIFT_ENABLE_TENSORFLOW END

// Build the SILConstantInfo and cache it.
auto resultBuf = Context.Allocate(sizeof(SILConstantInfo),
alignof(SILConstantInfo));
Expand Down
37 changes: 31 additions & 6 deletions lib/SILGen/SILGenPoly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3749,8 +3749,9 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
auto *thunk = fb.getOrCreateFunction(
loc, name, customDerivativeFn->getLinkage(), thunkFnTy, IsBare,
IsNotTransparent, customDerivativeFn->isSerialized(),
customDerivativeFn->isDynamicallyReplaceable(), customDerivativeFn->getEntryCount(),
IsThunk, customDerivativeFn->getClassSubclassScope());
customDerivativeFn->isDynamicallyReplaceable(),
customDerivativeFn->getEntryCount(), IsThunk,
customDerivativeFn->getClassSubclassScope());
thunk->setInlineStrategy(AlwaysInline);
if (!thunk->empty())
return thunk;
Expand All @@ -3762,15 +3763,39 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
thunkSGF.collectThunkParams(loc, params, &indirectResults);

auto *fnRef = thunkSGF.B.createFunctionRef(loc, customDerivativeFn);
auto fnRefType =
fnRef->getType().castTo<SILFunctionType>();
auto fnRefType = fnRef->getType().castTo<SILFunctionType>();

// Collect thunk arguments, converting ownership.
SmallVector<SILValue, 8> arguments;
for (auto *indRes : indirectResults)
arguments.push_back(indRes);
forwardFunctionArguments(thunkSGF, loc, fnRefType, params,
arguments);
forwardFunctionArguments(thunkSGF, loc, fnRefType, params, arguments);

// Special support for thunking class initializer derivatives.
//
// User-defined custom derivatives take a metatype as the last parameter:
// - `$(Param0, Param1, ..., @thick Class.Type) -> (...)`
// But class initializers take an allocated instance as the last parameter:
// - `$(Param0, Param1, ..., @owned Class) -> (...)`
//
// Adjust forwarded arguments:
// - Pop the last `@owned Class` argument.
// - Create a `@thick Class.Type` value and pass it as the last argument.
auto *origAFD =
cast<AbstractFunctionDecl>(originalFn->getDeclContext()->getAsDecl());
if (isa<ConstructorDecl>(origAFD) &&
SILDeclRef(origAFD, SILDeclRef::Kind::Initializer).mangle() ==
originalFn->getName()) {
auto classArgument = arguments.pop_back_val();
auto *classDecl = classArgument->getType().getClassOrBoundGenericClass();
assert(classDecl && "Expected last argument to have class type");
auto classMetatype = MetatypeType::get(
classDecl->getDeclaredInterfaceType(), MetatypeRepresentation::Thick);
auto canClassMetatype = classMetatype->getCanonicalType();
auto *metatype = thunkSGF.B.createMetatype(
loc, SILType::getPrimitiveObjectType(canClassMetatype));
arguments.push_back(metatype);
}
// Apply function argument.
auto apply = thunkSGF.emitApplyWithRethrow(
loc, fnRef, /*substFnType*/ fnRef->getType(),
Expand Down
28 changes: 24 additions & 4 deletions lib/SILOptimizer/Utils/Differentiation/PullbackEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,8 @@ SILValue PullbackEmitter::getAdjointProjection(SILBasicBlock *origBB,
auto *tanField = cast<VarDecl>(tanFieldLookup.front());
// Create a local allocation for the element adjoint buffer.
auto eltTanType = tanField->getValueInterfaceType()->getCanonicalType();
auto eltTanSILType = SILType::getPrimitiveAddressType(eltTanType);
auto eltTanSILType =
remapType(SILType::getPrimitiveAddressType(eltTanType));
auto *eltAdjBuffer = createFunctionLocalAllocation(eltTanSILType, loc);
builder.emitScopedBorrowOperation(
loc, adjClass, [&](SILValue borrowedAdjClass) {
Expand Down Expand Up @@ -1090,7 +1091,7 @@ PullbackEmitter::getArrayAdjointElementBuffer(SILValue arrayAdjoint,
auto arrayTanType = cast<StructType>(arrayAdjoint->getType().getASTType());
auto arrayType = arrayTanType->getParent()->castTo<BoundGenericStructType>();
auto eltTanType = arrayType->getGenericArgs().front()->getCanonicalType();
auto eltTanSILType = SILType::getPrimitiveAddressType(eltTanType);
auto eltTanSILType = remapType(SILType::getPrimitiveAddressType(eltTanType));
// Get `function_ref` and generic signature of
// `Array.TangentVector.subscript.getter`.
auto *arrayTanStructDecl = arrayTanType->getStructOrBoundGenericStruct();
Expand Down Expand Up @@ -1602,12 +1603,11 @@ void PullbackEmitter::visitLoadOperation(SingleValueInstruction *inst) {
void PullbackEmitter::visitStoreOperation(SILBasicBlock *bb, SILLocation loc,
SILValue origSrc, SILValue origDest) {
auto &adjBuf = getAdjointBuffer(bb, origDest);
auto bufType = remapType(adjBuf->getType());
auto adjVal =
builder.emitLoadValueOperation(loc, adjBuf, LoadOwnershipQualifier::Take);
recordTemporary(adjVal);
addAdjointValue(bb, origSrc, makeConcreteAdjointValue(adjVal), loc);
emitZeroIndirect(bufType.getASTType(), adjBuf, loc);
emitZeroIndirect(adjBuf->getType().getASTType(), adjBuf, loc);
}

void PullbackEmitter::visitStoreInst(StoreInst *si) {
Expand Down Expand Up @@ -1672,6 +1672,26 @@ void PullbackEmitter::visitUnconditionalCheckedCastAddrInst(
emitZeroIndirect(destType.getASTType(), adjDest, uccai->getLoc());
}

void PullbackEmitter::visitUncheckedRefCastInst(UncheckedRefCastInst *urci) {
auto *bb = urci->getParent();
assert(urci->getOperand()->getType().isObject());
assert(getRemappedTangentType(urci->getOperand()->getType()) ==
getRemappedTangentType(urci->getType()) &&
"Operand/result must have the same `TangentVector` type");
auto adj = getAdjointValue(bb, urci);
addAdjointValue(bb, urci->getOperand(), adj, urci->getLoc());
}

void PullbackEmitter::visitUpcastInst(UpcastInst *ui) {
auto *bb = ui->getParent();
assert(ui->getOperand()->getType().isObject());
assert(getRemappedTangentType(ui->getOperand()->getType()) ==
getRemappedTangentType(ui->getType()) &&
"Operand/result must have the same `TangentVector` type");
auto adj = getAdjointValue(bb, ui);
addAdjointValue(bb, ui->getOperand(), adj, ui->getLoc());
}

#define NOT_DIFFERENTIABLE(INST, DIAG) \
void PullbackEmitter::visit##INST##Inst(INST##Inst *inst) { \
getContext().emitNondifferentiabilityError(inst, getInvoker(), \
Expand Down
18 changes: 4 additions & 14 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3975,26 +3975,16 @@ llvm::Expected<IndexSubset *> DifferentiableAttributeTypeCheckRequest::evaluate(
// Diagnose if original function is an invalid class member.
if (isOriginalClassMember) {
// Class methods returning dynamic `Self` are not supported.
// (For class methods, dynamic `Self` is supported only as the single
// result - tuple-returning JVPs/VJPs would not type-check.)
if (auto *originalFn = dyn_cast<FuncDecl>(original)) {
if (originalFn->hasDynamicSelfResult()) {
// For class methods, dynamic `Self` is supported only as the single
// result - tuple-returning JVPs/VJPs would not type-check.
if (auto *originalFD = dyn_cast<FuncDecl>(original)) {
if (originalFD->hasDynamicSelfResult()) {
diags.diagnose(attr->getLocation(),
diag::differentiable_attr_class_member_no_dynamic_self);
attr->setInvalid();
return nullptr;
}
}

// TODO(TF-654): Class initializers are not yet supported.
// Extra JVP/VJP type calculation logic is necessary because classes have
// both allocators and initializers.
if (auto *initDecl = dyn_cast<ConstructorDecl>(original)) {
diags.diagnose(attr->getLocation(),
diag::differentiable_attr_class_init_not_yet_supported);
attr->setInvalid();
return nullptr;
}
}

// Resolve the derivative generic signature.
Expand Down
2 changes: 0 additions & 2 deletions test/AutoDiff/Sema/differentiable_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1080,8 +1080,6 @@ class Super: Differentiable {

var base: Float

// NOTE(TF-654): Class initializers are not yet supported.
// expected-error @+1 {{'@differentiable' attribute does not yet support class initializers}}
@differentiable
init(base: Float) {
self.base = base
Expand Down
29 changes: 29 additions & 0 deletions test/AutoDiff/downstream/class_differentiation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,16 @@ ClassTests.test("TrivialMember") {
@noDerivative
final var noDerivative: Float = 1

@differentiable
init(_ float: Float) {
self.float = float
}

@differentiable
convenience init(convenience x: Float) {
self.init(x)
}

@differentiable
func method(_ x: Float) -> Float {
x * float
Expand All @@ -44,6 +50,7 @@ ClassTests.test("TrivialMember") {
}
// Test class initializer differentiation.
expectEqual(10, pullback(at: 3, in: { C($0) })(.init(float: 10)))
expectEqual(10, pullback(at: 3, in: { C(convenience: $0) })(.init(float: 10)))
// Test class method differentiation.
expectEqual((.init(float: 3), 10), gradient(at: C(10), 3, in: { c, x in c.method(x) }))
expectEqual(.init(float: 0), gradient(at: C(10), in: { c in c.testNoDerivative() }))
Expand All @@ -56,6 +63,7 @@ ClassTests.test("NontrivialMember") {
@differentiable
var float: Tracked<Float>

@differentiable
init(_ float: Tracked<Float>) {
self.float = float
}
Expand Down Expand Up @@ -84,6 +92,26 @@ ClassTests.test("NontrivialMember") {
gradient(at: C(10), C(20), in: { c1, c2 in C.controlFlow(c1, c2, true) }))
}

ClassTests.test("GenericNontrivialMember") {
final class C<T: Differentiable>: Differentiable where T == T.TangentVector {
@differentiable
var x: Tracked<T>

@differentiable
init(_ x: T) {
self.x = Tracked(x)
}

@differentiable
convenience init(convenience x: T) {
self.init(x)
}
}
// Test class initializer differentiation.
expectEqual(10, pullback(at: 3, in: { C<Float>($0) })(.init(x: 10)))
expectEqual(10, pullback(at: 3, in: { C<Float>(convenience: $0) })(.init(x: 10)))
}

// TF-1149: Test class with loadable type but address-only `TangentVector` type.
// TODO(TF-1149): Uncomment when supported.
/*
Expand All @@ -92,6 +120,7 @@ ClassTests.test("AddressOnlyTangentVector") {
@differentiable
var stored: T

@differentiable
init(_ stored: T) {
self.stored = stored
}
Expand Down
Loading