Skip to content

[AutoDiff] Adjoint buffer optimization for address projections. #25268

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ using namespace swift;
using llvm::DenseMap;
using llvm::SmallDenseMap;
using llvm::SmallDenseSet;
using llvm::SmallMapVector;
using llvm::SmallSet;

/// This flag is used to disable `autodiff_function_extract` instruction folding
Expand Down Expand Up @@ -844,7 +845,9 @@ class ADContext {
SmallPtrSet<AutoDiffFunctionInst *, 32> processedAutoDiffFunctionInsts;

/// Mapping from `[differentiable]` attributes to invokers.
DenseMap<SILDifferentiableAttr *, DifferentiationInvoker> invokers;
/// `SmallMapVector` is used for deterministic insertion order iteration.
SmallMapVector<SILDifferentiableAttr *, DifferentiationInvoker, 32>
invokers;

/// Mapping from `autodiff_function` instructions to result indices.
DenseMap<AutoDiffFunctionInst *, unsigned> resultIndices;
Expand Down Expand Up @@ -902,7 +905,8 @@ class ADContext {
return processedAutoDiffFunctionInsts;
}

DenseMap<SILDifferentiableAttr *, DifferentiationInvoker> &getInvokers() {
llvm::SmallMapVector<SILDifferentiableAttr *, DifferentiationInvoker, 32> &
getInvokers() {
return invokers;
}

Expand Down Expand Up @@ -957,8 +961,8 @@ class ADContext {
}

void cleanUp() {
for (auto invokerInfo : invokers) {
auto *attr = invokerInfo.getFirst();
for (auto invokerPair : invokers) {
auto *attr = std::get<0>(invokerPair);
auto *original = attr->getOriginal();
LLVM_DEBUG(getADDebugStream()
<< "Removing [differentiable] attribute for "
Expand Down Expand Up @@ -4141,6 +4145,11 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
auto addActiveValue = [&](SILValue v) {
if (visited.count(v))
return;
// Skip address projections.
// Address projections do not need their own adjoint buffers; they
// become projections into their adjoint base buffer.
if (Projection::isAddressProjection(v))
return;
visited.insert(v);
bbActiveValues.push_back(v);
};
Expand Down Expand Up @@ -6352,17 +6361,14 @@ void Differentiation::run() {
// A global differentiation context.
ADContext context(*this);

// Handle all the instructions and attributes in the module that trigger
// differentiation.
// Register all `@differentiable` attributes and `autodiff_function`
// instructions in the module that trigger differentiation.
for (SILFunction &f : module) {
// If `f` has a `[differentiable]` attribute, register `f` and the attribute
// with an invoker.
for (auto *diffAttr : f.getDifferentiableAttrs()) {
DifferentiationInvoker invoker(diffAttr);
auto insertion =
context.getInvokers().try_emplace(diffAttr, invoker);
assert(insertion.second &&
assert(!context.getInvokers().count(diffAttr) &&
"[differentiable] attribute already has an invoker");
context.getInvokers().insert({diffAttr, invoker});
continue;
}
for (SILBasicBlock &bb : f)
Expand All @@ -6387,10 +6393,10 @@ void Differentiation::run() {
bool errorOccurred = false;

// Process all `[differentiable]` attributes.
for (auto invokerInfo : context.getInvokers()) {
auto *attr = invokerInfo.first;
for (auto invokerPair : context.getInvokers()) {
auto *attr = invokerPair.first;
auto *original = attr->getOriginal();
auto invoker = invokerInfo.second;
auto invoker = invokerPair.second;
errorOccurred |=
context.processDifferentiableAttribute(original, attr, invoker);
}
Expand Down
32 changes: 32 additions & 0 deletions test/AutoDiff/control_flow.swift
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,22 @@ ControlFlowTests.test("Conditionals") {
expectEqual((-20, 2), valueWithGradient(at: -10, in: cond_tuple))
expectEqual((-2674, 2), valueWithGradient(at: -1337, in: cond_tuple))

func cond_tuple2(_ x: Float) -> Float {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this test to verify that control flow AD works with active object projections (namely tuple_extract).

I expected that y0 and y1 would be lowered to active tuple_extract instructions, but it seems they're not active for some reason (maybe due to mandatory optimizations). That means this test isn't truly meaningful, but it's not bad to have. (cond_struct2 does meaningfully test active object projections.)

For reference, here's the SIL and activity info for cond_tuple2:

// cond_tuple2(_:)
sil hidden @$s5tuple11cond_tuple2yS2fF : $@convention(thin) (Float) -> Float {
// %0                                             // users: %30, %34, %21, %24, %8, %4, %2
4, %28, %28, %36, %2, %2, %1
bb0(%0 : $Float):
  debug_value %0 : $Float, let, name "x", argno 1 // id: %1
  %2 = tuple (%0 : $Float, %0 : $Float)           // user: %3
  debug_value %2 : $(Float, Float), let, name "y" // id: %3
  debug_value %0 : $Float, let, name "y0"         // id: %4
  %5 = metatype $@thin Float.Type
  %6 = metatype $@thick Float.Type                // user: %16
  %7 = alloc_stack $Float                         // users: %8, %18, %16
  store %0 to %7 : $*Float                        // id: %8
  %9 = integer_literal $Builtin.IntLiteral, 0     // user: %12
  %10 = metatype $@thin Float.Type                // user: %12
  // function_ref Float.init(_builtinIntegerLiteral:)
  %11 = function_ref @$sSf22_builtinIntegerLiteralSfBI_tcfC : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float // user: %12
  %12 = apply %11(%9, %10) : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float // user: %14
  %13 = alloc_stack $Float                        // users: %14, %17, %16
  store %12 to %13 : $*Float                      // id: %14
  // function_ref static FloatingPoint.> infix(_:_:)
  %15 = function_ref @$sSFsE1goiySbx_xtFZ : $@convention(method) <τ_0_0 where τ_0_0 : FloatingPoint> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> Bool // user: %16
  %16 = apply %15<Float>(%7, %13, %6) : $@convention(method) <τ_0_0 where τ_0_0 : FloatingPoint> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> Bool // user: %19
  dealloc_stack %13 : $*Float                     // id: %17
  dealloc_stack %7 : $*Float                      // id: %18
  %19 = struct_extract %16 : $Bool, #Bool._value  // user: %20
  cond_br %19, bb1, bb2                           // id: %20

bb1:                                              // Preds: bb0
  debug_value %0 : $Float, let, name "y1"         // id: %21
  %22 = metatype $@thin Float.Type                // user: %24
  // function_ref static Float.+ infix(_:_:)
  %23 = function_ref @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %24
  %24 = apply %23(%0, %0, %22) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %25
  br bb3(%24 : $Float)                            // id: %25

bb2:                                              // Preds: bb0
  %26 = metatype $@thin Float.Type                // user: %28
  // function_ref static Float.+ infix(_:_:)
  %27 = function_ref @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %28
  %28 = apply %27(%0, %0, %26) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // users: %34, %29
  debug_value %28 : $Float, let, name "y0_double" // id: %29
  debug_value %0 : $Float, let, name "y1"         // id: %30
  %31 = metatype $@thin Float.Type                // user: %36
  %32 = metatype $@thin Float.Type                // user: %34
  // function_ref static Float.- infix(_:_:)
  %33 = function_ref @$sSf1soiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %34
  %34 = apply %33(%28, %0, %32) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %36
  // function_ref static Float.+ infix(_:_:)
  %35 = function_ref @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %36
  %36 = apply %35(%34, %0, %31) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %37
  br bb3(%36 : $Float)                            // id: %37

// %38                                            // user: %39
bb3(%38 : $Float):                                // Preds: bb2 bb1
  return %38 : $Float                             // id: %39
} // end sil function '$s5tuple11cond_tuple2yS2fF'
[AD] Activity info for $s5tuple11cond_tuple2yS2fF at (source=0 parameters=(0))
bb0:
[ACTIVE] %0 = argument of bb0 : $Float                     // users: %30, %34, %21, %24, %8, %4, %24, %28, %28, %36, %2, %2, %1
[VARIED]   %2 = tuple (%0 : $Float, %0 : $Float)           // user: %3
[NONE]   %5 = metatype $@thin Float.Type
[NONE]   %6 = metatype $@thick Float.Type                // user: %16
[VARIED]   %7 = alloc_stack $Float                         // users: %8, %18, %16
[NONE]   %9 = integer_literal $Builtin.IntLiteral, 0     // user: %12
[NONE]   %10 = metatype $@thin Float.Type                // user: %12
[NONE]   // function_ref Float.init(_builtinIntegerLiteral:)
  %11 = function_ref @$sSf22_builtinIntegerLiteralSfBI_tcfC : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float // user: %12
[NONE]   %12 = apply %11(%9, %10) : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float // user: %14
[NONE]   %13 = alloc_stack $Float                        // users: %14, %17, %16
[NONE]   // function_ref static FloatingPoint.> infix(_:_:)
  %15 = function_ref @$sSFsE1goiySbx_xtFZ : $@convention(method) <τ_0_0 where τ_0_0 : FloatingPoint> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> Bool // user: %16
[VARIED]   %16 = apply %15<Float>(%7, %13, %6) : $@convention(method) <τ_0_0 where τ_0_0 : FloatingPoint> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> Bool // user: %19
[VARIED]   %19 = struct_extract %16 : $Bool, #Bool._value  // user: %20
bb1:
[USEFUL]   %22 = metatype $@thin Float.Type                // user: %24
[NONE]   // function_ref static Float.+ infix(_:_:)
  %23 = function_ref @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %24
[ACTIVE]   %24 = apply %23(%0, %0, %22) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %25
bb2:
[USEFUL]   %26 = metatype $@thin Float.Type                // user: %28
[NONE]   // function_ref static Float.+ infix(_:_:)
  %27 = function_ref @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %28
[ACTIVE]   %28 = apply %27(%0, %0, %26) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // users: %34, %29
[USEFUL]   %31 = metatype $@thin Float.Type                // user: %36
[USEFUL]   %32 = metatype $@thin Float.Type                // user: %34
[NONE]   // function_ref static Float.- infix(_:_:)
  %33 = function_ref @$sSf1soiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %34
[ACTIVE]   %34 = apply %33(%28, %0, %32) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %36
[NONE]   // function_ref static Float.+ infix(_:_:)
  %35 = function_ref @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %36
[ACTIVE]   %36 = apply %35(%34, %0, %31) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %37
bb3:
[ACTIVE] %38 = argument of bb3 : $Float                    // user: %39

// Convoluted function returning `x + x`.
let y: (Float, Float) = (x, x)
let y0 = y.0
if x > 0 {
let y1 = y.1
return y0 + y1
}
let y0_double = y0 + y.0
let y1 = y.1
return y0_double - y1 + y.0
}
expectEqual((8, 2), valueWithGradient(at: 4, in: cond_tuple2))
expectEqual((-20, 2), valueWithGradient(at: -10, in: cond_tuple2))
expectEqual((-2674, 2), valueWithGradient(at: -1337, in: cond_tuple2))

func cond_tuple_var(_ x: Float) -> Float {
// Convoluted function returning `x + x`.
var y: (Float, Float) = (x, x)
Expand Down Expand Up @@ -135,6 +151,22 @@ ControlFlowTests.test("Conditionals") {
expectEqual((-20, 2), valueWithGradient(at: -10, in: cond_struct))
expectEqual((-2674, 2), valueWithGradient(at: -1337, in: cond_struct))

func cond_struct2(_ x: Float) -> Float {
// Convoluted function returning `x + x`.
let y = FloatPair(x, x)
let y0 = y.first
if x > 0 {
let y1 = y.second
return y0 + y1
}
let y0_double = y0 + y.first
let y1 = y.second
return y0_double - y1 + y.first
}
expectEqual((8, 2), valueWithGradient(at: 4, in: cond_struct2))
expectEqual((-20, 2), valueWithGradient(at: -10, in: cond_struct2))
expectEqual((-2674, 2), valueWithGradient(at: -1337, in: cond_struct2))

func cond_struct_var(_ x: Float) -> Float {
// Convoluted function returning `x + x`.
var y = FloatPair(x, x)
Expand Down
93 changes: 89 additions & 4 deletions test/AutoDiff/control_flow_sil.swift
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
// RUN: %target-swift-frontend -emit-sil -verify %s | %FileCheck %s -check-prefix=CHECK-DATA-STRUCTURES
// RUN: %target-swift-frontend -emit-sil -verify %s | %FileCheck %s -check-prefix=CHECK-SIL
// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -sil-print-after=differentiation -o /dev/null 2>&1 %s | %FileCheck %s -check-prefix=CHECK-SIL
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: checking -Xllvm -sil-print-after=differentiation here is nicer because:

  • The printed SIL matches the SIL printed by -Xllvm -debug-only=differentiation.
  • -emit-sil performs further optimizations (e.g. mandatory inlining), so SIL contains floating-point builtins, etc.


// TODO: Add adjoint SIL FileCheck tests.
// TODO: Add FileCheck tests.

// Test conditional: a simple if-diamond.
//===----------------------------------------------------------------------===//
// Conditionals
//===----------------------------------------------------------------------===//

@differentiable
@_silgen_name("cond")
Expand Down Expand Up @@ -40,7 +42,7 @@ func cond(_ x: Float) -> Float {
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb3__Pred__src_0_wrt_0 { get set }
// CHECK-DATA-STRUCTURES: }

// CHECK-SIL-LABEL: sil hidden @AD__cond__vjp_src_0_wrt_0
// CHECK-SIL-LABEL: sil hidden @AD__cond__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
// CHECK-SIL: bb0([[INPUT_ARG:%.*]] : $Float):
// CHECK-SIL: [[BB0_PB_STRUCT:%.*]] = struct $_AD__cond_bb0__PB__src_0_wrt_0 ()
// CHECK-SIL: [[BB1_PRED:%.*]] = enum $_AD__cond_bb1__Pred__src_0_wrt_0, #_AD__cond_bb1__Pred__src_0_wrt_0.bb0!enumelt.1, [[BB0_PB_STRUCT]]
Expand All @@ -64,6 +66,40 @@ func cond(_ x: Float) -> Float {
// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float)
// CHECK-SIL: return [[VJP_RESULT]]

// CHECK-SIL-LABEL: sil hidden @AD__cond__adjoint_src_0_wrt_0 : $@convention(thin) (Float, @guaranteed _AD__cond_bb3__PB__src_0_wrt_0) -> Float {
// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PB_STRUCT:%.*]] : $_AD__cond_bb3__PB__src_0_wrt_0):
// CHECK-SIL: [[BB3_PRED:%.*]] = struct_extract %1 : $_AD__cond_bb3__PB__src_0_wrt_0, #_AD__cond_bb3__PB__src_0_wrt_0.predecessor
// CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_bb3__Pred__src_0_wrt_0, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt.1: bb3, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt.1: bb1

// CHECK-SIL: bb1([[BB3_PRED1_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb1__PB__src_0_wrt_0):
// CHECK-SIL: br bb2([[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}}: $Float)

// CHECK-SIL: bb2([[BB1_PB_STRUCT:%.*]] : $_AD__cond_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float):
// CHECK-SIL: [[BB1_PB:%.*]] = struct_extract [[BB1_PB_STRUCT]]
// CHECK-SIL: [[BB1_ADJVALS:%.*]] = apply [[BB1_PB]]([[SEED]]) : $@callee_guaranteed (Float) -> (Float, Float)
// CHECK-SIL: [[BB1_PRED:%.*]] = struct_extract [[BB1_PB_STRUCT]]
// CHECK-SIL: [[BB1_PB_STRUCT_DATA:%.*]] = unchecked_enum_data [[BB1_PRED]]
// CHECK-SIL: br bb5([[BB1_PB_STRUCT_DATA]] : $_AD__cond_bb0__PB__src_0_wrt_0)

// CHECK-SIL: bb3([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb2__PB__src_0_wrt_0):
// CHECK-SIL: br bb4([[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}}: $Float)

// CHECK-SIL: bb4([[BB2_PB_STRUCT:%.*]] : $_AD__cond_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float):
// CHECK-SIL: [[BB2_PB:%.*]] = struct_extract [[BB2_PB_STRUCT]]
// CHECK-SIL: [[BB2_ADJVALS:%.*]] = apply [[BB2_PB]]([[SEED]]) : $@callee_guaranteed (Float) -> (Float, Float)
// CHECK-SIL: [[BB2_PRED:%.*]] = struct_extract [[BB2_PB_STRUCT]]
// CHECK-SIL: [[BB2_PB_STRUCT_DATA:%.*]] = unchecked_enum_data [[BB2_PRED]]
// CHECK-SIL: br bb6([[BB2_PB_STRUCT_DATA]] : $_AD__cond_bb0__PB__src_0_wrt_0)

// CHECK-SIL: bb5([[BB1_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0):
// CHECK-SIL: br bb7([[BB1_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_bb0__PB__src_0_wrt_0, {{%.*}} : $Float)

// CHECK-SIL: bb6([[BB2_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0):
// CHECK-SIL: br bb7([[BB2_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_bb0__PB__src_0_wrt_0, {{%.*}} : $Float)

// CHECK-SIL: bb7([[BB0_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0, {{%.*}} : $Float):
// CHECK-SIL: return {{%.*}} : $Float

@differentiable
@_silgen_name("nested_cond")
func nested_cond(_ x: Float, _ y: Float) -> Float {
Expand All @@ -89,3 +125,52 @@ func nested_cond_generic<T : Differentiable & FloatingPoint>(_ x: T, _ y: T) ->
}
return y
}

// Test control flow + tuple buffer.
// Verify that adjoint buffers are not allocated for address projections.

@differentiable
@_silgen_name("cond_tuple_var")
func cond_tuple_var(_ x: Float) -> Float {
// expected-warning @+1 {{variable 'y' was never mutated; consider changing to 'let' constant}}
var y = (x, x)
if x > 0 {
return y.0
}
return y.1
}
// CHECK-SIL-LABEL: sil hidden @AD__cond_tuple_var__adjoint_src_0_wrt_0 : $@convention(thin) (Float, @guaranteed _AD__cond_tuple_var_bb3__PB__src_0_wrt_0) -> Float {
// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb3__PB__src_0_wrt_0):
// CHECK-SIL: [[BB3_PRED:%.*]] = struct_extract %1 : $_AD__cond_tuple_var_bb3__PB__src_0_wrt_0, #_AD__cond_tuple_var_bb3__PB__src_0_wrt_0.predecessor
// CHECK-SIL: copy_addr {{%.*}} to {{%.*}} : $*(Float, Float)
// CHECK-SIL-NOT: copy_addr {{%.*}} to {{%.*}} : $*Float
// CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0, case #_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0.bb2!enumelt.1: bb3, case #_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0.bb1!enumelt.1: bb1

// CHECK-SIL: bb1([[BB3_PRED1_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0):
// CHECK-SIL: br bb2([[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}}: $Float)

// CHECK-SIL: bb2([[BB1_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float):
// CHECK-SIL: [[BB1_PRED:%.*]] = struct_extract [[BB1_PB_STRUCT]]
// CHECK-SIL: copy_addr {{%.*}} to {{%.*}} : $*(Float, Float)
// CHECK-SIL-NOT: copy_addr {{%.*}} to {{%.*}} : $*Float
// CHECK-SIL: [[BB1_PB_STRUCT_DATA:%.*]] = unchecked_enum_data [[BB1_PRED]]
// CHECK-SIL: br bb5([[BB1_PB_STRUCT_DATA]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0)

// CHECK-SIL: bb3([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0):
// CHECK-SIL: br bb4([[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}}: $Float)

// CHECK-SIL: bb4([[BB2_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float):
// CHECK-SIL: [[BB2_PRED:%.*]] = struct_extract [[BB2_PB_STRUCT]]
// CHECK-SIL: copy_addr {{%.*}} to {{%.*}} : $*(Float, Float)
// CHECK-SIL-NOT: copy_addr {{%.*}} to {{%.*}} : $*Float
// CHECK-SIL: [[BB2_PB_STRUCT_DATA:%.*]] = unchecked_enum_data [[BB2_PRED]]
// CHECK-SIL: br bb6([[BB2_PB_STRUCT_DATA]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0)

// CHECK-SIL: bb5([[BB1_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0):
// CHECK-SIL: br bb7([[BB1_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0, {{%.*}} : $Float)

// CHECK-SIL: bb6([[BB2_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0):
// CHECK-SIL: br bb7([[BB2_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0, {{%.*}} : $Float)

// CHECK-SIL: bb7([[BB0_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0, {{%.*}} : $Float):
// CHECK-SIL: return {{%.*}} : $Float