Skip to content

Commit 3f8711f

Browse files
dtcxzywtstellar
authored andcommitted
[InstCombine] Fix miscompilation in PR83947 (#83993)
https://github.com/llvm/llvm-project/blob/762f762504967efbe159db5c737154b989afc9bb/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp#L394-L407 Comment from @topperc: > This transforms assumes the mask is a non-zero splat. We only know its a splat and not provably all 0s. The mask is a constexpr that includes the address of the global variable. We can't resolve the constant expression to an exact value. Fixes #83947.
1 parent 9b9aee1 commit 3f8711f

File tree

5 files changed

+110
-6
lines changed

5 files changed

+110
-6
lines changed

llvm/include/llvm/Analysis/VectorUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,11 @@ bool maskIsAllZeroOrUndef(Value *Mask);
406406
/// lanes can be assumed active.
407407
bool maskIsAllOneOrUndef(Value *Mask);
408408

409+
/// Given a mask vector of i1, Return true if any of the elements of this
410+
/// predicate mask are known to be true or undef. That is, return true if at
411+
/// least one lane can be assumed active.
412+
bool maskContainsAllOneOrUndef(Value *Mask);
413+
409414
/// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y)
410415
/// for each lane which may be active.
411416
APInt possiblyDemandedEltsInMask(Value *Mask);

llvm/lib/Analysis/VectorUtils.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,31 @@ bool llvm::maskIsAllOneOrUndef(Value *Mask) {
10121012
return true;
10131013
}
10141014

1015+
bool llvm::maskContainsAllOneOrUndef(Value *Mask) {
1016+
assert(isa<VectorType>(Mask->getType()) &&
1017+
isa<IntegerType>(Mask->getType()->getScalarType()) &&
1018+
cast<IntegerType>(Mask->getType()->getScalarType())->getBitWidth() ==
1019+
1 &&
1020+
"Mask must be a vector of i1");
1021+
1022+
auto *ConstMask = dyn_cast<Constant>(Mask);
1023+
if (!ConstMask)
1024+
return false;
1025+
if (ConstMask->isAllOnesValue() || isa<UndefValue>(ConstMask))
1026+
return true;
1027+
if (isa<ScalableVectorType>(ConstMask->getType()))
1028+
return false;
1029+
for (unsigned
1030+
I = 0,
1031+
E = cast<FixedVectorType>(ConstMask->getType())->getNumElements();
1032+
I != E; ++I) {
1033+
if (auto *MaskElt = ConstMask->getAggregateElement(I))
1034+
if (MaskElt->isAllOnesValue() || isa<UndefValue>(MaskElt))
1035+
return true;
1036+
}
1037+
return false;
1038+
}
1039+
10151040
/// TODO: This is a lot like known bits, but for
10161041
/// vectors. Is there something we can common this with?
10171042
APInt llvm::possiblyDemandedEltsInMask(Value *Mask) {

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -412,11 +412,14 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
412412
if (auto *SplatPtr = getSplatValue(II.getArgOperand(1))) {
413413
// scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr
414414
if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) {
415-
Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
416-
StoreInst *S =
417-
new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment);
418-
S->copyMetadata(II);
419-
return S;
415+
if (maskContainsAllOneOrUndef(ConstMask)) {
416+
Align Alignment =
417+
cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
418+
StoreInst *S = new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false,
419+
Alignment);
420+
S->copyMetadata(II);
421+
return S;
422+
}
420423
}
421424
// scatter(vector, splat(ptr), splat(true)) -> store extract(vector,
422425
// lastlane), ptr

llvm/test/Transforms/InstCombine/masked_intrinsics.ll

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,11 @@ entry:
292292
define void @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(ptr %dst, i16 %val) {
293293
; CHECK-LABEL: @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(
294294
; CHECK-NEXT: entry:
295-
; CHECK-NEXT: store i16 [[VAL:%.*]], ptr [[DST:%.*]], align 2
295+
; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <vscale x 4 x ptr> poison, ptr [[DST:%.*]], i64 0
296+
; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <vscale x 4 x ptr> [[BROADCAST_SPLATINSERT]], <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer
297+
; CHECK-NEXT: [[BROADCAST_VALUE:%.*]] = insertelement <vscale x 4 x i16> poison, i16 [[VAL:%.*]], i64 0
298+
; CHECK-NEXT: [[BROADCAST_SPLATVALUE:%.*]] = shufflevector <vscale x 4 x i16> [[BROADCAST_VALUE]], <vscale x 4 x i16> poison, <vscale x 4 x i32> zeroinitializer
299+
; CHECK-NEXT: call void @llvm.masked.scatter.nxv4i16.nxv4p0(<vscale x 4 x i16> [[BROADCAST_SPLATVALUE]], <vscale x 4 x ptr> [[BROADCAST_SPLAT]], i32 2, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> zeroinitializer, i1 true, i32 0), <vscale x 4 x i1> zeroinitializer, <vscale x 4 x i32> zeroinitializer))
296300
; CHECK-NEXT: ret void
297301
;
298302
entry:
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
2+
; RUN: opt -S -passes=instcombine < %s | FileCheck %s
3+
4+
@c = global i32 0, align 4
5+
@b = global i32 0, align 4
6+
7+
define void @masked_scatter1() {
8+
; CHECK-LABEL: define void @masked_scatter1() {
9+
; CHECK-NEXT: call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> zeroinitializer, <vscale x 4 x ptr> shufflevector (<vscale x 4 x ptr> insertelement (<vscale x 4 x ptr> poison, ptr @c, i64 0), <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer), i32 4, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c), i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
10+
; CHECK-NEXT: ret void
11+
;
12+
call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> zeroinitializer, <vscale x 4 x ptr> splat (ptr @c), i32 4, <vscale x 4 x i1> splat (i1 icmp eq (ptr getelementptr (i32, ptr @b, i64 1), ptr @c)))
13+
ret void
14+
}
15+
16+
define void @masked_scatter2() {
17+
; CHECK-LABEL: define void @masked_scatter2() {
18+
; CHECK-NEXT: store i32 0, ptr @c, align 4
19+
; CHECK-NEXT: ret void
20+
;
21+
call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 true))
22+
ret void
23+
}
24+
25+
define void @masked_scatter3() {
26+
; CHECK-LABEL: define void @masked_scatter3() {
27+
; CHECK-NEXT: store i32 0, ptr @c, align 4
28+
; CHECK-NEXT: ret void
29+
;
30+
call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> undef)
31+
ret void
32+
}
33+
34+
define void @masked_scatter4() {
35+
; CHECK-LABEL: define void @masked_scatter4() {
36+
; CHECK-NEXT: ret void
37+
;
38+
call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 false))
39+
ret void
40+
}
41+
42+
define void @masked_scatter5() {
43+
; CHECK-LABEL: define void @masked_scatter5() {
44+
; CHECK-NEXT: store i32 0, ptr @c, align 4
45+
; CHECK-NEXT: ret void
46+
;
47+
call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> <i1 true, i1 false>)
48+
ret void
49+
}
50+
51+
define void @masked_scatter6() {
52+
; CHECK-LABEL: define void @masked_scatter6() {
53+
; CHECK-NEXT: store i32 0, ptr @c, align 4
54+
; CHECK-NEXT: ret void
55+
;
56+
call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> <i1 undef, i1 false>)
57+
ret void
58+
}
59+
60+
define void @masked_scatter7() {
61+
; CHECK-LABEL: define void @masked_scatter7() {
62+
; CHECK-NEXT: call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> <ptr @c, ptr @c>, i32 4, <2 x i1> <i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c), i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c)>)
63+
; CHECK-NEXT: ret void
64+
;
65+
call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 icmp eq (ptr getelementptr (i32, ptr @b, i64 1), ptr @c)))
66+
ret void
67+
}

0 commit comments

Comments
 (0)