@@ -1271,45 +1271,77 @@ class AdjointGenerator
1271
1271
// TODO handle pointers
1272
1272
// TODO type analysis handle structs
1273
1273
1274
- IRBuilder<> Builder2 (IVI.getParent ());
1275
- getReverseBuilder (Builder2);
1274
+ switch (Mode) {
1275
+ case DerivativeMode::ForwardMode: {
1276
+ IRBuilder<> Builder2 (&IVI);
1277
+ getForwardBuilder (Builder2);
1276
1278
1277
- Value *orig_inserted = IVI.getInsertedValueOperand ();
1278
- Value *orig_agg = IVI.getAggregateOperand ();
1279
+ Value *orig_inserted = IVI.getInsertedValueOperand ();
1280
+ Value *orig_agg = IVI.getAggregateOperand ();
1279
1281
1280
- size_t size0 = 1 ;
1281
- if (orig_inserted->getType ()->isSized ())
1282
- size0 = (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
1283
- orig_inserted->getType ()) +
1284
- 7 ) /
1285
- 8 ;
1282
+ Value *diff_inserted = gutils->isConstantValue (orig_inserted)
1283
+ ? ConstantFP::get (orig_inserted->getType (), 0 )
1284
+ : diffe (orig_inserted, Builder2);
1286
1285
1287
- Type *flt = nullptr ;
1288
- if (!gutils->isConstantValue (orig_inserted) &&
1289
- (flt = TR.intType (size0, orig_inserted).isFloat ())) {
1290
- auto prediff = diffe (&IVI, Builder2);
1291
- auto dindex = Builder2.CreateExtractValue (prediff, IVI.getIndices ());
1292
- addToDiffe (orig_inserted, dindex, Builder2, flt);
1286
+ Value *prediff =
1287
+ gutils->isConstantValue (orig_agg)
1288
+ ? diffe (orig_agg, Builder2)
1289
+ : ConstantAggregate::getNullValue (orig_agg->getType ());
1290
+ auto dindex =
1291
+ Builder2.CreateInsertValue (prediff, diff_inserted, IVI.getIndices ());
1292
+ setDiffe (&IVI, dindex, Builder2);
1293
+
1294
+ return ;
1293
1295
}
1296
+ case DerivativeMode::ReverseModeCombined:
1297
+ case DerivativeMode::ReverseModeGradient: {
1298
+ IRBuilder<> Builder2 (IVI.getParent ());
1299
+ getReverseBuilder (Builder2);
1294
1300
1295
- size_t size1 = 1 ;
1296
- if (orig_agg->getType ()->isSized () &&
1297
- (orig_agg->getType ()->isIntOrIntVectorTy () ||
1298
- orig_agg->getType ()->isFPOrFPVectorTy ()))
1299
- size1 = (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
1300
- orig_agg->getType ()) +
1301
- 7 ) /
1302
- 8 ;
1301
+ Value *orig_inserted = IVI.getInsertedValueOperand ();
1302
+ Value *orig_agg = IVI.getAggregateOperand ();
1303
1303
1304
- if (!gutils->isConstantValue (orig_agg)) {
1305
- auto prediff = diffe (&IVI, Builder2);
1306
- auto dindex = Builder2.CreateInsertValue (
1307
- prediff, Constant::getNullValue (orig_inserted->getType ()),
1308
- IVI.getIndices ());
1309
- addToDiffe (orig_agg, dindex, Builder2, TR.addingType (size1, orig_agg));
1310
- }
1304
+ size_t size0 = 1 ;
1305
+ if (orig_inserted->getType ()->isSized ())
1306
+ size0 =
1307
+ (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
1308
+ orig_inserted->getType ()) +
1309
+ 7 ) /
1310
+ 8 ;
1311
+
1312
+ Type *flt = nullptr ;
1313
+ if (!gutils->isConstantValue (orig_inserted) &&
1314
+ (flt = TR.intType (size0, orig_inserted).isFloat ())) {
1315
+ auto prediff = diffe (&IVI, Builder2);
1316
+ auto dindex = Builder2.CreateExtractValue (prediff, IVI.getIndices ());
1317
+ addToDiffe (orig_inserted, dindex, Builder2, flt);
1318
+ }
1319
+
1320
+ size_t size1 = 1 ;
1321
+ if (orig_agg->getType ()->isSized () &&
1322
+ (orig_agg->getType ()->isIntOrIntVectorTy () ||
1323
+ orig_agg->getType ()->isFPOrFPVectorTy ()))
1324
+ size1 =
1325
+ (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
1326
+ orig_agg->getType ()) +
1327
+ 7 ) /
1328
+ 8 ;
1329
+
1330
+ if (!gutils->isConstantValue (orig_agg)) {
1331
+ auto prediff = diffe (&IVI, Builder2);
1332
+ auto dindex = Builder2.CreateInsertValue (
1333
+ prediff, Constant::getNullValue (orig_inserted->getType ()),
1334
+ IVI.getIndices ());
1335
+ addToDiffe (orig_agg, dindex, Builder2, TR.addingType (size1, orig_agg));
1336
+ }
1311
1337
1312
- setDiffe (&IVI, Constant::getNullValue (IVI.getType ()), Builder2);
1338
+ setDiffe (&IVI, Constant::getNullValue (IVI.getType ()), Builder2);
1339
+ return ;
1340
+ }
1341
+ case DerivativeMode::ReverseModePrimal: {
1342
+ return ;
1343
+ }
1344
+ }
1313
1345
}
1314
1346
1315
1347
void getReverseBuilder (IRBuilder<> &Builder2, bool original = true ) {
0 commit comments