Skip to content

Commit 794fa9c

Browse files
authored
Fix vector bugs (rust-lang#632)
* Fix vector * Fix reverse chunk bug
1 parent 12ebbf4 commit 794fa9c

File tree

3 files changed

+37
-24
lines changed

3 files changed

+37
-24
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11418,13 +11418,15 @@ class AdjointGenerator
1141811418

1141911419
if (subretused) {
1142011420
Value *dcall = nullptr;
11421+
assert(returnIdx);
1142111422
dcall = (returnIdx.getValue() < 0)
1142211423
? augmentcall
1142311424
: BuilderZ.CreateExtractValue(
1142411425
augmentcall, {(unsigned)returnIdx.getValue()});
1142511426
gutils->originalToNewFn[orig] = dcall;
1142611427
gutils->newToOriginalFn.erase(newCall);
1142711428
gutils->newToOriginalFn[dcall] = orig;
11429+
1142811430
assert(dcall->getType() == orig->getType());
1142911431
assert(dcall);
1143011432

enzyme/Enzyme/FunctionUtils.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1770,7 +1770,8 @@ FunctionType *getFunctionTypeForClone(
17701770
std::vector<Type *> RetTypes;
17711771
if (returnValue == ReturnType::ArgsWithReturn ||
17721772
returnValue == ReturnType::Return) {
1773-
if (returnType != DIFFE_TYPE::CONSTANT) {
1773+
if (returnType != DIFFE_TYPE::CONSTANT &&
1774+
returnType != DIFFE_TYPE::OUT_DIFF) {
17741775
RetTypes.push_back(
17751776
GradientUtils::getShadowType(FTy->getReturnType(), width));
17761777
} else {
@@ -1779,7 +1780,8 @@ FunctionType *getFunctionTypeForClone(
17791780
} else if (returnValue == ReturnType::ArgsWithTwoReturns ||
17801781
returnValue == ReturnType::TwoReturns) {
17811782
RetTypes.push_back(FTy->getReturnType());
1782-
if (returnType != DIFFE_TYPE::CONSTANT) {
1783+
if (returnType != DIFFE_TYPE::CONSTANT &&
1784+
returnType != DIFFE_TYPE::OUT_DIFF) {
17831785
RetTypes.push_back(
17841786
GradientUtils::getShadowType(FTy->getReturnType(), width));
17851787
} else {
@@ -1822,7 +1824,8 @@ FunctionType *getFunctionTypeForClone(
18221824
RetTypes.push_back(
18231825
GradientUtils::getShadowType(FTy->getReturnType(), width));
18241826
} else if (returnValue == ReturnType::TapeAndReturn) {
1825-
if (returnType != DIFFE_TYPE::CONSTANT)
1827+
if (returnType != DIFFE_TYPE::CONSTANT &&
1828+
returnType != DIFFE_TYPE::OUT_DIFF)
18261829
RetTypes.push_back(
18271830
GradientUtils::getShadowType(FTy->getReturnType(), width));
18281831
else

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -832,37 +832,42 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
832832
if (pidx == nullptr)
833833
goto endCheck;
834834

835-
if (pidx->getType() != dli->getOperand(0)->getType()) {
835+
if (pidx->getType() != getShadowType(dli->getOperand(0)->getType())) {
836836
llvm::errs() << "dli: " << *dli << "\n";
837837
llvm::errs() << "dli->getOperand(0): " << *dli->getOperand(0) << "\n";
838838
llvm::errs() << "pidx: " << *pidx << "\n";
839839
}
840-
assert(pidx->getType() == dli->getOperand(0)->getType());
840+
assert(pidx->getType() == getShadowType(dli->getOperand(0)->getType()));
841+
842+
Value *toreturn = applyChainRule(
843+
dli->getType(), BuilderM,
844+
[&](Value *pidx) {
841845
#if LLVM_VERSION_MAJOR > 7
842-
auto toreturn =
843-
BuilderM.CreateLoad(pidx->getType()->getPointerElementType(), pidx,
844-
phi->getName() + "_unwrap");
846+
auto toreturn = BuilderM.CreateLoad(dli->getType(), pidx,
847+
phi->getName() + "_unwrap");
845848
#else
846-
auto toreturn = BuilderM.CreateLoad(pidx, phi->getName() + "_unwrap");
849+
auto toreturn =
850+
BuilderM.CreateLoad(pidx, phi->getName() + "_unwrap");
847851
#endif
848-
if (auto newi = dyn_cast<Instruction>(toreturn)) {
849-
newi->copyIRFlags(dli);
850-
unwrappedLoads[toreturn] = dli;
851-
}
852+
if (auto newi = dyn_cast<Instruction>(toreturn)) {
853+
newi->copyIRFlags(dli);
854+
unwrappedLoads[toreturn] = dli;
855+
}
852856
#if LLVM_VERSION_MAJOR >= 10
853-
toreturn->setAlignment(dli->getAlign());
857+
toreturn->setAlignment(dli->getAlign());
854858
#else
855-
toreturn->setAlignment(dli->getAlignment());
859+
toreturn->setAlignment(dli->getAlignment());
856860
#endif
857-
toreturn->setVolatile(dli->isVolatile());
858-
toreturn->setOrdering(dli->getOrdering());
859-
toreturn->setSyncScopeID(dli->getSyncScopeID());
860-
toreturn->setDebugLoc(getNewFromOriginal(dli->getDebugLoc()));
861-
toreturn->setMetadata(LLVMContext::MD_tbaa,
862-
dli->getMetadata(LLVMContext::MD_tbaa));
863-
toreturn->setMetadata(
864-
LLVMContext::MD_invariant_group,
865-
dli->getMetadata(LLVMContext::MD_invariant_group));
861+
toreturn->setVolatile(dli->isVolatile());
862+
toreturn->setOrdering(dli->getOrdering());
863+
toreturn->setSyncScopeID(dli->getSyncScopeID());
864+
toreturn->setDebugLoc(getNewFromOriginal(dli->getDebugLoc()));
865+
toreturn->setMetadata(LLVMContext::MD_tbaa,
866+
dli->getMetadata(LLVMContext::MD_tbaa));
867+
return toreturn;
868+
},
869+
pidx);
870+
866871
// TODO adding to cache only legal if no alias of any future writes
867872
if (permitCache)
868873
unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] =
@@ -4825,6 +4830,9 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
48254830
new_op->setMetadata(LLVMContext::MD_invariant_group, invgroup);
48264831
}
48274832
}
4833+
if (op->getType() != inst->getType()) {
4834+
llvm::errs() << " op: " << *op << " inst: " << *inst << "\n";
4835+
}
48284836
assert(op->getType() == inst->getType());
48294837
if (!reduceRegister)
48304838
lookup_cache[BuilderM.GetInsertBlock()][val] = op;

0 commit comments

Comments
 (0)