Skip to content

Commit 571e9d9

Browse files
committed
Fix @differentiable attribute requirement type-checking in non-primary files.
1 parent a26b443 commit 571e9d9

File tree

3 files changed

+75
-0
lines changed

3 files changed

+75
-0
lines changed

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,17 @@ swift::matchWitness(
544544
auto *witnessAFD = dyn_cast<AbstractFunctionDecl>(witness);
545545
if (auto *witnessASD = dyn_cast<AbstractStorageDecl>(witness))
546546
witnessAFD = witnessASD->getAccessor(AccessorKind::Get);
547+
// NOTE: Validate `@differentiable` attributes by calling
548+
// `getParameterIndices`. This is important for type-checking
549+
// `@differentiable` attributes in non-primary files to skip invalid
550+
// attributes and to resolve derivative configurations, used below.
551+
for (auto *witnessDiffAttr :
552+
witnessAttrs.getAttributes<DifferentiableAttr>()) {
553+
(void)witnessDiffAttr->getParameterIndices();
554+
}
555+
for (auto *reqDiffAttr : reqAttrs.getAttributes<DifferentiableAttr>()) {
556+
(void)reqDiffAttr->getParameterIndices();
557+
}
547558
for (auto *reqDiffAttr : reqAttrs.getAttributes<DifferentiableAttr>()) {
548559
bool foundExactAttr = false;
549560
bool foundSupersetAttr = false;
@@ -583,6 +594,14 @@ swift::matchWitness(
583594
}
584595
}
585596
if (!success) {
597+
LLVM_DEBUG({
598+
llvm::dbgs() << "Protocol requirement match failure: missing "
599+
"`@differentiable` attribute for witness ";
600+
witnessAFD->dumpRef(llvm::dbgs());
601+
llvm::dbgs() << " from requirement ";
602+
req->dumpRef(llvm::dbgs());
603+
llvm::dbgs() << '\n';
604+
});
586605
if (auto *vdWitness = dyn_cast<VarDecl>(witness))
587606
return RequirementMatch(
588607
getStandinForAccessor(vdWitness, AccessorKind::Get),
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
public protocol Layer: Differentiable {
2+
associatedtype Input: Differentiable
3+
associatedtype Output: Differentiable
4+
5+
@differentiable
6+
func instanceMethod(_ input: Input) -> Output
7+
8+
@differentiable
9+
var computedProperty: Output { get }
10+
}
11+
12+
struct DummyLayer: Layer {
13+
@differentiable
14+
func instanceMethod(_ input: Float) -> Float {
15+
return input
16+
}
17+
18+
@differentiable
19+
var computedProperty: Float { 1 }
20+
}
21+
22+
public extension Differentiable {
23+
@differentiable
24+
func sequenced<L: Layer>(through layer: L) -> L.Output where L.Input == Self {
25+
return layer.instanceMethod(self)
26+
}
27+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: %target-swift-frontend -typecheck -verify %S/Inputs/differentiable_attr_type_checking_non_primary_file.swift -primary-file %s
2+
3+
// Test TF-1043: Type-checking protocol requirement `@differentiable` attributes
4+
// from non-primary files.
5+
6+
struct OuterLayer: Layer {
7+
typealias Input = Float
8+
typealias Output = Float
9+
10+
var dummy: DummyLayer
11+
12+
@differentiable
13+
var computedProperty: Output {
14+
// NOTE(TF-1043): Old misleading error:
15+
// error: 'Int' is not convertible to 'Float'
16+
// return Float(1).sequenced(through: dummy)
17+
// ^~~~~~~~
18+
return Float(1).sequenced(through: dummy)
19+
}
20+
21+
@differentiable
22+
func instanceMethod(_ input: Input) -> Output {
23+
// NOTE(TF-1043): Old misleading error:
24+
// error: type of expression is ambiguous without more context
25+
// return input.sequenced(through: dummy)
26+
// ~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~
27+
return input.sequenced(through: dummy)
28+
}
29+
}

0 commit comments

Comments
 (0)