Skip to content

Commit e9aa913

Browse files
authored
[AutoDiff] [stdlib] Deprecate 'CotangentVector' in favor of 'TangentVector'. (#24825)
This PR removes the `CotangentVector` associated type and make it equal to `TangentVector`. Mathematically, types conforming to the `Differentiable` protocol represent a Riemannian manifold, whose metric is inner products, which provide an isomorphism between a tangent space and a cotangent space at a point. Some theoretical and practical reasons: * It is not possible to correctly represent dual vectors in Swift today because functions cannot conform to protocols (`(TangentVector) -> Scalar` cannot conform to `AdditiveArithmetic`). * In computation, it is always more practical to compute numerical tangent vectors instead of true cotangent vectors (partially applied dot product functions). * The mutually recursive generic constraints triggered a significant bug in the type checker ([SR-9595](https://bugs.swift.org/browse/SR-9595)) that required a lot of workarounds ([TF-213](https://bugs.swift.org/browse/TF-213)), in particular, an ugly split of 3 protocols: `__Differentiable`, `_Differentiable`, and `Differentiable`. * The split between `TangentVector` and `CotangentVector` makes user-defined conformances complicated. Changes include: * Remove `associatedtype CotangentVector` from `Differentiable`. * Define `typealias CotangentVector = TangentVector` in `Differentiable` with a deprecation message. * Merge `__Differentiable` and `_Differentiable` back into `Differentiable`. * Replace all mentions of "cotangent" with "tangent" in the code base. * Adapt derived conformances logic. * Adapt tests. Defines away [TF-213](https://bugs.swift.org/browse/TF-213).
1 parent f641a58 commit e9aa913

39 files changed

+452
-767
lines changed

include/swift/AST/ASTContext.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ namespace swift {
107107
class VarDecl;
108108
class UnifiedStatsReporter;
109109
// SWIFT_ENABLE_TENSORFLOW
110-
enum class AutoDiffAssociatedVectorSpaceKind : unsigned;
111110
class VectorSpace;
112111
class AutoDiffParameterIndices;
113112
class DifferentiableAttr;
@@ -276,8 +275,7 @@ class ASTContext final {
276275
llvm::StringMap<Type> RemappedTypes;
277276

278277
/// Cache of autodiff-associated vector spaces.
279-
llvm::DenseMap<std::pair<Type, unsigned>,
280-
Optional<VectorSpace>> AutoDiffVectorSpaces;
278+
llvm::DenseMap<Type, Optional<VectorSpace>> AutoDiffVectorSpaces;
281279

282280
/// Cache of `@differentiable` attributes keyed by parameter indices. This
283281
/// helps us diagnose multiple `@differentiable`s that are with respect to the

include/swift/AST/AutoDiff.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -551,11 +551,6 @@ class AutoDiffAssociatedFunctionIdentifier : public llvm::FoldingSetNode {
551551
}
552552
};
553553

554-
/// The kind of an associated type.
555-
enum class AutoDiffAssociatedVectorSpaceKind : unsigned {
556-
Tangent = 0, Cotangent = 1
557-
};
558-
559554
/// Automatic differentiation utility namespace.
560555
namespace autodiff {
561556

include/swift/AST/DiagnosticsSema.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2715,7 +2715,7 @@ NOTE(protocol_witness_missing_specific_differentiable_attr,none,
27152715
// @differentiating
27162716
ERROR(differentiating_attr_expected_result_tuple,none,
27172717
"'@differentiating' attribute requires function to return a two-element tuple of type "
2718-
"'(value: T..., pullback: (U.CotangentVector) -> T.CotangentVector...)' or "
2718+
"'(value: T..., pullback: (U.TangentVector) -> T.TangentVector...)' or "
27192719
"'(value: T..., differential: (T.TangentVector...) -> U.TangentVector)'", ())
27202720
ERROR(differentiating_attr_invalid_result_tuple_value_label,none,
27212721
"'@differentiating' attribute requires function to return a two-element "

include/swift/AST/KnownIdentifiers.def

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,9 @@ IDENTIFIER(zero)
135135
IDENTIFIER(Scalar)
136136
// Differentiable
137137
IDENTIFIER(AllDifferentiableVariables)
138-
IDENTIFIER(CotangentVector)
139138
IDENTIFIER(TangentVector)
140139
IDENTIFIER(allDifferentiableVariables)
141140
IDENTIFIER(moved)
142-
IDENTIFIER(tangentVector)
143141

144142
// Kinds of layout constraints
145143
IDENTIFIER_WITH_NAME(UnknownLayout, "_UnknownLayout")

include/swift/AST/KnownProtocols.def

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,6 @@ PROTOCOL(TensorGroup)
8686
PROTOCOL_(TensorFlowDataTypeCompatible)
8787
PROTOCOL(TensorProtocol)
8888
PROTOCOL(VectorNumeric)
89-
// TODO(TF-213): Remove underscore `Differentiable` protocols.
90-
PROTOCOL(__Differentiable)
91-
PROTOCOL(_Differentiable)
9289
PROTOCOL(Differentiable)
9390

9491
PROTOCOL_(ObjectiveCBridgeable)

include/swift/AST/Types.h

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,20 +1096,17 @@ class alignas(1 << TypeAlignInBits) TypeBase {
10961096
TypeTraitResult canBeClass();
10971097

10981098
// SWIFT_ENABLE_TENSORFLOW
1099-
/// Return the associated tangent or cotangent type. Return the null type if
1100-
/// there is no associated tangent/cotangent type.
1101-
///
1102-
/// `kind` specifies whether to return the tangent or cotangent type.
1099+
/// Return the associated tangent type. Return the null type if there is no
1100+
/// associated tangent type.
11031101
///
11041102
/// If the type conforms to `Differentiable`, then the associated
1105-
/// tangent/cotangent type is the associated `TangentVector`/`CotangentVector`
1106-
/// from the `Differentiable` requirement. If the type is a tuple, then the
1107-
/// associated tangent/cotangent type is the elementwise tangent/cotangent
1108-
/// type of its elements. If the type is a builtin float, then the associated
1109-
/// tangent/cotangent type is itself. Otherwise, there is no associated type.
1103+
/// tangent type is the associated `TangentVector` from the `Differentiable`
1104+
/// requirement. If the type is a tuple, then the associated tangent type is
1105+
/// the elementwise tangent type of its elements. If the type is a builtin
1106+
/// float, then the associated tangent type is itself. Otherwise, there is no
1107+
/// associated type.
11101108
Optional<VectorSpace>
1111-
getAutoDiffAssociatedVectorSpace(AutoDiffAssociatedVectorSpaceKind kind,
1112-
LookupConformanceFn lookupConformance);
1109+
getAutoDiffAssociatedTangentSpace(LookupConformanceFn lookupConformance);
11131110

11141111
private:
11151112
// Make vanilla new/delete illegal for Types.
@@ -3074,12 +3071,12 @@ class AnyFunctionType : public TypeBase {
30743071
///
30753072
/// By default, if the original type has a self parameter list and parameter
30763073
/// indices include self, the computed associated function type will return a
3077-
/// linear map taking/returning self's tangent/cotangent *last* instead of
3078-
/// first, for consistency with SIL.
3074+
/// linear map taking/returning self's tangent *last* instead of first, for
3075+
/// consistency with SIL.
30793076
///
3080-
/// If `makeSelfParamFirst` is true, self's tangent/cotangent is reordered to
3081-
/// appear first. This should be used during type-checking, e.g.
3082-
/// type-checking `@differentiable` and `@differentiating` attributes.
3077+
/// If `makeSelfParamFirst` is true, self's tangent is reordered to appear
3078+
/// first. This should be used during type-checking, e.g. type-checking
3079+
/// `@differentiable` and `@differentiating` attributes.
30833080
///
30843081
/// \note The original function type (`self`) need not be `@differentiable`.
30853082
/// The resulting function will preserve all `ExtInfo` of the original

lib/AST/Builtins.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -985,7 +985,7 @@ static ValueDecl *getAutoDiffApplyAssociatedFunction(
985985
// rethrows -> (R, (...T.TangentVector) -> R.TangentVector)
986986
// VJP:
987987
// <...T...(arity), R> (@differentiable (...T) throws -> R, ...T)
988-
// rethrows -> (R, (R.CotangentVector) -> ...T.CotangentVector)
988+
// rethrows -> (R, (R.TangentVector) -> ...T.TangentVector)
989989
unsigned numGenericParams = 1 + arity;
990990
BuiltinGenericSignatureBuilder builder(Context, numGenericParams);
991991
// Look up the Differentiable protocol.

lib/AST/Type.cpp

Lines changed: 28 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4364,13 +4364,12 @@ makeFunctionType(AnyFunctionType *copy, ArrayRef<AnyFunctionType::Param> params,
43644364
return FunctionType::get(params, retTy, copy->getExtInfo());
43654365
}
43664366

4367-
Optional<VectorSpace> TypeBase::getAutoDiffAssociatedVectorSpace(
4368-
AutoDiffAssociatedVectorSpaceKind kind,
4367+
Optional<VectorSpace> TypeBase::getAutoDiffAssociatedTangentSpace(
43694368
LookupConformanceFn lookupConformance) {
43704369
assert(lookupConformance);
43714370
auto &ctx = getASTContext();
43724371

4373-
std::pair<Type, unsigned> cacheKey {this, (unsigned)kind};
4372+
Type cacheKey = this;
43744373
auto lookup = ctx.AutoDiffVectorSpaces.find(cacheKey);
43754374
if (lookup != ctx.AutoDiffVectorSpaces.end())
43764375
return lookup->getSecond();
@@ -4379,24 +4378,24 @@ Optional<VectorSpace> TypeBase::getAutoDiffAssociatedVectorSpace(
43794378
return vs;
43804379
};
43814380

4382-
// Functions' tangent/cotangent is the same function except the innermost
4383-
// return type being replaced by its tangent/cotangent.
4381+
// Functions' tangent is the same function except the innermost return type
4382+
// being replaced by its tangent.
43844383
if (auto *fnTy = getAs<AnyFunctionType>()) {
4385-
auto resultSpace = fnTy->getResult()->getAutoDiffAssociatedVectorSpace(
4386-
kind, lookupConformance);
4384+
auto resultSpace = fnTy->getResult()->getAutoDiffAssociatedTangentSpace(
4385+
lookupConformance);
43874386
if (!resultSpace)
43884387
return cache(None);
43894388
return cache(VectorSpace::getFunction(
43904389
makeFunctionType(fnTy, fnTy->getParams(), resultSpace->getType(),
43914390
fnTy->getOptGenericSignature())));
43924391
}
43934392

4394-
// Tuples' tangent/cotangent is a tuple of each element's Tangent/Cotangent.
4393+
// Tuples' tangent is a tuple of each element's Tangent.
43954394
if (auto *tupleTy = getAs<TupleType>()) {
43964395
SmallVector<TupleTypeElt, 8> newElts;
43974396
for (auto elt : tupleTy->getElements()) {
43984397
auto eltSpace = elt.getType()
4399-
->getAutoDiffAssociatedVectorSpace(kind, lookupConformance);
4398+
->getAutoDiffAssociatedTangentSpace(lookupConformance);
44004399
if (!eltSpace)
44014400
continue;
44024401
newElts.push_back(elt.getWithType(eltSpace->getType()));
@@ -4410,22 +4409,12 @@ Optional<VectorSpace> TypeBase::getAutoDiffAssociatedVectorSpace(
44104409
return cache(VectorSpace::getTuple(tupleType));
44114410
}
44124411

4413-
// Find the TangentVector/CotangentVector associated type on the
4414-
// Differentiable protocol.
4412+
// Find the TangentVector associated type on the Differentiable protocol.
44154413
auto *differentiableProtocol =
4416-
ctx.getProtocol(KnownProtocolKind::__Differentiable);
4417-
assert(differentiableProtocol && "Could not find __Differentiable protocol");
4418-
Identifier associatedTypeIdentifier;
4419-
switch (kind) {
4420-
case AutoDiffAssociatedVectorSpaceKind::Tangent:
4421-
associatedTypeIdentifier = ctx.Id_TangentVector;
4422-
break;
4423-
case AutoDiffAssociatedVectorSpaceKind::Cotangent:
4424-
associatedTypeIdentifier = ctx.Id_CotangentVector;
4425-
break;
4426-
}
4414+
ctx.getProtocol(KnownProtocolKind::Differentiable);
4415+
assert(differentiableProtocol && "Could not find Differentiable protocol");
44274416
auto associatedTypeLookup =
4428-
differentiableProtocol->lookupDirect(associatedTypeIdentifier);
4417+
differentiableProtocol->lookupDirect(ctx.Id_TangentVector);
44294418
assert(associatedTypeLookup.size() == 1);
44304419
auto *dependentType = DependentMemberType::get(
44314420
differentiableProtocol->getDeclaredInterfaceType(),
@@ -4448,7 +4437,7 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
44484437
// JVP: (T...) -> ((R...),
44494438
// (T.TangentVector...) -> (R.TangentVector...))
44504439
// VJP: (T...) -> ((R...),
4451-
// (R.CotangentVector...) -> (T.CotangentVector...))
4440+
// (R.TangentVector...) -> (T.TangentVector...))
44524441
//
44534442
// Note that both can be written as "(T...) -> ((R...), Closure)", so we build
44544443
// "Closure" and then use common code to wrap "Closure" in the outer function
@@ -4487,23 +4476,20 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
44874476
SmallVector<AnyFunctionType::Param, 8> differentialParams;
44884477
for (auto wrtParamType : wrtParamTypes)
44894478
differentialParams.push_back(
4490-
AnyFunctionType::Param(wrtParamType->getAutoDiffAssociatedVectorSpace(
4491-
AutoDiffAssociatedVectorSpaceKind::Tangent, lookupConformance)
4479+
AnyFunctionType::Param(
4480+
wrtParamType->getAutoDiffAssociatedTangentSpace(lookupConformance)
44924481
->getType()));
44934482

44944483
SmallVector<TupleTypeElt, 8> differentialResults;
44954484
if (auto *resultTuple = originalResult->getAs<TupleType>()) {
44964485
auto resultTupleEltType = resultTuple->getElementType(resultIndex);
4497-
differentialResults.push_back(
4498-
resultTupleEltType->getAutoDiffAssociatedVectorSpace(
4499-
AutoDiffAssociatedVectorSpaceKind::Tangent, lookupConformance)
4500-
->getType());
4486+
differentialResults.push_back(resultTupleEltType
4487+
->getAutoDiffAssociatedTangentSpace(lookupConformance)->getType());
45014488
} else {
45024489
assert(resultIndex == 0 && "resultIndex out of bounds");
45034490
differentialResults.push_back(
4504-
originalResult->getAutoDiffAssociatedVectorSpace(
4505-
AutoDiffAssociatedVectorSpaceKind::Tangent, lookupConformance)
4506-
->getType());
4491+
originalResult->getAutoDiffAssociatedTangentSpace(lookupConformance)
4492+
->getType());
45074493
}
45084494
Type differentialResult =
45094495
differentialResults.size() > 1
@@ -4515,28 +4501,26 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
45154501
}
45164502
case AutoDiffAssociatedFunctionKind::VJP: {
45174503
// closure is the VJP "pullback":
4518-
// (R.CotangentVector...) -> (T.CotangentVector...)
4504+
// (R.TangentVector...) -> (T.TangentVector...)
45194505
SmallVector<AnyFunctionType::Param, 8> pullbackParams;
45204506
if (auto *resultTuple = originalResult->getAs<TupleType>()) {
45214507
auto resultTupleEltType = resultTuple->getElementType(resultIndex);
45224508
pullbackParams.push_back(
4523-
AnyFunctionType::Param(
4524-
resultTupleEltType->getAutoDiffAssociatedVectorSpace(
4525-
AutoDiffAssociatedVectorSpaceKind::Cotangent,
4526-
lookupConformance)->getType()));
4509+
AnyFunctionType::Param(resultTupleEltType
4510+
->getAutoDiffAssociatedTangentSpace(lookupConformance)
4511+
->getType()));
45274512
} else {
45284513
assert(resultIndex == 0 && "resultIndex out of bounds");
45294514
pullbackParams.push_back(
4530-
AnyFunctionType::Param(
4531-
originalResult->getAutoDiffAssociatedVectorSpace(
4532-
AutoDiffAssociatedVectorSpaceKind::Cotangent,
4533-
lookupConformance)->getType()));
4515+
AnyFunctionType::Param(originalResult
4516+
->getAutoDiffAssociatedTangentSpace(lookupConformance)
4517+
->getType()));
45344518
}
45354519

45364520
SmallVector<TupleTypeElt, 8> pullbackResults;
45374521
for (auto wrtParamType : wrtParamTypes)
4538-
pullbackResults.push_back(wrtParamType->getAutoDiffAssociatedVectorSpace(
4539-
AutoDiffAssociatedVectorSpaceKind::Cotangent, lookupConformance)
4522+
pullbackResults.push_back(wrtParamType
4523+
->getAutoDiffAssociatedTangentSpace(lookupConformance)
45404524
->getType());
45414525
Type pullbackResult = pullbackResults.size() > 1
45424526
? TupleType::get(pullbackResults, ctx)

lib/IRGen/GenMeta.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4192,9 +4192,6 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
41924192
case KnownProtocolKind::TensorFlowDataTypeCompatible:
41934193
case KnownProtocolKind::TensorProtocol:
41944194
case KnownProtocolKind::VectorNumeric:
4195-
// TODO(TF-213): Remove underscore `Differentiable` protocols.
4196-
case KnownProtocolKind::__Differentiable:
4197-
case KnownProtocolKind::_Differentiable:
41984195
case KnownProtocolKind::Differentiable:
41994196
return SpecialProtocol::None;
42004197
}

lib/SIL/SILFunctionType.cpp

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
154154
// JVP: (T...) -> ((R...),
155155
// (T.TangentVector...) -> (R.TangentVector...))
156156
// VJP: (T...) -> ((R...),
157-
// (R.CotangentVector...) -> (T.CotangentVector...))
157+
// (R.TangentVector...) -> (T.TangentVector...))
158158

159159
auto &ctx = getASTContext();
160160
auto &typeConverter = module.Types;
@@ -164,9 +164,10 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
164164
whereClauseGenSig = getGenericSignature();
165165

166166
// Given a type, returns its formal SIL parameter info.
167-
auto getCotangentParameterInfoForOriginalResult = [&](
168-
CanType cotanType, ResultConvention origResConv) -> SILParameterInfo {
169-
auto &tl = typeConverter.getTypeLowering(cotanType, ResilienceExpansion::Minimal);
167+
auto getTangentParameterInfoForOriginalResult = [&](
168+
CanType tanType, ResultConvention origResConv) -> SILParameterInfo {
169+
auto &tl = typeConverter.getTypeLowering(tanType,
170+
ResilienceExpansion::Minimal);
170171
ParameterConvention conv;
171172
switch (origResConv) {
172173
case ResultConvention::Owned:
@@ -183,13 +184,14 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
183184
conv = ParameterConvention::Indirect_In_Guaranteed;
184185
break;
185186
}
186-
return {cotanType, conv};
187+
return {tanType, conv};
187188
};
188189

189190
// Given a type, returns its formal SIL result info.
190-
auto getCotangentResultInfoForOriginalParameter = [&](
191-
CanType cotanType, ParameterConvention origParamConv) -> SILResultInfo {
192-
auto &tl = typeConverter.getTypeLowering(cotanType, ResilienceExpansion::Minimal);
191+
auto getTangentResultInfoForOriginalParameter = [&](
192+
CanType tanType, ParameterConvention origParamConv) -> SILResultInfo {
193+
auto &tl = typeConverter.getTypeLowering(tanType,
194+
ResilienceExpansion::Minimal);
193195
ResultConvention conv;
194196
switch (origParamConv) {
195197
case ParameterConvention::Direct_Owned:
@@ -207,7 +209,7 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
207209
conv = ResultConvention::Indirect;
208210
break;
209211
}
210-
return {cotanType, conv};
212+
return {tanType, conv};
211213
};
212214

213215
// Helper function testing if we are differentiating wrt this index.
@@ -228,17 +230,15 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
228230
SmallVector<SILParameterInfo, 8> differentialParams;
229231
for (auto &param : wrtParams) {
230232
differentialParams.push_back(
231-
{param.getType()->getAutoDiffAssociatedVectorSpace(
232-
AutoDiffAssociatedVectorSpaceKind::Tangent, lookupConformance)
233-
->getCanonicalType(),
233+
{param.getType()->getAutoDiffAssociatedTangentSpace(lookupConformance)
234+
->getCanonicalType(),
234235
param.getConvention()});
235236
}
236237
SmallVector<SILResultInfo, 8> differentialResults;
237238
auto &result = getResults()[resultIndex];
238239
differentialResults.push_back(
239-
{result.getType()->getAutoDiffAssociatedVectorSpace(
240-
AutoDiffAssociatedVectorSpaceKind::Tangent, lookupConformance)
241-
->getCanonicalType(),
240+
{result.getType()->getAutoDiffAssociatedTangentSpace(lookupConformance)
241+
->getCanonicalType(),
242242
result.getConvention()});
243243
closureType = SILFunctionType::get(
244244
/*genericSignature*/ nullptr, ExtInfo(), SILCoroutineKind::None,
@@ -249,22 +249,20 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
249249
case AutoDiffAssociatedFunctionKind::VJP: {
250250
SmallVector<SILParameterInfo, 8> pullbackParams;
251251
auto &origRes = getResults()[resultIndex];
252-
auto cotangentAssocTy =
253-
origRes.getType()->getAutoDiffAssociatedVectorSpace(
254-
AutoDiffAssociatedVectorSpaceKind::Cotangent, lookupConformance)
255-
->getCanonicalType();
252+
auto tangentAssocTy =
253+
origRes.getType()->getAutoDiffAssociatedTangentSpace(lookupConformance)
254+
->getCanonicalType();
256255
pullbackParams.push_back(
257-
getCotangentParameterInfoForOriginalResult(cotangentAssocTy,
258-
origRes.getConvention()));
256+
getTangentParameterInfoForOriginalResult(tangentAssocTy,
257+
origRes.getConvention()));
259258
SmallVector<SILResultInfo, 8> pullbackResults;
260259
for (auto &param : wrtParams) {
261-
auto paramCotangentTy =
262-
param.getType()->getAutoDiffAssociatedVectorSpace(
263-
AutoDiffAssociatedVectorSpaceKind::Cotangent, lookupConformance)
264-
->getCanonicalType();
260+
auto paramTangentTy =
261+
param.getType()->getAutoDiffAssociatedTangentSpace(lookupConformance)
262+
->getCanonicalType();
265263
pullbackResults.push_back(
266-
getCotangentResultInfoForOriginalParameter(paramCotangentTy,
267-
param.getConvention()));
264+
getTangentResultInfoForOriginalParameter(paramTangentTy,
265+
param.getConvention()));
268266
}
269267
closureType = SILFunctionType::get(
270268
/*genericSignature*/ nullptr, ExtInfo(), SILCoroutineKind::None,

lib/SIL/SILType.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,6 @@ bool SILType::isLoweringOf(SILModule &Mod, CanType formalType) {
595595
// SWIFT_ENABLE_TENSORFLOW
596596
/// Returns true if this SILType is a differentiable type.
597597
bool SILType::isDifferentiable(SILModule &M) const {
598-
return getASTType()->getAutoDiffAssociatedVectorSpace(
599-
AutoDiffAssociatedVectorSpaceKind::Tangent,
598+
return getASTType()->getAutoDiffAssociatedTangentSpace(
600599
LookUpConformanceInModule(M.getSwiftModule())).hasValue();
601600
}

0 commit comments

Comments
 (0)