@@ -4655,32 +4655,16 @@ class AdjointGenerator
4655
4655
args.push_back (lookup (argi, Builder2));
4656
4656
}
4657
4657
4658
- if (gutils->isConstantValue (call.getArgOperand (i)) && !foreignFunction) {
4659
- argsInverted.push_back (DIFFE_TYPE::CONSTANT);
4658
+ auto argTy = gutils->getDiffeType (call.getArgOperand (i), foreignFunction);
4659
+ argsInverted.push_back (argTy);
4660
+
4661
+ if (argTy == DIFFE_TYPE::CONSTANT) {
4660
4662
continue ;
4661
4663
}
4662
4664
4663
4665
auto argType = argi->getType ();
4664
4666
4665
- if (!argType->isFPOrFPVectorTy () &&
4666
- TR.query (call.getArgOperand (i)).Inner0 ().isPossiblePointer ()) {
4667
- DIFFE_TYPE ty = DIFFE_TYPE::DUP_ARG;
4668
- if (argType->isPointerTy ()) {
4669
- #if LLVM_VERSION_MAJOR >= 12
4670
- auto at = getUnderlyingObject (call.getArgOperand (i), 100 );
4671
- #else
4672
- auto at = GetUnderlyingObject (
4673
- call.getArgOperand (i),
4674
- gutils->oldFunc ->getParent ()->getDataLayout (), 100 );
4675
- #endif
4676
- if (auto arg = dyn_cast<Argument>(at)) {
4677
- if (constant_args[arg->getArgNo ()] == DIFFE_TYPE::DUP_NONEED) {
4678
- ty = DIFFE_TYPE::DUP_NONEED;
4679
- }
4680
- }
4681
- }
4682
- argsInverted.push_back (ty);
4683
-
4667
+ if (argTy == DIFFE_TYPE::DUP_ARG || argTy == DIFFE_TYPE::DUP_NONEED) {
4684
4668
if (Mode != DerivativeMode::ReverseModePrimal) {
4685
4669
IRBuilder<> Builder2 (call.getParent ());
4686
4670
getReverseBuilder (Builder2);
@@ -4699,7 +4683,6 @@ class AdjointGenerator
4699
4683
assert (TR.query (call.getArgOperand (i)).Inner0 ().isFloat ());
4700
4684
OutTypes.push_back (call.getArgOperand (i));
4701
4685
OutFPTypes.push_back (argType);
4702
- argsInverted.push_back (DIFFE_TYPE::OUT_DIFF);
4703
4686
assert (whatType (argType, Mode) == DIFFE_TYPE::OUT_DIFF ||
4704
4687
whatType (argType, Mode) == DIFFE_TYPE::CONSTANT);
4705
4688
}
@@ -8484,37 +8467,10 @@ class AdjointGenerator
8484
8467
funcName = called->getName ();
8485
8468
}
8486
8469
8487
- bool subretused = unnecessaryValues.find (orig) == unnecessaryValues.end ();
8488
- if (gutils->knownRecomputeHeuristic .find (orig) !=
8489
- gutils->knownRecomputeHeuristic .end ()) {
8490
- if (!gutils->knownRecomputeHeuristic [orig]) {
8491
- subretused = true ;
8492
- }
8493
- }
8470
+ bool subretused = false ;
8494
8471
bool shadowReturnUsed = false ;
8495
-
8496
- DIFFE_TYPE subretType;
8497
- if (gutils->isConstantValue (orig)) {
8498
- subretType = DIFFE_TYPE::CONSTANT;
8499
- } else {
8500
- if (Mode == DerivativeMode::ForwardMode ||
8501
- Mode == DerivativeMode::ForwardModeSplit) {
8502
- subretType = DIFFE_TYPE::DUP_ARG;
8503
- shadowReturnUsed = true ;
8504
- } else {
8505
- if (!orig->getType ()->isFPOrFPVectorTy () &&
8506
- TR.query (orig).Inner0 ().isPossiblePointer ()) {
8507
- if (is_value_needed_in_reverse<ValueType::Shadow>(gutils, orig, Mode,
8508
- oldUnreachable)) {
8509
- subretType = DIFFE_TYPE::DUP_ARG;
8510
- shadowReturnUsed = true ;
8511
- } else
8512
- subretType = DIFFE_TYPE::CONSTANT;
8513
- } else {
8514
- subretType = DIFFE_TYPE::OUT_DIFF;
8515
- }
8516
- }
8517
- }
8472
+ DIFFE_TYPE subretType =
8473
+ gutils->getReturnDiffeType (orig, &subretused, &shadowReturnUsed);
8518
8474
8519
8475
if (Mode == DerivativeMode::ForwardMode) {
8520
8476
auto found = customFwdCallHandlers.find (funcName.str ());
@@ -8576,22 +8532,9 @@ class AdjointGenerator
8576
8532
getReverseBuilder (Builder2);
8577
8533
8578
8534
Value *invertedReturn = nullptr ;
8579
- bool hasNonReturnUse = false ;
8580
8535
auto ifound = gutils->invertedPointers .find (orig);
8581
8536
if (ifound != gutils->invertedPointers .end ()) {
8582
- // ! We only need the shadow pointer for non-forward Mode if it is used
8583
- // ! in a non return setting
8584
- if (!gutils->isConstantValue (orig)) {
8585
- if (!orig->getType ()->isFPOrFPVectorTy () &&
8586
- TR.query (orig).Inner0 ().isPossiblePointer ()) {
8587
- if (is_value_needed_in_reverse<ValueType::Shadow>(
8588
- gutils, orig, DerivativeMode::ReverseModePrimal,
8589
- oldUnreachable)) {
8590
- hasNonReturnUse = true ;
8591
- }
8592
- }
8593
- }
8594
- if (hasNonReturnUse)
8537
+ if (shadowReturnUsed)
8595
8538
invertedReturn = cast<PHINode>(&*ifound->second );
8596
8539
}
8597
8540
@@ -8627,7 +8570,7 @@ class AdjointGenerator
8627
8570
8628
8571
if (ifound != gutils->invertedPointers .end ()) {
8629
8572
auto placeholder = cast<PHINode>(&*ifound->second );
8630
- if (!hasNonReturnUse ) {
8573
+ if (!shadowReturnUsed ) {
8631
8574
gutils->invertedPointers .erase (ifound);
8632
8575
gutils->erase (placeholder);
8633
8576
} else {
@@ -8687,8 +8630,7 @@ class AdjointGenerator
8687
8630
gutils->replaceAWithB (newCall, normalReturn);
8688
8631
BuilderZ.SetInsertPoint (newCall->getNextNode ());
8689
8632
gutils->erase (newCall);
8690
- } else if ((!orig->mayWriteToMemory () ||
8691
- Mode == DerivativeMode::ReverseModeGradient) &&
8633
+ } else if (Mode == DerivativeMode::ReverseModeGradient &&
8692
8634
!orig->getType ()->isTokenTy ())
8693
8635
eraseIfUnused (*orig, /* erase*/ true , /* check*/ false );
8694
8636
}
@@ -11244,47 +11186,18 @@ class AdjointGenerator
11244
11186
#endif
11245
11187
args.push_back (argi);
11246
11188
11247
- if (gutils->isConstantValue (orig->getArgOperand (i)) &&
11248
- !foreignFunction) {
11249
- argsInverted.push_back (DIFFE_TYPE::CONSTANT);
11189
+ auto argTy =
11190
+ gutils->getDiffeType (orig->getArgOperand (i), foreignFunction);
11191
+ argsInverted.push_back (argTy);
11192
+
11193
+ if (argTy == DIFFE_TYPE::CONSTANT) {
11250
11194
continue ;
11251
11195
}
11252
11196
11253
- auto argType = argi->getType ();
11254
-
11255
- if (!argType->isFPOrFPVectorTy () &&
11256
- (TR.query (orig->getArgOperand (i)).Inner0 ().isPossiblePointer () ||
11257
- foreignFunction)) {
11258
- DIFFE_TYPE ty = DIFFE_TYPE::DUP_ARG;
11259
- if (argType->isPointerTy ()) {
11260
- #if LLVM_VERSION_MAJOR >= 12
11261
- auto at = getUnderlyingObject (orig->getArgOperand (i), 100 );
11262
- #else
11263
- auto at = GetUnderlyingObject (
11264
- orig->getArgOperand (i),
11265
- gutils->oldFunc ->getParent ()->getDataLayout (), 100 );
11266
- #endif
11267
- if (auto arg = dyn_cast<Argument>(at)) {
11268
- if (constant_args[arg->getArgNo ()] == DIFFE_TYPE::DUP_NONEED) {
11269
- ty = DIFFE_TYPE::DUP_NONEED;
11270
- }
11271
- }
11272
- }
11273
- args.push_back (
11274
- gutils->invertPointerM (orig->getArgOperand (i), Builder2));
11275
- argsInverted.push_back (ty);
11197
+ assert (argTy == DIFFE_TYPE::DUP_ARG || argTy == DIFFE_TYPE::DUP_NONEED);
11276
11198
11277
- // Note sometimes whattype mistakenly says something should be
11278
- // constant [because composed of integer pointers alone]
11279
- assert (whatType (argType, Mode) == DIFFE_TYPE::DUP_ARG ||
11280
- whatType (argType, Mode) == DIFFE_TYPE::CONSTANT);
11281
- } else {
11282
- if (foreignFunction)
11283
- assert (!argType->isIntOrIntVectorTy ());
11284
-
11285
- args.push_back (diffe (orig->getArgOperand (i), Builder2));
11286
- argsInverted.push_back (DIFFE_TYPE::DUP_ARG);
11287
- }
11199
+ args.push_back (
11200
+ gutils->invertPointerM (orig->getArgOperand (i), Builder2));
11288
11201
}
11289
11202
11290
11203
Optional<int > tapeIdx;
@@ -11478,33 +11391,18 @@ class AdjointGenerator
11478
11391
args.push_back (lookup (argi, Builder2));
11479
11392
}
11480
11393
11481
- if (gutils->isConstantValue (orig->getArgOperand (i)) && !foreignFunction) {
11482
- argsInverted.push_back (DIFFE_TYPE::CONSTANT);
11394
+ auto argTy =
11395
+ gutils->getDiffeType (orig->getArgOperand (i), foreignFunction);
11396
+
11397
+ argsInverted.push_back (argTy);
11398
+
11399
+ if (argTy == DIFFE_TYPE::CONSTANT) {
11483
11400
continue ;
11484
11401
}
11485
11402
11486
11403
auto argType = argi->getType ();
11487
11404
11488
- if (!argType->isFPOrFPVectorTy () &&
11489
- (TR.query (orig->getArgOperand (i)).Inner0 ().isPossiblePointer () ||
11490
- foreignFunction)) {
11491
- DIFFE_TYPE ty = DIFFE_TYPE::DUP_ARG;
11492
- if (argType->isPointerTy ()) {
11493
- #if LLVM_VERSION_MAJOR >= 12
11494
- auto at = getUnderlyingObject (orig->getArgOperand (i), 100 );
11495
- #else
11496
- auto at = GetUnderlyingObject (
11497
- orig->getArgOperand (i),
11498
- gutils->oldFunc ->getParent ()->getDataLayout (), 100 );
11499
- #endif
11500
- if (auto arg = dyn_cast<Argument>(at)) {
11501
- if (constant_args[arg->getArgNo ()] == DIFFE_TYPE::DUP_NONEED) {
11502
- ty = DIFFE_TYPE::DUP_NONEED;
11503
- }
11504
- }
11505
- }
11506
- argsInverted.push_back (ty);
11507
-
11405
+ if (argTy == DIFFE_TYPE::DUP_ARG || argTy == DIFFE_TYPE::DUP_NONEED) {
11508
11406
if (Mode != DerivativeMode::ReverseModePrimal) {
11509
11407
IRBuilder<> Builder2 (call.getParent ());
11510
11408
getReverseBuilder (Builder2);
@@ -11522,7 +11420,6 @@ class AdjointGenerator
11522
11420
} else {
11523
11421
if (foreignFunction)
11524
11422
assert (!argType->isIntOrIntVectorTy ());
11525
- argsInverted.push_back (DIFFE_TYPE::OUT_DIFF);
11526
11423
assert (whatType (argType, Mode) == DIFFE_TYPE::OUT_DIFF ||
11527
11424
whatType (argType, Mode) == DIFFE_TYPE::CONSTANT);
11528
11425
}
0 commit comments