Skip to content

Commit e6c9886

Browse files
committed
[AutoDiff] Support @differentiable class initializers.
Support `@differentiable` class initializers. - Add special logic in `SILGenModule::getOrCreateCustomDerivativeThunk` to thunk user-defined class initializer derivative functions to the expected derivative type for class initializers. - Make minor vtable SILGen changes for class initializer derivatives. Also support differentiation for class-related casting instructions: `unchecked_ref_cast` and `upcast`. These instructions are generated when calling `super.init` or referencing inherited superclass members. Resolves SR-12151 and SR-12153.
1 parent 5b7a1c3 commit e6c9886

File tree

14 files changed

+172
-115
lines changed

14 files changed

+172
-115
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3004,10 +3004,6 @@ ERROR(differentiable_attr_stored_property_variable_unsupported,none,
30043004
ERROR(differentiable_attr_class_member_no_dynamic_self,none,
30053005
"'@differentiable' attribute cannot be declared on class methods "
30063006
"returning 'Self'", ())
3007-
// TODO(TF-654): Remove when differentiation supports class initializers.
3008-
ERROR(differentiable_attr_class_init_not_yet_supported,none,
3009-
"'@differentiable' attribute does not yet support class initializers",
3010-
())
30113007
ERROR(differentiable_attr_empty_where_clause,none,
30123008
"empty 'where' clause in '@differentiable' attribute", ())
30133009
// SWIFT_ENABLE_TENSORFLOW

include/swift/SILOptimizer/Utils/Differentiation/PullbackEmitter.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,16 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
434434
void visitUnconditionalCheckedCastAddrInst(
435435
UnconditionalCheckedCastAddrInst *uccai);
436436

437+
/// Handle `unchecked_ref_cast` instruction.
438+
/// Original: y = unchecked_ref_cast x
439+
/// Adjoint: adj[x] += adj[y] (assuming x' and y' have the same type)
440+
void visitUncheckedRefCastInst(UncheckedRefCastInst *urci);
441+
442+
/// Handle `upcast` instruction.
443+
/// Original: y = upcast x
444+
/// Adjoint: adj[x] += adj[y] (assuming x' and y' have the same type)
445+
void visitUpcastInst(UpcastInst *ui);
446+
437447
#define NOT_DIFFERENTIABLE(INST, DIAG) void visit##INST##Inst(INST##Inst *inst);
438448
#undef NOT_DIFFERENTIABLE
439449

lib/AST/Type.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,20 @@ Type TypeBase::replaceCovariantResultType(Type newResultType,
792792
return OptionalType::get(
793793
objectType->replaceCovariantResultType(newResultType, uncurryLevel));
794794
}
795+
// SWIFT_ENABLE_TENSORFLOW
796+
// Special logic to handle JVP/VJP derivative functions of `Self`-returning
797+
// methods. This logic is hacky and not robust at all.
798+
// Consider adding a `bool isAutoDiffDerivative` flag for triggering this
799+
// logic, or creating a separate dedicated helper
800+
// `TypeBase::replaceCovariantResultTypeForAutoDiffDerivative`.
801+
if (auto tupleType = dyn_cast<TupleType>(this)) {
802+
assert(tupleType->getNumElements() == 2 &&
803+
"Tuple result is expected only for derivative functions, which "
804+
"return a two-element tuple");
805+
return TupleType::get({newResultType, tupleType->getElement(1)},
806+
getASTContext());
807+
}
808+
// SWIFT_ENABLE_TENSORFLOW END
795809

796810
return newResultType;
797811
}

lib/SIL/SILFunctionType.cpp

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2705,17 +2705,15 @@ TypeConverter::getConstantInfo(TypeExpansionContext expansion,
27052705
loweredInterfaceType);
27062706

27072707
// SWIFT_ENABLE_TENSORFLOW
2708-
// In the case of autodiff derivative functions, the above computations
2709-
// determine `silFnType` by first computing the derivative function type at
2710-
// the AST level and then lowering that. Unfortunately, the actual
2711-
// SILFunctionType for the function is determined by first lowering the
2712-
// function's AST type, and then computing the derivative function type at the
2713-
// SIL level. "Lowering" does not commute with "getting the autodiff
2714-
// associated type", so these two computations produce different results.
2715-
// Therefore `silFnType` is not the actual type of the function that
2716-
// `constant` refers to.
2708+
// For derivative functions, the above computations determine `silFnType`
2709+
// by first computing the derivative AST function type and then lowering it to
2710+
// SIL. Unfortunately, the expected derivative SIL function type is determined
2711+
// by first lowering the original function's AST type, and then computing its
2712+
// SIL derivative function type. "Lowering" does not commute with "getting the
2713+
// derivative type", so these two computations produce different results.
2714+
// Therefore, `silFnType` is not the expected SIL derivative function type.
27172715
//
2718-
// We hackily fix this problem by redoing the computation in the right order.
2716+
// We fix this problem by performing the computation in the right order.
27192717
if (auto *autoDiffFuncId = constant.autoDiffDerivativeFunctionIdentifier) {
27202718
auto origFnConstantInfo = getConstantInfo(
27212719
TypeExpansionContext::minimal(), constant.asAutoDiffOriginalFunction());
@@ -2725,6 +2723,7 @@ TypeConverter::getConstantInfo(TypeExpansionContext expansion,
27252723
loweredIndices, /*resultIndex*/ 0, autoDiffFuncId->getKind(),
27262724
*this, LookUpConformanceInModule(&M));
27272725
}
2726+
// SWIFT_ENABLE_TENSORFLOW END
27282727

27292728
LLVM_DEBUG(llvm::dbgs() << "lowering type for constant ";
27302729
constant.print(llvm::dbgs());
@@ -2938,6 +2937,30 @@ TypeConverter::getConstantOverrideInfo(TypeExpansionContext context,
29382937
*this, context, basePattern, bridgedTypes.Uncurried, base, derived,
29392938
/*reqt subs*/ None, ProtocolConformanceRef());
29402939

2940+
// SWIFT_ENABLE_TENSORFLOW
2941+
// For derivative functions, the above computations determine `fnTy`
2942+
// by first computing the derivative AST function type and then lowering it to
2943+
// SIL. Unfortunately, the expected derivative SIL function type is determined
2944+
// by first lowering the original function's AST type, and then computing its
2945+
// SIL derivative function type. "Lowering" does not commute with "getting the
2946+
// derivative type", so these two computations produce different results.
2947+
// Therefore, `fnTy` is not the expected SIL derivative function type.
2948+
//
2949+
// We fix this problem by performing the computation in the right order.
2950+
if (auto *derivedDerivID = derived.autoDiffDerivativeFunctionIdentifier) {
2951+
auto *baseDerivID = base.autoDiffDerivativeFunctionIdentifier;
2952+
assert(baseDerivID);
2953+
auto &overrideInfo =
2954+
getConstantOverrideInfo(context, derived.asAutoDiffOriginalFunction(),
2955+
base.asAutoDiffOriginalFunction());
2956+
auto loweredIndices = autodiff::getLoweredParameterIndices(
2957+
derivedDerivID->getParameterIndices(), overrideInterfaceTy);
2958+
fnTy = overrideInfo.SILFnType->getAutoDiffDerivativeFunctionType(
2959+
loweredIndices, /*resultIndex*/ 0, derivedDerivID->getKind(), *this,
2960+
LookUpConformanceInModule(&M));
2961+
}
2962+
// SWIFT_ENABLE_TENSORFLOW END
2963+
29412964
// Build the SILConstantInfo and cache it.
29422965
auto resultBuf = Context.Allocate(sizeof(SILConstantInfo),
29432966
alignof(SILConstantInfo));

lib/SILGen/SILGenPoly.cpp

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3749,8 +3749,9 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
37493749
auto *thunk = fb.getOrCreateFunction(
37503750
loc, name, customDerivativeFn->getLinkage(), thunkFnTy, IsBare,
37513751
IsNotTransparent, customDerivativeFn->isSerialized(),
3752-
customDerivativeFn->isDynamicallyReplaceable(), customDerivativeFn->getEntryCount(),
3753-
IsThunk, customDerivativeFn->getClassSubclassScope());
3752+
customDerivativeFn->isDynamicallyReplaceable(),
3753+
customDerivativeFn->getEntryCount(), IsThunk,
3754+
customDerivativeFn->getClassSubclassScope());
37543755
thunk->setInlineStrategy(AlwaysInline);
37553756
if (!thunk->empty())
37563757
return thunk;
@@ -3762,15 +3763,39 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
37623763
thunkSGF.collectThunkParams(loc, params, &indirectResults);
37633764

37643765
auto *fnRef = thunkSGF.B.createFunctionRef(loc, customDerivativeFn);
3765-
auto fnRefType =
3766-
fnRef->getType().castTo<SILFunctionType>();
3766+
auto fnRefType = fnRef->getType().castTo<SILFunctionType>();
37673767

37683768
// Collect thunk arguments, converting ownership.
37693769
SmallVector<SILValue, 8> arguments;
37703770
for (auto *indRes : indirectResults)
37713771
arguments.push_back(indRes);
3772-
forwardFunctionArguments(thunkSGF, loc, fnRefType, params,
3773-
arguments);
3772+
forwardFunctionArguments(thunkSGF, loc, fnRefType, params, arguments);
3773+
3774+
// Special support for thunking class initializer derivatives.
3775+
//
3776+
// User-defined custom derivatives take a metatype as the last parameter:
3777+
// - `$(Param0, Param1, ..., @thick Class.Type) -> (...)`
3778+
// But class initializers take an allocated instance as the last parameter:
3779+
// - `$(Param0, Param1, ..., @owned Class) -> (...)`
3780+
//
3781+
// Adjust forwarded arguments:
3782+
// - Pop the last `@owned Class` argument.
3783+
// - Create a `@thick Class.Type` value and pass it as the last argument.
3784+
auto *origAFD =
3785+
cast<AbstractFunctionDecl>(originalFn->getDeclContext()->getAsDecl());
3786+
if (isa<ConstructorDecl>(origAFD) &&
3787+
SILDeclRef(origAFD, SILDeclRef::Kind::Initializer).mangle() ==
3788+
originalFn->getName()) {
3789+
auto classArgument = arguments.pop_back_val();
3790+
auto *classDecl = classArgument->getType().getClassOrBoundGenericClass();
3791+
assert(classDecl && "Expected last argument to have class type");
3792+
auto classMetatype = MetatypeType::get(
3793+
classDecl->getDeclaredInterfaceType(), MetatypeRepresentation::Thick);
3794+
auto canClassMetatype = classMetatype->getCanonicalType();
3795+
auto *metatype = thunkSGF.B.createMetatype(
3796+
loc, SILType::getPrimitiveObjectType(canClassMetatype));
3797+
arguments.push_back(metatype);
3798+
}
37743799
// Apply function argument.
37753800
auto apply = thunkSGF.emitApplyWithRethrow(
37763801
loc, fnRef, /*substFnType*/ fnRef->getType(),

lib/SILOptimizer/Utils/Differentiation/PullbackEmitter.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1672,6 +1672,26 @@ void PullbackEmitter::visitUnconditionalCheckedCastAddrInst(
16721672
emitZeroIndirect(destType.getASTType(), adjDest, uccai->getLoc());
16731673
}
16741674

1675+
void PullbackEmitter::visitUncheckedRefCastInst(UncheckedRefCastInst *urci) {
1676+
auto *bb = urci->getParent();
1677+
assert(urci->getOperand()->getType().isObject());
1678+
assert(getRemappedTangentType(urci->getOperand()->getType()) ==
1679+
getRemappedTangentType(urci->getType()) &&
1680+
"Operand/result must have the same `TangentVector` type");
1681+
auto adj = getAdjointValue(bb, urci);
1682+
addAdjointValue(bb, urci->getOperand(), adj, urci->getLoc());
1683+
}
1684+
1685+
void PullbackEmitter::visitUpcastInst(UpcastInst *ui) {
1686+
auto *bb = ui->getParent();
1687+
assert(ui->getOperand()->getType().isObject());
1688+
assert(getRemappedTangentType(ui->getOperand()->getType()) ==
1689+
getRemappedTangentType(ui->getType()) &&
1690+
"Operand/result must have the same `TangentVector` type");
1691+
auto adj = getAdjointValue(bb, ui);
1692+
addAdjointValue(bb, ui->getOperand(), adj, ui->getLoc());
1693+
}
1694+
16751695
#define NOT_DIFFERENTIABLE(INST, DIAG) \
16761696
void PullbackEmitter::visit##INST##Inst(INST##Inst *inst) { \
16771697
getContext().emitNondifferentiabilityError(inst, getInvoker(), \

lib/Sema/TypeCheckAttr.cpp

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3975,26 +3975,16 @@ llvm::Expected<IndexSubset *> DifferentiableAttributeTypeCheckRequest::evaluate(
39753975
// Diagnose if original function is an invalid class member.
39763976
if (isOriginalClassMember) {
39773977
// Class methods returning dynamic `Self` are not supported.
3978-
// (For class methods, dynamic `Self` is supported only as the single
3979-
// result - tuple-returning JVPs/VJPs would not type-check.)
3980-
if (auto *originalFn = dyn_cast<FuncDecl>(original)) {
3981-
if (originalFn->hasDynamicSelfResult()) {
3978+
// For class methods, dynamic `Self` is supported only as the single
3979+
// result - tuple-returning JVPs/VJPs would not type-check.
3980+
if (auto *originalFD = dyn_cast<FuncDecl>(original)) {
3981+
if (originalFD->hasDynamicSelfResult()) {
39823982
diags.diagnose(attr->getLocation(),
39833983
diag::differentiable_attr_class_member_no_dynamic_self);
39843984
attr->setInvalid();
39853985
return nullptr;
39863986
}
39873987
}
3988-
3989-
// TODO(TF-654): Class initializers are not yet supported.
3990-
// Extra JVP/VJP type calculation logic is necessary because classes have
3991-
// both allocators and initializers.
3992-
if (auto *initDecl = dyn_cast<ConstructorDecl>(original)) {
3993-
diags.diagnose(attr->getLocation(),
3994-
diag::differentiable_attr_class_init_not_yet_supported);
3995-
attr->setInvalid();
3996-
return nullptr;
3997-
}
39983988
}
39993989

40003990
// Resolve the derivative generic signature.

test/AutoDiff/Sema/differentiable_attr_type_checking.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,8 +1080,6 @@ class Super: Differentiable {
10801080

10811081
var base: Float
10821082

1083-
// NOTE(TF-654): Class initializers are not yet supported.
1084-
// expected-error @+1 {{'@differentiable' attribute does not yet support class initializers}}
10851083
@differentiable
10861084
init(base: Float) {
10871085
self.base = base

test/AutoDiff/downstream/class_differentiation.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ ClassTests.test("AddressOnlyTangentVector") {
9292
@differentiable
9393
var stored: T
9494

95+
@differentiable
9596
init(_ stored: T) {
9697
self.stored = stored
9798
}

0 commit comments

Comments
 (0)