Skip to content

Commit 7484431

Browse files
authored
Merge pull request #21569 from marcrasi/add-missing-retains
[AutoDiff] add some missing retains
2 parents be8131b + b7f3ab4 commit 7484431

File tree

3 files changed

+30
-24
lines changed

3 files changed

+30
-24
lines changed

lib/SILOptimizer/Mandatory/TFDifferentiation.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2117,6 +2117,8 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
21172117
SILValue origRes, clonedRes;
21182118
std::tie(origRes, clonedRes) = resultPair;
21192119
getPrimalInfo().addStaticPrimalValueDecl(origRes);
2120+
getBuilder().createRetainValue(cloned->getLoc(), clonedRes,
2121+
getBuilder().getDefaultAtomicity());
21202122
staticPrimalValues.push_back(clonedRes);
21212123
}
21222124
break;
@@ -2321,6 +2323,8 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
23212323

23222324
// Checkpoint the original results.
23232325
getPrimalInfo().addStaticPrimalValueDecl(ai);
2326+
getBuilder().createRetainValue(ai->getLoc(), originalDirectResult,
2327+
getBuilder().getDefaultAtomicity());
23242328
staticPrimalValues.push_back(originalDirectResult);
23252329

23262330
// Checkpoint the pullback.
@@ -2484,6 +2488,8 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
24842488
// Checkpoint original results as a tuple.
24852489
getPrimalInfo().addStaticPrimalValueDecl(ai);
24862490
auto origResAggr = joinElements(origResults, builder, primalCall->getLoc());
2491+
getBuilder().createRetainValue(ai->getLoc(), origResAggr,
2492+
getBuilder().getDefaultAtomicity());
24872493
staticPrimalValues.push_back(origResAggr);
24882494

24892495
// Some instructions that produce the callee may have been cloned.

test/AutoDiff/nested_calls.sil

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,14 @@ bb0(%0 : @trivial $Float):
7878
// CHECK-VJP: %2 = apply %1(%0) : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
7979
// CHECK-VJP: %3 = tuple_extract %2 : $(Float, @callee_guaranteed (Float) -> Float), 0
8080
// CHECK-VJP: %4 = tuple_extract %2 : $(Float, @callee_guaranteed (Float) -> Float), 1
81-
// CHECK-VJP: %5 = function_ref @nested_func_without_diffattr : $@convention(thin) (Float) -> Float
82-
// CHECK-VJP: %6 = apply %5(%3) : $@convention(thin) (Float) -> Float
83-
// CHECK-VJP: %7 = tuple (%3 : $Float, %3 : $Float)
84-
// CHECK-VJP: %8 = tuple_extract %7 : $(Float, Float), 0
85-
// CHECK-VJP: %9 = struct $AD__func_to_diff__Type__src_0_wrt_0 (%3 : $Float, %4 : $@callee_guaranteed (Float) -> Float)
86-
// CHECK-VJP: %10 = tuple (%9 : $AD__func_to_diff__Type__src_0_wrt_0, %8 : $Float)
87-
// CHECK-VJP: return %10 : $(AD__func_to_diff__Type__src_0_wrt_0, Float)
81+
// CHECK-VJP: retain_value %3 : $Float
82+
// CHECK-VJP: %6 = function_ref @nested_func_without_diffattr : $@convention(thin) (Float) -> Float
83+
// CHECK-VJP: %7 = apply %6(%3) : $@convention(thin) (Float) -> Float
84+
// CHECK-VJP: %8 = tuple (%3 : $Float, %3 : $Float)
85+
// CHECK-VJP: %9 = tuple_extract %8 : $(Float, Float), 0
86+
// CHECK-VJP: %10 = struct $AD__func_to_diff__Type__src_0_wrt_0 (%3 : $Float, %4 : $@callee_guaranteed (Float) -> Float)
87+
// CHECK-VJP: %11 = tuple (%10 : $AD__func_to_diff__Type__src_0_wrt_0, %9 : $Float)
88+
// CHECK-VJP: return %11 : $(AD__func_to_diff__Type__src_0_wrt_0, Float)
8889
// CHECK-VJP: }
8990

9091
// CHECK-VJP-LABEL: @AD__func_to_diff__adjoint_src_0_wrt_0 : $@convention(thin) (Float, AD__func_to_diff__Type__src_0_wrt_0, Float, Float) -> Float {
@@ -100,9 +101,10 @@ bb0(%0 : @trivial $Float):
100101
// CHECK-VJP: %2 = apply %1(%0) : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
101102
// CHECK-VJP: %3 = tuple_extract %2 : $(Float, @callee_guaranteed (Float) -> Float), 0
102103
// CHECK-VJP: %4 = tuple_extract %2 : $(Float, @callee_guaranteed (Float) -> Float), 1
103-
// CHECK-VJP: %5 = struct $AD__nested_func_without_diffattr__Type__src_0_wrt_0 (%3 : $Float, %4 : $@callee_guaranteed (Float) -> Float)
104-
// CHECK-VJP: %6 = tuple (%5 : $AD__nested_func_without_diffattr__Type__src_0_wrt_0, %3 : $Float)
105-
// CHECK-VJP: return %6 : $(AD__nested_func_without_diffattr__Type__src_0_wrt_0, Float)
104+
// CHECK-VJP: retain_value %3 : $Float
105+
// CHECK-VJP: %6 = struct $AD__nested_func_without_diffattr__Type__src_0_wrt_0 (%3 : $Float, %4 : $@callee_guaranteed (Float) -> Float)
106+
// CHECK-VJP: %7 = tuple (%6 : $AD__nested_func_without_diffattr__Type__src_0_wrt_0, %3 : $Float)
107+
// CHECK-VJP: return %7 : $(AD__nested_func_without_diffattr__Type__src_0_wrt_0, Float)
106108
// CHECK-VJP: }
107109

108110
// CHECK-VJP-LABEL: @AD__nested_func_without_diffattr__adjoint_src_0_wrt_0 : $@convention(thin) (Float, AD__nested_func_without_diffattr__Type__src_0_wrt_0, Float, Float) -> Float {
@@ -147,14 +149,15 @@ bb0(%0 : @trivial $Float):
147149
// CHECK-NOVJP: %2 = apply %1(%0) : $@convention(thin) (Float) -> (@owned AD__nested_func_without_diffattr__Type__src_0_wrt_0, Float)
148150
// CHECK-NOVJP: %3 = tuple_extract %2 : $(AD__nested_func_without_diffattr__Type__src_0_wrt_0, Float), 0
149151
// CHECK-NOVJP: %4 = tuple_extract %2 : $(AD__nested_func_without_diffattr__Type__src_0_wrt_0, Float), 1
152+
// CHECK-NOVJP: retain_value %4 : $Float
150153
// CHECK-NOVJP: // function_ref nested_func_without_diffattr
151-
// CHECK-NOVJP: %5 = function_ref @nested_func_without_diffattr : $@convention(thin) (Float) -> Float
152-
// CHECK-NOVJP: %6 = apply %5(%4) : $@convention(thin) (Float) -> Float
153-
// CHECK-NOVJP: %7 = tuple (%4 : $Float, %4 : $Float)
154-
// CHECK-NOVJP: %8 = tuple_extract %7 : $(Float, Float), 0
155-
// CHECK-NOVJP: %9 = struct $AD__func_to_diff__Type__src_0_wrt_0 (%3 : $AD__nested_func_without_diffattr__Type__src_0_wrt_0, %4 : $Float)
156-
// CHECK-NOVJP: %10 = tuple (%9 : $AD__func_to_diff__Type__src_0_wrt_0, %8 : $Float)
157-
// CHECK-NOVJP: return %10 : $(AD__func_to_diff__Type__src_0_wrt_0, Float)
154+
// CHECK-NOVJP: %6 = function_ref @nested_func_without_diffattr : $@convention(thin) (Float) -> Float
155+
// CHECK-NOVJP: %7 = apply %6(%4) : $@convention(thin) (Float) -> Float
156+
// CHECK-NOVJP: %8 = tuple (%4 : $Float, %4 : $Float)
157+
// CHECK-NOVJP: %9 = tuple_extract %8 : $(Float, Float), 0
158+
// CHECK-NOVJP: %10 = struct $AD__func_to_diff__Type__src_0_wrt_0 (%3 : $AD__nested_func_without_diffattr__Type__src_0_wrt_0, %4 : $Float)
159+
// CHECK-NOVJP: %11 = tuple (%10 : $AD__func_to_diff__Type__src_0_wrt_0, %9 : $Float)
160+
// CHECK-NOVJP: return %11 : $(AD__func_to_diff__Type__src_0_wrt_0, Float)
158161
// CHECK-NOVJP: }
159162

160163
// CHECK-NOVJP-LABEL: @AD__func_to_diff__adjoint_src_0_wrt_0 : $@convention(thin) (Float, AD__func_to_diff__Type__src_0_wrt_0, Float, Float) -> Float {
@@ -172,9 +175,10 @@ bb0(%0 : @trivial $Float):
172175
// CHECK-NOVJP: %2 = apply %1(%0) : $@convention(thin) (Float) -> (Float, Float)
173176
// CHECK-NOVJP: %3 = tuple_extract %2 : $(Float, Float), 0
174177
// CHECK-NOVJP: %4 = tuple_extract %2 : $(Float, Float), 1
175-
// CHECK-NOVJP: %5 = struct $AD__nested_func_without_diffattr__Type__src_0_wrt_0 (%3 : $Float, %4 : $Float)
176-
// CHECK-NOVJP: %6 = tuple (%5 : $AD__nested_func_without_diffattr__Type__src_0_wrt_0, %4 : $Float)
177-
// CHECK-NOVJP: return %6 : $(AD__nested_func_without_diffattr__Type__src_0_wrt_0, Float)
178+
// CHECK-NOVJP: retain_value %4 : $Float
179+
// CHECK-NOVJP: %6 = struct $AD__nested_func_without_diffattr__Type__src_0_wrt_0 (%3 : $Float, %4 : $Float)
180+
// CHECK-NOVJP: %7 = tuple (%6 : $AD__nested_func_without_diffattr__Type__src_0_wrt_0, %4 : $Float)
181+
// CHECK-NOVJP: return %7 : $(AD__nested_func_without_diffattr__Type__src_0_wrt_0, Float)
178182
// CHECK-NOVJP: }
179183

180184
// CHECK-NOVJP-LABEL: @AD__nested_func_without_diffattr__adjoint_src_0_wrt_0 : $@convention(thin) (Float, AD__nested_func_without_diffattr__Type__src_0_wrt_0, Float, Float) -> Float {

test/TensorFlowRuntime/tensor_autodiff_runtime.swift

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@
77
// REQUIRES: executable_test
88
// REQUIRES: swift_test_mode_optimize
99
//
10-
// FIXME: Segfault.
11-
//
12-
// XFAIL: *
13-
//
1410
// Tensor AD runtime tests.
1511

1612
import TensorFlow

0 commit comments

Comments
 (0)