Skip to content

Commit d6d79c6

Browse files
committed
Merge two fields into a PointerUnion in SILDeclRef to save space
1 parent 2a2cf91 commit d6d79c6

File tree

9 files changed

+60
-42
lines changed

9 files changed

+60
-42
lines changed

include/swift/SIL/SILDeclRef.h

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,28 @@ struct SILDeclRef {
150150
unsigned isForeign : 1;
151151
/// The default argument index for a default argument getter.
152152
unsigned defaultArgIndex : 10;
153+
154+
PointerUnion<AutoDiffDerivativeFunctionIdentifier *,
155+
const GenericSignatureImpl *>
156+
pointer;
157+
153158
/// The derivative function identifier.
154-
AutoDiffDerivativeFunctionIdentifier *derivativeFunctionIdentifier = nullptr;
159+
AutoDiffDerivativeFunctionIdentifier * getDerivativeFunctionIdentifier() const {
160+
if (!pointer.is<AutoDiffDerivativeFunctionIdentifier *>())
161+
return nullptr;
162+
return pointer.get<AutoDiffDerivativeFunctionIdentifier *>();
163+
}
155164

156-
GenericSignature specializedSignature;
165+
GenericSignature getSpecializedSignature() const {
166+
if (!pointer.is<const GenericSignatureImpl *>())
167+
return GenericSignature();
168+
else
169+
return GenericSignature(pointer.get<const GenericSignatureImpl *>());
170+
}
157171

158172
/// Produces a null SILDeclRef.
159173
SILDeclRef()
160-
: loc(), kind(Kind::Func), isForeign(0), defaultArgIndex(0),
161-
derivativeFunctionIdentifier(nullptr) {}
174+
: loc(), kind(Kind::Func), isForeign(0), defaultArgIndex(0) {}
162175

163176
/// Produces a SILDeclRef of the given kind for the given decl.
164177
explicit SILDeclRef(
@@ -294,7 +307,7 @@ struct SILDeclRef {
294307
return loc.getOpaqueValue() == rhs.loc.getOpaqueValue() &&
295308
kind == rhs.kind && isForeign == rhs.isForeign &&
296309
defaultArgIndex == rhs.defaultArgIndex &&
297-
derivativeFunctionIdentifier == rhs.derivativeFunctionIdentifier;
310+
pointer == rhs.pointer;
298311
}
299312
bool operator!=(SILDeclRef rhs) const {
300313
return !(*this == rhs);
@@ -309,7 +322,7 @@ struct SILDeclRef {
309322
/// decl.
310323
SILDeclRef asForeign(bool foreign = true) const {
311324
return SILDeclRef(loc.getOpaqueValue(), kind, foreign, defaultArgIndex,
312-
derivativeFunctionIdentifier);
325+
pointer.get<AutoDiffDerivativeFunctionIdentifier *>());
313326
}
314327

315328
/// Returns the entry point for the corresponding autodiff derivative
@@ -318,16 +331,16 @@ struct SILDeclRef {
318331
AutoDiffDerivativeFunctionIdentifier *derivativeId) const {
319332
assert(derivativeId);
320333
SILDeclRef declRef = *this;
321-
declRef.derivativeFunctionIdentifier = derivativeId;
334+
declRef.pointer = derivativeId;
322335
return declRef;
323336
}
324337

325338
/// Returns the entry point for the original function corresponding to an
326339
/// autodiff derivative function.
327340
SILDeclRef asAutoDiffOriginalFunction() const {
328-
assert(derivativeFunctionIdentifier);
341+
assert(pointer.get<AutoDiffDerivativeFunctionIdentifier *>());
329342
SILDeclRef declRef = *this;
330-
declRef.derivativeFunctionIdentifier = nullptr;
343+
declRef.pointer = (AutoDiffDerivativeFunctionIdentifier *)nullptr;
331344
return declRef;
332345
}
333346

@@ -405,13 +418,14 @@ struct SILDeclRef {
405418
bool canBeDynamicReplacement() const;
406419

407420
bool isAutoDiffDerivativeFunction() const {
408-
return derivativeFunctionIdentifier != nullptr;
421+
return pointer.is<AutoDiffDerivativeFunctionIdentifier *>() &&
422+
pointer.get<AutoDiffDerivativeFunctionIdentifier *>() != nullptr;
409423
}
410424

411425
AutoDiffDerivativeFunctionIdentifier *
412426
getAutoDiffDerivativeFunctionIdentifier() const {
413427
assert(isAutoDiffDerivativeFunction());
414-
return derivativeFunctionIdentifier;
428+
return pointer.get<AutoDiffDerivativeFunctionIdentifier *>();
415429
}
416430

417431
private:
@@ -422,7 +436,7 @@ struct SILDeclRef {
422436
AutoDiffDerivativeFunctionIdentifier *derivativeId)
423437
: loc(Loc::getFromOpaqueValue(opaqueLoc)), kind(kind),
424438
isForeign(isForeign), defaultArgIndex(defaultArgIndex),
425-
derivativeFunctionIdentifier(derivativeId) {}
439+
pointer(derivativeId) {}
426440
};
427441

428442
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, SILDeclRef C) {
@@ -457,7 +471,7 @@ template<> struct DenseMapInfo<swift::SILDeclRef> {
457471
? UnsignedInfo::getHashValue(Val.defaultArgIndex)
458472
: 0;
459473
unsigned h4 = UnsignedInfo::getHashValue(Val.isForeign);
460-
unsigned h5 = PointerInfo::getHashValue(Val.derivativeFunctionIdentifier);
474+
unsigned h5 = PointerInfo::getHashValue(Val.pointer.getOpaqueValue());
461475
return h1 ^ (h2 << 4) ^ (h3 << 9) ^ (h4 << 7) ^ (h5 << 11);
462476
}
463477
static bool isEqual(swift::SILDeclRef const &LHS,

lib/IRGen/GenMeta.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1683,7 +1683,7 @@ namespace {
16831683
void emitNonoverriddenMethod(SILDeclRef fn) {
16841684
// TODO: Derivative functions do not distinguish themselves in the mangled
16851685
// names of method descriptor symbols yet, causing symbol name collisions.
1686-
if (fn.derivativeFunctionIdentifier)
1686+
if (fn.getDerivativeFunctionIdentifier())
16871687
return;
16881688

16891689
HasNonoverriddenMethods = true;

lib/SIL/IR/SILDeclRef.cpp

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,11 @@ bool swift::requiresForeignEntryPoint(ValueDecl *vd) {
119119
SILDeclRef::SILDeclRef(ValueDecl *vd, SILDeclRef::Kind kind, bool isForeign,
120120
AutoDiffDerivativeFunctionIdentifier *derivativeId)
121121
: loc(vd), kind(kind), isForeign(isForeign), defaultArgIndex(0),
122-
derivativeFunctionIdentifier(derivativeId) {}
122+
pointer(derivativeId) {}
123123

124124
SILDeclRef::SILDeclRef(SILDeclRef::Loc baseLoc, bool asForeign)
125-
: defaultArgIndex(0), derivativeFunctionIdentifier(nullptr) {
125+
: defaultArgIndex(0),
126+
pointer((AutoDiffDerivativeFunctionIdentifier *)nullptr) {
126127
if (auto *vd = baseLoc.dyn_cast<ValueDecl*>()) {
127128
if (auto *fd = dyn_cast<FuncDecl>(vd)) {
128129
// Map FuncDecls directly to Func SILDeclRefs.
@@ -164,7 +165,7 @@ SILDeclRef::SILDeclRef(SILDeclRef::Loc baseLoc, bool asForeign)
164165
SILDeclRef::SILDeclRef(SILDeclRef::Loc baseLoc,
165166
GenericSignature prespecializedSig)
166167
: SILDeclRef(baseLoc, false) {
167-
specializedSignature = prespecializedSig;
168+
pointer = prespecializedSig.getPointer();
168169
}
169170

170171
Optional<AnyFunctionRef> SILDeclRef::getAnyFunctionRef() const {
@@ -232,7 +233,7 @@ bool SILDeclRef::isImplicit() const {
232233
SILLinkage SILDeclRef::getLinkage(ForDefinition_t forDefinition) const {
233234

234235
// Prespecializations are public.
235-
if (specializedSignature) {
236+
if (getSpecializedSignature()) {
236237
return SILLinkage::Public;
237238
}
238239

@@ -678,6 +679,7 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const {
678679
using namespace Mangle;
679680
ASTMangler mangler;
680681

682+
auto *derivativeFunctionIdentifier = getDerivativeFunctionIdentifier();
681683
if (derivativeFunctionIdentifier) {
682684
std::string originalMangled = asAutoDiffOriginalFunction().mangle(MKind);
683685
auto *silParameterIndices = autodiff::getLoweredParameterIndices(
@@ -716,14 +718,15 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const {
716718
}
717719

718720
// Mangle prespecializations.
719-
if (specializedSignature) {
721+
if (getSpecializedSignature()) {
720722
SILDeclRef nonSpecializedDeclRef = *this;
721-
nonSpecializedDeclRef.specializedSignature = GenericSignature();
723+
nonSpecializedDeclRef.pointer =
724+
(AutoDiffDerivativeFunctionIdentifier *)nullptr;
722725
auto mangledNonSpecializedString = nonSpecializedDeclRef.mangle();
723726
auto *funcDecl = cast<AbstractFunctionDecl>(getDecl());
724727
auto genericSig = funcDecl->getGenericSignature();
725728
return GenericSpecializationMangler::manglePrespecialization(
726-
mangledNonSpecializedString, genericSig, specializedSignature);
729+
mangledNonSpecializedString, genericSig, getSpecializedSignature());
727730
}
728731

729732
ASTMangler::SymbolKind SKind = ASTMangler::SymbolKind::Default;
@@ -818,7 +821,7 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const {
818821
// Returns true if the given JVP/VJP SILDeclRef requires a new vtable entry.
819822
// FIXME(TF-1213): Also consider derived declaration `@derivative` attributes.
820823
static bool derivativeFunctionRequiresNewVTableEntry(SILDeclRef declRef) {
821-
assert(declRef.derivativeFunctionIdentifier &&
824+
assert(declRef.getDerivativeFunctionIdentifier() &&
822825
"Expected a derivative function SILDeclRef");
823826
auto overridden = declRef.getOverridden();
824827
if (!overridden)
@@ -828,7 +831,7 @@ static bool derivativeFunctionRequiresNewVTableEntry(SILDeclRef declRef) {
828831
declRef.getDecl()->getAttrs().getAttributes<DifferentiableAttr>(),
829832
[&](const DifferentiableAttr *derivedDiffAttr) {
830833
return derivedDiffAttr->getParameterIndices() ==
831-
declRef.derivativeFunctionIdentifier->getParameterIndices();
834+
declRef.getDerivativeFunctionIdentifier()->getParameterIndices();
832835
});
833836
assert(derivedDiffAttr && "Expected `@differentiable` attribute");
834837
// Otherwise, if the base `@differentiable` attribute specifies a derivative
@@ -838,7 +841,7 @@ static bool derivativeFunctionRequiresNewVTableEntry(SILDeclRef declRef) {
838841
overridden.getDecl()->getAttrs().getAttributes<DifferentiableAttr>();
839842
for (auto *baseDiffAttr : baseDiffAttrs) {
840843
if (baseDiffAttr->getParameterIndices() ==
841-
declRef.derivativeFunctionIdentifier->getParameterIndices())
844+
declRef.getDerivativeFunctionIdentifier()->getParameterIndices())
842845
return false;
843846
}
844847
// Otherwise, if there is no base `@differentiable` attribute exists, then a
@@ -847,7 +850,7 @@ static bool derivativeFunctionRequiresNewVTableEntry(SILDeclRef declRef) {
847850
}
848851

849852
bool SILDeclRef::requiresNewVTableEntry() const {
850-
if (derivativeFunctionIdentifier)
853+
if (getDerivativeFunctionIdentifier())
851854
if (derivativeFunctionRequiresNewVTableEntry(*this))
852855
return true;
853856
if (!hasDecl())
@@ -928,15 +931,16 @@ SILDeclRef SILDeclRef::getNextOverriddenVTableEntry() const {
928931

929932
// JVPs/VJPs are overridden only if the base declaration has a
930933
// `@differentiable` attribute with the same parameter indices.
931-
if (derivativeFunctionIdentifier) {
934+
if (getDerivativeFunctionIdentifier()) {
932935
auto overriddenAttrs =
933936
overridden.getDecl()->getAttrs().getAttributes<DifferentiableAttr>();
934937
for (const auto *attr : overriddenAttrs) {
935938
if (attr->getParameterIndices() !=
936-
derivativeFunctionIdentifier->getParameterIndices())
939+
getDerivativeFunctionIdentifier()->getParameterIndices())
937940
continue;
938-
auto *overriddenDerivativeId = overridden.derivativeFunctionIdentifier;
939-
overridden.derivativeFunctionIdentifier =
941+
auto *overriddenDerivativeId =
942+
overridden.getDerivativeFunctionIdentifier();
943+
overridden.pointer =
940944
AutoDiffDerivativeFunctionIdentifier::get(
941945
overriddenDerivativeId->getKind(),
942946
overriddenDerivativeId->getParameterIndices(),

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3266,7 +3266,7 @@ TypeConverter::getConstantInfo(TypeExpansionContext expansion,
32663266
// preserving SIL typing invariants.
32673267
//
32683268
// Always use (ad) to compute lowered derivative function types.
3269-
if (auto *derivativeId = constant.derivativeFunctionIdentifier) {
3269+
if (auto *derivativeId = constant.getDerivativeFunctionIdentifier()) {
32703270
// Get lowered original function type.
32713271
auto origFnConstantInfo = getConstantInfo(
32723272
TypeExpansionContext::minimal(), constant.asAutoDiffOriginalFunction());

lib/SIL/IR/SILPrinter.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -349,19 +349,19 @@ void SILDeclRef::print(raw_ostream &OS) const {
349349
if (isForeign)
350350
OS << (isDot ? '.' : '!') << "foreign";
351351

352-
if (derivativeFunctionIdentifier) {
352+
if (getDerivativeFunctionIdentifier()) {
353353
OS << ((isDot || isForeign) ? '.' : '!');
354-
switch (derivativeFunctionIdentifier->getKind()) {
354+
switch (getDerivativeFunctionIdentifier()->getKind()) {
355355
case AutoDiffDerivativeFunctionKind::JVP:
356356
OS << "jvp.";
357357
break;
358358
case AutoDiffDerivativeFunctionKind::VJP:
359359
OS << "vjp.";
360360
break;
361361
}
362-
OS << derivativeFunctionIdentifier->getParameterIndices()->getString();
362+
OS << getDerivativeFunctionIdentifier()->getParameterIndices()->getString();
363363
if (auto derivativeGenSig =
364-
derivativeFunctionIdentifier->getDerivativeGenericSignature()) {
364+
getDerivativeFunctionIdentifier()->getDerivativeGenericSignature()) {
365365
OS << "." << derivativeGenSig;
366366
}
367367
}

lib/SIL/IR/TypeLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2520,7 +2520,7 @@ getFunctionInterfaceTypeWithCaptures(TypeConverter &TC,
25202520
}
25212521

25222522
CanAnyFunctionType TypeConverter::makeConstantInterfaceType(SILDeclRef c) {
2523-
if (auto *derivativeId = c.derivativeFunctionIdentifier) {
2523+
if (auto *derivativeId = c.getDerivativeFunctionIdentifier()) {
25242524
auto originalFnTy =
25252525
makeConstantInterfaceType(c.asAutoDiffOriginalFunction());
25262526
auto *derivativeFnTy = originalFnTy->getAutoDiffDerivativeFunctionType(

lib/SILGen/SILGenPoly.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4577,7 +4577,7 @@ getWitnessFunctionRef(SILGenFunction &SGF,
45774577
SILLocation loc) {
45784578
switch (witnessKind) {
45794579
case WitnessDispatchKind::Static:
4580-
if (auto *derivativeId = witness.derivativeFunctionIdentifier) {
4580+
if (auto *derivativeId = witness.getDerivativeFunctionIdentifier()) {
45814581
auto originalFn =
45824582
SGF.emitGlobalFunctionRef(loc, witness.asAutoDiffOriginalFunction());
45834583
auto *loweredParamIndices = autodiff::getLoweredParameterIndices(
@@ -4594,7 +4594,7 @@ getWitnessFunctionRef(SILGenFunction &SGF,
45944594
}
45954595
return SGF.emitGlobalFunctionRef(loc, witness);
45964596
case WitnessDispatchKind::Dynamic:
4597-
assert(!witness.derivativeFunctionIdentifier);
4597+
assert(!witness.getDerivativeFunctionIdentifier());
45984598
return SGF.emitDynamicMethodRef(loc, witness, witnessFTy).getValue();
45994599
case WitnessDispatchKind::Witness: {
46004600
auto typeAndConf =
@@ -4609,7 +4609,7 @@ getWitnessFunctionRef(SILGenFunction &SGF,
46094609
// If `witness` is a derivative function `SILDeclRef`, replace the
46104610
// derivative function identifier's generic signature with the witness thunk
46114611
// substitution map's generic signature.
4612-
if (auto *derivativeId = witness.derivativeFunctionIdentifier) {
4612+
if (auto *derivativeId = witness.getDerivativeFunctionIdentifier()) {
46134613
auto *newDerivativeId = AutoDiffDerivativeFunctionIdentifier::get(
46144614
derivativeId->getKind(), derivativeId->getParameterIndices(),
46154615
witnessSubs.getGenericSignature(), SGF.getASTContext());

lib/SILGen/SILGenThunk.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ getOrCreateReabstractionThunk(CanSILFunctionType thunkType,
175175

176176
SILFunction *SILGenModule::getOrCreateAutoDiffClassMethodThunk(
177177
SILDeclRef derivativeFnDeclRef, CanSILFunctionType constantTy) {
178-
auto *derivativeId = derivativeFnDeclRef.derivativeFunctionIdentifier;
178+
auto *derivativeId = derivativeFnDeclRef.getDerivativeFunctionIdentifier();
179179
assert(derivativeId);
180180
auto *derivativeFnDecl = derivativeFnDeclRef.getDecl();
181181

lib/SILGen/SILGenType.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ SILGenModule::emitVTableMethod(ClassDecl *theClass,
9090
implFn = getDynamicThunk(
9191
derived, Types.getConstantInfo(TypeExpansionContext::minimal(), derived)
9292
.SILFnType);
93-
} else if (auto *derivativeId = derived.derivativeFunctionIdentifier) {
93+
} else if (auto *derivativeId = derived.getDerivativeFunctionIdentifier()) {
9494
// For JVP/VJP methods, create a vtable entry thunk. The thunk contains an
9595
// `differentiable_function` instruction, which is later filled during the
9696
// differentiation transform.
@@ -168,7 +168,7 @@ SILGenModule::emitVTableMethod(ClassDecl *theClass,
168168
base.kind == SILDeclRef::Kind::Allocator);
169169
}
170170
// TODO(TF-685): Use proper autodiff thunk mangling.
171-
if (auto *derivativeId = derived.derivativeFunctionIdentifier) {
171+
if (auto *derivativeId = derived.getDerivativeFunctionIdentifier()) {
172172
switch (derivativeId->getKind()) {
173173
case AutoDiffDerivativeFunctionKind::JVP:
174174
name += "_jvp";
@@ -743,7 +743,7 @@ SILFunction *SILGenModule::emitProtocolWitness(
743743
std::string nameBuffer =
744744
NewMangler.mangleWitnessThunk(manglingConformance, requirement.getDecl());
745745
// TODO(TF-685): Proper mangling for derivative witness thunks.
746-
if (auto *derivativeId = requirement.derivativeFunctionIdentifier) {
746+
if (auto *derivativeId = requirement.getDerivativeFunctionIdentifier()) {
747747
std::string kindString;
748748
switch (derivativeId->getKind()) {
749749
case AutoDiffDerivativeFunctionKind::JVP:

0 commit comments

Comments
 (0)