@@ -3360,13 +3360,23 @@ static CanSILFunctionType buildWithoutActuallyEscapingThunkType(
3360
3360
// SWIFT_ENABLE_TENSORFLOW
3361
3361
// / Given a value, extracts all elements to `result` from this value if it's a
3362
3362
// / tuple. Otherwise, add this value directly to `result`.
3363
- static void extractAllElements (SILValue val, SILBuilder &builder,
3363
+ static void extractAllElements (SILValue val, SILLocation loc,
3364
+ SILBuilder &builder,
3364
3365
SmallVectorImpl<SILValue> &result) {
3365
- if (auto tupleType = val->getType ().getAs <TupleType>())
3366
- for (auto i : range (tupleType->getNumElements ()))
3367
- result.push_back (builder.createTupleExtract (val.getLoc (), val, i));
3368
- else
3366
+ auto &fn = builder.getFunction ();
3367
+ auto tupleType = val->getType ().getAs <TupleType>();
3368
+ if (!tupleType) {
3369
3369
result.push_back (val);
3370
+ return ;
3371
+ }
3372
+ if (!fn.hasOwnership ()) {
3373
+ for (auto i : range (tupleType->getNumElements ()))
3374
+ result.push_back (builder.createTupleExtract (loc, val, i));
3375
+ return ;
3376
+ }
3377
+ if (tupleType->getNumElements () == 0 )
3378
+ return ;
3379
+ builder.emitDestructureValueOperation (loc, val, result);
3370
3380
}
3371
3381
3372
3382
// SWIFT_ENABLE_TENSORFLOW
@@ -3408,17 +3418,18 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
3408
3418
thunkType, fromInterfaceType, toInterfaceType,
3409
3419
Type (), getModule ().getSwiftModule ());
3410
3420
// TODO(TF-685): Use principled thunk mangling.
3411
- if (reorderSelf) {
3412
- switch (assocFnKind) {
3413
- case AutoDiffAssociatedFunctionKind::JVP:
3414
- name += " _differential" ;
3415
- break ;
3416
- case AutoDiffAssociatedFunctionKind::VJP:
3417
- name += " _pullback" ;
3418
- break ;
3419
- }
3420
- name = " AD__" + name + " _self_reordering_thunk" ;
3421
+ switch (assocFnKind) {
3422
+ case AutoDiffAssociatedFunctionKind::JVP:
3423
+ name += " _differential" ;
3424
+ break ;
3425
+ case AutoDiffAssociatedFunctionKind::VJP:
3426
+ name += " _pullback" ;
3427
+ break ;
3421
3428
}
3429
+ name = " AD__" + name;
3430
+ if (reorderSelf)
3431
+ name += " _self_reordering" ;
3432
+ name += " _thunk" ;
3422
3433
3423
3434
// Create the thunk.
3424
3435
auto loc = F.getLocation ();
@@ -3441,7 +3452,6 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
3441
3452
if (!thunk->empty ())
3442
3453
return getThunkedResult ();
3443
3454
thunk->setGenericEnvironment (genericEnv);
3444
- thunk->setOwnershipEliminated ();
3445
3455
3446
3456
SILGenFunction thunkSGF (SGM, *thunk, FunctionDC);
3447
3457
SmallVector<ManagedValue, 4 > params;
@@ -3452,8 +3462,10 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
3452
3462
SILFunctionConventions toConv (toType, getModule ());
3453
3463
assert (toConv.useLoweredAddresses ());
3454
3464
3455
- SmallVector<SILArgument *, 4 > thunkArguments (thunk->getArguments ().begin (),
3456
- thunk->getArguments ().end ());
3465
+ SmallVector<ManagedValue, 4 > thunkArguments;
3466
+ for (auto *indRes : thunkIndirectResults)
3467
+ thunkArguments.push_back (ManagedValue::forLValue (indRes));
3468
+ thunkArguments.append (params.begin (), params.end ());
3457
3469
SmallVector<SILParameterInfo, 4 > toParameters (toConv.getParameters ().begin (),
3458
3470
toConv.getParameters ().end ());
3459
3471
SmallVector<SILResultInfo, 4 > toResults (toConv.getResults ().begin (),
@@ -3472,17 +3484,13 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
3472
3484
thunkArguments.begin () + numIndirectResults - 1 ,
3473
3485
thunkArguments.begin () + numIndirectResults);
3474
3486
}
3475
- std::rotate (toResults.begin (),
3476
- toResults.end () - 1 ,
3477
- toResults.end ());
3487
+ std::rotate (toResults.begin (), toResults.end () - 1 , toResults.end ());
3478
3488
}
3479
3489
if (reorderSelf && assocFnKind == AutoDiffAssociatedFunctionKind::JVP &&
3480
3490
thunkArguments.size () > 1 ) {
3481
3491
std::rotate (thunkArguments.begin () + numIndirectResults,
3482
- thunkArguments.end () - 2 ,
3483
- thunkArguments.end () - 1 );
3484
- std::rotate (toParameters.begin (),
3485
- toParameters.end () - 1 ,
3492
+ thunkArguments.end () - 2 , thunkArguments.end () - 1 );
3493
+ std::rotate (toParameters.begin (), toParameters.end () - 1 ,
3486
3494
toParameters.end ());
3487
3495
}
3488
3496
@@ -3506,7 +3514,8 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
3506
3514
SmallVector<SILValue, 4 > arguments;
3507
3515
auto toArgIter = thunkArguments.begin ();
3508
3516
auto useNextArgument = [&]() {
3509
- arguments.push_back (*toArgIter++);
3517
+ auto nextArgument = *toArgIter++;
3518
+ arguments.push_back (nextArgument.getValue ());
3510
3519
};
3511
3520
3512
3521
SmallVector<AllocStackInst *, 4 > localAllocations;
@@ -3555,19 +3564,17 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
3555
3564
if (!paramTy.hasArchetype ())
3556
3565
paramTy = thunk->mapTypeIntoContext (paramTy);
3557
3566
assert (paramTy.isAddress ());
3558
- auto * toArg = *toArgIter++;
3567
+ auto toArg = ( *toArgIter++). getValue () ;
3559
3568
auto *buf = createAllocStack (toArg->getType ());
3560
- thunkSGF.B .createStore (
3561
- loc, toArg, buf, StoreOwnershipQualifier::Unqualified);
3569
+ thunkSGF.B .createStore (loc, toArg, buf, StoreOwnershipQualifier::Init);
3562
3570
arguments.push_back (buf);
3563
3571
continue ;
3564
3572
}
3565
3573
// Convert direct parameter to indirect parameter.
3566
3574
assert (toParam.isFormalIndirect ());
3567
- auto *toArg = *toArgIter++;
3568
- auto *load =
3569
- thunkSGF.B .createLoad (loc, toArg, LoadOwnershipQualifier::Unqualified);
3570
- arguments.push_back (load);
3575
+ auto toArg = (*toArgIter++).getValue ();
3576
+ auto load = thunkSGF.emitManagedLoadBorrow (loc, toArg);
3577
+ arguments.push_back (load.getValue ());
3571
3578
}
3572
3579
3573
3580
auto *linearMapArg = thunk->getArgumentsWithoutIndirectResults ().back ();
@@ -3578,7 +3585,7 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
3578
3585
SmallVector<SILValue, 4 > results;
3579
3586
// Extract all direct results.
3580
3587
SmallVector<SILValue, 4 > directResults;
3581
- extractAllElements (apply, thunkSGF.B , directResults);
3588
+ extractAllElements (apply, loc, thunkSGF.B , directResults);
3582
3589
3583
3590
// Handle self reordering.
3584
3591
// For pullbacks: rotate direct results if self is direct.
@@ -3629,11 +3636,16 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
3629
3636
SILType resultTy = toConv.getSILType (toRes);
3630
3637
assert (resultTy.isAddress ());
3631
3638
auto indRes = *toIndResultsIter++;
3632
- thunkSGF.B .createStore (loc, *fromDirResultsIter++, indRes,
3633
- StoreOwnershipQualifier::Unqualified);
3639
+ thunkSGF.emitSemanticStore (loc, *fromDirResultsIter++, indRes,
3640
+ thunkSGF.getTypeLowering (resultTy),
3641
+ IsInitialization);
3634
3642
}
3635
3643
auto retVal = joinElements (results, thunkSGF.B , loc);
3636
3644
3645
+ // Emit cleanups.
3646
+ thunkSGF.Cleanups .emitCleanupsForReturn (
3647
+ CleanupLocation::get (loc), NotForUnwind);
3648
+
3637
3649
// Deallocate local allocations.
3638
3650
for (auto *alloc : reversed (localAllocations))
3639
3651
thunkSGF.B .createDeallocStack (loc, alloc);
@@ -3758,8 +3770,8 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
3758
3770
thunk->getForwardingSubstitutionMap (), arguments);
3759
3771
3760
3772
SmallVector<SILValue, 8 > directResults;
3761
- extractAllElements (apply, thunkSGF.B , directResults);
3762
- auto linearMap = ManagedValue::forBorrowedObjectRValue (directResults.back ());
3773
+ extractAllElements (apply, loc, thunkSGF.B , directResults);
3774
+ auto linearMap = ManagedValue::forUnmanaged (directResults.back ());
3763
3775
auto linearMapFnType = linearMap.getType ().castTo <SILFunctionType>();
3764
3776
auto targetLinearMapFnType = thunk->mapTypeIntoContext (
3765
3777
origAssocFnType->getResults ().back ().getSILStorageType ())
0 commit comments