Skip to content

[AutoDiff] Re-enable LoadableByAddress. #27923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/SIL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5673,17 +5673,23 @@ differentiable_function_extract
sil-instruction ::= 'differentiable_function_extract'
'[' sil-differentiable-function-extractee ']'
sil-value ':' sil-type
('as' sil-type)?

sil-differentiable-function-extractee ::= 'original' | 'jvp' | 'vjp'

differentiable_function_extract [original] %0 : $@differentiable (T) -> T
differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T
differentiable_function_extract [vjp] %0 : $@differentiable (T) -> T
differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T \
as $(@in_constant T) -> (T, (T.TangentVector) -> T.TangentVector)

Extracts the original function or a derivative function from the given
``@differentiable`` function. It must be provided with an extractee:
``[original]``, ``[jvp]`` or ``[vjp]``.

An explicit extractee type may be provided in lowered SIL. This is currently
used by the LoadableByAddress transformation, which rewrites function types.


linear_function_extract
```````````````````````
Expand Down
5 changes: 1 addition & 4 deletions include/swift/AST/SILOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,7 @@ class SILOptions {
bool EnableDynamicReplacementCanCallPreviousImplementation = true;

/// Enable large loadable types IRGen pass.
// bool EnableLargeLoadableTypes = true;
// FIXME(TF-11, SR-9849): Disabled because LoadableByAddress cannot handle
// some functions that return closures that take/return large loadable types.
bool EnableLargeLoadableTypes = false;
bool EnableLargeLoadableTypes = true;

/// Should the default pass pipelines strip ownership during the diagnostic
/// pipeline or after serialization.
Expand Down
8 changes: 5 additions & 3 deletions include/swift/SIL/SILBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -526,12 +526,14 @@ class SILBuilder {
getModule(), getSILDebugLocation(Loc), ParameterIndices,
OriginalFunction, TransposeFunction, hasOwnership()));
}


/// Note: explicit extractee type may be specified only in lowered SIL.
DifferentiableFunctionExtractInst *createDifferentiableFunctionExtract(
SILLocation Loc, NormalDifferentiableFunctionTypeComponent Extractee,
SILValue TheFunction) {
SILValue TheFunction, Optional<SILType> ExtracteeType = None) {
return insert(new (getModule()) DifferentiableFunctionExtractInst(
getModule(), getSILDebugLocation(Loc), Extractee, TheFunction));
getModule(), getSILDebugLocation(Loc), Extractee, TheFunction,
ExtracteeType));
}

LinearFunctionExtractInst *createLinearFunctionExtract(
Expand Down
20 changes: 12 additions & 8 deletions include/swift/SIL/SILInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -7969,34 +7969,38 @@ class DifferentiableFunctionExtractInst
SingleValueInstruction> {
private:
/// The extractee.
NormalDifferentiableFunctionTypeComponent extractee;
NormalDifferentiableFunctionTypeComponent Extractee;
/// The list containing the `@differentiable` function operand.
FixedOperandList<1> operands;
FixedOperandList<1> Operands;
/// True if the instruction has an explicit extractee type.
bool HasExplicitExtracteeType;

static SILType
getExtracteeType(
SILValue function, NormalDifferentiableFunctionTypeComponent extractee,
SILModule &module);

public:
/// Note: explicit extractee type may be specified only in lowered SIL.
explicit DifferentiableFunctionExtractInst(
SILModule &module, SILDebugLocation debugLoc,
NormalDifferentiableFunctionTypeComponent extractee,
SILValue theFunction);
SILValue theFunction, Optional<SILType> extracteeType = None);

NormalDifferentiableFunctionTypeComponent getExtractee() const {
return extractee;
return Extractee;
}

AutoDiffDerivativeFunctionKind getDerivativeFunctionKind() const {
auto kind = extractee.getAsDerivativeFunctionKind();
auto kind = Extractee.getAsDerivativeFunctionKind();
assert(kind);
return *kind;
}

SILValue getFunctionOperand() const { return operands[0].get(); }
ArrayRef<Operand> getAllOperands() const { return operands.asArray(); }
MutableArrayRef<Operand> getAllOperands() { return operands.asArray(); }
SILValue getFunctionOperand() const { return Operands[0].get(); }
ArrayRef<Operand> getAllOperands() const { return Operands.asArray(); }
MutableArrayRef<Operand> getAllOperands() { return Operands.asArray(); }
bool hasExplicitExtracteeType() const { return HasExplicitExtracteeType; }
};

/// `linear_function_extract` - given an `@differentiable(linear)` function
Expand Down
4 changes: 2 additions & 2 deletions lib/AST/ASTPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4511,10 +4511,10 @@ void SILParameterInfo::print(ASTPrinter &Printer,
const PrintOptions &Opts) const {
/// SWIFT_ENABLE_TENSORFLOW
switch (getDifferentiability()) {
case SILParameterDifferentiability::NotDifferentiable:
case SILParameterDifferentiability::NotDifferentiable:
Printer << "@nondiff ";
break;
default:
default:
break;
}
Printer << getStringForParameterConvention(getConvention());
Expand Down
19 changes: 15 additions & 4 deletions lib/IRGen/LoadableByAddress.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,14 +381,23 @@ SILParameterInfo LargeSILTypeMapper::getNewParameter(GenericEnvironment *env,
} else if (isLargeLoadableType(env, storageType, IGM)) {
if (param.getConvention() == ParameterConvention::Direct_Guaranteed)
return SILParameterInfo(storageType.getASTType(),
ParameterConvention::Indirect_In_Guaranteed);
// SWIFT_ENABLE_TENSORFLOW
ParameterConvention::Indirect_In_Guaranteed,
param.getDifferentiability());
// SWIFT_ENABLE_TENSORFLOW_END
else
return SILParameterInfo(storageType.getASTType(),
ParameterConvention::Indirect_In_Constant);
// SWIFT_ENABLE_TENSORFLOW
ParameterConvention::Indirect_In_Constant,
param.getDifferentiability());
// SWIFT_ENABLE_TENSORFLOW_END
} else {
auto newType = getNewSILType(env, storageType, IGM);
return SILParameterInfo(newType.getASTType(),
param.getConvention());
// SWIFT_ENABLE_TENSORFLOW
param.getConvention(),
param.getDifferentiability());
// SWIFT_ENABLE_TENSORFLOW_END
}
}

Expand Down Expand Up @@ -2757,8 +2766,10 @@ bool LoadableByAddress::recreateConvInstr(SILInstruction &I,
}
case SILInstructionKind::DifferentiableFunctionExtractInst: {
auto instr = cast<DifferentiableFunctionExtractInst>(convInstr);
// Rewrite `differentiable_function_extract` with explicit extractee type.
newInstr = convBuilder.createDifferentiableFunctionExtract(
instr->getLoc(), instr->getExtractee(), instr->getFunctionOperand());
instr->getLoc(), instr->getExtractee(), instr->getFunctionOperand(),
newType);
break;
}
case SILInstructionKind::LinearFunctionExtractInst: {
Expand Down
17 changes: 13 additions & 4 deletions lib/ParseSIL/ParseSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3050,7 +3050,8 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {

case SILInstructionKind::DifferentiableFunctionExtractInst: {
// Parse the rest of the instruction: an extractee, a differentiable
// function operand, and a debug location.
// function operand, an optional explicit extractee type, and a debug
// location.
NormalDifferentiableFunctionTypeComponent extractee;
StringRef extracteeNames[3] = {"original", "jvp", "vjp"};
SILValue functionOperand;
Expand All @@ -3062,11 +3063,19 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare,
"extractee kind"))
return true;
if (parseTypedValueRef(functionOperand, B) ||
parseSILDebugLocation(InstLoc, B))
if (parseTypedValueRef(functionOperand, B))
return true;
// Parse an optional explicit extractee type.
Optional<SILType> extracteeType = None;
if (P.consumeIf(tok::kw_as)) {
extracteeType = SILType();
if (parseSILType(*extracteeType))
return true;
}
if (parseSILDebugLocation(InstLoc, B))
return true;
ResultVal = B.createDifferentiableFunctionExtract(
InstLoc, extractee, functionOperand);
InstLoc, extractee, functionOperand, extracteeType);
break;
}

Expand Down
17 changes: 14 additions & 3 deletions lib/SIL/SILInstructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -685,10 +685,21 @@ getExtracteeType(

DifferentiableFunctionExtractInst::DifferentiableFunctionExtractInst(
SILModule &module, SILDebugLocation debugLoc,
NormalDifferentiableFunctionTypeComponent extractee, SILValue theFunction)
NormalDifferentiableFunctionTypeComponent extractee, SILValue theFunction,
Optional<SILType> extracteeType)
: InstructionBase(debugLoc,
getExtracteeType(theFunction, extractee, module)),
extractee(extractee), operands(this, theFunction) {}
extracteeType
? *extracteeType
: getExtracteeType(theFunction, extractee, module)),
Extractee(extractee), Operands(this, theFunction),
HasExplicitExtracteeType(extracteeType.hasValue()) {
#ifndef NDEBUG
if (extracteeType.hasValue()) {
assert(module.getStage() == SILStage::Lowered &&
"Explicit type is valid only in lowered SIL");
}
#endif
}

SILType LinearFunctionExtractInst::
getExtracteeType(
Expand Down
4 changes: 4 additions & 0 deletions lib/SIL/SILPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1251,6 +1251,10 @@ class SILPrinter : public SILInstructionVisitor<SILPrinter> {
}
*this << "] ";
*this << getIDAndType(dfei->getFunctionOperand());
if (dfei->hasExplicitExtracteeType()) {
*this << " as ";
*this << dfei->getType();
}
}

void visitLinearFunctionExtractInst(LinearFunctionExtractInst *lfei) {
Expand Down
16 changes: 12 additions & 4 deletions lib/SIL/SILVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1498,8 +1498,12 @@ class SILVerifier : public SILVerifierBase<SILVerifier> {
require(origTy, "The original function must have a function type");
require(!origTy->isDifferentiable(),
"The original function must not be @differentiable");
if (F.getModule().getStage() == SILStage::Canonical ||
dfi->hasDerivativeFunctions()) {
// Skip lowered SIL: LoadableByAddress changes parameter/result conventions.
// TODO: Check that derivative function types match excluding
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: robust verification here should be aware of the lowered SIL stage and check that address-lowered SIL function types match.

In the meantime, we could use a custom function for recursively checking SIL function type equality, ignoring parameter/result conventions. SILFunctionType::isABICompatibleWith is not suitable because it checks parameter/result conventions and returns an opaque ABICompatibilityCheckResult.

// parameter/result conventions in lowered SIL.
if (F.getModule().getStage() == SILStage::Lowered)
return;
if (dfi->hasDerivativeFunctions()) {
auto jvp = dfi->getJVPFunction();
auto jvpType = jvp->getType().getAs<SILFunctionType>();
require(jvpType, "The JVP function must have a function type");
Expand Down Expand Up @@ -1533,8 +1537,12 @@ class SILVerifier : public SILVerifierBase<SILVerifier> {
require(origTy, "The original function must have a function type");
require(!origTy->isDifferentiable(),
"The original function must not be differentiable");
if (F.getModule().getStage() == SILStage::Canonical ||
lfi->hasTransposeFunction()) {
// Skip lowered SIL: LoadableByAddress changes parameter/result conventions.
// TODO: Check that transpose function type matches excluding
// parameter/result conventions in lowered SIL.
if (F.getModule().getStage() == SILStage::Lowered)
return;
if (lfi->hasTransposeFunction()) {
auto transpose = lfi->getTransposeFunction();
auto transposeType = transpose->getType().getAs<SILFunctionType>();
require(transposeType,
Expand Down
9 changes: 7 additions & 2 deletions lib/Serialization/DeserializeSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1146,7 +1146,8 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, SILBasicBlock *BB,
break;
case SIL_INST_DIFFERENTIABLE_FUNCTION_EXTRACT:
SILInstDifferentiableFunctionExtractLayout::readRecord(
scratch, TyID, TyCategory, ValID, /*extractee*/ Attr);
scratch, TyID, TyCategory, ValID, /*extractee*/ Attr,
/*hasExplicitExtracteeType*/ Attr2);
RawOpCode = (unsigned)SILInstructionKind::DifferentiableFunctionExtractInst;
break;
case SIL_INST_LINEAR_FUNCTION_EXTRACT:
Expand Down Expand Up @@ -1609,8 +1610,12 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, SILBasicBlock *BB,
auto silTy = getSILType(astTy, SILValueCategory::Object);
auto val = getLocalValue(ValID, silTy);
NormalDifferentiableFunctionTypeComponent extractee(Attr);
Optional<SILType> explicitExtracteeType = None;
if (Attr2)
explicitExtracteeType = silTy;
ResultVal =
Builder.createDifferentiableFunctionExtract(Loc, extractee, val);
Builder.createDifferentiableFunctionExtract(Loc, extractee, val,
explicitExtracteeType);
break;
}
case SILInstructionKind::LinearFunctionExtractInst: {
Expand Down
2 changes: 1 addition & 1 deletion lib/Serialization/ModuleFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0;
/// describe what change you made. The content of this comment isn't important;
/// it just ensures a conflict if two people change the module format.
/// Don't worry about adhering to the 80-column limit for this line.
const uint16_t SWIFTMODULE_VERSION_MINOR = 523; // differentiable_function and differentiable_function_extract instructions
const uint16_t SWIFTMODULE_VERSION_MINOR = 524; // differentiable_function_extract explicit extractee type

/// A standard hash seed used for all string hashes in a serialized module.
///
Expand Down
3 changes: 2 additions & 1 deletion lib/Serialization/SILFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,8 @@ namespace sil_block {
TypeIDField,
SILTypeCategoryField,
ValueIDField,
BCFixed<2> // extractee
BCFixed<2>, // extractee
BCFixed<1> // has explicit extractee type?
>;

using SILInstLinearFunctionExtractLayout = BCRecordLayout<
Expand Down
2 changes: 1 addition & 1 deletion lib/Serialization/SerializeSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1043,7 +1043,7 @@ void SILSerializer::writeSILInstruction(const SILInstruction &SI) {
SILInstDifferentiableFunctionExtractLayout::emitRecord(Out, ScratchRecord,
SILAbbrCodes[SILInstDifferentiableFunctionExtractLayout::Code],
operandTypeRef, (unsigned)operandType.getCategory(), operandRef,
rawExtractee);
rawExtractee, (unsigned)dfei->hasExplicitExtracteeType());
break;
}
case SILInstructionKind::LinearFunctionExtractInst: {
Expand Down
64 changes: 64 additions & 0 deletions test/AutoDiff/differentiable_function_inst_lowered.sil
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// RUN: %target-sil-opt %s | %target-sil-opt | %FileCheck %s

// Test `differentiable_function_extract` with explicit lowered type.
// SIL generated via `%target-sil-opt -loadable-address %s`.
// Note: SIL serialization/deserialization does not support lowered SIL.

sil_stage lowered

import Swift
import Builtin

struct Large : _Differentiable {
@_hasStorage @noDerivative let a: Float { get }
@_hasStorage @noDerivative let b: Float { get }
@_hasStorage @noDerivative let c: Float { get }
@_hasStorage @noDerivative let d: Float { get }
@_hasStorage @noDerivative let e: Float { get }
init(a: Float, b: Float, c: Float, d: Float, e: Float)
struct TangentVector : _Differentiable, AdditiveArithmetic {
init()
typealias TangentVector = Large.TangentVector
static var zero: Large.TangentVector { get }
static func + (lhs: Large.TangentVector, rhs: Large.TangentVector) -> Large.TangentVector
static func - (lhs: Large.TangentVector, rhs: Large.TangentVector) -> Large.TangentVector
@_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: Large.TangentVector, _ b: Large.TangentVector) -> Bool
}
mutating func move(along direction: Large.TangentVector)
}

sil @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
sil @examplemethod : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large

// CHECK-LABEL: sil @test
sil @test : $@convention(thin) () -> () {
bb0:
%0 = function_ref @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
%1 = differentiable_function [parameters 0 1 2] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
%2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))

// CHECK: %1 = differentiable_function [parameters 0 1 2] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
// CHECK: %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))

%3 = differentiable_function [parameters 0] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
%4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)

// CHECK: %3 = differentiable_function [parameters 0] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
// CHECK: %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)

%5 = function_ref @examplemethod : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
%6 = differentiable_function [parameters 0 1 2] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
%7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))

// CHECK: %6 = differentiable_function [parameters 0 1 2] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
// CHECK: %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))

%8 = differentiable_function [parameters 0] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
%9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)

// CHECK: %8 = differentiable_function [parameters 0] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
// CHECK: %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)

%10 = tuple ()
return %10 : $()
}
Loading