@@ -166,18 +166,20 @@ void DifferentiableActivityInfo::propagateVaried(
166
166
setVariedAndPropagateToUsers (teai, i);
167
167
}
168
168
}
169
- // Handle `struct_extract` and `struct_element_addr` instructions.
169
+ // Handle element projection instructions:
170
+ // `struct_extract`, `struct_element_addr`, `ref_element_addr`.
170
171
// - If the field is marked `@noDerivative`, do not set the result as
171
172
// varied because it does not need a derivative.
172
173
// - Otherwise, propagate variedness from operand to result as usual.
173
- #define PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION (INST ) \
174
- else if (auto *sei = dyn_cast<INST##Inst>(inst)) { \
175
- if (isVaried (sei ->getOperand (), i) && \
176
- !sei ->getField ()->getAttrs ().hasAttribute <NoDerivativeAttr>()) \
177
- setVariedAndPropagateToUsers (sei , i); \
174
+ #define PROPAGATE_VARIED_FOR_ELEMENT_PROJECTION (INST ) \
175
+ else if (auto *projInst = dyn_cast<INST##Inst>(inst)) { \
176
+ if (isVaried (projInst ->getOperand (), i) && \
177
+ !projInst ->getField ()->getAttrs ().hasAttribute <NoDerivativeAttr>()) \
178
+ setVariedAndPropagateToUsers (projInst , i); \
178
179
}
179
- PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION (StructExtract)
180
- PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION (StructElementAddr)
180
+ PROPAGATE_VARIED_FOR_ELEMENT_PROJECTION (StructExtract)
181
+ PROPAGATE_VARIED_FOR_ELEMENT_PROJECTION (StructElementAddr)
182
+ PROPAGATE_VARIED_FOR_ELEMENT_PROJECTION (RefElementAddr)
181
183
#undef PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION
182
184
// Handle `br`.
183
185
else if (auto *bi = dyn_cast<BranchInst>(inst)) {
@@ -222,13 +224,14 @@ static Optional<AccessorKind> getAccessorKind(SILFunction *fn) {
222
224
void DifferentiableActivityInfo::propagateVariedInwardsThroughProjections (
223
225
SILValue value, unsigned independentVariableIndex) {
224
226
auto i = independentVariableIndex;
225
- // Skip `@noDerivative` struct projections.
227
+ // Skip `@noDerivative` projections.
226
228
#define SKIP_NODERIVATIVE (INST ) \
227
- if (auto *sei = dyn_cast<INST##Inst>(value)) \
228
- if (sei ->getField ()->getAttrs ().hasAttribute <NoDerivativeAttr>()) \
229
+ if (auto *projInst = dyn_cast<INST##Inst>(value)) \
230
+ if (projInst ->getField ()->getAttrs ().hasAttribute <NoDerivativeAttr>()) \
229
231
return ;
230
232
SKIP_NODERIVATIVE (StructExtract)
231
233
SKIP_NODERIVATIVE (StructElementAddr)
234
+ SKIP_NODERIVATIVE (RefElementAddr)
232
235
#undef SKIP_NODERIVATIVE
233
236
// Set value as varied and propagate to users.
234
237
setVariedAndPropagateToUsers (value, i);
@@ -274,7 +277,8 @@ void DifferentiableActivityInfo::setUsefulAndPropagateToOperands(
274
277
// Skip already-useful values to prevent infinite recursion.
275
278
if (isUseful (value, dependentVariableIndex))
276
279
return ;
277
- if (value->getType ().isAddress ()) {
280
+ if (value->getType ().isAddress () ||
281
+ value->getType ().getClassOrBoundGenericClass ()) {
278
282
propagateUsefulThroughAddress (value, dependentVariableIndex);
279
283
return ;
280
284
}
@@ -331,15 +335,16 @@ void DifferentiableActivityInfo::propagateUseful(
331
335
PROPAGATE_USEFUL_THROUGH_STORE (CopyAddr)
332
336
PROPAGATE_USEFUL_THROUGH_STORE (UnconditionalCheckedCastAddr)
333
337
#undef PROPAGATE_USEFUL_THROUGH_STORE
334
- // Handle struct element extraction , skipping `@noDerivative` fields:
335
- // `struct_extract`, `struct_element_addr`.
336
- #define PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION (INST ) \
337
- else if (auto *sei = dyn_cast<INST##Inst>(inst)) { \
338
- if (!sei ->getField ()->getAttrs ().hasAttribute <NoDerivativeAttr>()) \
339
- setUsefulAndPropagateToOperands (sei ->getOperand (), i); \
338
+ // Handle element projections , skipping `@noDerivative` fields:
339
+ // `struct_extract`, `struct_element_addr`, `ref_element_addr` .
340
+ #define PROPAGATE_USEFUL_THROUGH_ELEMENT_PROJECTION (INST ) \
341
+ else if (auto *projInst = dyn_cast<INST##Inst>(inst)) { \
342
+ if (!projInst ->getField ()->getAttrs ().hasAttribute <NoDerivativeAttr>()) \
343
+ setUsefulAndPropagateToOperands (projInst ->getOperand (), i); \
340
344
}
341
- PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION (StructExtract)
342
- PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION (StructElementAddr)
345
+ PROPAGATE_USEFUL_THROUGH_ELEMENT_PROJECTION (StructExtract)
346
+ PROPAGATE_USEFUL_THROUGH_ELEMENT_PROJECTION (StructElementAddr)
347
+ PROPAGATE_USEFUL_THROUGH_ELEMENT_PROJECTION (RefElementAddr)
343
348
#undef PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION
344
349
// Handle everything else.
345
350
else {
@@ -350,7 +355,8 @@ void DifferentiableActivityInfo::propagateUseful(
350
355
351
356
void DifferentiableActivityInfo::propagateUsefulThroughAddress (
352
357
SILValue value, unsigned dependentVariableIndex) {
353
- assert (value->getType ().isAddress ());
358
+ assert (value->getType ().isAddress () ||
359
+ value->getType ().getClassOrBoundGenericClass ());
354
360
// Skip already-useful values to prevent infinite recursion.
355
361
if (isUseful (value, dependentVariableIndex))
356
362
return ;
@@ -364,13 +370,15 @@ void DifferentiableActivityInfo::propagateUsefulThroughAddress(
364
370
propagateUseful (use->getUser (), dependentVariableIndex);
365
371
for (auto res : use->getUser ()->getResults ()) {
366
372
#define SKIP_NODERIVATIVE (INST ) \
367
- if (auto *sei = dyn_cast<INST##Inst>(res)) \
368
- if (sei ->getField ()->getAttrs ().hasAttribute <NoDerivativeAttr>()) \
373
+ if (auto *projInst = dyn_cast<INST##Inst>(res)) \
374
+ if (projInst ->getField ()->getAttrs ().hasAttribute <NoDerivativeAttr>()) \
369
375
continue ;
370
376
SKIP_NODERIVATIVE (StructExtract)
371
377
SKIP_NODERIVATIVE (StructElementAddr)
378
+ SKIP_NODERIVATIVE (RefElementAddr)
372
379
#undef SKIP_NODERIVATIVE
373
- if (Projection::isAddressProjection (res) || isa<BeginAccessInst>(res))
380
+ if (Projection::isAddressProjection (res) || isa<BeginAccessInst>(res) ||
381
+ isa<BeginBorrowInst>(res))
374
382
propagateUsefulThroughAddress (res, dependentVariableIndex);
375
383
}
376
384
}
0 commit comments