Skip to content

Commit bbb7196

Browse files
committed
AutoDiff: Workaround for performing generic signature queries on the wrong signature
1 parent b57be6d commit bbb7196

File tree

9 files changed

+64
-24
lines changed

9 files changed

+64
-24
lines changed

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "swift/AST/Module.h"
2727
#include "swift/AST/ModuleLoader.h"
2828
#include "swift/AST/ProtocolConformance.h"
29+
#include "swift/AST/TypeCheckRequests.h"
2930
#include "swift/ClangImporter/ClangImporter.h"
3031
#include "swift/SIL/SILModule.h"
3132
#include "swift/SIL/SILType.h"
@@ -360,6 +361,41 @@ getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices,
360361
IndexSubset::get(C, parameterIndices->getCapacity(), inoutParamIndices);
361362
}
362363

364+
static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignature sig,
365+
CanType tanType) {
366+
if (!sig)
367+
return sig;
368+
369+
llvm::DenseSet<CanType> types;
370+
371+
auto &ctx = tanType->getASTContext();
372+
373+
(void) tanType.findIf([&](Type t) -> bool {
374+
if (auto *dmt = t->getAs<DependentMemberType>()) {
375+
if (dmt->getName() == ctx.Id_TangentVector)
376+
types.insert(dmt->getBase()->getCanonicalType());
377+
}
378+
379+
return false;
380+
});
381+
382+
SmallVector<Requirement, 2> reqs;
383+
auto *proto = ctx.getProtocol(KnownProtocolKind::Differentiable);
384+
assert(proto != nullptr);
385+
386+
for (auto type : types) {
387+
if (!sig->requiresProtocol(type, proto)) {
388+
reqs.push_back(Requirement(RequirementKind::Conformance, type,
389+
proto->getDeclaredInterfaceType()));
390+
}
391+
}
392+
393+
return evaluateOrDefault(
394+
ctx.evaluator,
395+
AbstractGenericSignatureRequest{sig.getPointer(), {}, reqs},
396+
GenericSignature()).getCanonicalSignature();
397+
}
398+
363399
/// Returns the differential type for the given original function type,
364400
/// parameter indices, and result index.
365401
static CanSILFunctionType getAutoDiffDifferentialType(
@@ -371,10 +407,11 @@ static CanSILFunctionType getAutoDiffDifferentialType(
371407
auto getTangentParameterConvention =
372408
[&](CanType tanType,
373409
ParameterConvention origParamConv) -> ParameterConvention {
374-
tanType =
375-
tanType->getCanonicalType(originalFnTy->getSubstGenericSignature());
376-
AbstractionPattern pattern(originalFnTy->getSubstGenericSignature(),
377-
tanType);
410+
auto sig = buildDifferentiableGenericSignature(
411+
originalFnTy->getSubstGenericSignature(), tanType);
412+
413+
tanType = tanType->getCanonicalType(sig);
414+
AbstractionPattern pattern(sig, tanType);
378415
auto &tl =
379416
TC.getTypeLowering(pattern, tanType, TypeExpansionContext::minimal());
380417
// When the tangent type is address only, we must ensure that the tangent
@@ -398,10 +435,11 @@ static CanSILFunctionType getAutoDiffDifferentialType(
398435
auto getTangentResultConvention =
399436
[&](CanType tanType,
400437
ResultConvention origResConv) -> ResultConvention {
401-
tanType =
402-
tanType->getCanonicalType(originalFnTy->getSubstGenericSignature());
403-
AbstractionPattern pattern(originalFnTy->getSubstGenericSignature(),
404-
tanType);
438+
auto sig = buildDifferentiableGenericSignature(
439+
originalFnTy->getSubstGenericSignature(), tanType);
440+
441+
tanType = tanType->getCanonicalType(sig);
442+
AbstractionPattern pattern(sig, tanType);
405443
auto &tl =
406444
TC.getTypeLowering(pattern, tanType, TypeExpansionContext::minimal());
407445
// When the tangent type is address only, we must ensure that the tangent
@@ -530,10 +568,11 @@ static CanSILFunctionType getAutoDiffPullbackType(
530568
auto getTangentParameterConventionForOriginalResult =
531569
[&](CanType tanType,
532570
ResultConvention origResConv) -> ParameterConvention {
533-
tanType =
534-
tanType->getCanonicalType(originalFnTy->getSubstGenericSignature());
535-
AbstractionPattern pattern(originalFnTy->getSubstGenericSignature(),
536-
tanType);
571+
auto sig = buildDifferentiableGenericSignature(
572+
originalFnTy->getSubstGenericSignature(), tanType);
573+
574+
tanType = tanType->getCanonicalType(sig);
575+
AbstractionPattern pattern(sig, tanType);
537576
auto &tl =
538577
TC.getTypeLowering(pattern, tanType, TypeExpansionContext::minimal());
539578
ParameterConvention conv;
@@ -560,10 +599,11 @@ static CanSILFunctionType getAutoDiffPullbackType(
560599
auto getTangentResultConventionForOriginalParameter =
561600
[&](CanType tanType,
562601
ParameterConvention origParamConv) -> ResultConvention {
563-
tanType =
564-
tanType->getCanonicalType(originalFnTy->getSubstGenericSignature());
565-
AbstractionPattern pattern(originalFnTy->getSubstGenericSignature(),
566-
tanType);
602+
auto sig = buildDifferentiableGenericSignature(
603+
originalFnTy->getSubstGenericSignature(), tanType);
604+
605+
tanType = tanType->getCanonicalType(sig);
606+
AbstractionPattern pattern(sig, tanType);
567607
auto &tl =
568608
TC.getTypeLowering(pattern, tanType, TypeExpansionContext::minimal());
569609
ResultConvention conv;

test/AutoDiff/SIL/Parse/sildeclref.sil

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-sil-opt %s -module-name=sildeclref_parse | %target-sil-opt -module-name=sildeclref_parse | %FileCheck %s
1+
// RUN: %target-sil-opt %s -module-name=sildeclref_parse -requirement-machine=off | %target-sil-opt -module-name=sildeclref_parse -requirement-machine=off | %FileCheck %s
22
// Parse AutoDiff derivative SILDeclRefs via `witness_method` and `class_method` instructions.
33

44
import Swift

test/AutoDiff/SILGen/vtable.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s
1+
// RUN: %target-swift-frontend -emit-silgen %s -requirement-machine=off | %FileCheck %s
22

33
// Test derivative function vtable entries for `@differentiable` class members:
44
// - Methods.

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-sil -verify %s
1+
// RUN: %target-swift-frontend -emit-sil -verify -requirement-machine=off %s
22

33
// Test differentiation transform diagnostics.
44

test/AutoDiff/SILOptimizer/semantic_member_accessors_sil.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-sil -Xllvm -sil-print-after=differentiation %s -module-name null -o /dev/null 2>&1 | %FileCheck %s
1+
// RUN: %target-swift-frontend -emit-sil -Xllvm -sil-print-after=differentiation %s -module-name null -o /dev/null -requirement-machine=off 2>&1 | %FileCheck %s
22

33
// Test differentiation of semantic member accessors:
44
// - Stored property accessors.

test/AutoDiff/compiler_crashers_fixed/sr12744-unhandled-pullback-indirect-result.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-sil -verify %s
1+
// RUN: %target-swift-frontend -emit-sil -verify -requirement-machine=off %s
22

33
// SR-12744: Pullback generation crash for unhandled indirect result.
44
// May be due to inconsistent derivative function type calculation logic in

test/AutoDiff/mangling.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-sil -enable-experimental-forward-mode-differentiation -module-name=mangling -verify %s | %FileCheck %s
1+
// RUN: %target-swift-frontend -emit-sil -enable-experimental-forward-mode-differentiation -module-name=mangling -verify -requirement-machine=off %s | %FileCheck %s
22

33
import _Differentiation
44

test/AutoDiff/validation-test/forward_mode_simple.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation)
1+
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation -Xfrontend -requirement-machine=off)
22
// REQUIRES: executable_test
33

44
import StdlibUnittest

test/AutoDiff/validation-test/inout_parameters.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-run-simple-swift
1+
// RUN: %target-run-simple-swift(-Xfrontend -requirement-machine=off)
22
// REQUIRES: executable_test
33

44
// Would fail due to unavailability of swift_autoDiffCreateLinearMapContext.

0 commit comments

Comments
 (0)