Skip to content

Commit b82d755

Browse files
authored
[AutoDiff] Decouple test from stdlib dependency. (#26428)
Decouple test/AutoDiff/differentiable_attr_silgen_cross_module.swift from stdlib dependency (Float operations).
1 parent ead5f4d commit b82d755

File tree

2 files changed

+52
-13
lines changed

2 files changed

+52
-13
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
public struct Wrapper : Differentiable, AdditiveArithmetic {
2+
public var x: Float
3+
public init(_ x: Float) {
4+
self.x = x
5+
}
6+
7+
public static func + (lhs: Wrapper, rhs: Wrapper) -> Wrapper {
8+
return Wrapper(lhs.x + rhs.x)
9+
}
10+
11+
@differentiating(+)
12+
public static func vjpAdd(lhs: Wrapper, rhs: Wrapper)
13+
-> (value: Wrapper, pullback: (Wrapper) -> (Wrapper, Wrapper)) {
14+
return (lhs + rhs, { v in (v, v) })
15+
}
16+
17+
public static func * (lhs: Wrapper, rhs: Wrapper) -> Wrapper {
18+
return Wrapper(lhs.x * rhs.x)
19+
}
20+
21+
@differentiating(*)
22+
public static func jvpMultiply(lhs: Wrapper, rhs: Wrapper)
23+
-> (value: Wrapper, differential: (Wrapper, Wrapper) -> Wrapper) {
24+
return (lhs * rhs, { dlhs, drhs in dlhs * rhs + lhs * drhs })
25+
}
26+
27+
@differentiating(*)
28+
public static func vjpMultiply(lhs: Wrapper, rhs: Wrapper)
29+
-> (value: Wrapper, pullback: (Wrapper) -> (Wrapper, Wrapper)) {
30+
return (lhs * rhs, { v in (v * rhs, v * lhs) })
31+
}
32+
}

test/AutoDiff/differentiable_attr_silgen_cross_module.swift

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
1-
// RUN: %target-swift-frontend -emit-silgen -verify %s | %FileCheck %s -check-prefix=CHECK-SILGEN
2-
// RUN: %target-swift-frontend -emit-sil -verify %s | %FileCheck %s -check-prefix=CHECK-SIL
1+
// RUN: %empty-directory(%t)
2+
// RUN: %target-swift-frontend -emit-module -primary-file %S/Inputs/differentiable_attr_silgen_other_module.swift -emit-module-path %t/differentiable_attr_silgen_other_module.swiftmodule
3+
// RUN: %target-swift-frontend -emit-silgen -verify -I %t -primary-file %s | %FileCheck %s -check-prefix=CHECK-SILGEN
4+
// RUN: %target-swift-frontend -emit-sil -verify -I %t -primary-file %s | %FileCheck %s -check-prefix=CHECK-SIL
35

4-
// After SILGen, SIL `[differentiable]` should have jvp/vjp names only if the AST `@differentiable` attribute does.
5-
// The differentiation pass is guaranteed to fill in jvp/vjp names.
6+
import differentiable_attr_silgen_other_module
67

7-
_ = gradient(at: Float(1)) { x in x + x * x }
8+
// After SILGen, SIL `[differentiable]` attribute should have JVP/VJP names
9+
// only if the AST `@differentiable` attribute does.
10+
// The differentiation pass is guaranteed to fill in SIL `[differentiable]`
11+
// attribute JVP/VJP names.
812

9-
// CHECK-SILGEN-LABEL: // static Float.* infix(_:_:)
10-
// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0, 1] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
11-
// CHECK-SIL-LABEL: // static Float.* infix(_:_:)
12-
// CHECK-SIL-NEXT: sil public_external [transparent] [serialized] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @AD__$sSf1moiyS2f_SftFZ__vjp_src_0_wrt_0_1] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @AD__$sSf1moiyS2f_SftFZ__vjp_src_0_wrt_0_1] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
13+
_ = pullback(at: Wrapper(1)) { x in x + x * x }
14+
15+
// CHECK-SILGEN-LABEL: // static Wrapper.* infix(_:_:)
16+
// CHECK-SILGEN-NEXT: sil [differentiable source 0 wrt 0, 1] @$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper
17+
// CHECK-SIL-LABEL: // static Wrapper.* infix(_:_:)
18+
// CHECK-SIL-NEXT: sil [differentiable source 0 wrt 0, 1 jvp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ__jvp_src_0_wrt_0_1 vjp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ__vjp_src_0_wrt_0_1] @$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper
19+
20+
// CHECK-SILGEN-LABEL: // static Wrapper.+ infix(_:_:)
21+
// CHECK-SILGEN-NEXT: sil [differentiable source 0 wrt 0, 1] @$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper
22+
// CHECK-SIL-LABEL: // static Wrapper.+ infix(_:_:)
23+
// CHECK-SIL-NEXT: sil [differentiable source 0 wrt 0, 1 jvp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ__jvp_src_0_wrt_0_1 vjp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ__vjp_src_0_wrt_0_1] @$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper
1324

14-
// CHECK-SILGEN-LABEL: // static Float.+ infix(_:_:)
15-
// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0, 1] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
16-
// CHECK-SIL-LABEL: // static Float.+ infix(_:_:)
17-
// CHECK-SIL-NEXT: sil public_external [transparent] [serialized] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1poiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @AD__$sSf1poiyS2f_SftFZ__vjp_src_0_wrt_0_1] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1poiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @AD__$sSf1poiyS2f_SftFZ__vjp_src_0_wrt_0_1] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float

0 commit comments

Comments
 (0)