Skip to content

Commit 958ecc6

Browse files
authored
[AutoDiff] Fix variedness propagation for apply inout arguments. (#28352)
Propagate variedness from `apply` argument operands to `apply` inout arguments (representing results). `apply` inout arguments are now correctly marked as active, triggering non-differentiability errors. Add `ApplyInstBase::getInoutArguments` for iterating over `@inout` and `@inout_aliasable` arguments. Add non-differentiability diagnostics and activity info tests. Resolves TF-974.
1 parent f5f7dd8 commit 958ecc6

File tree

6 files changed

+176
-49
lines changed

6 files changed

+176
-49
lines changed

include/swift/SIL/SILInstruction.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2046,6 +2046,11 @@ class ApplyInstBase<Impl, Base, true>
20462046
ApplyInstBase(As &&...args)
20472047
: ApplyInstBase<Impl, Base, false>(std::forward<As>(args)...) {}
20482048

2049+
// SWIFT_ENABLE_TENSORFLOW
2050+
private:
2051+
const Impl &asImpl() const { return static_cast<const Impl &>(*this); }
2052+
// SWIFT_ENABLE_TENSORFLOW END
2053+
20492054
public:
20502055
using super::getCallee;
20512056
using super::getSubstCalleeType;
@@ -2133,6 +2138,35 @@ class ApplyInstBase<Impl, Base, true>
21332138
bool hasSemantics(StringRef semanticsString) const {
21342139
return doesApplyCalleeHaveSemantics(getCallee(), semanticsString);
21352140
}
2141+
2142+
// SWIFT_ENABLE_TENSORFLOW
2143+
private:
2144+
/// Predicate used to filter InoutArgumentRange.
2145+
struct OperandToInoutArgument {
2146+
ArrayRef<SILParameterInfo> paramInfos;
2147+
OperandValueArrayRef arguments;
2148+
OperandToInoutArgument(const Impl &inst)
2149+
: paramInfos(inst.getSubstCalleeConv().getParameters()),
2150+
arguments(inst.getArgumentsWithoutIndirectResults()) {
2151+
assert(paramInfos.size() == arguments.size());
2152+
}
2153+
Optional<SILValue> operator()(unsigned long i) const {
2154+
if (paramInfos[i].isIndirectMutating())
2155+
return arguments[i];
2156+
return None;
2157+
}
2158+
};
2159+
2160+
public:
2161+
using InoutArgumentRange =
2162+
OptionalTransformRange<IntRange<unsigned long>, OperandToInoutArgument>;
2163+
/// Returns all `@inout` and `@inout_aliasable` arguments passed to the
2164+
/// instruction.
2165+
InoutArgumentRange getInoutArguments() const {
2166+
return InoutArgumentRange(indices(getArgumentsWithoutIndirectResults()),
2167+
OperandToInoutArgument(asImpl()));
2168+
}
2169+
// SWIFT_ENABLE_TENSORFLOW END
21362170
};
21372171

21382172
/// ApplyInst - Represents the full application of a function value.

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1622,11 +1622,8 @@ LinearMapInfo::LinearMapInfo(ADContext &context,
16221622
/// active argument.
16231623
bool LinearMapInfo::shouldDifferentiateApplyInst(ApplyInst *ai) {
16241624
// Function applications with an inout argument should be differentiated.
1625-
auto paramInfos = ai->getSubstCalleeConv().getParameters();
1626-
auto arguments = ai->getArgumentsWithoutIndirectResults();
1627-
for (auto i : swift::indices(paramInfos))
1628-
if (paramInfos[i].isIndirectInOut() &&
1629-
activityInfo.isActive(arguments[i], indices))
1625+
for (auto inoutArg : ai->getInoutArguments())
1626+
if (activityInfo.isActive(inoutArg, indices))
16301627
return true;
16311628

16321629
bool hasActiveDirectResults = false;
@@ -1642,6 +1639,7 @@ bool LinearMapInfo::shouldDifferentiateApplyInst(ApplyInst *ai) {
16421639
if (isArrayLiteralIntrinsic(ai) && hasActiveResults)
16431640
return true;
16441641

1642+
auto arguments = ai->getArgumentsWithoutIndirectResults();
16451643
bool hasActiveArguments = llvm::any_of(arguments,
16461644
[&](SILValue arg) { return activityInfo.isActive(arg, indices); });
16471645
return hasActiveResults && hasActiveArguments;
@@ -1834,20 +1832,13 @@ void LinearMapInfo::generateDifferentiationDataStructures(
18341832
for (auto &origBB : *original) {
18351833
for (auto &inst : origBB) {
18361834
if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
1837-
// Check for active 'inout' arguments.
1838-
bool isInout = false;
1839-
auto paramInfos = ai->getSubstCalleeConv().getParameters();
1840-
for (unsigned i : swift::indices(paramInfos)) {
1841-
if (paramInfos[i].isIndirectInOut() &&
1842-
activityInfo.isActive(ai->getArgumentsWithoutIndirectResults()[i],
1843-
indices)) {
1844-
// Reject functions with active inout arguments. It's not yet
1845-
// supported.
1846-
isInout = true;
1847-
break;
1848-
}
1849-
}
1850-
if (isInout)
1835+
// Skip `apply` instructions with active `inout` arguments.
1836+
// TODO(TF-129): Support `inout` argument differentiation.
1837+
bool hasActiveInoutArgument =
1838+
llvm::any_of(ai->getInoutArguments(), [&](SILValue inoutArg) {
1839+
return activityInfo.isActive(inoutArg, indices);
1840+
});
1841+
if (hasActiveInoutArgument)
18511842
continue;
18521843

18531844
// Add linear map field to struct for active `apply` instructions.
@@ -2008,10 +1999,13 @@ void DifferentiableActivityInfo::propagateVaried(
20081999
// If callee is non-varying, skip.
20092000
if (isWithoutDerivative(ai->getCallee()))
20102001
return;
2011-
// If operand is varied, set all direct and indirect results as varied.
2002+
// If operand is varied, set all direct/indirect results and inout arguments
2003+
// as varied.
20122004
if (isVaried(operand->get(), i)) {
20132005
for (auto indRes : ai->getIndirectSILResults())
20142006
propagateVariedInwardsThroughProjections(indRes, i);
2007+
for (auto inoutArg : ai->getInoutArguments())
2008+
propagateVariedInwardsThroughProjections(inoutArg, i);
20152009
forEachApplyDirectResult(ai, [&](SILValue directResult) {
20162010
setVariedAndPropagateToUsers(directResult, i);
20172011
});
@@ -3778,7 +3772,7 @@ class VJPEmitter final
37783772
sei->getLoc(), getOpValue(sei->getOperand()), newDefaultBB, caseBBs);
37793773
}
37803774

3781-
// If an `apply` has active results or active inout parameters, replace it
3775+
// If an `apply` has active results or active inout arguments, replace it
37823776
// with an `apply` of its VJP.
37833777
void visitApplyInst(ApplyInst *ai) {
37843778
// If the function should not be differentiated or its the array literal
@@ -3790,13 +3784,10 @@ class VJPEmitter final
37903784
return;
37913785
}
37923786

3793-
// Check and reject functions with active inout arguments. It's not yet
3794-
// supported.
3795-
auto paramInfos = ai->getSubstCalleeConv().getParameters();
3796-
auto paramArgs = ai->getArgumentsWithoutIndirectResults();
3797-
for (unsigned i : swift::indices(paramInfos)) {
3798-
if (paramInfos[i].isIndirectInOut() &&
3799-
activityInfo.isActive(paramArgs[i], getIndices())) {
3787+
// Diagnose functions with active inout arguments.
3788+
// TODO(TF-129): Support `inout` argument differentiation.
3789+
for (auto inoutArg : ai->getInoutArguments()) {
3790+
if (activityInfo.isActive(inoutArg, getIndices())) {
38003791
context.emitNondifferentiabilityError(ai, invoker,
38013792
diag::autodiff_cannot_differentiate_through_inout_arguments);
38023793
errorOccurred = true;
@@ -5472,13 +5463,10 @@ class JVPEmitter final
54725463
return;
54735464
}
54745465

5475-
// Check and reject functions with active inout arguments. It's not yet
5476-
// supported.
5477-
auto paramInfos = ai->getSubstCalleeConv().getParameters();
5478-
auto paramArgs = ai->getArgumentsWithoutIndirectResults();
5479-
for (unsigned i : swift::indices(paramInfos)) {
5480-
if (paramInfos[i].isIndirectInOut() &&
5481-
activityInfo.isActive(paramArgs[i], getIndices())) {
5466+
// Diagnose functions with active inout arguments.
5467+
// TODO(TF-129): Support `inout` argument differentiation.
5468+
for (auto inoutArg : ai->getInoutArguments()) {
5469+
if (activityInfo.isActive(inoutArg, getIndices())) {
54825470
context.emitNondifferentiabilityError(ai, invoker,
54835471
diag::autodiff_cannot_differentiate_through_inout_arguments);
54845472
errorOccurred = true;

lib/SILOptimizer/Mandatory/Differentiation.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ inline void createEntryArguments(SILFunction *f) {
103103
auto *decl = new (ctx) ParamDecl(loc, loc, Identifier(), loc,
104104
Identifier(), moduleDecl);
105105
decl->setSpecifier(ParamDecl::Specifier::Default);
106-
// decl->setType(type.getASTType());
107106
entry->createFunctionArgument(type, decl);
108107
};
109108
for (auto indResTy : conv.getIndirectSILResultTypes())

test/AutoDiff/activity_analysis.swift

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-emit-sil -Xllvm -debug-only=differentiation 2>&1 %s | %FileCheck %s
1+
// RUN: %target-swift-emit-sil -verify -Xllvm -debug-only=differentiation 2>&1 %s | %FileCheck %s
22

33
// Check that `@noDerivative` struct projections have "NONE" activity.
44

@@ -203,3 +203,50 @@ func TF_954(_ x: Float) -> Float {
203203
// CHECK: bb5:
204204
// CHECK: [ACTIVE] %40 = begin_access [read] [static] %2 : $*Float
205205
// CHECK: [ACTIVE] %41 = load [trivial] %40 : $*Float
206+
207+
//===----------------------------------------------------------------------===//
208+
// Non-differentiable functions
209+
//===----------------------------------------------------------------------===//
210+
211+
// Check `inout` arguments.
212+
213+
// expected-error @+1 {{function is not differentiable}}
214+
@differentiable
215+
// expected-note @+1 {{when differentiating this function definition}}
216+
func activeInoutArg(_ x: Float) -> Float {
217+
var result = x
218+
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
219+
result += x
220+
return result
221+
}
222+
223+
// CHECK-LABEL: [AD] Activity info for ${{.*}}activeInoutArg{{.*}} at (source=0 parameters=(0))
224+
// CHECK: [ACTIVE] %0 = argument of bb0 : $Float
225+
// CHECK: [ACTIVE] %2 = alloc_stack $Float, var, name "result"
226+
// CHECK: [ACTIVE] %5 = begin_access [modify] [static] %2 : $*Float
227+
// CHECK: [NONE] // function_ref static Float.+= infix(_:_:)
228+
// CHECK: [NONE] %7 = apply %6(%5, %0, %4) : $@convention(method) (@inout Float, Float, @thin Float.Type) -> ()
229+
// CHECK: [ACTIVE] %9 = begin_access [read] [static] %2 : $*Float
230+
// CHECK: [ACTIVE] %10 = load [trivial] %9 : $*Float
231+
232+
// expected-error @+1 {{function is not differentiable}}
233+
@differentiable
234+
// expected-note @+1 {{when differentiating this function definition}}
235+
func activeInoutArgNonactiveInitialResult(_ x: Float) -> Float {
236+
var result: Float = 1
237+
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
238+
result += x
239+
return result
240+
}
241+
242+
// CHECK-LABEL: [AD] Activity info for ${{.*}}activeInoutArgNonactiveInitialResult{{.*}} at (source=0 parameters=(0))
243+
// CHECK-LABEL: [ACTIVE] %0 = argument of bb0 : $Float
244+
// CHECK-LABEL: [ACTIVE] %2 = alloc_stack $Float, var, name "result"
245+
// CHECK-LABEL: [NONE] // function_ref Float.init(_builtinIntegerLiteral:)
246+
// CHECK-LABEL: [USEFUL] %6 = apply %5(%3, %4) : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float
247+
// CHECK-LABEL: [USEFUL] %8 = metatype $@thin Float.Type
248+
// CHECK-LABEL: [ACTIVE] %9 = begin_access [modify] [static] %2 : $*Float
249+
// CHECK-LABEL: [NONE] // function_ref static Float.+= infix(_:_:)
250+
// CHECK-LABEL: [NONE] %11 = apply %10(%9, %0, %8) : $@convention(method) (@inout Float, Float, @thin Float.Type) -> ()
251+
// CHECK-LABEL: [ACTIVE] %13 = begin_access [read] [static] %2 : $*Float
252+
// CHECK-LABEL: [ACTIVE] %14 = load [trivial] %13 : $*Float

test/AutoDiff/differentiation_transform_diagnostics.swift

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,23 +270,68 @@ func roundingGivesError(x: Float) -> Float {
270270
// Inout arguments
271271
//===----------------------------------------------------------------------===//
272272

273+
// expected-error @+1 {{function is not differentiable}}
274+
@differentiable
275+
// expected-note @+1 {{when differentiating this function definition}}
273276
func activeInoutArg(_ x: Float) -> Float {
274-
var a = x
277+
var result = x
275278
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
276-
a += x
277-
return a
279+
result += x
280+
return result
278281
}
282+
279283
// expected-error @+1 {{function is not differentiable}}
280-
_ = pullback(at: .zero, in: activeInoutArg(_:))
284+
@differentiable
285+
// expected-note @+1 {{when differentiating this function definition}}
286+
func activeInoutArgNonactiveInitialResult(_ x: Float) -> Float {
287+
var result: Float = 1
288+
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
289+
result += x
290+
return result
291+
}
281292

293+
// expected-error @+1 {{function is not differentiable}}
294+
@differentiable
295+
// expected-note @+1 {{when differentiating this function definition}}
282296
func activeInoutArgTuple(_ x: Float) -> Float {
283297
var tuple = (x, x)
284298
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
285299
tuple.0 *= x
286300
return x * tuple.0
287301
}
302+
288303
// expected-error @+1 {{function is not differentiable}}
289-
_ = pullback(at: .zero, in: activeInoutArgTuple(_:))
304+
@differentiable
305+
// expected-note @+1 {{when differentiating this function definition}}
306+
func activeInoutArgControlFlow(_ array: [Float]) -> Float {
307+
var result: Float = 1
308+
for i in withoutDerivative(at: array).indices {
309+
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
310+
result += array[i]
311+
}
312+
return result
313+
}
314+
315+
// expected-error @+1 {{function is not differentiable}}
316+
@differentiable
317+
// expected-note @+1 {{when differentiating this function definition}}
318+
func activeInoutArgControlFlowComplex(_ array: [Float], _ bool: Bool) -> Float {
319+
var result: Float = 1
320+
if bool {
321+
if bool {}
322+
for i in withoutDerivative(at: array).indices {
323+
switch i % 2 {
324+
case 0: continue
325+
case 1: break
326+
default: break
327+
}
328+
result = result + 1
329+
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
330+
result += array[i]
331+
}
332+
}
333+
return result
334+
}
290335

291336
//===----------------------------------------------------------------------===//
292337
// Non-varied results

test/AutoDiff/forward_mode_diagnostics.swift

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,23 +53,37 @@ func calls_diff_of_nested(_ x: Float) -> Float {
5353
// Inout arguments
5454
//===----------------------------------------------------------------------===//
5555

56-
func activeInoutArg(_ x: Float) -> Float {
57-
var a = x
56+
// expected-error @+1 {{function is not differentiable}}
57+
@differentiable
58+
// expected-note @+1 {{when differentiating this function definition}}
59+
func activeInoutArgNonactiveInitialResult(_ x: Float) -> Float {
60+
var result: Float = 1
5861
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
59-
a += x
60-
return a
62+
result += x
63+
return result
6164
}
62-
// expected-error @+1 {{function is not differentiable}}
63-
_ = differential(at: .zero, in: activeInoutArg(_:))
6465

66+
// expected-error @+1 {{function is not differentiable}}
67+
@differentiable
68+
// expected-note @+1 {{when differentiating this function definition}}
6569
func activeInoutArgTuple(_ x: Float) -> Float {
6670
var tuple = (x, x)
6771
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
6872
tuple.0 *= x
6973
return x * tuple.0
7074
}
75+
7176
// expected-error @+1 {{function is not differentiable}}
72-
_ = differential(at: .zero, in: activeInoutArgTuple(_:))
77+
@differentiable
78+
// expected-note @+2 {{when differentiating this function definition}}
79+
// expected-note @+1 {{forward-mode differentiation does not yet support control flow}}
80+
func activeInoutArgControlFlow(_ array: [Float]) -> Float {
81+
var result: Float = 1
82+
for i in withoutDerivative(at: array).indices {
83+
result += array[i]
84+
}
85+
return result
86+
}
7387

7488
//===----------------------------------------------------------------------===//
7589
// Non-varied results

0 commit comments

Comments
 (0)