-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] Adjoint buffer optimization for address projections. #25268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoDiff] Adjoint buffer optimization for address projections. #25268
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add a filecheck test for this?
Sure, I'll try to add adjoint SIL checks involving control flow (todo item in I did test with active object projections ( |
@@ -68,6 +68,22 @@ ControlFlowTests.test("Conditionals") { | |||
expectEqual((-20, 2), valueWithGradient(at: -10, in: cond_tuple)) | |||
expectEqual((-2674, 2), valueWithGradient(at: -1337, in: cond_tuple)) | |||
|
|||
func cond_tuple2(_ x: Float) -> Float { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this test to verify that control flow AD works with active object projections (namely tuple_extract
).
I expected that y0
and y1
would be lowered to active tuple_extract
instructions, but it seems they're not active for some reason (maybe due to mandatory optimizations). That means this test isn't truly meaningful, but it's not bad to have. (cond_struct2
does meaningfully test active object projections.)
For reference, here's the SIL and activity info for cond_tuple2
:
// cond_tuple2(_:)
sil hidden @$s5tuple11cond_tuple2yS2fF : $@convention(thin) (Float) -> Float {
// %0 // users: %30, %34, %21, %24, %8, %4, %2
4, %28, %28, %36, %2, %2, %1
bb0(%0 : $Float):
debug_value %0 : $Float, let, name "x", argno 1 // id: %1
%2 = tuple (%0 : $Float, %0 : $Float) // user: %3
debug_value %2 : $(Float, Float), let, name "y" // id: %3
debug_value %0 : $Float, let, name "y0" // id: %4
%5 = metatype $@thin Float.Type
%6 = metatype $@thick Float.Type // user: %16
%7 = alloc_stack $Float // users: %8, %18, %16
store %0 to %7 : $*Float // id: %8
%9 = integer_literal $Builtin.IntLiteral, 0 // user: %12
%10 = metatype $@thin Float.Type // user: %12
// function_ref Float.init(_builtinIntegerLiteral:)
%11 = function_ref @$sSf22_builtinIntegerLiteralSfBI_tcfC : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float // user: %12
%12 = apply %11(%9, %10) : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float // user: %14
%13 = alloc_stack $Float // users: %14, %17, %16
store %12 to %13 : $*Float // id: %14
// function_ref static FloatingPoint.> infix(_:_:)
%15 = function_ref @$sSFsE1goiySbx_xtFZ : $@convention(method) <τ_0_0 where τ_0_0 : FloatingPoint> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> Bool // user: %16
%16 = apply %15<Float>(%7, %13, %6) : $@convention(method) <τ_0_0 where τ_0_0 : FloatingPoint> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> Bool // user: %19
dealloc_stack %13 : $*Float // id: %17
dealloc_stack %7 : $*Float // id: %18
%19 = struct_extract %16 : $Bool, #Bool._value // user: %20
cond_br %19, bb1, bb2 // id: %20
bb1: // Preds: bb0
debug_value %0 : $Float, let, name "y1" // id: %21
%22 = metatype $@thin Float.Type // user: %24
// function_ref static Float.+ infix(_:_:)
%23 = function_ref @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %24
%24 = apply %23(%0, %0, %22) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %25
br bb3(%24 : $Float) // id: %25
bb2: // Preds: bb0
%26 = metatype $@thin Float.Type // user: %28
// function_ref static Float.+ infix(_:_:)
%27 = function_ref @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %28
%28 = apply %27(%0, %0, %26) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // users: %34, %29
debug_value %28 : $Float, let, name "y0_double" // id: %29
debug_value %0 : $Float, let, name "y1" // id: %30
%31 = metatype $@thin Float.Type // user: %36
%32 = metatype $@thin Float.Type // user: %34
// function_ref static Float.- infix(_:_:)
%33 = function_ref @$sSf1soiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %34
%34 = apply %33(%28, %0, %32) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %36
// function_ref static Float.+ infix(_:_:)
%35 = function_ref @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %36
%36 = apply %35(%34, %0, %31) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %37
br bb3(%36 : $Float) // id: %37
// %38 // user: %39
bb3(%38 : $Float): // Preds: bb2 bb1
return %38 : $Float // id: %39
} // end sil function '$s5tuple11cond_tuple2yS2fF'
[AD] Activity info for $s5tuple11cond_tuple2yS2fF at (source=0 parameters=(0))
bb0:
[ACTIVE] %0 = argument of bb0 : $Float // users: %30, %34, %21, %24, %8, %4, %24, %28, %28, %36, %2, %2, %1
[VARIED] %2 = tuple (%0 : $Float, %0 : $Float) // user: %3
[NONE] %5 = metatype $@thin Float.Type
[NONE] %6 = metatype $@thick Float.Type // user: %16
[VARIED] %7 = alloc_stack $Float // users: %8, %18, %16
[NONE] %9 = integer_literal $Builtin.IntLiteral, 0 // user: %12
[NONE] %10 = metatype $@thin Float.Type // user: %12
[NONE] // function_ref Float.init(_builtinIntegerLiteral:)
%11 = function_ref @$sSf22_builtinIntegerLiteralSfBI_tcfC : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float // user: %12
[NONE] %12 = apply %11(%9, %10) : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float // user: %14
[NONE] %13 = alloc_stack $Float // users: %14, %17, %16
[NONE] // function_ref static FloatingPoint.> infix(_:_:)
%15 = function_ref @$sSFsE1goiySbx_xtFZ : $@convention(method) <τ_0_0 where τ_0_0 : FloatingPoint> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> Bool // user: %16
[VARIED] %16 = apply %15<Float>(%7, %13, %6) : $@convention(method) <τ_0_0 where τ_0_0 : FloatingPoint> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> Bool // user: %19
[VARIED] %19 = struct_extract %16 : $Bool, #Bool._value // user: %20
bb1:
[USEFUL] %22 = metatype $@thin Float.Type // user: %24
[NONE] // function_ref static Float.+ infix(_:_:)
%23 = function_ref @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %24
[ACTIVE] %24 = apply %23(%0, %0, %22) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %25
bb2:
[USEFUL] %26 = metatype $@thin Float.Type // user: %28
[NONE] // function_ref static Float.+ infix(_:_:)
%27 = function_ref @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %28
[ACTIVE] %28 = apply %27(%0, %0, %26) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // users: %34, %29
[USEFUL] %31 = metatype $@thin Float.Type // user: %36
[USEFUL] %32 = metatype $@thin Float.Type // user: %34
[NONE] // function_ref static Float.- infix(_:_:)
%33 = function_ref @$sSf1soiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %34
[ACTIVE] %34 = apply %33(%28, %0, %32) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %36
[NONE] // function_ref static Float.+ infix(_:_:)
%35 = function_ref @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %36
[ACTIVE] %36 = apply %35(%34, %0, %31) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %37
bb3:
[ACTIVE] %38 = argument of bb3 : $Float // user: %39
Do not allocate adjoint buffers for address projections; they become projections into their adjoint base buffer.
- Use deterministic iteration order when processing `@differentiable` attributes in differentiation transform. - Add adjoint SIL tests.
e8df30d
to
9b36197
Compare
@@ -1,9 +1,11 @@ | |||
// RUN: %target-swift-frontend -emit-sil -verify %s | %FileCheck %s -check-prefix=CHECK-DATA-STRUCTURES | |||
// RUN: %target-swift-frontend -emit-sil -verify %s | %FileCheck %s -check-prefix=CHECK-SIL | |||
// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -sil-print-after=differentiation -o /dev/null 2>&1 %s | %FileCheck %s -check-prefix=CHECK-SIL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: checking -Xllvm -sil-print-after=differentiation
here is nicer because:
- The printed SIL matches the SIL printed by
-Xllvm -debug-only=differentiation
. -emit-sil
performs further optimizations (e.g. mandatory inlining), so SIL contains floating-point builtins, etc.
Control flow + address projection test added in 9b36197. |
@swift-ci Please test tensorflow |
Use `llvm::SmallDenseMap` for deterministic insertion order iteration.
@swift-ci Please test tensorflow |
Do not allocate adjoint buffers for address projections; they become
projections into their adjoint base buffer.
This reduces unnecessary adjoint buffer logic (allocation/initialization/copying/cleanup).
Quick experiments (based on
struct_element_addr
):Similar improvements for
tuple_element_addr
.