Skip to content

Commit 306c387

Browse files
committed
Add nvvm sqrt
1 parent 0d40ec5 commit 306c387

File tree

4 files changed

+14
-11
lines changed

4 files changed

+14
-11
lines changed

enzyme/CMakeLists.txt

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,6 @@ else()
7575
file(READ ${LLVM_IDIR}/llvm/Analysis/ScalarEvolutionExpander.h INPUT_TEXT)
7676
endif()
7777

78-
find_package(MPI)
79-
if (${MPI_FOUND})
80-
add_definitions(-DBUILDMPI)
81-
include_directories(SYSTEM ${MPI_C_INCLUDE_PATH})
82-
else()
83-
set(MPI_C_LIBRARIES "")
84-
endif()
85-
8678
if (${LLVM_VERSION_MAJOR} LESS 12)
8779
string(REPLACE "#define LLVM_ANALYSIS_SCALAREVOLUTIONEXPANDER_H" "#define LLVM_ANALYSIS_SCALAREVOLUTIONEXPANDER_H\n#include \"SCEV/ScalarEvolution.h\"" INPUT_TEXT "${INPUT_TEXT}")
8880
string(REPLACE "LLVM_ANALYSIS" "FAKELLVM_ANALYSIS" INPUT_TEXT "${INPUT_TEXT}")

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1956,6 +1956,7 @@ class AdjointGenerator
19561956
case Intrinsic::nearbyint:
19571957
case Intrinsic::round:
19581958
case Intrinsic::sqrt:
1959+
case Intrinsic::nvvm_sqrt_rn_d:
19591960
case Intrinsic::fma:
19601961
return;
19611962
default:
@@ -2071,6 +2072,7 @@ class AdjointGenerator
20712072
return;
20722073
}
20732074

2075+
case Intrinsic::nvvm_sqrt_rn_d:
20742076
case Intrinsic::sqrt: {
20752077
if (vdiff && !gutils->isConstantValue(orig_ops[0])) {
20762078
SmallVector<Value *, 2> args = {

enzyme/Enzyme/Enzyme.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class Enzyme : public ModulePass {
147147
bool AtomicAdd = Arch == Triple::nvptx || Arch == Triple::nvptx64 ||
148148
Arch == Triple::amdgcn;
149149

150-
std::map<int, Type*> byVal;
150+
std::map<int, Type *> byVal;
151151
for (unsigned i = 1; i < CI->getNumArgOperands(); ++i) {
152152
Value *res = CI->getArgOperand(i);
153153

@@ -358,9 +358,11 @@ class Enzyme : public ModulePass {
358358
}
359359
res = Builder.CreateBitCast(res, PTy);
360360
}
361+
#if LLVM_VERSION_MAJOR >= 9
361362
if (CI->isByValArgument(i)) {
362363
byVal[args.size()] = CI->getParamByValType(i);
363364
}
365+
#endif
364366
args.push_back(res);
365367
if (ty == DIFFE_TYPE::DUP_ARG || ty == DIFFE_TYPE::DUP_NONEED) {
366368
++i;
@@ -479,9 +481,13 @@ class Enzyme : public ModulePass {
479481
CallInst *diffret = cast<CallInst>(Builder.CreateCall(newFunc, args));
480482
diffret->setCallingConv(CI->getCallingConv());
481483
diffret->setDebugLoc(CI->getDebugLoc());
484+
#if LLVM_VERSION_MAJOR >= 9
482485
for (auto pair : byVal) {
483-
diffret->addParamAttr(pair.first, Attribute::getWithByValType(diffret->getContext(), pair.second));
486+
diffret->addParamAttr(
487+
pair.first,
488+
Attribute::getWithByValType(diffret->getContext(), pair.second));
484489
}
490+
#endif
485491

486492
if (!diffret->getType()->isEmptyTy() && !diffret->getType()->isVoidTy()) {
487493
unsigned idxs[] = {0};

enzyme/test/Enzyme/ReverseMode/byval.ll

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,8 +490,11 @@ attributes #11 = { noreturn }
490490
!21 = !{!22, !22, i64 0}
491491
!22 = !{!"int", !5, i64 0}
492492

493+
; CHECK: define dso_local double @_Z14dcar_erg_atpos3card(%class.car* nocapture readonly byval(%class.car) align 8 %car1, double %pos)
494+
; CHECK: call { double } @diffe_Z13car_erg_atpos3card(%class.car* nonnull byval(%class.car) %car1, double %pos, double 1.000000e+00)
495+
493496
; CHECK: define internal { double } @diffe_Z13car_erg_atpos3card(%class.car* byval(%class.car) align 8 %car1, double %pos, double %differeturn)
494497
; CHECK-NEXT: entry:
495498
; CHECK-NEXT: %"car1'ipa" = alloca %class.car
496499
; CHECK-NEXT: %0 = bitcast %class.car* %"car1'ipa" to i8*
497-
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* nonnull %0, i8 0, i64 24, i1 false)
500+
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* nonnull %0, i8 0, i64 24, i1 false)

0 commit comments

Comments
 (0)