Skip to content

[AutoDiff] Simplify AD-related SILGen logic. #26430

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions lib/SILGen/SILGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -826,8 +826,6 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
expectedJVPType);
}
silDiffAttr->setJVPName(jvpThunk->getName());
// Unset JVP so that TBDGen triggers.
diffAttr->setJVPFunction(nullptr);
}
// Thunk VJP method, if it is defined.
if (auto *vjpDecl = diffAttr->getVJPFunction()) {
Expand All @@ -845,8 +843,6 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
expectedVJPType);
}
silDiffAttr->setVJPName(vjpThunk->getName());
// Unset VJP so that TBDGen triggers.
diffAttr->setVJPFunction(nullptr);
}
}
}
Expand Down
33 changes: 4 additions & 29 deletions lib/SILGen/SILGenType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,36 +88,11 @@ SILGenModule::emitVTableMethod(ClassDecl *theClass,
implFn = getDynamicThunk(derived, Types.getConstantInfo(derived).SILFnType);
// SWIFT_ENABLE_TENSORFLOW
} else if (auto *adafi = derived.autoDiffAssociatedFunctionIdentifier) {
auto *decl = derived.getDecl();
auto *DA = *llvm::find_if(
decl->getAttrs().getAttributes<DifferentiableAttr>(),
[&](const DifferentiableAttr *attr) {
return attr->getParameterIndices() == adafi->getParameterIndices();
});
assert(DA && "Expected `@differentiable` attribute");
// Get autodiff associated function declaration, if it exists.
FuncDecl *assocDecl = nullptr;
switch (adafi->getKind()) {
case AutoDiffAssociatedFunctionKind::JVP:
assocDecl = DA->getJVPFunction();
break;
case AutoDiffAssociatedFunctionKind::VJP:
assocDecl = DA->getVJPFunction();
break;
}
// If declaration exists, get corresponding SIL function.
if (assocDecl) {
SILDeclRef assocRef(assocDecl, SILDeclRef::Kind::Func);
implFn = getFunction(assocRef, NotForDefinition);
}
// Otherwise, create an autodiff vtable entry thunk. The thunk contains an
// `autodiff_function` instruction, which is later filled during
// For JVP/VJP methods, create a vtable entry thunk. The thunk contains an
// `autodiff_function` instruction, which is later filled during the
// differentiation transform.
// TODO(TF-524): Generalize canonical JVP/VJP thunk generation.
else {
implFn = getOrCreateAutoDiffClassMethodThunk(
derived, Types.getConstantInfo(derived).SILFnType);
}
implFn = getOrCreateAutoDiffClassMethodThunk(
derived, Types.getConstantInfo(derived).SILFnType);
// SWIFT_ENABLE_TENSORFLOW END
} else {
implFn = getFunction(derived, NotForDefinition);
Expand Down
15 changes: 9 additions & 6 deletions test/AutoDiff/differentiable_attr_silgen_cross_module.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,23 @@

import differentiable_attr_silgen_other_module

// After SILGen, SIL `[differentiable]` attribute should have JVP/VJP names
// only if the AST `@differentiable` attribute does.
// The differentiation pass is guaranteed to fill in SIL `[differentiable]`
// attribute JVP/VJP names.
// After SILGen, a SIL `[differentiable]` attribute on a function from the
// current module should have JVP/VJP names only if the AST `@differentiable`
// attribute does.

// For external functions, `[differentiable]` attribute JVP/VJP names should
// always exist. The differentiation pass is guaranteed to fill in
// `[differentiable]` attribute JVP/VJP names.

_ = pullback(at: Wrapper(1)) { x in x + x * x }

// CHECK-SILGEN-LABEL: // static Wrapper.* infix(_:_:)
// CHECK-SILGEN-NEXT: sil [differentiable source 0 wrt 0, 1] @$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper
// CHECK-SILGEN-NEXT: sil [differentiable source 0 wrt 0, 1 jvp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ__jvp_src_0_wrt_0_1 vjp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ__vjp_src_0_wrt_0_1] @$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper
// CHECK-SIL-LABEL: // static Wrapper.* infix(_:_:)
// CHECK-SIL-NEXT: sil [differentiable source 0 wrt 0, 1 jvp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ__jvp_src_0_wrt_0_1 vjp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ__vjp_src_0_wrt_0_1] @$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper

// CHECK-SILGEN-LABEL: // static Wrapper.+ infix(_:_:)
// CHECK-SILGEN-NEXT: sil [differentiable source 0 wrt 0, 1] @$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper
// CHECK-SILGEN-NEXT: sil [differentiable source 0 wrt 0, 1 vjp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ__vjp_src_0_wrt_0_1] @$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper
// CHECK-SIL-LABEL: // static Wrapper.+ infix(_:_:)
// CHECK-SIL-NEXT: sil [differentiable source 0 wrt 0, 1 jvp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ__jvp_src_0_wrt_0_1 vjp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ__vjp_src_0_wrt_0_1] @$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper