Skip to content

Commit 5b7a1c3

Browse files
authored
[AutoDiff] Support ref_element_addr differentiation. (#29749)
Support differentiation of `ref_element_addr`: class stored property references. Activity analysis: - Propagate activity for `ref_element_addr` like `struct_element_addr`. - Do not propagate activity for `ref_element_addr` to `@noDerivative` members. Pullback generation rules: ``` Original: y = ref_element_addr x, <n> Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0) ^~~~~~~ field in tangent space corresponding to #field ``` - Add `ref_element_addr` case to `PullbackEmitter::getAdjointProjection`. - The adjoint projection of a `ref_element_addr` is a local allocation initialized with the corresponding field value from the class's base adjoint value. Exposes TF-1149: cannot differentiate active value with loadable type but address-only `TangentVector` type. Diagnose for now. Resolves SR-12152.
1 parent 709d7d5 commit 5b7a1c3

File tree

12 files changed

+390
-78
lines changed

12 files changed

+390
-78
lines changed

include/swift/AST/DiagnosticsSIL.def

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,12 @@ WARNING(autodiff_nonvaried_result_fixit,none,
491491
"result does not depend on differentiation arguments and will always "
492492
"have a zero derivative; do you want to use 'withoutDerivative(at:)'?",
493493
())
494+
// TODO(TF-1149): Remove this diagnostic.
495+
NOTE(autodiff_loadable_value_addressonly_tangent_unsupported,none,
496+
"cannot yet differentiate value whose type %0 has a compile-time known "
497+
"size, but whose 'TangentVector' contains stored properties of unknown "
498+
"size; consider modifying %1 to use fewer generic parameters in stored "
499+
"properties", (Type, Type))
494500
NOTE(autodiff_enums_unsupported,none,
495501
"differentiating enum values is not yet supported", ())
496502
NOTE(autodiff_global_let_closure_not_differentiable,none,
@@ -543,9 +549,6 @@ NOTE(autodiff_jvp_control_flow_not_supported,none,
543549
"forward-mode differentiation does not yet support control flow", ())
544550
NOTE(autodiff_control_flow_not_supported,none,
545551
"cannot differentiate unsupported control flow", ())
546-
// TODO(TF-645): Remove when differentiation supports `ref_element_addr`.
547-
NOTE(autodiff_class_property_not_supported,none,
548-
"differentiating class properties is not yet supported", ())
549552
// TODO(TF-1080): Remove when differentiation supports `begin_apply`.
550553
NOTE(autodiff_coroutines_not_supported,none,
551554
"differentiation of coroutine calls is not yet supported", ())

include/swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ class DifferentiableActivityInfo {
157157
/// Propagates variedness from the given operand to its user's results.
158158
void propagateVaried(Operand *operand, unsigned independentVariableIndex);
159159
/// Marks the given value as varied and recursively propagates variedness
160-
/// inwards (to operands) through projections. Skips `@noDerivative` struct
161-
/// field projections.
160+
/// inwards (to operands) through projections. Skips `@noDerivative` field
161+
/// projections.
162162
void
163163
propagateVariedInwardsThroughProjections(SILValue value,
164164
unsigned independentVariableIndex);
@@ -172,9 +172,9 @@ class DifferentiableActivityInfo {
172172
unsigned dependentVariableIndex);
173173
/// Propagates usefulnesss to the operands of the given instruction.
174174
void propagateUseful(SILInstruction *inst, unsigned dependentVariableIndex);
175-
/// Marks the given address as useful and recursively propagates usefulness
176-
/// inwards (to operands) through projections. Skips `@noDerivative` struct
177-
/// field projections.
175+
/// Marks the given address or class-typed value as useful and recursively
176+
/// propagates usefulness inwards (to operands) through projections. Skips
177+
/// `@noDerivative` field projections.
178178
void propagateUsefulThroughAddress(SILValue value,
179179
unsigned dependentVariableIndex);
180180
/// If the given value is an `array.uninitialized_intrinsic` application,

include/swift/SILOptimizer/Utils/Differentiation/PullbackEmitter.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,17 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
266266

267267
SILBasicBlock::iterator getNextFunctionLocalAllocationInsertionPoint();
268268

269+
/// Creates and returns a local allocation with the given type.
270+
///
271+
/// Local allocations are created uninitialized in the pullback entry and
272+
/// deallocated in the pullback exit. All local allocations not in
273+
/// `destroyedLocalAllocations` are also destroyed in the pullback exit.
274+
AllocStackInst *createFunctionLocalAllocation(SILType type, SILLocation loc);
275+
269276
SILValue &getAdjointBuffer(SILBasicBlock *origBB, SILValue originalBuffer);
270277

271-
// Accumulates `rhsBufferAccess` into the adjoint buffer corresponding to
272-
// `originalBuffer`.
278+
/// Accumulates `rhsBufferAccess` into the adjoint buffer corresponding to
279+
/// `originalBuffer`.
273280
void addToAdjointBuffer(SILBasicBlock *origBB, SILValue originalBuffer,
274281
SILValue rhsBufferAccess, SILLocation loc);
275282

@@ -354,6 +361,13 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
354361
/// field in tangent space corresponding to #field
355362
void visitStructExtractInst(StructExtractInst *sei);
356363

364+
/// Handle `ref_element_addr` instruction.
365+
/// Original: y = ref_element_addr x, <n>
366+
/// Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0)
367+
/// ^~~~~~~
368+
/// field in tangent space corresponding to #field
369+
void visitRefElementAddrInst(RefElementAddrInst *reai);
370+
357371
/// Handle `tuple` instruction.
358372
/// Original: y = tuple (x0, x1, x2, ...)
359373
/// Adjoint: (adj[x0], adj[x1], adj[x2], ...) += destructure_tuple adj[y]
@@ -421,8 +435,6 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
421435
UnconditionalCheckedCastAddrInst *uccai);
422436

423437
#define NOT_DIFFERENTIABLE(INST, DIAG) void visit##INST##Inst(INST##Inst *inst);
424-
425-
NOT_DIFFERENTIABLE(RefElementAddr, autodiff_class_property_not_supported)
426438
#undef NOT_DIFFERENTIABLE
427439

428440
#define NO_ADJOINT(INST) \

lib/SIL/SILFunctionType.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,12 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
275275
switch (origResConv) {
276276
case ResultConvention::Owned:
277277
case ResultConvention::Autoreleased:
278-
conv = tl.isTrivial()
279-
? ParameterConvention::Direct_Unowned
280-
: ParameterConvention::Direct_Guaranteed;
278+
if (tl.isAddressOnly()) {
279+
conv = ParameterConvention::Indirect_In_Guaranteed;
280+
} else {
281+
conv = tl.isTrivial() ? ParameterConvention::Direct_Unowned
282+
: ParameterConvention::Direct_Guaranteed;
283+
}
281284
break;
282285
case ResultConvention::Unowned:
283286
case ResultConvention::UnownedInnerPointer:
@@ -301,9 +304,12 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
301304
case ParameterConvention::Direct_Owned:
302305
case ParameterConvention::Direct_Guaranteed:
303306
case ParameterConvention::Direct_Unowned:
304-
conv = tl.isTrivial()
305-
? ResultConvention::Unowned
306-
: ResultConvention::Owned;
307+
if (tl.isAddressOnly()) {
308+
conv = ResultConvention::Indirect;
309+
} else {
310+
conv = tl.isTrivial() ? ResultConvention::Unowned
311+
: ResultConvention::Owned;
312+
}
307313
break;
308314
case ParameterConvention::Indirect_In:
309315
case ParameterConvention::Indirect_Inout:

lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -166,18 +166,20 @@ void DifferentiableActivityInfo::propagateVaried(
166166
setVariedAndPropagateToUsers(teai, i);
167167
}
168168
}
169-
// Handle `struct_extract` and `struct_element_addr` instructions.
169+
// Handle element projection instructions:
170+
// `struct_extract`, `struct_element_addr`, `ref_element_addr`.
170171
// - If the field is marked `@noDerivative`, do not set the result as
171172
// varied because it does not need a derivative.
172173
// - Otherwise, propagate variedness from operand to result as usual.
173-
#define PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(INST) \
174-
else if (auto *sei = dyn_cast<INST##Inst>(inst)) { \
175-
if (isVaried(sei->getOperand(), i) && \
176-
!sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>()) \
177-
setVariedAndPropagateToUsers(sei, i); \
174+
#define PROPAGATE_VARIED_FOR_ELEMENT_PROJECTION(INST) \
175+
else if (auto *projInst = dyn_cast<INST##Inst>(inst)) { \
176+
if (isVaried(projInst->getOperand(), i) && \
177+
!projInst->getField()->getAttrs().hasAttribute<NoDerivativeAttr>()) \
178+
setVariedAndPropagateToUsers(projInst, i); \
178179
}
179-
PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(StructExtract)
180-
PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(StructElementAddr)
180+
PROPAGATE_VARIED_FOR_ELEMENT_PROJECTION(StructExtract)
181+
PROPAGATE_VARIED_FOR_ELEMENT_PROJECTION(StructElementAddr)
182+
PROPAGATE_VARIED_FOR_ELEMENT_PROJECTION(RefElementAddr)
181183
#undef PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION
182184
// Handle `br`.
183185
else if (auto *bi = dyn_cast<BranchInst>(inst)) {
@@ -222,13 +224,14 @@ static Optional<AccessorKind> getAccessorKind(SILFunction *fn) {
222224
void DifferentiableActivityInfo::propagateVariedInwardsThroughProjections(
223225
SILValue value, unsigned independentVariableIndex) {
224226
auto i = independentVariableIndex;
225-
// Skip `@noDerivative` struct projections.
227+
// Skip `@noDerivative` projections.
226228
#define SKIP_NODERIVATIVE(INST) \
227-
if (auto *sei = dyn_cast<INST##Inst>(value)) \
228-
if (sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>()) \
229+
if (auto *projInst = dyn_cast<INST##Inst>(value)) \
230+
if (projInst->getField()->getAttrs().hasAttribute<NoDerivativeAttr>()) \
229231
return;
230232
SKIP_NODERIVATIVE(StructExtract)
231233
SKIP_NODERIVATIVE(StructElementAddr)
234+
SKIP_NODERIVATIVE(RefElementAddr)
232235
#undef SKIP_NODERIVATIVE
233236
// Set value as varied and propagate to users.
234237
setVariedAndPropagateToUsers(value, i);
@@ -274,7 +277,8 @@ void DifferentiableActivityInfo::setUsefulAndPropagateToOperands(
274277
// Skip already-useful values to prevent infinite recursion.
275278
if (isUseful(value, dependentVariableIndex))
276279
return;
277-
if (value->getType().isAddress()) {
280+
if (value->getType().isAddress() ||
281+
value->getType().getClassOrBoundGenericClass()) {
278282
propagateUsefulThroughAddress(value, dependentVariableIndex);
279283
return;
280284
}
@@ -331,15 +335,16 @@ void DifferentiableActivityInfo::propagateUseful(
331335
PROPAGATE_USEFUL_THROUGH_STORE(CopyAddr)
332336
PROPAGATE_USEFUL_THROUGH_STORE(UnconditionalCheckedCastAddr)
333337
#undef PROPAGATE_USEFUL_THROUGH_STORE
334-
// Handle struct element extraction, skipping `@noDerivative` fields:
335-
// `struct_extract`, `struct_element_addr`.
336-
#define PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(INST) \
337-
else if (auto *sei = dyn_cast<INST##Inst>(inst)) { \
338-
if (!sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>()) \
339-
setUsefulAndPropagateToOperands(sei->getOperand(), i); \
338+
// Handle element projections, skipping `@noDerivative` fields:
339+
// `struct_extract`, `struct_element_addr`, `ref_element_addr`.
340+
#define PROPAGATE_USEFUL_THROUGH_ELEMENT_PROJECTION(INST) \
341+
else if (auto *projInst = dyn_cast<INST##Inst>(inst)) { \
342+
if (!projInst->getField()->getAttrs().hasAttribute<NoDerivativeAttr>()) \
343+
setUsefulAndPropagateToOperands(projInst->getOperand(), i); \
340344
}
341-
PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(StructExtract)
342-
PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(StructElementAddr)
345+
PROPAGATE_USEFUL_THROUGH_ELEMENT_PROJECTION(StructExtract)
346+
PROPAGATE_USEFUL_THROUGH_ELEMENT_PROJECTION(StructElementAddr)
347+
PROPAGATE_USEFUL_THROUGH_ELEMENT_PROJECTION(RefElementAddr)
343348
#undef PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION
344349
// Handle everything else.
345350
else {
@@ -350,7 +355,8 @@ void DifferentiableActivityInfo::propagateUseful(
350355

351356
void DifferentiableActivityInfo::propagateUsefulThroughAddress(
352357
SILValue value, unsigned dependentVariableIndex) {
353-
assert(value->getType().isAddress());
358+
assert(value->getType().isAddress() ||
359+
value->getType().getClassOrBoundGenericClass());
354360
// Skip already-useful values to prevent infinite recursion.
355361
if (isUseful(value, dependentVariableIndex))
356362
return;
@@ -364,13 +370,15 @@ void DifferentiableActivityInfo::propagateUsefulThroughAddress(
364370
propagateUseful(use->getUser(), dependentVariableIndex);
365371
for (auto res : use->getUser()->getResults()) {
366372
#define SKIP_NODERIVATIVE(INST) \
367-
if (auto *sei = dyn_cast<INST##Inst>(res)) \
368-
if (sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>()) \
373+
if (auto *projInst = dyn_cast<INST##Inst>(res)) \
374+
if (projInst->getField()->getAttrs().hasAttribute<NoDerivativeAttr>()) \
369375
continue;
370376
SKIP_NODERIVATIVE(StructExtract)
371377
SKIP_NODERIVATIVE(StructElementAddr)
378+
SKIP_NODERIVATIVE(RefElementAddr)
372379
#undef SKIP_NODERIVATIVE
373-
if (Projection::isAddressProjection(res) || isa<BeginAccessInst>(res))
380+
if (Projection::isAddressProjection(res) || isa<BeginAccessInst>(res) ||
381+
isa<BeginBorrowInst>(res))
374382
propagateUsefulThroughAddress(res, dependentVariableIndex);
375383
}
376384
}

0 commit comments

Comments
 (0)