@@ -165,9 +165,11 @@ SILDeserializer::SILDeserializer(
165
165
kind == sil_index_block::SIL_GLOBALVAR_NAMES ||
166
166
kind == sil_index_block::SIL_WITNESS_TABLE_NAMES ||
167
167
kind == sil_index_block::SIL_DEFAULT_WITNESS_TABLE_NAMES ||
168
- kind == sil_index_block::SIL_PROPERTY_OFFSETS)) &&
168
+ kind == sil_index_block::SIL_PROPERTY_OFFSETS ||
169
+ kind == sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_NAMES)) &&
169
170
" Expect SIL_FUNC_NAMES, SIL_VTABLE_NAMES, SIL_GLOBALVAR_NAMES, \
170
- SIL_WITNESS_TABLE_NAMES, or SIL_DEFAULT_WITNESS_TABLE_NAMES." );
171
+ SIL_WITNESS_TABLE_NAMES, SIL_DEFAULT_WITNESS_TABLE_NAMES, \
172
+ SIL_PROPERTY_OFFSETS, or SIL_DIFFERENTIABILITY_WITNESS_NAMES." );
171
173
(void )prevKind;
172
174
173
175
if (kind == sil_index_block::SIL_FUNC_NAMES)
@@ -180,6 +182,8 @@ SILDeserializer::SILDeserializer(
180
182
WitnessTableList = readFuncTable (scratch, blobData);
181
183
else if (kind == sil_index_block::SIL_DEFAULT_WITNESS_TABLE_NAMES)
182
184
DefaultWitnessTableList = readFuncTable (scratch, blobData);
185
+ else if (kind == sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_NAMES)
186
+ DifferentiabilityWitnessList = readFuncTable (scratch, blobData);
183
187
else if (kind == sil_index_block::SIL_PROPERTY_OFFSETS) {
184
188
// No matching 'names' block for property descriptors needed yet.
185
189
MF->allocateBuffer (Properties, scratch);
@@ -217,6 +221,12 @@ SILDeserializer::SILDeserializer(
217
221
offKind == sil_index_block::SIL_DEFAULT_WITNESS_TABLE_OFFSETS) &&
218
222
" Expect a SIL_DEFAULT_WITNESS_TABLE_OFFSETS record." );
219
223
MF->allocateBuffer (DefaultWitnessTables, scratch);
224
+ } else if (kind == sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_NAMES) {
225
+ assert ((next.Kind == llvm::BitstreamEntry::Record &&
226
+ offKind ==
227
+ sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_OFFSETS) &&
228
+ " Expect a SIL_DIFFERENTIABILITY_WITNESS_OFFSETS record." );
229
+ MF->allocateBuffer (DifferentiabilityWitnesses, scratch);
220
230
}
221
231
}
222
232
}
@@ -339,6 +349,24 @@ SILType SILDeserializer::getSILType(Type Ty, SILValueCategory Category,
339
349
.getCategoryType (Category);
340
350
}
341
351
352
+ // / Helper function to find a SILDifferentiabilityWitness, given its mangled
353
+ // / key.
354
+ SILDifferentiabilityWitness *
355
+ SILDeserializer::getSILDifferentiabilityWitnessForReference (
356
+ StringRef mangledKey) {
357
+ // Check to see if we have a witness under this key already.
358
+ auto *witness = SILMod.lookUpDifferentiabilityWitness (mangledKey);
359
+ if (witness)
360
+ return witness;
361
+ // Otherwise, look for a witness under this key in the module.
362
+ if (!DifferentiabilityWitnessList)
363
+ return nullptr ;
364
+ auto iter = DifferentiabilityWitnessList->find (mangledKey);
365
+ if (iter == DifferentiabilityWitnessList->end ())
366
+ return nullptr ;
367
+ return readDifferentiabilityWitness (*iter);
368
+ }
369
+
342
370
// / Helper function to find a SILFunction, given its name and type.
343
371
SILFunction *SILDeserializer::getFuncForReference (StringRef name,
344
372
SILType type) {
@@ -760,7 +788,7 @@ SILDeserializer::readSILFunctionChecked(DeclID FID, SILFunction *existingFn,
760
788
// SIL_VTABLE or SIL_GLOBALVAR or SIL_WITNESS_TABLE record also means the end
761
789
// of this SILFunction.
762
790
while (kind != SIL_FUNCTION && kind != SIL_VTABLE && kind != SIL_GLOBALVAR &&
763
- kind != SIL_WITNESS_TABLE) {
791
+ kind != SIL_WITNESS_TABLE && kind != SIL_DIFFERENTIABILITY_WITNESS ) {
764
792
if (kind == SIL_BASIC_BLOCK)
765
793
// Handle a SILBasicBlock record.
766
794
CurrentBB = readSILBasicBlock (fn, CurrentBB, scratch);
@@ -2988,6 +3016,7 @@ void SILDeserializer::readWitnessTableEntries(
2988
3016
// Another record means the end of this WitnessTable.
2989
3017
while (kind != SIL_WITNESS_TABLE &&
2990
3018
kind != SIL_DEFAULT_WITNESS_TABLE &&
3019
+ kind != SIL_DIFFERENTIABILITY_WITNESS &&
2991
3020
kind != SIL_FUNCTION) {
2992
3021
if (kind == SIL_DEFAULT_WITNESS_TABLE_NO_ENTRY) {
2993
3022
witnessEntries.push_back (SILDefaultWitnessTable::Entry ());
@@ -3343,6 +3372,138 @@ SILDeserializer::lookupDefaultWitnessTable(SILDefaultWitnessTable *existingWt) {
3343
3372
return Wt;
3344
3373
}
3345
3374
3375
+ SILDifferentiabilityWitness *
3376
+ SILDeserializer::readDifferentiabilityWitness (DeclID DId) {
3377
+ if (DId == 0 )
3378
+ return nullptr ;
3379
+ assert (DId <= DifferentiabilityWitnesses.size () &&
3380
+ " Invalid SILDifferentiabilityWitness ID" );
3381
+
3382
+ auto &diffWitnessOrOffset = DifferentiabilityWitnesses[DId - 1 ];
3383
+ if (diffWitnessOrOffset.isFullyDeserialized ())
3384
+ return diffWitnessOrOffset.get ();
3385
+
3386
+ BCOffsetRAII restoreOffset (SILCursor);
3387
+ if (auto err = SILCursor.JumpToBit (diffWitnessOrOffset.getOffset ()))
3388
+ MF->fatal (std::move (err));
3389
+ llvm::Expected<llvm::BitstreamEntry> maybeEntry =
3390
+ SILCursor.advance (AF_DontPopBlockAtEnd);
3391
+ if (!maybeEntry)
3392
+ MF->fatal (maybeEntry.takeError ());
3393
+ auto entry = maybeEntry.get ();
3394
+ if (entry.Kind == llvm::BitstreamEntry::Error) {
3395
+ LLVM_DEBUG (llvm::dbgs () << " Cursor advance error in "
3396
+ " readDefaultWitnessTable.\n " );
3397
+ return nullptr ;
3398
+ }
3399
+
3400
+ SmallVector<uint64_t , 64 > scratch;
3401
+ StringRef blobData;
3402
+ llvm::Expected<unsigned > maybeKind =
3403
+ SILCursor.readRecord (entry.ID , scratch, &blobData);
3404
+ if (!maybeKind)
3405
+ MF->fatal (maybeKind.takeError ());
3406
+ unsigned kind = maybeKind.get ();
3407
+ assert (kind == SIL_DIFFERENTIABILITY_WITNESS &&
3408
+ " Expected sil_differentiability_witness" );
3409
+ (void )kind;
3410
+
3411
+ DeclID originalNameId, jvpNameId, vjpNameId;
3412
+ unsigned rawLinkage, isDeclaration, isSerialized, numParameterIndices,
3413
+ numResultIndices;
3414
+ GenericSignatureID derivativeGenSigID;
3415
+ ArrayRef<uint64_t > rawParameterAndResultIndices;
3416
+
3417
+ DifferentiabilityWitnessLayout::readRecord (
3418
+ scratch, originalNameId, rawLinkage, isDeclaration, isSerialized,
3419
+ derivativeGenSigID, jvpNameId, vjpNameId, numParameterIndices,
3420
+ numResultIndices, rawParameterAndResultIndices);
3421
+
3422
+ if (isDeclaration) {
3423
+ assert (!isSerialized && " declaration must not be serialized" );
3424
+ }
3425
+
3426
+ auto linkageOpt = fromStableSILLinkage (rawLinkage);
3427
+ assert (linkageOpt &&
3428
+ " Expected value linkage for sil_differentiability_witness" );
3429
+ auto originalName = MF->getIdentifierText (originalNameId);
3430
+ auto jvpName = MF->getIdentifierText (jvpNameId);
3431
+ auto vjpName = MF->getIdentifierText (vjpNameId);
3432
+ auto *original = getFuncForReference (originalName);
3433
+ assert (original && " Original function must be found" );
3434
+ auto *jvp = getFuncForReference (jvpName);
3435
+ if (!jvpName.empty ()) {
3436
+ assert (!isDeclaration && " JVP must not be defined in declaration" );
3437
+ assert (jvp && " JVP function must be found if JVP name is not empty" );
3438
+ }
3439
+ auto *vjp = getFuncForReference (vjpName);
3440
+ if (!vjpName.empty ()) {
3441
+ assert (!isDeclaration && " VJP must not be defined in declaration" );
3442
+ assert (vjp && " VJP function must be found if VJP name is not empty" );
3443
+ }
3444
+ auto derivativeGenSig = MF->getGenericSignature (derivativeGenSigID);
3445
+
3446
+ SmallVector<unsigned , 8 > parameterAndResultIndices (
3447
+ rawParameterAndResultIndices.begin (), rawParameterAndResultIndices.end ());
3448
+ assert (parameterAndResultIndices.size () ==
3449
+ numParameterIndices + numResultIndices &&
3450
+ " Parameter/result indices count mismatch" );
3451
+ auto *parameterIndices = IndexSubset::get (
3452
+ MF->getContext (), original->getLoweredFunctionType ()->getNumParameters (),
3453
+ ArrayRef<unsigned >(parameterAndResultIndices)
3454
+ .take_front (numParameterIndices));
3455
+ auto *resultIndices = IndexSubset::get (
3456
+ MF->getContext (), original->getLoweredFunctionType ()->getNumResults (),
3457
+ ArrayRef<unsigned >(parameterAndResultIndices)
3458
+ .take_back (numResultIndices));
3459
+
3460
+ AutoDiffConfig config (parameterIndices, resultIndices, derivativeGenSig);
3461
+ auto *diffWitness =
3462
+ SILMod.lookUpDifferentiabilityWitness ({originalName, config});
3463
+
3464
+ // Witnesses that we deserialize are always available externally; we never
3465
+ // want to emit them ourselves.
3466
+ auto linkage = swift::addExternalToLinkage (*linkageOpt);
3467
+
3468
+ // If there is no existing differentiability witness, create one.
3469
+ if (!diffWitness)
3470
+ diffWitness = SILDifferentiabilityWitness::createDeclaration (
3471
+ SILMod, linkage, original, parameterIndices, resultIndices,
3472
+ derivativeGenSig);
3473
+
3474
+ // If the current differentiability witness is merely a declaration, and the
3475
+ // deserialized witness is a definition, upgrade the current differentiability
3476
+ // witness to a definition. This can happen in the following situations:
3477
+ // 1. The witness was just created above.
3478
+ // 2. The witness started out as a declaration (e.g. the differentiation
3479
+ // pass emitted a witness for an external function) and now we're loading
3480
+ // the definition (e.g. an optimization pass asked for the definition and
3481
+ // we found the definition serialized in this module).
3482
+ if (diffWitness->isDeclaration () && !isDeclaration)
3483
+ diffWitness->convertToDefinition (jvp, vjp, isSerialized);
3484
+
3485
+ diffWitnessOrOffset.set (diffWitness,
3486
+ /* isFullyDeserialized*/ diffWitness->isDefinition ());
3487
+ return diffWitness;
3488
+ }
3489
+
3490
+ SILDifferentiabilityWitness *SILDeserializer::lookupDifferentiabilityWitness (
3491
+ StringRef mangledDiffWitnessKey) {
3492
+ if (!DifferentiabilityWitnessList)
3493
+ return nullptr ;
3494
+ auto iter = DifferentiabilityWitnessList->find (mangledDiffWitnessKey);
3495
+ if (iter == DifferentiabilityWitnessList->end ())
3496
+ return nullptr ;
3497
+ return readDifferentiabilityWitness (*iter);
3498
+ }
3499
+
3500
+ void SILDeserializer::getAllDifferentiabilityWitnesses () {
3501
+ if (!DifferentiabilityWitnessList)
3502
+ return ;
3503
+ for (unsigned I = 0 , E = DifferentiabilityWitnesses.size (); I < E; ++I)
3504
+ readDifferentiabilityWitness (I + 1 );
3505
+ }
3506
+
3346
3507
SILDeserializer::~SILDeserializer () {
3347
3508
// Drop our references to anything we've deserialized.
3348
3509
for (auto &fnEntry : Funcs) {
0 commit comments