Skip to content

Commit e7d2d81

Browse files
committed
WIP
1 parent f2819a8 commit e7d2d81

19 files changed

+115
-78
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,6 @@ class AutoDiffIndexSubset : public llvm::FoldingSetNode {
277277
IntRange<> range);
278278
static AutoDiffIndexSubset *get(ASTContext &ctx, unsigned capacity,
279279
ArrayRef<unsigned> indices);
280-
template<typename TBool>
281-
static AutoDiffIndexSubset *get(ASTContext &ctx, ArrayRef<TBool> bits);
282280

283281
unsigned getCapacity() const {
284282
return capacity;
@@ -297,11 +295,14 @@ class AutoDiffIndexSubset : public llvm::FoldingSetNode {
297295
return getBitWord(bitWordIndex) & (1 << offset);
298296
}
299297

298+
bool isEmpty() const;
300299
bool equals(const AutoDiffIndexSubset *other) const;
301300
bool isSubsetOf(const AutoDiffIndexSubset *other) const;
302301
bool isSupersetOf(const AutoDiffIndexSubset *other) const;
303302

304303
AutoDiffIndexSubset *adding(unsigned index, ASTContext &ctx) const;
304+
AutoDiffIndexSubset *extendingCapacity(ASTContext &ctx,
305+
unsigned newCapacity) const;
305306

306307
void Profile(llvm::FoldingSetNodeID &id) const;
307308

@@ -349,13 +350,13 @@ class AutoDiffIndexSubset : public llvm::FoldingSetNode {
349350

350351
bool operator==(const iterator &other) const {
351352
assert(&parent == &other.parent &&
352-
"Comparing iterators from different BitVectors");
353+
"Comparing iterators from different AutoDiffIndexSubsets");
353354
return current == other.current;
354355
}
355356

356357
bool operator!=(const iterator &other) const {
357358
assert(&parent == &other.parent &&
358-
"Comparing iterators from different BitVectors");
359+
"Comparing iterators from different AutoDiffIndexSubsets");
359360
return current != other.current;
360361
}
361362
};

include/swift/AST/DiagnosticsParse.def

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1535,7 +1535,9 @@ ERROR(sil_inst_autodiff_operand_list_expected_rbrace,PointsToFirstBadToken,
15351535
ERROR(sil_inst_autodiff_num_operand_list_order_mismatch,PointsToFirstBadToken,
15361536
"the number of operand lists does not match the order", ())
15371537
ERROR(sil_inst_autodiff_expected_associated_function_kind_attr,PointsToFirstBadToken,
1538-
"expects an assoiacted function kind attribute, e.g. '[jvp]'", ())
1538+
"expected an assoiacted function kind attribute, e.g. '[jvp]'", ())
1539+
ERROR(sil_inst_autodiff_expected_function_type_operand,PointsToFirstBadToken,
1540+
"expected an operand of a function type", ())
15391541

15401542
//------------------------------------------------------------------------------
15411543
// MARK: Generics parsing diagnostics

include/swift/AST/Types.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4148,7 +4148,7 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
41484148

41494149
// SWIFT_ENABLE_TENSORFLOW
41504150
CanSILFunctionType getWithDifferentiability(
4151-
unsigned differentiationOrder, const SmallBitVector &parameterIndices);
4151+
unsigned differentiationOrder, AutoDiffIndexSubset *parameterIndices);
41524152

41534153
CanSILFunctionType getWithoutDifferentiability();
41544154

@@ -4164,7 +4164,7 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
41644164
/// differentiate with respect to for this differentiable function type. (e.g.
41654165
/// which parameters are not @nondiff). The function type must be
41664166
/// differentiable.
4167-
SmallBitVector getDifferentiationParameterIndices() const;
4167+
AutoDiffIndexSubset *getDifferentiationParameterIndices();
41684168

41694169
/// If this is a @convention(witness_method) function with a class
41704170
/// constrained self parameter, return the class constraint for the

include/swift/SIL/SILBuilder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ class SILBuilder {
504504

505505
/// SWIFT_ENABLE_TENSORFLOW
506506
AutoDiffFunctionInst *createAutoDiffFunction(
507-
SILLocation loc, const llvm::SmallBitVector &parameterIndices,
507+
SILLocation loc, AutoDiffIndexSubset *parameterIndices,
508508
unsigned differentiationOrder, SILValue original,
509509
ArrayRef<SILValue> associatedFunctions = {}) {
510510
return insert(AutoDiffFunctionInst::create(getModule(),

include/swift/SIL/SILFunction.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ class SILDifferentiableAttr final {
174174
SILFunction *getOriginal() const { return Original; }
175175

176176
const SILAutoDiffIndices &getIndices() const { return indices; }
177+
void setIndices(const SILAutoDiffIndices &indices) {
178+
this->indices = indices;
179+
}
177180

178181
TrailingWhereClause *getWhereClause() const { return WhereClause; }
179182

include/swift/SIL/SILInstruction.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7721,7 +7721,7 @@ class AutoDiffFunctionInst final :
77217721
private:
77227722
friend SILBuilder;
77237723
/// Differentiation parameter indices.
7724-
SmallBitVector parameterIndices;
7724+
AutoDiffIndexSubset *parameterIndices;
77257725
/// The order of differentiation.
77267726
unsigned differentiationOrder;
77277727
/// The number of operands. The first operand is always the original function.
@@ -7730,28 +7730,28 @@ class AutoDiffFunctionInst final :
77307730
unsigned numOperands;
77317731

77327732
AutoDiffFunctionInst(SILModule &module, SILDebugLocation debugLoc,
7733-
const SmallBitVector &parameterIndices,
7733+
AutoDiffIndexSubset *parameterIndices,
77347734
unsigned differentiationOrder,
77357735
SILValue originalFunction,
77367736
ArrayRef<SILValue> associatedFunctions);
77377737

77387738
public:
77397739
static AutoDiffFunctionInst *create(SILModule &module,
77407740
SILDebugLocation debugLoc,
7741-
const SmallBitVector &parameterIndices,
7741+
AutoDiffIndexSubset *parameterIndices,
77427742
unsigned differentiationOrder,
77437743
SILValue originalFunction,
77447744
ArrayRef<SILValue> associatedFunctions);
77457745

77467746
static SILType getAutoDiffType(SILValue original,
77477747
unsigned differentiationOrder,
7748-
const SmallBitVector &parameterIndices);
7748+
AutoDiffIndexSubset *parameterIndices);
77497749

77507750
/// Returns the original function.
77517751
SILValue getOriginalFunction() const { return getAllOperands()[0].get(); }
77527752

77537753
/// Returns differentiation indices.
7754-
const SmallBitVector &getParameterIndices() const {
7754+
AutoDiffIndexSubset *getParameterIndices() const {
77557755
return parameterIndices;
77567756
}
77577757

lib/AST/AutoDiff.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -364,17 +364,6 @@ AutoDiffIndexSubset::get(ASTContext &ctx, unsigned capacity, bool includeAll) {
364364
SmallVector<unsigned, 8>(capacity, (unsigned)includeAll));
365365
}
366366

367-
template<typename TBool>
368-
AutoDiffIndexSubset *
369-
AutoDiffIndexSubset::get(ASTContext &ctx, ArrayRef<TBool> bits) {
370-
SmallVector<unsigned, 8> indices;
371-
indices.reserve(bits.size());
372-
for (auto i : indices(bits))
373-
if (bits[i])
374-
indices.push_back(i);
375-
return get(ctx, bits.size(), indices);
376-
}
377-
378367
AutoDiffIndexSubset *AutoDiffIndexSubset::get(ASTContext &ctx,
379368
unsigned capacity,
380369
IntRange<> range) {
@@ -408,6 +397,10 @@ unsigned AutoDiffIndexSubset::getNumIndices() const {
408397
});
409398
}
410399

400+
bool AutoDiffIndexSubset::isEmpty() const {
401+
return llvm::all_of(getBitWords(), [](BitWord bw) { return !(bool)bw; });
402+
}
403+
411404
bool AutoDiffIndexSubset::equals(const AutoDiffIndexSubset *other) const {
412405
return capacity == other->getCapacity() &&
413406
getBitWords().equals(other->getBitWords());
@@ -448,6 +441,15 @@ AutoDiffIndexSubset::adding(unsigned index, ASTContext &ctx) const {
448441
return get(ctx, capacity, newIndices);
449442
}
450443

444+
AutoDiffIndexSubset *AutoDiffIndexSubset::extendingCapacity(
445+
ASTContext &ctx, unsigned newCapacity) const {
446+
assert(newCapacity >= getCapacity());
447+
SmallVector<unsigned, 8> indices;
448+
for (auto index : getIndices())
449+
indices.push_back(index);
450+
return AutoDiffIndexSubset::get(ctx, newCapacity, indices);
451+
}
452+
451453
int AutoDiffIndexSubset::findNext(int startIndex) const {
452454
if (numBitWords == 0)
453455
return -1;

lib/IRGen/GenDiffFunc.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ namespace {
3939
class DiffFuncFieldInfo final : public RecordField<DiffFuncFieldInfo> {
4040
public:
4141
DiffFuncFieldInfo(DiffFuncIndex index, const TypeInfo &type,
42-
const SmallBitVector &parameterIndices)
42+
AutoDiffIndexSubset *parameterIndices)
4343
: RecordField(type), Index(index), ParameterIndices(parameterIndices) {}
4444

4545
/// The field index.
4646
const DiffFuncIndex Index;
4747

4848
/// The parameter indices.
49-
SmallBitVector ParameterIndices;
49+
AutoDiffIndexSubset *ParameterIndices;
5050

5151
std::string getFieldName() const {
5252
auto extractee = std::get<0>(Index);
@@ -119,7 +119,7 @@ class DiffFuncTypeBuilder
119119
DiffFuncIndex> {
120120

121121
SILFunctionType *origFnTy;
122-
SmallBitVector parameterIndices;
122+
AutoDiffIndexSubset *parameterIndices;
123123

124124
public:
125125
DiffFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy)

lib/ParseSIL/ParseSIL.cpp

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,7 @@ void SILParser::convertRequirements(SILFunction *F,
974974

975975
/// SWIFT_ENABLE_TENSORFLOW
976976
/// Parse a `differentiable` attribute, e.g.
977-
/// `[differentiable wrt 0, 1 adjoint @other]`.
977+
/// `[differentiable wrt 0, 1 vjp @other]`.
978978
/// Returns true on error.
979979
static bool parseDifferentiableAttr(
980980
SmallVectorImpl<SILDifferentiableAttr *> &DAs, SILParser &SP) {
@@ -1045,9 +1045,12 @@ static bool parseDifferentiableAttr(
10451045
whereLoc, requirementReprs);
10461046
}
10471047
// Create a SILDifferentiableAttr and we are done.
1048+
auto maxIndexRef = std::max_element(ParamIndices.begin(), ParamIndices.end());
1049+
auto *paramIndicesSubset = AutoDiffIndexSubset::get(
1050+
P.Context, maxIndexRef ? *maxIndexRef + 1 : 0, ParamIndices);
10481051
auto *Attr = SILDifferentiableAttr::create(
1049-
SP.SILMod, {SourceIndex, ParamIndices}, JVPName.str(), VJPName.str(),
1050-
WhereClause);
1052+
SP.SILMod, {SourceIndex, paramIndicesSubset}, JVPName.str(),
1053+
VJPName.str(), WhereClause);
10511054
DAs.push_back(Attr);
10521055
return false;
10531056
}
@@ -2873,7 +2876,7 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
28732876
// {%1 : $T, %2 : $T}, {%3 : $T, %4 : $T}
28742877
// ^ jvp ^ vjp
28752878
SourceLoc lastLoc;
2876-
SmallBitVector parameterIndices(32);
2879+
SmallVector<unsigned, 8> parameterIndices;
28772880
unsigned order = 1;
28782881
// Parse optional `[wrt <integer_literal>...]`
28792882
if (P.Tok.is(tok::l_square) &&
@@ -2882,15 +2885,12 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
28822885
P.consumeToken(tok::l_square);
28832886
P.consumeToken(tok::identifier);
28842887
// Parse indices.
2885-
unsigned size = parameterIndices.size();
28862888
while (P.Tok.is(tok::integer_literal)) {
28872889
unsigned index;
28882890
if (P.parseUnsignedInteger(index, lastLoc,
28892891
diag::sil_inst_autodiff_expected_parameter_index))
28902892
return true;
2891-
if (index >= size)
2892-
parameterIndices.resize((size *= 2));
2893-
parameterIndices.set(index);
2893+
parameterIndices.push_back(index);
28942894
}
28952895
if (P.parseToken(tok::r_square,
28962896
diag::sil_inst_autodiff_attr_expected_rsquare,
@@ -2917,8 +2917,15 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
29172917
}
29182918
// Parse the original function value.
29192919
SILValue original;
2920-
if (parseTypedValueRef(original, B))
2920+
SourceLoc originalOperandLoc;
2921+
if (parseTypedValueRef(original, originalOperandLoc, B))
29212922
return true;
2923+
auto fnType = original->getType().getAs<SILFunctionType>();
2924+
if (!fnType) {
2925+
P.diagnose(originalOperandLoc,
2926+
diag::sil_inst_autodiff_expected_function_type_operand);
2927+
return true;
2928+
}
29222929
SmallVector<SILValue, 16> associatedFunctions;
29232930
// Parse optional operand lists `with { <operand> , <operand> }, ...`.
29242931
if (P.Tok.is(tok::identifier) && P.Tok.getText() == "with") {
@@ -2950,7 +2957,10 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
29502957
}
29512958
if (parseSILDebugLocation(InstLoc, B))
29522959
return true;
2953-
ResultVal = B.createAutoDiffFunction(InstLoc, parameterIndices, order,
2960+
auto *parameterIndicesSubset =
2961+
AutoDiffIndexSubset::get(P.Context, fnType->getNumParameters(),
2962+
parameterIndices);
2963+
ResultVal = B.createAutoDiffFunction(InstLoc, parameterIndicesSubset, order,
29542964
original, associatedFunctions);
29552965
break;
29562966
}
@@ -5714,6 +5724,18 @@ bool SILParserTUState::parseDeclSIL(Parser &P) {
57145724

57155725
// SWIFT_ENABLE_TENSORFLOW
57165726
for (auto &attr : DiffAttrs) {
5727+
// Resolve parameter indices to have the right capacity, if it's
5728+
// different from the number of parameters. We have to do this because
5729+
// the parser does not know the function type before creating a
5730+
// `SILDifferentiableAttr`, so it had to find the max of all provided
5731+
// indices.
5732+
if (attr->getIndices().parameters->getCapacity() !=
5733+
SILFnType->getNumParameters()) {
5734+
auto *newParamIndices = attr->getIndices().parameters
5735+
->extendingCapacity(P.Context, SILFnType->getNumParameters());
5736+
attr->setIndices({attr->getIndices().source, newParamIndices});
5737+
}
5738+
57175739
// Resolve where clause requirements.
57185740
// If no where clause, continue.
57195741
if (!attr->getWhereClause())

lib/SIL/SILDeclRef.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const {
683683
getDecl()->getInterfaceType()->castTo<AnyFunctionType>();
684684
auto silParameterIndices =
685685
autoDiffAssociatedFunctionIdentifier->getParameterIndices()->getLowered(
686-
functionTy);
686+
functionTy->getASTContext(), functionTy);
687687
SILAutoDiffIndices indices(/*source*/ 0, silParameterIndices);
688688
std::string mangledKind;
689689
switch (autoDiffAssociatedFunctionIdentifier->getKind()) {

lib/SIL/SILFunctionBuilder.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ void SILFunctionBuilder::addFunctionAttributes(SILFunction *F,
8989
// Get lowered argument indices.
9090
auto paramIndices = A->getParameterIndices();
9191
auto loweredParamIndices = paramIndices->getLowered(
92+
F->getASTContext(),
9293
decl->getInterfaceType()->castTo<AnyFunctionType>());
9394
SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices);
9495
auto silDiffAttr = SILDifferentiableAttr::create(

lib/SIL/SILFunctionType.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -98,30 +98,30 @@ CanType SILFunctionType::getSelfInstanceType() const {
9898
}
9999

100100
// SWIFT_ENABLE_TENSORFLOW
101-
SmallBitVector
102-
SILFunctionType::getDifferentiationParameterIndices() const {
101+
AutoDiffIndexSubset *
102+
SILFunctionType::getDifferentiationParameterIndices() {
103103
assert(isDifferentiable());
104-
SmallBitVector result(NumParameters, true);
104+
SmallVector<unsigned, 8> result;
105105
for (auto valueAndIndex : enumerate(getParameters()))
106-
if (valueAndIndex.value().getDifferentiability() ==
106+
if (valueAndIndex.value().getDifferentiability() !=
107107
SILParameterDifferentiability::NotDifferentiable)
108-
result.reset(valueAndIndex.index());
109-
return result;
108+
result.push_back(valueAndIndex.index());
109+
return AutoDiffIndexSubset::get(getASTContext(), getNumParameters(), result);
110110
}
111111

112112
CanSILFunctionType SILFunctionType::getWithDifferentiability(
113-
unsigned differentiationOrder,
114-
const SmallBitVector &parameterIndices) {
113+
unsigned differentiationOrder, AutoDiffIndexSubset *parameterIndices) {
115114
// FIXME(rxwei): Handle differentiation order.
116115

117116
SmallVector<SILParameterInfo, 8> newParameters;
118117
for (auto paramAndIndex : enumerate(getParameters())) {
119118
auto &param = paramAndIndex.value();
120119
unsigned index = paramAndIndex.index();
121120
newParameters.push_back(param.getWithDifferentiability(
122-
index < parameterIndices.size() && parameterIndices[index]
123-
? SILParameterDifferentiability::DifferentiableOrNotApplicable
124-
: SILParameterDifferentiability::NotDifferentiable));
121+
index < parameterIndices->getCapacity() &&
122+
parameterIndices->contains(index)
123+
? SILParameterDifferentiability::DifferentiableOrNotApplicable
124+
: SILParameterDifferentiability::NotDifferentiable));
125125
}
126126

127127
auto newExtInfo = getExtInfo().withDifferentiable();
@@ -212,7 +212,8 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
212212

213213
// Helper function testing if we are differentiating wrt this index.
214214
auto isWrtIndex = [&](unsigned index) -> bool {
215-
return index < parameterIndices.size() && parameterIndices[index];
215+
return index < parameterIndices->getCapacity() &&
216+
parameterIndices->contains(index);
216217
};
217218

218219
// Calculate WRT parameter infos, in the order that they should appear in the
@@ -2316,8 +2317,8 @@ const SILConstantInfo &TypeConverter::getConstantInfo(SILDeclRef constant) {
23162317
if (auto *autoDiffFuncId = constant.autoDiffAssociatedFunctionIdentifier) {
23172318
auto origFnConstantInfo =
23182319
getConstantInfo(constant.asAutoDiffOriginalFunction());
2319-
auto loweredIndices =
2320-
autoDiffFuncId->getParameterIndices()->getLowered(formalInterfaceType);
2320+
auto loweredIndices = autoDiffFuncId->getParameterIndices()
2321+
->getLowered(Context, formalInterfaceType);
23212322
silFnType = origFnConstantInfo.SILFnType->getAutoDiffAssociatedFunctionType(
23222323
loweredIndices, /*resultIndex*/ 0,
23232324
autoDiffFuncId->getDifferentiationOrder(), autoDiffFuncId->getKind(), M,

0 commit comments

Comments
 (0)