Skip to content

Commit b70026c

Browse files
committed
[ScalarizeMaskedMemIntrin] Bitcast the mask to the scalar domain and use scalar bit tests for the branches.
X86 at least is able to use movmsk or kmov to move the mask to the scalar domain. Then we can just use test instructions to test individual bits. This is more efficient than extracting each mask element individually. I special cased v1i1 to use the previous behavior. This avoids poor type legalization of bitcast of v1i1 to i1. I've skipped expandload/compressstore as I think we need to handle constant masks for those better first. Many tests end up with duplicate test instructions due to tail duplication in the branch folding pass. But the same thing happens when constructing similar code in C. So its not unique to the scalarization. Not sure if this lowering code will also be good for other targets, but we're only testing X86 today. Differential Revision: https://reviews.llvm.org/D65319 llvm-svn: 367489
1 parent b51dc64 commit b70026c

14 files changed

+20932
-24392
lines changed

llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,30 @@ static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
173173
return;
174174
}
175175

176+
// If the mask is not v1i1, use scalar bit test operations. This generates
177+
// better results on X86 at least.
178+
Value *SclrMask;
179+
if (VectorWidth != 1) {
180+
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
181+
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
182+
}
183+
176184
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
177185
// Fill the "else" block, created in the previous iteration
178186
//
179187
// %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
180-
// %mask_1 = extractelement <16 x i1> %mask, i32 Idx
188+
// %mask_1 = and i16 %scalar_mask, i32 1 << Idx
189+
// %cond = icmp ne i16 %mask_1, 0
181190
// br i1 %mask_1, label %cond.load, label %else
182191
//
183-
184-
Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
192+
Value *Predicate;
193+
if (VectorWidth != 1) {
194+
Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
195+
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
196+
Builder.getIntN(VectorWidth, 0));
197+
} else {
198+
Predicate = Builder.CreateExtractElement(Mask, Idx);
199+
}
185200

186201
// Create "cond" block
187202
//
@@ -290,13 +305,29 @@ static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
290305
return;
291306
}
292307

308+
// If the mask is not v1i1, use scalar bit test operations. This generates
309+
// better results on X86 at least.
310+
Value *SclrMask;
311+
if (VectorWidth != 1) {
312+
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
313+
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
314+
}
315+
293316
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
294317
// Fill the "else" block, created in the previous iteration
295318
//
296-
// %mask_1 = extractelement <16 x i1> %mask, i32 Idx
319+
// %mask_1 = and i16 %scalar_mask, i32 1 << Idx
320+
// %cond = icmp ne i16 %mask_1, 0
297321
// br i1 %mask_1, label %cond.store, label %else
298322
//
299-
Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
323+
Value *Predicate;
324+
if (VectorWidth != 1) {
325+
Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
326+
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
327+
Builder.getIntN(VectorWidth, 0));
328+
} else {
329+
Predicate = Builder.CreateExtractElement(Mask, Idx);
330+
}
300331

301332
// Create "cond" block
302333
//
@@ -392,15 +423,30 @@ static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
392423
return;
393424
}
394425

426+
// If the mask is not v1i1, use scalar bit test operations. This generates
427+
// better results on X86 at least.
428+
Value *SclrMask;
429+
if (VectorWidth != 1) {
430+
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
431+
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
432+
}
433+
395434
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
396435
// Fill the "else" block, created in the previous iteration
397436
//
398-
// %Mask1 = extractelement <16 x i1> %Mask, i32 1
437+
// %Mask1 = and i16 %scalar_mask, i32 1 << Idx
438+
// %cond = icmp ne i16 %mask_1, 0
399439
// br i1 %Mask1, label %cond.load, label %else
400440
//
401441

402-
Value *Predicate =
403-
Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
442+
Value *Predicate;
443+
if (VectorWidth != 1) {
444+
Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
445+
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
446+
Builder.getIntN(VectorWidth, 0));
447+
} else {
448+
Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
449+
}
404450

405451
// Create "cond" block
406452
//
@@ -499,14 +545,29 @@ static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
499545
return;
500546
}
501547

548+
// If the mask is not v1i1, use scalar bit test operations. This generates
549+
// better results on X86 at least.
550+
Value *SclrMask;
551+
if (VectorWidth != 1) {
552+
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
553+
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
554+
}
555+
502556
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
503557
// Fill the "else" block, created in the previous iteration
504558
//
505-
// %Mask1 = extractelement <16 x i1> %Mask, i32 Idx
559+
// %Mask1 = and i16 %scalar_mask, i32 1 << Idx
560+
// %cond = icmp ne i16 %mask_1, 0
506561
// br i1 %Mask1, label %cond.store, label %else
507562
//
508-
Value *Predicate =
509-
Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
563+
Value *Predicate;
564+
if (VectorWidth != 1) {
565+
Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
566+
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
567+
Builder.getIntN(VectorWidth, 0));
568+
} else {
569+
Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
570+
}
510571

511572
// Create "cond" block
512573
//

0 commit comments

Comments
 (0)