Skip to content

[AutoDiff] Constrain wrt parameters to conform to Differentiable. #26426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 31, 2019

Conversation

dan-zheng
Copy link
Contributor

Constrain all wrt parameters to conform to Differentiable when computing
AD associated function generic signatures.

This fixes crashes when differentiating generic original functions that
do not constrain parameters to be Differentiable, e.g. an unconstrained
identity function.

Gardening included:

  • Remove unused isSerialized flag from
    SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk.
  • Rename whereClauseGenericSignature in SIL to
    associatedFunctionGenericSignature.
    • The generic signature does not necessarily come from the where
      clause of a [differentiable] attribute.

Resolves TF-691 and TF-697.


TF-691 reproducer:

// tf-691.swift
func id<T>(_ x: T) -> T { x }
print(gradient(at: 0, in: { x in id(x) }))

Before:

$ swift tf-691.swift
Stack dump:
0.	Program arguments: /Library/Developer/Toolchains/swift-tensorflow-DEVELOPMENT-2019-07-25-a.xctoolchain/usr/bin/swift -frontend -interpret tf-691.swift -enable-objc-interop -sdk /Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX10.15.sdk -color-diagnostics -module-name main
1.	Swift version 5.1-dev (LLVM 200186e28b, Swift 3416770d73)
2.	While running pass #28 SILModuleTransform "Differentiation".
0  swift                    0x000000010ae97ad5 llvm::sys::PrintStackTrace(llvm::raw_ostream&) + 37
1  swift                    0x000000010ae96b18 llvm::sys::RunSignalHandlers() + 248
2  swift                    0x000000010ae980c8 SignalHandler(int) + 264
3  libsystem_platform.dylib 0x00007fff65936b5d _sigtramp + 29
4  libsystem_platform.dylib 0x00007fc6038c0430 _sigtramp + 2650314992
5  swift                    0x0000000107a8a803 (anonymous namespace)::SILTypeSubstituter::substSILFunctionType(swift::CanTypeWrapper<swift::SILFunctionType>) + 243
6  swift                    0x0000000107a8a6d7 swift::SILFunctionType::substGenericArgs(swift::SILModule&, swift::SubstitutionMap) + 231
7  swift                    0x0000000107afc59d swift::SILType::substGenericArgs(swift::SILModule&, swift::SubstitutionMap) const + 77
8  swift                    0x0000000107ab1dc6 swift::PartialApplyInst::create(swift::SILDebugLocation, swift::SILValue, llvm::ArrayRef<swift::SILValue>, swift::SubstitutionMap, swift::ParameterConvention, swift::SILFunction&, swift::SILOpenedArchetypesState&, swift::GenericSpecializationInformation const*, swift::PartialApplyInst::OnStackKind) + 70
9  swift                    0x00000001077cfc1d swift::SILInstructionVisitor<(anonymous namespace)::VJPEmitter, void>::visit(swift::SILInstruction*) + 50573
10 swift                    0x00000001077af6b8 (anonymous namespace)::ADContext::processDifferentiableAttribute(swift::SILFunction*, swift::SILDifferentiableAttr*, (anonymous namespace)::DifferentiationInvoker) + 10856
11 swift                    0x00000001077fe1ca (anonymous namespace)::ADContext::promoteToDifferentiableFunction(swift::AutoDiffFunctionInst*, swift::SILBuilder&, swift::SILLocation, (anonymous namespace)::DifferentiationInvoker) + 4954
12 swift                    0x00000001077b263a (anonymous namespace)::ADContext::processAutoDiffFunctionInst(swift::AutoDiffFunctionInst*) + 378
13 swift                    0x00000001077ac501 (anonymous namespace)::Differentiation::run() + 2913

After:

$ swift tf-691.swift
1.0

TF-697 reproducer:

// tf-697.swift
// TF-697: Test generic requirements of generated AD associated function.
protocol Module: Differentiable where AllDifferentiableVariables == TangentVector {
  associatedtype Input
  associatedtype Output: Differentiable

  @differentiable(wrt: self)
  func callAsFunction(_ input: Input) -> Output
}
protocol Layer: Module where Input: Differentiable {
  @differentiable
  func callLayer(_ input: Input) -> Output
}
struct Sequential<Layer1: Module, Layer2: Layer>: Module
  where Layer1.Output == Layer2.Input {
  var layer1: Layer1
  var layer2: Layer2

  @differentiable(wrt: self)
  func callAsFunction(_ input: Layer1.Input) -> Layer2.Output {
      layer2(layer1(input))
  }
}
extension Sequential: Layer where Layer1: Layer {
  @differentiable
  func callAsFunction(_ input: Layer1.Input) -> Layer2.Output {
      layer2(layer1(input))
  }
}

Before:

$ swift tf-697.swift
SIL verification failed: JVP type does not match expected JVP type
  $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 : Module, τ_0_1 : Layer, τ_0_0.AllDifferentiableVariables == τ_0_0.TangentVector.AllDifferentiableVariables, τ_0_0.Input : Differentiable, τ_0_0.Output == τ_0_1.Input, τ_0_0.AllDifferentiableVariables.VectorSpaceScalar == τ_0_1.AllDifferentiableVariables.VectorSpaceScalar> (@in_guaranteed τ_0_0.Input, @in_guaranteed Sequential<τ_0_0, τ_0_1>) -> (@out τ_0_1.Output, @owned @callee_guaranteed (@in_guaranteed Sequential<τ_0_0, τ_0_1>.AllDifferentiableVariables) -> @out τ_0_1.Output.TangentVector)
  $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 : Module, τ_0_1 : Layer, τ_0_0.AllDifferentiableVariables == τ_0_0.TangentVector.AllDifferentiableVariables, τ_0_0.Output == τ_0_1.Input, τ_0_0.AllDifferentiableVariables.VectorSpaceScalar == τ_0_1.AllDifferentiableVariables.VectorSpaceScalar> (@in_guaranteed τ_0_0.Input, @in_guaranteed Sequential<τ_0_0, τ_0_1>) -> (@out τ_0_1.Output, @owned @callee_guaranteed (@in_guaranteed Sequential<τ_0_0, τ_0_1>.AllDifferentiableVariables) -> @out τ_0_1.Output.TangentVector)
...

After:

$ swift tf-697.swift
# No error.

Constrain all wrt parameters to conform to `Differentiable` when computing
AD associated function generic signatures.

This fixes crashes when differentiating generic original functions that
do not constrain parameters to be `Differentiable`, e.g. an unconstrained
identity function.

Gardening included:
- Remove unused `isSerialized` flag from
  `SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk`.
- Rename `whereClauseGenericSignature` in SIL to
  `associatedFunctionGenericSignature`.
  - The generic signature does not necessarily come from the `where`
    clause of a `[differentiable]` attribute.

Resolves TF-691 and TF-697.
@dan-zheng dan-zheng added the tensorflow This is for "tensorflow" branch PRs. label Jul 31, 2019
// Returns the canonical generic signature for an autodiff associated function
// given an existing associated function generic signature. All differentiation
// parameters are constrained to conform to `Differentiable`.
static CanGenericSignature getAssociatedFunctionGenericSignature(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about unifying getAssociatedFunctionGenericSignature defined here in lib/SIL/SILFunctionType.cpp (1) with the one defined in lib/SILOptimizer/Differentiation.cpp (2), but the two are not quite compatible.

(1) takes a pre-existing associated function generic signature, but (2) builds the associated function generic signature given the original function's generic signature and [differentiable] attribute requirements.

The part that can be shared is the "constrain all wrt parameters to conform to Differentiable" logic - I don't sharing that short code is worthwhile.

@dan-zheng dan-zheng requested a review from rxwei July 31, 2019 02:50
@dan-zheng
Copy link
Contributor Author

@swift-ci Please test tensorflow

@dan-zheng
Copy link
Contributor Author

Merging to unblock swift-DEVELOPMENT-SNAPSHOT-2019-07-28-a merge (#26398), happy to address feedback later.

@dan-zheng dan-zheng merged commit ead5f4d into swiftlang:tensorflow Jul 31, 2019
@dan-zheng dan-zheng deleted the TF-697 branch July 31, 2019 06:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tensorflow This is for "tensorflow" branch PRs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants