Skip to content

Commit 341bb33

Browse files
authored
[AutoDiff] Propagate '@nondiff' from AST function types to SIL function types. (#23854)
This is to unblock `@differentiable` functions with `@nondiff` parameters. - Propagate `@nondiff` from AST to SIL. - Add `AnyFunctionType::getWithoutDifferentiability`, which drops all `@differentiable` and `@nondiff` attributes from a function type. - Use autodiff parameter indices from function types now that `@nondiff` has been propagated. - Replace currying logic from `SILFunctionType::getAssociatedFunctionType` with lightweight logic that handles methods, which is needed for differentiable protocol requirements. - Emit an error when a `@nondiff` parameter is being differentiated with respect to. - Fix `@nondiff` AST printing. Note: - Once `SILDifferentiableFunctionType` in #23482 lands, `@nondiff` should be nuked from SIL. - Before merging, pull from the `tensorflow` branch to make sure #23887 is merged. Resolves [TF-421](https://bugs.swift.org/browse/TF-421).
1 parent 0235147 commit 341bb33

19 files changed

+194
-233
lines changed

include/swift/AST/DiagnosticsSIL.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,8 @@ NOTE(autodiff_protocol_member_not_differentiable,none,
396396
NOTE(autodiff_protocol_member_subset_indices_not_differentiable,none,
397397
"member is differentiable only with respect to a smaller subset of "
398398
"arguments", ())
399+
NOTE(autodiff_function_nondiff_parameter_not_differentiable,none,
400+
"cannot differentiate with respect to a '@nondiff' parameter", ())
399401
NOTE(autodiff_function_assoc_func_requirements_unmet,none,
400402
"function call is not differentiable because generic requirements are not "
401403
"met", ())

include/swift/AST/Types.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3093,6 +3093,7 @@ class AnyFunctionType : public TypeBase {
30933093
return getExtInfo().getRepresentation();
30943094
}
30953095

3096+
// SWIFT_ENABLE_TENSORFLOW
30963097
/// Given `indices`, `differentiationOrder`, and `kind`, calculates the type
30973098
/// of the corresponding autodiff associated function.
30983099
///
@@ -3105,6 +3106,8 @@ class AnyFunctionType : public TypeBase {
31053106
LookupConformanceFn lookupConformance,
31063107
GenericSignature *whereClauseGenericSignature = nullptr);
31073108

3109+
AnyFunctionType *getWithoutDifferentiability() const;
3110+
31083111
/// \brief True if the parameter declaration it is attached to is guaranteed
31093112
/// to not persist the closure for longer than the duration of the call.
31103113
bool isNoEscape() const {

include/swift/SIL/SILBuilder.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,6 @@ class SILBuilder {
547547
theFunction));
548548
}
549549

550-
551550
AutoDiffFunctionExtractInst *createAutoDiffFunctionExtractOriginal(
552551
SILLocation loc, SILValue theFunction) {
553552
return insert(new (getModule()) AutoDiffFunctionExtractInst(

lib/AST/ASTPrinter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2338,7 +2338,7 @@ static void printParameterFlags(ASTPrinter &printer, PrintOptions options,
23382338
printer << "@escaping ";
23392339
// SWIFT_ENABLE_TENSORFLOW
23402340
if (!options.excludeAttrKind(TAK_nondiff) && flags.isNonDifferentiable())
2341-
printer << "@nondiff";
2341+
printer << "@nondiff ";
23422342

23432343
switch (flags.getValueOwnership()) {
23442344
case ValueOwnership::Default:

lib/AST/AutoDiff.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -254,15 +254,12 @@ AutoDiffParameterIndices::getLowered(AnyFunctionType *functionType) const {
254254
}
255255

256256
static unsigned getNumAutoDiffParameterIndices(AnyFunctionType *fnTy) {
257-
unsigned numAutoDiffParameterIndices = 0;
258-
// FIXME: Compute the exact parameter count.
259-
// Do not loop ad-infinitum; loop either 1 or 2 iterations, depending on
260-
// whether the function is a free function/static method/instance method.
261-
while (fnTy != nullptr) {
262-
numAutoDiffParameterIndices += fnTy->getNumParams();
263-
fnTy = fnTy->getResult()->getAs<AnyFunctionType>();
264-
}
265-
return numAutoDiffParameterIndices;
257+
// TODO: For more correct counting, we still need to know whether it's a
258+
// method or not.
259+
unsigned numParameters = fnTy->getNumParams();
260+
if (auto *innerFn = fnTy->getResult()->getAs<AnyFunctionType>())
261+
numParameters += innerFn->getNumParams();
262+
return numParameters;
266263
}
267264

268265
/// Returns true if the given type conforms to `Differentiable` in the given

lib/AST/Builtins.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,8 +1075,8 @@ static ValueDecl *getAutoDiffApplyAssociatedFunction(
10751075
// of the associated function type.
10761076
auto *origFnTy =
10771077
firstArgGen.build(builder)->castTo<AnyFunctionType>();
1078-
origFnTy = origFnTy->withExtInfo(
1079-
origFnTy->getExtInfo().withDifferentiable(false).withNoEscape(false));
1078+
origFnTy = origFnTy->getWithoutDifferentiability()->withExtInfo(
1079+
origFnTy->getExtInfo().withNoEscape(false));
10801080
auto autodiffBuilder = AutoDiffParameterIndicesBuilder::inferParameters(
10811081
origFnTy, Context.getStdlibModule());
10821082
auto *paramIndices = autodiffBuilder.build(Context);

lib/AST/Type.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4304,3 +4304,18 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
43044304

43054305
return associatedFunction;
43064306
}
4307+
4308+
AnyFunctionType *AnyFunctionType::getWithoutDifferentiability() const {
4309+
SmallVector<Param, 8> newParams;
4310+
for (auto &param : getParams()) {
4311+
Param newParam(param.getPlainType(), param.getLabel(),
4312+
param.getParameterFlags().withNonDifferentiable(false));
4313+
newParams.push_back(newParam);
4314+
}
4315+
auto nonDiffExtInfo = getExtInfo().withDifferentiable(false);
4316+
if (isa<FunctionType>(this))
4317+
return FunctionType::get(newParams, getResult(), nonDiffExtInfo);
4318+
assert(isa<GenericFunctionType>(this));
4319+
return GenericFunctionType::get(getOptGenericSignature(), newParams,
4320+
getResult(), nonDiffExtInfo);
4321+
}

lib/IRGen/GenDiffFunc.cpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,16 @@ using DiffFuncIndex =
3838
namespace {
3939
class DiffFuncFieldInfo final : public RecordField<DiffFuncFieldInfo> {
4040
public:
41-
DiffFuncFieldInfo(DiffFuncIndex index, const TypeInfo &type)
42-
: RecordField(type), Index(index) {}
41+
DiffFuncFieldInfo(DiffFuncIndex index, const TypeInfo &type,
42+
const SmallBitVector &parameterIndices)
43+
: RecordField(type), Index(index), ParameterIndices(parameterIndices) {}
4344

4445
/// The field index.
4546
const DiffFuncIndex Index;
4647

48+
/// The parameter indices.
49+
SmallBitVector ParameterIndices;
50+
4751
std::string getFieldName() const {
4852
auto extractee = std::get<0>(Index);
4953
auto differentiationOrder = std::get<1>(Index);
@@ -59,17 +63,14 @@ class DiffFuncFieldInfo final : public RecordField<DiffFuncFieldInfo> {
5963

6064
SILType getType(IRGenModule &IGM, SILType t) const {
6165
auto fnTy = t.castTo<SILFunctionType>();
62-
auto extInfo = fnTy->getExtInfo();
63-
auto nondiffExtInfo = extInfo.withDifferentiable(false);
64-
auto origFnTy = fnTy->getWithExtInfo(nondiffExtInfo);
66+
auto origFnTy = fnTy->getWithoutDifferentiability();
6567
if (std::get<0>(Index) == AutoDiffFunctionExtractInst::Extractee::Original)
6668
return SILType::getPrimitiveObjectType(origFnTy);
6769
auto differentiationOrder = std::get<1>(Index);
6870
auto kind = *std::get<0>(Index).getExtracteeAsAssociatedFunction();
6971
auto assocTy = origFnTy->getAutoDiffAssociatedFunctionType(
70-
SmallBitVector(origFnTy->getNumParameters(), true), /*resultIndex*/ 0,
71-
differentiationOrder, kind, IGM.getSILModule(),
72-
LookUpConformanceInModule(IGM.getSwiftModule()));
72+
ParameterIndices, /*resultIndex*/ 0, differentiationOrder, kind,
73+
IGM.getSILModule(), LookUpConformanceInModule(IGM.getSwiftModule()));
7374
return SILType::getPrimitiveObjectType(assocTy);
7475
}
7576
};
@@ -118,14 +119,13 @@ class DiffFuncTypeBuilder
118119
DiffFuncIndex> {
119120

120121
SILFunctionType *origFnTy;
122+
SmallBitVector parameterIndices;
121123

122124
public:
123125
DiffFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy)
124-
: RecordTypeBuilder(IGM) {
126+
: RecordTypeBuilder(IGM), origFnTy(fnTy->getWithoutDifferentiability()),
127+
parameterIndices(fnTy->getDifferentiationParameterIndices()) {
125128
assert(fnTy->isDifferentiable());
126-
auto extInfo = fnTy->getExtInfo();
127-
auto nondiffExtInfo = extInfo.withDifferentiable(false);
128-
origFnTy = fnTy->getWithExtInfo(nondiffExtInfo);
129129
}
130130

131131
TypeInfo *createFixed(ArrayRef<DiffFuncFieldInfo> fields,
@@ -150,7 +150,7 @@ class DiffFuncTypeBuilder
150150

151151
DiffFuncFieldInfo getFieldInfo(unsigned index, DiffFuncIndex field,
152152
const TypeInfo &fieldTI) {
153-
return DiffFuncFieldInfo(field, fieldTI);
153+
return DiffFuncFieldInfo(field, fieldTI, parameterIndices);
154154
}
155155

156156
SILType getType(DiffFuncIndex field) {
@@ -159,9 +159,8 @@ class DiffFuncTypeBuilder
159159
auto differentiationOrder = std::get<1>(field);
160160
auto kind = *std::get<0>(field).getExtracteeAsAssociatedFunction();
161161
auto assocTy = origFnTy->getAutoDiffAssociatedFunctionType(
162-
SmallBitVector(origFnTy->getNumParameters(), true), /*resultIndex*/ 0,
163-
differentiationOrder, kind, IGM.getSILModule(),
164-
LookUpConformanceInModule(IGM.getSwiftModule()));
162+
parameterIndices, /*resultIndex*/ 0, differentiationOrder, kind,
163+
IGM.getSILModule(), LookUpConformanceInModule(IGM.getSwiftModule()));
165164
return SILType::getPrimitiveObjectType(assocTy);
166165
}
167166

lib/IRGen/IRGenSIL.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3580,10 +3580,8 @@ void IRGenSILFunction::visitFullApplySite(FullApplySite site) {
35803580
(void)adFnExp.claimAll();
35813581
tmpCalleeLV = LoweredValue(e);
35823582

3583-
origCalleeType = origCalleeType->getWithExtInfo(
3584-
origCalleeType->getExtInfo().withDifferentiable(false));
3585-
substCalleeType = substCalleeType->getWithExtInfo(
3586-
substCalleeType->getExtInfo().withDifferentiable(false));
3583+
origCalleeType = origCalleeType->getWithoutDifferentiability();
3584+
substCalleeType = substCalleeType->getWithoutDifferentiability();
35873585
}
35883586
const LoweredValue &calleeLV =
35893587
tmpCalleeLV ? *tmpCalleeLV : getLoweredValue(site.getCallee());

0 commit comments

Comments
 (0)