Skip to content

[ScalarizeMaskedMemIntr] Don't use a scalar mask on GPUs #104842

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 83 additions & 53 deletions llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,11 @@ class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {

static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
const TargetTransformInfo &TTI, const DataLayout &DL,
DomTreeUpdater *DTU);
bool HasBranchDivergence, DomTreeUpdater *DTU);
static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
const TargetTransformInfo &TTI,
const DataLayout &DL, DomTreeUpdater *DTU);
const DataLayout &DL, bool HasBranchDivergence,
DomTreeUpdater *DTU);

char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;

Expand Down Expand Up @@ -141,8 +142,9 @@ static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth,
// %10 = extractelement <16 x i1> %mask, i32 2
// br i1 %10, label %cond.load4, label %else5
//
static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
DomTreeUpdater *DTU, bool &ModifiedDT) {
static void scalarizeMaskedLoad(const DataLayout &DL, bool HasBranchDivergence,
CallInst *CI, DomTreeUpdater *DTU,
bool &ModifiedDT) {
Value *Ptr = CI->getArgOperand(0);
Value *Alignment = CI->getArgOperand(1);
Value *Mask = CI->getArgOperand(2);
Expand Down Expand Up @@ -221,25 +223,26 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
return;
}
// If the mask is not v1i1, use scalar bit test operations. This generates
// better results on X86 at least.
// Note: this produces worse code on AMDGPU, where the "i1" is implicitly SIMD
// - what's a good way to detect this?
Value *SclrMask;
if (VectorWidth != 1) {
// better results on X86 at least. However, don't do this on GPUs and other
// machines with divergence, as there each i1 needs a vector register.
Value *SclrMask = nullptr;
if (VectorWidth != 1 && !HasBranchDivergence) {
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
}

for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
// Fill the "else" block, created in the previous iteration
//
// %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
// %mask_1 = and i16 %scalar_mask, i32 1 << Idx
// %cond = icmp ne i16 %mask_1, 0
// br i1 %mask_1, label %cond.load, label %else
// %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
// %else ] %mask_1 = and i16 %scalar_mask, i32 1 << Idx %cond = icmp ne i16
// %mask_1, 0 br i1 %mask_1, label %cond.load, label %else
//
// On GPUs, use
// %cond = extrectelement %mask, Idx
// instead
Value *Predicate;
if (VectorWidth != 1) {
if (SclrMask != nullptr) {
Value *Mask = Builder.getInt(APInt::getOneBitSet(
VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
Expand Down Expand Up @@ -312,8 +315,9 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
// store i32 %6, i32* %7
// br label %else2
// . . .
static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
DomTreeUpdater *DTU, bool &ModifiedDT) {
static void scalarizeMaskedStore(const DataLayout &DL, bool HasBranchDivergence,
CallInst *CI, DomTreeUpdater *DTU,
bool &ModifiedDT) {
Value *Src = CI->getArgOperand(0);
Value *Ptr = CI->getArgOperand(1);
Value *Alignment = CI->getArgOperand(2);
Expand Down Expand Up @@ -378,10 +382,10 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
}

// If the mask is not v1i1, use scalar bit test operations. This generates
// better results on X86 at least.

Value *SclrMask;
if (VectorWidth != 1) {
// better results on X86 at least. However, don't do this on GPUs or other
// machines with branch divergence, as there each i1 takes up a register.
Value *SclrMask = nullptr;
if (VectorWidth != 1 && !HasBranchDivergence) {
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
}
Expand All @@ -393,8 +397,11 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
// %cond = icmp ne i16 %mask_1, 0
// br i1 %mask_1, label %cond.store, label %else
//
// On GPUs, use
// %cond = extrectelement %mask, Idx
// instead
Value *Predicate;
if (VectorWidth != 1) {
if (SclrMask != nullptr) {
Value *Mask = Builder.getInt(APInt::getOneBitSet(
VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
Expand Down Expand Up @@ -461,7 +468,8 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
// . . .
// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
// ret <16 x i32> %Result
static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
static void scalarizeMaskedGather(const DataLayout &DL,
bool HasBranchDivergence, CallInst *CI,
DomTreeUpdater *DTU, bool &ModifiedDT) {
Value *Ptrs = CI->getArgOperand(0);
Value *Alignment = CI->getArgOperand(1);
Expand Down Expand Up @@ -500,9 +508,10 @@ static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
}

// If the mask is not v1i1, use scalar bit test operations. This generates
// better results on X86 at least.
Value *SclrMask;
if (VectorWidth != 1) {
// better results on X86 at least. However, don't do this on GPUs or other
// machines with branch divergence, as there, each i1 takes up a register.
Value *SclrMask = nullptr;
if (VectorWidth != 1 && !HasBranchDivergence) {
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
}
Expand All @@ -514,9 +523,12 @@ static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
// %cond = icmp ne i16 %mask_1, 0
// br i1 %Mask1, label %cond.load, label %else
//
// On GPUs, use
// %cond = extrectelement %mask, Idx
// instead

Value *Predicate;
if (VectorWidth != 1) {
if (SclrMask != nullptr) {
Value *Mask = Builder.getInt(APInt::getOneBitSet(
VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
Expand Down Expand Up @@ -591,7 +603,8 @@ static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
// store i32 %Elt1, i32* %Ptr1, align 4
// br label %else2
// . . .
static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
static void scalarizeMaskedScatter(const DataLayout &DL,
bool HasBranchDivergence, CallInst *CI,
DomTreeUpdater *DTU, bool &ModifiedDT) {
Value *Src = CI->getArgOperand(0);
Value *Ptrs = CI->getArgOperand(1);
Expand Down Expand Up @@ -629,8 +642,8 @@ static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,

// If the mask is not v1i1, use scalar bit test operations. This generates
// better results on X86 at least.
Value *SclrMask;
if (VectorWidth != 1) {
Value *SclrMask = nullptr;
if (VectorWidth != 1 && !HasBranchDivergence) {
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
}
Expand All @@ -642,8 +655,11 @@ static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
// %cond = icmp ne i16 %mask_1, 0
// br i1 %Mask1, label %cond.store, label %else
//
// On GPUs, use
// %cond = extrectelement %mask, Idx
// instead
Value *Predicate;
if (VectorWidth != 1) {
if (SclrMask != nullptr) {
Value *Mask = Builder.getInt(APInt::getOneBitSet(
VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
Expand Down Expand Up @@ -681,7 +697,8 @@ static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
ModifiedDT = true;
}

static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
static void scalarizeMaskedExpandLoad(const DataLayout &DL,
bool HasBranchDivergence, CallInst *CI,
DomTreeUpdater *DTU, bool &ModifiedDT) {
Value *Ptr = CI->getArgOperand(0);
Value *Mask = CI->getArgOperand(1);
Expand Down Expand Up @@ -738,23 +755,27 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
}

// If the mask is not v1i1, use scalar bit test operations. This generates
// better results on X86 at least.
Value *SclrMask;
if (VectorWidth != 1) {
// better results on X86 at least. However, don't do this on GPUs or other
// machines with branch divergence, as there, each i1 takes up a register.
Value *SclrMask = nullptr;
if (VectorWidth != 1 && !HasBranchDivergence) {
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
}

for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
// Fill the "else" block, created in the previous iteration
//
// %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
// %mask_1 = extractelement <16 x i1> %mask, i32 Idx
// br i1 %mask_1, label %cond.load, label %else
// %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
// %else ] %mask_1 = extractelement <16 x i1> %mask, i32 Idx br i1 %mask_1,
// label %cond.load, label %else
//
// On GPUs, use
// %cond = extrectelement %mask, Idx
// instead

Value *Predicate;
if (VectorWidth != 1) {
if (SclrMask != nullptr) {
Value *Mask = Builder.getInt(APInt::getOneBitSet(
VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
Expand Down Expand Up @@ -813,7 +834,8 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
ModifiedDT = true;
}

static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
static void scalarizeMaskedCompressStore(const DataLayout &DL,
bool HasBranchDivergence, CallInst *CI,
DomTreeUpdater *DTU,
bool &ModifiedDT) {
Value *Src = CI->getArgOperand(0);
Expand Down Expand Up @@ -855,9 +877,10 @@ static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
}

// If the mask is not v1i1, use scalar bit test operations. This generates
// better results on X86 at least.
Value *SclrMask;
if (VectorWidth != 1) {
// better results on X86 at least. However, don't do this on GPUs or other
// machines with branch divergence, as there, each i1 takes up a register.
Value *SclrMask = nullptr;
if (VectorWidth != 1 && !HasBranchDivergence) {
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
}
Expand All @@ -868,8 +891,11 @@ static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
// %mask_1 = extractelement <16 x i1> %mask, i32 Idx
// br i1 %mask_1, label %cond.store, label %else
//
// On GPUs, use
// %cond = extrectelement %mask, Idx
// instead
Value *Predicate;
if (VectorWidth != 1) {
if (SclrMask != nullptr) {
Value *Mask = Builder.getInt(APInt::getOneBitSet(
VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
Expand Down Expand Up @@ -993,12 +1019,13 @@ static bool runImpl(Function &F, const TargetTransformInfo &TTI,
bool EverMadeChange = false;
bool MadeChange = true;
auto &DL = F.getDataLayout();
bool HasBranchDivergence = TTI.hasBranchDivergence(&F);
while (MadeChange) {
MadeChange = false;
for (BasicBlock &BB : llvm::make_early_inc_range(F)) {
bool ModifiedDTOnIteration = false;
MadeChange |= optimizeBlock(BB, ModifiedDTOnIteration, TTI, DL,
DTU ? &*DTU : nullptr);
HasBranchDivergence, DTU ? &*DTU : nullptr);

// Restart BB iteration if the dominator tree of the Function was changed
if (ModifiedDTOnIteration)
Expand Down Expand Up @@ -1032,13 +1059,14 @@ ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) {

static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
const TargetTransformInfo &TTI, const DataLayout &DL,
DomTreeUpdater *DTU) {
bool HasBranchDivergence, DomTreeUpdater *DTU) {
bool MadeChange = false;

BasicBlock::iterator CurInstIterator = BB.begin();
while (CurInstIterator != BB.end()) {
if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL, DTU);
MadeChange |=
optimizeCallInst(CI, ModifiedDT, TTI, DL, HasBranchDivergence, DTU);
if (ModifiedDT)
return true;
}
Expand All @@ -1048,7 +1076,8 @@ static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,

static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
const TargetTransformInfo &TTI,
const DataLayout &DL, DomTreeUpdater *DTU) {
const DataLayout &DL, bool HasBranchDivergence,
DomTreeUpdater *DTU) {
IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
if (II) {
// The scalarization code below does not work for scalable vectors.
Expand All @@ -1071,14 +1100,14 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
CI->getType(),
cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue()))
return false;
scalarizeMaskedLoad(DL, CI, DTU, ModifiedDT);
scalarizeMaskedLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
return true;
case Intrinsic::masked_store:
if (TTI.isLegalMaskedStore(
CI->getArgOperand(0)->getType(),
cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue()))
return false;
scalarizeMaskedStore(DL, CI, DTU, ModifiedDT);
scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
return true;
case Intrinsic::masked_gather: {
MaybeAlign MA =
Expand All @@ -1089,7 +1118,7 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
if (TTI.isLegalMaskedGather(LoadTy, Alignment) &&
!TTI.forceScalarizeMaskedGather(cast<VectorType>(LoadTy), Alignment))
return false;
scalarizeMaskedGather(DL, CI, DTU, ModifiedDT);
scalarizeMaskedGather(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
return true;
}
case Intrinsic::masked_scatter: {
Expand All @@ -1102,22 +1131,23 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
!TTI.forceScalarizeMaskedScatter(cast<VectorType>(StoreTy),
Alignment))
return false;
scalarizeMaskedScatter(DL, CI, DTU, ModifiedDT);
scalarizeMaskedScatter(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
return true;
}
case Intrinsic::masked_expandload:
if (TTI.isLegalMaskedExpandLoad(
CI->getType(),
CI->getAttributes().getParamAttrs(0).getAlignment().valueOrOne()))
return false;
scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT);
scalarizeMaskedExpandLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
return true;
case Intrinsic::masked_compressstore:
if (TTI.isLegalMaskedCompressStore(
CI->getArgOperand(0)->getType(),
CI->getAttributes().getParamAttrs(1).getAlignment().valueOrOne()))
return false;
scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT);
scalarizeMaskedCompressStore(DL, HasBranchDivergence, CI, DTU,
ModifiedDT);
return true;
}
}
Expand Down
Loading
Loading