Skip to content

Commit 0b7d8ab

Browse files
authored
[AutoDiff] Remove 'readnone' attribute from autoDiffCreateLinearMapContext. (#66203)
It certainly has side effects and returned value every time is different. This way we ensure multiple calls are not CSE'd or LICM'ed. Fixes #65989
1 parent 1cf20fb commit 0b7d8ab

File tree

2 files changed

+91
-1
lines changed

2 files changed

+91
-1
lines changed

include/swift/AST/Builtins.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,7 @@ BUILTIN_MISC_OPERATION_WITH_SILGEN(CreateAsyncTaskInGroup,
986986
BUILTIN_MISC_OPERATION_WITH_SILGEN(GlobalStringTablePointer, "globalStringTablePointer", "n", Special)
987987

988988
// autoDiffCreateLinearMapContext: (Builtin.Word) -> Builtin.NativeObject
989-
BUILTIN_MISC_OPERATION_WITH_SILGEN(AutoDiffCreateLinearMapContext, "autoDiffCreateLinearMapContext", "n", Special)
989+
BUILTIN_MISC_OPERATION_WITH_SILGEN(AutoDiffCreateLinearMapContext, "autoDiffCreateLinearMapContext", "", Special)
990990

991991
// autoDiffProjectTopLevelSubcontext: (Builtin.NativeObject) -> Builtin.RawPointer
992992
BUILTIN_MISC_OPERATION_WITH_SILGEN(AutoDiffProjectTopLevelSubcontext, "autoDiffProjectTopLevelSubcontext", "n", Special)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// RUN: %target-swift-frontend -emit-sil -O %s | %FileCheck %s
2+
3+
// Ensure that autoDiffCreateLinearMapContext call is not LICM'ed
4+
import _Differentiation;
5+
6+
public struct R: Differentiable {
7+
@noDerivative public var z: Int
8+
}
9+
10+
public struct Z: Differentiable {
11+
public var r: [R] = []
12+
}
13+
14+
public struct B: Differentiable {
15+
public var h = [Float]();
16+
public var e = Z()
17+
}
18+
19+
public extension Array {
20+
@differentiable(reverse where Element: Differentiable)
21+
mutating func update(at x: Int, with n: Element) {
22+
self[x] = n
23+
}
24+
}
25+
26+
public extension Array where Element: Differentiable {
27+
@derivative(of: update(at:with:))
28+
mutating func v(at x: Int, with nv: Element) ->
29+
(value: Void,
30+
pullback: (inout TangentVector) -> (Element.TangentVector)) {
31+
update(at: x, with: nv);
32+
let f = count;
33+
return ((),
34+
{ v in
35+
if v.base.count < f {
36+
v.base = [Element.TangentVector](repeating: .zero, count: f)
37+
};
38+
let d = v[x];
39+
v.base[x] = .zero;
40+
return d}
41+
)
42+
}
43+
}
44+
45+
extension B {
46+
@differentiable(reverse)
47+
mutating func a() {
48+
for idx in 0 ..< withoutDerivative(at: self.e.r).count {
49+
let z = self.e.r[idx].z;
50+
let c = self.h[z];
51+
self.h.update(at: z, with: c + 2.4)
52+
}
53+
}
54+
}
55+
56+
public func b(y: B) -> (value: B,
57+
pullback: (B.TangentVector) -> (B.TangentVector)) {
58+
let s = valueWithPullback(at: y, of: s);
59+
return (value: s.value, pullback: s.pullback)
60+
}
61+
62+
@differentiable(reverse)
63+
public func s(y: B) -> B {
64+
@differentiable(reverse)
65+
func q(_ u: B) -> B {
66+
var l = u;
67+
for _ in 0 ..< 1 {
68+
l.a()
69+
};
70+
return l
71+
};
72+
let w = m(q);
73+
return w(y)
74+
}
75+
76+
// CHECK-LABEL: sil private @$s12licm_context1s1yAA1BVAE_tF1qL_yA2EFTJrSpSr :
77+
// CHECK: autoDiffCreateLinearMapContext
78+
// CHECK: autoDiffCreateLinearMapContext
79+
// CHECK-LABEL: end sil function '$s12licm_context1s1yAA1BVAE_tF1qL_yA2EFTJrSpSr'
80+
81+
func o<T, R>(_ x: T, _ f: @differentiable(reverse) (T) -> R) -> R {
82+
f(x)
83+
}
84+
85+
func m<T, R>(_ f: @escaping @differentiable(reverse) (T) -> R) -> @differentiable(reverse) (T) -> R {
86+
{ x in o(x, f) }
87+
}
88+
89+
let m = b(y: B());
90+
let grad = m.pullback(B.TangentVector(h: Array<Float>.TangentVector(), e: Z.TangentVector(r: Array<R>.TangentVector())))

0 commit comments

Comments
 (0)