Skip to content

Commit 1c64770

Browse files
authored
Merge pull request #30873 from dan-zheng/autodiff-sil
[AutoDiff] Minor SILOptimizer changes.
2 parents f27f1cd + 6f4b812 commit 1c64770

File tree

3 files changed

+13
-9
lines changed

3 files changed

+13
-9
lines changed

include/swift/SILOptimizer/Utils/Differentiation/AdjointValue.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ class AdjointValue final {
148148
std::get<1>(elt).print(s);
149149
},
150150
[&s] { s << ", "; });
151-
} else if (auto tupleType = getType().getAs<TupleType>()) {
151+
} else if (getType().is<TupleType>()) {
152152
s << "Tuple>(";
153153
interleave(
154154
base->value.aggregate,

lib/SILOptimizer/Utils/Differentiation/LinearMapInfo.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,11 @@ LinearMapInfo::createBranchingTraceDecl(SILBasicBlock *originalBB,
122122
auto &file = getSynthesizedFile();
123123
// Create a branching trace enum.
124124
Mangle::ASTMangler mangler;
125+
auto originalFnTy = original->getLoweredFunctionType();
126+
auto numResults = originalFnTy->getNumResults() +
127+
originalFnTy->getNumIndirectMutatingParameters();
125128
auto *resultIndices = IndexSubset::get(
126-
original->getASTContext(),
127-
original->getLoweredFunctionType()->getNumResults(), indices.source);
129+
original->getASTContext(), numResults, indices.source);
128130
auto *parameterIndices = indices.parameters;
129131
AutoDiffConfig config(parameterIndices, resultIndices, genericSig);
130132
auto enumName = mangler.mangleAutoDiffGeneratedDeclaration(
@@ -193,9 +195,11 @@ LinearMapInfo::createLinearMapStruct(SILBasicBlock *originalBB,
193195
auto &file = getSynthesizedFile();
194196
// Create a linear map struct.
195197
Mangle::ASTMangler mangler;
198+
auto originalFnTy = original->getLoweredFunctionType();
199+
auto numResults = originalFnTy->getNumResults() +
200+
originalFnTy->getNumIndirectMutatingParameters();
196201
auto *resultIndices = IndexSubset::get(
197-
original->getASTContext(),
198-
original->getLoweredFunctionType()->getNumResults(), indices.source);
202+
original->getASTContext(), numResults, indices.source);
199203
auto *parameterIndices = indices.parameters;
200204
AutoDiffConfig config(parameterIndices, resultIndices, genericSig);
201205
auto structName = mangler.mangleAutoDiffGeneratedDeclaration(

lib/SILOptimizer/Utils/Differentiation/PullbackEmitter.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1676,14 +1676,14 @@ void PullbackEmitter::visitBeginBorrowInst(BeginBorrowInst *bbi) {
16761676
void PullbackEmitter::visitBeginAccessInst(BeginAccessInst *bai) {
16771677
// Check for non-differentiable writes.
16781678
if (bai->getAccessKind() == SILAccessKind::Modify) {
1679-
if (auto *gai = dyn_cast<GlobalAddrInst>(bai->getSource())) {
1679+
if (isa<GlobalAddrInst>(bai->getSource())) {
16801680
getContext().emitNondifferentiabilityError(
16811681
bai, getInvoker(),
16821682
diag::autodiff_cannot_differentiate_writes_to_global_variables);
16831683
errorOccurred = true;
16841684
return;
16851685
}
1686-
if (auto *pbi = dyn_cast<ProjectBoxInst>(bai->getSource())) {
1686+
if (isa<ProjectBoxInst>(bai->getSource())) {
16871687
getContext().emitNondifferentiabilityError(
16881688
bai, getInvoker(),
16891689
diag::autodiff_cannot_differentiate_writes_to_mutable_captures);
@@ -1904,7 +1904,7 @@ AdjointValue PullbackEmitter::accumulateAdjointsDirect(AdjointValue lhs,
19041904
SmallVector<AdjointValue, 8> newElements;
19051905
auto lhsTy = lhsVal->getType().getASTType();
19061906
auto lhsValCopy = builder.emitCopyValueOperation(loc, lhsVal);
1907-
if (auto *tupTy = lhsTy->getAs<TupleType>()) {
1907+
if (lhsTy->is<TupleType>()) {
19081908
auto elts = builder.createDestructureTuple(loc, lhsValCopy);
19091909
llvm::for_each(elts->getResults(),
19101910
[this](SILValue result) { recordTemporary(result); });
@@ -1913,7 +1913,7 @@ AdjointValue PullbackEmitter::accumulateAdjointsDirect(AdjointValue lhs,
19131913
newElements.push_back(accumulateAdjointsDirect(
19141914
makeConcreteAdjointValue(elts->getResult(i)), rhsElt, loc));
19151915
}
1916-
} else if (auto *structDecl = lhsTy->getStructOrBoundGenericStruct()) {
1916+
} else if (lhsTy->getStructOrBoundGenericStruct()) {
19171917
auto elts =
19181918
builder.createDestructureStruct(lhsVal.getLoc(), lhsValCopy);
19191919
llvm::for_each(elts->getResults(),

0 commit comments

Comments
 (0)