Skip to content

Commit 709d0f9

Browse files
authored
---
yaml --- r: 338939 b: refs/heads/rxwei-patch-1 c: 7c6d2fa h: refs/heads/master i: 338937: e846ad6 338935: b6e2678
1 parent 8232160 commit 709d0f9

File tree

4 files changed

+113
-36
lines changed

4 files changed

+113
-36
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,7 @@ refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-08-18-a: b10b1fce14385faa6d44f6b933e95
10151015
refs/heads/rdar-43033749-fix-batch-mode-no-diags-swift-5.0-branch: a14e64eaad30de89f0f5f0b2a782eed7ecdcb255
10161016
refs/heads/revert-19006-error-bridging-integer-type: 8a9065a3696535305ea53fe9b71f91cbe6702019
10171017
refs/heads/revert-19050-revert-19006-error-bridging-integer-type: ecf752d54b05dd0a20f510f0bfa54a3fec3bcaca
1018-
refs/heads/rxwei-patch-1: b14f205d0a0907c06d6ae5d0501ede2f6e3756c8
1018+
refs/heads/rxwei-patch-1: 7c6d2fa0b744f651fe88424628bd4363d54a688a
10191019
refs/heads/shahmishal-patch-1: e58ec0f7488258d42bef51bc3e6d7b3dc74d7b2a
10201020
refs/heads/typelist-existential: 4046359efd541fb5c72d69a92eefc0a784df8f5e
10211021
refs/tags/swift-4.2-DEVELOPMENT-SNAPSHOT-2018-08-20-a: 4319ba09e4fb8650ee86061075c74a016b6baab9

branches/rxwei-patch-1/lib/Sema/TypeCheckAttr.cpp

Lines changed: 62 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3053,15 +3053,6 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
30533053
// Set the checked differentiation parameter indices in the attribute.
30543054
attr->setParameterIndices(checkedWrtParamIndices);
30553055

3056-
auto insertion =
3057-
ctx.DifferentiableAttrs.try_emplace({D, checkedWrtParamIndices}, attr);
3058-
// `@differentiable` attributes are uniqued by their parameter indices.
3059-
// Reject duplicate attributes for the same decl and parameter indices pair.
3060-
if (!insertion.second && insertion.first->getSecond() != attr) {
3061-
diagnoseAndRemoveAttr(attr, diag::differentiable_attr_duplicate);
3062-
return;
3063-
}
3064-
30653056
// Check that original function's result type conforms to `Differentiable`.
30663057
if (whereClauseGenEnv) {
30673058
auto originalResultInterfaceType = !originalResultTy->hasTypeParameter()
@@ -3079,9 +3070,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
30793070
}
30803071

30813072
// Checks that the `candidate` function type equals the `required` function
3082-
// type, disregarding parameter labels.
3083-
//
3084-
// Precondition: `required` has no parameter labels.
3073+
// type, disregarding parameter labels and tuple result labels.
30853074
std::function<bool(CanAnyFunctionType, CanType)> checkFunctionSignature;
30863075
checkFunctionSignature = [&](CanAnyFunctionType required,
30873076
CanType candidate) -> bool {
@@ -3096,21 +3085,31 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
30963085
required.getOptGenericSignature())
30973086
return false;
30983087

3099-
// Check that parameter types match (disregards labels).
3100-
if (candidateFnTy.getParams().size() != required.getParams().size())
3088+
// Check that parameter types match, disregarding labels.
3089+
if (!std::equal(required.getParams().begin(), required.getParams().end(),
3090+
candidateFnTy.getParams().begin(),
3091+
[](AnyFunctionType::Param x, AnyFunctionType::Param y) {
3092+
return x.getPlainType()->isEqual(y.getPlainType());
3093+
}))
31013094
return false;
3102-
for (auto paramPair : llvm::zip(candidateFnTy.getParams(),
3103-
required.getParams()))
3104-
if (!std::get<0>(paramPair).getPlainType()->isEqual(
3105-
std::get<1>(paramPair).getPlainType()))
3106-
return false;
31073095

3108-
// If required result type is non-function, check that result types match
3109-
// exactly.
3096+
// If required result type is non-function, check that result types match.
3097+
// If result types are tuple types, ignore labels.
31103098
CanAnyFunctionType requiredResultFnTy =
31113099
dyn_cast<AnyFunctionType>(required.getResult());
3112-
if (!requiredResultFnTy)
3113-
return required.getResult() == candidateFnTy.getResult();
3100+
if (!requiredResultFnTy) {
3101+
auto requiredResultTupleTy = required.getResult()->getAs<TupleType>();
3102+
auto candidateResultTupleTy =
3103+
candidateFnTy.getResult()->getAs<TupleType>();
3104+
if (!requiredResultTupleTy || !candidateResultTupleTy)
3105+
return required.getResult()->isEqual(candidateFnTy.getResult());
3106+
// If result types are tuple types, check that element types match,
3107+
// ignoring labels.
3108+
return std::equal(requiredResultTupleTy->getElementTypes().begin(),
3109+
requiredResultTupleTy->getElementTypes().end(),
3110+
candidateResultTupleTy->getElementTypes().begin(),
3111+
[](Type x, Type y) { return x->isEqual(y); });
3112+
}
31143113

31153114
// Required result type is a function. Recurse.
31163115
return checkFunctionSignature(requiredResultFnTy,
@@ -3168,6 +3167,15 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
31683167
// Memorize the vjp reference in the attribute.
31693168
attr->setVJPFunction(vjp);
31703169
}
3170+
3171+
auto insertion =
3172+
ctx.DifferentiableAttrs.try_emplace({D, checkedWrtParamIndices}, attr);
3173+
// `@differentiable` attributes are uniqued by their parameter indices.
3174+
// Reject duplicate attributes for the same decl and parameter indices pair.
3175+
if (!insertion.second && insertion.first->getSecond() != attr) {
3176+
diagnoseAndRemoveAttr(attr, diag::differentiable_attr_duplicate);
3177+
return;
3178+
}
31713179
}
31723180

31733181
// SWIFT_ENABLE_TENSORFLOW
@@ -3489,33 +3497,50 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) {
34893497
derivativeRequirements.push_back(req);
34903498
}
34913499

3492-
// Add the derivative to a `@differentiable` attribute on the original
3493-
// function with the same differentiation parameters. If no such
3494-
// `@differentiable` attribute exists, create one.
3500+
// Try to find a `@differentiable` attribute on the original function with the
3501+
// same differentiation parameters.
34953502
DifferentiableAttr *da = nullptr;
34963503
for (auto *cda : originalFn->getAttrs().getAttributes<DifferentiableAttr>())
34973504
if (checkedWrtParamIndices == cda->getParameterIndices())
34983505
da = const_cast<DifferentiableAttr *>(cda);
3506+
// If the original function does not have a `@differentiable` attribute with
3507+
// the same differentiation parameters, create one.
34993508
if (!da) {
35003509
da = DifferentiableAttr::create(ctx, /*implicit*/ true, attr->AtLoc,
35013510
attr->getRange(), checkedWrtParamIndices,
3502-
None, None, derivativeRequirements);
3511+
/*jvp*/ None, /*vjp*/ None,
3512+
derivativeRequirements);
3513+
switch (kind) {
3514+
case AutoDiffAssociatedFunctionKind::JVP:
3515+
da->setJVPFunction(derivative);
3516+
break;
3517+
case AutoDiffAssociatedFunctionKind::VJP:
3518+
da->setVJPFunction(derivative);
3519+
break;
3520+
}
35033521
auto insertion = ctx.DifferentiableAttrs.try_emplace(
35043522
{originalFn, checkedWrtParamIndices}, da);
3505-
// `@differentiable` attributes are uniqued by their parameter indices.
3506-
// Reject duplicate attributes for the same decl and parameter indices pair.
3523+
// Valid `@differentiable` attributes are uniqued by their parameter
3524+
// indices. Reject duplicate attributes for the same decl and parameter
3525+
// indices pair.
35073526
if (!insertion.second && insertion.first->getSecond() != da) {
35083527
diagnoseAndRemoveAttr(da, diag::differentiable_attr_duplicate);
35093528
return;
35103529
}
35113530
originalFn->getAttrs().add(da);
3531+
return;
35123532
}
3513-
// Check if the `@differentiable` attribute already has a registered
3514-
// derivative. If so, emit an error on the `@differentiating` attribute.
3515-
// Otherwise, register the derivative in the `@differentiable` attribute.
3533+
// If the original function has a `@differentiable` attribute with the same
3534+
// differentiation parameters, check if the `@differentiable` attribute
3535+
// already has a different registered derivative. If so, emit an error on the
3536+
// `@differentiating` attribute. Otherwise, register the derivative in the
3537+
// `@differentiable` attribute.
35163538
switch (kind) {
35173539
case AutoDiffAssociatedFunctionKind::JVP:
3518-
if (da->getJVP() || da->getJVPFunction()) {
3540+
// If there's a different registered derivative, emit an error.
3541+
if ((da->getJVP() &&
3542+
da->getJVP()->Name.getBaseName() != derivative->getBaseName()) ||
3543+
(da->getJVPFunction() && da->getJVPFunction() != derivative)) {
35193544
diagnoseAndRemoveAttr(
35203545
attr, diag::differentiating_attr_original_already_has_derivative,
35213546
originalFn->getFullName());
@@ -3524,7 +3549,10 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) {
35243549
da->setJVPFunction(derivative);
35253550
break;
35263551
case AutoDiffAssociatedFunctionKind::VJP:
3527-
if (da->getVJP() || da->getVJPFunction()) {
3552+
// If there's a different registered derivative, emit an error.
3553+
if ((da->getVJP() &&
3554+
da->getVJP()->Name.getBaseName() != derivative->getBaseName()) ||
3555+
(da->getVJPFunction() && da->getVJPFunction() != derivative)) {
35283556
diagnoseAndRemoveAttr(
35293557
attr, diag::differentiating_attr_original_already_has_derivative,
35303558
originalFn->getFullName());

branches/rxwei-patch-1/test/AutoDiff/differentiable_attr_type_checking.swift

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,41 @@ func vjpWhere1<T : Differentiable>(x: T) -> (T, (T.CotangentVector) -> T.Cotange
505505
return (x, { v in v })
506506
}
507507

508+
// Test derivative functions with result tuple type labels.
509+
@differentiable(jvp: jvpResultLabels, vjp: vjpResultLabels)
510+
func derivativeResultLabels(_ x: Float) -> Float {
511+
return x
512+
}
513+
func jvpResultLabels(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
514+
return (x, { $0 })
515+
}
516+
func vjpResultLabels(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
517+
return (x, { $0 })
518+
}
519+
struct ResultLabelTest {
520+
@differentiable(jvp: jvpResultLabels, vjp: vjpResultLabels)
521+
static func derivativeResultLabels(_ x: Float) -> Float {
522+
return x
523+
}
524+
static func jvpResultLabels(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
525+
return (x, { $0 })
526+
}
527+
static func vjpResultLabels(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
528+
return (x, { $0 })
529+
}
530+
531+
@differentiable(jvp: jvpResultLabels, vjp: vjpResultLabels)
532+
func derivativeResultLabels(_ x: Float) -> Float {
533+
return x
534+
}
535+
func jvpResultLabels(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
536+
return (x, { $0 })
537+
}
538+
func vjpResultLabels(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
539+
return (x, { $0 })
540+
}
541+
}
542+
508543
struct Tensor<Scalar> : AdditiveArithmetic {}
509544
extension Tensor : Differentiable where Scalar : Differentiable {
510545
typealias TangentVector = Tensor
@@ -544,7 +579,6 @@ extension FloatingPoint {
544579
}
545580

546581
protocol MethodDiffReq {
547-
// expected-error @+1 {{'vjpFoo' does not have expected type '<Self where Self : Differentiable, Self : MethodDiffReq> (Self) -> () -> (Self, (Self.CotangentVector) -> Self.CotangentVector)'}}
548582
@differentiable(wrt: self, vjp: vjpFoo where Self : Differentiable)
549583
func foo() -> Self
550584
}

branches/rxwei-patch-1/test/AutoDiff/differentiating_attr_type_checking.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,18 @@ extension InstanceMethodProto where Self : Differentiable {
280280
return (bar(), { _ in .zero })
281281
}
282282
}
283+
284+
// Test consistent usages of `@differentiable` and `@differentiating` where
285+
// derivative functions are specified in both attributes.
286+
@differentiable(jvp: jvpConsistent, vjp: vjpConsistent)
287+
func consistentSpecifiedDerivatives(_ x: Float) -> Float {
288+
return x
289+
}
290+
@differentiating(consistentSpecifiedDerivatives)
291+
func jvpConsistent(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
292+
return (x, { $0 })
293+
}
294+
@differentiating(consistentSpecifiedDerivatives(_:))
295+
func vjpConsistent(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
296+
return (x, { $0 })
297+
}

0 commit comments

Comments
 (0)