Skip to content

[AutoDiff] Adjoint buffer optimization for address projections. #25268

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

dan-zheng
Copy link
Contributor

Do not allocate adjoint buffers for address projections; they become
projections into their adjoint base buffer.


This reduces unnecessary adjoint buffer logic (allocation/initialization/copying/cleanup).
Quick experiments (based on struct_element_addr):

struct FloatPair : Differentiable {
  var first, second: Float
  init(_ first: Float, _ second: Float) {
    self.first = first
    self.second = second
  }
}

struct Pair<T : Differentiable, U : Differentiable> : Differentiable {
  var first: T
  var second: U
  init(_ first: T, _ second: U) {
    self.first = first
    self.second = second
  }
}

func cond_struct_var(_ x: Float) -> Float {
  // Convoluted function returning `x + x`.
  var y = FloatPair(x, x)
  var z = FloatPair(x + x, x - x)
  if x > 0 {
    var w = y
    y.first = w.second
    y.second = w.first
    z.first = z.first - y.first
    z.second = z.second + y.first
  } else {
    z = FloatPair(x, x)
  }
  return y.first + y.second - z.first + z.second
}
print((8, 2), valueWithGradient(at: 4, in: cond_struct_var))
print((-20, 2), valueWithGradient(at: -10, in: cond_struct_var))
print((-2674, 2), valueWithGradient(at: -1337, in: cond_struct_var))

func cond_nestedstruct_var(_ x: Float) -> Float {
  // Convoluted function returning `x + x`.
  var y = FloatPair(x + x, x - x)
  var z = Pair(y, x)
  if x > 0 {
    var w = FloatPair(x, x)
    y.first = w.second
    y.second = w.first
    z.first.first = z.first.first - y.first
    z.first.second = z.first.second + y.first
  } else {
    z = Pair(FloatPair(y.first - x, y.second + x), x)
  }
  return y.first + y.second - z.first.first + z.first.second
}
print((8, 2), valueWithGradient(at: 4, in: cond_nestedstruct_var))
Debugging active values for $s6struct05cond_A4_varyS2fF (before vs after):
Active values in bb0 (7 vs 7)
Active values in bb1 (27 vs 37)
Active values in bb2 (9 vs 9)
Active values in bb3 (18 vs 22)

Debugging active values for $s6nested21cond_nestedstruct_varyS2fF (before vs after):
Active values in bb0 (10 vs 10)
Active values in bb1 (44 vs 30)
Active values in bb2 (24 vs 22)
Active values in bb3 (27 vs 21)

Similar improvements for tuple_element_addr.

@dan-zheng dan-zheng added the tensorflow This is for "tensorflow" branch PRs. label Jun 5, 2019
@dan-zheng dan-zheng requested a review from rxwei June 5, 2019 20:33
Copy link
Contributor

@rxwei rxwei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a filecheck test for this?

@dan-zheng
Copy link
Contributor Author

Could you please add a filecheck test for this?

Sure, I'll try to add adjoint SIL checks involving control flow (todo item in test/AutoDiff/control_flow_sil.swift). None exist currently.

I did test with active object projections (struct_extract, tuple_extract) to verify that differentiation of those instructions is correct.

@@ -68,6 +68,22 @@ ControlFlowTests.test("Conditionals") {
expectEqual((-20, 2), valueWithGradient(at: -10, in: cond_tuple))
expectEqual((-2674, 2), valueWithGradient(at: -1337, in: cond_tuple))

func cond_tuple2(_ x: Float) -> Float {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this test to verify that control flow AD works with active object projections (namely tuple_extract).

I expected that y0 and y1 would be lowered to active tuple_extract instructions, but it seems they're not active for some reason (maybe due to mandatory optimizations). That means this test isn't truly meaningful, but it's not bad to have. (cond_struct2 does meaningfully test active object projections.)

For reference, here's the SIL and activity info for cond_tuple2:

// cond_tuple2(_:)
sil hidden @$s5tuple11cond_tuple2yS2fF : $@convention(thin) (Float) -> Float {
// %0                                             // users: %30, %34, %21, %24, %8, %4, %2
4, %28, %28, %36, %2, %2, %1
bb0(%0 : $Float):
  debug_value %0 : $Float, let, name "x", argno 1 // id: %1
  %2 = tuple (%0 : $Float, %0 : $Float)           // user: %3
  debug_value %2 : $(Float, Float), let, name "y" // id: %3
  debug_value %0 : $Float, let, name "y0"         // id: %4
  %5 = metatype $@thin Float.Type
  %6 = metatype $@thick Float.Type                // user: %16
  %7 = alloc_stack $Float                         // users: %8, %18, %16
  store %0 to %7 : $*Float                        // id: %8
  %9 = integer_literal $Builtin.IntLiteral, 0     // user: %12
  %10 = metatype $@thin Float.Type                // user: %12
  // function_ref Float.init(_builtinIntegerLiteral:)
  %11 = function_ref @$sSf22_builtinIntegerLiteralSfBI_tcfC : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float // user: %12
  %12 = apply %11(%9, %10) : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float // user: %14
  %13 = alloc_stack $Float                        // users: %14, %17, %16
  store %12 to %13 : $*Float                      // id: %14
  // function_ref static FloatingPoint.> infix(_:_:)
  %15 = function_ref @$sSFsE1goiySbx_xtFZ : $@convention(method) <τ_0_0 where τ_0_0 : FloatingPoint> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> Bool // user: %16
  %16 = apply %15<Float>(%7, %13, %6) : $@convention(method) <τ_0_0 where τ_0_0 : FloatingPoint> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> Bool // user: %19
  dealloc_stack %13 : $*Float                     // id: %17
  dealloc_stack %7 : $*Float                      // id: %18
  %19 = struct_extract %16 : $Bool, #Bool._value  // user: %20
  cond_br %19, bb1, bb2                           // id: %20

bb1:                                              // Preds: bb0
  debug_value %0 : $Float, let, name "y1"         // id: %21
  %22 = metatype $@thin Float.Type                // user: %24
  // function_ref static Float.+ infix(_:_:)
  %23 = function_ref @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %24
  %24 = apply %23(%0, %0, %22) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %25
  br bb3(%24 : $Float)                            // id: %25

bb2:                                              // Preds: bb0
  %26 = metatype $@thin Float.Type                // user: %28
  // function_ref static Float.+ infix(_:_:)
  %27 = function_ref @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %28
  %28 = apply %27(%0, %0, %26) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // users: %34, %29
  debug_value %28 : $Float, let, name "y0_double" // id: %29
  debug_value %0 : $Float, let, name "y1"         // id: %30
  %31 = metatype $@thin Float.Type                // user: %36
  %32 = metatype $@thin Float.Type                // user: %34
  // function_ref static Float.- infix(_:_:)
  %33 = function_ref @$sSf1soiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %34
  %34 = apply %33(%28, %0, %32) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %36
  // function_ref static Float.+ infix(_:_:)
  %35 = function_ref @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %36
  %36 = apply %35(%34, %0, %31) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %37
  br bb3(%36 : $Float)                            // id: %37

// %38                                            // user: %39
bb3(%38 : $Float):                                // Preds: bb2 bb1
  return %38 : $Float                             // id: %39
} // end sil function '$s5tuple11cond_tuple2yS2fF'
[AD] Activity info for $s5tuple11cond_tuple2yS2fF at (source=0 parameters=(0))
bb0:
[ACTIVE] %0 = argument of bb0 : $Float                     // users: %30, %34, %21, %24, %8, %4, %24, %28, %28, %36, %2, %2, %1
[VARIED]   %2 = tuple (%0 : $Float, %0 : $Float)           // user: %3
[NONE]   %5 = metatype $@thin Float.Type
[NONE]   %6 = metatype $@thick Float.Type                // user: %16
[VARIED]   %7 = alloc_stack $Float                         // users: %8, %18, %16
[NONE]   %9 = integer_literal $Builtin.IntLiteral, 0     // user: %12
[NONE]   %10 = metatype $@thin Float.Type                // user: %12
[NONE]   // function_ref Float.init(_builtinIntegerLiteral:)
  %11 = function_ref @$sSf22_builtinIntegerLiteralSfBI_tcfC : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float // user: %12
[NONE]   %12 = apply %11(%9, %10) : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float // user: %14
[NONE]   %13 = alloc_stack $Float                        // users: %14, %17, %16
[NONE]   // function_ref static FloatingPoint.> infix(_:_:)
  %15 = function_ref @$sSFsE1goiySbx_xtFZ : $@convention(method) <τ_0_0 where τ_0_0 : FloatingPoint> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> Bool // user: %16
[VARIED]   %16 = apply %15<Float>(%7, %13, %6) : $@convention(method) <τ_0_0 where τ_0_0 : FloatingPoint> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> Bool // user: %19
[VARIED]   %19 = struct_extract %16 : $Bool, #Bool._value  // user: %20
bb1:
[USEFUL]   %22 = metatype $@thin Float.Type                // user: %24
[NONE]   // function_ref static Float.+ infix(_:_:)
  %23 = function_ref @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %24
[ACTIVE]   %24 = apply %23(%0, %0, %22) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %25
bb2:
[USEFUL]   %26 = metatype $@thin Float.Type                // user: %28
[NONE]   // function_ref static Float.+ infix(_:_:)
  %27 = function_ref @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %28
[ACTIVE]   %28 = apply %27(%0, %0, %26) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // users: %34, %29
[USEFUL]   %31 = metatype $@thin Float.Type                // user: %36
[USEFUL]   %32 = metatype $@thin Float.Type                // user: %34
[NONE]   // function_ref static Float.- infix(_:_:)
  %33 = function_ref @$sSf1soiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %34
[ACTIVE]   %34 = apply %33(%28, %0, %32) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %36
[NONE]   // function_ref static Float.+ infix(_:_:)
  %35 = function_ref @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %36
[ACTIVE]   %36 = apply %35(%34, %0, %31) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %37
bb3:
[ACTIVE] %38 = argument of bb3 : $Float                    // user: %39

dan-zheng added 2 commits June 5, 2019 15:46
Do not allocate adjoint buffers for address projections; they become
projections into their adjoint base buffer.
- Use deterministic iteration order when processing `@differentiable`
  attributes in differentiation transform.
- Add adjoint SIL tests.
@dan-zheng dan-zheng force-pushed the autodiff-optimize-address-projections branch from e8df30d to 9b36197 Compare June 6, 2019 00:48
@@ -1,9 +1,11 @@
// RUN: %target-swift-frontend -emit-sil -verify %s | %FileCheck %s -check-prefix=CHECK-DATA-STRUCTURES
// RUN: %target-swift-frontend -emit-sil -verify %s | %FileCheck %s -check-prefix=CHECK-SIL
// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -sil-print-after=differentiation -o /dev/null 2>&1 %s | %FileCheck %s -check-prefix=CHECK-SIL
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: checking -Xllvm -sil-print-after=differentiation here is nicer because:

  • The printed SIL matches the SIL printed by -Xllvm -debug-only=differentiation.
  • -emit-sil performs further optimizations (e.g. mandatory inlining), so SIL contains floating-point builtins, etc.

@dan-zheng
Copy link
Contributor Author

Could you please add a filecheck test for this?

Control flow + address projection test added in 9b36197.

@dan-zheng dan-zheng requested a review from rxwei June 6, 2019 00:54
@dan-zheng
Copy link
Contributor Author

@swift-ci Please test tensorflow

Use `llvm::SmallDenseMap` for deterministic insertion order iteration.
@dan-zheng
Copy link
Contributor Author

@swift-ci Please test tensorflow

@dan-zheng dan-zheng merged commit 8c646af into swiftlang:tensorflow Jun 6, 2019
@dan-zheng dan-zheng deleted the autodiff-optimize-address-projections branch June 6, 2019 07:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tensorflow This is for "tensorflow" branch PRs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants