Skip to content

Commit 79cdb96

Browse files
dan-zhengrxwei
authored andcommitted
[AutoDiff] Simplify AD-related SILGen logic. (#26430)
- Do not unset `@differentiable` attribute functions for consistent same-module and cross-module `[differentiable]` SILGen behavior. - Always create vtable entry thunks for JVPs/VJPs. - Previously, vtable entry thunks were conditionally generated based on an ad-hoc condition. - Update tests. Todo: remove JVP/VJP names from `[differentiable]` attribute.
1 parent b82d755 commit 79cdb96

File tree

3 files changed

+13
-39
lines changed

3 files changed

+13
-39
lines changed

lib/SILGen/SILGen.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -826,8 +826,6 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
826826
expectedJVPType);
827827
}
828828
silDiffAttr->setJVPName(jvpThunk->getName());
829-
// Unset JVP so that TBDGen triggers.
830-
diffAttr->setJVPFunction(nullptr);
831829
}
832830
// Thunk VJP method, if it is defined.
833831
if (auto *vjpDecl = diffAttr->getVJPFunction()) {
@@ -845,8 +843,6 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
845843
expectedVJPType);
846844
}
847845
silDiffAttr->setVJPName(vjpThunk->getName());
848-
// Unset VJP so that TBDGen triggers.
849-
diffAttr->setVJPFunction(nullptr);
850846
}
851847
}
852848
}

lib/SILGen/SILGenType.cpp

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -88,36 +88,11 @@ SILGenModule::emitVTableMethod(ClassDecl *theClass,
8888
implFn = getDynamicThunk(derived, Types.getConstantInfo(derived).SILFnType);
8989
// SWIFT_ENABLE_TENSORFLOW
9090
} else if (auto *adafi = derived.autoDiffAssociatedFunctionIdentifier) {
91-
auto *decl = derived.getDecl();
92-
auto *DA = *llvm::find_if(
93-
decl->getAttrs().getAttributes<DifferentiableAttr>(),
94-
[&](const DifferentiableAttr *attr) {
95-
return attr->getParameterIndices() == adafi->getParameterIndices();
96-
});
97-
assert(DA && "Expected `@differentiable` attribute");
98-
// Get autodiff associated function declaration, if it exists.
99-
FuncDecl *assocDecl = nullptr;
100-
switch (adafi->getKind()) {
101-
case AutoDiffAssociatedFunctionKind::JVP:
102-
assocDecl = DA->getJVPFunction();
103-
break;
104-
case AutoDiffAssociatedFunctionKind::VJP:
105-
assocDecl = DA->getVJPFunction();
106-
break;
107-
}
108-
// If declaration exists, get corresponding SIL function.
109-
if (assocDecl) {
110-
SILDeclRef assocRef(assocDecl, SILDeclRef::Kind::Func);
111-
implFn = getFunction(assocRef, NotForDefinition);
112-
}
113-
// Otherwise, create an autodiff vtable entry thunk. The thunk contains an
114-
// `autodiff_function` instruction, which is later filled during
91+
// For JVP/VJP methods, create a vtable entry thunk. The thunk contains an
92+
// `autodiff_function` instruction, which is later filled during the
11593
// differentiation transform.
116-
// TODO(TF-524): Generalize canonical JVP/VJP thunk generation.
117-
else {
118-
implFn = getOrCreateAutoDiffClassMethodThunk(
119-
derived, Types.getConstantInfo(derived).SILFnType);
120-
}
94+
implFn = getOrCreateAutoDiffClassMethodThunk(
95+
derived, Types.getConstantInfo(derived).SILFnType);
12196
// SWIFT_ENABLE_TENSORFLOW END
12297
} else {
12398
implFn = getFunction(derived, NotForDefinition);

test/AutoDiff/differentiable_attr_silgen_cross_module.swift

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,23 @@
55

66
import differentiable_attr_silgen_other_module
77

8-
// After SILGen, SIL `[differentiable]` attribute should have JVP/VJP names
9-
// only if the AST `@differentiable` attribute does.
10-
// The differentiation pass is guaranteed to fill in SIL `[differentiable]`
11-
// attribute JVP/VJP names.
8+
// After SILGen, a SIL `[differentiable]` attribute on a function from the
9+
// current module should have JVP/VJP names only if the AST `@differentiable`
10+
// attribute does.
11+
12+
// For external functions, `[differentiable]` attribute JVP/VJP names should
13+
// always exist. The differentiation pass is guaranteed to fill in
14+
// `[differentiable]` attribute JVP/VJP names.
1215

1316
_ = pullback(at: Wrapper(1)) { x in x + x * x }
1417

1518
// CHECK-SILGEN-LABEL: // static Wrapper.* infix(_:_:)
16-
// CHECK-SILGEN-NEXT: sil [differentiable source 0 wrt 0, 1] @$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper
19+
// CHECK-SILGEN-NEXT: sil [differentiable source 0 wrt 0, 1 jvp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ__jvp_src_0_wrt_0_1 vjp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ__vjp_src_0_wrt_0_1] @$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper
1720
// CHECK-SIL-LABEL: // static Wrapper.* infix(_:_:)
1821
// CHECK-SIL-NEXT: sil [differentiable source 0 wrt 0, 1 jvp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ__jvp_src_0_wrt_0_1 vjp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ__vjp_src_0_wrt_0_1] @$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper
1922

2023
// CHECK-SILGEN-LABEL: // static Wrapper.+ infix(_:_:)
21-
// CHECK-SILGEN-NEXT: sil [differentiable source 0 wrt 0, 1] @$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper
24+
// CHECK-SILGEN-NEXT: sil [differentiable source 0 wrt 0, 1 vjp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ__vjp_src_0_wrt_0_1] @$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper
2225
// CHECK-SIL-LABEL: // static Wrapper.+ infix(_:_:)
2326
// CHECK-SIL-NEXT: sil [differentiable source 0 wrt 0, 1 jvp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ__jvp_src_0_wrt_0_1 vjp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ__vjp_src_0_wrt_0_1] @$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper
2427

0 commit comments

Comments
 (0)