Skip to content

[InstCombine] Widen Sel width after Cmp to generate Max/Min intrinsics. #118932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 60 additions & 33 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8781,40 +8781,10 @@ static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred,
return matchFastFloatClamp(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS);
}

/// Helps to match a select pattern in case of a type mismatch.
///
/// The function processes the case when type of true and false values of a
/// select instruction differs from type of the cmp instruction operands because
/// of a cast instruction. The function checks if it is legal to move the cast
/// operation after "select". If yes, it returns the new second value of
/// "select" (with the assumption that cast is moved):
/// 1. As operand of cast instruction when both values of "select" are same cast
/// instructions.
/// 2. As restored constant (by applying reverse cast operation) when the first
/// value of the "select" is a cast operation and the second value is a
/// constant.
/// NOTE: We return only the new second value because the first value could be
/// accessed as operand of cast instruction.
static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2,
Instruction::CastOps *CastOp) {
auto *Cast1 = dyn_cast<CastInst>(V1);
if (!Cast1)
return nullptr;

*CastOp = Cast1->getOpcode();
Type *SrcTy = Cast1->getSrcTy();
if (auto *Cast2 = dyn_cast<CastInst>(V2)) {
// If V1 and V2 are both the same cast from the same type, look through V1.
if (*CastOp == Cast2->getOpcode() && SrcTy == Cast2->getSrcTy())
return Cast2->getOperand(0);
return nullptr;
}

auto *C = dyn_cast<Constant>(V2);
if (!C)
return nullptr;

static Value *lookThroughCastConst(CmpInst *CmpI, Type *SrcTy, Constant *C,
Instruction::CastOps *CastOp) {
const DataLayout &DL = CmpI->getDataLayout();

Constant *CastedTo = nullptr;
switch (*CastOp) {
case Instruction::ZExt:
Expand Down Expand Up @@ -8890,6 +8860,63 @@ static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2,
return CastedTo;
}

/// Helps to match a select pattern in case of a type mismatch.
///
/// The function processes the case when type of true and false values of a
/// select instruction differs from type of the cmp instruction operands because
/// of a cast instruction. The function checks if it is legal to move the cast
/// operation after "select". If yes, it returns the new second value of
/// "select" (with the assumption that cast is moved):
/// 1. As operand of cast instruction when both values of "select" are same cast
/// instructions.
/// 2. As restored constant (by applying reverse cast operation) when the first
/// value of the "select" is a cast operation and the second value is a
/// constant. It is implemented in lookThroughCastConst().
/// 3. As one operand is cast instruction and the other is not. The operands in
/// sel(cmp) are in different type integer.
/// NOTE: We return only the new second value because the first value could be
/// accessed as operand of cast instruction.
static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2,
Instruction::CastOps *CastOp) {
auto *Cast1 = dyn_cast<CastInst>(V1);
if (!Cast1)
return nullptr;

*CastOp = Cast1->getOpcode();
Type *SrcTy = Cast1->getSrcTy();
if (auto *Cast2 = dyn_cast<CastInst>(V2)) {
// If V1 and V2 are both the same cast from the same type, look through V1.
if (*CastOp == Cast2->getOpcode() && SrcTy == Cast2->getSrcTy())
return Cast2->getOperand(0);
return nullptr;
}

auto *C = dyn_cast<Constant>(V2);
if (C)
return lookThroughCastConst(CmpI, SrcTy, C, CastOp);

Value *CastedTo = nullptr;
if (*CastOp == Instruction::Trunc) {
if (match(CmpI->getOperand(1), m_ZExtOrSExt(m_Specific(V2)))) {
// Here we have the following case:
// %y_ext = sext iK %y to iN
// %cond = cmp iN %x, %y_ext
// %tr = trunc iN %x to iK
// %narrowsel = select i1 %cond, iK %tr, iK %y
//
// We can always move trunc after select operation:
// %y_ext = sext iK %y to iN
// %cond = cmp iN %x, %y_ext
// %widesel = select i1 %cond, iN %x, iN %y_ext
// %tr = trunc iN %widesel to iK
assert(V2->getType() == Cast1->getType() &&
"V2 and Cast1 should be the same type.");
CastedTo = CmpI->getOperand(1);
}
}

return CastedTo;
}
SelectPatternResult llvm::matchSelectPattern(Value *V, Value *&LHS, Value *&RHS,
Instruction::CastOps *CastOp,
unsigned Depth) {
Expand Down
56 changes: 56 additions & 0 deletions llvm/test/Transforms/InstCombine/minmax-fold.ll
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,34 @@ define zeroext i8 @look_through_cast2(i32 %x) {
ret i8 %res
}

define i8 @look_through_cast_int_min(i8 %a, i32 %min) {
; CHECK-LABEL: @look_through_cast_int_min(
; CHECK-NEXT: [[A32:%.*]] = sext i8 [[A:%.*]] to i32
; CHECK-NEXT: [[SEL1:%.*]] = call i32 @llvm.smin.i32(i32 [[MIN:%.*]], i32 [[A32]])
; CHECK-NEXT: [[SEL:%.*]] = trunc i32 [[SEL1]] to i8
; CHECK-NEXT: ret i8 [[SEL]]
;
%a32 = sext i8 %a to i32
%cmp = icmp slt i32 %a32, %min
%min8 = trunc i32 %min to i8
%sel = select i1 %cmp, i8 %a, i8 %min8
ret i8 %sel
}

define i16 @look_through_cast_int_max(i16 %a, i32 %max) {
; CHECK-LABEL: @look_through_cast_int_max(
; CHECK-NEXT: [[A32:%.*]] = zext i16 [[A:%.*]] to i32
; CHECK-NEXT: [[SEL1:%.*]] = call i32 @llvm.smax.i32(i32 [[MAX:%.*]], i32 [[A32]])
; CHECK-NEXT: [[SEL:%.*]] = trunc i32 [[SEL1]] to i16
; CHECK-NEXT: ret i16 [[SEL]]
;
%a32 = zext i16 %a to i32
%cmp = icmp sgt i32 %max, %a32
%max8 = trunc i32 %max to i16
%sel = select i1 %cmp, i16 %max8, i16 %a
ret i16 %sel
}

define <2 x i8> @min_through_cast_vec1(<2 x i32> %x) {
; CHECK-LABEL: @min_through_cast_vec1(
; CHECK-NEXT: [[RES1:%.*]] = call <2 x i32> @llvm.smin.v2i32(<2 x i32> [[X:%.*]], <2 x i32> <i32 510, i32 511>)
Expand All @@ -721,6 +749,34 @@ define <2 x i8> @min_through_cast_vec2(<2 x i32> %x) {
ret <2 x i8> %res
}

define <8 x i8> @look_through_cast_int_min_vec(<8 x i8> %a, <8 x i32> %min) {
; CHECK-LABEL: @look_through_cast_int_min_vec(
; CHECK-NEXT: [[A32:%.*]] = sext <8 x i8> [[A:%.*]] to <8 x i32>
; CHECK-NEXT: [[SEL1:%.*]] = call <8 x i32> @llvm.umin.v8i32(<8 x i32> [[MIN:%.*]], <8 x i32> [[A32]])
; CHECK-NEXT: [[SEL:%.*]] = trunc <8 x i32> [[SEL1]] to <8 x i8>
; CHECK-NEXT: ret <8 x i8> [[SEL]]
;
%a32 = sext <8 x i8> %a to <8 x i32>
%cmp = icmp ult <8 x i32> %a32, %min
%min8 = trunc <8 x i32> %min to <8 x i8>
%sel = select <8 x i1> %cmp, <8 x i8> %a, <8 x i8> %min8
ret <8 x i8> %sel
}

define <8 x i32> @look_through_cast_int_max_vec(<8 x i32> %a, <8 x i64> %max) {
; CHECK-LABEL: @look_through_cast_int_max_vec(
; CHECK-NEXT: [[A32:%.*]] = zext <8 x i32> [[A:%.*]] to <8 x i64>
; CHECK-NEXT: [[SEL1:%.*]] = call <8 x i64> @llvm.smax.v8i64(<8 x i64> [[MAX:%.*]], <8 x i64> [[A32]])
; CHECK-NEXT: [[SEL:%.*]] = trunc <8 x i64> [[SEL1]] to <8 x i32>
; CHECK-NEXT: ret <8 x i32> [[SEL]]
;
%a32 = zext <8 x i32> %a to <8 x i64>
%cmp = icmp sgt <8 x i64> %a32, %max
%max8 = trunc <8 x i64> %max to <8 x i32>
%sel = select <8 x i1> %cmp, <8 x i32> %a, <8 x i32> %max8
ret <8 x i32> %sel
}

; Remove a min/max op in a sequence with a common operand.
; PR35717: https://bugs.llvm.org/show_bug.cgi?id=35717

Expand Down
Loading