Skip to content

Commit 4265835

Browse files
authored
Merge pull request #30826 from dan-zheng/SR-12526
[AutoDiff] Add SR-12526 negative test.
2 parents 5b39fa7 + 52b771d commit 4265835

File tree

5 files changed

+43
-11
lines changed

5 files changed

+43
-11
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4423,14 +4423,6 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
44234423
return true;
44244424
}
44254425

4426-
// Reject different-file derivative registration.
4427-
// TODO(TF-1021): Lift same-file derivative registration restriction.
4428-
if (originalAFD->getParentSourceFile() != derivative->getParentSourceFile()) {
4429-
diags.diagnose(attr->getLocation(),
4430-
diag::derivative_attr_not_in_same_file_as_original);
4431-
return true;
4432-
}
4433-
44344426
// Reject duplicate `@derivative` attributes.
44354427
auto &derivativeAttrs = Ctx.DerivativeAttrs[std::make_tuple(
44364428
originalAFD, resolvedDiffParamIndices, kind)];
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
public struct Struct {
2+
public func method(_ x: Float) -> Float { x }
3+
4+
public static func +(_ lhs: Self, rhs: Self) -> Self {
5+
lhs
6+
}
7+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import _Differentiation
2+
import a
3+
4+
extension Struct: Differentiable {
5+
public struct TangentVector: Differentiable & AdditiveArithmetic {}
6+
public mutating func move(along _: TangentVector) {}
7+
8+
@usableFromInline
9+
@derivative(of: method, wrt: x)
10+
func vjpMethod(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
11+
(x, { $0 })
12+
}
13+
14+
@usableFromInline
15+
@derivative(of: +)
16+
static func vjpAdd(_ lhs: Self, rhs: Self) -> (
17+
value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
18+
) {
19+
(lhs + rhs, { v in (v, v) })
20+
}
21+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: %empty-directory(%t)
2+
// RUN: %target-swift-frontend -emit-module -primary-file %S/Inputs/a.swift -emit-module-path %t/a.swiftmodule
3+
// RUN: %target-swift-frontend -emit-module -primary-file %S/Inputs/b.swift -emit-module-path %t/b.swiftmodule -I %t
4+
// RUN: not --crash %target-swift-frontend-typecheck -verify -I %t %s
5+
6+
// SR-12526: Fix cross-module deserialization crash involving `@derivative` attribute.
7+
8+
import a
9+
import b
10+
11+
func foo(_ s: Struct) {
12+
_ = Struct()
13+
_ = s.method(1)
14+
}

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -743,11 +743,9 @@ extension InoutParameters {
743743
}
744744
}
745745

746-
// Test cross-file derivative registration. Currently unsupported.
747-
// TODO(TF-1021): Lift this restriction.
746+
// Test cross-file derivative registration.
748747

749748
extension FloatingPoint where Self: Differentiable {
750-
// expected-error @+1 {{derivative not in the same file as the original function}}
751749
@derivative(of: rounded)
752750
func vjpRounded() -> (
753751
value: Self,

0 commit comments

Comments
 (0)