Skip to content

Commit 3e460d3

Browse files
authored
More opaque pointer support (rust-lang#978)
1 parent 0c3425d commit 3e460d3

File tree

2 files changed

+81
-51
lines changed

2 files changed

+81
-51
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11731,6 +11731,7 @@ class AdjointGenerator
1173111731
}
1173211732

1173311733
Value *newcalled = nullptr;
11734+
FunctionType *FT = nullptr;
1173411735

1173511736
if (called) {
1173611737
newcalled = gutils->Logic.CreateForwardDiff(
@@ -11739,6 +11740,7 @@ class AdjointGenerator
1173911740
((DiffeGradientUtils *)gutils)->FreeMemory, gutils->getWidth(),
1174011741
tape ? tape->getType() : nullptr, nextTypeInfo, uncacheable_args,
1174111742
/*augmented*/ subdata);
11743+
FT = cast<Function>(newcalled)->getFunctionType();
1174211744
} else {
1174311745
#if LLVM_VERSION_MAJOR >= 11
1174411746
auto callval = orig->getCalledOperand();
@@ -11766,10 +11768,10 @@ class AdjointGenerator
1176611768
? (retActive ? ReturnType::TwoReturns : ReturnType::Return)
1176711769
: (retActive ? ReturnType::Return : ReturnType::Void);
1176811770

11769-
FunctionType *FTy = getFunctionTypeForClone(
11771+
FT = getFunctionTypeForClone(
1177011772
ft, Mode, gutils->getWidth(), tape ? tape->getType() : nullptr,
1177111773
argsInverted, false, subretVal, subretType);
11772-
PointerType *fptype = PointerType::getUnqual(FTy);
11774+
PointerType *fptype = PointerType::getUnqual(FT);
1177311775
newcalled = BuilderZ.CreatePointerCast(newcalled,
1177411776
PointerType::getUnqual(fptype));
1177511777
#if LLVM_VERSION_MAJOR > 7
@@ -11780,8 +11782,7 @@ class AdjointGenerator
1178011782
}
1178111783

1178211784
assert(newcalled);
11783-
FunctionType *FT =
11784-
cast<FunctionType>(newcalled->getType()->getPointerElementType());
11785+
assert(FT);
1178511786

1178611787
SmallVector<ValueType, 2> BundleTypes;
1178711788
for (auto A : argsInverted)
@@ -12055,6 +12056,7 @@ class AdjointGenerator
1205512056
if (modifyPrimal) {
1205612057

1205712058
Value *newcalled = nullptr;
12059+
FunctionType *FT = nullptr;
1205812060
const AugmentedReturn *fnandtapetype = nullptr;
1205912061

1206012062
if (!called) {
@@ -12096,17 +12098,28 @@ class AdjointGenerator
1209612098
FunctionType *ft = nullptr;
1209712099
if (auto F = dyn_cast<Function>(callval))
1209812100
ft = F->getFunctionType();
12099-
else
12100-
ft = cast<FunctionType>(callval->getType()->getPointerElementType());
12101+
else {
12102+
#if LLVM_VERSION_MAJOR >= 15
12103+
if (orig->getContext().supportsTypedPointers()) {
12104+
#endif
12105+
ft =
12106+
cast<FunctionType>(callval->getType()->getPointerElementType());
12107+
#if LLVM_VERSION_MAJOR >= 15
12108+
} else {
12109+
ft = orig->getFunctionType();
12110+
}
12111+
#endif
12112+
}
1210112113

1210212114
std::set<llvm::Type *> seen;
1210312115
DIFFE_TYPE subretType = whatType(orig->getType(), Mode,
1210412116
/*intAreConstant*/ false, seen);
1210512117
auto res = getDefaultFunctionTypeForAugmentation(
1210612118
ft, /*returnUsed*/ true, /*subretType*/ subretType);
12107-
auto fptype = PointerType::getUnqual(FunctionType::get(
12119+
FT = FunctionType::get(
1210812120
StructType::get(newcalled->getContext(), res.second), res.first,
12109-
ft->isVarArg()));
12121+
ft->isVarArg());
12122+
auto fptype = PointerType::getUnqual(FT);
1211012123
newcalled = BuilderZ.CreatePointerCast(newcalled,
1211112124
PointerType::getUnqual(fptype));
1211212125
#if LLVM_VERSION_MAJOR > 7
@@ -12149,6 +12162,7 @@ class AdjointGenerator
1214912162
assert(subdata);
1215012163
fnandtapetype = subdata;
1215112164
newcalled = subdata->fn;
12165+
FT = cast<Function>(newcalled)->getFunctionType();
1215212166

1215312167
auto found = subdata->returns.find(AugmentedStruct::DifferentialReturn);
1215412168
if (found != subdata->returns.end()) {
@@ -12172,11 +12186,7 @@ class AdjointGenerator
1217212186
// sub_index_map = fnandtapetype.tapeIndices;
1217312187

1217412188
assert(newcalled);
12175-
FunctionType *FT = nullptr;
12176-
if (auto F = dyn_cast<Function>(newcalled))
12177-
FT = F->getFunctionType();
12178-
else
12179-
FT = cast<FunctionType>(newcalled->getType()->getPointerElementType());
12189+
assert(FT);
1218012190

1218112191
// llvm::errs() << "seeing sub_index_map of " << sub_index_map->size()
1218212192
// << " in ap " << cast<Function>(called)->getName() << "\n";
@@ -12483,6 +12493,7 @@ class AdjointGenerator
1248312493
getReverseBuilder(Builder2);
1248412494

1248512495
Value *newcalled = nullptr;
12496+
FunctionType *FT = nullptr;
1248612497

1248712498
DerivativeMode subMode = (replaceFunction || !modifyPrimal)
1248812499
? DerivativeMode::ReverseModeCombined
@@ -12505,6 +12516,7 @@ class AdjointGenerator
1250512516
TR.analyzer.interprocedural, subdata);
1250612517
if (!newcalled)
1250712518
return;
12519+
FT = cast<Function>(newcalled)->getFunctionType();
1250812520
} else {
1250912521

1251012522
assert(subMode != DerivativeMode::ReverseModeCombined);
@@ -12522,15 +12534,17 @@ class AdjointGenerator
1252212534
assert(!gutils->isConstantValue(callval));
1252312535
newcalled = lookup(gutils->invertPointerM(callval, Builder2), Builder2);
1252412536

12525-
auto ft = cast<FunctionType>(callval->getType()->getPointerElementType());
12537+
auto ft = orig->getFunctionType();
12538+
// cast<FunctionType>(callval->getType()->getPointerElementType());
1252612539

1252712540
auto res =
1252812541
getDefaultFunctionTypeForGradient(ft, /*subretType*/ subretType);
1252912542
// TODO Note there is empty tape added here, replace with generic
1253012543
res.first.push_back(Type::getInt8PtrTy(newcalled->getContext()));
12531-
auto fptype = PointerType::getUnqual(FunctionType::get(
12544+
FT = FunctionType::get(
1253212545
StructType::get(newcalled->getContext(), res.second), res.first,
12533-
ft->isVarArg()));
12546+
ft->isVarArg());
12547+
auto fptype = PointerType::getUnqual(FT);
1253412548
newcalled =
1253512549
Builder2.CreatePointerCast(newcalled, PointerType::getUnqual(fptype));
1253612550
#if LLVM_VERSION_MAJOR > 7
@@ -12554,13 +12568,7 @@ class AdjointGenerator
1255412568
}
1255512569

1255612570
assert(newcalled);
12557-
// if (auto NC = dyn_cast<Function>(newcalled)) {
12558-
FunctionType *FT = nullptr;
12559-
if (auto F = dyn_cast<Function>(newcalled))
12560-
FT = F->getFunctionType();
12561-
else {
12562-
FT = cast<FunctionType>(newcalled->getType()->getPointerElementType());
12563-
}
12571+
assert(FT);
1256412572

1256512573
if (false) {
1256612574
badfn:;

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1964,35 +1964,51 @@ Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc,
19641964
ret = (idx < 0) ? tape
19651965
: entryBuilder.CreateExtractValue(tape, {(unsigned)idx});
19661966

1967-
Type *innerType = ret->getType();
1968-
for (size_t i = 0,
1969-
limit = getSubLimits(
1970-
/*inForwardPass*/ true, nullptr,
1971-
LimitContext(
1972-
/*ReverseLimit*/ reverseBlocks.size() > 0,
1973-
BuilderQ.GetInsertBlock()))
1974-
.size();
1975-
i < limit; ++i) {
1976-
if (!isa<PointerType>(innerType)) {
1977-
llvm::errs() << "mod: "
1978-
<< *BuilderQ.GetInsertBlock()->getParent()->getParent()
1979-
<< "\n";
1980-
llvm::errs() << "fn: " << *BuilderQ.GetInsertBlock()->getParent()
1981-
<< "\n";
1982-
llvm::errs() << "bq insertblock: " << *BuilderQ.GetInsertBlock()
1983-
<< "\n";
1984-
llvm::errs() << "ret: " << *ret << " type: " << *ret->getType()
1985-
<< "\n";
1986-
llvm::errs() << "innerType: " << *innerType << "\n";
1987-
if (malloc)
1988-
llvm::errs() << " malloc: " << *malloc << " i=" << i
1989-
<< " / lim = " << limit << "\n";
1967+
assert(malloc);
1968+
1969+
Type *innerType = nullptr;
1970+
1971+
#if LLVM_VERSION_MAJOR >= 15
1972+
if (ret->getContext().supportsTypedPointers()) {
1973+
#endif
1974+
innerType = ret->getType();
1975+
for (size_t i = 0,
1976+
limit = getSubLimits(
1977+
/*inForwardPass*/ true, nullptr,
1978+
LimitContext(
1979+
/*ReverseLimit*/ reverseBlocks.size() > 0,
1980+
BuilderQ.GetInsertBlock()))
1981+
.size();
1982+
i < limit; ++i) {
1983+
if (!isa<PointerType>(innerType)) {
1984+
llvm::errs() << "mod: "
1985+
<< *BuilderQ.GetInsertBlock()->getParent()->getParent()
1986+
<< "\n";
1987+
llvm::errs() << "fn: " << *BuilderQ.GetInsertBlock()->getParent()
1988+
<< "\n";
1989+
llvm::errs() << "bq insertblock: " << *BuilderQ.GetInsertBlock()
1990+
<< "\n";
1991+
llvm::errs() << "ret: " << *ret << " type: " << *ret->getType()
1992+
<< "\n";
1993+
llvm::errs() << "innerType: " << *innerType << "\n";
1994+
if (malloc)
1995+
llvm::errs() << " malloc: " << *malloc << " i=" << i
1996+
<< " / lim = " << limit << "\n";
1997+
}
1998+
assert(isa<PointerType>(innerType));
1999+
innerType = innerType->getPointerElementType();
19902000
}
1991-
assert(isa<PointerType>(innerType));
1992-
innerType = innerType->getPointerElementType();
2001+
#if LLVM_VERSION_MAJOR >= 15
2002+
} else {
2003+
assert(!ignoreType);
2004+
if (EfficientBoolCache && malloc->getType()->isIntegerTy() &&
2005+
cast<IntegerType>(malloc->getType())->getBitWidth() == 1)
2006+
innerType = Type::getInt8Ty(malloc->getContext());
2007+
else
2008+
innerType = malloc->getType();
19932009
}
2010+
#endif
19942011

1995-
assert(malloc);
19962012
if (!ignoreType) {
19972013
if (EfficientBoolCache && malloc->getType()->isIntegerTy() &&
19982014
cast<IntegerType>(malloc->getType())->getBitWidth() == 1 &&
@@ -2021,7 +2037,13 @@ Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc,
20212037
bool isi1 = !ignoreType && malloc->getType()->isIntegerTy() &&
20222038
cast<IntegerType>(malloc->getType())->getBitWidth() == 1;
20232039
assert(isa<PointerType>(cache->getType()));
2024-
assert(cache->getType()->getPointerElementType() == ret->getType());
2040+
#if LLVM_VERSION_MAJOR >= 15
2041+
if (cache->getContext().supportsTypedPointers()) {
2042+
#endif
2043+
assert(cache->getType()->getPointerElementType() == ret->getType());
2044+
#if LLVM_VERSION_MAJOR >= 15
2045+
}
2046+
#endif
20252047
entryBuilder.CreateStore(ret, cache);
20262048

20272049
auto v =
@@ -3849,7 +3871,7 @@ Constant *GradientUtils::GetOrCreateShadowConstant(
38493871
if (arg->isConstant() || arg->hasInternalLinkage() ||
38503872
arg->hasPrivateLinkage() ||
38513873
(arg->hasExternalLinkage() && arg->hasInitializer())) {
3852-
Type *type = arg->getType()->getPointerElementType();
3874+
Type *type = arg->getValueType();
38533875
auto shadow = new GlobalVariable(
38543876
*arg->getParent(), type, arg->isConstant(), arg->getLinkage(),
38553877
Constant::getNullValue(type), arg->getName() + "_shadow", arg,

0 commit comments

Comments
 (0)