Skip to content

Commit 0ace614

Browse files
committed
Fix bug and add test case.
1 parent 87f1bb1 commit 0ace614

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3235,7 +3235,8 @@ class VJPEmitter final
32353235
// This instruction is active. Determine the appropriate differentiation
32363236
// strategy, and use it.
32373237
auto *structDecl = sei->getStructDecl();
3238-
if (structDecl->getAttrs().hasAttribute<FieldwiseDifferentiableAttr>()) {
3238+
if (structDecl->getEffectiveAccess() <= AccessLevel::Internal
3239+
|| structDecl->getAttrs().hasAttribute<FieldwiseDifferentiableAttr>()) {
32393240
strategies[sei] = StructExtractDifferentiationStrategy::Fieldwise;
32403241
SILClonerWithScopes::visitStructExtractInst(sei);
32413242
return;

test/AutoDiff/generics.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,4 +128,16 @@ func TF_508_func(x: TF_508_Struct<Float>, y: TF_508_Struct<Float>)
128128
}
129129
let TF_508_bp = pullback(at: TF_508_inst, TF_508_inst, in: TF_508_func)
130130

131+
// TF-523
132+
struct A : Differentiable & AdditiveArithmetic {
133+
var a: Float = 1
134+
typealias TangentVector = A
135+
typealias AllDifferentiableVariables = A
136+
}
137+
138+
@differentiable
139+
func f(_ x: A) -> Float {
140+
return x.a * 2
141+
}
142+
131143
// TODO: add more tests.

0 commit comments

Comments
 (0)