@@ -333,7 +333,7 @@ struct DifferentiationInvoker {
333
333
334
334
// / The `[differentiable]` attribute associated with the
335
335
// / `SILDifferentiableAttribute` case.
336
- SILDifferentiableAttr * silDifferentiableAttribute;
336
+ SILDifferentiableAttr *silDifferentiableAttribute;
337
337
Value (SILDifferentiableAttr *attr) : silDifferentiableAttribute (attr) {}
338
338
} value;
339
339
@@ -896,18 +896,18 @@ class ADContext {
896
896
// / Get or create an associated function index subset thunk from
897
897
// / `actualIndices` to `desiredIndices` for the given associated function
898
898
// / value and original function operand.
899
- // / Calls `getOrCreateLinearMapIndexSubsetThunk ` to thunk the linear map
900
- // / returned by the associated function.
899
+ // / Calls `getOrCreateSubsetParametersThunkForLinearMap ` to thunk the linear
900
+ // / map returned by the associated function.
901
901
std::pair<SILFunction *, SubstitutionMap>
902
- getOrCreateAssociatedFunctionIndexSubsetThunk (
902
+ getOrCreateSubsetParametersThunkForAssociatedFunction (
903
903
SILValue origFnOperand, SILValue assocFn,
904
904
AutoDiffAssociatedFunctionKind kind, SILAutoDiffIndices desiredIndices,
905
905
SILAutoDiffIndices actualIndices);
906
906
907
907
// / Get or create an associated function index subset thunk from
908
908
// / `actualIndices` to `desiredIndices` for the given associated function
909
909
// / value and original function operand.
910
- SILFunction *getOrCreateLinearMapIndexSubsetThunk (
910
+ SILFunction *getOrCreateSubsetParametersThunkForLinearMap (
911
911
SILFunction *assocFn, CanSILFunctionType linearMapType,
912
912
CanSILFunctionType targetType, AutoDiffAssociatedFunctionKind kind,
913
913
SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices);
@@ -1986,6 +1986,41 @@ emitAssociatedFunctionReference(
1986
1986
return None;
1987
1987
}
1988
1988
1989
+ // / Emit a zero value into the given buffer access by calling
1990
+ // / `AdditiveArithmetic.zero`. The given type must conform to
1991
+ // / `AdditiveArithmetic`.
1992
+ static void emitZeroIntoBuffer (
1993
+ SILBuilder &builder, CanType type, SILValue bufferAccess,
1994
+ SILLocation loc) {
1995
+ auto &astCtx = builder.getASTContext ();
1996
+ auto *swiftMod = builder.getModule ().getSwiftModule ();
1997
+ auto &typeConverter = builder.getModule ().Types ;
1998
+ // Look up conformance to `AdditiveArithmetic`.
1999
+ auto *additiveArithmeticProto =
2000
+ astCtx.getProtocol (KnownProtocolKind::AdditiveArithmetic);
2001
+ auto confRef = swiftMod->lookupConformance (type, additiveArithmeticProto);
2002
+ assert (confRef.hasValue () && " Missing conformance to `AdditiveArithmetic`" );
2003
+ // Look up `AdditiveArithmetic.zero.getter`.
2004
+ auto zeroDeclLookup = additiveArithmeticProto->lookupDirect (astCtx.Id_zero );
2005
+ auto *zeroDecl = cast<VarDecl>(zeroDeclLookup.front ());
2006
+ assert (zeroDecl->isProtocolRequirement ());
2007
+ auto *accessorDecl = zeroDecl->getAccessor (AccessorKind::Get);
2008
+ SILDeclRef accessorDeclRef (accessorDecl, SILDeclRef::Kind::Func);
2009
+ auto silFnType = typeConverter.getConstantType (accessorDeclRef);
2010
+ // %wm = witness_method ...
2011
+ auto *getter = builder.createWitnessMethod (
2012
+ loc, type, *confRef, accessorDeclRef, silFnType);
2013
+ // %metatype = metatype $T
2014
+ auto metatypeType = CanMetatypeType::get (
2015
+ type, MetatypeRepresentation::Thick);
2016
+ auto metatype = builder.createMetatype (
2017
+ loc, SILType::getPrimitiveObjectType (metatypeType));
2018
+ auto subMap = SubstitutionMap::getProtocolSubstitutions (
2019
+ additiveArithmeticProto, type, *confRef);
2020
+ builder.createApply (loc, getter, subMap, {bufferAccess, metatype},
2021
+ /* isNonThrowing*/ false );
2022
+ }
2023
+
1989
2024
// ===----------------------------------------------------------------------===//
1990
2025
// Thunk helpers
1991
2026
// ===----------------------------------------------------------------------===//
@@ -4823,35 +4858,9 @@ void AdjointEmitter::emitZeroIndirect(CanType type, SILValue bufferAccess,
4823
4858
LookUpConformanceInModule (swiftMod));
4824
4859
assert (tangentSpace && " No tangent space for this type" );
4825
4860
switch (tangentSpace->getKind ()) {
4826
- case VectorSpace::Kind::Vector: {
4827
- // Look up conformance to `AdditiveArithmetic`.
4828
- auto *additiveArithmeticProto =
4829
- getASTContext ().getProtocol (KnownProtocolKind::AdditiveArithmetic);
4830
- auto confRef = swiftMod->lookupConformance (type, additiveArithmeticProto);
4831
- assert (confRef.hasValue () && " Missing conformance to `AdditiveArithmetic`" );
4832
- // Look up `AdditiveArithmetic.zero.getter`.
4833
- auto zeroDeclLookup =
4834
- additiveArithmeticProto->lookupDirect (getASTContext ().Id_zero );
4835
- auto *zeroDecl = cast<VarDecl>(zeroDeclLookup.front ());
4836
- assert (zeroDecl->isProtocolRequirement ());
4837
- auto *accessorDecl = zeroDecl->getAccessor (AccessorKind::Get);
4838
- SILDeclRef accessorDeclRef (accessorDecl, SILDeclRef::Kind::Func);
4839
- auto silFnType =
4840
- getContext ().getTypeConverter ().getConstantType (accessorDeclRef);
4841
- // %wm = witness_method ...
4842
- auto *getter = builder.createWitnessMethod (
4843
- loc, type, *confRef, accessorDeclRef, silFnType);
4844
- // %metatype = metatype $T
4845
- auto metatypeType = CanMetatypeType::get (
4846
- type, MetatypeRepresentation::Thick);
4847
- auto metatype = builder.createMetatype (
4848
- loc, SILType::getPrimitiveObjectType (metatypeType));
4849
- auto subMap = SubstitutionMap::getProtocolSubstitutions (
4850
- additiveArithmeticProto, type, *confRef);
4851
- builder.createApply (loc, getter, subMap, {bufferAccess, metatype},
4852
- /* isNonThrowing*/ false );
4861
+ case VectorSpace::Kind::Vector:
4862
+ emitZeroIntoBuffer (builder, type, bufferAccess, loc);
4853
4863
return ;
4854
- }
4855
4864
case VectorSpace::Kind::Tuple: {
4856
4865
auto tupleType = tangentSpace->getTuple ();
4857
4866
SmallVector<SILValue, 8 > zeroElements;
@@ -5173,8 +5182,8 @@ bool VJPEmitter::run() {
5173
5182
// Create entry BB and arguments.
5174
5183
auto *entry = vjp->createBasicBlock ();
5175
5184
createEntryArguments (vjp);
5176
- auto entryArgs = map< SmallVector<SILValue, 4 >>(
5177
- entry->getArguments (), [](SILArgument *arg) { return arg; } );
5185
+ SmallVector<SILValue, 4 > entryArgs (entry-> getArguments (). begin (),
5186
+ entry->getArguments (). end () );
5178
5187
5179
5188
auto vjpGenericSig = vjp->getLoweredFunctionType ()->getGenericSignature ();
5180
5189
auto *primalValueStructDecl =
@@ -5313,7 +5322,7 @@ static SILFunction* createJVP(
5313
5322
jvpGenericSig);
5314
5323
5315
5324
SILOptFunctionBuilder fb (context.getTransform ());
5316
- auto linkage = autodiff::getAutoDiffFunctionLinkage (
5325
+ auto linkage = autodiff::getAutoDiffAssociatedFunctionLinkage (
5317
5326
original->getLinkage (), isExported);
5318
5327
auto *jvp = fb.createFunction (linkage, jvpName, jvpType, jvpGenericEnv,
5319
5328
original->getLocation (), original->isBare (),
@@ -5372,7 +5381,7 @@ static SILFunction *createEmptyVJP(
5372
5381
LookUpConformanceInModule (module .getSwiftModule ()), vjpGenericSig);
5373
5382
5374
5383
SILOptFunctionBuilder fb (context.getTransform ());
5375
- auto linkage = autodiff::getAutoDiffFunctionLinkage (
5384
+ auto linkage = autodiff::getAutoDiffAssociatedFunctionLinkage (
5376
5385
original->getLinkage (), isExported);
5377
5386
auto *vjp = fb.createFunction (linkage, vjpName, vjpType, vjpGenericEnv,
5378
5387
original->getLocation (), original->isBare (),
@@ -5475,12 +5484,10 @@ class Differentiation : public SILModuleTransform {
5475
5484
} // end anonymous namespace
5476
5485
5477
5486
SILFunction *
5478
- ADContext::getOrCreateLinearMapIndexSubsetThunk (
5487
+ ADContext::getOrCreateSubsetParametersThunkForLinearMap (
5479
5488
SILFunction *parentThunk, CanSILFunctionType linearMapType,
5480
5489
CanSILFunctionType targetType, AutoDiffAssociatedFunctionKind kind,
5481
5490
SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices) {
5482
- auto &astCtx = getASTContext ();
5483
-
5484
5491
SubstitutionMap interfaceSubs = parentThunk->getForwardingSubstitutionMap ();
5485
5492
GenericEnvironment *genericEnv = parentThunk->getGenericEnvironment ();
5486
5493
auto thunkType = buildThunkType (
@@ -5536,53 +5543,26 @@ ADContext::getOrCreateLinearMapIndexSubsetThunk(
5536
5543
LookUpConformanceInModule (swiftMod));
5537
5544
assert (tangentSpace && " No tangent space for this type" );
5538
5545
switch (tangentSpace->getKind ()) {
5539
- case VectorSpace::Kind::Vector: {
5540
- auto *buff = builder.createAllocStack (loc, zeroSILObjType);
5541
- localAllocations.push_back (buff);
5542
- // Look up conformance to `AdditiveArithmetic`.
5543
- auto *additiveArithmeticProto =
5544
- astCtx.getProtocol (KnownProtocolKind::AdditiveArithmetic);
5545
- auto confRef = swiftMod->lookupConformance (
5546
- zeroType, additiveArithmeticProto);
5547
- assert (confRef.hasValue () &&
5548
- " Missing conformance to `AdditiveArithmetic`" );
5549
- // Look up `AdditiveArithmetic.zero.getter`.
5550
- auto zeroDeclLookup =
5551
- additiveArithmeticProto->lookupDirect (astCtx.Id_zero );
5552
- auto *zeroDecl = cast<VarDecl>(zeroDeclLookup.front ());
5553
- assert (zeroDecl->isProtocolRequirement ());
5554
- auto *accessorDecl = zeroDecl->getAccessor (AccessorKind::Get);
5555
- SILDeclRef accessorDeclRef (accessorDecl, SILDeclRef::Kind::Func);
5556
- auto silFnType =
5557
- getTypeConverter ().getConstantType (accessorDeclRef);
5558
- // %wm = witness_method ...
5559
- auto *getter = builder.createWitnessMethod (
5560
- loc, zeroType, *confRef, accessorDeclRef, silFnType);
5561
- // %metatype = metatype $T
5562
- auto metatypeType = CanMetatypeType::get (
5563
- zeroType, MetatypeRepresentation::Thick);
5564
- auto metatype = builder.createMetatype (
5565
- loc, SILType::getPrimitiveObjectType (metatypeType));
5566
- auto subMap = SubstitutionMap::getProtocolSubstitutions (
5567
- additiveArithmeticProto, zeroType, *confRef);
5568
- builder.createApply (loc, getter, subMap, {buff, metatype},
5569
- /* isNonThrowing*/ false );
5570
- if (zeroSILType.isAddress ())
5571
- arguments.push_back (buff);
5572
- else {
5573
- auto loq = getBufferLOQ (buff->getType ().getASTType (), *thunk);
5574
- auto *arg = builder.createLoad (loc, buff, loq);
5575
- arguments.push_back (arg);
5576
- }
5577
- break ;
5578
- }
5579
- case VectorSpace::Kind::Tuple: {
5580
- llvm_unreachable (
5581
- " Unimplemented: Handle zero initialization for tuples" );
5546
+ case VectorSpace::Kind::Vector: {
5547
+ auto *buf = builder.createAllocStack (loc, zeroSILObjType);
5548
+ localAllocations.push_back (buf);
5549
+ emitZeroIntoBuffer (builder, zeroType, buf, loc);
5550
+ if (zeroSILType.isAddress ())
5551
+ arguments.push_back (buf);
5552
+ else {
5553
+ auto loq = getBufferLOQ (buf->getType ().getASTType (), *thunk);
5554
+ auto *arg = builder.createLoad (loc, buf, loq);
5555
+ arguments.push_back (arg);
5582
5556
}
5583
- case VectorSpace::Kind::Function:
5584
- llvm_unreachable (
5585
- " Unimplemented: Emit thunks for abstracting zero initialization" );
5557
+ break ;
5558
+ }
5559
+ case VectorSpace::Kind::Tuple: {
5560
+ llvm_unreachable (
5561
+ " Unimplemented: Handle zero initialization for tuples" );
5562
+ }
5563
+ case VectorSpace::Kind::Function:
5564
+ llvm_unreachable (
5565
+ " Unimplemented: Emit thunks for abstracting zero initialization" );
5586
5566
}
5587
5567
};
5588
5568
@@ -5698,7 +5678,7 @@ ADContext::getOrCreateLinearMapIndexSubsetThunk(
5698
5678
}
5699
5679
5700
5680
std::pair<SILFunction *, SubstitutionMap>
5701
- ADContext::getOrCreateAssociatedFunctionIndexSubsetThunk (
5681
+ ADContext::getOrCreateSubsetParametersThunkForAssociatedFunction (
5702
5682
SILValue origFnOperand, SILValue assocFn,
5703
5683
AutoDiffAssociatedFunctionKind kind, SILAutoDiffIndices desiredIndices,
5704
5684
SILAutoDiffIndices actualIndices) {
@@ -5819,7 +5799,7 @@ ADContext::getOrCreateAssociatedFunctionIndexSubsetThunk(
5819
5799
auto linearMapTargetType = targetType->getResults ().back ().getSILStorageType ()
5820
5800
.castTo <SILFunctionType>();
5821
5801
5822
- auto *innerThunk = getOrCreateLinearMapIndexSubsetThunk (
5802
+ auto *innerThunk = getOrCreateSubsetParametersThunkForLinearMap (
5823
5803
thunk, linearMapType, linearMapTargetType, kind,
5824
5804
desiredIndices, actualIndices);
5825
5805
@@ -5903,8 +5883,8 @@ SILValue ADContext::promoteToDifferentiableFunction(
5903
5883
return nullptr ;
5904
5884
5905
5885
auto *newThunkRef = builder.createFunctionRef (loc, newThunk);
5906
- auto arguments = map< SmallVector<SILValue, 8 >>(
5907
- ai->getArguments (), [](SILValue v) { return v; } );
5886
+ SmallVector<SILValue, 8 > arguments (ai-> getArguments (). begin (),
5887
+ ai->getArguments (). end () );
5908
5888
auto *newApply = builder.createApply (
5909
5889
ai->getLoc (), newThunkRef, ai->getSubstitutionMap (), arguments,
5910
5890
ai->isNonThrowing ());
@@ -5947,7 +5927,7 @@ SILValue ADContext::promoteToDifferentiableFunction(
5947
5927
SILFunction *thunk;
5948
5928
SubstitutionMap interfaceSubs;
5949
5929
std::tie (thunk, interfaceSubs) =
5950
- getOrCreateAssociatedFunctionIndexSubsetThunk (
5930
+ getOrCreateSubsetParametersThunkForAssociatedFunction (
5951
5931
origFnOperand, assocFn, assocFnKind, desiredIndices,
5952
5932
actualIndices);
5953
5933
auto *thunkFRI = builder.createFunctionRef (loc, thunk);
@@ -6078,8 +6058,8 @@ void Differentiation::run() {
6078
6058
context.getAutoDiffFunctionInsts ().pop_back ();
6079
6059
// Skip instructions that have been set to nullptr by
6080
6060
// `processAutoDiffFunctionInst`.
6081
- if (adfi)
6082
- errorOccurred |= context.processAutoDiffFunctionInst (adfi);
6061
+ if (! adfi) continue ;
6062
+ errorOccurred |= context.processAutoDiffFunctionInst (adfi);
6083
6063
}
6084
6064
6085
6065
// If any error occurred while processing `[differentiable]` attributes or
0 commit comments