Skip to content

Commit 9600858

Browse files
committed
Merge branch 'tensorflow' of github.com:apple/swift into tensorflow-merge
2 parents 0f13e79 + 958ecc6 commit 9600858

File tree

7 files changed

+576
-229
lines changed

7 files changed

+576
-229
lines changed

include/swift/SIL/SILInstruction.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2046,6 +2046,11 @@ class ApplyInstBase<Impl, Base, true>
20462046
ApplyInstBase(As &&...args)
20472047
: ApplyInstBase<Impl, Base, false>(std::forward<As>(args)...) {}
20482048

2049+
// SWIFT_ENABLE_TENSORFLOW
2050+
private:
2051+
const Impl &asImpl() const { return static_cast<const Impl &>(*this); }
2052+
// SWIFT_ENABLE_TENSORFLOW END
2053+
20492054
public:
20502055
using super::getCallee;
20512056
using super::getSubstCalleeType;
@@ -2133,6 +2138,35 @@ class ApplyInstBase<Impl, Base, true>
21332138
bool hasSemantics(StringRef semanticsString) const {
21342139
return doesApplyCalleeHaveSemantics(getCallee(), semanticsString);
21352140
}
2141+
2142+
// SWIFT_ENABLE_TENSORFLOW
2143+
private:
2144+
/// Predicate used to filter InoutArgumentRange.
2145+
struct OperandToInoutArgument {
2146+
ArrayRef<SILParameterInfo> paramInfos;
2147+
OperandValueArrayRef arguments;
2148+
OperandToInoutArgument(const Impl &inst)
2149+
: paramInfos(inst.getSubstCalleeConv().getParameters()),
2150+
arguments(inst.getArgumentsWithoutIndirectResults()) {
2151+
assert(paramInfos.size() == arguments.size());
2152+
}
2153+
Optional<SILValue> operator()(unsigned long i) const {
2154+
if (paramInfos[i].isIndirectMutating())
2155+
return arguments[i];
2156+
return None;
2157+
}
2158+
};
2159+
2160+
public:
2161+
using InoutArgumentRange =
2162+
OptionalTransformRange<IntRange<unsigned long>, OperandToInoutArgument>;
2163+
/// Returns all `@inout` and `@inout_aliasable` arguments passed to the
2164+
/// instruction.
2165+
InoutArgumentRange getInoutArguments() const {
2166+
return InoutArgumentRange(indices(getArgumentsWithoutIndirectResults()),
2167+
OperandToInoutArgument(asImpl()));
2168+
}
2169+
// SWIFT_ENABLE_TENSORFLOW END
21362170
};
21372171

21382172
/// ApplyInst - Represents the full application of a function value.

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 214 additions & 175 deletions
Large diffs are not rendered by default.

lib/SILOptimizer/Mandatory/Differentiation.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ inline void createEntryArguments(SILFunction *f) {
103103
auto *decl = new (ctx) ParamDecl(loc, loc, Identifier(), loc,
104104
Identifier(), moduleDecl);
105105
decl->setSpecifier(ParamDecl::Specifier::Default);
106-
// decl->setType(type.getASTType());
107106
entry->createFunctionArgument(type, decl);
108107
};
109108
for (auto indResTy : conv.getIndirectSILResultTypes())

test/AutoDiff/activity_analysis.swift

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-emit-sil -Xllvm -debug-only=differentiation 2>&1 %s | %FileCheck %s
1+
// RUN: %target-swift-emit-sil -verify -Xllvm -debug-only=differentiation 2>&1 %s | %FileCheck %s
22

33
// Check that `@noDerivative` struct projections have "NONE" activity.
44

@@ -72,9 +72,9 @@ func testArrayUninitializedIntrinsic(_ x: Float, _ y: Float) -> [Float] {
7272
// CHECK: [ACTIVE] %6 = apply %5<Float>(%4) : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
7373
// CHECK: [ACTIVE] (**%7**, %8) = destructure_tuple %6 : $(Array<Float>, Builtin.RawPointer)
7474
// CHECK: [VARIED] (%7, **%8**) = destructure_tuple %6 : $(Array<Float>, Builtin.RawPointer)
75-
// CHECK: [VARIED] %9 = pointer_to_address %8 : $Builtin.RawPointer to [strict] $*Float
75+
// CHECK: [ACTIVE] %9 = pointer_to_address %8 : $Builtin.RawPointer to [strict] $*Float
7676
// CHECK: [VARIED] %11 = integer_literal $Builtin.Word, 1
77-
// CHECK: [VARIED] %12 = index_addr %9 : $*Float, %11 : $Builtin.Word
77+
// CHECK: [ACTIVE] %12 = index_addr %9 : $*Float, %11 : $Builtin.Word
7878

7979
@differentiable(where T: Differentiable)
8080
func testArrayUninitializedIntrinsicGeneric<T>(_ x: T, _ y: T) -> [T] {
@@ -89,9 +89,9 @@ func testArrayUninitializedIntrinsicGeneric<T>(_ x: T, _ y: T) -> [T] {
8989
// CHECK: [ACTIVE] %6 = apply %5<T>(%4) : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
9090
// CHECK: [ACTIVE] (**%7**, %8) = destructure_tuple %6 : $(Array<T>, Builtin.RawPointer)
9191
// CHECK: [VARIED] (%7, **%8**) = destructure_tuple %6 : $(Array<T>, Builtin.RawPointer)
92-
// CHECK: [VARIED] %9 = pointer_to_address %8 : $Builtin.RawPointer to [strict] $*T
92+
// CHECK: [ACTIVE] %9 = pointer_to_address %8 : $Builtin.RawPointer to [strict] $*T
9393
// CHECK: [VARIED] %11 = integer_literal $Builtin.Word, 1
94-
// CHECK: [VARIED] %12 = index_addr %9 : $*T, %11 : $Builtin.Word
94+
// CHECK: [ACTIVE] %12 = index_addr %9 : $*T, %11 : $Builtin.Word
9595

9696
// TF-952: Test array literal initialized from an address (e.g. `var`).
9797
@differentiable
@@ -114,10 +114,10 @@ func testArrayUninitializedIntrinsicAddress(_ x: Float, _ y: Float) -> [Float] {
114114
// CHECK: [ACTIVE] %17 = apply %16<Float>(%15) : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
115115
// CHECK: [ACTIVE] (**%18**, %19) = destructure_tuple %17 : $(Array<Float>, Builtin.RawPointer)
116116
// CHECK: [VARIED] (%18, **%19**) = destructure_tuple %17 : $(Array<Float>, Builtin.RawPointer)
117-
// CHECK: [VARIED] %20 = pointer_to_address %19 : $Builtin.RawPointer to [strict] $*Float
117+
// CHECK: [ACTIVE] %20 = pointer_to_address %19 : $Builtin.RawPointer to [strict] $*Float
118118
// CHECK: [ACTIVE] %21 = begin_access [read] [static] %4 : $*Float
119119
// CHECK: [VARIED] %24 = integer_literal $Builtin.Word, 1
120-
// CHECK: [VARIED] %25 = index_addr %20 : $*Float, %24 : $Builtin.Word
120+
// CHECK: [ACTIVE] %25 = index_addr %20 : $*Float, %24 : $Builtin.Word
121121
// CHECK: [ACTIVE] %26 = begin_access [read] [static] %4 : $*Float
122122

123123
// TF-952: Test array literal initialized with function call results.
@@ -133,11 +133,11 @@ func testArrayUninitializedIntrinsicFunctionResult(_ x: Float, _ y: Float) -> [F
133133
// [ACTIVE] %6 = apply %5<Float>(%4) : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
134134
// [ACTIVE] (**%7**, %8) = destructure_tuple %6 : $(Array<Float>, Builtin.RawPointer)
135135
// [VARIED] (%7, **%8**) = destructure_tuple %6 : $(Array<Float>, Builtin.RawPointer)
136-
// [VARIED] %9 = pointer_to_address %8 : $Builtin.RawPointer to [strict] $*Float
136+
// [ACTIVE] %9 = pointer_to_address %8 : $Builtin.RawPointer to [strict] $*Float
137137
// [NONE] // function_ref static Float.* infix(_:_:)
138138
// [ACTIVE] %12 = apply %11(%0, %1, %10) : $@convention(method) (Float, Float, @thin Float.Type) -> Float
139139
// [VARIED] %14 = integer_literal $Builtin.Word, 1
140-
// [VARIED] %15 = index_addr %9 : $*Float, %14 : $Builtin.Word
140+
// [ACTIVE] %15 = index_addr %9 : $*Float, %14 : $Builtin.Word
141141
// [USEFUL] %16 = metatype $@thin Float.Type
142142
// [NONE] // function_ref static Float.* infix(_:_:)
143143
// [ACTIVE] %18 = apply %17(%0, %1, %16) : $@convention(method) (Float, Float, @thin Float.Type) -> Float
@@ -203,3 +203,50 @@ func TF_954(_ x: Float) -> Float {
203203
// CHECK: bb5:
204204
// CHECK: [ACTIVE] %40 = begin_access [read] [static] %2 : $*Float
205205
// CHECK: [ACTIVE] %41 = load [trivial] %40 : $*Float
206+
207+
//===----------------------------------------------------------------------===//
208+
// Non-differentiable functions
209+
//===----------------------------------------------------------------------===//
210+
211+
// Check `inout` arguments.
212+
213+
// expected-error @+1 {{function is not differentiable}}
214+
@differentiable
215+
// expected-note @+1 {{when differentiating this function definition}}
216+
func activeInoutArg(_ x: Float) -> Float {
217+
var result = x
218+
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
219+
result += x
220+
return result
221+
}
222+
223+
// CHECK-LABEL: [AD] Activity info for ${{.*}}activeInoutArg{{.*}} at (source=0 parameters=(0))
224+
// CHECK: [ACTIVE] %0 = argument of bb0 : $Float
225+
// CHECK: [ACTIVE] %2 = alloc_stack $Float, var, name "result"
226+
// CHECK: [ACTIVE] %5 = begin_access [modify] [static] %2 : $*Float
227+
// CHECK: [NONE] // function_ref static Float.+= infix(_:_:)
228+
// CHECK: [NONE] %7 = apply %6(%5, %0, %4) : $@convention(method) (@inout Float, Float, @thin Float.Type) -> ()
229+
// CHECK: [ACTIVE] %9 = begin_access [read] [static] %2 : $*Float
230+
// CHECK: [ACTIVE] %10 = load [trivial] %9 : $*Float
231+
232+
// expected-error @+1 {{function is not differentiable}}
233+
@differentiable
234+
// expected-note @+1 {{when differentiating this function definition}}
235+
func activeInoutArgNonactiveInitialResult(_ x: Float) -> Float {
236+
var result: Float = 1
237+
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
238+
result += x
239+
return result
240+
}
241+
242+
// CHECK-LABEL: [AD] Activity info for ${{.*}}activeInoutArgNonactiveInitialResult{{.*}} at (source=0 parameters=(0))
243+
// CHECK-LABEL: [ACTIVE] %0 = argument of bb0 : $Float
244+
// CHECK-LABEL: [ACTIVE] %2 = alloc_stack $Float, var, name "result"
245+
// CHECK-LABEL: [NONE] // function_ref Float.init(_builtinIntegerLiteral:)
246+
// CHECK-LABEL: [USEFUL] %6 = apply %5(%3, %4) : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float
247+
// CHECK-LABEL: [USEFUL] %8 = metatype $@thin Float.Type
248+
// CHECK-LABEL: [ACTIVE] %9 = begin_access [modify] [static] %2 : $*Float
249+
// CHECK-LABEL: [NONE] // function_ref static Float.+= infix(_:_:)
250+
// CHECK-LABEL: [NONE] %11 = apply %10(%9, %0, %8) : $@convention(method) (@inout Float, Float, @thin Float.Type) -> ()
251+
// CHECK-LABEL: [ACTIVE] %13 = begin_access [read] [static] %2 : $*Float
252+
// CHECK-LABEL: [ACTIVE] %14 = load [trivial] %13 : $*Float

0 commit comments

Comments
 (0)