Skip to content

Commit ff17a4e

Browse files
committed
cblas_ddot/cblas_sdot -> cblas_daxpy/cblas_saxpy
Derivative of cblas_ddot can be calculated with calls to cblas_daxpy. Similarly for cblas_sdot.
1 parent 494a7a4 commit ff17a4e

22 files changed

+2310
-0
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,11 @@
2323
// LLVM instructions.
2424
//
2525
//===----------------------------------------------------------------------===//
26+
#include "llvm/ADT/ArrayRef.h"
2627
#include "llvm/ADT/SmallVector.h"
2728
#include "llvm/Analysis/ValueTracking.h"
29+
#include "llvm/IR/Constants.h"
30+
#include "llvm/IR/DerivedTypes.h"
2831
#include "llvm/IR/Value.h"
2932
#include "llvm/Transforms/Utils/Cloning.h"
3033

@@ -4045,6 +4048,160 @@ class AdjointGenerator
40454048
return;
40464049
}
40474050

4051+
if ((funcName == "cblas_ddot" || funcName == "cblas_sdot") &&
4052+
called->isDeclaration()) {
4053+
Type *innerType;
4054+
std::string dfuncName;
4055+
if (funcName == "cblas_ddot") {
4056+
innerType = Type::getDoubleTy(call.getContext());
4057+
dfuncName = "cblas_daxpy";
4058+
} else if (funcName == "cblas_sdot") {
4059+
innerType = Type::getFloatTy(call.getContext());
4060+
dfuncName = "cblas_saxpy";
4061+
} else {
4062+
assert(false && "Unreachable");
4063+
}
4064+
Type *castvals[2] = {call.getArgOperand(1)->getType(),
4065+
call.getArgOperand(3)->getType()};
4066+
auto *cachetype = StructType::get(call.getContext(), ArrayRef(castvals));
4067+
Value *undefinit = UndefValue::get(cachetype);
4068+
Value *cacheval;
4069+
auto in_arg = call.getCalledFunction()->arg_begin();
4070+
in_arg++;
4071+
Argument *xfuncarg = in_arg;
4072+
in_arg++;
4073+
in_arg++;
4074+
Argument *yfuncarg = in_arg;
4075+
bool xcache = !gutils->isConstantValue(call.getArgOperand(3)) &&
4076+
uncacheable_args.find(xfuncarg)->second;
4077+
bool ycache = !gutils->isConstantValue(call.getArgOperand(1)) &&
4078+
uncacheable_args.find(yfuncarg)->second;
4079+
if ((Mode == DerivativeMode::ReverseModeCombined ||
4080+
Mode == DerivativeMode::ReverseModePrimal) &&
4081+
(xcache || ycache)) {
4082+
BuilderZ.SetInsertPoint(gutils->getNewFromOriginal(&call));
4083+
Value *arg1, *arg2;
4084+
auto size = ConstantExpr::getSizeOf(innerType);
4085+
if (xcache) {
4086+
auto dmemcpy = getOrInsertMemcpyStrided(
4087+
*BuilderZ.GetInsertBlock()->getParent()->getParent(),
4088+
PointerType::getUnqual(innerType), 0, 0);
4089+
auto malins = CallInst::CreateMalloc(
4090+
gutils->getNewFromOriginal(&call), size->getType(), innerType,
4091+
size, call.getArgOperand(0), nullptr, "");
4092+
arg1 =
4093+
BuilderZ.CreateBitCast(malins, call.getArgOperand(1)->getType());
4094+
SmallVector<Value *, 4> args;
4095+
args.push_back(arg1);
4096+
args.push_back(gutils->getNewFromOriginal(call.getArgOperand(1)));
4097+
args.push_back(call.getArgOperand(0));
4098+
args.push_back(call.getArgOperand(2));
4099+
BuilderZ.CreateCall(dmemcpy, args);
4100+
}
4101+
if (ycache) {
4102+
auto dmemcpy = getOrInsertMemcpyStrided(
4103+
*BuilderZ.GetInsertBlock()->getParent()->getParent(),
4104+
PointerType::getUnqual(innerType), 0, 0);
4105+
auto malins = CallInst::CreateMalloc(
4106+
gutils->getNewFromOriginal(&call), size->getType(), innerType,
4107+
size, call.getArgOperand(0), nullptr, "");
4108+
arg2 =
4109+
BuilderZ.CreateBitCast(malins, call.getArgOperand(3)->getType());
4110+
SmallVector<Value *, 4> args;
4111+
args.push_back(arg2);
4112+
args.push_back(gutils->getNewFromOriginal(call.getArgOperand(3)));
4113+
args.push_back(call.getArgOperand(0));
4114+
args.push_back(call.getArgOperand(4));
4115+
BuilderZ.CreateCall(dmemcpy, args);
4116+
}
4117+
if (xcache && ycache) {
4118+
auto valins1 = BuilderZ.CreateInsertValue(undefinit, arg1, 0);
4119+
cacheval = BuilderZ.CreateInsertValue(valins1, arg2, 1);
4120+
} else if (xcache)
4121+
cacheval = arg1;
4122+
else if (ycache)
4123+
cacheval = arg2;
4124+
gutils->cacheForReverse(BuilderZ, cacheval,
4125+
getIndex(&call, CacheType::Tape));
4126+
}
4127+
if (Mode == DerivativeMode::ReverseModeCombined ||
4128+
Mode == DerivativeMode::ReverseModeGradient) {
4129+
IRBuilder<> Builder2(call.getParent());
4130+
getReverseBuilder(Builder2);
4131+
auto derivcall = gutils->oldFunc->getParent()->getOrInsertFunction(
4132+
dfuncName, Builder2.getVoidTy(), Builder2.getInt32Ty(), innerType,
4133+
call.getArgOperand(1)->getType(), Builder2.getInt32Ty(),
4134+
call.getArgOperand(3)->getType(), Builder2.getInt32Ty());
4135+
Value *structarg1;
4136+
Value *structarg2;
4137+
if (xcache || ycache) {
4138+
if (Mode == DerivativeMode::ReverseModeGradient &&
4139+
(!gutils->isConstantValue(call.getArgOperand(1)) ||
4140+
!gutils->isConstantValue(call.getArgOperand(3)))) {
4141+
cacheval = Builder2.CreatePHI(cachetype, 0);
4142+
}
4143+
cacheval = gutils->cacheForReverse(Builder2, cacheval,
4144+
getIndex(&call, CacheType::Tape));
4145+
if (xcache && ycache) {
4146+
structarg1 = BuilderZ.CreateExtractValue(cacheval, 0);
4147+
structarg2 = BuilderZ.CreateExtractValue(cacheval, 1);
4148+
} else if (xcache)
4149+
structarg1 = cacheval;
4150+
else if (ycache)
4151+
structarg2 = cacheval;
4152+
}
4153+
if (!xcache)
4154+
structarg1 = lookup(
4155+
gutils->getNewFromOriginal(orig->getArgOperand(1)), Builder2);
4156+
if (!ycache)
4157+
structarg2 = lookup(
4158+
gutils->getNewFromOriginal(orig->getArgOperand(3)), Builder2);
4159+
CallInst *firstdcall, *seconddcall;
4160+
if (!gutils->isConstantValue(call.getArgOperand(3))) {
4161+
Value *estride;
4162+
if (xcache)
4163+
estride = Builder2.getInt32(1);
4164+
else
4165+
estride = lookup(gutils->getNewFromOriginal(orig->getArgOperand(2)),
4166+
Builder2);
4167+
SmallVector<Value *, 6> args1 = {
4168+
lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)),
4169+
Builder2),
4170+
diffe(orig, Builder2),
4171+
structarg1,
4172+
estride,
4173+
gutils->invertPointerM(orig->getArgOperand(3), Builder2),
4174+
lookup(gutils->getNewFromOriginal(orig->getArgOperand(4)),
4175+
Builder2)};
4176+
firstdcall = Builder2.CreateCall(derivcall, args1);
4177+
}
4178+
if (!gutils->isConstantValue(call.getArgOperand(1))) {
4179+
Value *estride;
4180+
if (ycache)
4181+
estride = Builder2.getInt32(1);
4182+
else
4183+
estride = lookup(gutils->getNewFromOriginal(orig->getArgOperand(4)),
4184+
Builder2);
4185+
SmallVector<Value *, 6> args2 = {
4186+
lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)),
4187+
Builder2),
4188+
diffe(orig, Builder2),
4189+
structarg2,
4190+
estride,
4191+
gutils->invertPointerM(orig->getArgOperand(1), Builder2),
4192+
lookup(gutils->getNewFromOriginal(orig->getArgOperand(2)),
4193+
Builder2)};
4194+
seconddcall = Builder2.CreateCall(derivcall, args2);
4195+
}
4196+
setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2);
4197+
if (xcache)
4198+
CallInst::CreateFree(structarg1, firstdcall->getNextNode());
4199+
if (ycache)
4200+
CallInst::CreateFree(structarg2, seconddcall->getNextNode());
4201+
}
4202+
return;
4203+
}
4204+
40484205
if (funcName == "printf" || funcName == "puts" ||
40494206
funcName.startswith("_ZN3std2io5stdio6_print") ||
40504207
funcName.startswith("_ZN4core3fmt")) {

enzyme/Enzyme/Enzyme.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,13 @@ class Enzyme : public ModulePass {
821821
F.addFnAttr(Attribute::ReadOnly);
822822
F.addFnAttr(Attribute::InaccessibleMemOnly);
823823
}
824+
if ((Fn->getName() == "cblas_ddot" || Fn->getName() == "cblas_sdot") &&
825+
Fn->isDeclaration()) {
826+
CI->addParamAttr(1, Attribute::ReadOnly);
827+
CI->addParamAttr(1, Attribute::NoCapture);
828+
CI->addParamAttr(3, Attribute::ReadOnly);
829+
CI->addParamAttr(3, Attribute::NoCapture);
830+
}
824831
if (Fn->getName() == "frexp" || Fn->getName() == "frexpf" ||
825832
Fn->getName() == "frexpl") {
826833
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ArgMemOnly);

enzyme/Enzyme/Utils.cpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,100 @@ Function *getOrInsertDifferentialFloatMemcpy(Module &M, PointerType *T,
158158
return F;
159159
}
160160

161+
Function *getOrInsertMemcpyStrided(Module &M, PointerType *T, unsigned dstalign,
162+
unsigned srcalign) {
163+
Type *elementType = T->getElementType();
164+
assert(elementType->isFloatingPointTy());
165+
std::string name = "__enzyme_memcpy_" + tofltstr(elementType) + "da" +
166+
std::to_string(dstalign) + "sa" +
167+
std::to_string(srcalign) + "stride";
168+
FunctionType *FT = FunctionType::get(Type::getVoidTy(M.getContext()),
169+
{T, T, Type::getInt32Ty(M.getContext()),
170+
Type::getInt32Ty(M.getContext())},
171+
false);
172+
173+
#if LLVM_VERSION_MAJOR >= 9
174+
Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee());
175+
#else
176+
Function *F = cast<Function>(M.getOrInsertFunction(name, FT));
177+
#endif
178+
179+
if (!F->empty())
180+
return F;
181+
182+
F->setLinkage(Function::LinkageTypes::InternalLinkage);
183+
F->addFnAttr(Attribute::ArgMemOnly);
184+
F->addFnAttr(Attribute::NoUnwind);
185+
F->addParamAttr(0, Attribute::NoCapture);
186+
F->addParamAttr(1, Attribute::NoCapture);
187+
F->addParamAttr(0, Attribute::WriteOnly);
188+
F->addParamAttr(1, Attribute::ReadOnly);
189+
190+
BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F);
191+
BasicBlock *body = BasicBlock::Create(M.getContext(), "for.body", F);
192+
BasicBlock *end = BasicBlock::Create(M.getContext(), "for.end", F);
193+
194+
auto dst = F->arg_begin();
195+
dst->setName("dst");
196+
auto src = dst + 1;
197+
src->setName("src");
198+
auto num = src + 1;
199+
num->setName("num");
200+
auto stride = num + 1;
201+
stride->setName("stride");
202+
203+
{
204+
IRBuilder<> B(entry);
205+
B.CreateCondBr(B.CreateICmpEQ(num, ConstantInt::get(num->getType(), 0)),
206+
end, body);
207+
}
208+
209+
{
210+
IRBuilder<> B(body);
211+
B.setFastMathFlags(getFast());
212+
PHINode *idx = B.CreatePHI(num->getType(), 2, "idx");
213+
PHINode *sidx = B.CreatePHI(num->getType(), 2, "sidx");
214+
idx->addIncoming(ConstantInt::get(num->getType(), 0), entry);
215+
sidx->addIncoming(ConstantInt::get(num->getType(), 0), entry);
216+
217+
Value *dsti = B.CreateGEP(dst, idx, "dst.i");
218+
219+
Value *srci = B.CreateGEP(src, sidx, "src.i");
220+
LoadInst *srcl = B.CreateLoad(srci, "src.i.l");
221+
222+
StoreInst *dsts = B.CreateStore(srcl, dsti);
223+
224+
if (dstalign) {
225+
#if LLVM_VERSION_MAJOR >= 10
226+
dsts->setAlignment(Align(dstalign));
227+
#else
228+
dsts->setAlignment(dstalign);
229+
#endif
230+
}
231+
if (srcalign) {
232+
#if LLVM_VERSION_MAJOR >= 10
233+
srcl->setAlignment(Align(srcalign));
234+
#else
235+
srcl->setAlignment(srcalign);
236+
#endif
237+
}
238+
239+
Value *next =
240+
B.CreateNUWAdd(idx, ConstantInt::get(num->getType(), 1), "idx.next");
241+
Value *snext = B.CreateNUWAdd(sidx, stride, "sidx.next");
242+
idx->addIncoming(next, body);
243+
sidx->addIncoming(snext, body);
244+
B.CreateCondBr(B.CreateICmpEQ(num, next), end, body);
245+
}
246+
247+
{
248+
IRBuilder<> B(end);
249+
B.CreateRetVoid();
250+
}
251+
252+
return F;
253+
}
254+
161255
// TODO implement differential memmove
162256
Function *getOrInsertDifferentialFloatMemmove(Module &M, PointerType *T,
163257
unsigned dstalign,

enzyme/Enzyme/Utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,10 @@ llvm::Function *getOrInsertDifferentialFloatMemcpy(llvm::Module &M,
519519
unsigned dstalign,
520520
unsigned srcalign);
521521

522+
/// Create function for type that performs memcpy with a stride
523+
llvm::Function *getOrInsertMemcpyStrided(llvm::Module &M, llvm::PointerType *T,
524+
unsigned dstalign, unsigned srcalign);
525+
522526
/// Create function for type that performs the derivative memmove on floating
523527
/// point memory
524528
llvm::Function *getOrInsertDifferentialFloatMemmove(llvm::Module &M,
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
;RUN: %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
2+
3+
;#include <cblas.h>
4+
;
5+
;extern double __enzyme_autodiff(double*, double*, double*);
6+
;
7+
;double g(double *restrict m) {
8+
; double n[3] = {4, 5, 6};
9+
; double x = cblas_ddot(3, m, 1, n, 1);
10+
; double y = x*x;
11+
; return y;
12+
;}
13+
;
14+
;int main() {
15+
; double m[3] = {1, 2, 3};
16+
; double m1[3] = {0.};
17+
; double z = __enzyme_autodiff((double*)g, m, m1);
18+
;}
19+
20+
target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
21+
target triple = "x86_64-unknown-linux-gnu"
22+
23+
@__const.g.n = private unnamed_addr constant [3 x double] [double 4.000000e+00, double 5.000000e+00, double 6.000000e+00], align 16
24+
@__const.main.m = private unnamed_addr constant [3 x double] [double 1.000000e+00, double 2.000000e+00, double 3.000000e+00], align 16
25+
26+
define dso_local double @g(double* noalias %m) {
27+
entry:
28+
%m.addr = alloca double*, align 8
29+
%n = alloca [3 x double], align 16
30+
%x = alloca double, align 8
31+
%y = alloca double, align 8
32+
store double* %m, double** %m.addr, align 8
33+
%0 = bitcast [3 x double]* %n to i8*
34+
call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([3 x double]* @__const.g.n to i8*), i64 24, i1 false)
35+
%1 = load double*, double** %m.addr, align 8
36+
%arraydecay = getelementptr inbounds [3 x double], [3 x double]* %n, i32 0, i32 0
37+
%call = call double @cblas_ddot(i32 3, double* %1, i32 1, double* %arraydecay, i32 1)
38+
store double %call, double* %x, align 8
39+
%2 = load double, double* %x, align 8
40+
%3 = load double, double* %x, align 8
41+
%mul = fmul double %2, %3
42+
store double %mul, double* %y, align 8
43+
%4 = load double, double* %y, align 8
44+
ret double %4
45+
}
46+
47+
declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1)
48+
49+
declare dso_local double @cblas_ddot(i32, double*, i32, double*, i32)
50+
51+
define dso_local i32 @main() {
52+
entry:
53+
%m = alloca [3 x double], align 16
54+
%m1 = alloca [3 x double], align 16
55+
%z = alloca double, align 8
56+
%0 = bitcast [3 x double]* %m to i8*
57+
call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([3 x double]* @__const.main.m to i8*), i64 24, i1 false)
58+
%1 = bitcast [3 x double]* %m1 to i8*
59+
call void @llvm.memset.p0i8.i64(i8* align 16 %1, i8 0, i64 24, i1 false)
60+
%arraydecay = getelementptr inbounds [3 x double], [3 x double]* %m, i32 0, i32 0
61+
%arraydecay1 = getelementptr inbounds [3 x double], [3 x double]* %m1, i32 0, i32 0
62+
%call = call double @__enzyme_autodiff(double* bitcast (double (double*)* @g to double*), double* %arraydecay, double* %arraydecay1)
63+
store double %call, double* %z, align 8
64+
ret i32 0
65+
}
66+
67+
declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1)
68+
69+
declare dso_local double @__enzyme_autodiff(double*, double*, double*)
70+
71+
;CHECK:define internal void @diffeg(double* noalias %m, double* %"m'", double %differeturn) {
72+
;CHECK-NEXT:entry:
73+
;CHECK-NEXT: %n = alloca [3 x double], align 16
74+
;CHECK-NEXT: %0 = bitcast [3 x double]* %n to i8*
75+
;CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([3 x double]* @__const.g.n to i8*), i64 24, i1 false)
76+
;CHECK-NEXT: %arraydecay = getelementptr inbounds [3 x double], [3 x double]* %n, i32 0, i32 0
77+
;CHECK-NEXT: %call = call double @cblas_ddot(i32 3, double* nocapture readonly %m, i32 1, double* nocapture readonly %arraydecay, i32 1)
78+
;CHECK-NEXT: %m0diffecall = fmul fast double %differeturn, %call
79+
;CHECK-NEXT: %m1diffecall = fmul fast double %differeturn, %call
80+
;CHECK-NEXT: %1 = fadd fast double %m0diffecall, %m1diffecall
81+
;CHECK-NEXT: call void @cblas_daxpy(i32 3, double %1, double* %arraydecay, i32 1, double* %"m'", i32 1)
82+
;CHECK-NEXT: ret void
83+
;CHECK-NEXT:}
84+
85+
;CHECK: declare void @cblas_daxpy(i32, double, double*, i32, double*, i32)

0 commit comments

Comments
 (0)