Skip to content

[AutoDiff] Compute derivative types using requirements from archetypes. #39728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -4653,7 +4653,9 @@ class SILFunctionType final
AutoDiffDerivativeFunctionKind kind, Lowering::TypeConverter &TC,
LookupConformanceFn lookupConformance,
CanGenericSignature derivativeFunctionGenericSignature = nullptr,
bool isReabstractionThunk = false);
bool isReabstractionThunk = false,
CanType origTypeOfAbstraction = CanType());


/// Returns the type of the transpose function for the given parameter
/// indices, transpose function generic signature (optional), and other
Expand Down
37 changes: 28 additions & 9 deletions lib/SIL/IR/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,8 @@ getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices,
}

static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignature sig,
CanType tanType) {
CanType tanType,
CanType origTypeOfAbstraction) {
if (!sig)
return sig;

Expand Down Expand Up @@ -390,6 +391,20 @@ static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignatu
}
}

if (origTypeOfAbstraction) {
(void) origTypeOfAbstraction.findIf([&](Type t) -> bool {
if (auto *at = t->getAs<ArchetypeType>()) {
types.insert(at->getInterfaceType()->getCanonicalType());
for (auto *proto : at->getConformsTo()) {
reqs.push_back(Requirement(RequirementKind::Conformance,
at->getInterfaceType(),
proto->getDeclaredInterfaceType()));
}
}
return false;
});
}

return evaluateOrDefault(
ctx.evaluator,
AbstractGenericSignatureRequest{sig.getPointer(), {}, reqs},
Expand Down Expand Up @@ -427,14 +442,15 @@ static CanType getAutoDiffTangentTypeForLinearMap(
static CanSILFunctionType getAutoDiffDifferentialType(
SILFunctionType *originalFnTy, IndexSubset *parameterIndices,
IndexSubset *resultIndices, LookupConformanceFn lookupConformance,
CanType origTypeOfAbstraction,
TypeConverter &TC) {
// Given the tangent type and the corresponding original parameter's
// convention, returns the tangent parameter's convention.
auto getTangentParameterConvention =
[&](CanType tanType,
ParameterConvention origParamConv) -> ParameterConvention {
auto sig = buildDifferentiableGenericSignature(
originalFnTy->getSubstGenericSignature(), tanType);
originalFnTy->getSubstGenericSignature(), tanType, origTypeOfAbstraction);

tanType = tanType->getCanonicalType(sig);
AbstractionPattern pattern(sig, tanType);
Expand Down Expand Up @@ -462,7 +478,7 @@ static CanSILFunctionType getAutoDiffDifferentialType(
[&](CanType tanType,
ResultConvention origResConv) -> ResultConvention {
auto sig = buildDifferentiableGenericSignature(
originalFnTy->getSubstGenericSignature(), tanType);
originalFnTy->getSubstGenericSignature(), tanType, origTypeOfAbstraction);

tanType = tanType->getCanonicalType(sig);
AbstractionPattern pattern(sig, tanType);
Expand Down Expand Up @@ -565,7 +581,7 @@ static CanSILFunctionType getAutoDiffDifferentialType(
static CanSILFunctionType getAutoDiffPullbackType(
SILFunctionType *originalFnTy, IndexSubset *parameterIndices,
IndexSubset *resultIndices, LookupConformanceFn lookupConformance,
TypeConverter &TC) {
CanType origTypeOfAbstraction, TypeConverter &TC) {
auto &ctx = originalFnTy->getASTContext();
SmallVector<GenericTypeParamType *, 4> substGenericParams;
SmallVector<Requirement, 4> substRequirements;
Expand All @@ -582,7 +598,7 @@ static CanSILFunctionType getAutoDiffPullbackType(
[&](CanType tanType,
ResultConvention origResConv) -> ParameterConvention {
auto sig = buildDifferentiableGenericSignature(
originalFnTy->getSubstGenericSignature(), tanType);
originalFnTy->getSubstGenericSignature(), tanType, origTypeOfAbstraction);

tanType = tanType->getCanonicalType(sig);
AbstractionPattern pattern(sig, tanType);
Expand Down Expand Up @@ -613,7 +629,7 @@ static CanSILFunctionType getAutoDiffPullbackType(
[&](CanType tanType,
ParameterConvention origParamConv) -> ResultConvention {
auto sig = buildDifferentiableGenericSignature(
originalFnTy->getSubstGenericSignature(), tanType);
originalFnTy->getSubstGenericSignature(), tanType, origTypeOfAbstraction);

tanType = tanType->getCanonicalType(sig);
AbstractionPattern pattern(sig, tanType);
Expand Down Expand Up @@ -780,7 +796,8 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
AutoDiffDerivativeFunctionKind kind, TypeConverter &TC,
LookupConformanceFn lookupConformance,
CanGenericSignature derivativeFnInvocationGenSig,
bool isReabstractionThunk) {
bool isReabstractionThunk,
CanType origTypeOfAbstraction) {
assert(parameterIndices);
assert(!parameterIndices->isEmpty() && "Parameter indices must not be empty");
assert(resultIndices);
Expand Down Expand Up @@ -810,12 +827,14 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
case AutoDiffDerivativeFunctionKind::JVP:
closureType =
getAutoDiffDifferentialType(constrainedOriginalFnTy, parameterIndices,
resultIndices, lookupConformance, TC);
resultIndices, lookupConformance,
origTypeOfAbstraction, TC);
break;
case AutoDiffDerivativeFunctionKind::VJP:
closureType =
getAutoDiffPullbackType(constrainedOriginalFnTy, parameterIndices,
resultIndices, lookupConformance, TC);
resultIndices, lookupConformance,
origTypeOfAbstraction, TC);
break;
}
// Compute the derivative function parameters.
Expand Down
14 changes: 9 additions & 5 deletions lib/SIL/IR/TypeLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,19 +331,23 @@ namespace {
CanSILFunctionType type, AbstractionPattern origType) {
auto &M = TC.M;
auto origTy = type->getWithoutDifferentiability();
// Pass the `AbstractionPattern` generic signature to
// `SILFunctionType:getAutoDiffDerivativeFunctionType` for correct type
// lowering.
// Pass the original type of abstraction pattern to
// `SILFunctionType:getAutoDiffDerivativeFunctionType` to get the
// necessary generic requirements.
auto origTypeOfAbstraction =
origType.hasGenericSignature() ? origType.getType() : CanType();
auto jvpTy = origTy->getAutoDiffDerivativeFunctionType(
type->getDifferentiabilityParameterIndices(),
type->getDifferentiabilityResultIndices(),
AutoDiffDerivativeFunctionKind::JVP, TC,
LookUpConformanceInModule(&M), CanGenericSignature());
LookUpConformanceInModule(&M), CanGenericSignature(),
false, origTypeOfAbstraction);
auto vjpTy = origTy->getAutoDiffDerivativeFunctionType(
type->getDifferentiabilityParameterIndices(),
type->getDifferentiabilityResultIndices(),
AutoDiffDerivativeFunctionKind::VJP, TC,
LookUpConformanceInModule(&M), CanGenericSignature());
LookUpConformanceInModule(&M), CanGenericSignature(),
false, origTypeOfAbstraction);
RecursiveProperties props;
props.addSubobject(classifyType(origType, origTy, TC, Expansion));
props.addSubobject(classifyType(origType, jvpTy, TC, Expansion));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %target-swift-frontend -emit-sil -Xllvm -sil-print-after=differentiation %s -module-name null -o /dev/null -requirement-machine=off 2>&1 | %FileCheck %s
// RUN: %target-swift-frontend -emit-sil -Xllvm -sil-print-after=differentiation %s -module-name null -o /dev/null 2>&1 | %FileCheck %s

// Test differentiation of semantic member accessors:
// - Stored property accessors.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: %empty-directory(%t)
// RUN: not --crash %target-build-swift -emit-module -module-name pr32302 -emit-module-path %t/pr32302.swiftmodule -swift-version 5 -c %S/pr32302-autodiff-generictypeparamdecl-has-incorrect-depth.swift -Xfrontend -requirement-machine=off
// RUN: not --crash %target-build-swift -emit-module -module-name pr32302 -emit-module-path %t/pr32302.swiftmodule -swift-version 5 -c %S/pr32302-autodiff-generictypeparamdecl-has-incorrect-depth.swift
// XFAIL: *

// pr32302 / pr32343 / pr38745 : reproduce assert with _Differentiation where
Expand Down Expand Up @@ -28,7 +28,7 @@ extension Differentiable {
// GenericTypeParamDecl has incorrect depth
// Please submit a bug report (https://swift.org/contributing/#reporting-bugs) and include the project and the crash backtrace.
// Stack dump:
// 0. Program arguments: /work/software/swift-stocktoolchain/build/ds/swift-linux-x86_64/bin/swift-frontend -frontend -merge-modules -emit-module /tmp/pr32302-autodiff-generictypeparamdecl-has-incorrect-depth-acc95c.swiftmodule -parse-as-library -disable-diagnostic-passes -disable-sil-perf-optzns -target x86_64-unknown-linux-gnu -warn-on-potentially-unavailable-enum-case -disable-objc-interop -module-cache-path /work/software/swift-stocktoolchain/build/ds/swift-linux-x86_64/swift-test-results/x86_64-unknown-linux-gnu/clang-module-cache -swift-version 5 -define-availability "SwiftStdlib 5.5:macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0" -requirement-machine=off -emit-module-doc-path /work/software/swift-stocktoolchain/build/ds/swift-linux-x86_64/test-linux-x86_64/AutoDiff/compiler_crashers/Output/pr32302-autodiff-generictypeparamdecl-has-incorrect-depth.swift.tmp/pr32302.swiftdoc -emit-module-source-info-path /work/software/swift-stocktoolchain/build/ds/swift-linux-x86_64/test-linux-x86_64/AutoDiff/compiler_crashers/Output/pr32302-autodiff-generictypeparamdecl-has-incorrect-depth.swift.tmp/pr32302.swiftsourceinfo -module-name pr32302 -o /work/software/swift-stocktoolchain/build/ds/swift-linux-x86_64/test-linux-x86_64/AutoDiff/compiler_crashers/Output/pr32302-autodiff-generictypeparamdecl-has-incorrect-depth.swift.tmp/pr32302.swiftmodule
// 0. Program arguments: /work/software/swift-stocktoolchain/build/ds/swift-linux-x86_64/bin/swift-frontend -frontend -merge-modules -emit-module /tmp/pr32302-autodiff-generictypeparamdecl-has-incorrect-depth-acc95c.swiftmodule -parse-as-library -disable-diagnostic-passes -disable-sil-perf-optzns -target x86_64-unknown-linux-gnu -warn-on-potentially-unavailable-enum-case -disable-objc-interop -module-cache-path /work/software/swift-stocktoolchain/build/ds/swift-linux-x86_64/swift-test-results/x86_64-unknown-linux-gnu/clang-module-cache -swift-version 5 -define-availability "SwiftStdlib 5.5:macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0" -emit-module-doc-path /work/software/swift-stocktoolchain/build/ds/swift-linux-x86_64/test-linux-x86_64/AutoDiff/compiler_crashers/Output/pr32302-autodiff-generictypeparamdecl-has-incorrect-depth.swift.tmp/pr32302.swiftdoc -emit-module-source-info-path /work/software/swift-stocktoolchain/build/ds/swift-linux-x86_64/test-linux-x86_64/AutoDiff/compiler_crashers/Output/pr32302-autodiff-generictypeparamdecl-has-incorrect-depth.swift.tmp/pr32302.swiftsourceinfo -module-name pr32302 -o /work/software/swift-stocktoolchain/build/ds/swift-linux-x86_64/test-linux-x86_64/AutoDiff/compiler_crashers/Output/pr32302-autodiff-generictypeparamdecl-has-incorrect-depth.swift.tmp/pr32302.swiftmodule
// 1. Swift version 5.6-dev (LLVM ba0b85f590c1ba2, Swift 319b3e64aaeb252)
// 2. Compiling with the current language version
// 3. While verifying GenericTypeParamDecl 'τ_1_0' (in module 'pr32302')
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: %target-build-swift %s

import _Differentiation

public protocol Layer {
associatedtype Input: Differentiable
associatedtype Output: Differentiable
func callAsFunction(_ input: Input) -> Output
}

public class Function<Input: Differentiable, Output: Differentiable>: Layer {
public typealias Body = @differentiable(reverse) (Input) -> Output

@noDerivative public let body: Body

public init(_ body: @escaping Body) {
self.body = body
}

@differentiable(reverse)
public func callAsFunction(_ input: Input) -> Output {
body(input)
}
}