@@ -2231,7 +2231,13 @@ class AdjointGenerator
2231
2231
IRBuilder<> Builder2 (&MTI);
2232
2232
getForwardBuilder (Builder2);
2233
2233
auto ddst = gutils->invertPointerM (orig_dst, Builder2);
2234
+ if (ddst->getType ()->isIntegerTy ())
2235
+ ddst = Builder2.CreateIntToPtr (ddst,
2236
+ Type::getInt8PtrTy (ddst->getContext ()));
2234
2237
auto dsrc = gutils->invertPointerM (orig_src, Builder2);
2238
+ if (dsrc->getType ()->isIntegerTy ())
2239
+ dsrc = Builder2.CreateIntToPtr (dsrc,
2240
+ Type::getInt8PtrTy (dsrc->getContext ()));
2235
2241
2236
2242
auto call =
2237
2243
Builder2.CreateMemCpy (ddst, dstAlign, dsrc, srcAlign, new_size);
@@ -6059,59 +6065,21 @@ class AdjointGenerator
6059
6065
subretType = DIFFE_TYPE::OUT_DIFF;
6060
6066
}
6061
6067
6062
- auto found = customCallHandlers.find (funcName.str ());
6063
- if (found != customCallHandlers.end ()) {
6064
- IRBuilder<> Builder2 (call.getParent ());
6065
- if (Mode == DerivativeMode::ReverseModeGradient ||
6066
- Mode == DerivativeMode::ReverseModeCombined)
6067
- getReverseBuilder (Builder2);
6068
-
6069
- Value *invertedReturn = nullptr ;
6070
- bool hasNonReturnUse = false ;
6071
- auto ifound = gutils->invertedPointers .find (orig);
6072
- if (ifound != gutils->invertedPointers .end ()) {
6073
- // ! We only need the shadow pointer for non-forward Mode if it is used
6074
- // ! in a non return setting
6075
- hasNonReturnUse = subretType == DIFFE_TYPE::DUP_ARG;
6076
- if (hasNonReturnUse)
6068
+ if (Mode == DerivativeMode::ForwardMode) {
6069
+ auto found = customFwdCallHandlers.find (funcName.str ());
6070
+ if (found != customFwdCallHandlers.end ()) {
6071
+ Value *invertedReturn = nullptr ;
6072
+ auto ifound = gutils->invertedPointers .find (orig);
6073
+ if (ifound != gutils->invertedPointers .end ()) {
6077
6074
invertedReturn = cast<PHINode>(&*ifound->second );
6078
- }
6075
+ }
6079
6076
6080
- Value *normalReturn = subretused ? newCall : nullptr ;
6077
+ Value *normalReturn = subretused ? newCall : nullptr ;
6081
6078
6082
- Value *tape = nullptr ;
6079
+ found-> second (BuilderZ, orig, *gutils, normalReturn, invertedReturn) ;
6083
6080
6084
- if (Mode == DerivativeMode::ReverseModePrimal ||
6085
- Mode == DerivativeMode::ReverseModeCombined) {
6086
- found->second .first (BuilderZ, orig, *gutils, normalReturn,
6087
- invertedReturn, tape);
6088
- if (tape)
6089
- gutils->cacheForReverse (BuilderZ, tape,
6090
- getIndex (orig, CacheType::Tape));
6091
- }
6092
-
6093
- if (Mode == DerivativeMode::ReverseModeGradient ||
6094
- Mode == DerivativeMode::ReverseModeCombined) {
6095
- if (Mode == DerivativeMode::ReverseModeGradient &&
6096
- augmentedReturn->tapeIndices .find (std::make_pair (
6097
- orig, CacheType::Tape)) != augmentedReturn->tapeIndices .end ()) {
6098
- tape = BuilderZ.CreatePHI (Type::getInt32Ty (orig->getContext ()), 0 );
6099
- tape = gutils->cacheForReverse (BuilderZ, tape,
6100
- getIndex (orig, CacheType::Tape),
6101
- /* ignoreType*/ true );
6102
- }
6103
- if (tape)
6104
- tape = gutils->lookupM (tape, Builder2);
6105
- found->second .second (Builder2, orig, *(DiffeGradientUtils *)gutils,
6106
- tape);
6107
- }
6108
-
6109
- if (ifound != gutils->invertedPointers .end ()) {
6110
- auto placeholder = cast<PHINode>(&*ifound->second );
6111
- if (!hasNonReturnUse) {
6112
- gutils->invertedPointers .erase (ifound);
6113
- gutils->erase (placeholder);
6114
- } else {
6081
+ if (ifound != gutils->invertedPointers .end ()) {
6082
+ auto placeholder = cast<PHINode>(&*ifound->second );
6115
6083
if (invertedReturn && invertedReturn != placeholder) {
6116
6084
if (invertedReturn->getType () != orig->getType ()) {
6117
6085
llvm::errs () << " o: " << *orig << " \n " ;
@@ -6126,50 +6094,143 @@ class AdjointGenerator
6126
6094
assert (invertedReturn->getType () == orig->getType ());
6127
6095
placeholder->replaceAllUsesWith (invertedReturn);
6128
6096
gutils->erase (placeholder);
6129
- } else
6130
- invertedReturn = placeholder;
6131
-
6132
- invertedReturn = gutils->cacheForReverse (
6133
- BuilderZ, invertedReturn, getIndex (orig, CacheType::Shadow));
6134
-
6135
- gutils->invertedPointers .insert (std::make_pair (
6136
- (const Value *)orig, InvertedPointerVH (gutils, invertedReturn)));
6097
+ gutils->invertedPointers .insert (
6098
+ std::make_pair ((const Value *)orig,
6099
+ InvertedPointerVH (gutils, invertedReturn)));
6100
+ } else {
6101
+ gutils->invertedPointers .erase (orig);
6102
+ gutils->erase (placeholder);
6103
+ }
6137
6104
}
6138
- }
6139
-
6140
- bool primalNeededInReverse;
6141
6105
6142
- if (gutils->knownRecomputeHeuristic .count (orig)) {
6143
- primalNeededInReverse = !gutils->knownRecomputeHeuristic [orig];
6144
- } else {
6145
- std::map<UsageKey, bool > Seen;
6146
- for (auto pair : gutils->knownRecomputeHeuristic )
6147
- if (!pair.second )
6148
- Seen[UsageKey (pair.first , ValueType::Primal)] = false ;
6149
- primalNeededInReverse = is_value_needed_in_reverse<ValueType::Primal>(
6150
- TR, gutils, orig, Mode, Seen, oldUnreachable);
6151
- }
6152
- if (subretused && primalNeededInReverse) {
6153
- if (normalReturn != newCall) {
6154
- assert (normalReturn->getType () == newCall->getType ());
6155
- gutils->replaceAWithB (newCall, normalReturn);
6156
- BuilderZ.SetInsertPoint (newCall->getNextNode ());
6157
- gutils->erase (newCall);
6158
- }
6159
- normalReturn = gutils->cacheForReverse (BuilderZ, normalReturn,
6160
- getIndex (orig, CacheType::Self));
6161
- } else {
6162
6106
if (normalReturn && normalReturn != newCall) {
6163
6107
assert (normalReturn->getType () == newCall->getType ());
6164
6108
assert (Mode != DerivativeMode::ReverseModeGradient);
6165
6109
gutils->replaceAWithB (newCall, normalReturn);
6166
- BuilderZ.SetInsertPoint (newCall->getNextNode ());
6167
6110
gutils->erase (newCall);
6168
- } else if (!orig->mayWriteToMemory () ||
6169
- Mode == DerivativeMode::ReverseModeGradient)
6170
- eraseIfUnused (*orig, /* erase*/ true , /* check*/ false );
6111
+ }
6112
+ eraseIfUnused (*orig);
6113
+ return ;
6114
+ }
6115
+ }
6116
+
6117
+ if (Mode == DerivativeMode::ReverseModePrimal ||
6118
+ Mode == DerivativeMode::ReverseModeCombined ||
6119
+ Mode == DerivativeMode::ReverseModeGradient) {
6120
+ auto found = customCallHandlers.find (funcName.str ());
6121
+ if (found != customCallHandlers.end ()) {
6122
+ IRBuilder<> Builder2 (call.getParent ());
6123
+ if (Mode == DerivativeMode::ReverseModeGradient ||
6124
+ Mode == DerivativeMode::ReverseModeCombined)
6125
+ getReverseBuilder (Builder2);
6126
+
6127
+ Value *invertedReturn = nullptr ;
6128
+ bool hasNonReturnUse = false ;
6129
+ auto ifound = gutils->invertedPointers .find (orig);
6130
+ if (ifound != gutils->invertedPointers .end ()) {
6131
+ // ! We only need the shadow pointer for non-forward Mode if it is used
6132
+ // ! in a non return setting
6133
+ hasNonReturnUse = subretType == DIFFE_TYPE::DUP_ARG;
6134
+ if (hasNonReturnUse)
6135
+ invertedReturn = cast<PHINode>(&*ifound->second );
6136
+ }
6137
+
6138
+ Value *normalReturn = subretused ? newCall : nullptr ;
6139
+
6140
+ Value *tape = nullptr ;
6141
+
6142
+ if (Mode == DerivativeMode::ReverseModePrimal ||
6143
+ Mode == DerivativeMode::ReverseModeCombined) {
6144
+ found->second .first (BuilderZ, orig, *gutils, normalReturn,
6145
+ invertedReturn, tape);
6146
+ if (tape)
6147
+ gutils->cacheForReverse (BuilderZ, tape,
6148
+ getIndex (orig, CacheType::Tape));
6149
+ }
6150
+
6151
+ if (Mode == DerivativeMode::ReverseModeGradient ||
6152
+ Mode == DerivativeMode::ReverseModeCombined) {
6153
+ if (Mode == DerivativeMode::ReverseModeGradient &&
6154
+ augmentedReturn->tapeIndices .find (
6155
+ std::make_pair (orig, CacheType::Tape)) !=
6156
+ augmentedReturn->tapeIndices .end ()) {
6157
+ tape = BuilderZ.CreatePHI (Type::getInt32Ty (orig->getContext ()), 0 );
6158
+ tape = gutils->cacheForReverse (BuilderZ, tape,
6159
+ getIndex (orig, CacheType::Tape),
6160
+ /* ignoreType*/ true );
6161
+ }
6162
+ if (tape)
6163
+ tape = gutils->lookupM (tape, Builder2);
6164
+ found->second .second (Builder2, orig, *(DiffeGradientUtils *)gutils,
6165
+ tape);
6166
+ }
6167
+
6168
+ if (ifound != gutils->invertedPointers .end ()) {
6169
+ auto placeholder = cast<PHINode>(&*ifound->second );
6170
+ if (!hasNonReturnUse) {
6171
+ gutils->invertedPointers .erase (ifound);
6172
+ gutils->erase (placeholder);
6173
+ } else {
6174
+ if (invertedReturn && invertedReturn != placeholder) {
6175
+ if (invertedReturn->getType () != orig->getType ()) {
6176
+ llvm::errs () << " o: " << *orig << " \n " ;
6177
+ llvm::errs () << " ot: " << *orig->getType () << " \n " ;
6178
+ llvm::errs () << " ir: " << *invertedReturn << " \n " ;
6179
+ llvm::errs () << " irt: " << *invertedReturn->getType () << " \n " ;
6180
+ llvm::errs () << " p: " << *placeholder << " \n " ;
6181
+ llvm::errs () << " PT: " << *placeholder->getType () << " \n " ;
6182
+ llvm::errs () << " newCall: " << *newCall << " \n " ;
6183
+ llvm::errs () << " newCallT: " << *newCall->getType () << " \n " ;
6184
+ }
6185
+ assert (invertedReturn->getType () == orig->getType ());
6186
+ placeholder->replaceAllUsesWith (invertedReturn);
6187
+ gutils->erase (placeholder);
6188
+ } else
6189
+ invertedReturn = placeholder;
6190
+
6191
+ invertedReturn = gutils->cacheForReverse (
6192
+ BuilderZ, invertedReturn, getIndex (orig, CacheType::Shadow));
6193
+
6194
+ gutils->invertedPointers .insert (
6195
+ std::make_pair ((const Value *)orig,
6196
+ InvertedPointerVH (gutils, invertedReturn)));
6197
+ }
6198
+ }
6199
+
6200
+ bool primalNeededInReverse;
6201
+
6202
+ if (gutils->knownRecomputeHeuristic .count (orig)) {
6203
+ primalNeededInReverse = !gutils->knownRecomputeHeuristic [orig];
6204
+ } else {
6205
+ std::map<UsageKey, bool > Seen;
6206
+ for (auto pair : gutils->knownRecomputeHeuristic )
6207
+ if (!pair.second )
6208
+ Seen[UsageKey (pair.first , ValueType::Primal)] = false ;
6209
+ primalNeededInReverse = is_value_needed_in_reverse<ValueType::Primal>(
6210
+ TR, gutils, orig, Mode, Seen, oldUnreachable);
6211
+ }
6212
+ if (subretused && primalNeededInReverse) {
6213
+ if (normalReturn != newCall) {
6214
+ assert (normalReturn->getType () == newCall->getType ());
6215
+ gutils->replaceAWithB (newCall, normalReturn);
6216
+ BuilderZ.SetInsertPoint (newCall->getNextNode ());
6217
+ gutils->erase (newCall);
6218
+ }
6219
+ normalReturn = gutils->cacheForReverse (
6220
+ BuilderZ, normalReturn, getIndex (orig, CacheType::Self));
6221
+ } else {
6222
+ if (normalReturn && normalReturn != newCall) {
6223
+ assert (normalReturn->getType () == newCall->getType ());
6224
+ assert (Mode != DerivativeMode::ReverseModeGradient);
6225
+ gutils->replaceAWithB (newCall, normalReturn);
6226
+ BuilderZ.SetInsertPoint (newCall->getNextNode ());
6227
+ gutils->erase (newCall);
6228
+ } else if (!orig->mayWriteToMemory () ||
6229
+ Mode == DerivativeMode::ReverseModeGradient)
6230
+ eraseIfUnused (*orig, /* erase*/ true , /* check*/ false );
6231
+ }
6232
+ return ;
6171
6233
}
6172
- return ;
6173
6234
}
6174
6235
6175
6236
if (Mode != DerivativeMode::ReverseModePrimal && called) {
@@ -7875,6 +7936,9 @@ class AdjointGenerator
7875
7936
argsInverted.push_back (DIFFE_TYPE::DUP_ARG);
7876
7937
}
7877
7938
}
7939
+ if (!called)
7940
+ llvm::errs () << *called << " \n " ;
7941
+ assert (called);
7878
7942
7879
7943
auto newcalled = gutils->Logic .CreateForwardDiff (
7880
7944
cast<Function>(called), subretType, argsInverted, gutils->TLI ,
0 commit comments