@@ -11731,6 +11731,7 @@ class AdjointGenerator
11731
11731
}
11732
11732
11733
11733
Value *newcalled = nullptr ;
11734
+ FunctionType *FT = nullptr ;
11734
11735
11735
11736
if (called) {
11736
11737
newcalled = gutils->Logic .CreateForwardDiff (
@@ -11739,6 +11740,7 @@ class AdjointGenerator
11739
11740
((DiffeGradientUtils *)gutils)->FreeMemory , gutils->getWidth (),
11740
11741
tape ? tape->getType () : nullptr , nextTypeInfo, uncacheable_args,
11741
11742
/* augmented*/ subdata);
11743
+ FT = cast<Function>(newcalled)->getFunctionType ();
11742
11744
} else {
11743
11745
#if LLVM_VERSION_MAJOR >= 11
11744
11746
auto callval = orig->getCalledOperand ();
@@ -11766,10 +11768,10 @@ class AdjointGenerator
11766
11768
? (retActive ? ReturnType::TwoReturns : ReturnType::Return)
11767
11769
: (retActive ? ReturnType::Return : ReturnType::Void);
11768
11770
11769
- FunctionType *FTy = getFunctionTypeForClone (
11771
+ FT = getFunctionTypeForClone (
11770
11772
ft, Mode, gutils->getWidth (), tape ? tape->getType () : nullptr ,
11771
11773
argsInverted, false , subretVal, subretType);
11772
- PointerType *fptype = PointerType::getUnqual (FTy );
11774
+ PointerType *fptype = PointerType::getUnqual (FT );
11773
11775
newcalled = BuilderZ.CreatePointerCast (newcalled,
11774
11776
PointerType::getUnqual (fptype));
11775
11777
#if LLVM_VERSION_MAJOR > 7
@@ -11780,8 +11782,7 @@ class AdjointGenerator
11780
11782
}
11781
11783
11782
11784
assert (newcalled);
11783
- FunctionType *FT =
11784
- cast<FunctionType>(newcalled->getType ()->getPointerElementType ());
11785
+ assert (FT);
11785
11786
11786
11787
SmallVector<ValueType, 2 > BundleTypes;
11787
11788
for (auto A : argsInverted)
@@ -12055,6 +12056,7 @@ class AdjointGenerator
12055
12056
if (modifyPrimal) {
12056
12057
12057
12058
Value *newcalled = nullptr ;
12059
+ FunctionType *FT = nullptr ;
12058
12060
const AugmentedReturn *fnandtapetype = nullptr ;
12059
12061
12060
12062
if (!called) {
@@ -12096,17 +12098,28 @@ class AdjointGenerator
12096
12098
FunctionType *ft = nullptr ;
12097
12099
if (auto F = dyn_cast<Function>(callval))
12098
12100
ft = F->getFunctionType ();
12099
- else
12100
- ft = cast<FunctionType>(callval->getType ()->getPointerElementType ());
12101
+ else {
12102
+ #if LLVM_VERSION_MAJOR >= 15
12103
+ if (orig->getContext ().supportsTypedPointers ()) {
12104
+ #endif
12105
+ ft =
12106
+ cast<FunctionType>(callval->getType ()->getPointerElementType ());
12107
+ #if LLVM_VERSION_MAJOR >= 15
12108
+ } else {
12109
+ ft = orig->getFunctionType ();
12110
+ }
12111
+ #endif
12112
+ }
12101
12113
12102
12114
std::set<llvm::Type *> seen;
12103
12115
DIFFE_TYPE subretType = whatType (orig->getType (), Mode,
12104
12116
/* intAreConstant*/ false , seen);
12105
12117
auto res = getDefaultFunctionTypeForAugmentation (
12106
12118
ft, /* returnUsed*/ true , /* subretType*/ subretType);
12107
- auto fptype = PointerType::getUnqual ( FunctionType::get (
12119
+ FT = FunctionType::get (
12108
12120
StructType::get (newcalled->getContext (), res.second ), res.first ,
12109
- ft->isVarArg ()));
12121
+ ft->isVarArg ());
12122
+ auto fptype = PointerType::getUnqual (FT);
12110
12123
newcalled = BuilderZ.CreatePointerCast (newcalled,
12111
12124
PointerType::getUnqual (fptype));
12112
12125
#if LLVM_VERSION_MAJOR > 7
@@ -12149,6 +12162,7 @@ class AdjointGenerator
12149
12162
assert (subdata);
12150
12163
fnandtapetype = subdata;
12151
12164
newcalled = subdata->fn ;
12165
+ FT = cast<Function>(newcalled)->getFunctionType ();
12152
12166
12153
12167
auto found = subdata->returns .find (AugmentedStruct::DifferentialReturn);
12154
12168
if (found != subdata->returns .end ()) {
@@ -12172,11 +12186,7 @@ class AdjointGenerator
12172
12186
// sub_index_map = fnandtapetype.tapeIndices;
12173
12187
12174
12188
assert (newcalled);
12175
- FunctionType *FT = nullptr ;
12176
- if (auto F = dyn_cast<Function>(newcalled))
12177
- FT = F->getFunctionType ();
12178
- else
12179
- FT = cast<FunctionType>(newcalled->getType ()->getPointerElementType ());
12189
+ assert (FT);
12180
12190
12181
12191
// llvm::errs() << "seeing sub_index_map of " << sub_index_map->size()
12182
12192
// << " in ap " << cast<Function>(called)->getName() << "\n";
@@ -12483,6 +12493,7 @@ class AdjointGenerator
12483
12493
getReverseBuilder (Builder2);
12484
12494
12485
12495
Value *newcalled = nullptr ;
12496
+ FunctionType *FT = nullptr ;
12486
12497
12487
12498
DerivativeMode subMode = (replaceFunction || !modifyPrimal)
12488
12499
? DerivativeMode::ReverseModeCombined
@@ -12505,6 +12516,7 @@ class AdjointGenerator
12505
12516
TR.analyzer .interprocedural , subdata);
12506
12517
if (!newcalled)
12507
12518
return ;
12519
+ FT = cast<Function>(newcalled)->getFunctionType ();
12508
12520
} else {
12509
12521
12510
12522
assert (subMode != DerivativeMode::ReverseModeCombined);
@@ -12522,15 +12534,17 @@ class AdjointGenerator
12522
12534
assert (!gutils->isConstantValue (callval));
12523
12535
newcalled = lookup (gutils->invertPointerM (callval, Builder2), Builder2);
12524
12536
12525
- auto ft = cast<FunctionType>(callval->getType ()->getPointerElementType ());
12537
+ auto ft = orig->getFunctionType ();
12538
+ // cast<FunctionType>(callval->getType()->getPointerElementType());
12526
12539
12527
12540
auto res =
12528
12541
getDefaultFunctionTypeForGradient (ft, /* subretType*/ subretType);
12529
12542
// TODO Note there is empty tape added here, replace with generic
12530
12543
res.first .push_back (Type::getInt8PtrTy (newcalled->getContext ()));
12531
- auto fptype = PointerType::getUnqual ( FunctionType::get (
12544
+ FT = FunctionType::get (
12532
12545
StructType::get (newcalled->getContext (), res.second ), res.first ,
12533
- ft->isVarArg ()));
12546
+ ft->isVarArg ());
12547
+ auto fptype = PointerType::getUnqual (FT);
12534
12548
newcalled =
12535
12549
Builder2.CreatePointerCast (newcalled, PointerType::getUnqual (fptype));
12536
12550
#if LLVM_VERSION_MAJOR > 7
@@ -12554,13 +12568,7 @@ class AdjointGenerator
12554
12568
}
12555
12569
12556
12570
assert (newcalled);
12557
- // if (auto NC = dyn_cast<Function>(newcalled)) {
12558
- FunctionType *FT = nullptr ;
12559
- if (auto F = dyn_cast<Function>(newcalled))
12560
- FT = F->getFunctionType ();
12561
- else {
12562
- FT = cast<FunctionType>(newcalled->getType ()->getPointerElementType ());
12563
- }
12571
+ assert (FT);
12564
12572
12565
12573
if (false ) {
12566
12574
badfn:;
0 commit comments