@@ -832,37 +832,42 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
832
832
if (pidx == nullptr )
833
833
goto endCheck;
834
834
835
- if (pidx->getType () != dli->getOperand (0 )->getType ()) {
835
+ if (pidx->getType () != getShadowType ( dli->getOperand (0 )->getType () )) {
836
836
llvm::errs () << " dli: " << *dli << " \n " ;
837
837
llvm::errs () << " dli->getOperand(0): " << *dli->getOperand (0 ) << " \n " ;
838
838
llvm::errs () << " pidx: " << *pidx << " \n " ;
839
839
}
840
- assert (pidx->getType () == dli->getOperand (0 )->getType ());
840
+ assert (pidx->getType () == getShadowType (dli->getOperand (0 )->getType ()));
841
+
842
+ Value *toreturn = applyChainRule (
843
+ dli->getType (), BuilderM,
844
+ [&](Value *pidx) {
841
845
#if LLVM_VERSION_MAJOR > 7
842
- auto toreturn =
843
- BuilderM.CreateLoad (pidx->getType ()->getPointerElementType (), pidx,
844
- phi->getName () + " _unwrap" );
846
+ auto toreturn = BuilderM.CreateLoad (dli->getType (), pidx,
847
+ phi->getName () + " _unwrap" );
845
848
#else
846
- auto toreturn = BuilderM.CreateLoad (pidx, phi->getName () + " _unwrap" );
849
+ auto toreturn =
850
+ BuilderM.CreateLoad (pidx, phi->getName () + " _unwrap" );
847
851
#endif
848
- if (auto newi = dyn_cast<Instruction>(toreturn)) {
849
- newi->copyIRFlags (dli);
850
- unwrappedLoads[toreturn] = dli;
851
- }
852
+ if (auto newi = dyn_cast<Instruction>(toreturn)) {
853
+ newi->copyIRFlags (dli);
854
+ unwrappedLoads[toreturn] = dli;
855
+ }
852
856
#if LLVM_VERSION_MAJOR >= 10
853
- toreturn->setAlignment (dli->getAlign ());
857
+ toreturn->setAlignment (dli->getAlign ());
854
858
#else
855
- toreturn->setAlignment (dli->getAlignment ());
859
+ toreturn->setAlignment (dli->getAlignment ());
856
860
#endif
857
- toreturn->setVolatile (dli->isVolatile ());
858
- toreturn->setOrdering (dli->getOrdering ());
859
- toreturn->setSyncScopeID (dli->getSyncScopeID ());
860
- toreturn->setDebugLoc (getNewFromOriginal (dli->getDebugLoc ()));
861
- toreturn->setMetadata (LLVMContext::MD_tbaa,
862
- dli->getMetadata (LLVMContext::MD_tbaa));
863
- toreturn->setMetadata (
864
- LLVMContext::MD_invariant_group,
865
- dli->getMetadata (LLVMContext::MD_invariant_group));
861
+ toreturn->setVolatile (dli->isVolatile ());
862
+ toreturn->setOrdering (dli->getOrdering ());
863
+ toreturn->setSyncScopeID (dli->getSyncScopeID ());
864
+ toreturn->setDebugLoc (getNewFromOriginal (dli->getDebugLoc ()));
865
+ toreturn->setMetadata (LLVMContext::MD_tbaa,
866
+ dli->getMetadata (LLVMContext::MD_tbaa));
867
+ return toreturn;
868
+ },
869
+ pidx);
870
+
866
871
// TODO adding to cache only legal if no alias of any future writes
867
872
if (permitCache)
868
873
unwrap_cache[BuilderM.GetInsertBlock ()][idx.first ][idx.second ] =
@@ -4825,6 +4830,9 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
4825
4830
new_op->setMetadata (LLVMContext::MD_invariant_group, invgroup);
4826
4831
}
4827
4832
}
4833
+ if (op->getType () != inst->getType ()) {
4834
+ llvm::errs () << " op: " << *op << " inst: " << *inst << " \n " ;
4835
+ }
4828
4836
assert (op->getType () == inst->getType ());
4829
4837
if (!reduceRegister)
4830
4838
lookup_cache[BuilderM.GetInsertBlock ()][val] = op;
0 commit comments