@@ -69,10 +69,11 @@ class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {
69
69
70
70
static bool optimizeBlock (BasicBlock &BB, bool &ModifiedDT,
71
71
const TargetTransformInfo &TTI, const DataLayout &DL,
72
- DomTreeUpdater *DTU);
72
+ bool HasBranchDivergence, DomTreeUpdater *DTU);
73
73
static bool optimizeCallInst (CallInst *CI, bool &ModifiedDT,
74
74
const TargetTransformInfo &TTI,
75
- const DataLayout &DL, DomTreeUpdater *DTU);
75
+ const DataLayout &DL, bool HasBranchDivergence,
76
+ DomTreeUpdater *DTU);
76
77
77
78
char ScalarizeMaskedMemIntrinLegacyPass::ID = 0 ;
78
79
@@ -141,8 +142,9 @@ static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth,
141
142
// %10 = extractelement <16 x i1> %mask, i32 2
142
143
// br i1 %10, label %cond.load4, label %else5
143
144
//
144
- static void scalarizeMaskedLoad (const DataLayout &DL, CallInst *CI,
145
- DomTreeUpdater *DTU, bool &ModifiedDT) {
145
+ static void scalarizeMaskedLoad (const DataLayout &DL, bool HasBranchDivergence,
146
+ CallInst *CI, DomTreeUpdater *DTU,
147
+ bool &ModifiedDT) {
146
148
Value *Ptr = CI->getArgOperand (0 );
147
149
Value *Alignment = CI->getArgOperand (1 );
148
150
Value *Mask = CI->getArgOperand (2 );
@@ -221,25 +223,26 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
221
223
return ;
222
224
}
223
225
// If the mask is not v1i1, use scalar bit test operations. This generates
224
- // better results on X86 at least.
225
- // Note: this produces worse code on AMDGPU, where the "i1" is implicitly SIMD
226
- // - what's a good way to detect this?
227
- Value *SclrMask;
228
- if (VectorWidth != 1 ) {
226
+ // better results on X86 at least. However, don't do this on GPUs and other
227
+ // machines with divergence, as there each i1 needs a vector register.
228
+ Value *SclrMask = nullptr ;
229
+ if (VectorWidth != 1 && !HasBranchDivergence) {
229
230
Type *SclrMaskTy = Builder.getIntNTy (VectorWidth);
230
231
SclrMask = Builder.CreateBitCast (Mask, SclrMaskTy, " scalar_mask" );
231
232
}
232
233
233
234
for (unsigned Idx = 0 ; Idx < VectorWidth; ++Idx) {
234
235
// Fill the "else" block, created in the previous iteration
235
236
//
236
- // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
237
- // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
238
- // %cond = icmp ne i16 %mask_1, 0
239
- // br i1 %mask_1, label %cond.load, label %else
237
+ // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
238
+ // %else ] %mask_1 = and i16 %scalar_mask, i32 1 << Idx %cond = icmp ne i16
239
+ // %mask_1, 0 br i1 %mask_1, label %cond.load, label %else
240
240
//
241
+ // On GPUs, use
242
+ // %cond = extrectelement %mask, Idx
243
+ // instead
241
244
Value *Predicate;
242
- if (VectorWidth != 1 ) {
245
+ if (SclrMask != nullptr ) {
243
246
Value *Mask = Builder.getInt (APInt::getOneBitSet (
244
247
VectorWidth, adjustForEndian (DL, VectorWidth, Idx)));
245
248
Predicate = Builder.CreateICmpNE (Builder.CreateAnd (SclrMask, Mask),
@@ -312,8 +315,9 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
312
315
// store i32 %6, i32* %7
313
316
// br label %else2
314
317
// . . .
315
- static void scalarizeMaskedStore (const DataLayout &DL, CallInst *CI,
316
- DomTreeUpdater *DTU, bool &ModifiedDT) {
318
+ static void scalarizeMaskedStore (const DataLayout &DL, bool HasBranchDivergence,
319
+ CallInst *CI, DomTreeUpdater *DTU,
320
+ bool &ModifiedDT) {
317
321
Value *Src = CI->getArgOperand (0 );
318
322
Value *Ptr = CI->getArgOperand (1 );
319
323
Value *Alignment = CI->getArgOperand (2 );
@@ -378,10 +382,10 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
378
382
}
379
383
380
384
// If the mask is not v1i1, use scalar bit test operations. This generates
381
- // better results on X86 at least.
382
-
383
- Value *SclrMask;
384
- if (VectorWidth != 1 ) {
385
+ // better results on X86 at least. However, don't do this on GPUs or other
386
+ // machines with branch divergence, as there each i1 takes up a register.
387
+ Value *SclrMask = nullptr ;
388
+ if (VectorWidth != 1 && !HasBranchDivergence ) {
385
389
Type *SclrMaskTy = Builder.getIntNTy (VectorWidth);
386
390
SclrMask = Builder.CreateBitCast (Mask, SclrMaskTy, " scalar_mask" );
387
391
}
@@ -393,8 +397,11 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
393
397
// %cond = icmp ne i16 %mask_1, 0
394
398
// br i1 %mask_1, label %cond.store, label %else
395
399
//
400
+ // On GPUs, use
401
+ // %cond = extrectelement %mask, Idx
402
+ // instead
396
403
Value *Predicate;
397
- if (VectorWidth != 1 ) {
404
+ if (SclrMask != nullptr ) {
398
405
Value *Mask = Builder.getInt (APInt::getOneBitSet (
399
406
VectorWidth, adjustForEndian (DL, VectorWidth, Idx)));
400
407
Predicate = Builder.CreateICmpNE (Builder.CreateAnd (SclrMask, Mask),
@@ -461,7 +468,8 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
461
468
// . . .
462
469
// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
463
470
// ret <16 x i32> %Result
464
- static void scalarizeMaskedGather (const DataLayout &DL, CallInst *CI,
471
+ static void scalarizeMaskedGather (const DataLayout &DL,
472
+ bool HasBranchDivergence, CallInst *CI,
465
473
DomTreeUpdater *DTU, bool &ModifiedDT) {
466
474
Value *Ptrs = CI->getArgOperand (0 );
467
475
Value *Alignment = CI->getArgOperand (1 );
@@ -500,9 +508,10 @@ static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
500
508
}
501
509
502
510
// If the mask is not v1i1, use scalar bit test operations. This generates
503
- // better results on X86 at least.
504
- Value *SclrMask;
505
- if (VectorWidth != 1 ) {
511
+ // better results on X86 at least. However, don't do this on GPUs or other
512
+ // machines with branch divergence, as there, each i1 takes up a register.
513
+ Value *SclrMask = nullptr ;
514
+ if (VectorWidth != 1 && !HasBranchDivergence) {
506
515
Type *SclrMaskTy = Builder.getIntNTy (VectorWidth);
507
516
SclrMask = Builder.CreateBitCast (Mask, SclrMaskTy, " scalar_mask" );
508
517
}
@@ -514,9 +523,12 @@ static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
514
523
// %cond = icmp ne i16 %mask_1, 0
515
524
// br i1 %Mask1, label %cond.load, label %else
516
525
//
526
+ // On GPUs, use
527
+ // %cond = extrectelement %mask, Idx
528
+ // instead
517
529
518
530
Value *Predicate;
519
- if (VectorWidth != 1 ) {
531
+ if (SclrMask != nullptr ) {
520
532
Value *Mask = Builder.getInt (APInt::getOneBitSet (
521
533
VectorWidth, adjustForEndian (DL, VectorWidth, Idx)));
522
534
Predicate = Builder.CreateICmpNE (Builder.CreateAnd (SclrMask, Mask),
@@ -591,7 +603,8 @@ static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
591
603
// store i32 %Elt1, i32* %Ptr1, align 4
592
604
// br label %else2
593
605
// . . .
594
- static void scalarizeMaskedScatter (const DataLayout &DL, CallInst *CI,
606
+ static void scalarizeMaskedScatter (const DataLayout &DL,
607
+ bool HasBranchDivergence, CallInst *CI,
595
608
DomTreeUpdater *DTU, bool &ModifiedDT) {
596
609
Value *Src = CI->getArgOperand (0 );
597
610
Value *Ptrs = CI->getArgOperand (1 );
@@ -629,8 +642,8 @@ static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
629
642
630
643
// If the mask is not v1i1, use scalar bit test operations. This generates
631
644
// better results on X86 at least.
632
- Value *SclrMask;
633
- if (VectorWidth != 1 ) {
645
+ Value *SclrMask = nullptr ;
646
+ if (VectorWidth != 1 && !HasBranchDivergence ) {
634
647
Type *SclrMaskTy = Builder.getIntNTy (VectorWidth);
635
648
SclrMask = Builder.CreateBitCast (Mask, SclrMaskTy, " scalar_mask" );
636
649
}
@@ -642,8 +655,11 @@ static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
642
655
// %cond = icmp ne i16 %mask_1, 0
643
656
// br i1 %Mask1, label %cond.store, label %else
644
657
//
658
+ // On GPUs, use
659
+ // %cond = extrectelement %mask, Idx
660
+ // instead
645
661
Value *Predicate;
646
- if (VectorWidth != 1 ) {
662
+ if (SclrMask != nullptr ) {
647
663
Value *Mask = Builder.getInt (APInt::getOneBitSet (
648
664
VectorWidth, adjustForEndian (DL, VectorWidth, Idx)));
649
665
Predicate = Builder.CreateICmpNE (Builder.CreateAnd (SclrMask, Mask),
@@ -681,7 +697,8 @@ static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
681
697
ModifiedDT = true ;
682
698
}
683
699
684
- static void scalarizeMaskedExpandLoad (const DataLayout &DL, CallInst *CI,
700
+ static void scalarizeMaskedExpandLoad (const DataLayout &DL,
701
+ bool HasBranchDivergence, CallInst *CI,
685
702
DomTreeUpdater *DTU, bool &ModifiedDT) {
686
703
Value *Ptr = CI->getArgOperand (0 );
687
704
Value *Mask = CI->getArgOperand (1 );
@@ -738,23 +755,27 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
738
755
}
739
756
740
757
// If the mask is not v1i1, use scalar bit test operations. This generates
741
- // better results on X86 at least.
742
- Value *SclrMask;
743
- if (VectorWidth != 1 ) {
758
+ // better results on X86 at least. However, don't do this on GPUs or other
759
+ // machines with branch divergence, as there, each i1 takes up a register.
760
+ Value *SclrMask = nullptr ;
761
+ if (VectorWidth != 1 && !HasBranchDivergence) {
744
762
Type *SclrMaskTy = Builder.getIntNTy (VectorWidth);
745
763
SclrMask = Builder.CreateBitCast (Mask, SclrMaskTy, " scalar_mask" );
746
764
}
747
765
748
766
for (unsigned Idx = 0 ; Idx < VectorWidth; ++Idx) {
749
767
// Fill the "else" block, created in the previous iteration
750
768
//
751
- // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
752
- // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
753
- // br i1 %mask_1, label %cond.load, label %else
769
+ // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
770
+ // %else ] % mask_1 = extractelement <16 x i1> %mask, i32 Idx br i1 %mask_1,
771
+ // label %cond.load, label %else
754
772
//
773
+ // On GPUs, use
774
+ // %cond = extrectelement %mask, Idx
775
+ // instead
755
776
756
777
Value *Predicate;
757
- if (VectorWidth != 1 ) {
778
+ if (SclrMask != nullptr ) {
758
779
Value *Mask = Builder.getInt (APInt::getOneBitSet (
759
780
VectorWidth, adjustForEndian (DL, VectorWidth, Idx)));
760
781
Predicate = Builder.CreateICmpNE (Builder.CreateAnd (SclrMask, Mask),
@@ -813,7 +834,8 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
813
834
ModifiedDT = true ;
814
835
}
815
836
816
- static void scalarizeMaskedCompressStore (const DataLayout &DL, CallInst *CI,
837
+ static void scalarizeMaskedCompressStore (const DataLayout &DL,
838
+ bool HasBranchDivergence, CallInst *CI,
817
839
DomTreeUpdater *DTU,
818
840
bool &ModifiedDT) {
819
841
Value *Src = CI->getArgOperand (0 );
@@ -855,9 +877,10 @@ static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
855
877
}
856
878
857
879
// If the mask is not v1i1, use scalar bit test operations. This generates
858
- // better results on X86 at least.
859
- Value *SclrMask;
860
- if (VectorWidth != 1 ) {
880
+ // better results on X86 at least. However, don't do this on GPUs or other
881
+ // machines with branch divergence, as there, each i1 takes up a register.
882
+ Value *SclrMask = nullptr ;
883
+ if (VectorWidth != 1 && !HasBranchDivergence) {
861
884
Type *SclrMaskTy = Builder.getIntNTy (VectorWidth);
862
885
SclrMask = Builder.CreateBitCast (Mask, SclrMaskTy, " scalar_mask" );
863
886
}
@@ -868,8 +891,11 @@ static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
868
891
// %mask_1 = extractelement <16 x i1> %mask, i32 Idx
869
892
// br i1 %mask_1, label %cond.store, label %else
870
893
//
894
+ // On GPUs, use
895
+ // %cond = extrectelement %mask, Idx
896
+ // instead
871
897
Value *Predicate;
872
- if (VectorWidth != 1 ) {
898
+ if (SclrMask != nullptr ) {
873
899
Value *Mask = Builder.getInt (APInt::getOneBitSet (
874
900
VectorWidth, adjustForEndian (DL, VectorWidth, Idx)));
875
901
Predicate = Builder.CreateICmpNE (Builder.CreateAnd (SclrMask, Mask),
@@ -993,12 +1019,13 @@ static bool runImpl(Function &F, const TargetTransformInfo &TTI,
993
1019
bool EverMadeChange = false ;
994
1020
bool MadeChange = true ;
995
1021
auto &DL = F.getDataLayout ();
1022
+ bool HasBranchDivergence = TTI.hasBranchDivergence (&F);
996
1023
while (MadeChange) {
997
1024
MadeChange = false ;
998
1025
for (BasicBlock &BB : llvm::make_early_inc_range (F)) {
999
1026
bool ModifiedDTOnIteration = false ;
1000
1027
MadeChange |= optimizeBlock (BB, ModifiedDTOnIteration, TTI, DL,
1001
- DTU ? &*DTU : nullptr );
1028
+ HasBranchDivergence, DTU ? &*DTU : nullptr );
1002
1029
1003
1030
// Restart BB iteration if the dominator tree of the Function was changed
1004
1031
if (ModifiedDTOnIteration)
@@ -1032,13 +1059,14 @@ ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) {
1032
1059
1033
1060
static bool optimizeBlock (BasicBlock &BB, bool &ModifiedDT,
1034
1061
const TargetTransformInfo &TTI, const DataLayout &DL,
1035
- DomTreeUpdater *DTU) {
1062
+ bool HasBranchDivergence, DomTreeUpdater *DTU) {
1036
1063
bool MadeChange = false ;
1037
1064
1038
1065
BasicBlock::iterator CurInstIterator = BB.begin ();
1039
1066
while (CurInstIterator != BB.end ()) {
1040
1067
if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
1041
- MadeChange |= optimizeCallInst (CI, ModifiedDT, TTI, DL, DTU);
1068
+ MadeChange |=
1069
+ optimizeCallInst (CI, ModifiedDT, TTI, DL, HasBranchDivergence, DTU);
1042
1070
if (ModifiedDT)
1043
1071
return true ;
1044
1072
}
@@ -1048,7 +1076,8 @@ static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
1048
1076
1049
1077
static bool optimizeCallInst (CallInst *CI, bool &ModifiedDT,
1050
1078
const TargetTransformInfo &TTI,
1051
- const DataLayout &DL, DomTreeUpdater *DTU) {
1079
+ const DataLayout &DL, bool HasBranchDivergence,
1080
+ DomTreeUpdater *DTU) {
1052
1081
IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
1053
1082
if (II) {
1054
1083
// The scalarization code below does not work for scalable vectors.
@@ -1071,14 +1100,14 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
1071
1100
CI->getType (),
1072
1101
cast<ConstantInt>(CI->getArgOperand (1 ))->getAlignValue ()))
1073
1102
return false ;
1074
- scalarizeMaskedLoad (DL, CI, DTU, ModifiedDT);
1103
+ scalarizeMaskedLoad (DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1075
1104
return true ;
1076
1105
case Intrinsic::masked_store:
1077
1106
if (TTI.isLegalMaskedStore (
1078
1107
CI->getArgOperand (0 )->getType (),
1079
1108
cast<ConstantInt>(CI->getArgOperand (2 ))->getAlignValue ()))
1080
1109
return false ;
1081
- scalarizeMaskedStore (DL, CI, DTU, ModifiedDT);
1110
+ scalarizeMaskedStore (DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1082
1111
return true ;
1083
1112
case Intrinsic::masked_gather: {
1084
1113
MaybeAlign MA =
@@ -1089,7 +1118,7 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
1089
1118
if (TTI.isLegalMaskedGather (LoadTy, Alignment) &&
1090
1119
!TTI.forceScalarizeMaskedGather (cast<VectorType>(LoadTy), Alignment))
1091
1120
return false ;
1092
- scalarizeMaskedGather (DL, CI, DTU, ModifiedDT);
1121
+ scalarizeMaskedGather (DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1093
1122
return true ;
1094
1123
}
1095
1124
case Intrinsic::masked_scatter: {
@@ -1102,22 +1131,23 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
1102
1131
!TTI.forceScalarizeMaskedScatter (cast<VectorType>(StoreTy),
1103
1132
Alignment))
1104
1133
return false ;
1105
- scalarizeMaskedScatter (DL, CI, DTU, ModifiedDT);
1134
+ scalarizeMaskedScatter (DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1106
1135
return true ;
1107
1136
}
1108
1137
case Intrinsic::masked_expandload:
1109
1138
if (TTI.isLegalMaskedExpandLoad (
1110
1139
CI->getType (),
1111
1140
CI->getAttributes ().getParamAttrs (0 ).getAlignment ().valueOrOne ()))
1112
1141
return false ;
1113
- scalarizeMaskedExpandLoad (DL, CI, DTU, ModifiedDT);
1142
+ scalarizeMaskedExpandLoad (DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1114
1143
return true ;
1115
1144
case Intrinsic::masked_compressstore:
1116
1145
if (TTI.isLegalMaskedCompressStore (
1117
1146
CI->getArgOperand (0 )->getType (),
1118
1147
CI->getAttributes ().getParamAttrs (1 ).getAlignment ().valueOrOne ()))
1119
1148
return false ;
1120
- scalarizeMaskedCompressStore (DL, CI, DTU, ModifiedDT);
1149
+ scalarizeMaskedCompressStore (DL, HasBranchDivergence, CI, DTU,
1150
+ ModifiedDT);
1121
1151
return true ;
1122
1152
}
1123
1153
}
0 commit comments