Skip to content

Commit a75e86a

Browse files
authored
Fix noundef and other errors (rust-lang#227)
* Fix noundef and other errors * Fix lookup * Only run gmm on LLVM's without tbaa bug
1 parent 00b5819 commit a75e86a

File tree

6 files changed

+67
-7
lines changed

6 files changed

+67
-7
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,8 @@ class AdjointGenerator
560560
if (FT) {
561561
//! Only need to update the reverse function
562562
switch (Mode) {
563+
case DerivativeMode::ReverseModePrimal:
564+
break;
563565
case DerivativeMode::ReverseModeGradient:
564566
case DerivativeMode::ReverseModeCombined: {
565567
IRBuilder<> Builder2(SI.getParent());
@@ -2750,9 +2752,9 @@ class AdjointGenerator
27502752

27512753
// llvm::Optional<std::map<std::pair<Instruction*, std::string>, unsigned>>
27522754
// sub_index_map;
2753-
Optional<int> tapeIdx;
2754-
Optional<int> returnIdx;
2755-
Optional<int> differetIdx;
2755+
// Optional<int> tapeIdx;
2756+
// Optional<int> returnIdx;
2757+
// Optional<int> differetIdx;
27562758

27572759
const AugmentedReturn *subdata = nullptr;
27582760
if (Mode == DerivativeMode::ReverseModeGradient) {
@@ -3469,7 +3471,7 @@ class AdjointGenerator
34693471
for (int i = 0; i < 7; i++)
34703472
types[i] = args[i]->getType();
34713473

3472-
FunctionType *FT = FunctionType::get(root->getType(), types);
3474+
FunctionType *FT = FunctionType::get(root->getType(), types, false);
34733475
Builder2.CreateCall(
34743476
called->getParent()->getOrInsertFunction("MPI_Reduce", FT), args);
34753477
}

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1687,6 +1687,13 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
16871687
gutils->newFunc->removeAttribute(llvm::AttributeList::ReturnIndex,
16881688
llvm::Attribute::NoAlias);
16891689
}
1690+
#if LLVM_VERSION_MAJOR >= 11
1691+
if (gutils->newFunc->hasAttribute(llvm::AttributeList::ReturnIndex,
1692+
llvm::Attribute::NoUndef)) {
1693+
gutils->newFunc->removeAttribute(llvm::AttributeList::ReturnIndex,
1694+
llvm::Attribute::NoUndef);
1695+
}
1696+
#endif
16901697
if (gutils->newFunc->hasAttribute(llvm::AttributeList::ReturnIndex,
16911698
llvm::Attribute::NonNull)) {
16921699
gutils->newFunc->removeAttribute(llvm::AttributeList::ReturnIndex,
@@ -2005,6 +2012,13 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
20052012
NewF->removeAttribute(llvm::AttributeList::ReturnIndex,
20062013
llvm::Attribute::NoAlias);
20072014
}
2015+
#if LLVM_VERSION_MAJOR >= 11
2016+
if (NewF->hasAttribute(llvm::AttributeList::ReturnIndex,
2017+
llvm::Attribute::NoUndef)) {
2018+
NewF->removeAttribute(llvm::AttributeList::ReturnIndex,
2019+
llvm::Attribute::NoUndef);
2020+
}
2021+
#endif
20082022
if (NewF->hasAttribute(llvm::AttributeList::ReturnIndex,
20092023
llvm::Attribute::ZExt)) {
20102024
NewF->removeAttribute(llvm::AttributeList::ReturnIndex,
@@ -3168,6 +3182,13 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
31683182
gutils->newFunc->removeAttribute(llvm::AttributeList::ReturnIndex,
31693183
llvm::Attribute::NoAlias);
31703184
}
3185+
#if LLVM_VERSION_MAJOR >= 11
3186+
if (gutils->newFunc->hasAttribute(llvm::AttributeList::ReturnIndex,
3187+
llvm::Attribute::NoUndef)) {
3188+
gutils->newFunc->removeAttribute(llvm::AttributeList::ReturnIndex,
3189+
llvm::Attribute::NoUndef);
3190+
}
3191+
#endif
31713192
if (gutils->newFunc->hasAttribute(llvm::AttributeList::ReturnIndex,
31723193
llvm::Attribute::NonNull)) {
31733194
gutils->newFunc->removeAttribute(llvm::AttributeList::ReturnIndex,

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3632,7 +3632,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
36323632
});
36333633
if (failed)
36343634
break;
3635-
IRBuilder<> nv(ctx->getTerminator());
3635+
IRBuilder<> nv(nctx->getTerminator());
36363636
Value *nlim = unwrapM(lim, nv,
36373637
/*available*/ ValueToValueMapTy(),
36383638
UnwrapMode::AttemptFullUnwrapWithLookup);

enzyme/benchmarks/gmm/Makefile.make

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# RUN: cd %desired_wd/gmm && LD_LIBRARY_PATH="%bldpath:$LD_LIBRARY_PATH" BENCH="%bench" BENCHLINK="%blink" LOAD="%loadEnzyme" make -B gmm-unopt.ll gmm-raw.ll results.txt -f %s
1+
# RUN: if [ %llvmver -ge 12 ] || [ %llvmver -le 9 ]; then cd %desired_wd/gmm && LD_LIBRARY_PATH="%bldpath:$LD_LIBRARY_PATH" BENCH="%bench" BENCHLINK="%blink" LOAD="%loadEnzyme" make -B gmm-unopt.ll gmm-raw.ll results.txt -f %s; fi
22

33
.PHONY: clean
44

enzyme/test/Enzyme/ReverseMode/mpi_bcast.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ declare double @__enzyme_autodiff(i8*, ...)
5252
; CHECK-NEXT: %8 = tail call i8* @malloc(i64 %7)
5353
; CHECK-NEXT: call void @"__enzyme_mpi_sumFloat@doubleinitializer"()
5454
; CHECK-NEXT: %9 = load i8*, i8** @"__enzyme_mpi_sumFloat@double"
55-
; CHECK-NEXT: %10 = call i32 (...) @MPI_Reduce(i8* %"'ipc", i8* %8, i32 1, %struct.ompi_datatype_t* bitcast (%struct.ompi_predefined_datatype_t* @ompi_mpi_double to %struct.ompi_datatype_t*), i8* %9, i32 0, %struct.ompi_communicator_t* bitcast (%struct.ompi_predefined_communicator_t* @ompi_mpi_comm_world to %struct.ompi_communicator_t*))
55+
; CHECK-NEXT: %10 = call i32 @MPI_Reduce(i8* %"'ipc", i8* %8, i32 1, %struct.ompi_datatype_t* bitcast (%struct.ompi_predefined_datatype_t* @ompi_mpi_double to %struct.ompi_datatype_t*), i8* %9, i32 0, %struct.ompi_communicator_t* bitcast (%struct.ompi_predefined_communicator_t* @ompi_mpi_comm_world to %struct.ompi_communicator_t*))
5656
; CHECK-NEXT: %11 = call i32 @MPI_Comm_rank(%struct.ompi_communicator_t* bitcast (%struct.ompi_predefined_communicator_t* @ompi_mpi_comm_world to %struct.ompi_communicator_t*), i32* %1)
5757
; CHECK-NEXT: %12 = load i32, i32* %1
5858
; CHECK-NEXT: %13 = icmp eq i32 %12, 0
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
; RUN: if [ %llvmver -ge 11 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -simplifycfg -instcombine -adce -S | FileCheck %s; fi
2+
3+
; this test should ensure that the alignment on the <2 x double> load is kept
4+
5+
define void @caller(double* %in_W, double* %in_Wp) {
6+
entry:
7+
call void @__enzyme_autodiff(i8* bitcast (void (double*)* @matvec to i8*), double* nonnull %in_W, double* nonnull %in_Wp) #8
8+
ret void
9+
}
10+
11+
declare void @__enzyme_autodiff(i8*, double*, double*)
12+
13+
define noalias noundef nonnull align 8 double* @cst(double* noalias %W) {
14+
entry:
15+
ret double* %W
16+
}
17+
18+
define internal void @matvec(double* noalias %W) {
19+
entry:
20+
%ptr = call double* @cst(double* %W)
21+
%ld = load double, double* %ptr, align 8
22+
%mul = fmul double %ld, %ld
23+
store double %mul, double* %W
24+
ret void
25+
}
26+
27+
; CHECK: define internal { double*, double* } @augmented_cst(double* noalias %W, double* %"W'")
28+
; CHECK-NEXT: entry:
29+
; CHECK-NEXT: %.fca.0.insert = insertvalue { double*, double* } undef, double* %W, 0
30+
; CHECK-NEXT: %.fca.1.insert = insertvalue { double*, double* } %.fca.0.insert, double* %"W'", 1
31+
; CHECK-NEXT: ret { double*, double* } %.fca.1.insert
32+
; CHECK-NEXT: }
33+
34+
; CHECK: define internal void @diffecst(double* noalias %W, double* %"W'")
35+
; CHECK-NEXT: entry:
36+
; CHECK-NEXT: ret void
37+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)