174
174
#include " llvm/IR/Function.h"
175
175
#include " llvm/IR/GetElementPtrTypeIterator.h"
176
176
#include " llvm/IR/IRBuilder.h"
177
+ #include " llvm/IR/InstIterator.h"
177
178
#include " llvm/IR/InstrTypes.h"
178
179
#include " llvm/IR/Instruction.h"
179
180
#include " llvm/IR/Instructions.h"
190
191
#include " llvm/Support/ErrorHandling.h"
191
192
#include " llvm/Support/raw_ostream.h"
192
193
#include " llvm/Transforms/Scalar.h"
194
+ #include " llvm/Transforms/Utils/BasicBlockUtils.h"
193
195
#include " llvm/Transforms/Utils/Local.h"
194
196
#include < cassert>
195
197
#include < cstdint>
198
200
using namespace llvm ;
199
201
using namespace llvm ::PatternMatch;
200
202
203
+ #define DEBUG_TYPE " separate-offset-gep"
204
+
201
205
static cl::opt<bool > DisableSeparateConstOffsetFromGEP (
202
206
" disable-separate-const-offset-from-gep" , cl::init(false ),
203
207
cl::desc(" Do not separate the constant offset from a GEP instruction" ),
@@ -488,6 +492,42 @@ class SeparateConstOffsetFromGEP {
488
492
DenseMap<ExprKey, SmallVector<Instruction *, 2 >> DominatingSubs;
489
493
};
490
494
495
+ // / A helper class that aims to convert xor operations into or operations when
496
+ // / their operands are disjoint and the result is used in a GEP's index. This
497
+ // / can then enable further GEP optimizations by effectively turning BaseVal |
498
+ // / Const into BaseVal + Const when they are disjoint, which
499
+ // / SeparateConstOffsetFromGEP can then process. This is a common pattern that
500
+ // / sets up a grid of memory accesses across a wave where each thread acesses
501
+ // / data at various offsets.
502
+ class XorToOrDisjointTransformer {
503
+ public:
504
+ XorToOrDisjointTransformer (Function &F, DominatorTree &DT,
505
+ const DataLayout &DL)
506
+ : F(F), DT(DT), DL(DL) {}
507
+
508
+ bool run ();
509
+
510
+ private:
511
+ Function &F;
512
+ DominatorTree &DT;
513
+ const DataLayout &DL;
514
+ // / Maps a common operand to all Xor instructions
515
+ using XorOpList = SmallVector<std::pair<BinaryOperator *, APInt>, 8 >;
516
+ using XorBaseValInst = DenseMap<Instruction *, XorOpList>;
517
+ XorBaseValInst XorGroups;
518
+
519
+ // / Checks if the given value has at least one GetElementPtr user
520
+ static bool hasGEPUser (const Value *V);
521
+
522
+ // / Helper function to check if BaseXor dominates all XORs in the group
523
+ bool dominatesAllXors (BinaryOperator *BaseXor, const XorOpList &XorsInGroup);
524
+
525
+ // / Processes a group of XOR instructions that share the same non-constant
526
+ // / base operand. Returns true if this group's processing modified the
527
+ // / function.
528
+ bool processXorGroup (Instruction *OriginalBaseInst, XorOpList &XorsInGroup);
529
+ };
530
+
491
531
} // end anonymous namespace
492
532
493
533
char SeparateConstOffsetFromGEPLegacyPass::ID = 0 ;
@@ -1223,6 +1263,154 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) {
1223
1263
return true ;
1224
1264
}
1225
1265
1266
+ // Helper function to check if an instruction has at least one GEP user
1267
+ bool XorToOrDisjointTransformer::hasGEPUser (const Value *V) {
1268
+ return llvm::any_of (V->users (), [](const User *U) {
1269
+ return isa<llvm::GetElementPtrInst>(U);
1270
+ });
1271
+ }
1272
+
1273
+ bool XorToOrDisjointTransformer::dominatesAllXors (
1274
+ BinaryOperator *BaseXor, const XorOpList &XorsInGroup) {
1275
+ return llvm::all_of (XorsInGroup, [&](const auto &XorEntry) {
1276
+ BinaryOperator *XorInst = XorEntry.first ;
1277
+ // Do not evaluate the BaseXor, otherwise we end up cloning it.
1278
+ return XorInst == BaseXor || DT.dominates (BaseXor, XorInst);
1279
+ });
1280
+ }
1281
+
1282
+ bool XorToOrDisjointTransformer::processXorGroup (Instruction *OriginalBaseInst,
1283
+ XorOpList &XorsInGroup) {
1284
+ bool Changed = false ;
1285
+ if (XorsInGroup.size () <= 1 )
1286
+ return false ;
1287
+
1288
+ // Sort XorsInGroup by the constant offset value in increasing order.
1289
+ llvm::sort (XorsInGroup, [](const auto &A, const auto &B) {
1290
+ return A.second .slt (B.second );
1291
+ });
1292
+
1293
+ // Dominance check
1294
+ // The "base" XOR for dominance purposes is the one with the smallest
1295
+ // constant.
1296
+ BinaryOperator *XorWithSmallConst = XorsInGroup[0 ].first ;
1297
+
1298
+ if (!dominatesAllXors (XorWithSmallConst, XorsInGroup)) {
1299
+ LLVM_DEBUG (dbgs () << DEBUG_TYPE
1300
+ << " : Cloning and inserting XOR with smallest constant ("
1301
+ << *XorWithSmallConst
1302
+ << " ) as it does not dominate all other XORs"
1303
+ << " in function " << F.getName () << " \n " );
1304
+
1305
+ BinaryOperator *ClonedXor =
1306
+ cast<BinaryOperator>(XorWithSmallConst->clone ());
1307
+ ClonedXor->setName (XorWithSmallConst->getName () + " .dom_clone" );
1308
+ ClonedXor->insertAfter (OriginalBaseInst);
1309
+ LLVM_DEBUG (dbgs () << " Cloned Inst: " << *ClonedXor << " \n " );
1310
+ Changed = true ;
1311
+ XorWithSmallConst = ClonedXor;
1312
+ }
1313
+
1314
+ SmallVector<Instruction *, 8 > InstructionsToErase;
1315
+ const APInt SmallestConst =
1316
+ cast<ConstantInt>(XorWithSmallConst->getOperand (1 ))->getValue ();
1317
+
1318
+ // Main transformation loop: Iterate over the original XORs in the sorted
1319
+ // group.
1320
+ for (const auto &XorEntry : XorsInGroup) {
1321
+ BinaryOperator *XorInst = XorEntry.first ; // Original XOR instruction
1322
+ const APInt ConstOffsetVal = XorEntry.second ;
1323
+
1324
+ // Do not process the one with smallest constant as it is the base.
1325
+ if (XorInst == XorWithSmallConst)
1326
+ continue ;
1327
+
1328
+ // Disjointness Check 1
1329
+ APInt NewConstVal = ConstOffsetVal - SmallestConst;
1330
+ if ((NewConstVal & SmallestConst) != 0 ) {
1331
+ LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Cannot transform XOR in function "
1332
+ << F.getName () << " :\n "
1333
+ << " New Const: " << NewConstVal
1334
+ << " Smallest Const: " << SmallestConst
1335
+ << " are not disjoint \n " );
1336
+ continue ;
1337
+ }
1338
+
1339
+ // Disjointness Check 2
1340
+ if (MaskedValueIsZero (XorWithSmallConst, NewConstVal, SimplifyQuery (DL),
1341
+ 0 )) {
1342
+ LLVM_DEBUG (dbgs () << DEBUG_TYPE
1343
+ << " : Transforming XOR to OR (disjoint) in function "
1344
+ << F.getName () << " :\n "
1345
+ << " Xor: " << *XorInst << " \n "
1346
+ << " Base Val: " << *XorWithSmallConst << " \n "
1347
+ << " New Const: " << NewConstVal << " \n " );
1348
+
1349
+ auto *NewOrInst = BinaryOperator::CreateDisjointOr (
1350
+ XorWithSmallConst,
1351
+ ConstantInt::get (OriginalBaseInst->getType (), NewConstVal),
1352
+ XorInst->getName () + " .or_disjoint" , XorInst->getIterator ());
1353
+
1354
+ NewOrInst->copyMetadata (*XorInst);
1355
+ XorInst->replaceAllUsesWith (NewOrInst);
1356
+ LLVM_DEBUG (dbgs () << " New Inst: " << *NewOrInst << " \n " );
1357
+ InstructionsToErase.push_back (XorInst); // Mark original XOR for deletion
1358
+
1359
+ Changed = true ;
1360
+ } else {
1361
+ LLVM_DEBUG (
1362
+ dbgs () << DEBUG_TYPE
1363
+ << " : Cannot transform XOR (not proven disjoint) in function "
1364
+ << F.getName () << " :\n "
1365
+ << " Xor: " << *XorInst << " \n "
1366
+ << " Base Val: " << *XorWithSmallConst << " \n "
1367
+ << " New Const: " << NewConstVal << " \n " );
1368
+ }
1369
+ }
1370
+
1371
+ for (Instruction *I : InstructionsToErase)
1372
+ I->eraseFromParent ();
1373
+
1374
+ return Changed;
1375
+ }
1376
+
1377
+ // Try to transform XOR(A, B+C) in to XOR(A,C) + B where XOR(A,C) becomes
1378
+ // the base for memory operations. This transformation is true under the
1379
+ // following conditions
1380
+ // Check 1 - B and C are disjoint.
1381
+ // Check 2 - XOR(A,C) and B are disjoint.
1382
+ //
1383
+ // This transformation is beneficial particularly for GEPs because:
1384
+ // 1. OR operations often map better to addressing modes than XOR
1385
+ // 2. Disjoint OR operations preserve the semantics of the original XOR
1386
+ // 3. This can enable further optimizations in the GEP offset folding pipeline
1387
+ bool XorToOrDisjointTransformer::run () {
1388
+ bool Changed = false ;
1389
+
1390
+ // Collect all candidate XORs
1391
+ for (Instruction &I : instructions (F)) {
1392
+ Instruction *Op0 = nullptr ;
1393
+ ConstantInt *C1 = nullptr ;
1394
+ BinaryOperator *MatchedXorOp = nullptr ;
1395
+
1396
+ // Attempt to match the instruction 'I' as XOR operation.
1397
+ if (match (&I, m_CombineAnd (m_Xor (m_Instruction (Op0), m_ConstantInt (C1)),
1398
+ m_BinOp (MatchedXorOp))) &&
1399
+ hasGEPUser (MatchedXorOp))
1400
+ XorGroups[Op0].emplace_back (MatchedXorOp, C1->getValue ());
1401
+ }
1402
+
1403
+ if (XorGroups.empty ())
1404
+ return false ;
1405
+
1406
+ // Process each group of XORs
1407
+ for (auto &[OriginalBaseInst, XorsInGroup] : XorGroups)
1408
+ if (processXorGroup (OriginalBaseInst, XorsInGroup))
1409
+ Changed = true ;
1410
+
1411
+ return Changed;
1412
+ }
1413
+
1226
1414
bool SeparateConstOffsetFromGEPLegacyPass::runOnFunction (Function &F) {
1227
1415
if (skipFunction (F))
1228
1416
return false ;
@@ -1242,6 +1430,11 @@ bool SeparateConstOffsetFromGEP::run(Function &F) {
1242
1430
1243
1431
DL = &F.getDataLayout ();
1244
1432
bool Changed = false ;
1433
+
1434
+ // Decompose xor in to "or disjoint" if possible.
1435
+ XorToOrDisjointTransformer XorTransformer (F, *DT, *DL);
1436
+ Changed |= XorTransformer.run ();
1437
+
1245
1438
for (BasicBlock &B : F) {
1246
1439
if (!DT->isReachableFromEntry (&B))
1247
1440
continue ;
0 commit comments