@@ -117,8 +117,8 @@ class MemCmpExpansion {
117
117
Value *Lhs = nullptr ;
118
118
Value *Rhs = nullptr ;
119
119
};
120
- LoadPair getLoadPair (Type *LoadSizeType, bool NeedsBSwap, Type *CmpSizeType ,
121
- unsigned OffsetBytes);
120
+ LoadPair getLoadPair (Type *LoadSizeType, Type *BSwapSizeType ,
121
+ Type *CmpSizeType, unsigned OffsetBytes);
122
122
123
123
static LoadEntryVector
124
124
computeGreedyLoadSequence (uint64_t Size, llvm::ArrayRef<unsigned > LoadSizes,
@@ -128,6 +128,11 @@ class MemCmpExpansion {
128
128
unsigned MaxNumLoads,
129
129
unsigned &NumLoadsNonOneByte);
130
130
131
+ static void optimiseLoadSequence (
132
+ LoadEntryVector &LoadSequence,
133
+ const TargetTransformInfo::MemCmpExpansionOptions &Options,
134
+ bool IsUsedForZeroCmp);
135
+
131
136
public:
132
137
MemCmpExpansion (CallInst *CI, uint64_t Size,
133
138
const TargetTransformInfo::MemCmpExpansionOptions &Options,
@@ -210,6 +215,37 @@ MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size,
210
215
return LoadSequence;
211
216
}
212
217
218
+ void MemCmpExpansion::optimiseLoadSequence (
219
+ LoadEntryVector &LoadSequence,
220
+ const TargetTransformInfo::MemCmpExpansionOptions &Options,
221
+ bool IsUsedForZeroCmp) {
222
+ // This part of code attempts to optimize the LoadSequence by merging allowed
223
+ // subsequences into single loads of allowed sizes from
224
+ // `MemCmpExpansionOptions::AllowedTailExpansions`. If it is for zero
225
+ // comparison or if no allowed tail expansions are specified, we exit early.
226
+ if (IsUsedForZeroCmp || Options.AllowedTailExpansions .empty ())
227
+ return ;
228
+
229
+ while (LoadSequence.size () >= 2 ) {
230
+ auto Last = LoadSequence[LoadSequence.size () - 1 ];
231
+ auto PreLast = LoadSequence[LoadSequence.size () - 2 ];
232
+
233
+ // Exit the loop if the two sequences are not contiguous
234
+ if (PreLast.Offset + PreLast.LoadSize != Last.Offset )
235
+ break ;
236
+
237
+ auto LoadSize = Last.LoadSize + PreLast.LoadSize ;
238
+ if (find (Options.AllowedTailExpansions , LoadSize) ==
239
+ Options.AllowedTailExpansions .end ())
240
+ break ;
241
+
242
+ // Remove the last two sequences and replace with the combined sequence
243
+ LoadSequence.pop_back ();
244
+ LoadSequence.pop_back ();
245
+ LoadSequence.emplace_back (PreLast.Offset , LoadSize);
246
+ }
247
+ }
248
+
213
249
// Initialize the basic block structure required for expansion of memcmp call
214
250
// with given maximum load size and memcmp size parameter.
215
251
// This structure includes:
@@ -255,6 +291,7 @@ MemCmpExpansion::MemCmpExpansion(
255
291
}
256
292
}
257
293
assert (LoadSequence.size () <= Options.MaxNumLoads && " broken invariant" );
294
+ optimiseLoadSequence (LoadSequence, Options, IsUsedForZeroCmp);
258
295
}
259
296
260
297
unsigned MemCmpExpansion::getNumBlocks () {
@@ -278,7 +315,7 @@ void MemCmpExpansion::createResultBlock() {
278
315
}
279
316
280
317
MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair (Type *LoadSizeType,
281
- bool NeedsBSwap ,
318
+ Type *BSwapSizeType ,
282
319
Type *CmpSizeType,
283
320
unsigned OffsetBytes) {
284
321
// Get the memory source at offset `OffsetBytes`.
@@ -307,16 +344,22 @@ MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType,
307
344
if (!Rhs)
308
345
Rhs = Builder.CreateAlignedLoad (LoadSizeType, RhsSource, RhsAlign);
309
346
347
+ // Zero extend if Byte Swap intrinsic has different type
348
+ if (BSwapSizeType && LoadSizeType != BSwapSizeType) {
349
+ Lhs = Builder.CreateZExt (Lhs, BSwapSizeType);
350
+ Rhs = Builder.CreateZExt (Rhs, BSwapSizeType);
351
+ }
352
+
310
353
// Swap bytes if required.
311
- if (NeedsBSwap ) {
312
- Function *Bswap = Intrinsic::getDeclaration (CI-> getModule (),
313
- Intrinsic::bswap, LoadSizeType );
354
+ if (BSwapSizeType ) {
355
+ Function *Bswap = Intrinsic::getDeclaration (
356
+ CI-> getModule (), Intrinsic::bswap, BSwapSizeType );
314
357
Lhs = Builder.CreateCall (Bswap, Lhs);
315
358
Rhs = Builder.CreateCall (Bswap, Rhs);
316
359
}
317
360
318
361
// Zero extend if required.
319
- if (CmpSizeType != nullptr && CmpSizeType != LoadSizeType ) {
362
+ if (CmpSizeType != nullptr && CmpSizeType != Lhs-> getType () ) {
320
363
Lhs = Builder.CreateZExt (Lhs, CmpSizeType);
321
364
Rhs = Builder.CreateZExt (Rhs, CmpSizeType);
322
365
}
@@ -332,7 +375,7 @@ void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
332
375
BasicBlock *BB = LoadCmpBlocks[BlockIndex];
333
376
Builder.SetInsertPoint (BB);
334
377
const LoadPair Loads =
335
- getLoadPair (Type::getInt8Ty (CI->getContext ()), /* NeedsBSwap= */ false ,
378
+ getLoadPair (Type::getInt8Ty (CI->getContext ()), nullptr ,
336
379
Type::getInt32Ty (CI->getContext ()), OffsetBytes);
337
380
Value *Diff = Builder.CreateSub (Loads.Lhs , Loads.Rhs );
338
381
@@ -385,11 +428,12 @@ Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex,
385
428
IntegerType *const MaxLoadType =
386
429
NumLoads == 1 ? nullptr
387
430
: IntegerType::get (CI->getContext (), MaxLoadSize * 8 );
431
+
388
432
for (unsigned i = 0 ; i < NumLoads; ++i, ++LoadIndex) {
389
433
const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
390
434
const LoadPair Loads = getLoadPair (
391
- IntegerType::get (CI->getContext (), CurLoadEntry.LoadSize * 8 ),
392
- /* NeedsBSwap= */ false , MaxLoadType, CurLoadEntry.Offset );
435
+ IntegerType::get (CI->getContext (), CurLoadEntry.LoadSize * 8 ), nullptr ,
436
+ MaxLoadType, CurLoadEntry.Offset );
393
437
394
438
if (NumLoads != 1 ) {
395
439
// If we have multiple loads per block, we need to generate a composite
@@ -475,14 +519,20 @@ void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) {
475
519
476
520
Type *LoadSizeType =
477
521
IntegerType::get (CI->getContext (), CurLoadEntry.LoadSize * 8 );
478
- Type *MaxLoadType = IntegerType::get (CI->getContext (), MaxLoadSize * 8 );
522
+ Type *BSwapSizeType =
523
+ DL.isLittleEndian ()
524
+ ? IntegerType::get (CI->getContext (),
525
+ PowerOf2Ceil (CurLoadEntry.LoadSize * 8 ))
526
+ : nullptr ;
527
+ Type *MaxLoadType = IntegerType::get (
528
+ CI->getContext (),
529
+ std::max (MaxLoadSize, (unsigned )PowerOf2Ceil (CurLoadEntry.LoadSize )) * 8 );
479
530
assert (CurLoadEntry.LoadSize <= MaxLoadSize && " Unexpected load type" );
480
531
481
532
Builder.SetInsertPoint (LoadCmpBlocks[BlockIndex]);
482
533
483
- const LoadPair Loads =
484
- getLoadPair (LoadSizeType, /* NeedsBSwap=*/ DL.isLittleEndian (), MaxLoadType,
485
- CurLoadEntry.Offset );
534
+ const LoadPair Loads = getLoadPair (LoadSizeType, BSwapSizeType, MaxLoadType,
535
+ CurLoadEntry.Offset );
486
536
487
537
// Add the loaded values to the phi nodes for calculating memcmp result only
488
538
// if result is not used in a zero equality.
@@ -587,19 +637,24 @@ Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
587
637
// / A memcmp expansion that only has one block of load and compare can bypass
588
638
// / the compare, branch, and phi IR that is required in the general case.
589
639
Value *MemCmpExpansion::getMemCmpOneBlock () {
590
- Type *LoadSizeType = IntegerType::get (CI->getContext (), Size * 8 );
591
640
bool NeedsBSwap = DL.isLittleEndian () && Size != 1 ;
641
+ Type *LoadSizeType = IntegerType::get (CI->getContext (), Size * 8 );
642
+ Type *BSwapSizeType =
643
+ NeedsBSwap ? IntegerType::get (CI->getContext (), PowerOf2Ceil (Size * 8 ))
644
+ : nullptr ;
645
+ Type *MaxLoadType =
646
+ IntegerType::get (CI->getContext (),
647
+ std::max (MaxLoadSize, (unsigned )PowerOf2Ceil (Size)) * 8 );
592
648
593
649
// The i8 and i16 cases don't need compares. We zext the loaded values and
594
650
// subtract them to get the suitable negative, zero, or positive i32 result.
595
- if (Size < 4 ) {
596
- const LoadPair Loads =
597
- getLoadPair (LoadSizeType, NeedsBSwap, Builder.getInt32Ty (),
598
- /* Offset*/ 0 );
651
+ if (Size == 1 || Size == 2 ) {
652
+ const LoadPair Loads = getLoadPair (LoadSizeType, BSwapSizeType,
653
+ Builder.getInt32Ty (), /* Offset*/ 0 );
599
654
return Builder.CreateSub (Loads.Lhs , Loads.Rhs );
600
655
}
601
656
602
- const LoadPair Loads = getLoadPair (LoadSizeType, NeedsBSwap, LoadSizeType ,
657
+ const LoadPair Loads = getLoadPair (LoadSizeType, BSwapSizeType, MaxLoadType ,
603
658
/* Offset*/ 0 );
604
659
// The result of memcmp is negative, zero, or positive, so produce that by
605
660
// subtracting 2 extended compare bits: sub (ugt, ult).
0 commit comments