Skip to content

Commit eb0e7ac

Browse files
committed
[InstCombine] canEvaluateTruncated - use KnownBits to check for inrange shift amounts
Currently canEvaluateTruncated can only attempt to truncate shifts if they are scalar/uniform constant amounts that are in range. This patch replaces the constant extraction code with KnownBits handling, using the KnownBits::getMaxValue to check that the amounts are inrange. This enables support for nonuniform constant cases, and also variable shift amounts that have been masked somehow. Annoyingly, this still won't work for vectors with (demanded) undefs as KnownBits returns nothing in those cases, but its a definite improvement on what we currently have. Differential Revision: https://reviews.llvm.org/D83127
1 parent 53422e8 commit eb0e7ac

File tree

4 files changed

+77
-89
lines changed

4 files changed

+77
-89
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -377,29 +377,31 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC,
377377
break;
378378
}
379379
case Instruction::Shl: {
380-
// If we are truncating the result of this SHL, and if it's a shift of a
381-
// constant amount, we can always perform a SHL in a smaller type.
382-
const APInt *Amt;
383-
if (match(I->getOperand(1), m_APInt(Amt))) {
384-
uint32_t BitWidth = Ty->getScalarSizeInBits();
385-
if (Amt->getLimitedValue(BitWidth) < BitWidth)
386-
return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI);
387-
}
380+
// If we are truncating the result of this SHL, and if it's a shift of an
381+
// inrange amount, we can always perform a SHL in a smaller type.
382+
uint32_t BitWidth = Ty->getScalarSizeInBits();
383+
KnownBits AmtKnownBits =
384+
llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout());
385+
if (AmtKnownBits.getMaxValue().ult(BitWidth))
386+
return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
387+
canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
388388
break;
389389
}
390390
case Instruction::LShr: {
391391
// If this is a truncate of a logical shr, we can truncate it to a smaller
392392
// lshr iff we know that the bits we would otherwise be shifting in are
393393
// already zeros.
394-
const APInt *Amt;
395-
if (match(I->getOperand(1), m_APInt(Amt))) {
396-
uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
397-
uint32_t BitWidth = Ty->getScalarSizeInBits();
398-
if (Amt->getLimitedValue(BitWidth) < BitWidth &&
399-
IC.MaskedValueIsZero(I->getOperand(0),
400-
APInt::getBitsSetFrom(OrigBitWidth, BitWidth), 0, CxtI)) {
401-
return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI);
402-
}
394+
// TODO: It is enough to check that the bits we would be shifting in are
395+
// zero - use AmtKnownBits.getMaxValue().
396+
uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
397+
uint32_t BitWidth = Ty->getScalarSizeInBits();
398+
KnownBits AmtKnownBits =
399+
llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout());
400+
APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
401+
if (AmtKnownBits.getMaxValue().ult(BitWidth) &&
402+
IC.MaskedValueIsZero(I->getOperand(0), ShiftedBits, 0, CxtI)) {
403+
return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
404+
canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
403405
}
404406
break;
405407
}
@@ -409,15 +411,15 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC,
409411
// original type and the sign bit of the truncate type are similar.
410412
// TODO: It is enough to check that the bits we would be shifting in are
411413
// similar to sign bit of the truncate type.
412-
const APInt *Amt;
413-
if (match(I->getOperand(1), m_APInt(Amt))) {
414-
uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
415-
uint32_t BitWidth = Ty->getScalarSizeInBits();
416-
if (Amt->getLimitedValue(BitWidth) < BitWidth &&
417-
OrigBitWidth - BitWidth <
418-
IC.ComputeNumSignBits(I->getOperand(0), 0, CxtI))
419-
return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI);
420-
}
414+
uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
415+
uint32_t BitWidth = Ty->getScalarSizeInBits();
416+
KnownBits AmtKnownBits =
417+
llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout());
418+
unsigned ShiftedBits = OrigBitWidth - BitWidth;
419+
if (AmtKnownBits.getMaxValue().ult(BitWidth) &&
420+
ShiftedBits < IC.ComputeNumSignBits(I->getOperand(0), 0, CxtI))
421+
return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
422+
canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
421423
break;
422424
}
423425
case Instruction::Trunc:

llvm/test/Transforms/InstCombine/2008-01-21-MulTrunc.ll

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,10 @@ define <2 x i16> @test1_vec(<2 x i16> %a) {
3535

3636
define <2 x i16> @test1_vec_nonuniform(<2 x i16> %a) {
3737
; CHECK-LABEL: @test1_vec_nonuniform(
38-
; CHECK-NEXT: [[B:%.*]] = zext <2 x i16> [[A:%.*]] to <2 x i32>
39-
; CHECK-NEXT: [[C:%.*]] = lshr <2 x i32> [[B]], <i32 8, i32 9>
40-
; CHECK-NEXT: [[D:%.*]] = mul nuw nsw <2 x i32> [[B]], <i32 5, i32 6>
41-
; CHECK-NEXT: [[E:%.*]] = or <2 x i32> [[C]], [[D]]
42-
; CHECK-NEXT: [[F:%.*]] = trunc <2 x i32> [[E]] to <2 x i16>
43-
; CHECK-NEXT: ret <2 x i16> [[F]]
38+
; CHECK-NEXT: [[C:%.*]] = lshr <2 x i16> [[A:%.*]], <i16 8, i16 9>
39+
; CHECK-NEXT: [[D:%.*]] = mul <2 x i16> [[A]], <i16 5, i16 6>
40+
; CHECK-NEXT: [[E:%.*]] = or <2 x i16> [[C]], [[D]]
41+
; CHECK-NEXT: ret <2 x i16> [[E]]
4442
;
4543
%b = zext <2 x i16> %a to <2 x i32>
4644
%c = lshr <2 x i32> %b, <i32 8, i32 9>

llvm/test/Transforms/InstCombine/cast.ll

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -502,12 +502,10 @@ define <2 x i16> @test40vec(<2 x i16> %a) {
502502

503503
define <2 x i16> @test40vec_nonuniform(<2 x i16> %a) {
504504
; ALL-LABEL: @test40vec_nonuniform(
505-
; ALL-NEXT: [[T:%.*]] = zext <2 x i16> [[A:%.*]] to <2 x i32>
506-
; ALL-NEXT: [[T21:%.*]] = lshr <2 x i32> [[T]], <i32 9, i32 10>
507-
; ALL-NEXT: [[T5:%.*]] = shl <2 x i32> [[T]], <i32 8, i32 9>
508-
; ALL-NEXT: [[T32:%.*]] = or <2 x i32> [[T21]], [[T5]]
509-
; ALL-NEXT: [[R:%.*]] = trunc <2 x i32> [[T32]] to <2 x i16>
510-
; ALL-NEXT: ret <2 x i16> [[R]]
505+
; ALL-NEXT: [[T21:%.*]] = lshr <2 x i16> [[A:%.*]], <i16 9, i16 10>
506+
; ALL-NEXT: [[T5:%.*]] = shl <2 x i16> [[A]], <i16 8, i16 9>
507+
; ALL-NEXT: [[T32:%.*]] = or <2 x i16> [[T21]], [[T5]]
508+
; ALL-NEXT: ret <2 x i16> [[T32]]
511509
;
512510
%t = zext <2 x i16> %a to <2 x i32>
513511
%t21 = lshr <2 x i32> %t, <i32 9, i32 10>

llvm/test/Transforms/InstCombine/trunc.ll

Lines changed: 41 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -286,12 +286,11 @@ define <2 x i64> @test8_vec(<2 x i32> %A, <2 x i32> %B) {
286286

287287
define <2 x i64> @test8_vec_nonuniform(<2 x i32> %A, <2 x i32> %B) {
288288
; CHECK-LABEL: @test8_vec_nonuniform(
289-
; CHECK-NEXT: [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i128>
290-
; CHECK-NEXT: [[D:%.*]] = zext <2 x i32> [[B:%.*]] to <2 x i128>
291-
; CHECK-NEXT: [[E:%.*]] = shl <2 x i128> [[D]], <i128 32, i128 48>
292-
; CHECK-NEXT: [[F:%.*]] = or <2 x i128> [[E]], [[C]]
293-
; CHECK-NEXT: [[G:%.*]] = trunc <2 x i128> [[F]] to <2 x i64>
294-
; CHECK-NEXT: ret <2 x i64> [[G]]
289+
; CHECK-NEXT: [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i64>
290+
; CHECK-NEXT: [[D:%.*]] = zext <2 x i32> [[B:%.*]] to <2 x i64>
291+
; CHECK-NEXT: [[E:%.*]] = shl <2 x i64> [[D]], <i64 32, i64 48>
292+
; CHECK-NEXT: [[F:%.*]] = or <2 x i64> [[E]], [[C]]
293+
; CHECK-NEXT: ret <2 x i64> [[F]]
295294
;
296295
%C = zext <2 x i32> %A to <2 x i128>
297296
%D = zext <2 x i32> %B to <2 x i128>
@@ -343,12 +342,11 @@ define i8 @test10(i32 %X) {
343342

344343
define i64 @test11(i32 %A, i32 %B) {
345344
; CHECK-LABEL: @test11(
346-
; CHECK-NEXT: [[C:%.*]] = zext i32 [[A:%.*]] to i128
345+
; CHECK-NEXT: [[C:%.*]] = zext i32 [[A:%.*]] to i64
347346
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[B:%.*]], 31
348-
; CHECK-NEXT: [[E:%.*]] = zext i32 [[TMP1]] to i128
349-
; CHECK-NEXT: [[F:%.*]] = shl i128 [[C]], [[E]]
350-
; CHECK-NEXT: [[G:%.*]] = trunc i128 [[F]] to i64
351-
; CHECK-NEXT: ret i64 [[G]]
347+
; CHECK-NEXT: [[E:%.*]] = zext i32 [[TMP1]] to i64
348+
; CHECK-NEXT: [[F:%.*]] = shl i64 [[C]], [[E]]
349+
; CHECK-NEXT: ret i64 [[F]]
352350
;
353351
%C = zext i32 %A to i128
354352
%D = zext i32 %B to i128
@@ -360,12 +358,11 @@ define i64 @test11(i32 %A, i32 %B) {
360358

361359
define <2 x i64> @test11_vec(<2 x i32> %A, <2 x i32> %B) {
362360
; CHECK-LABEL: @test11_vec(
363-
; CHECK-NEXT: [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i128>
361+
; CHECK-NEXT: [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i64>
364362
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i32> [[B:%.*]], <i32 31, i32 31>
365-
; CHECK-NEXT: [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i128>
366-
; CHECK-NEXT: [[F:%.*]] = shl <2 x i128> [[C]], [[E]]
367-
; CHECK-NEXT: [[G:%.*]] = trunc <2 x i128> [[F]] to <2 x i64>
368-
; CHECK-NEXT: ret <2 x i64> [[G]]
363+
; CHECK-NEXT: [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
364+
; CHECK-NEXT: [[F:%.*]] = shl <2 x i64> [[C]], [[E]]
365+
; CHECK-NEXT: ret <2 x i64> [[F]]
369366
;
370367
%C = zext <2 x i32> %A to <2 x i128>
371368
%D = zext <2 x i32> %B to <2 x i128>
@@ -377,12 +374,11 @@ define <2 x i64> @test11_vec(<2 x i32> %A, <2 x i32> %B) {
377374

378375
define <2 x i64> @test11_vec_nonuniform(<2 x i32> %A, <2 x i32> %B) {
379376
; CHECK-LABEL: @test11_vec_nonuniform(
380-
; CHECK-NEXT: [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i128>
377+
; CHECK-NEXT: [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i64>
381378
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i32> [[B:%.*]], <i32 31, i32 15>
382-
; CHECK-NEXT: [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i128>
383-
; CHECK-NEXT: [[F:%.*]] = shl <2 x i128> [[C]], [[E]]
384-
; CHECK-NEXT: [[G:%.*]] = trunc <2 x i128> [[F]] to <2 x i64>
385-
; CHECK-NEXT: ret <2 x i64> [[G]]
379+
; CHECK-NEXT: [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
380+
; CHECK-NEXT: [[F:%.*]] = shl <2 x i64> [[C]], [[E]]
381+
; CHECK-NEXT: ret <2 x i64> [[F]]
386382
;
387383
%C = zext <2 x i32> %A to <2 x i128>
388384
%D = zext <2 x i32> %B to <2 x i128>
@@ -411,12 +407,11 @@ define <2 x i64> @test11_vec_undef(<2 x i32> %A, <2 x i32> %B) {
411407

412408
define i64 @test12(i32 %A, i32 %B) {
413409
; CHECK-LABEL: @test12(
414-
; CHECK-NEXT: [[C:%.*]] = zext i32 [[A:%.*]] to i128
410+
; CHECK-NEXT: [[C:%.*]] = zext i32 [[A:%.*]] to i64
415411
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[B:%.*]], 31
416-
; CHECK-NEXT: [[E:%.*]] = zext i32 [[TMP1]] to i128
417-
; CHECK-NEXT: [[F:%.*]] = lshr i128 [[C]], [[E]]
418-
; CHECK-NEXT: [[G:%.*]] = trunc i128 [[F]] to i64
419-
; CHECK-NEXT: ret i64 [[G]]
412+
; CHECK-NEXT: [[E:%.*]] = zext i32 [[TMP1]] to i64
413+
; CHECK-NEXT: [[F:%.*]] = lshr i64 [[C]], [[E]]
414+
; CHECK-NEXT: ret i64 [[F]]
420415
;
421416
%C = zext i32 %A to i128
422417
%D = zext i32 %B to i128
@@ -428,12 +423,11 @@ define i64 @test12(i32 %A, i32 %B) {
428423

429424
define <2 x i64> @test12_vec(<2 x i32> %A, <2 x i32> %B) {
430425
; CHECK-LABEL: @test12_vec(
431-
; CHECK-NEXT: [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i128>
426+
; CHECK-NEXT: [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i64>
432427
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i32> [[B:%.*]], <i32 31, i32 31>
433-
; CHECK-NEXT: [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i128>
434-
; CHECK-NEXT: [[F:%.*]] = lshr <2 x i128> [[C]], [[E]]
435-
; CHECK-NEXT: [[G:%.*]] = trunc <2 x i128> [[F]] to <2 x i64>
436-
; CHECK-NEXT: ret <2 x i64> [[G]]
428+
; CHECK-NEXT: [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
429+
; CHECK-NEXT: [[F:%.*]] = lshr <2 x i64> [[C]], [[E]]
430+
; CHECK-NEXT: ret <2 x i64> [[F]]
437431
;
438432
%C = zext <2 x i32> %A to <2 x i128>
439433
%D = zext <2 x i32> %B to <2 x i128>
@@ -445,12 +439,11 @@ define <2 x i64> @test12_vec(<2 x i32> %A, <2 x i32> %B) {
445439

446440
define <2 x i64> @test12_vec_nonuniform(<2 x i32> %A, <2 x i32> %B) {
447441
; CHECK-LABEL: @test12_vec_nonuniform(
448-
; CHECK-NEXT: [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i128>
442+
; CHECK-NEXT: [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i64>
449443
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i32> [[B:%.*]], <i32 31, i32 15>
450-
; CHECK-NEXT: [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i128>
451-
; CHECK-NEXT: [[F:%.*]] = lshr <2 x i128> [[C]], [[E]]
452-
; CHECK-NEXT: [[G:%.*]] = trunc <2 x i128> [[F]] to <2 x i64>
453-
; CHECK-NEXT: ret <2 x i64> [[G]]
444+
; CHECK-NEXT: [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
445+
; CHECK-NEXT: [[F:%.*]] = lshr <2 x i64> [[C]], [[E]]
446+
; CHECK-NEXT: ret <2 x i64> [[F]]
454447
;
455448
%C = zext <2 x i32> %A to <2 x i128>
456449
%D = zext <2 x i32> %B to <2 x i128>
@@ -479,12 +472,11 @@ define <2 x i64> @test12_vec_undef(<2 x i32> %A, <2 x i32> %B) {
479472

480473
define i64 @test13(i32 %A, i32 %B) {
481474
; CHECK-LABEL: @test13(
482-
; CHECK-NEXT: [[C:%.*]] = sext i32 [[A:%.*]] to i128
475+
; CHECK-NEXT: [[C:%.*]] = sext i32 [[A:%.*]] to i64
483476
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[B:%.*]], 31
484-
; CHECK-NEXT: [[E:%.*]] = zext i32 [[TMP1]] to i128
485-
; CHECK-NEXT: [[F:%.*]] = ashr i128 [[C]], [[E]]
486-
; CHECK-NEXT: [[G:%.*]] = trunc i128 [[F]] to i64
487-
; CHECK-NEXT: ret i64 [[G]]
477+
; CHECK-NEXT: [[E:%.*]] = zext i32 [[TMP1]] to i64
478+
; CHECK-NEXT: [[F:%.*]] = ashr i64 [[C]], [[E]]
479+
; CHECK-NEXT: ret i64 [[F]]
488480
;
489481
%C = sext i32 %A to i128
490482
%D = zext i32 %B to i128
@@ -496,12 +488,11 @@ define i64 @test13(i32 %A, i32 %B) {
496488

497489
define <2 x i64> @test13_vec(<2 x i32> %A, <2 x i32> %B) {
498490
; CHECK-LABEL: @test13_vec(
499-
; CHECK-NEXT: [[C:%.*]] = sext <2 x i32> [[A:%.*]] to <2 x i128>
491+
; CHECK-NEXT: [[C:%.*]] = sext <2 x i32> [[A:%.*]] to <2 x i64>
500492
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i32> [[B:%.*]], <i32 31, i32 31>
501-
; CHECK-NEXT: [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i128>
502-
; CHECK-NEXT: [[F:%.*]] = ashr <2 x i128> [[C]], [[E]]
503-
; CHECK-NEXT: [[G:%.*]] = trunc <2 x i128> [[F]] to <2 x i64>
504-
; CHECK-NEXT: ret <2 x i64> [[G]]
493+
; CHECK-NEXT: [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
494+
; CHECK-NEXT: [[F:%.*]] = ashr <2 x i64> [[C]], [[E]]
495+
; CHECK-NEXT: ret <2 x i64> [[F]]
505496
;
506497
%C = sext <2 x i32> %A to <2 x i128>
507498
%D = zext <2 x i32> %B to <2 x i128>
@@ -513,12 +504,11 @@ define <2 x i64> @test13_vec(<2 x i32> %A, <2 x i32> %B) {
513504

514505
define <2 x i64> @test13_vec_nonuniform(<2 x i32> %A, <2 x i32> %B) {
515506
; CHECK-LABEL: @test13_vec_nonuniform(
516-
; CHECK-NEXT: [[C:%.*]] = sext <2 x i32> [[A:%.*]] to <2 x i128>
507+
; CHECK-NEXT: [[C:%.*]] = sext <2 x i32> [[A:%.*]] to <2 x i64>
517508
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i32> [[B:%.*]], <i32 31, i32 15>
518-
; CHECK-NEXT: [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i128>
519-
; CHECK-NEXT: [[F:%.*]] = ashr <2 x i128> [[C]], [[E]]
520-
; CHECK-NEXT: [[G:%.*]] = trunc <2 x i128> [[F]] to <2 x i64>
521-
; CHECK-NEXT: ret <2 x i64> [[G]]
509+
; CHECK-NEXT: [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
510+
; CHECK-NEXT: [[F:%.*]] = ashr <2 x i64> [[C]], [[E]]
511+
; CHECK-NEXT: ret <2 x i64> [[F]]
522512
;
523513
%C = sext <2 x i32> %A to <2 x i128>
524514
%D = zext <2 x i32> %B to <2 x i128>

0 commit comments

Comments
 (0)