Skip to content

Commit 5e347a1

Browse files
authored
Implementvector mode: memory (rust-lang#465)
1 parent 702a89f commit 5e347a1

File tree

9 files changed

+728
-222
lines changed

9 files changed

+728
-222
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 88 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -585,8 +585,9 @@ class AdjointGenerator
585585
// the instruction if the value is a potential pointer. This may not be
586586
// caught by type analysis is the result does not have a known type.
587587
if (!gutils->isConstantInstruction(&I)) {
588-
Type *isfloat =
589-
type->isFPOrFPVectorTy() ? type->getScalarType() : nullptr;
588+
Type *isfloat = I.getType()->isFPOrFPVectorTy()
589+
? I.getType()->getScalarType()
590+
: nullptr;
590591
if (!isfloat && type->isIntOrIntVectorTy()) {
591592
auto LoadSize = DL.getTypeSizeInBits(type) / 8;
592593
ConcreteType vd = BaseType::Unknown;
@@ -610,40 +611,48 @@ class AdjointGenerator
610611
getForwardBuilder(Builder2);
611612

612613
if (!gutils->isConstantValue(&I)) {
614+
Value *ip = gutils->invertPointerM(I.getOperand(0), Builder2);
615+
613616
Value *diff;
614617
if (!mask) {
618+
619+
auto rule = [&](Value *ip) {
615620
#if LLVM_VERSION_MAJOR > 7
616-
auto LI = Builder2.CreateLoad(
617-
cast<PointerType>(I.getOperand(0)->getType())
618-
->getElementType(),
619-
gutils->invertPointerM(I.getOperand(0), Builder2));
621+
auto LI = Builder2.CreateLoad(I.getType(), ip);
620622
#else
621-
auto LI = Builder2.CreateLoad(
622-
gutils->invertPointerM(I.getOperand(0), Builder2));
623+
auto LI = Builder2.CreateLoad(ip);
623624
#endif
624-
if (alignment)
625+
if (alignment)
625626
#if LLVM_VERSION_MAJOR >= 10
626-
LI->setAlignment(*alignment);
627+
LI->setAlignment(*alignment);
627628
#else
628-
LI->setAlignment(alignment);
629+
LI->setAlignment(alignment);
629630
#endif
630-
diff = LI;
631+
return LI;
632+
};
633+
634+
diff = applyChainRule(I.getType(), Builder2, rule, ip);
635+
631636
} else {
632-
Type *tys[] = {I.getType(), I.getOperand(0)->getType()};
633-
auto F = Intrinsic::getDeclaration(gutils->oldFunc->getParent(),
634-
Intrinsic::masked_load, tys);
637+
auto mi = diffe(orig_maskInit, Builder2);
638+
639+
auto rule = [&](Value *ip, Value *mi) {
640+
Type *tys[] = {I.getType(), I.getOperand(0)->getType()};
641+
auto F = Intrinsic::getDeclaration(gutils->oldFunc->getParent(),
642+
Intrinsic::masked_load, tys);
635643
#if LLVM_VERSION_MAJOR >= 10
636-
Value *alignv =
637-
ConstantInt::get(Type::getInt32Ty(mask->getContext()),
638-
alignment ? alignment->value() : 0);
644+
Value *alignv =
645+
ConstantInt::get(Type::getInt32Ty(mask->getContext()),
646+
alignment ? alignment->value() : 0);
639647
#else
640-
Value *alignv = ConstantInt::get(
641-
Type::getInt32Ty(mask->getContext()), alignment);
648+
Value *alignv = ConstantInt::get(
649+
Type::getInt32Ty(mask->getContext()), alignment);
642650
#endif
643-
Value *args[] = {
644-
gutils->invertPointerM(I.getOperand(0), Builder2), alignv,
645-
mask, diffe(orig_maskInit, Builder2)};
646-
diff = Builder2.CreateCall(F, args);
651+
Value *args[] = {ip, alignv, mask, mi};
652+
return Builder2.CreateCall(F, args);
653+
};
654+
655+
diff = applyChainRule(I.getType(), Builder2, rule, ip, mi);
647656
}
648657
setDiffe(&I, diff, Builder2);
649658
}
@@ -869,10 +878,13 @@ class AdjointGenerator
869878
IRBuilder<> Builder2(&I);
870879
getForwardBuilder(Builder2);
871880

872-
Value *diff = constantval ? Constant::getNullValue(valType)
881+
Type *diffeTy = gutils->getShadowType(valType);
882+
883+
Value *diff = constantval ? Constant::getNullValue(diffeTy)
873884
: diffe(orig_val, Builder2);
874885
gutils->setPtrDiffe(orig_ptr, diff, Builder2, align, isVolatile,
875886
ordering, syncScope, mask);
887+
876888
break;
877889
}
878890
}
@@ -889,6 +901,14 @@ class AdjointGenerator
889901

890902
if (constantval) {
891903
valueop = val;
904+
if (gutils->getWidth() > 1) {
905+
Value *array =
906+
UndefValue::get(gutils->getShadowType(val->getType()));
907+
for (unsigned i = 0; i < gutils->getWidth(); ++i) {
908+
array = storeBuilder.CreateInsertValue(array, val, {i});
909+
}
910+
valueop = array;
911+
}
892912
} else {
893913
valueop = gutils->invertPointerM(orig_val, storeBuilder);
894914
}
@@ -915,37 +935,38 @@ class AdjointGenerator
915935
}
916936
case DerivativeMode::ForwardModeSplit:
917937
case DerivativeMode::ForwardMode: {
918-
break;
919-
}
920-
}
938+
BasicBlock *oBB = phi.getParent();
939+
BasicBlock *nBB = gutils->getNewFromOriginal(oBB);
921940

922-
BasicBlock *oBB = phi.getParent();
923-
BasicBlock *nBB = gutils->getNewFromOriginal(oBB);
941+
IRBuilder<> diffeBuilder(nBB->getFirstNonPHI());
942+
diffeBuilder.setFastMathFlags(getFast());
924943

925-
IRBuilder<> diffeBuilder(nBB->getFirstNonPHI());
926-
diffeBuilder.setFastMathFlags(getFast());
944+
IRBuilder<> phiBuilder(&phi);
945+
getForwardBuilder(phiBuilder);
927946

928-
IRBuilder<> phiBuilder(&phi);
929-
getForwardBuilder(phiBuilder);
947+
Type *diffeType = gutils->getShadowType(phi.getType());
930948

931-
auto newPhi = phiBuilder.CreatePHI(phi.getType(), 1, phi.getName() + "'");
932-
for (unsigned int i = 0; i < phi.getNumIncomingValues(); ++i) {
933-
auto val = phi.getIncomingValue(i);
934-
auto block = phi.getIncomingBlock(i);
949+
auto newPhi = phiBuilder.CreatePHI(diffeType, 1, phi.getName() + "'");
950+
for (unsigned int i = 0; i < phi.getNumIncomingValues(); ++i) {
951+
auto val = phi.getIncomingValue(i);
952+
auto block = phi.getIncomingBlock(i);
935953

936-
auto newBlock = gutils->getNewFromOriginal(block);
937-
IRBuilder<> pBuilder(newBlock->getTerminator());
938-
pBuilder.setFastMathFlags(getFast());
954+
auto newBlock = gutils->getNewFromOriginal(block);
955+
IRBuilder<> pBuilder(newBlock->getTerminator());
956+
pBuilder.setFastMathFlags(getFast());
939957

940-
if (gutils->isConstantValue(val)) {
941-
newPhi->addIncoming(Constant::getNullValue(val->getType()), newBlock);
942-
} else {
943-
auto diff = diffe(val, pBuilder);
944-
newPhi->addIncoming(diff, newBlock);
958+
if (gutils->isConstantValue(val)) {
959+
newPhi->addIncoming(Constant::getNullValue(diffeType), newBlock);
960+
} else {
961+
auto diff = diffe(val, pBuilder);
962+
newPhi->addIncoming(diff, newBlock);
963+
}
945964
}
946-
}
947965

948-
setDiffe(&phi, newPhi, diffeBuilder);
966+
setDiffe(&phi, newPhi, diffeBuilder);
967+
return;
968+
}
969+
}
949970
}
950971

951972
void visitCastInst(llvm::CastInst &I) {
@@ -2589,11 +2610,26 @@ class AdjointGenerator
25892610
dsrc = Builder2.CreateIntToPtr(dsrc,
25902611
Type::getInt8PtrTy(dsrc->getContext()));
25912612

2592-
auto call =
2593-
Builder2.CreateMemCpy(ddst, dstAlign, dsrc, srcAlign, new_size);
2594-
call->setAttributes(MTI.getAttributes());
2595-
call->setTailCallKind(MTI.getTailCallKind());
2613+
auto rule = [&](Value *ddst, Value *dsrc) {
2614+
CallInst *call;
2615+
if (ID == Intrinsic::memmove) {
2616+
call =
2617+
Builder2.CreateMemMove(ddst, dstAlign, dsrc, srcAlign, new_size);
2618+
} else {
2619+
call =
2620+
Builder2.CreateMemCpy(ddst, dstAlign, dsrc, srcAlign, new_size);
2621+
}
2622+
call->setAttributes(MTI.getAttributes());
2623+
call->setMetadata(LLVMContext::MD_tbaa,
2624+
MTI.getMetadata(LLVMContext::MD_tbaa));
2625+
call->setMetadata(LLVMContext::MD_tbaa_struct,
2626+
MTI.getMetadata(LLVMContext::MD_tbaa_struct));
2627+
call->setMetadata(LLVMContext::MD_invariant_group,
2628+
MTI.getMetadata(LLVMContext::MD_invariant_group));
2629+
call->setTailCallKind(MTI.getTailCallKind());
2630+
};
25962631

2632+
applyChainRule(Builder2, rule, ddst, dsrc);
25972633
return;
25982634
}
25992635

0 commit comments

Comments
 (0)