Skip to content

[AutoDiff] Fix derivative generic signature same-type requirements. #28772

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 14, 2019

Conversation

dan-zheng
Copy link
Contributor

@dan-zheng dan-zheng commented Dec 13, 2019

Fix derivative generic signature calculation when same-type requirements bind
all generic parameters to concrete types, i.e. when all generic parameters are
concrete.

Declarations whose generic signature have all concrete generic parameters are
lowered as SIL functions with no generic signature: they are specialized with
the concrete types from the same-type requirements.

For @differentiable attributes: when the original generic signature and the
derivative generic signature are equal and all generic parameters are concrete,
do not set the attribute's derivative generic signature.

Update SIL infrastructure to handle derivative generic signatures with all
concrete generic parameters. In such cases:

  • SIL derivative function types are specialized with concrete types and have
    no generic signature.
  • SIL differentiability witnesses have a derivative generic signature iff it
    differs from the original generic signature. Witness generic signatures should
    be used for remapping types during the differentiation transform.

Resolves TF-1059 and TF-1062.
Exposes SR-11950: SIL parser crash for SILGen round-trip.


Examples

TF-1059

struct Tensor<Scalar> {}
extension Tensor: Differentiable where Scalar == Float {}

extension Tensor where Scalar == Float {
  @differentiable(vjp: _vjpAdd)
  static func + (_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
    return lhs
  }

  static func _vjpAdd(lhs: Tensor, rhs: Tensor)
    -> (Tensor, (TangentVector) -> (TangentVector, TangentVector)) {
    return (lhs + rhs, { v in (v, v) })
  }
}

let _: @differentiable (Tensor<Float>, Tensor<Float>) -> Tensor<Float> = { $0 + $1 }

Before:

$ swift tf-1059.swift
tf-1059.swift:16:79: error: expression is not differentiable
let _: @differentiable (Tensor<Float>, Tensor<Float>) -> Tensor<Float> = { $0 + $1 }
                                                                              ^
tf-1059.swift:16:79: note: function call is not differentiable because generic requirements are not met: 'Scalar == Float'
let _: @differentiable (Tensor<Float>, Tensor<Float>) -> Tensor<Float> = { $0 + $1 }
                                                                              ^

After: runs successfully.

TF-1062

struct Tensor<T>: Differentiable {}
extension Tensor: Equatable where T: Equatable {}
extension Tensor where T: AdditiveArithmetic {
  static func + (_ lhs: Self, _ rhs: Self) -> Self {
    lhs
  }
}
extension Tensor where T == Float {
  @derivative(of: +)
  static func vjpAdd(_ lhs: Self, _ rhs: Self) -> (
    value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
  ) {
    return (lhs + rhs, { v in (v, v) })
  }
}

Before:

$ swift tf-1062.swift
tf-1062.swift:9:19: error: could not find function '+' with expected type '<T where T == Float> (Tensor<T>.Type) -> (Tensor<T>, Tensor<T>) -> Tensor<T>'
  @derivative(of: +)
                  ^

After: runs successfully.

@dan-zheng dan-zheng added the tensorflow This is for "tensorflow" branch PRs. label Dec 13, 2019
@@ -3816,8 +3848,20 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
auto vectorTy = valueResultConf.getTypeWitnessByName(
valueResultType, Ctx.Id_TangentVector);

// Compute the actual differential/pullback type that we use for comparison
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this code is copied from #28762 - it's needed to fix TF-1062. #28762 should be merged first!

/// - Return the witness derivative generic signature if it exists.
/// - Otherwise, return the original function's generic signature.
CanGenericSignature
getDerivativeGenericSignature(SILDifferentiabilityWitness *witness,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this helper is removed because the witness derivative generic signature can be used directly, now that derivative generic signatures are propagated correctly. There's no need to fallback to the original generic signature.

@dan-zheng dan-zheng requested review from marcrasi and rxwei and removed request for marcrasi December 13, 2019 12:06
@rxwei
Copy link
Contributor

rxwei commented Dec 13, 2019

Will this fix @differentiable(where T == Float)?

@dan-zheng
Copy link
Contributor Author

Will this fix @differentiable(where T == Float)?

Yes, that's exactly what's tested.

Fix derivative generic signature calculation when same-type requirements bind
all generic parameters to concrete types, i.e. when all generic parameters are
concrete.

Declarations whose generic signature have all concrete generic parameters are
lowered as SIL functions with no generic signature: they are specialized with
the concrete types from the same-type requirements.

For `@differentiable` attributes: when the original generic signature and the
derivative generic signature are equal and all generic parameters are concrete,
do not set the attribute's derivative generic signature.

Update SIL infrastructure to handle derivative generic signatures with all
concrete generic parameters. In such cases:
- SIL derivative function types are specialized with concrete types and have
  no generic signature.
- SIL differentiability witnesses have a derivative generic signature iff it
  differs from the original generic signature. Witness generic signatures should
  be used for remapping types during the differentiation transform.

Resolves TF-1059 and TF-1062.
@dan-zheng
Copy link
Contributor Author

@swift-ci Please test tensorflow

@dan-zheng dan-zheng merged commit 01de397 into swiftlang:tensorflow Dec 14, 2019
@dan-zheng dan-zheng deleted the derivative-gen-sig branch December 14, 2019 00:22
dan-zheng added a commit to dan-zheng/swift that referenced this pull request Dec 18, 2019
Upstream `@derivative` attribute type-checking fixes regarding derivative
generic signatures with all concrete generic parameters.

Cherry-picked from:
- swiftlang#28762
- swiftlang#28772
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tensorflow This is for "tensorflow" branch PRs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants