Skip to content

[AutoDiff] Propagate '@nondiff' from AST function types to SIL function types. #23854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 11, 2019
Merged
2 changes: 2 additions & 0 deletions include/swift/AST/DiagnosticsSIL.def
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,8 @@ NOTE(autodiff_protocol_member_not_differentiable,none,
NOTE(autodiff_protocol_member_subset_indices_not_differentiable,none,
"member is differentiable only with respect to a smaller subset of "
"arguments", ())
NOTE(autodiff_function_nondiff_parameter_not_differentiable,none,
"cannot differentiate with respect to a '@nondiff' parameter", ())
NOTE(autodiff_function_assoc_func_requirements_unmet,none,
"function call is not differentiable because generic requirements are not "
"met", ())
Expand Down
3 changes: 3 additions & 0 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -3093,6 +3093,7 @@ class AnyFunctionType : public TypeBase {
return getExtInfo().getRepresentation();
}

// SWIFT_ENABLE_TENSORFLOW
/// Given `indices`, `differentiationOrder`, and `kind`, calculates the type
/// of the corresponding autodiff associated function.
///
Expand All @@ -3105,6 +3106,8 @@ class AnyFunctionType : public TypeBase {
LookupConformanceFn lookupConformance,
GenericSignature *whereClauseGenericSignature = nullptr);

AnyFunctionType *getWithoutDifferentiability() const;

/// \brief True if the parameter declaration it is attached to is guaranteed
/// to not persist the closure for longer than the duration of the call.
bool isNoEscape() const {
Expand Down
1 change: 0 additions & 1 deletion include/swift/SIL/SILBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,6 @@ class SILBuilder {
theFunction));
}


AutoDiffFunctionExtractInst *createAutoDiffFunctionExtractOriginal(
SILLocation loc, SILValue theFunction) {
return insert(new (getModule()) AutoDiffFunctionExtractInst(
Expand Down
2 changes: 1 addition & 1 deletion lib/AST/ASTPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2338,7 +2338,7 @@ static void printParameterFlags(ASTPrinter &printer, PrintOptions options,
printer << "@escaping ";
// SWIFT_ENABLE_TENSORFLOW
if (!options.excludeAttrKind(TAK_nondiff) && flags.isNonDifferentiable())
printer << "@nondiff";
printer << "@nondiff ";

switch (flags.getValueOwnership()) {
case ValueOwnership::Default:
Expand Down
15 changes: 6 additions & 9 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,15 +254,12 @@ AutoDiffParameterIndices::getLowered(AnyFunctionType *functionType) const {
}

static unsigned getNumAutoDiffParameterIndices(AnyFunctionType *fnTy) {
unsigned numAutoDiffParameterIndices = 0;
// FIXME: Compute the exact parameter count.
// Do not loop ad-infinitum; loop either 1 or 2 iterations, depending on
// whether the function is a free function/static method/instance method.
while (fnTy != nullptr) {
numAutoDiffParameterIndices += fnTy->getNumParams();
fnTy = fnTy->getResult()->getAs<AnyFunctionType>();
}
return numAutoDiffParameterIndices;
// TODO: For more correct counting, we still need to know whether it's a
// method or not.
unsigned numParameters = fnTy->getNumParams();
if (auto *innerFn = fnTy->getResult()->getAs<AnyFunctionType>())
numParameters += innerFn->getNumParams();
return numParameters;
}

/// Returns true if the given type conforms to `Differentiable` in the given
Expand Down
4 changes: 2 additions & 2 deletions lib/AST/Builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1075,8 +1075,8 @@ static ValueDecl *getAutoDiffApplyAssociatedFunction(
// of the associated function type.
auto *origFnTy =
firstArgGen.build(builder)->castTo<AnyFunctionType>();
origFnTy = origFnTy->withExtInfo(
origFnTy->getExtInfo().withDifferentiable(false).withNoEscape(false));
origFnTy = origFnTy->getWithoutDifferentiability()->withExtInfo(
origFnTy->getExtInfo().withNoEscape(false));
auto autodiffBuilder = AutoDiffParameterIndicesBuilder::inferParameters(
origFnTy, Context.getStdlibModule());
auto *paramIndices = autodiffBuilder.build(Context);
Expand Down
15 changes: 15 additions & 0 deletions lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4304,3 +4304,18 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(

return associatedFunction;
}

AnyFunctionType *AnyFunctionType::getWithoutDifferentiability() const {
SmallVector<Param, 8> newParams;
for (auto &param : getParams()) {
Param newParam(param.getPlainType(), param.getLabel(),
param.getParameterFlags().withNonDifferentiable(false));
newParams.push_back(newParam);
}
auto nonDiffExtInfo = getExtInfo().withDifferentiable(false);
if (isa<FunctionType>(this))
return FunctionType::get(newParams, getResult(), nonDiffExtInfo);
assert(isa<GenericFunctionType>(this));
return GenericFunctionType::get(getOptGenericSignature(), newParams,
getResult(), nonDiffExtInfo);
}
31 changes: 15 additions & 16 deletions lib/IRGen/GenDiffFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,16 @@ using DiffFuncIndex =
namespace {
class DiffFuncFieldInfo final : public RecordField<DiffFuncFieldInfo> {
public:
DiffFuncFieldInfo(DiffFuncIndex index, const TypeInfo &type)
: RecordField(type), Index(index) {}
DiffFuncFieldInfo(DiffFuncIndex index, const TypeInfo &type,
const SmallBitVector &parameterIndices)
: RecordField(type), Index(index), ParameterIndices(parameterIndices) {}

/// The field index.
const DiffFuncIndex Index;

/// The parameter indices.
SmallBitVector ParameterIndices;

std::string getFieldName() const {
auto extractee = std::get<0>(Index);
auto differentiationOrder = std::get<1>(Index);
Expand All @@ -59,17 +63,14 @@ class DiffFuncFieldInfo final : public RecordField<DiffFuncFieldInfo> {

SILType getType(IRGenModule &IGM, SILType t) const {
auto fnTy = t.castTo<SILFunctionType>();
auto extInfo = fnTy->getExtInfo();
auto nondiffExtInfo = extInfo.withDifferentiable(false);
auto origFnTy = fnTy->getWithExtInfo(nondiffExtInfo);
auto origFnTy = fnTy->getWithoutDifferentiability();
if (std::get<0>(Index) == AutoDiffFunctionExtractInst::Extractee::Original)
return SILType::getPrimitiveObjectType(origFnTy);
auto differentiationOrder = std::get<1>(Index);
auto kind = *std::get<0>(Index).getExtracteeAsAssociatedFunction();
auto assocTy = origFnTy->getAutoDiffAssociatedFunctionType(
SmallBitVector(origFnTy->getNumParameters(), true), /*resultIndex*/ 0,
differentiationOrder, kind, IGM.getSILModule(),
LookUpConformanceInModule(IGM.getSwiftModule()));
ParameterIndices, /*resultIndex*/ 0, differentiationOrder, kind,
IGM.getSILModule(), LookUpConformanceInModule(IGM.getSwiftModule()));
return SILType::getPrimitiveObjectType(assocTy);
}
};
Expand Down Expand Up @@ -118,14 +119,13 @@ class DiffFuncTypeBuilder
DiffFuncIndex> {

SILFunctionType *origFnTy;
SmallBitVector parameterIndices;

public:
DiffFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy)
: RecordTypeBuilder(IGM) {
: RecordTypeBuilder(IGM), origFnTy(fnTy->getWithoutDifferentiability()),
parameterIndices(fnTy->getDifferentiationParameterIndices()) {
assert(fnTy->isDifferentiable());
auto extInfo = fnTy->getExtInfo();
auto nondiffExtInfo = extInfo.withDifferentiable(false);
origFnTy = fnTy->getWithExtInfo(nondiffExtInfo);
}

TypeInfo *createFixed(ArrayRef<DiffFuncFieldInfo> fields,
Expand All @@ -150,7 +150,7 @@ class DiffFuncTypeBuilder

DiffFuncFieldInfo getFieldInfo(unsigned index, DiffFuncIndex field,
const TypeInfo &fieldTI) {
return DiffFuncFieldInfo(field, fieldTI);
return DiffFuncFieldInfo(field, fieldTI, parameterIndices);
}

SILType getType(DiffFuncIndex field) {
Expand All @@ -159,9 +159,8 @@ class DiffFuncTypeBuilder
auto differentiationOrder = std::get<1>(field);
auto kind = *std::get<0>(field).getExtracteeAsAssociatedFunction();
auto assocTy = origFnTy->getAutoDiffAssociatedFunctionType(
SmallBitVector(origFnTy->getNumParameters(), true), /*resultIndex*/ 0,
differentiationOrder, kind, IGM.getSILModule(),
LookUpConformanceInModule(IGM.getSwiftModule()));
parameterIndices, /*resultIndex*/ 0, differentiationOrder, kind,
IGM.getSILModule(), LookUpConformanceInModule(IGM.getSwiftModule()));
return SILType::getPrimitiveObjectType(assocTy);
}

Expand Down
6 changes: 2 additions & 4 deletions lib/IRGen/IRGenSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3580,10 +3580,8 @@ void IRGenSILFunction::visitFullApplySite(FullApplySite site) {
(void)adFnExp.claimAll();
tmpCalleeLV = LoweredValue(e);

origCalleeType = origCalleeType->getWithExtInfo(
origCalleeType->getExtInfo().withDifferentiable(false));
substCalleeType = substCalleeType->getWithExtInfo(
substCalleeType->getExtInfo().withDifferentiable(false));
origCalleeType = origCalleeType->getWithoutDifferentiability();
substCalleeType = substCalleeType->getWithoutDifferentiability();
}
const LoweredValue &calleeLV =
tmpCalleeLV ? *tmpCalleeLV : getLoweredValue(site.getCallee());
Expand Down
Loading