Skip to content

Commit a8e09db

Browse files
authored
Activity of unreachable use and gemm type (rust-lang#239)
1 parent b7e2f14 commit a8e09db

File tree

4 files changed

+138
-42
lines changed

4 files changed

+138
-42
lines changed

enzyme/Enzyme/ActivityAnalysis.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,14 @@ bool ActivityAnalyzer::isConstantInstruction(TypeResults &TR, Instruction *I) {
283283
return false;
284284
}
285285

286+
if (!TR.isBlockAnalyzed(I->getParent())) {
287+
if (EnzymePrintActivity)
288+
llvm::errs() << " constant instruction as dominates unreachable " << *I
289+
<< "\n";
290+
InsertConstantInstruction(TR, I);
291+
return true;
292+
}
293+
286294
/// A store into all integral memory is inactive
287295
if (auto SI = dyn_cast<StoreInst>(I)) {
288296
auto StoreSize = SI->getParent()
@@ -429,7 +437,7 @@ bool ActivityAnalyzer::isConstantInstruction(TypeResults &TR, Instruction *I) {
429437
if (directions == DOWN && !isa<PHINode>(I)) {
430438
if (isValueInactiveFromUsers(TR, I, UseActivity::None)) {
431439
if (EnzymePrintActivity)
432-
llvm::errs() << " constant instruction[" << directions
440+
llvm::errs() << " constant instruction[" << (int)directions
433441
<< "] from users instruction " << *I << "\n";
434442
InsertConstantInstruction(TR, I);
435443
return true;
@@ -441,7 +449,7 @@ bool ActivityAnalyzer::isConstantInstruction(TypeResults &TR, Instruction *I) {
441449
if (DownHypothesis->isValueInactiveFromUsers(TR, I,
442450
UseActivity::None)) {
443451
if (EnzymePrintActivity)
444-
llvm::errs() << " constant instruction[" << directions
452+
llvm::errs() << " constant instruction[" << (int)directions
445453
<< "] from users instruction " << *I << "\n";
446454
InsertConstantInstruction(TR, I);
447455
insertConstantsFrom(TR, *DownHypothesis);
@@ -1804,6 +1812,13 @@ bool ActivityAnalyzer::isValueInactiveFromUsers(TypeResults &TR,
18041812
// if its return is used in an active way, therefore add this to
18051813
// the list of users to analyze
18061814
if (auto I = dyn_cast<Instruction>(a)) {
1815+
if (!TR.isBlockAnalyzed(I->getParent())) {
1816+
if (EnzymePrintActivity) {
1817+
llvm::errs() << "Value found constant unreachable inst use:" << *val
1818+
<< " user " << *I << "\n";
1819+
}
1820+
continue;
1821+
}
18071822
if (ConstantInstructions.count(I) &&
18081823
(I->getType()->isVoidTy() || ConstantValues.count(I))) {
18091824
if (EnzymePrintActivity) {

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 49 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -468,51 +468,58 @@ class AdjointGenerator
468468
}
469469
}
470470

471-
bool isfloat = type->isFPOrFPVectorTy();
472-
if (!isfloat && type->isIntOrIntVectorTy()) {
473-
auto LoadSize = DL.getTypeSizeInBits(type) / 8;
474-
ConcreteType vd = BaseType::Unknown;
475-
if (!OrigOffset)
476-
vd = TR.firstPointer(LoadSize, I.getOperand(0),
477-
/*errifnotfound*/ false, /*pointerIntSame*/ true);
478-
if (vd.isKnown())
479-
isfloat = vd.isFloat();
480-
else
481-
isfloat = TR.intType(LoadSize, &I, /*errIfNotFound*/ !looseTypeAnalysis)
482-
.isFloat();
483-
}
471+
// Only propagate if instruction is active. The value can be active and not
472+
// the instruction if the value is a potential pointer. This may not be
473+
// caught by type analysis is the result does not have a known type.
474+
if (!gutils->isConstantInstruction(&I)) {
475+
bool isfloat = type->isFPOrFPVectorTy();
476+
if (!isfloat && type->isIntOrIntVectorTy()) {
477+
auto LoadSize = DL.getTypeSizeInBits(type) / 8;
478+
ConcreteType vd = BaseType::Unknown;
479+
if (!OrigOffset)
480+
vd =
481+
TR.firstPointer(LoadSize, I.getOperand(0),
482+
/*errifnotfound*/ false, /*pointerIntSame*/ true);
483+
if (vd.isKnown())
484+
isfloat = vd.isFloat();
485+
else
486+
isfloat =
487+
TR.intType(LoadSize, &I, /*errIfNotFound*/ !looseTypeAnalysis)
488+
.isFloat();
489+
}
484490

485-
if (isfloat) {
491+
if (isfloat) {
486492

487-
switch (Mode) {
488-
case DerivativeMode::ForwardMode: {
489-
IRBuilder<> Builder2(&I);
490-
getForwardBuilder(Builder2);
493+
switch (Mode) {
494+
case DerivativeMode::ForwardMode: {
495+
IRBuilder<> Builder2(&I);
496+
getForwardBuilder(Builder2);
491497

492-
if (!gutils->isConstantValue(&I)) {
493-
auto diff = Builder2.CreateLoad(
494-
gutils->invertPointerM(I.getOperand(0), Builder2));
495-
setDiffe(&I, diff, Builder2);
498+
if (!gutils->isConstantValue(&I)) {
499+
auto diff = Builder2.CreateLoad(
500+
gutils->invertPointerM(I.getOperand(0), Builder2));
501+
setDiffe(&I, diff, Builder2);
502+
}
503+
break;
496504
}
497-
break;
498-
}
499-
case DerivativeMode::ReverseModeGradient:
500-
case DerivativeMode::ReverseModeCombined: {
501-
IRBuilder<> Builder2(parent);
502-
getReverseBuilder(Builder2);
505+
case DerivativeMode::ReverseModeGradient:
506+
case DerivativeMode::ReverseModeCombined: {
507+
IRBuilder<> Builder2(parent);
508+
getReverseBuilder(Builder2);
503509

504-
auto prediff = diffe(&I, Builder2);
505-
setDiffe(&I, Constant::getNullValue(type), Builder2);
510+
auto prediff = diffe(&I, Builder2);
511+
setDiffe(&I, Constant::getNullValue(type), Builder2);
506512

507-
if (!gutils->isConstantValue(I.getOperand(0))) {
508-
((DiffeGradientUtils *)gutils)
509-
->addToInvertedPtrDiffe(I.getOperand(0), prediff, Builder2,
510-
alignment, OrigOffset);
513+
if (!gutils->isConstantValue(I.getOperand(0))) {
514+
((DiffeGradientUtils *)gutils)
515+
->addToInvertedPtrDiffe(I.getOperand(0), prediff, Builder2,
516+
alignment, OrigOffset);
517+
}
518+
break;
519+
}
520+
case DerivativeMode::ReverseModePrimal:
521+
break;
511522
}
512-
break;
513-
}
514-
case DerivativeMode::ReverseModePrimal:
515-
break;
516523
}
517524
}
518525
}
@@ -4479,9 +4486,10 @@ class AdjointGenerator
44794486
}
44804487

44814488
if (funcName == "julia.pointer_from_objref") {
4482-
eraseIfUnused(*orig);
4483-
if (gutils->isConstantValue(orig))
4489+
if (gutils->isConstantValue(orig)) {
4490+
eraseIfUnused(*orig);
44844491
return;
4492+
}
44854493

44864494
Value *ptrshadow =
44874495
gutils->invertPointerM(call.getArgOperand(0), BuilderZ);
@@ -4494,6 +4502,7 @@ class AdjointGenerator
44944502
gutils->invertedPointers[orig] = val;
44954503
gutils->replaceAWithB(placeholder, val);
44964504
gutils->erase(placeholder);
4505+
eraseIfUnused(*orig);
44974506
return;
44984507
}
44994508

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3134,6 +3134,26 @@ void TypeAnalyzer::visitCallInst(CallInst &call) {
31343134
funcName.startswith("_ZN4core3fmt")) {
31353135
return;
31363136
}
3137+
/// GEMM
3138+
if (funcName == "dgemm_64" || funcName == "dgemm_64_" ||
3139+
funcName == "dgemm" || funcName == "dgemm_") {
3140+
TypeTree ptrint;
3141+
ptrint.insert({-1}, BaseType::Pointer);
3142+
ptrint.insert({-1, 0}, BaseType::Integer);
3143+
// transa, transb, m, n, k, lda, ldb, ldc
3144+
for (int i : {0, 1, 2, 3, 4, 7, 9, 12})
3145+
updateAnalysis(call.getArgOperand(i), ptrint, &call);
3146+
3147+
TypeTree ptrdbl;
3148+
ptrdbl.insert({-1}, BaseType::Pointer);
3149+
ptrdbl.insert({-1, 0}, Type::getDoubleTy(call.getContext()));
3150+
3151+
// alpha, a, b, beta, c
3152+
for (int i : {5, 6, 8, 10, 11})
3153+
updateAnalysis(call.getArgOperand(i), ptrdbl, &call);
3154+
return;
3155+
}
3156+
31373157
/// MPI
31383158
if (funcName == "MPI_Init") {
31393159
TypeTree ptrint;
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -adce -S | FileCheck %s
2+
3+
declare void @llvm.trap()
4+
5+
declare void @baduse(i64, double*)
6+
7+
; Function Attrs: nounwind readnone uwtable
8+
define double @tester({double, i64}* %x, i1 %cmp) {
9+
entry:
10+
%gep0 = getelementptr inbounds {double, i64}, {double, i64}* %x, i64 0, i32 0
11+
%gep1 = getelementptr inbounds {double, i64}, {double, i64}* %x, i64 0, i32 1
12+
%ld = load i64, i64* %gep1
13+
br i1 %cmp, label %exit, label %err
14+
15+
err:
16+
call void @baduse(i64 %ld, double* %gep0)
17+
call void @llvm.trap()
18+
unreachable
19+
20+
exit:
21+
%res = load double, double* %gep0
22+
ret double %res
23+
}
24+
25+
define double @test_derivative({double, i64}* %x, {double, i64}* %dx, i1 %cmp) {
26+
entry:
27+
%0 = tail call double (double ({double, i64}*, i1)*, ...) @__enzyme_autodiff(double ({double, i64}*, i1)* nonnull @tester, {double, i64}* %x, {double, i64}* %dx, i1 %cmp)
28+
ret double %0
29+
}
30+
31+
; Function Attrs: nounwind
32+
declare double @__enzyme_autodiff(double ({double, i64}*, i1)*, ...)
33+
34+
; CHECK: define internal void @diffetester({ double, i64 }* %x, { double, i64 }* %"x'", i1 %cmp, double %differeturn)
35+
; CHECK-NEXT: entry:
36+
; CHECK-NEXT: %"gep0'ipg" = getelementptr inbounds { double, i64 }, { double, i64 }* %"x'", i64 0, i32 0
37+
; CHECK-NEXT: %gep0 = getelementptr inbounds { double, i64 }, { double, i64 }* %x, i64 0, i32 0
38+
; CHECK-NEXT: %gep1 = getelementptr inbounds { double, i64 }, { double, i64 }* %x, i64 0, i32 1
39+
; CHECK-NEXT: %ld = load i64, i64* %gep1
40+
; CHECK-NEXT: br i1 %cmp, label %invertexit, label %err
41+
42+
; CHECK: err: ; preds = %entry
43+
; CHECK-NEXT: call void @baduse(i64 %ld, double* %gep0)
44+
; CHECK-NEXT: call void @llvm.trap()
45+
; CHECK-NEXT: unreachable
46+
47+
; CHECK: invertexit: ; preds = %entry
48+
; CHECK-NEXT: %0 = load double, double* %"gep0'ipg"
49+
; CHECK-NEXT: %1 = fadd fast double %0, %differeturn
50+
; CHECK-NEXT: store double %1, double* %"gep0'ipg"
51+
; CHECK-NEXT: ret void
52+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)