Skip to content

[InstCombine] Add combines/simplifications for llvm.ptrmask #67166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clang/test/CodeGen/arm64_32-vaarg.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ long long test_longlong(OneLongLong input, va_list *mylist) {
// CHECK-LABEL: define{{.*}} i64 @test_longlong(i64 %input
// CHECK: [[STARTPTR:%.*]] = load ptr, ptr %mylist
// CHECK: [[ALIGN_TMP:%.+]] = getelementptr inbounds i8, ptr [[STARTPTR]], i32 7
// CHECK: [[ALIGNED_ADDR:%.+]] = tail call ptr @llvm.ptrmask.p0.i32(ptr nonnull [[ALIGN_TMP]], i32 -8)
// CHECK: [[ALIGNED_ADDR:%.+]] = tail call align 8 ptr @llvm.ptrmask.p0.i32(ptr nonnull [[ALIGN_TMP]], i32 -8)
// CHECK: [[NEXT:%.*]] = getelementptr inbounds i8, ptr [[ALIGNED_ADDR]], i32 8
// CHECK: store ptr [[NEXT]], ptr %mylist

Expand Down
38 changes: 38 additions & 0 deletions llvm/lib/Analysis/InstructionSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6411,6 +6411,44 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
return Constant::getNullValue(ReturnType);
break;
}
case Intrinsic::ptrmask: {
if (isa<PoisonValue>(Op0) || isa<PoisonValue>(Op1))
return PoisonValue::get(Op0->getType());

// NOTE: We can't apply this simplifications based on the value of Op1
// because we need to preserve provenance.
if (Q.isUndefValue(Op0) || match(Op0, m_Zero()))
return Constant::getNullValue(Op0->getType());

assert(Op1->getType()->getScalarSizeInBits() ==
Q.DL.getIndexTypeSizeInBits(Op0->getType()) &&
"Invalid mask width");
// If index-width (mask size) is less than pointer-size then mask is
// 1-extended.
if (match(Op1, m_PtrToInt(m_Specific(Op0))))
return Op0;

// NOTE: We may have attributes associated with the return value of the
// llvm.ptrmask intrinsic that will be lost when we just return the
// operand. We should try to preserve them.
if (match(Op1, m_AllOnes()) || Q.isUndefValue(Op1))
return Op0;

Constant *C;
if (match(Op1, m_ImmConstant(C))) {
KnownBits PtrKnown = computeKnownBits(Op0, /*Depth=*/0, Q);
// See if we only masking off bits we know are already zero due to
// alignment.
APInt IrrelevantPtrBits =
PtrKnown.Zero.zextOrTrunc(C->getType()->getScalarSizeInBits());
C = ConstantFoldBinaryOpOperands(
Instruction::Or, C, ConstantInt::get(C->getType(), IrrelevantPtrBits),
Q.DL);
if (C != nullptr && C->isAllOnesValue())
return Op0;
}
break;
}
case Intrinsic::smax:
case Intrinsic::smin:
case Intrinsic::umax:
Expand Down
38 changes: 35 additions & 3 deletions llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1962,17 +1962,49 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
break;
}
case Intrinsic::ptrmask: {
unsigned BitWidth = DL.getPointerTypeSizeInBits(II->getType());
KnownBits Known(BitWidth);
if (SimplifyDemandedInstructionBits(*II, Known))
return II;

Value *InnerPtr, *InnerMask;
bool Changed = false;
// Combine:
// (ptrmask (ptrmask p, A), B)
// -> (ptrmask p, (and A, B))
if (match(II->getArgOperand(0),
m_OneUse(m_Intrinsic<Intrinsic::ptrmask>(m_Value(InnerPtr),
m_Value(InnerMask))))) {
assert(II->getArgOperand(1)->getType() == InnerMask->getType() &&
"Mask types must match");
// TODO: If InnerMask == Op1, we could copy attributes from inner
// callsite -> outer callsite.
Value *NewMask = Builder.CreateAnd(II->getArgOperand(1), InnerMask);
return replaceInstUsesWith(
*II, Builder.CreateIntrinsic(InnerPtr->getType(), Intrinsic::ptrmask,
{InnerPtr, NewMask}));
replaceOperand(CI, 0, InnerPtr);
replaceOperand(CI, 1, NewMask);
Changed = true;
}

// See if we can deduce non-null.
if (!CI.hasRetAttr(Attribute::NonNull) &&
(Known.isNonZero() ||
isKnownNonZero(II, DL, /*Depth*/ 0, &AC, II, &DT))) {
CI.addRetAttr(Attribute::NonNull);
Changed = true;
}

unsigned NewAlignmentLog =
std::min(Value::MaxAlignmentExponent,
std::min(BitWidth - 1, Known.countMinTrailingZeros()));
// Known bits will capture if we had alignment information associated with
// the pointer argument.
if (NewAlignmentLog > Log2(CI.getRetAlign().valueOrOne())) {
CI.addRetAttr(Attribute::getWithAlignment(
CI.getContext(), Align(uint64_t(1) << NewAlignmentLog)));
Changed = true;
}
if (Changed)
return &CI;
break;
}
case Intrinsic::uadd_with_overflow:
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1933,6 +1933,15 @@ Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) {
return CastInst::CreateIntegerCast(P, Ty, /*isSigned=*/false);
}

// (ptrtoint (ptrmask P, M))
// -> (and (ptrtoint P), M)
// This is generally beneficial as `and` is better supported than `ptrmask`.
Value *Ptr, *Mask;
if (match(SrcOp, m_OneUse(m_Intrinsic<Intrinsic::ptrmask>(m_Value(Ptr),
m_Value(Mask)))) &&
Mask->getType() == Ty)
return BinaryOperator::CreateAnd(Builder.CreatePtrToInt(Ptr, Ty), Mask);

if (auto *GEP = dyn_cast<GetElementPtrInst>(SrcOp)) {
// Fold ptrtoint(gep null, x) to multiply + constant if the GEP has one use.
// While this can increase the number of instructions it doesn't actually
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
/// Tries to simplify operands to an integer instruction based on its
/// demanded bits.
bool SimplifyDemandedInstructionBits(Instruction &Inst);
bool SimplifyDemandedInstructionBits(Instruction &Inst, KnownBits &Known);

Value *SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
APInt &UndefElts, unsigned Depth = 0,
Expand Down
71 changes: 63 additions & 8 deletions llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,20 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
return true;
}

/// Returns the bitwidth of the given scalar or pointer type. For vector types,
/// returns the element type's bitwidth.
static unsigned getBitWidth(Type *Ty, const DataLayout &DL) {
if (unsigned BitWidth = Ty->getScalarSizeInBits())
return BitWidth;

return DL.getPointerTypeSizeInBits(Ty);
}

/// Inst is an integer instruction that SimplifyDemandedBits knows about. See if
/// the instruction has any properties that allow us to simplify its operands.
bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) {
unsigned BitWidth = Inst.getType()->getScalarSizeInBits();
KnownBits Known(BitWidth);
APInt DemandedMask(APInt::getAllOnes(BitWidth));

bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst,
KnownBits &Known) {
APInt DemandedMask(APInt::getAllOnes(Known.getBitWidth()));
Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, Known,
0, &Inst);
if (!V) return false;
Expand All @@ -65,6 +70,13 @@ bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) {
return true;
}

/// Inst is an integer instruction that SimplifyDemandedBits knows about. See if
/// the instruction has any properties that allow us to simplify its operands.
bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) {
KnownBits Known(getBitWidth(Inst.getType(), DL));
return SimplifyDemandedInstructionBits(Inst, Known);
}

/// This form of SimplifyDemandedBits simplifies the specified instruction
/// operand if possible, updating it in place. It returns true if it made any
/// change and false otherwise.
Expand Down Expand Up @@ -143,7 +155,6 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
return SimplifyMultipleUseDemandedBits(I, DemandedMask, Known, Depth, CxtI);

KnownBits LHSKnown(BitWidth), RHSKnown(BitWidth);

// If this is the root being simplified, allow it to have multiple uses,
// just set the DemandedMask to all bits so that we can try to simplify the
// operands. This allows visitTruncInst (for example) to simplify the
Expand Down Expand Up @@ -893,6 +904,48 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
}
break;
}
case Intrinsic::ptrmask: {
unsigned MaskWidth = I->getOperand(1)->getType()->getScalarSizeInBits();
RHSKnown = KnownBits(MaskWidth);
// If either the LHS or the RHS are Zero, the result is zero.
if (SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1) ||
SimplifyDemandedBits(
I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth),
RHSKnown, Depth + 1))
return I;

// TODO: Should be 1-extend
RHSKnown = RHSKnown.anyextOrTrunc(BitWidth);
assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?");
assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");

Known = LHSKnown & RHSKnown;
KnownBitsComputed = true;

// If the client is only demanding bits we know to be zero, return
// `llvm.ptrmask(p, 0)`. We can't return `null` here due to pointer
// provenance, but making the mask zero will be easily optimizable in
// the backend.
if (DemandedMask.isSubsetOf(Known.Zero) &&
!match(I->getOperand(1), m_Zero()))
return replaceOperand(
*I, 1, Constant::getNullValue(I->getOperand(1)->getType()));

// Mask in demanded space does nothing.
// NOTE: We may have attributes associated with the return value of the
// llvm.ptrmask intrinsic that will be lost when we just return the
// operand. We should try to preserve them.
if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero))
return I->getOperand(0);

// If the RHS is a constant, see if we can simplify it.
if (ShrinkDemandedConstant(
I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth)))
return I;

break;
}

case Intrinsic::fshr:
case Intrinsic::fshl: {
const APInt *SA;
Expand Down Expand Up @@ -978,8 +1031,10 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
}

// If the client is only demanding bits that we know, return the known
// constant.
if (DemandedMask.isSubsetOf(Known.Zero|Known.One))
// constant. We can't directly simplify pointers as a constant because of
// pointer provenance.
// TODO: We could return `(inttoptr const)` for pointers.
if (!V->getType()->isPointerTy() && DemandedMask.isSubsetOf(Known.Zero | Known.One))
return Constant::getIntegerValue(VTy, Known.One);
return nullptr;
}
Expand Down
5 changes: 2 additions & 3 deletions llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,7 @@ define i8 @ptrmask_cast_local_to_flat_const_mask_fffffffffffffffe(ptr addrspace(

define i8 @ptrmask_cast_local_to_flat_const_mask_ffffffffffffffff(ptr addrspace(3) %src.ptr) {
; CHECK-LABEL: @ptrmask_cast_local_to_flat_const_mask_ffffffffffffffff(
; CHECK-NEXT: [[TMP1:%.*]] = call ptr addrspace(3) @llvm.ptrmask.p3.i32(ptr addrspace(3) [[SRC_PTR:%.*]], i32 -1)
; CHECK-NEXT: [[LOAD:%.*]] = load i8, ptr addrspace(3) [[TMP1]], align 1
; CHECK-NEXT: [[LOAD:%.*]] = load i8, ptr addrspace(3) [[SRC_PTR:%.*]], align 1
; CHECK-NEXT: ret i8 [[LOAD]]
;
%cast = addrspacecast ptr addrspace(3) %src.ptr to ptr
Expand All @@ -333,7 +332,7 @@ define i8 @ptrmask_cast_local_to_flat_const_mask_ffffffffffffffff(ptr addrspace(
; Make sure non-constant masks can also be handled.
define i8 @ptrmask_cast_local_to_flat_load_range_mask(ptr addrspace(3) %src.ptr, ptr addrspace(1) %mask.ptr) {
; CHECK-LABEL: @ptrmask_cast_local_to_flat_load_range_mask(
; CHECK-NEXT: [[LOAD_MASK:%.*]] = load i64, ptr addrspace(1) [[MASK_PTR:%.*]], align 8, !range !0
; CHECK-NEXT: [[LOAD_MASK:%.*]] = load i64, ptr addrspace(1) [[MASK_PTR:%.*]], align 8, !range [[RNG0:![0-9]+]]
; CHECK-NEXT: [[TMP1:%.*]] = trunc i64 [[LOAD_MASK]] to i32
; CHECK-NEXT: [[TMP2:%.*]] = call ptr addrspace(3) @llvm.ptrmask.p3.i32(ptr addrspace(3) [[SRC_PTR:%.*]], i32 [[TMP1]])
; CHECK-NEXT: [[LOAD:%.*]] = load i8, ptr addrspace(3) [[TMP2]], align 1
Expand Down
16 changes: 6 additions & 10 deletions llvm/test/Transforms/InstCombine/align-addr.ll
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ define <16 x i8> @ptrmask_align_unknown_ptr_align1(ptr align 1 %ptr, i64 %mask)

define <16 x i8> @ptrmask_align_unknown_ptr_align8(ptr align 8 %ptr, i64 %mask) {
; CHECK-LABEL: @ptrmask_align_unknown_ptr_align8(
; CHECK-NEXT: [[ALIGNED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[PTR:%.*]], i64 [[MASK:%.*]])
; CHECK-NEXT: [[ALIGNED:%.*]] = call align 8 ptr @llvm.ptrmask.p0.i64(ptr [[PTR:%.*]], i64 [[MASK:%.*]])
; CHECK-NEXT: [[LOAD:%.*]] = load <16 x i8>, ptr [[ALIGNED]], align 1
; CHECK-NEXT: ret <16 x i8> [[LOAD]]
;
Expand All @@ -147,7 +147,7 @@ define <16 x i8> @ptrmask_align_unknown_ptr_align8(ptr align 8 %ptr, i64 %mask)
; Increase load align from 1 to 2
define <16 x i8> @ptrmask_align2_ptr_align1(ptr align 1 %ptr) {
; CHECK-LABEL: @ptrmask_align2_ptr_align1(
; CHECK-NEXT: [[ALIGNED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[PTR:%.*]], i64 -2)
; CHECK-NEXT: [[ALIGNED:%.*]] = call align 2 ptr @llvm.ptrmask.p0.i64(ptr [[PTR:%.*]], i64 -2)
; CHECK-NEXT: [[LOAD:%.*]] = load <16 x i8>, ptr [[ALIGNED]], align 1
; CHECK-NEXT: ret <16 x i8> [[LOAD]]
;
Expand All @@ -159,7 +159,7 @@ define <16 x i8> @ptrmask_align2_ptr_align1(ptr align 1 %ptr) {
; Increase load align from 1 to 4
define <16 x i8> @ptrmask_align4_ptr_align1(ptr align 1 %ptr) {
; CHECK-LABEL: @ptrmask_align4_ptr_align1(
; CHECK-NEXT: [[ALIGNED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[PTR:%.*]], i64 -4)
; CHECK-NEXT: [[ALIGNED:%.*]] = call align 4 ptr @llvm.ptrmask.p0.i64(ptr [[PTR:%.*]], i64 -4)
; CHECK-NEXT: [[LOAD:%.*]] = load <16 x i8>, ptr [[ALIGNED]], align 1
; CHECK-NEXT: ret <16 x i8> [[LOAD]]
;
Expand All @@ -171,7 +171,7 @@ define <16 x i8> @ptrmask_align4_ptr_align1(ptr align 1 %ptr) {
; Increase load align from 1 to 8
define <16 x i8> @ptrmask_align8_ptr_align1(ptr align 1 %ptr) {
; CHECK-LABEL: @ptrmask_align8_ptr_align1(
; CHECK-NEXT: [[ALIGNED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[PTR:%.*]], i64 -8)
; CHECK-NEXT: [[ALIGNED:%.*]] = call align 8 ptr @llvm.ptrmask.p0.i64(ptr [[PTR:%.*]], i64 -8)
; CHECK-NEXT: [[LOAD:%.*]] = load <16 x i8>, ptr [[ALIGNED]], align 1
; CHECK-NEXT: ret <16 x i8> [[LOAD]]
;
Expand All @@ -181,11 +181,9 @@ define <16 x i8> @ptrmask_align8_ptr_align1(ptr align 1 %ptr) {
}

; Underlying alignment already the same as forced alignment by ptrmask
; TODO: Should be able to drop the ptrmask
define <16 x i8> @ptrmask_align8_ptr_align8(ptr align 8 %ptr) {
; CHECK-LABEL: @ptrmask_align8_ptr_align8(
; CHECK-NEXT: [[ALIGNED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[PTR:%.*]], i64 -8)
; CHECK-NEXT: [[LOAD:%.*]] = load <16 x i8>, ptr [[ALIGNED]], align 1
; CHECK-NEXT: [[LOAD:%.*]] = load <16 x i8>, ptr [[PTR:%.*]], align 1
; CHECK-NEXT: ret <16 x i8> [[LOAD]]
;
%aligned = call ptr @llvm.ptrmask.p0.i64(ptr %ptr, i64 -8)
Expand All @@ -194,11 +192,9 @@ define <16 x i8> @ptrmask_align8_ptr_align8(ptr align 8 %ptr) {
}

; Underlying alignment greater than alignment forced by ptrmask
; TODO: Should be able to drop the ptrmask
define <16 x i8> @ptrmask_align8_ptr_align16(ptr align 16 %ptr) {
; CHECK-LABEL: @ptrmask_align8_ptr_align16(
; CHECK-NEXT: [[ALIGNED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[PTR:%.*]], i64 -8)
; CHECK-NEXT: [[LOAD:%.*]] = load <16 x i8>, ptr [[ALIGNED]], align 1
; CHECK-NEXT: [[LOAD:%.*]] = load <16 x i8>, ptr [[PTR:%.*]], align 1
; CHECK-NEXT: ret <16 x i8> [[LOAD]]
;
%aligned = call ptr @llvm.ptrmask.p0.i64(ptr %ptr, i64 -8)
Expand Down
27 changes: 26 additions & 1 deletion llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
target datalayout = "p1:64:64:64:32"

declare ptr @llvm.ptrmask.p0.i64(ptr, i64)
declare ptr addrspace(1) @llvm.ptrmask.p1.i32(ptr addrspace(1), i32)
declare ptr addrspace(1) @llvm.ptrmask.p1.i32(ptr addrspace(1) , i32)
declare <2 x ptr addrspace(1) > @llvm.ptrmask.v2p1.v2i32(<2 x ptr addrspace(1) >, <2 x i32>)
declare <2 x ptr> @llvm.ptrmask.v2p0.v2i64(<2 x ptr>, <2 x i64>)
declare void @use.ptr(ptr)

Expand Down Expand Up @@ -57,3 +58,27 @@ define ptr addrspace(1) @fold_2x_smaller_index_type(ptr addrspace(1) %p, i32 %m0
%p1 = call ptr addrspace(1) @llvm.ptrmask.p1.i32(ptr addrspace(1) %p0, i32 %m1)
ret ptr addrspace(1) %p1
}

define <2 x ptr> @fold_2x_vec_i64(<2 x ptr> %p, <2 x i64> %m0) {
; CHECK-LABEL: define <2 x ptr> @fold_2x_vec_i64
; CHECK-SAME: (<2 x ptr> [[P:%.*]], <2 x i64> [[M0:%.*]]) {
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i64> [[M0]], <i64 -2, i64 -2>
; CHECK-NEXT: [[P1:%.*]] = call align 2 <2 x ptr> @llvm.ptrmask.v2p0.v2i64(<2 x ptr> [[P]], <2 x i64> [[TMP1]])
; CHECK-NEXT: ret <2 x ptr> [[P1]]
;
%p0 = call <2 x ptr> @llvm.ptrmask.v2p0.v2i64(<2 x ptr> %p, <2 x i64> %m0)
%p1 = call <2 x ptr> @llvm.ptrmask.v2p0.v2i64(<2 x ptr> %p0, <2 x i64> <i64 -2, i64 -2>)
ret <2 x ptr> %p1
}

define <2 x ptr addrspace(1) > @fold_2x_vec_i32_undef(<2 x ptr addrspace(1) > %p, <2 x i32> %m0) {
; CHECK-LABEL: define <2 x ptr addrspace(1)> @fold_2x_vec_i32_undef
; CHECK-SAME: (<2 x ptr addrspace(1)> [[P:%.*]], <2 x i32> [[M0:%.*]]) {
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i32> [[M0]], <i32 -2, i32 undef>
; CHECK-NEXT: [[P1:%.*]] = call <2 x ptr addrspace(1)> @llvm.ptrmask.v2p1.v2i32(<2 x ptr addrspace(1)> [[P]], <2 x i32> [[TMP1]])
; CHECK-NEXT: ret <2 x ptr addrspace(1)> [[P1]]
;
%p0 = call <2 x ptr addrspace(1) > @llvm.ptrmask.v2p1.v2i32(<2 x ptr addrspace(1) > %p, <2 x i32> %m0)
%p1 = call <2 x ptr addrspace(1) > @llvm.ptrmask.v2p1.v2i32(<2 x ptr addrspace(1) > %p0, <2 x i32> <i32 -2, i32 undef>)
ret <2 x ptr addrspace(1) > %p1
}
Loading