Skip to content

Commit 25d976b

Browse files
authored
[ScalarizeMaskedMemIntr] Don't use a scalar mask on GPUs (#104842)
ScalarizedMaskedMemIntr contains an optimization where the <N x i1> mask is bitcast into an iN and then bit-tests with powers of two are used to determine whether to load/store/... or not. However, on machines with branch divergence (mainly GPUs), this is a mis-optimization, since each i1 in the mask will be stored in a condition register - that is, ecah of these "i1"s is likely to be a word or two wide, making these bit operations counterproductive. Therefore, amend this pass to skip the optimizaiton on targets that it pessimizes. Pre-commit tests #104645
1 parent ecfceb8 commit 25d976b

File tree

5 files changed

+115
-109
lines changed

5 files changed

+115
-109
lines changed

llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp

Lines changed: 83 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,11 @@ class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {
6969

7070
static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
7171
const TargetTransformInfo &TTI, const DataLayout &DL,
72-
DomTreeUpdater *DTU);
72+
bool HasBranchDivergence, DomTreeUpdater *DTU);
7373
static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
7474
const TargetTransformInfo &TTI,
75-
const DataLayout &DL, DomTreeUpdater *DTU);
75+
const DataLayout &DL, bool HasBranchDivergence,
76+
DomTreeUpdater *DTU);
7677

7778
char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
7879

@@ -141,8 +142,9 @@ static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth,
141142
// %10 = extractelement <16 x i1> %mask, i32 2
142143
// br i1 %10, label %cond.load4, label %else5
143144
//
144-
static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
145-
DomTreeUpdater *DTU, bool &ModifiedDT) {
145+
static void scalarizeMaskedLoad(const DataLayout &DL, bool HasBranchDivergence,
146+
CallInst *CI, DomTreeUpdater *DTU,
147+
bool &ModifiedDT) {
146148
Value *Ptr = CI->getArgOperand(0);
147149
Value *Alignment = CI->getArgOperand(1);
148150
Value *Mask = CI->getArgOperand(2);
@@ -221,25 +223,26 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
221223
return;
222224
}
223225
// If the mask is not v1i1, use scalar bit test operations. This generates
224-
// better results on X86 at least.
225-
// Note: this produces worse code on AMDGPU, where the "i1" is implicitly SIMD
226-
// - what's a good way to detect this?
227-
Value *SclrMask;
228-
if (VectorWidth != 1) {
226+
// better results on X86 at least. However, don't do this on GPUs and other
227+
// machines with divergence, as there each i1 needs a vector register.
228+
Value *SclrMask = nullptr;
229+
if (VectorWidth != 1 && !HasBranchDivergence) {
229230
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
230231
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
231232
}
232233

233234
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
234235
// Fill the "else" block, created in the previous iteration
235236
//
236-
// %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
237-
// %mask_1 = and i16 %scalar_mask, i32 1 << Idx
238-
// %cond = icmp ne i16 %mask_1, 0
239-
// br i1 %mask_1, label %cond.load, label %else
237+
// %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
238+
// %else ] %mask_1 = and i16 %scalar_mask, i32 1 << Idx %cond = icmp ne i16
239+
// %mask_1, 0 br i1 %mask_1, label %cond.load, label %else
240240
//
241+
// On GPUs, use
242+
// %cond = extrectelement %mask, Idx
243+
// instead
241244
Value *Predicate;
242-
if (VectorWidth != 1) {
245+
if (SclrMask != nullptr) {
243246
Value *Mask = Builder.getInt(APInt::getOneBitSet(
244247
VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
245248
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
@@ -312,8 +315,9 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
312315
// store i32 %6, i32* %7
313316
// br label %else2
314317
// . . .
315-
static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
316-
DomTreeUpdater *DTU, bool &ModifiedDT) {
318+
static void scalarizeMaskedStore(const DataLayout &DL, bool HasBranchDivergence,
319+
CallInst *CI, DomTreeUpdater *DTU,
320+
bool &ModifiedDT) {
317321
Value *Src = CI->getArgOperand(0);
318322
Value *Ptr = CI->getArgOperand(1);
319323
Value *Alignment = CI->getArgOperand(2);
@@ -378,10 +382,10 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
378382
}
379383

380384
// If the mask is not v1i1, use scalar bit test operations. This generates
381-
// better results on X86 at least.
382-
383-
Value *SclrMask;
384-
if (VectorWidth != 1) {
385+
// better results on X86 at least. However, don't do this on GPUs or other
386+
// machines with branch divergence, as there each i1 takes up a register.
387+
Value *SclrMask = nullptr;
388+
if (VectorWidth != 1 && !HasBranchDivergence) {
385389
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
386390
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
387391
}
@@ -393,8 +397,11 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
393397
// %cond = icmp ne i16 %mask_1, 0
394398
// br i1 %mask_1, label %cond.store, label %else
395399
//
400+
// On GPUs, use
401+
// %cond = extrectelement %mask, Idx
402+
// instead
396403
Value *Predicate;
397-
if (VectorWidth != 1) {
404+
if (SclrMask != nullptr) {
398405
Value *Mask = Builder.getInt(APInt::getOneBitSet(
399406
VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
400407
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
@@ -461,7 +468,8 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
461468
// . . .
462469
// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
463470
// ret <16 x i32> %Result
464-
static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
471+
static void scalarizeMaskedGather(const DataLayout &DL,
472+
bool HasBranchDivergence, CallInst *CI,
465473
DomTreeUpdater *DTU, bool &ModifiedDT) {
466474
Value *Ptrs = CI->getArgOperand(0);
467475
Value *Alignment = CI->getArgOperand(1);
@@ -500,9 +508,10 @@ static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
500508
}
501509

502510
// If the mask is not v1i1, use scalar bit test operations. This generates
503-
// better results on X86 at least.
504-
Value *SclrMask;
505-
if (VectorWidth != 1) {
511+
// better results on X86 at least. However, don't do this on GPUs or other
512+
// machines with branch divergence, as there, each i1 takes up a register.
513+
Value *SclrMask = nullptr;
514+
if (VectorWidth != 1 && !HasBranchDivergence) {
506515
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
507516
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
508517
}
@@ -514,9 +523,12 @@ static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
514523
// %cond = icmp ne i16 %mask_1, 0
515524
// br i1 %Mask1, label %cond.load, label %else
516525
//
526+
// On GPUs, use
527+
// %cond = extrectelement %mask, Idx
528+
// instead
517529

518530
Value *Predicate;
519-
if (VectorWidth != 1) {
531+
if (SclrMask != nullptr) {
520532
Value *Mask = Builder.getInt(APInt::getOneBitSet(
521533
VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
522534
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
@@ -591,7 +603,8 @@ static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
591603
// store i32 %Elt1, i32* %Ptr1, align 4
592604
// br label %else2
593605
// . . .
594-
static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
606+
static void scalarizeMaskedScatter(const DataLayout &DL,
607+
bool HasBranchDivergence, CallInst *CI,
595608
DomTreeUpdater *DTU, bool &ModifiedDT) {
596609
Value *Src = CI->getArgOperand(0);
597610
Value *Ptrs = CI->getArgOperand(1);
@@ -629,8 +642,8 @@ static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
629642

630643
// If the mask is not v1i1, use scalar bit test operations. This generates
631644
// better results on X86 at least.
632-
Value *SclrMask;
633-
if (VectorWidth != 1) {
645+
Value *SclrMask = nullptr;
646+
if (VectorWidth != 1 && !HasBranchDivergence) {
634647
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
635648
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
636649
}
@@ -642,8 +655,11 @@ static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
642655
// %cond = icmp ne i16 %mask_1, 0
643656
// br i1 %Mask1, label %cond.store, label %else
644657
//
658+
// On GPUs, use
659+
// %cond = extrectelement %mask, Idx
660+
// instead
645661
Value *Predicate;
646-
if (VectorWidth != 1) {
662+
if (SclrMask != nullptr) {
647663
Value *Mask = Builder.getInt(APInt::getOneBitSet(
648664
VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
649665
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
@@ -681,7 +697,8 @@ static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
681697
ModifiedDT = true;
682698
}
683699

684-
static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
700+
static void scalarizeMaskedExpandLoad(const DataLayout &DL,
701+
bool HasBranchDivergence, CallInst *CI,
685702
DomTreeUpdater *DTU, bool &ModifiedDT) {
686703
Value *Ptr = CI->getArgOperand(0);
687704
Value *Mask = CI->getArgOperand(1);
@@ -738,23 +755,27 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
738755
}
739756

740757
// If the mask is not v1i1, use scalar bit test operations. This generates
741-
// better results on X86 at least.
742-
Value *SclrMask;
743-
if (VectorWidth != 1) {
758+
// better results on X86 at least. However, don't do this on GPUs or other
759+
// machines with branch divergence, as there, each i1 takes up a register.
760+
Value *SclrMask = nullptr;
761+
if (VectorWidth != 1 && !HasBranchDivergence) {
744762
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
745763
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
746764
}
747765

748766
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
749767
// Fill the "else" block, created in the previous iteration
750768
//
751-
// %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
752-
// %mask_1 = extractelement <16 x i1> %mask, i32 Idx
753-
// br i1 %mask_1, label %cond.load, label %else
769+
// %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
770+
// %else ] %mask_1 = extractelement <16 x i1> %mask, i32 Idx br i1 %mask_1,
771+
// label %cond.load, label %else
754772
//
773+
// On GPUs, use
774+
// %cond = extrectelement %mask, Idx
775+
// instead
755776

756777
Value *Predicate;
757-
if (VectorWidth != 1) {
778+
if (SclrMask != nullptr) {
758779
Value *Mask = Builder.getInt(APInt::getOneBitSet(
759780
VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
760781
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
@@ -813,7 +834,8 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
813834
ModifiedDT = true;
814835
}
815836

816-
static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
837+
static void scalarizeMaskedCompressStore(const DataLayout &DL,
838+
bool HasBranchDivergence, CallInst *CI,
817839
DomTreeUpdater *DTU,
818840
bool &ModifiedDT) {
819841
Value *Src = CI->getArgOperand(0);
@@ -855,9 +877,10 @@ static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
855877
}
856878

857879
// If the mask is not v1i1, use scalar bit test operations. This generates
858-
// better results on X86 at least.
859-
Value *SclrMask;
860-
if (VectorWidth != 1) {
880+
// better results on X86 at least. However, don't do this on GPUs or other
881+
// machines with branch divergence, as there, each i1 takes up a register.
882+
Value *SclrMask = nullptr;
883+
if (VectorWidth != 1 && !HasBranchDivergence) {
861884
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
862885
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
863886
}
@@ -868,8 +891,11 @@ static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
868891
// %mask_1 = extractelement <16 x i1> %mask, i32 Idx
869892
// br i1 %mask_1, label %cond.store, label %else
870893
//
894+
// On GPUs, use
895+
// %cond = extrectelement %mask, Idx
896+
// instead
871897
Value *Predicate;
872-
if (VectorWidth != 1) {
898+
if (SclrMask != nullptr) {
873899
Value *Mask = Builder.getInt(APInt::getOneBitSet(
874900
VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
875901
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
@@ -993,12 +1019,13 @@ static bool runImpl(Function &F, const TargetTransformInfo &TTI,
9931019
bool EverMadeChange = false;
9941020
bool MadeChange = true;
9951021
auto &DL = F.getDataLayout();
1022+
bool HasBranchDivergence = TTI.hasBranchDivergence(&F);
9961023
while (MadeChange) {
9971024
MadeChange = false;
9981025
for (BasicBlock &BB : llvm::make_early_inc_range(F)) {
9991026
bool ModifiedDTOnIteration = false;
10001027
MadeChange |= optimizeBlock(BB, ModifiedDTOnIteration, TTI, DL,
1001-
DTU ? &*DTU : nullptr);
1028+
HasBranchDivergence, DTU ? &*DTU : nullptr);
10021029

10031030
// Restart BB iteration if the dominator tree of the Function was changed
10041031
if (ModifiedDTOnIteration)
@@ -1032,13 +1059,14 @@ ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) {
10321059

10331060
static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
10341061
const TargetTransformInfo &TTI, const DataLayout &DL,
1035-
DomTreeUpdater *DTU) {
1062+
bool HasBranchDivergence, DomTreeUpdater *DTU) {
10361063
bool MadeChange = false;
10371064

10381065
BasicBlock::iterator CurInstIterator = BB.begin();
10391066
while (CurInstIterator != BB.end()) {
10401067
if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
1041-
MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL, DTU);
1068+
MadeChange |=
1069+
optimizeCallInst(CI, ModifiedDT, TTI, DL, HasBranchDivergence, DTU);
10421070
if (ModifiedDT)
10431071
return true;
10441072
}
@@ -1048,7 +1076,8 @@ static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
10481076

10491077
static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
10501078
const TargetTransformInfo &TTI,
1051-
const DataLayout &DL, DomTreeUpdater *DTU) {
1079+
const DataLayout &DL, bool HasBranchDivergence,
1080+
DomTreeUpdater *DTU) {
10521081
IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
10531082
if (II) {
10541083
// The scalarization code below does not work for scalable vectors.
@@ -1071,14 +1100,14 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
10711100
CI->getType(),
10721101
cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue()))
10731102
return false;
1074-
scalarizeMaskedLoad(DL, CI, DTU, ModifiedDT);
1103+
scalarizeMaskedLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
10751104
return true;
10761105
case Intrinsic::masked_store:
10771106
if (TTI.isLegalMaskedStore(
10781107
CI->getArgOperand(0)->getType(),
10791108
cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue()))
10801109
return false;
1081-
scalarizeMaskedStore(DL, CI, DTU, ModifiedDT);
1110+
scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
10821111
return true;
10831112
case Intrinsic::masked_gather: {
10841113
MaybeAlign MA =
@@ -1089,7 +1118,7 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
10891118
if (TTI.isLegalMaskedGather(LoadTy, Alignment) &&
10901119
!TTI.forceScalarizeMaskedGather(cast<VectorType>(LoadTy), Alignment))
10911120
return false;
1092-
scalarizeMaskedGather(DL, CI, DTU, ModifiedDT);
1121+
scalarizeMaskedGather(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
10931122
return true;
10941123
}
10951124
case Intrinsic::masked_scatter: {
@@ -1102,22 +1131,23 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
11021131
!TTI.forceScalarizeMaskedScatter(cast<VectorType>(StoreTy),
11031132
Alignment))
11041133
return false;
1105-
scalarizeMaskedScatter(DL, CI, DTU, ModifiedDT);
1134+
scalarizeMaskedScatter(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
11061135
return true;
11071136
}
11081137
case Intrinsic::masked_expandload:
11091138
if (TTI.isLegalMaskedExpandLoad(
11101139
CI->getType(),
11111140
CI->getAttributes().getParamAttrs(0).getAlignment().valueOrOne()))
11121141
return false;
1113-
scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT);
1142+
scalarizeMaskedExpandLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
11141143
return true;
11151144
case Intrinsic::masked_compressstore:
11161145
if (TTI.isLegalMaskedCompressStore(
11171146
CI->getArgOperand(0)->getType(),
11181147
CI->getAttributes().getParamAttrs(1).getAlignment().valueOrOne()))
11191148
return false;
1120-
scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT);
1149+
scalarizeMaskedCompressStore(DL, HasBranchDivergence, CI, DTU,
1150+
ModifiedDT);
11211151
return true;
11221152
}
11231153
}

0 commit comments

Comments
 (0)