Skip to content

Commit bb12649

Browse files
authored
ForwardMode InsertValueInst (rust-lang#328)
* implement InsertValueInst * add test
1 parent 4c9a28c commit bb12649

File tree

2 files changed

+96
-32
lines changed

2 files changed

+96
-32
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 64 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,45 +1271,77 @@ class AdjointGenerator
12711271
// TODO handle pointers
12721272
// TODO type analysis handle structs
12731273

1274-
IRBuilder<> Builder2(IVI.getParent());
1275-
getReverseBuilder(Builder2);
1274+
switch (Mode) {
1275+
case DerivativeMode::ForwardMode: {
1276+
IRBuilder<> Builder2(&IVI);
1277+
getForwardBuilder(Builder2);
12761278

1277-
Value *orig_inserted = IVI.getInsertedValueOperand();
1278-
Value *orig_agg = IVI.getAggregateOperand();
1279+
Value *orig_inserted = IVI.getInsertedValueOperand();
1280+
Value *orig_agg = IVI.getAggregateOperand();
12791281

1280-
size_t size0 = 1;
1281-
if (orig_inserted->getType()->isSized())
1282-
size0 = (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1283-
orig_inserted->getType()) +
1284-
7) /
1285-
8;
1282+
Value *diff_inserted = gutils->isConstantValue(orig_inserted)
1283+
? ConstantFP::get(orig_inserted->getType(), 0)
1284+
: diffe(orig_inserted, Builder2);
12861285

1287-
Type *flt = nullptr;
1288-
if (!gutils->isConstantValue(orig_inserted) &&
1289-
(flt = TR.intType(size0, orig_inserted).isFloat())) {
1290-
auto prediff = diffe(&IVI, Builder2);
1291-
auto dindex = Builder2.CreateExtractValue(prediff, IVI.getIndices());
1292-
addToDiffe(orig_inserted, dindex, Builder2, flt);
1286+
Value *prediff =
1287+
gutils->isConstantValue(orig_agg)
1288+
? diffe(orig_agg, Builder2)
1289+
: ConstantAggregate::getNullValue(orig_agg->getType());
1290+
auto dindex =
1291+
Builder2.CreateInsertValue(prediff, diff_inserted, IVI.getIndices());
1292+
setDiffe(&IVI, dindex, Builder2);
1293+
1294+
return;
12931295
}
1296+
case DerivativeMode::ReverseModeCombined:
1297+
case DerivativeMode::ReverseModeGradient: {
1298+
IRBuilder<> Builder2(IVI.getParent());
1299+
getReverseBuilder(Builder2);
12941300

1295-
size_t size1 = 1;
1296-
if (orig_agg->getType()->isSized() &&
1297-
(orig_agg->getType()->isIntOrIntVectorTy() ||
1298-
orig_agg->getType()->isFPOrFPVectorTy()))
1299-
size1 = (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1300-
orig_agg->getType()) +
1301-
7) /
1302-
8;
1301+
Value *orig_inserted = IVI.getInsertedValueOperand();
1302+
Value *orig_agg = IVI.getAggregateOperand();
13031303

1304-
if (!gutils->isConstantValue(orig_agg)) {
1305-
auto prediff = diffe(&IVI, Builder2);
1306-
auto dindex = Builder2.CreateInsertValue(
1307-
prediff, Constant::getNullValue(orig_inserted->getType()),
1308-
IVI.getIndices());
1309-
addToDiffe(orig_agg, dindex, Builder2, TR.addingType(size1, orig_agg));
1310-
}
1304+
size_t size0 = 1;
1305+
if (orig_inserted->getType()->isSized())
1306+
size0 =
1307+
(gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1308+
orig_inserted->getType()) +
1309+
7) /
1310+
8;
1311+
1312+
Type *flt = nullptr;
1313+
if (!gutils->isConstantValue(orig_inserted) &&
1314+
(flt = TR.intType(size0, orig_inserted).isFloat())) {
1315+
auto prediff = diffe(&IVI, Builder2);
1316+
auto dindex = Builder2.CreateExtractValue(prediff, IVI.getIndices());
1317+
addToDiffe(orig_inserted, dindex, Builder2, flt);
1318+
}
1319+
1320+
size_t size1 = 1;
1321+
if (orig_agg->getType()->isSized() &&
1322+
(orig_agg->getType()->isIntOrIntVectorTy() ||
1323+
orig_agg->getType()->isFPOrFPVectorTy()))
1324+
size1 =
1325+
(gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1326+
orig_agg->getType()) +
1327+
7) /
1328+
8;
1329+
1330+
if (!gutils->isConstantValue(orig_agg)) {
1331+
auto prediff = diffe(&IVI, Builder2);
1332+
auto dindex = Builder2.CreateInsertValue(
1333+
prediff, Constant::getNullValue(orig_inserted->getType()),
1334+
IVI.getIndices());
1335+
addToDiffe(orig_agg, dindex, Builder2, TR.addingType(size1, orig_agg));
1336+
}
13111337

1312-
setDiffe(&IVI, Constant::getNullValue(IVI.getType()), Builder2);
1338+
setDiffe(&IVI, Constant::getNullValue(IVI.getType()), Builder2);
1339+
return;
1340+
}
1341+
case DerivativeMode::ReverseModePrimal: {
1342+
return;
1343+
}
1344+
}
13131345
}
13141346

13151347
void getReverseBuilder(IRBuilder<> &Builder2, bool original = true) {
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -early-cse -S | FileCheck %s
2+
3+
define { double, double } @squared(double %x) {
4+
entry:
5+
%mul = fmul double %x, %x
6+
%mul2 = fmul double %mul, %x
7+
%.fca.0.insert = insertvalue { double, double } undef, double %mul, 0
8+
%.fca.1.insert = insertvalue { double, double } %.fca.0.insert, double %mul2, 1
9+
ret { double, double } %.fca.1.insert
10+
}
11+
12+
define { double, double } @dsquared(double %x) {
13+
entry:
14+
%call = call { double, double } (i8*, ...) @__enzyme_fwddiff(i8* bitcast ({ double, double } (double)* @squared to i8*), double %x, double 1.0)
15+
ret { double, double } %call
16+
}
17+
18+
declare { double, double } @__enzyme_fwddiff(i8*, ...)
19+
20+
21+
22+
; CHECK: define internal {{(dso_local )?}}{ double, double } @fwddiffesquared(double %x, double %"x'")
23+
; CHECK-NEXT: entry:
24+
; CHECK-NEXT: %mul = fmul double %x, %x
25+
; CHECK-NEXT: %0 = fmul fast double %"x'", %x
26+
; CHECK-NEXT: %1 = fadd fast double %0, %0
27+
; CHECK-NEXT: %2 = fmul fast double %1, %x
28+
; CHECK-NEXT: %3 = fmul fast double %"x'", %mul
29+
; CHECK-NEXT: %4 = fadd fast double %2, %3
30+
; CHECK-NEXT: %5 = insertvalue { double, double } zeroinitializer, double %4, 1
31+
; CHECK-NEXT: ret { double, double } %5
32+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)