|
23 | 23 | // LLVM instructions.
|
24 | 24 | //
|
25 | 25 | //===----------------------------------------------------------------------===//
|
| 26 | +#include "llvm/ADT/ArrayRef.h" |
26 | 27 | #include "llvm/ADT/SmallVector.h"
|
27 | 28 | #include "llvm/Analysis/ValueTracking.h"
|
| 29 | +#include "llvm/IR/Constants.h" |
| 30 | +#include "llvm/IR/DerivedTypes.h" |
28 | 31 | #include "llvm/IR/Value.h"
|
29 | 32 | #include "llvm/Transforms/Utils/Cloning.h"
|
30 | 33 |
|
@@ -4045,6 +4048,160 @@ class AdjointGenerator
|
4045 | 4048 | return;
|
4046 | 4049 | }
|
4047 | 4050 |
|
| 4051 | + if ((funcName == "cblas_ddot" || funcName == "cblas_sdot") && |
| 4052 | + called->isDeclaration()) { |
| 4053 | + Type *innerType; |
| 4054 | + std::string dfuncName; |
| 4055 | + if (funcName == "cblas_ddot") { |
| 4056 | + innerType = Type::getDoubleTy(call.getContext()); |
| 4057 | + dfuncName = "cblas_daxpy"; |
| 4058 | + } else if (funcName == "cblas_sdot") { |
| 4059 | + innerType = Type::getFloatTy(call.getContext()); |
| 4060 | + dfuncName = "cblas_saxpy"; |
| 4061 | + } else { |
| 4062 | + assert(false && "Unreachable"); |
| 4063 | + } |
| 4064 | + Type *castvals[2] = {call.getArgOperand(1)->getType(), |
| 4065 | + call.getArgOperand(3)->getType()}; |
| 4066 | + auto *cachetype = StructType::get(call.getContext(), ArrayRef(castvals)); |
| 4067 | + Value *undefinit = UndefValue::get(cachetype); |
| 4068 | + Value *cacheval; |
| 4069 | + auto in_arg = call.getCalledFunction()->arg_begin(); |
| 4070 | + in_arg++; |
| 4071 | + Argument *xfuncarg = in_arg; |
| 4072 | + in_arg++; |
| 4073 | + in_arg++; |
| 4074 | + Argument *yfuncarg = in_arg; |
| 4075 | + bool xcache = !gutils->isConstantValue(call.getArgOperand(3)) && |
| 4076 | + uncacheable_args.find(xfuncarg)->second; |
| 4077 | + bool ycache = !gutils->isConstantValue(call.getArgOperand(1)) && |
| 4078 | + uncacheable_args.find(yfuncarg)->second; |
| 4079 | + if ((Mode == DerivativeMode::ReverseModeCombined || |
| 4080 | + Mode == DerivativeMode::ReverseModePrimal) && |
| 4081 | + (xcache || ycache)) { |
| 4082 | + BuilderZ.SetInsertPoint(gutils->getNewFromOriginal(&call)); |
| 4083 | + Value *arg1, *arg2; |
| 4084 | + auto size = ConstantExpr::getSizeOf(innerType); |
| 4085 | + if (xcache) { |
| 4086 | + auto dmemcpy = getOrInsertMemcpyStrided( |
| 4087 | + *BuilderZ.GetInsertBlock()->getParent()->getParent(), |
| 4088 | + PointerType::getUnqual(innerType), 0, 0); |
| 4089 | + auto malins = CallInst::CreateMalloc( |
| 4090 | + gutils->getNewFromOriginal(&call), size->getType(), innerType, |
| 4091 | + size, call.getArgOperand(0), nullptr, ""); |
| 4092 | + arg1 = |
| 4093 | + BuilderZ.CreateBitCast(malins, call.getArgOperand(1)->getType()); |
| 4094 | + SmallVector<Value *, 4> args; |
| 4095 | + args.push_back(arg1); |
| 4096 | + args.push_back(gutils->getNewFromOriginal(call.getArgOperand(1))); |
| 4097 | + args.push_back(call.getArgOperand(0)); |
| 4098 | + args.push_back(call.getArgOperand(2)); |
| 4099 | + BuilderZ.CreateCall(dmemcpy, args); |
| 4100 | + } |
| 4101 | + if (ycache) { |
| 4102 | + auto dmemcpy = getOrInsertMemcpyStrided( |
| 4103 | + *BuilderZ.GetInsertBlock()->getParent()->getParent(), |
| 4104 | + PointerType::getUnqual(innerType), 0, 0); |
| 4105 | + auto malins = CallInst::CreateMalloc( |
| 4106 | + gutils->getNewFromOriginal(&call), size->getType(), innerType, |
| 4107 | + size, call.getArgOperand(0), nullptr, ""); |
| 4108 | + arg2 = |
| 4109 | + BuilderZ.CreateBitCast(malins, call.getArgOperand(3)->getType()); |
| 4110 | + SmallVector<Value *, 4> args; |
| 4111 | + args.push_back(arg2); |
| 4112 | + args.push_back(gutils->getNewFromOriginal(call.getArgOperand(3))); |
| 4113 | + args.push_back(call.getArgOperand(0)); |
| 4114 | + args.push_back(call.getArgOperand(4)); |
| 4115 | + BuilderZ.CreateCall(dmemcpy, args); |
| 4116 | + } |
| 4117 | + if (xcache && ycache) { |
| 4118 | + auto valins1 = BuilderZ.CreateInsertValue(undefinit, arg1, 0); |
| 4119 | + cacheval = BuilderZ.CreateInsertValue(valins1, arg2, 1); |
| 4120 | + } else if (xcache) |
| 4121 | + cacheval = arg1; |
| 4122 | + else if (ycache) |
| 4123 | + cacheval = arg2; |
| 4124 | + gutils->cacheForReverse(BuilderZ, cacheval, |
| 4125 | + getIndex(&call, CacheType::Tape)); |
| 4126 | + } |
| 4127 | + if (Mode == DerivativeMode::ReverseModeCombined || |
| 4128 | + Mode == DerivativeMode::ReverseModeGradient) { |
| 4129 | + IRBuilder<> Builder2(call.getParent()); |
| 4130 | + getReverseBuilder(Builder2); |
| 4131 | + auto derivcall = gutils->oldFunc->getParent()->getOrInsertFunction( |
| 4132 | + dfuncName, Builder2.getVoidTy(), Builder2.getInt32Ty(), innerType, |
| 4133 | + call.getArgOperand(1)->getType(), Builder2.getInt32Ty(), |
| 4134 | + call.getArgOperand(3)->getType(), Builder2.getInt32Ty()); |
| 4135 | + Value *structarg1; |
| 4136 | + Value *structarg2; |
| 4137 | + if (xcache || ycache) { |
| 4138 | + if (Mode == DerivativeMode::ReverseModeGradient && |
| 4139 | + (!gutils->isConstantValue(call.getArgOperand(1)) || |
| 4140 | + !gutils->isConstantValue(call.getArgOperand(3)))) { |
| 4141 | + cacheval = Builder2.CreatePHI(cachetype, 0); |
| 4142 | + } |
| 4143 | + cacheval = gutils->cacheForReverse(Builder2, cacheval, |
| 4144 | + getIndex(&call, CacheType::Tape)); |
| 4145 | + if (xcache && ycache) { |
| 4146 | + structarg1 = BuilderZ.CreateExtractValue(cacheval, 0); |
| 4147 | + structarg2 = BuilderZ.CreateExtractValue(cacheval, 1); |
| 4148 | + } else if (xcache) |
| 4149 | + structarg1 = cacheval; |
| 4150 | + else if (ycache) |
| 4151 | + structarg2 = cacheval; |
| 4152 | + } |
| 4153 | + if (!xcache) |
| 4154 | + structarg1 = lookup( |
| 4155 | + gutils->getNewFromOriginal(orig->getArgOperand(1)), Builder2); |
| 4156 | + if (!ycache) |
| 4157 | + structarg2 = lookup( |
| 4158 | + gutils->getNewFromOriginal(orig->getArgOperand(3)), Builder2); |
| 4159 | + CallInst *firstdcall, *seconddcall; |
| 4160 | + if (!gutils->isConstantValue(call.getArgOperand(3))) { |
| 4161 | + Value *estride; |
| 4162 | + if (xcache) |
| 4163 | + estride = Builder2.getInt32(1); |
| 4164 | + else |
| 4165 | + estride = lookup(gutils->getNewFromOriginal(orig->getArgOperand(2)), |
| 4166 | + Builder2); |
| 4167 | + SmallVector<Value *, 6> args1 = { |
| 4168 | + lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), |
| 4169 | + Builder2), |
| 4170 | + diffe(orig, Builder2), |
| 4171 | + structarg1, |
| 4172 | + estride, |
| 4173 | + gutils->invertPointerM(orig->getArgOperand(3), Builder2), |
| 4174 | + lookup(gutils->getNewFromOriginal(orig->getArgOperand(4)), |
| 4175 | + Builder2)}; |
| 4176 | + firstdcall = Builder2.CreateCall(derivcall, args1); |
| 4177 | + } |
| 4178 | + if (!gutils->isConstantValue(call.getArgOperand(1))) { |
| 4179 | + Value *estride; |
| 4180 | + if (ycache) |
| 4181 | + estride = Builder2.getInt32(1); |
| 4182 | + else |
| 4183 | + estride = lookup(gutils->getNewFromOriginal(orig->getArgOperand(4)), |
| 4184 | + Builder2); |
| 4185 | + SmallVector<Value *, 6> args2 = { |
| 4186 | + lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), |
| 4187 | + Builder2), |
| 4188 | + diffe(orig, Builder2), |
| 4189 | + structarg2, |
| 4190 | + estride, |
| 4191 | + gutils->invertPointerM(orig->getArgOperand(1), Builder2), |
| 4192 | + lookup(gutils->getNewFromOriginal(orig->getArgOperand(2)), |
| 4193 | + Builder2)}; |
| 4194 | + seconddcall = Builder2.CreateCall(derivcall, args2); |
| 4195 | + } |
| 4196 | + setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); |
| 4197 | + if (xcache) |
| 4198 | + CallInst::CreateFree(structarg1, firstdcall->getNextNode()); |
| 4199 | + if (ycache) |
| 4200 | + CallInst::CreateFree(structarg2, seconddcall->getNextNode()); |
| 4201 | + } |
| 4202 | + return; |
| 4203 | + } |
| 4204 | + |
4048 | 4205 | if (funcName == "printf" || funcName == "puts" ||
|
4049 | 4206 | funcName.startswith("_ZN3std2io5stdio6_print") ||
|
4050 | 4207 | funcName.startswith("_ZN4core3fmt")) {
|
|
0 commit comments