Skip to content

Commit 87f1bb1

Browse files
bartchr808rxwei
authored andcommitted
[AutoDiff] TF-508: Temporary Hack Fix (#25115)
[JIRA Ticket TF-508](https://bugs.swift.org/browse/TF-508) This ticket should not be closed, as this PR is a temporary hack to fix the problem and unblock me on `SIMD` vectors. The problem has to do with not being able to resolve `Self` for the actual underlying struct type. For example, in the reproducer in the JIRA ticket, it doesn't know about `MyStruct`. So here, we take a look at the first type, which we are assuming is `Self`, but more specfically the struct which lets us successfully check the requirements.
1 parent 57a92b1 commit 87f1bb1

File tree

2 files changed

+89
-22
lines changed

2 files changed

+89
-22
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ void DifferentiationInvoker::print(llvm::raw_ostream &os) const {
786786
// subsitution map and in the given module.
787787
static bool checkRequirementsSatisfied(
788788
ArrayRef<Requirement> requirements, SubstitutionMap substMap,
789-
ModuleDecl *swiftModule) {
789+
SILFunction *original, ModuleDecl *swiftModule) {
790790
if (requirements.empty())
791791
return true;
792792
// Jointly iterate through associated function requirements/conformances.
@@ -831,10 +831,27 @@ static bool checkRequirementsSatisfied(
831831
continue;
832832
}
833833
}
834+
Optional<ProtocolConformanceRef> conformance;
835+
// FIXME terrible hack: If LHS is a dependent member, try to resolve it
836+
// using the `Self` type, assuming the first replacement type is `Self`.
837+
if (auto depMemType = firstType->getAs<DependentMemberType>()) {
838+
if (original->hasSelfParam() &&
839+
!substMap.getReplacementTypes().empty()) {
840+
if (auto substType = depMemType->substBaseType(
841+
substMap.getReplacementTypes().front(),
842+
LookUpConformanceInModule(swiftModule))) {
843+
firstType = substType;
844+
}
845+
}
846+
conformance = swiftModule->conformsToProtocol(firstType, protocol);
847+
}
834848
// Otherwise, try to look up conformance in substitution maps.
835-
auto isConformanceMet = substMap.lookupConformance(
836-
firstType->getCanonicalType(), protocol);
837-
if (!isConformanceMet)
849+
else {
850+
conformance = substMap.lookupConformance(
851+
firstType->getCanonicalType(), protocol);
852+
}
853+
854+
if (!conformance)
838855
unsatisfiedRequirements.push_back(req);
839856
continue;
840857
}
@@ -2111,7 +2128,7 @@ emitAssociatedFunctionReference(
21112128
// TODO(TF-482): Change `lookupMinimalDifferentiableAttr`.
21122129
if (!checkRequirementsSatisfied(
21132130
minimalAttr->getRequirements(),
2114-
substMap, context.getModule().getSwiftModule())) {
2131+
substMap, originalFn, context.getModule().getSwiftModule())) {
21152132
context.emitNondifferentiabilityError(original, invoker,
21162133
diag::autodiff_function_assoc_func_requirements_unmet);
21172134
return None;

test/AutoDiff/generics.swift

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -57,25 +57,75 @@ struct SupervisedTrainer<Model : Layer> {
5757
}
5858

5959
// Tests TF-440.
60-
struct TF_440_Input<Input: Differentiable, State: Differentiable>: Differentiable {
61-
var input: Input
62-
var state: State
60+
struct TF_440_Input<Input: Differentiable, State: Differentiable>
61+
: Differentiable {
62+
var input: Input
63+
var state: State
6364
}
6465
struct TF_440<T : Differentiable> {
65-
@differentiable
66-
func applied(to input: TF_440_Input<Float, Float>) -> Float {
67-
return input.state
68-
}
69-
70-
@differentiable
71-
func applied(to input: TF_440_Input<T, Float>) -> Float {
72-
return input.state
73-
}
74-
75-
@differentiable
76-
func applied(to input: TF_440_Input<T, Float>) -> T {
77-
return input.input
78-
}
66+
@differentiable
67+
func applied(to input: TF_440_Input<Float, Float>) -> Float {
68+
return input.state
69+
}
70+
71+
@differentiable
72+
func applied(to input: TF_440_Input<T, Float>) -> Float {
73+
return input.state
74+
}
75+
76+
@differentiable
77+
func applied(to input: TF_440_Input<T, Float>) -> T {
78+
return input.input
79+
}
80+
}
81+
82+
// Tests TF-508
83+
protocol TF_508_Protocol {
84+
associatedtype Scalar: BinaryFloatingPoint
85+
}
86+
87+
extension TF_508_Protocol {
88+
@differentiable(vjp: _vjpAdd(lhs:rhs:)
89+
where Self : Differentiable,
90+
Scalar : Differentiable,
91+
Self.TangentVector : TF_508_Protocol)
92+
static func +(lhs: Self, rhs: Self) -> Self {
93+
return lhs
94+
}
95+
96+
static var zero: Self {
97+
fatalError()
98+
}
99+
100+
static func - (lhs: Self, rhs: Self) -> Self {
101+
fatalError()
102+
}
103+
}
104+
105+
extension TF_508_Protocol
106+
where Self : Differentiable,
107+
Scalar : Differentiable,
108+
TangentVector : TF_508_Protocol {
109+
static func _vjpAdd(lhs: Self, rhs: Self)
110+
-> (Self, (TangentVector) -> (TangentVector, TangentVector)) {
111+
return (lhs, { ($0, $0) })
112+
}
113+
}
114+
115+
struct TF_508_Struct<Scalar: BinaryFloatingPoint>
116+
: TF_508_Protocol & AdditiveArithmetic {}
117+
118+
extension TF_508_Struct : Differentiable
119+
where Scalar : Differentiable {
120+
typealias TangentVector = TF_508_Struct
121+
typealias AllDifferentiableVariables = TF_508_Struct
122+
}
123+
124+
let TF_508_inst = TF_508_Struct<Float>()
125+
func TF_508_func(x: TF_508_Struct<Float>, y: TF_508_Struct<Float>)
126+
-> TF_508_Struct<Float> {
127+
return x + y
79128
}
129+
let TF_508_bp = pullback(at: TF_508_inst, TF_508_inst, in: TF_508_func)
80130

81131
// TODO: add more tests.

0 commit comments

Comments
 (0)