Skip to content

Commit cadc7c1

Browse files
committed
[AutoDiff] Re-enable LoadableByAddress.
The LoadableByAddress transform rewrites function types. This caused verification failures in lowered SIL for `differentiable_function` and `differentiable_function_extract` instructions, which have precise type verification rules. Now, verification has been relaxed for these instructions in lowered SIL. - `differentiable_function_extract` can now have an explicit extractee type in lowered SIL: parsing/printing/serialization support have been added. - LoadableByAddress rewrites `differentiable_function_extract` instructions using explicit address-lowered function type, similar to other function conversion instructions. - Add SIL FileCheck and runtime tests.
1 parent 1862c95 commit cadc7c1

16 files changed

+234
-37
lines changed

docs/SIL.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5673,17 +5673,23 @@ differentiable_function_extract
56735673
sil-instruction ::= 'differentiable_function_extract'
56745674
'[' sil-differentiable-function-extractee ']'
56755675
sil-value ':' sil-type
5676+
('as' sil-type)?
56765677

56775678
sil-differentiable-function-extractee ::= 'original' | 'jvp' | 'vjp'
56785679

56795680
differentiable_function_extract [original] %0 : $@differentiable (T) -> T
56805681
differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T
56815682
differentiable_function_extract [vjp] %0 : $@differentiable (T) -> T
5683+
differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T \
5684+
as $(@in_constant T) -> (T, (T.TangentVector) -> T.TangentVector)
56825685

56835686
Extracts the original function or a derivative function from the given
56845687
``@differentiable`` function. It must be provided with an extractee:
56855688
``[original]``, ``[jvp]`` or ``[vjp]``.
56865689

5690+
An explicit extractee type may be provided in lowered SIL. This is currently
5691+
used by the LoadableByAddress transformation, which rewrites function types.
5692+
56875693

56885694
linear_function_extract
56895695
```````````````````````

include/swift/AST/SILOptions.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,7 @@ class SILOptions {
143143
bool EnableDynamicReplacementCanCallPreviousImplementation = true;
144144

145145
/// Enable large loadable types IRGen pass.
146-
// bool EnableLargeLoadableTypes = true;
147-
// FIXME(TF-11, SR-9849): Disabled because LoadableByAddress cannot handle
148-
// some functions that return closures that take/return large loadable types.
149-
bool EnableLargeLoadableTypes = false;
146+
bool EnableLargeLoadableTypes = true;
150147

151148
/// Should the default pass pipelines strip ownership during the diagnostic
152149
/// pipeline or after serialization.

include/swift/SIL/SILBuilder.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -526,12 +526,14 @@ class SILBuilder {
526526
getModule(), getSILDebugLocation(Loc), ParameterIndices,
527527
OriginalFunction, TransposeFunction, hasOwnership()));
528528
}
529-
529+
530+
/// Note: explicit extractee type may be specified only in lowered SIL.
530531
DifferentiableFunctionExtractInst *createDifferentiableFunctionExtract(
531532
SILLocation Loc, NormalDifferentiableFunctionTypeComponent Extractee,
532-
SILValue TheFunction) {
533+
SILValue TheFunction, Optional<SILType> ExtracteeType = None) {
533534
return insert(new (getModule()) DifferentiableFunctionExtractInst(
534-
getModule(), getSILDebugLocation(Loc), Extractee, TheFunction));
535+
getModule(), getSILDebugLocation(Loc), Extractee, TheFunction,
536+
ExtracteeType));
535537
}
536538

537539
LinearFunctionExtractInst *createLinearFunctionExtract(

include/swift/SIL/SILInstruction.h

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7969,34 +7969,38 @@ class DifferentiableFunctionExtractInst
79697969
SingleValueInstruction> {
79707970
private:
79717971
/// The extractee.
7972-
NormalDifferentiableFunctionTypeComponent extractee;
7972+
NormalDifferentiableFunctionTypeComponent Extractee;
79737973
/// The list containing the `@differentiable` function operand.
7974-
FixedOperandList<1> operands;
7974+
FixedOperandList<1> Operands;
7975+
/// True if the instruction has an explicit extractee type.
7976+
bool HasExplicitExtracteeType;
79757977

79767978
static SILType
79777979
getExtracteeType(
79787980
SILValue function, NormalDifferentiableFunctionTypeComponent extractee,
79797981
SILModule &module);
79807982

79817983
public:
7984+
/// Note: explicit extractee type may be specified only in lowered SIL.
79827985
explicit DifferentiableFunctionExtractInst(
79837986
SILModule &module, SILDebugLocation debugLoc,
79847987
NormalDifferentiableFunctionTypeComponent extractee,
7985-
SILValue theFunction);
7988+
SILValue theFunction, Optional<SILType> extracteeType = None);
79867989

79877990
NormalDifferentiableFunctionTypeComponent getExtractee() const {
7988-
return extractee;
7991+
return Extractee;
79897992
}
79907993

79917994
AutoDiffDerivativeFunctionKind getDerivativeFunctionKind() const {
7992-
auto kind = extractee.getAsDerivativeFunctionKind();
7995+
auto kind = Extractee.getAsDerivativeFunctionKind();
79937996
assert(kind);
79947997
return *kind;
79957998
}
79967999

7997-
SILValue getFunctionOperand() const { return operands[0].get(); }
7998-
ArrayRef<Operand> getAllOperands() const { return operands.asArray(); }
7999-
MutableArrayRef<Operand> getAllOperands() { return operands.asArray(); }
8000+
SILValue getFunctionOperand() const { return Operands[0].get(); }
8001+
ArrayRef<Operand> getAllOperands() const { return Operands.asArray(); }
8002+
MutableArrayRef<Operand> getAllOperands() { return Operands.asArray(); }
8003+
bool hasExplicitExtracteeType() const { return HasExplicitExtracteeType; }
80008004
};
80018005

80028006
/// `linear_function_extract` - given an `@differentiable(linear)` function

lib/AST/ASTPrinter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4511,10 +4511,10 @@ void SILParameterInfo::print(ASTPrinter &Printer,
45114511
const PrintOptions &Opts) const {
45124512
/// SWIFT_ENABLE_TENSORFLOW
45134513
switch (getDifferentiability()) {
4514-
case SILParameterDifferentiability::NotDifferentiable:
4514+
case SILParameterDifferentiability::NotDifferentiable:
45154515
Printer << "@nondiff ";
45164516
break;
4517-
default:
4517+
default:
45184518
break;
45194519
}
45204520
Printer << getStringForParameterConvention(getConvention());

lib/IRGen/LoadableByAddress.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -381,14 +381,23 @@ SILParameterInfo LargeSILTypeMapper::getNewParameter(GenericEnvironment *env,
381381
} else if (isLargeLoadableType(env, storageType, IGM)) {
382382
if (param.getConvention() == ParameterConvention::Direct_Guaranteed)
383383
return SILParameterInfo(storageType.getASTType(),
384-
ParameterConvention::Indirect_In_Guaranteed);
384+
// SWIFT_ENABLE_TENSORFLOW
385+
ParameterConvention::Indirect_In_Guaranteed,
386+
param.getDifferentiability());
387+
// SWIFT_ENABLE_TENSORFLOW_END
385388
else
386389
return SILParameterInfo(storageType.getASTType(),
387-
ParameterConvention::Indirect_In_Constant);
390+
// SWIFT_ENABLE_TENSORFLOW
391+
ParameterConvention::Indirect_In_Constant,
392+
param.getDifferentiability());
393+
// SWIFT_ENABLE_TENSORFLOW_END
388394
} else {
389395
auto newType = getNewSILType(env, storageType, IGM);
390396
return SILParameterInfo(newType.getASTType(),
391-
param.getConvention());
397+
// SWIFT_ENABLE_TENSORFLOW
398+
param.getConvention(),
399+
param.getDifferentiability());
400+
// SWIFT_ENABLE_TENSORFLOW_END
392401
}
393402
}
394403

@@ -2757,8 +2766,10 @@ bool LoadableByAddress::recreateConvInstr(SILInstruction &I,
27572766
}
27582767
case SILInstructionKind::DifferentiableFunctionExtractInst: {
27592768
auto instr = cast<DifferentiableFunctionExtractInst>(convInstr);
2769+
// Rewrite `differentiable_function_extract` with explicit extractee type.
27602770
newInstr = convBuilder.createDifferentiableFunctionExtract(
2761-
instr->getLoc(), instr->getExtractee(), instr->getFunctionOperand());
2771+
instr->getLoc(), instr->getExtractee(), instr->getFunctionOperand(),
2772+
newType);
27622773
break;
27632774
}
27642775
case SILInstructionKind::LinearFunctionExtractInst: {

lib/ParseSIL/ParseSIL.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3050,7 +3050,8 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
30503050

30513051
case SILInstructionKind::DifferentiableFunctionExtractInst: {
30523052
// Parse the rest of the instruction: an extractee, a differentiable
3053-
// function operand, and a debug location.
3053+
// function operand, an optional explicit extractee type, and a debug
3054+
// location.
30543055
NormalDifferentiableFunctionTypeComponent extractee;
30553056
StringRef extracteeNames[3] = {"original", "jvp", "vjp"};
30563057
SILValue functionOperand;
@@ -3062,11 +3063,19 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
30623063
P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare,
30633064
"extractee kind"))
30643065
return true;
3065-
if (parseTypedValueRef(functionOperand, B) ||
3066-
parseSILDebugLocation(InstLoc, B))
3066+
if (parseTypedValueRef(functionOperand, B))
3067+
return true;
3068+
// Parse an optional explicit extractee type.
3069+
Optional<SILType> extracteeType = None;
3070+
if (P.consumeIf(tok::kw_as)) {
3071+
extracteeType = SILType();
3072+
if (parseSILType(*extracteeType))
3073+
return true;
3074+
}
3075+
if (parseSILDebugLocation(InstLoc, B))
30673076
return true;
30683077
ResultVal = B.createDifferentiableFunctionExtract(
3069-
InstLoc, extractee, functionOperand);
3078+
InstLoc, extractee, functionOperand, extracteeType);
30703079
break;
30713080
}
30723081

lib/SIL/SILInstructions.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -685,10 +685,21 @@ getExtracteeType(
685685

686686
DifferentiableFunctionExtractInst::DifferentiableFunctionExtractInst(
687687
SILModule &module, SILDebugLocation debugLoc,
688-
NormalDifferentiableFunctionTypeComponent extractee, SILValue theFunction)
688+
NormalDifferentiableFunctionTypeComponent extractee, SILValue theFunction,
689+
Optional<SILType> extracteeType)
689690
: InstructionBase(debugLoc,
690-
getExtracteeType(theFunction, extractee, module)),
691-
extractee(extractee), operands(this, theFunction) {}
691+
extracteeType
692+
? *extracteeType
693+
: getExtracteeType(theFunction, extractee, module)),
694+
Extractee(extractee), Operands(this, theFunction),
695+
HasExplicitExtracteeType(extracteeType.hasValue()) {
696+
#ifndef NDEBUG
697+
if (extracteeType.hasValue()) {
698+
assert(module.getStage() == SILStage::Lowered &&
699+
"Explicit type is valid only in lowered SIL");
700+
}
701+
#endif
702+
}
692703

693704
SILType LinearFunctionExtractInst::
694705
getExtracteeType(

lib/SIL/SILPrinter.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,6 +1251,10 @@ class SILPrinter : public SILInstructionVisitor<SILPrinter> {
12511251
}
12521252
*this << "] ";
12531253
*this << getIDAndType(dfei->getFunctionOperand());
1254+
if (dfei->hasExplicitExtracteeType()) {
1255+
*this << " as ";
1256+
*this << dfei->getType();
1257+
}
12541258
}
12551259

12561260
void visitLinearFunctionExtractInst(LinearFunctionExtractInst *lfei) {

lib/SIL/SILVerifier.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,8 +1498,12 @@ class SILVerifier : public SILVerifierBase<SILVerifier> {
14981498
require(origTy, "The original function must have a function type");
14991499
require(!origTy->isDifferentiable(),
15001500
"The original function must not be @differentiable");
1501-
if (F.getModule().getStage() == SILStage::Canonical ||
1502-
dfi->hasDerivativeFunctions()) {
1501+
// Skip lowered SIL: LoadableByAddress changes parameter/result conventions.
1502+
// TODO: Check that derivative function types match excluding
1503+
// parameter/result conventions in lowered SIL.
1504+
if (F.getModule().getStage() == SILStage::Lowered)
1505+
return;
1506+
if (dfi->hasDerivativeFunctions()) {
15031507
auto jvp = dfi->getJVPFunction();
15041508
auto jvpType = jvp->getType().getAs<SILFunctionType>();
15051509
require(jvpType, "The JVP function must have a function type");
@@ -1533,8 +1537,12 @@ class SILVerifier : public SILVerifierBase<SILVerifier> {
15331537
require(origTy, "The original function must have a function type");
15341538
require(!origTy->isDifferentiable(),
15351539
"The original function must not be differentiable");
1536-
if (F.getModule().getStage() == SILStage::Canonical ||
1537-
lfi->hasTransposeFunction()) {
1540+
// Skip lowered SIL: LoadableByAddress changes parameter/result conventions.
1541+
// TODO: Check that transpose function type matches excluding
1542+
// parameter/result conventions in lowered SIL.
1543+
if (F.getModule().getStage() == SILStage::Lowered)
1544+
return;
1545+
if (lfi->hasTransposeFunction()) {
15381546
auto transpose = lfi->getTransposeFunction();
15391547
auto transposeType = transpose->getType().getAs<SILFunctionType>();
15401548
require(transposeType,

lib/Serialization/DeserializeSIL.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,7 +1146,8 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, SILBasicBlock *BB,
11461146
break;
11471147
case SIL_INST_DIFFERENTIABLE_FUNCTION_EXTRACT:
11481148
SILInstDifferentiableFunctionExtractLayout::readRecord(
1149-
scratch, TyID, TyCategory, ValID, /*extractee*/ Attr);
1149+
scratch, TyID, TyCategory, ValID, /*extractee*/ Attr,
1150+
/*hasExplicitExtracteeType*/ Attr2);
11501151
RawOpCode = (unsigned)SILInstructionKind::DifferentiableFunctionExtractInst;
11511152
break;
11521153
case SIL_INST_LINEAR_FUNCTION_EXTRACT:
@@ -1609,8 +1610,12 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, SILBasicBlock *BB,
16091610
auto silTy = getSILType(astTy, SILValueCategory::Object);
16101611
auto val = getLocalValue(ValID, silTy);
16111612
NormalDifferentiableFunctionTypeComponent extractee(Attr);
1613+
Optional<SILType> explicitExtracteeType = None;
1614+
if (Attr2)
1615+
explicitExtracteeType = silTy;
16121616
ResultVal =
1613-
Builder.createDifferentiableFunctionExtract(Loc, extractee, val);
1617+
Builder.createDifferentiableFunctionExtract(Loc, extractee, val,
1618+
explicitExtracteeType);
16141619
break;
16151620
}
16161621
case SILInstructionKind::LinearFunctionExtractInst: {

lib/Serialization/ModuleFormat.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0;
5252
/// describe what change you made. The content of this comment isn't important;
5353
/// it just ensures a conflict if two people change the module format.
5454
/// Don't worry about adhering to the 80-column limit for this line.
55-
const uint16_t SWIFTMODULE_VERSION_MINOR = 523; // differentiable_function and differentiable_function_extract instructions
55+
const uint16_t SWIFTMODULE_VERSION_MINOR = 524; // differentiable_function_extract explicit extractee type
5656

5757
/// A standard hash seed used for all string hashes in a serialized module.
5858
///

lib/Serialization/SILFormat.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,8 @@ namespace sil_block {
456456
TypeIDField,
457457
SILTypeCategoryField,
458458
ValueIDField,
459-
BCFixed<2> // extractee
459+
BCFixed<2>, // extractee
460+
BCFixed<1> // has explicit extractee type?
460461
>;
461462

462463
using SILInstLinearFunctionExtractLayout = BCRecordLayout<

lib/Serialization/SerializeSIL.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1043,7 +1043,7 @@ void SILSerializer::writeSILInstruction(const SILInstruction &SI) {
10431043
SILInstDifferentiableFunctionExtractLayout::emitRecord(Out, ScratchRecord,
10441044
SILAbbrCodes[SILInstDifferentiableFunctionExtractLayout::Code],
10451045
operandTypeRef, (unsigned)operandType.getCategory(), operandRef,
1046-
rawExtractee);
1046+
rawExtractee, (unsigned)dfei->hasExplicitExtracteeType());
10471047
break;
10481048
}
10491049
case SILInstructionKind::LinearFunctionExtractInst: {
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// RUN: %target-sil-opt %s | %target-sil-opt | %FileCheck %s
2+
3+
// Test `differentiable_function_extract` with explicit lowered type.
4+
// SIL generated via `%target-sil-opt -loadable-address %s`.
5+
// Note: SIL serialization/deserialization does not support lowered SIL.
6+
7+
sil_stage lowered
8+
9+
import Swift
10+
import Builtin
11+
12+
struct Large : _Differentiable {
13+
@_hasStorage @noDerivative let a: Float { get }
14+
@_hasStorage @noDerivative let b: Float { get }
15+
@_hasStorage @noDerivative let c: Float { get }
16+
@_hasStorage @noDerivative let d: Float { get }
17+
@_hasStorage @noDerivative let e: Float { get }
18+
init(a: Float, b: Float, c: Float, d: Float, e: Float)
19+
struct TangentVector : _Differentiable, AdditiveArithmetic {
20+
init()
21+
typealias TangentVector = Large.TangentVector
22+
static var zero: Large.TangentVector { get }
23+
static func + (lhs: Large.TangentVector, rhs: Large.TangentVector) -> Large.TangentVector
24+
static func - (lhs: Large.TangentVector, rhs: Large.TangentVector) -> Large.TangentVector
25+
@_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: Large.TangentVector, _ b: Large.TangentVector) -> Bool
26+
}
27+
mutating func move(along direction: Large.TangentVector)
28+
}
29+
30+
sil @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
31+
sil @examplemethod : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
32+
33+
// CHECK-LABEL: sil @test
34+
sil @test : $@convention(thin) () -> () {
35+
bb0:
36+
%0 = function_ref @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
37+
%1 = differentiable_function [parameters 0 1 2] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
38+
%2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
39+
40+
// CHECK: %1 = differentiable_function [parameters 0 1 2] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
41+
// CHECK: %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
42+
43+
%3 = differentiable_function [parameters 0] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
44+
%4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
45+
46+
// CHECK: %3 = differentiable_function [parameters 0] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
47+
// CHECK: %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
48+
49+
%5 = function_ref @examplemethod : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
50+
%6 = differentiable_function [parameters 0 1 2] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
51+
%7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
52+
53+
// CHECK: %6 = differentiable_function [parameters 0 1 2] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
54+
// CHECK: %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
55+
56+
%8 = differentiable_function [parameters 0] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
57+
%9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
58+
59+
// CHECK: %8 = differentiable_function [parameters 0] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
60+
// CHECK: %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
61+
62+
%10 = tuple ()
63+
return %10 : $()
64+
}

0 commit comments

Comments
 (0)