Skip to content

[API] [AD] Revamp @differentiable usages in stdlib. #21732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions lib/SILOptimizer/Mandatory/TFDifferentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2441,9 +2441,16 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
newArgs.push_back(getOpValue(origArg));
assert(newArgs.size() == numVJPParams);
// Apply the VJP.
auto *vjpCall = getBuilder().createApply(ai->getLoc(), vjp,
ai->getSubstitutionMap(), newArgs,
ai->isNonThrowing());
auto substMap = ai->getSubstitutionMap();
if (auto vjpGenSig = vjpFnTy->getGenericSignature()) {
auto vjpSubstMap =
vjpGenSig->createGenericEnvironment()->getForwardingSubstitutionMap();
substMap = vjpSubstMap.subst(
[&](SubstitutableType *ty) { return Type(ty).subst(substMap); },
LookUpConformanceInModule(context.getModule().getSwiftModule()));
}
auto *vjpCall = getBuilder().createApply(ai->getLoc(), vjp, substMap,
newArgs, ai->isNonThrowing());
LLVM_DEBUG(getADDebugStream() << "Applied vjp function\n" << *vjpCall);

// Get the VJP results (original results and pullback).
Expand Down
5 changes: 3 additions & 2 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2345,14 +2345,15 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {

// Conformance requirements are valid if:
// - The first type is a generic type parameter type.
// - The second type is a protocol type.
// - The second type is a protocol type or protocol composition type.
case RequirementKind::Conformance:
if (diagnoseDifferentiableAttrIndirectGenericType(
attr->getLocation(), req.getFirstType(),
reqRepr->getSubjectRepr()))
return false;

if (!req.getSecondType()->is<ProtocolType>()) {
if (!req.getSecondType()->is<ProtocolType>() &&
!req.getSecondType()->is<ProtocolCompositionType>()) {
TC.diagnose(attr->getLocation(),
diag::differentiable_attr_non_protocol_type_constraint_req)
.highlight(reqRepr->getSourceRange());
Expand Down
14 changes: 10 additions & 4 deletions stdlib/public/TensorFlow/CompositeMath.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,28 @@
/// Computes `sigmoid` of the specified tensor element-wise.
/// Specifically, computes `1 / (1 + exp(-x))`.
@inlinable @inline(__always)
public func sigmoid<T : BinaryFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
public func sigmoid<T>(_ x: Tensor<T>) -> Tensor<T>
where T : Differentiable & FloatingPoint
{
return 1 / (1 + exp(-x))
}

/// Computes `relu` of the specified tensor element-wise.
/// Specifically, computes `max(0, x)`.
@inlinable @inline(__always)
@differentiable(adjoint: _adjointRelu(_:_:_:))
public func relu<T : BinaryFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
public func relu<T>(_ x: Tensor<T>) -> Tensor<T>
where T : Differentiable & FloatingPoint
{
return max(0, x)
}

/// Computes the softmax of the specified tensor element-wise.
/// Specifically, computes `exp(x) / exp(x).sum()`.
@inlinable @inline(__always)
public func softmax<T : BinaryFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
public func softmax<T>(_ x: Tensor<T>) -> Tensor<T>
where T : Differentiable & FloatingPoint
{
let expx = exp(x)
let sum = expx.sum()
return expx / sum
Expand All @@ -43,7 +49,7 @@ public func softmax<T : BinaryFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
/// Computes the softmax of the specified tensor along the specified axis.
/// Specifically, computes `exp(x) / exp(x).sum(alongAxes: axis)`.
@inlinable @inline(__always)
public func softmax<T : BinaryFloatingPoint>(
public func softmax<T : Differentiable & FloatingPoint>(
_ x: Tensor<T>, alongAxis axis: Int32
) -> Tensor<T> {
let expx = exp(x)
Expand Down
54 changes: 24 additions & 30 deletions stdlib/public/TensorFlow/Gradients.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@
// TODO:
// - Add gradients for more ops ('sum', 'mean', etc).
// - Fix gradients for broadcasting ops (need to perform reduction).
// - When the trailing 'where' clause in @differentiable is properly
// type-checked, define constraints on BinaryFloatingPoint in original
// declarations and define adjoints on BinaryFloatingPoint.
//
// FIXME:
// - Handle scalar broadcasting.
Expand All @@ -49,7 +46,7 @@
// Elementwise binary
//===----------------------------------------------------------------------===//

extension Tensor where Scalar : Numeric {
extension Tensor where Scalar : Differentiable & FloatingPoint {
@inlinable
static func _adjointAdd(
_ seed: Tensor, _ originalValue: Tensor, _ x: Tensor, _ y: Tensor
Expand Down Expand Up @@ -84,7 +81,7 @@ extension Tensor where Scalar : Numeric {
}

@inlinable
func _adjointMinMax<T : Numeric & Comparable>(
func _adjointMinMax<T : Differentiable & FloatingPoint>(
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>, _ y: Tensor<T>
) -> (Tensor<T>, Tensor<T>) {
let denom = 1 + Tensor<T>(x .== y)
Expand All @@ -94,7 +91,7 @@ func _adjointMinMax<T : Numeric & Comparable>(
}

@inlinable
func _adjointPow<T : BinaryFloatingPoint>(
func _adjointPow<T : Differentiable & FloatingPoint>(
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>, _ y: Tensor<T>
) -> (Tensor<T>, Tensor<T>) {
return ((seed * y * pow(x, y-1)).unbroadcast(like: x),
Expand All @@ -105,7 +102,7 @@ func _adjointPow<T : BinaryFloatingPoint>(
// Elementwise unary
//===----------------------------------------------------------------------===//

extension Tensor where Scalar : SignedNumeric {
extension Tensor where Scalar : Differentiable & FloatingPoint {
@inlinable
static func _adjointNegate(
_ seed: Tensor, _ originalValue: Tensor, _ x: Tensor
Expand All @@ -115,90 +112,90 @@ extension Tensor where Scalar : SignedNumeric {
}

@inlinable
func _adjointLog<T : BinaryFloatingPoint>(
func _adjointLog<T : Differentiable & FloatingPoint>(
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
) -> Tensor<T> {
return seed / x
}

@inlinable
func _adjointSin<T : BinaryFloatingPoint>(
func _adjointSin<T : Differentiable & FloatingPoint>(
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
) -> Tensor<T> {
return seed * cos(x)
}

@inlinable
func _adjointCos<T : BinaryFloatingPoint>(
func _adjointCos<T : Differentiable & FloatingPoint>(
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
) -> Tensor<T> {
return -seed * sin(x)
}

@inlinable
func _adjointTan<T : BinaryFloatingPoint>(
func _adjointTan<T : Differentiable & FloatingPoint>(
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
) -> Tensor<T> {
return seed * (1 + originalValue.squared())
}

@inlinable
func _adjointSinh<T : BinaryFloatingPoint>(
func _adjointSinh<T : Differentiable & FloatingPoint>(
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
) -> Tensor<T> {
return seed * cosh(x)
}

@inlinable
func _adjointCosh<T : BinaryFloatingPoint>(
func _adjointCosh<T : Differentiable & FloatingPoint>(
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
) -> Tensor<T> {
return seed * sinh(x)
}

@inlinable
func _adjointTanh<T : BinaryFloatingPoint>(
func _adjointTanh<T : Differentiable & FloatingPoint>(
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
) -> Tensor<T> {
return seed * (1 - originalValue.squared())
}

@inlinable
func _adjointExp<T : BinaryFloatingPoint>(
func _adjointExp<T : Differentiable & FloatingPoint>(
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
) -> Tensor<T> {
return originalValue * seed
}

@inlinable
func _adjointCeil<T : BinaryFloatingPoint>(
func _adjointCeil<T : Differentiable & FloatingPoint>(
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
) -> Tensor<T> {
return Tensor(0).broadcast(like: x)
}

@inlinable
func _adjointFloor<T : BinaryFloatingPoint>(
func _adjointFloor<T : Differentiable & FloatingPoint>(
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
) -> Tensor<T> {
return Tensor(0).broadcast(like: x)
}

@inlinable
func _adjointSqrt<T : BinaryFloatingPoint>(
func _adjointSqrt<T : Differentiable & FloatingPoint>(
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
) -> Tensor<T> {
return seed / (2 * originalValue)
}

@inlinable
func _adjointRsqrt<T : BinaryFloatingPoint>(
func _adjointRsqrt<T : Differentiable & FloatingPoint>(
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
) -> Tensor<T> {
return -seed / 2 * pow(originalValue, 3)
}

func _adjointSquared<T : BinaryFloatingPoint>(
func _adjointSquared<T : Differentiable & FloatingPoint>(
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
) -> Tensor<T> {
return 2 * x * seed
Expand All @@ -209,7 +206,7 @@ func _adjointSquared<T : BinaryFloatingPoint>(
//===----------------------------------------------------------------------===//

@inlinable
func _adjointMatmul<Scalar : Numeric>(
func _adjointMatmul<Scalar : Differentiable & FloatingPoint>(
_ seed: Tensor<Scalar>, _ originalValue: Tensor<Scalar>,
_ left: Tensor<Scalar>, _ right: Tensor<Scalar>
) -> (Tensor<Scalar>, Tensor<Scalar>) {
Expand All @@ -220,16 +217,14 @@ func _adjointMatmul<Scalar : Numeric>(
// TODO: We have to define a custom adjoint on • because AD can't yet
// differentiate generic methods. After AD can differentiate generic methods,
// remove the custom adjoint.
extension Tensor where Scalar : Numeric {
extension Tensor where Scalar : Differentiable & FloatingPoint {
@inlinable
static func _adjointMatmulOperator(seed: Tensor, originalValue: Tensor,
lhs: Tensor, rhs: Tensor)
-> (Tensor, Tensor) {
return _adjointMatmul(seed, originalValue, lhs, rhs)
}
}

extension Tensor {
@inlinable
func _adjointTransposed(
_ seed: Tensor, _ originalValue: Tensor, _ permutations: Tensor<Int32>
Expand All @@ -243,7 +238,7 @@ extension Tensor {
// Shape transformations
//===----------------------------------------------------------------------===//

extension Tensor {
extension Tensor where Scalar : Differentiable & FloatingPoint {
@inlinable
func _adjointReshaped(
seed: Tensor, originalValue: Tensor, toShape newShape: Tensor<Int32>
Expand All @@ -265,9 +260,8 @@ extension Tensor {
// Normalization
//===----------------------------------------------------------------------===//

extension Tensor where Scalar : BinaryFloatingPoint,
Scalar : Differentiable,
Scalar.CotangentVector == Scalar {
extension Tensor where Scalar : BinaryFloatingPoint & Differentiable,
Scalar == Scalar.CotangentVector {
// TODO: Verify that these calculations are correct.
@inlinable
func _adjointBatchNormalized(
Expand Down Expand Up @@ -304,7 +298,7 @@ extension Tensor where Scalar : BinaryFloatingPoint,
// Convolution and pooling
//===----------------------------------------------------------------------===//

extension Tensor where Scalar : BinaryFloatingPoint {
extension Tensor where Scalar : Differentiable & FloatingPoint {
/// TensorFlow builtin conv2d gradient helper for the input.
@inlinable
@differentiable(
Expand Down Expand Up @@ -448,7 +442,7 @@ extension Tensor where Scalar : BinaryFloatingPoint {
//===----------------------------------------------------------------------===//

@inlinable
func _adjointRelu<T : BinaryFloatingPoint>(
func _adjointRelu<T : Differentiable & FloatingPoint>(
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
) -> Tensor<T> {
return Tensor(x .> 0) * seed
Expand Down
Loading