Skip to content

Commit 646096e

Browse files
committed
Address review comments.
- Create common function `emitZeroIntoBuffer`. - Shared by `AdjointEmitter::emitZeroIndirect` and `buildZeroArgument` lambda. - Minor naming/style changes.
1 parent 10522e5 commit 646096e

File tree

4 files changed

+79
-99
lines changed

4 files changed

+79
-99
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -578,12 +578,12 @@ bool getBuiltinAutoDiffApplyConfig(StringRef operationName,
578578
unsigned &arity, unsigned &order,
579579
bool &rethrows);
580580

581-
/// Computes the correct linkage for associated functions given the linkage of
581+
/// Computes the correct linkage for an associated function given the linkage of
582582
/// the original function. If the original linkage is not external and
583583
/// `isAssocFnExported` is true, use the original function's linkage. Otherwise,
584584
/// return hidden linkage.
585-
SILLinkage getAutoDiffFunctionLinkage(SILLinkage originalLinkage,
586-
bool isAssocFnExported);
585+
SILLinkage getAutoDiffAssociatedFunctionLinkage(SILLinkage originalLinkage,
586+
bool isAssocFnExported);
587587

588588
} // end namespace autodiff
589589

lib/AST/AutoDiff.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ bool autodiff::getBuiltinAutoDiffApplyConfig(
8989
return operationName.empty();
9090
}
9191

92-
SILLinkage autodiff::getAutoDiffFunctionLinkage(SILLinkage originalLinkage,
93-
bool isAssocFnExported) {
92+
SILLinkage autodiff::getAutoDiffAssociatedFunctionLinkage(
93+
SILLinkage originalLinkage, bool isAssocFnExported) {
9494
// If the original is defined externally, then the AD pass is just generating
9595
// associated functions for use in the current module and therefore these
9696
// associated functions should not be visible outside the module.

lib/SILGen/SILGenPoly.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3530,7 +3530,7 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionReorderingThunk(
35303530

35313531
auto loc = assocFn->getLocation();
35323532
SILGenFunctionBuilder fb(*this);
3533-
auto linkage = autodiff::getAutoDiffFunctionLinkage(
3533+
auto linkage = autodiff::getAutoDiffAssociatedFunctionLinkage(
35343534
original->getLinkage(), /*isAssocFnExported*/ true);
35353535
auto *thunk = fb.getOrCreateFunction(
35363536
loc, name, linkage, targetType, IsBare, IsNotTransparent,

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 73 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ struct DifferentiationInvoker {
333333

334334
/// The `[differentiable]` attribute associated with the
335335
/// `SILDifferentiableAttribute` case.
336-
SILDifferentiableAttr * silDifferentiableAttribute;
336+
SILDifferentiableAttr *silDifferentiableAttribute;
337337
Value(SILDifferentiableAttr *attr) : silDifferentiableAttribute(attr) {}
338338
} value;
339339

@@ -896,18 +896,18 @@ class ADContext {
896896
/// Get or create an associated function index subset thunk from
897897
/// `actualIndices` to `desiredIndices` for the given associated function
898898
/// value and original function operand.
899-
/// Calls `getOrCreateLinearMapIndexSubsetThunk` to thunk the linear map
900-
/// returned by the associated function.
899+
/// Calls `getOrCreateSubsetParametersThunkForLinearMap` to thunk the linear
900+
/// map returned by the associated function.
901901
std::pair<SILFunction *, SubstitutionMap>
902-
getOrCreateAssociatedFunctionIndexSubsetThunk(
902+
getOrCreateSubsetParametersThunkForAssociatedFunction(
903903
SILValue origFnOperand, SILValue assocFn,
904904
AutoDiffAssociatedFunctionKind kind, SILAutoDiffIndices desiredIndices,
905905
SILAutoDiffIndices actualIndices);
906906

907907
/// Get or create an associated function index subset thunk from
908908
/// `actualIndices` to `desiredIndices` for the given associated function
909909
/// value and original function operand.
910-
SILFunction *getOrCreateLinearMapIndexSubsetThunk(
910+
SILFunction *getOrCreateSubsetParametersThunkForLinearMap(
911911
SILFunction *assocFn, CanSILFunctionType linearMapType,
912912
CanSILFunctionType targetType, AutoDiffAssociatedFunctionKind kind,
913913
SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices);
@@ -1986,6 +1986,41 @@ emitAssociatedFunctionReference(
19861986
return None;
19871987
}
19881988

1989+
/// Emit a zero value into the given buffer access by calling
1990+
/// `AdditiveArithmetic.zero`. The given type must conform to
1991+
/// `AdditiveArithmetic`.
1992+
static void emitZeroIntoBuffer(
1993+
SILBuilder &builder, CanType type, SILValue bufferAccess,
1994+
SILLocation loc) {
1995+
auto &astCtx = builder.getASTContext();
1996+
auto *swiftMod = builder.getModule().getSwiftModule();
1997+
auto &typeConverter = builder.getModule().Types;
1998+
// Look up conformance to `AdditiveArithmetic`.
1999+
auto *additiveArithmeticProto =
2000+
astCtx.getProtocol(KnownProtocolKind::AdditiveArithmetic);
2001+
auto confRef = swiftMod->lookupConformance(type, additiveArithmeticProto);
2002+
assert(confRef.hasValue() && "Missing conformance to `AdditiveArithmetic`");
2003+
// Look up `AdditiveArithmetic.zero.getter`.
2004+
auto zeroDeclLookup = additiveArithmeticProto->lookupDirect(astCtx.Id_zero);
2005+
auto *zeroDecl = cast<VarDecl>(zeroDeclLookup.front());
2006+
assert(zeroDecl->isProtocolRequirement());
2007+
auto *accessorDecl = zeroDecl->getAccessor(AccessorKind::Get);
2008+
SILDeclRef accessorDeclRef(accessorDecl, SILDeclRef::Kind::Func);
2009+
auto silFnType = typeConverter.getConstantType(accessorDeclRef);
2010+
// %wm = witness_method ...
2011+
auto *getter = builder.createWitnessMethod(
2012+
loc, type, *confRef, accessorDeclRef, silFnType);
2013+
// %metatype = metatype $T
2014+
auto metatypeType = CanMetatypeType::get(
2015+
type, MetatypeRepresentation::Thick);
2016+
auto metatype = builder.createMetatype(
2017+
loc, SILType::getPrimitiveObjectType(metatypeType));
2018+
auto subMap = SubstitutionMap::getProtocolSubstitutions(
2019+
additiveArithmeticProto, type, *confRef);
2020+
builder.createApply(loc, getter, subMap, {bufferAccess, metatype},
2021+
/*isNonThrowing*/ false);
2022+
}
2023+
19892024
//===----------------------------------------------------------------------===//
19902025
// Thunk helpers
19912026
//===----------------------------------------------------------------------===//
@@ -4823,35 +4858,9 @@ void AdjointEmitter::emitZeroIndirect(CanType type, SILValue bufferAccess,
48234858
LookUpConformanceInModule(swiftMod));
48244859
assert(tangentSpace && "No tangent space for this type");
48254860
switch (tangentSpace->getKind()) {
4826-
case VectorSpace::Kind::Vector: {
4827-
// Look up conformance to `AdditiveArithmetic`.
4828-
auto *additiveArithmeticProto =
4829-
getASTContext().getProtocol(KnownProtocolKind::AdditiveArithmetic);
4830-
auto confRef = swiftMod->lookupConformance(type, additiveArithmeticProto);
4831-
assert(confRef.hasValue() && "Missing conformance to `AdditiveArithmetic`");
4832-
// Look up `AdditiveArithmetic.zero.getter`.
4833-
auto zeroDeclLookup =
4834-
additiveArithmeticProto->lookupDirect(getASTContext().Id_zero);
4835-
auto *zeroDecl = cast<VarDecl>(zeroDeclLookup.front());
4836-
assert(zeroDecl->isProtocolRequirement());
4837-
auto *accessorDecl = zeroDecl->getAccessor(AccessorKind::Get);
4838-
SILDeclRef accessorDeclRef(accessorDecl, SILDeclRef::Kind::Func);
4839-
auto silFnType =
4840-
getContext().getTypeConverter().getConstantType(accessorDeclRef);
4841-
// %wm = witness_method ...
4842-
auto *getter = builder.createWitnessMethod(
4843-
loc, type, *confRef, accessorDeclRef, silFnType);
4844-
// %metatype = metatype $T
4845-
auto metatypeType = CanMetatypeType::get(
4846-
type, MetatypeRepresentation::Thick);
4847-
auto metatype = builder.createMetatype(
4848-
loc, SILType::getPrimitiveObjectType(metatypeType));
4849-
auto subMap = SubstitutionMap::getProtocolSubstitutions(
4850-
additiveArithmeticProto, type, *confRef);
4851-
builder.createApply(loc, getter, subMap, {bufferAccess, metatype},
4852-
/*isNonThrowing*/ false);
4861+
case VectorSpace::Kind::Vector:
4862+
emitZeroIntoBuffer(builder, type, bufferAccess, loc);
48534863
return;
4854-
}
48554864
case VectorSpace::Kind::Tuple: {
48564865
auto tupleType = tangentSpace->getTuple();
48574866
SmallVector<SILValue, 8> zeroElements;
@@ -5173,8 +5182,8 @@ bool VJPEmitter::run() {
51735182
// Create entry BB and arguments.
51745183
auto *entry = vjp->createBasicBlock();
51755184
createEntryArguments(vjp);
5176-
auto entryArgs = map<SmallVector<SILValue, 4>>(
5177-
entry->getArguments(), [](SILArgument *arg) { return arg; });
5185+
SmallVector<SILValue, 4> entryArgs(entry->getArguments().begin(),
5186+
entry->getArguments().end());
51785187

51795188
auto vjpGenericSig = vjp->getLoweredFunctionType()->getGenericSignature();
51805189
auto *primalValueStructDecl =
@@ -5313,7 +5322,7 @@ static SILFunction* createJVP(
53135322
jvpGenericSig);
53145323

53155324
SILOptFunctionBuilder fb(context.getTransform());
5316-
auto linkage = autodiff::getAutoDiffFunctionLinkage(
5325+
auto linkage = autodiff::getAutoDiffAssociatedFunctionLinkage(
53175326
original->getLinkage(), isExported);
53185327
auto *jvp = fb.createFunction(linkage, jvpName, jvpType, jvpGenericEnv,
53195328
original->getLocation(), original->isBare(),
@@ -5372,7 +5381,7 @@ static SILFunction *createEmptyVJP(
53725381
LookUpConformanceInModule(module.getSwiftModule()), vjpGenericSig);
53735382

53745383
SILOptFunctionBuilder fb(context.getTransform());
5375-
auto linkage = autodiff::getAutoDiffFunctionLinkage(
5384+
auto linkage = autodiff::getAutoDiffAssociatedFunctionLinkage(
53765385
original->getLinkage(), isExported);
53775386
auto *vjp = fb.createFunction(linkage, vjpName, vjpType, vjpGenericEnv,
53785387
original->getLocation(), original->isBare(),
@@ -5475,12 +5484,10 @@ class Differentiation : public SILModuleTransform {
54755484
} // end anonymous namespace
54765485

54775486
SILFunction *
5478-
ADContext::getOrCreateLinearMapIndexSubsetThunk(
5487+
ADContext::getOrCreateSubsetParametersThunkForLinearMap(
54795488
SILFunction *parentThunk, CanSILFunctionType linearMapType,
54805489
CanSILFunctionType targetType, AutoDiffAssociatedFunctionKind kind,
54815490
SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices) {
5482-
auto &astCtx = getASTContext();
5483-
54845491
SubstitutionMap interfaceSubs = parentThunk->getForwardingSubstitutionMap();
54855492
GenericEnvironment *genericEnv = parentThunk->getGenericEnvironment();
54865493
auto thunkType = buildThunkType(
@@ -5536,53 +5543,26 @@ ADContext::getOrCreateLinearMapIndexSubsetThunk(
55365543
LookUpConformanceInModule(swiftMod));
55375544
assert(tangentSpace && "No tangent space for this type");
55385545
switch (tangentSpace->getKind()) {
5539-
case VectorSpace::Kind::Vector: {
5540-
auto *buff = builder.createAllocStack(loc, zeroSILObjType);
5541-
localAllocations.push_back(buff);
5542-
// Look up conformance to `AdditiveArithmetic`.
5543-
auto *additiveArithmeticProto =
5544-
astCtx.getProtocol(KnownProtocolKind::AdditiveArithmetic);
5545-
auto confRef = swiftMod->lookupConformance(
5546-
zeroType, additiveArithmeticProto);
5547-
assert(confRef.hasValue() &&
5548-
"Missing conformance to `AdditiveArithmetic`");
5549-
// Look up `AdditiveArithmetic.zero.getter`.
5550-
auto zeroDeclLookup =
5551-
additiveArithmeticProto->lookupDirect(astCtx.Id_zero);
5552-
auto *zeroDecl = cast<VarDecl>(zeroDeclLookup.front());
5553-
assert(zeroDecl->isProtocolRequirement());
5554-
auto *accessorDecl = zeroDecl->getAccessor(AccessorKind::Get);
5555-
SILDeclRef accessorDeclRef(accessorDecl, SILDeclRef::Kind::Func);
5556-
auto silFnType =
5557-
getTypeConverter().getConstantType(accessorDeclRef);
5558-
// %wm = witness_method ...
5559-
auto *getter = builder.createWitnessMethod(
5560-
loc, zeroType, *confRef, accessorDeclRef, silFnType);
5561-
// %metatype = metatype $T
5562-
auto metatypeType = CanMetatypeType::get(
5563-
zeroType, MetatypeRepresentation::Thick);
5564-
auto metatype = builder.createMetatype(
5565-
loc, SILType::getPrimitiveObjectType(metatypeType));
5566-
auto subMap = SubstitutionMap::getProtocolSubstitutions(
5567-
additiveArithmeticProto, zeroType, *confRef);
5568-
builder.createApply(loc, getter, subMap, {buff, metatype},
5569-
/*isNonThrowing*/ false);
5570-
if (zeroSILType.isAddress())
5571-
arguments.push_back(buff);
5572-
else {
5573-
auto loq = getBufferLOQ(buff->getType().getASTType(), *thunk);
5574-
auto *arg = builder.createLoad(loc, buff, loq);
5575-
arguments.push_back(arg);
5576-
}
5577-
break;
5578-
}
5579-
case VectorSpace::Kind::Tuple: {
5580-
llvm_unreachable(
5581-
"Unimplemented: Handle zero initialization for tuples");
5546+
case VectorSpace::Kind::Vector: {
5547+
auto *buf = builder.createAllocStack(loc, zeroSILObjType);
5548+
localAllocations.push_back(buf);
5549+
emitZeroIntoBuffer(builder, zeroType, buf, loc);
5550+
if (zeroSILType.isAddress())
5551+
arguments.push_back(buf);
5552+
else {
5553+
auto loq = getBufferLOQ(buf->getType().getASTType(), *thunk);
5554+
auto *arg = builder.createLoad(loc, buf, loq);
5555+
arguments.push_back(arg);
55825556
}
5583-
case VectorSpace::Kind::Function:
5584-
llvm_unreachable(
5585-
"Unimplemented: Emit thunks for abstracting zero initialization");
5557+
break;
5558+
}
5559+
case VectorSpace::Kind::Tuple: {
5560+
llvm_unreachable(
5561+
"Unimplemented: Handle zero initialization for tuples");
5562+
}
5563+
case VectorSpace::Kind::Function:
5564+
llvm_unreachable(
5565+
"Unimplemented: Emit thunks for abstracting zero initialization");
55865566
}
55875567
};
55885568

@@ -5698,7 +5678,7 @@ ADContext::getOrCreateLinearMapIndexSubsetThunk(
56985678
}
56995679

57005680
std::pair<SILFunction *, SubstitutionMap>
5701-
ADContext::getOrCreateAssociatedFunctionIndexSubsetThunk(
5681+
ADContext::getOrCreateSubsetParametersThunkForAssociatedFunction(
57025682
SILValue origFnOperand, SILValue assocFn,
57035683
AutoDiffAssociatedFunctionKind kind, SILAutoDiffIndices desiredIndices,
57045684
SILAutoDiffIndices actualIndices) {
@@ -5819,7 +5799,7 @@ ADContext::getOrCreateAssociatedFunctionIndexSubsetThunk(
58195799
auto linearMapTargetType = targetType->getResults().back().getSILStorageType()
58205800
.castTo<SILFunctionType>();
58215801

5822-
auto *innerThunk = getOrCreateLinearMapIndexSubsetThunk(
5802+
auto *innerThunk = getOrCreateSubsetParametersThunkForLinearMap(
58235803
thunk, linearMapType, linearMapTargetType, kind,
58245804
desiredIndices, actualIndices);
58255805

@@ -5903,8 +5883,8 @@ SILValue ADContext::promoteToDifferentiableFunction(
59035883
return nullptr;
59045884

59055885
auto *newThunkRef = builder.createFunctionRef(loc, newThunk);
5906-
auto arguments = map<SmallVector<SILValue, 8>>(
5907-
ai->getArguments(), [](SILValue v) { return v; });
5886+
SmallVector<SILValue, 8> arguments(ai->getArguments().begin(),
5887+
ai->getArguments().end());
59085888
auto *newApply = builder.createApply(
59095889
ai->getLoc(), newThunkRef, ai->getSubstitutionMap(), arguments,
59105890
ai->isNonThrowing());
@@ -5947,7 +5927,7 @@ SILValue ADContext::promoteToDifferentiableFunction(
59475927
SILFunction *thunk;
59485928
SubstitutionMap interfaceSubs;
59495929
std::tie(thunk, interfaceSubs) =
5950-
getOrCreateAssociatedFunctionIndexSubsetThunk(
5930+
getOrCreateSubsetParametersThunkForAssociatedFunction(
59515931
origFnOperand, assocFn, assocFnKind, desiredIndices,
59525932
actualIndices);
59535933
auto *thunkFRI = builder.createFunctionRef(loc, thunk);
@@ -6078,8 +6058,8 @@ void Differentiation::run() {
60786058
context.getAutoDiffFunctionInsts().pop_back();
60796059
// Skip instructions that have been set to nullptr by
60806060
// `processAutoDiffFunctionInst`.
6081-
if (adfi)
6082-
errorOccurred |= context.processAutoDiffFunctionInst(adfi);
6061+
if (!adfi) continue;
6062+
errorOccurred |= context.processAutoDiffFunctionInst(adfi);
60836063
}
60846064

60856065
// If any error occurred while processing `[differentiable]` attributes or

0 commit comments

Comments
 (0)