Skip to content

Commit bafacd8

Browse files
authored
[AutoDiff] [SIL] Tweak 'differentiable_function' syntax. (#27689)
In `differentiable_function` instruction's syntax, use `parameters` instead of `wrt` and 'with_derivative' instead of `with`. This is to align with `linear_function` instruction and the future direction that both parameter indices and result indices will be included in this instruction. Note: `with_derivative` is not named `with_derivatives` because VJPs will be dropped from this instruction when we complete JVP + linear map transposition. Resolves [TF-909](https://bugs.swift.org/browse/TF-909).
1 parent b987d64 commit bafacd8

12 files changed

+48
-46
lines changed

docs/SIL.rst

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5611,28 +5611,29 @@ differentiable_function
56115611
sil-differentiable-function-derivative-functions-clause?
56125612
56135613
sil-differentiable-function-parameter-indices ::=
5614-
'[' 'wrt' [0-9]+ (' ' [0-9]+)* ']'
5614+
'[' 'parameters' [0-9]+ (' ' [0-9]+)* ']'
56155615
sil-differentiable-derivative-functions-clause ::=
5616-
'with' '{' sil-value ':' sil-type ',' sil-value ':' sil-type '}'
5616+
'with_derivative'
5617+
'{' sil-value ':' sil-type ',' sil-value ':' sil-type '}'
56175618

5618-
differentiable_function [wrt 0] %0 : $(T) -> T \
5619-
with {%1 : $(T) -> (T, (T) -> T), %2 : $(T) -> (T, (T) -> T)}
5619+
differentiable_function [parameters 0] %0 : $(T) -> T \
5620+
with_derivative {%1 : $(T) -> (T, (T) -> T), %2 : $(T) -> (T, (T) -> T)}
56205621

56215622
Bundles a function with its derivative functions into a ``@differentiable``
56225623
function. There are two derivative functions: a Jacobian-vector products (JVP)
56235624
function and a vector-Jacobian products (VJP) function.
56245625

5625-
``[wrt ...]`` specifies parameter indices that the original function is
5626+
``[parameters ...]`` specifies parameter indices that the original function is
56265627
differentiable with respect to. When not specified, it defaults to all
56275628
parameters.
56285629

5629-
A ``with`` clause specifies the differentiation functions associated
5630-
with the original function. When a ``with`` clause is not specified, the first
5631-
operand will be differentiated to produce derivative functions, and a ``with``
5632-
clause will be added to the instruction.
5630+
A ``with_derivative`` clause specifies the differentiation functions associated
5631+
with the original function. When a ``with_derivative`` clause is not specified,
5632+
the first operand will be differentiated to produce derivative functions, and a
5633+
``with_derivative`` clause will be added to the instruction.
56335634

5634-
In raw SIL, it is optional to provide a derivative function ``with`` clause.
5635-
In canonical SIL, a ``with`` clause is mandatory.
5635+
In raw SIL, it is optional to provide a derivative function ``with_derivative``
5636+
clause. In canonical SIL, a ``with_derivative`` clause is mandatory.
56365637

56375638

56385639
linear_function

lib/ParseSIL/ParseSIL.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2922,17 +2922,17 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
29222922

29232923
// SWIFT_ENABLE_TENSORFLOW
29242924
case SILInstructionKind::DifferentiableFunctionInst: {
2925-
// e.g. differentiable_function [wrt 0 1 2] %0 : $T
2925+
// e.g. differentiable_function [parameters 0 1 2] %0 : $T
29262926
//
2927-
// e.g. differentiable_function [wrt 0 1 2] %0 : $T with
2927+
// e.g. differentiable_function [parameters 0 1 2] %0 : $T with_derivative
29282928
// {%1 : $T, %2 : $T}
29292929
// ^ jvp ^ vjp
29302930
SourceLoc lastLoc;
29312931
SmallVector<unsigned, 8> parameterIndices;
2932-
// Parse optional `[wrt <integer_literal>...]`
2932+
// Parse optional `[parameters <integer_literal>...]`
29332933
if (P.Tok.is(tok::l_square) &&
29342934
P.peekToken().is(tok::identifier) &&
2935-
P.peekToken().getText() == "wrt") {
2935+
P.peekToken().getText() == "parameters") {
29362936
P.consumeToken(tok::l_square);
29372937
P.consumeToken(tok::identifier);
29382938
// Parse indices.
@@ -2960,8 +2960,9 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
29602960
return true;
29612961
}
29622962
Optional<std::pair<SILValue, SILValue>> derivativeFunctions = None;
2963-
// Parse an optional operand list `with { <operand> , <operand> }`.
2964-
if (P.Tok.is(tok::identifier) && P.Tok.getText() == "with") {
2963+
// Parse an optional operand list
2964+
// `with_derivative { <operand> , <operand> }`.
2965+
if (P.Tok.is(tok::identifier) && P.Tok.getText() == "with_derivative") {
29652966
P.consumeToken(tok::identifier);
29662967
// Parse derivative function values as an operand list.
29672968
// FIXME(rxwei): Change this to *not* require a type signature once

lib/SIL/SILPrinter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,14 +1163,14 @@ class SILPrinter : public SILInstructionVisitor<SILPrinter> {
11631163
// SWIFT_ENABLE_TENSORFLOW
11641164
void visitDifferentiableFunctionInst(DifferentiableFunctionInst *dfi) {
11651165
if (!dfi->getParameterIndices()->isEmpty()) {
1166-
*this << "[wrt";
1166+
*this << "[parameters";
11671167
for (auto i : dfi->getParameterIndices()->getIndices())
11681168
*this << ' ' << i;
11691169
*this << "] ";
11701170
}
11711171
*this << getIDAndType(dfi->getOriginalFunction());
11721172
if (dfi->hasDerivativeFunctions()) {
1173-
*this << " with ";
1173+
*this << " with_derivative ";
11741174
*this << '{' << getIDAndType(dfi->getJVPFunction()) << ", "
11751175
<< getIDAndType(dfi->getVJPFunction()) << '}';
11761176
}

test/AutoDiff/differentiable_function_inst.sil

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,20 @@ sil @examplemethod : $@convention(method) (Float, Float, Float) -> Float
1717
sil @test : $@convention(thin) () -> () {
1818
bb0:
1919
%0 = function_ref @examplefunc : $@convention(thin) (Float, Float, Float) -> Float
20-
%1 = differentiable_function [wrt 0 1 2] %0 : $@convention(thin) (Float, Float, Float) -> Float
20+
%1 = differentiable_function [parameters 0 1 2] %0 : $@convention(thin) (Float, Float, Float) -> Float
2121

2222
// CHECK: %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (Float, Float, Float) -> Float
2323
%2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (Float, Float, Float) -> Float
24-
%3 = differentiable_function [wrt 0] %0 : $@convention(thin) (Float, Float, Float) -> Float
24+
%3 = differentiable_function [parameters 0] %0 : $@convention(thin) (Float, Float, Float) -> Float
2525

2626
// CHECK: %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (Float, @nondiff Float, @nondiff Float) -> Float
2727
%4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (Float, @nondiff Float, @nondiff Float) -> Float
2828
%5 = function_ref @examplemethod : $@convention(method) (Float, Float, Float) -> Float
29-
%6 = differentiable_function [wrt 0 1 2] %5 : $@convention(method) (Float, Float, Float) -> Float
29+
%6 = differentiable_function [parameters 0 1 2] %5 : $@convention(method) (Float, Float, Float) -> Float
3030

3131
// CHECK: %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (Float, Float, Float) -> Float
3232
%7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (Float, Float, Float) -> Float
33-
%8 = differentiable_function [wrt 0] %5 : $@convention(method) (Float, Float, Float) -> Float
33+
%8 = differentiable_function [parameters 0] %5 : $@convention(method) (Float, Float, Float) -> Float
3434

3535
// CHECK: %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (Float, @nondiff Float, @nondiff Float) -> Float
3636
%9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (Float, @nondiff Float, @nondiff Float) -> Float
@@ -68,19 +68,19 @@ bb0(%0 : $Float):
6868
sil @make_diff_func : $@convention(thin) () -> @differentiable @convention(thin) (Float) -> Float {
6969
bb0:
7070
%orig = function_ref @foo : $@convention(thin) (Float) -> Float
71-
%undiffedFunc = differentiable_function [wrt 0] %orig : $@convention(thin) (Float) -> Float
71+
%undiffedFunc = differentiable_function [parameters 0] %orig : $@convention(thin) (Float) -> Float
7272
%vjp = function_ref @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
73-
%diffFunc = differentiable_function [wrt 0] %orig : $@convention(thin) (Float) -> Float with {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)}
73+
%diffFunc = differentiable_function [parameters 0] %orig : $@convention(thin) (Float) -> Float with_derivative {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)}
7474
%extractedVJP = differentiable_function_extract [vjp] %diffFunc : $@differentiable @convention(thin) (Float) -> Float
7575
%extractedOriginal = differentiable_function_extract [original] %diffFunc : $@differentiable @convention(thin) (Float) -> Float
7676
return %undiffedFunc : $@differentiable @convention(thin) (Float) -> Float
7777
}
7878

7979
// CHECK-LABEL: @make_diff_func : $@convention(thin) () -> @differentiable @convention(thin) (Float) -> Float
8080
// CHECK: [[FOO:%.*]] = function_ref @foo : $@convention(thin) (Float) -> Float
81-
// CHECK: [[UNDIFFED_FOO:%.*]] = differentiable_function [wrt 0] [[FOO]] : $@convention(thin) (Float) -> Float
81+
// CHECK: [[UNDIFFED_FOO:%.*]] = differentiable_function [parameters 0] [[FOO]] : $@convention(thin) (Float) -> Float
8282
// CHECK: [[FOO_VJP:%.*]] = function_ref @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
83-
// CHECK: [[DIFFED_FOO:%.*]] = differentiable_function [wrt 0] [[FOO]] : $@convention(thin) (Float) -> Float with {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), [[FOO_VJP]] : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)}
83+
// CHECK: [[DIFFED_FOO:%.*]] = differentiable_function [parameters 0] [[FOO]] : $@convention(thin) (Float) -> Float with_derivative {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), [[FOO_VJP]] : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)}
8484
// CHECK: [[EXTRACTED_VJP:%.*]] = differentiable_function_extract [vjp] [[DIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float
8585
// CHECK: [[EXTRACTED_ORIG:%.*]] = differentiable_function_extract [original] [[DIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float
8686
// CHECK: return [[UNDIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float

test/AutoDiff/differentiable_function_inst_irgen.sil

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ sil @make_diff_func : $@convention(thin) () -> (@convention(thin) (Float) -> Flo
3636
bb0:
3737
%orig = function_ref @foo : $@convention(thin) (Float) -> Float
3838
%vjp = function_ref @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
39-
%diffFunc = differentiable_function [wrt 0] %orig : $@convention(thin) (Float) -> Float with {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)}
39+
%diffFunc = differentiable_function [parameters 0] %orig : $@convention(thin) (Float) -> Float with_derivative {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)}
4040
%extractedOrig = differentiable_function_extract [original] %diffFunc : $@differentiable @convention(thin) (Float) -> Float
4141
%extractedVJP = differentiable_function_extract [vjp] %diffFunc : $@differentiable @convention(thin) (Float) -> Float
4242
%tuple = tuple (%extractedOrig : $@convention(thin) (Float) -> Float, %extractedVJP : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float))

test/AutoDiff/differentiable_function_silgen.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ func apply() {
7575
// CHECK-SILGEN-LABEL: @{{.*}}apply{{.*}}
7676
// CHECK-SILGEN: [[ORIG:%.*]] = function_ref @{{.*}}thin{{.*}} : $@convention(thin) (Float) -> Float
7777
// CHECK-SILGEN-NEXT: [[ORIG_THICK:%.*]] = thin_to_thick_function [[ORIG]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
78-
// CHECK-SILGEN-NEXT: [[DIFFED:%.*]] = differentiable_function [wrt 0] [[ORIG_THICK]] : $@callee_guaranteed (Float) -> Float
78+
// CHECK-SILGEN-NEXT: [[DIFFED:%.*]] = differentiable_function [parameters 0] [[ORIG_THICK]] : $@callee_guaranteed (Float) -> Float
7979
// CHECK-SILGEN: [[ORIG:%.*]] = function_ref @{{.*}}thin{{.*}} : $@convention(thin) (Float) -> Float
8080
// CHECK-SILGEN-NEXT: [[ORIG_THICK:%.*]] = thin_to_thick_function [[ORIG]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
8181
// CHECK-SILGEN-NEXT: [[LIN:%.*]] = linear_function [parameters 0] [[ORIG_THICK]] : $@callee_guaranteed (Float) -> Float
@@ -110,6 +110,6 @@ func appliesReabstraction(_ f: @escaping @differentiable (Float) -> Float) {
110110
// CHECK-SILGEN: [[VJP_COPY:%.*]] = copy_value [[VJP]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
111111
// CHECK-SILGEN: [[REABS_VJP:%.*]] = function_ref @$sS4fIegyd_Iegydo_S4fIegnr_Iegnro_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
112112
// CHECK-SILGEN: [[NEW_VJP:%.*]] = partial_apply [callee_guaranteed] [[REABS_VJP]]([[VJP_COPY]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
113-
// CHECK-SILGEN: [[NEW_DIFF_FUNC:%.*]] = differentiable_function [wrt 0] [[NEW_ORIG]] : $@callee_guaranteed (@in_guaranteed Float) -> @out Float with {[[NEW_JVP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float), [[NEW_VJP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)}
113+
// CHECK-SILGEN: [[NEW_DIFF_FUNC:%.*]] = differentiable_function [parameters 0] [[NEW_ORIG]] : $@callee_guaranteed (@in_guaranteed Float) -> @out Float with_derivative {[[NEW_JVP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float), [[NEW_VJP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)}
114114
// CHECK-SILGEN: [[DIFF_API:%.*]] = function_ref @${{.*}}pullback{{.*}}at{{.*}} : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : _Differentiable, τ_0_1 : _Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector
115115
// CHECK-SILGEN: apply [[DIFF_API]]<Float, Float>({{.*}}, [[NEW_DIFF_FUNC]]) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : _Differentiable, τ_0_1 : _Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector

0 commit comments

Comments
 (0)