Skip to content

Commit caa124b

Browse files
[InstCombine] Zero-extend shift amounts in narrow funnel shift ops
An issue arose when handling shift amounts while performing narrowed funnel shifts simplification. Specifically, shift amounts were incorrectly truncated when their type was narrower than the target bit width. This has been addressed by zero-extending `ShAmt` in such cases. Fixes: #71463. Proof: https://alive2.llvm.org/ce/z/5draKz.
1 parent 16ef496 commit caa124b

File tree

2 files changed

+55
-25
lines changed

2 files changed

+55
-25
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -502,11 +502,20 @@ Instruction *InstCombinerImpl::narrowFunnelShift(TruncInst &Trunc) {
502502
if (!MaskedValueIsZero(ShVal1, HiBitMask, 0, &Trunc))
503503
return nullptr;
504504

505-
// We have an unnecessarily wide rotate!
506-
// trunc (or (shl ShVal0, ShAmt), (lshr ShVal1, BitWidth - ShAmt))
507-
// Narrow the inputs and convert to funnel shift intrinsic:
508-
// llvm.fshl.i8(trunc(ShVal), trunc(ShVal), trunc(ShAmt))
509-
Value *NarrowShAmt = Builder.CreateTrunc(ShAmt, DestTy);
505+
// Adjust the width of ShAmt for narrowed funnel shift operation:
506+
// - Zero-extend if ShAmt is narrower than the destination type.
507+
// - Truncate if ShAmt is wider, discarding non-significant high-order bits.
508+
// This prepares ShAmt for llvm.fshl.i8(trunc(ShVal), trunc(ShVal),
509+
// zext/trunc(ShAmt)).
510+
Value *NarrowShAmt;
511+
if (ShAmt->getType()->getScalarSizeInBits() < NarrowWidth) {
512+
// If ShAmt is narrower than the destination type, zero-extend it.
513+
NarrowShAmt = Builder.CreateZExt(ShAmt, DestTy, "shamt.zext");
514+
} else {
515+
// If ShAmt is wider than the destination type, truncate it.
516+
NarrowShAmt = Builder.CreateTrunc(ShAmt, DestTy, "shamt.trunc");
517+
}
518+
510519
Value *X, *Y;
511520
X = Y = Builder.CreateTrunc(ShVal0, DestTy);
512521
if (ShVal0 != ShVal1)

llvm/test/Transforms/InstCombine/rotate.ll

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -421,8 +421,8 @@ define <2 x i16> @rotate_left_commute_16bit_vec(<2 x i16> %v, <2 x i32> %shift)
421421

422422
define i8 @rotate_right_8bit(i8 %v, i3 %shift) {
423423
; CHECK-LABEL: @rotate_right_8bit(
424-
; CHECK-NEXT: [[TMP1:%.*]] = zext i3 [[SHIFT:%.*]] to i8
425-
; CHECK-NEXT: [[CONV2:%.*]] = call i8 @llvm.fshr.i8(i8 [[V:%.*]], i8 [[V]], i8 [[TMP1]])
424+
; CHECK-NEXT: [[SHAMT_TRUNC:%.*]] = zext i3 [[SHIFT:%.*]] to i8
425+
; CHECK-NEXT: [[CONV2:%.*]] = call i8 @llvm.fshr.i8(i8 [[V:%.*]], i8 [[V]], i8 [[SHAMT_TRUNC]])
426426
; CHECK-NEXT: ret i8 [[CONV2]]
427427
;
428428
%and = zext i3 %shift to i32
@@ -441,10 +441,10 @@ define i8 @rotate_right_8bit(i8 %v, i3 %shift) {
441441
define i8 @rotate_right_commute_8bit_unmasked_shl(i32 %v, i32 %shift) {
442442
; CHECK-LABEL: @rotate_right_commute_8bit_unmasked_shl(
443443
; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[SHIFT:%.*]] to i8
444-
; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], 3
445-
; CHECK-NEXT: [[TMP3:%.*]] = trunc i32 [[V:%.*]] to i8
446-
; CHECK-NEXT: [[TMP4:%.*]] = trunc i32 [[V]] to i8
447-
; CHECK-NEXT: [[CONV2:%.*]] = call i8 @llvm.fshr.i8(i8 [[TMP3]], i8 [[TMP4]], i8 [[TMP2]])
444+
; CHECK-NEXT: [[SHAMT_TRUNC:%.*]] = and i8 [[TMP1]], 3
445+
; CHECK-NEXT: [[TMP2:%.*]] = trunc i32 [[V:%.*]] to i8
446+
; CHECK-NEXT: [[TMP3:%.*]] = trunc i32 [[V]] to i8
447+
; CHECK-NEXT: [[CONV2:%.*]] = call i8 @llvm.fshr.i8(i8 [[TMP2]], i8 [[TMP3]], i8 [[SHAMT_TRUNC]])
448448
; CHECK-NEXT: ret i8 [[CONV2]]
449449
;
450450
%and = and i32 %shift, 3
@@ -462,10 +462,10 @@ define i8 @rotate_right_commute_8bit_unmasked_shl(i32 %v, i32 %shift) {
462462
define i8 @rotate_right_commute_8bit(i32 %v, i32 %shift) {
463463
; CHECK-LABEL: @rotate_right_commute_8bit(
464464
; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[SHIFT:%.*]] to i8
465-
; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], 3
466-
; CHECK-NEXT: [[TMP3:%.*]] = trunc i32 [[V:%.*]] to i8
467-
; CHECK-NEXT: [[TMP4:%.*]] = trunc i32 [[V]] to i8
468-
; CHECK-NEXT: [[CONV2:%.*]] = call i8 @llvm.fshr.i8(i8 [[TMP3]], i8 [[TMP4]], i8 [[TMP2]])
465+
; CHECK-NEXT: [[SHAMT_TRUNC:%.*]] = and i8 [[TMP1]], 3
466+
; CHECK-NEXT: [[TMP2:%.*]] = trunc i32 [[V:%.*]] to i8
467+
; CHECK-NEXT: [[TMP3:%.*]] = trunc i32 [[V]] to i8
468+
; CHECK-NEXT: [[CONV2:%.*]] = call i8 @llvm.fshr.i8(i8 [[TMP2]], i8 [[TMP3]], i8 [[SHAMT_TRUNC]])
469469
; CHECK-NEXT: ret i8 [[CONV2]]
470470
;
471471
%and = and i32 %shift, 3
@@ -483,8 +483,8 @@ define i8 @rotate_right_commute_8bit(i32 %v, i32 %shift) {
483483

484484
define i8 @rotate8_not_safe(i8 %v, i32 %shamt) {
485485
; CHECK-LABEL: @rotate8_not_safe(
486-
; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[SHAMT:%.*]] to i8
487-
; CHECK-NEXT: [[RET:%.*]] = call i8 @llvm.fshl.i8(i8 [[V:%.*]], i8 [[V]], i8 [[TMP1]])
486+
; CHECK-NEXT: [[SHAMT_TRUNC:%.*]] = trunc i32 [[SHAMT:%.*]] to i8
487+
; CHECK-NEXT: [[RET:%.*]] = call i8 @llvm.fshl.i8(i8 [[V:%.*]], i8 [[V]], i8 [[SHAMT_TRUNC]])
488488
; CHECK-NEXT: ret i8 [[RET]]
489489
;
490490
%conv = zext i8 %v to i32
@@ -597,8 +597,8 @@ define i8 @rotateright_8_neg_mask_commute(i8 %v, i8 %shamt) {
597597

598598
define i16 @rotateright_16_neg_mask_wide_amount(i16 %v, i32 %shamt) {
599599
; CHECK-LABEL: @rotateright_16_neg_mask_wide_amount(
600-
; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[SHAMT:%.*]] to i16
601-
; CHECK-NEXT: [[RET:%.*]] = call i16 @llvm.fshr.i16(i16 [[V:%.*]], i16 [[V]], i16 [[TMP1]])
600+
; CHECK-NEXT: [[SHAMT_TRUNC:%.*]] = trunc i32 [[SHAMT:%.*]] to i16
601+
; CHECK-NEXT: [[RET:%.*]] = call i16 @llvm.fshr.i16(i16 [[V:%.*]], i16 [[V]], i16 [[SHAMT_TRUNC]])
602602
; CHECK-NEXT: ret i16 [[RET]]
603603
;
604604
%neg = sub i32 0, %shamt
@@ -614,8 +614,8 @@ define i16 @rotateright_16_neg_mask_wide_amount(i16 %v, i32 %shamt) {
614614

615615
define i16 @rotateright_16_neg_mask_wide_amount_commute(i16 %v, i32 %shamt) {
616616
; CHECK-LABEL: @rotateright_16_neg_mask_wide_amount_commute(
617-
; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[SHAMT:%.*]] to i16
618-
; CHECK-NEXT: [[RET:%.*]] = call i16 @llvm.fshr.i16(i16 [[V:%.*]], i16 [[V]], i16 [[TMP1]])
617+
; CHECK-NEXT: [[SHAMT_TRUNC:%.*]] = trunc i32 [[SHAMT:%.*]] to i16
618+
; CHECK-NEXT: [[RET:%.*]] = call i16 @llvm.fshr.i16(i16 [[V:%.*]], i16 [[V]], i16 [[SHAMT_TRUNC]])
619619
; CHECK-NEXT: ret i16 [[RET]]
620620
;
621621
%neg = sub i32 0, %shamt
@@ -648,8 +648,8 @@ define i64 @rotateright_64_zext_neg_mask_amount(i64 %0, i32 %1) {
648648

649649
define i8 @rotateleft_8_neg_mask_wide_amount(i8 %v, i32 %shamt) {
650650
; CHECK-LABEL: @rotateleft_8_neg_mask_wide_amount(
651-
; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[SHAMT:%.*]] to i8
652-
; CHECK-NEXT: [[RET:%.*]] = call i8 @llvm.fshl.i8(i8 [[V:%.*]], i8 [[V]], i8 [[TMP1]])
651+
; CHECK-NEXT: [[SHAMT_TRUNC:%.*]] = trunc i32 [[SHAMT:%.*]] to i8
652+
; CHECK-NEXT: [[RET:%.*]] = call i8 @llvm.fshl.i8(i8 [[V:%.*]], i8 [[V]], i8 [[SHAMT_TRUNC]])
653653
; CHECK-NEXT: ret i8 [[RET]]
654654
;
655655
%neg = sub i32 0, %shamt
@@ -665,8 +665,8 @@ define i8 @rotateleft_8_neg_mask_wide_amount(i8 %v, i32 %shamt) {
665665

666666
define i8 @rotateleft_8_neg_mask_wide_amount_commute(i8 %v, i32 %shamt) {
667667
; CHECK-LABEL: @rotateleft_8_neg_mask_wide_amount_commute(
668-
; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[SHAMT:%.*]] to i8
669-
; CHECK-NEXT: [[RET:%.*]] = call i8 @llvm.fshl.i8(i8 [[V:%.*]], i8 [[V]], i8 [[TMP1]])
668+
; CHECK-NEXT: [[SHAMT_TRUNC:%.*]] = trunc i32 [[SHAMT:%.*]] to i8
669+
; CHECK-NEXT: [[RET:%.*]] = call i8 @llvm.fshl.i8(i8 [[V:%.*]], i8 [[V]], i8 [[SHAMT_TRUNC]])
670670
; CHECK-NEXT: ret i8 [[RET]]
671671
;
672672
%neg = sub i32 0, %shamt
@@ -957,3 +957,24 @@ define i8 @unmasked_shlop_unmasked_shift_amount(i32 %x, i32 %shamt) {
957957
%t8 = trunc i32 %t7 to i8
958958
ret i8 %t8
959959
}
960+
961+
define i1 @check_rotate_masked_16bit(i8 %0, i32 %1) {
962+
; CHECK-LABEL: @check_rotate_masked_16bit(
963+
; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP1:%.*]], 1
964+
; CHECK-NEXT: [[TMP4:%.*]] = icmp eq i32 [[TMP3]], 0
965+
; CHECK-NEXT: ret i1 [[TMP4]]
966+
;
967+
%3 = and i32 %1, 1
968+
%4 = and i8 %0, 15
969+
%5 = zext i8 %4 to i32
970+
%6 = lshr i32 %3, %5
971+
%7 = sub i8 0, %0
972+
%8 = and i8 %7, 15
973+
%9 = zext i8 %8 to i32
974+
%10 = shl nuw nsw i32 %3, %9
975+
%11 = or i32 %6, %10
976+
%12 = trunc i32 %11 to i16
977+
%13 = sext i16 %12 to i64
978+
%14 = icmp uge i64 0, %13
979+
ret i1 %14
980+
}

0 commit comments

Comments
 (0)