Skip to content

Commit c8dc21d

Browse files
authored
[SelectionDAG][RISCV] Fix break of vnsrl pattern in issue #94265 (#95563)
Added a RISCV overload of `isTruncateFree` to fix the break of vnsrl described in issue #94265. Fixes #94265
1 parent c28ddf9 commit c8dc21d

File tree

4 files changed

+58
-1
lines changed

4 files changed

+58
-1
lines changed

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2586,6 +2586,17 @@ bool TargetLowering::SimplifyDemandedBits(
25862586
break;
25872587

25882588
if (Src.getNode()->hasOneUse()) {
2589+
if (isTruncateFree(Src, VT) &&
2590+
!isTruncateFree(Src.getValueType(), VT)) {
2591+
// If truncate is only free at trunc(srl), do not turn it into
2592+
// srl(trunc). The check is done by first check the truncate is free
2593+
// at Src's opcode(srl), then check the truncate is not done by
2594+
// referencing sub-register. In test, if both trunc(srl) and
2595+
// srl(trunc)'s trunc are free, srl(trunc) performs better. If only
2596+
// trunc(srl)'s trunc is free, trunc(srl) is better.
2597+
break;
2598+
}
2599+
25892600
std::optional<uint64_t> ShAmtC =
25902601
TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 2);
25912602
if (!ShAmtC || *ShAmtC >= BitWidth)
@@ -2596,7 +2607,6 @@ bool TargetLowering::SimplifyDemandedBits(
25962607
APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth);
25972608
HighBits.lshrInPlace(ShVal);
25982609
HighBits = HighBits.trunc(BitWidth);
2599-
26002610
if (!(HighBits & DemandedBits)) {
26012611
// None of the shifted in bits are needed. Add a truncate of the
26022612
// shift input, then shift it.

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1894,6 +1894,21 @@ bool RISCVTargetLowering::isTruncateFree(EVT SrcVT, EVT DstVT) const {
18941894
return (SrcBits == 64 && DestBits == 32);
18951895
}
18961896

1897+
bool RISCVTargetLowering::isTruncateFree(SDValue Val, EVT VT2) const {
1898+
EVT SrcVT = Val.getValueType();
1899+
// free truncate from vnsrl and vnsra
1900+
if (Subtarget.hasStdExtV() &&
1901+
(Val.getOpcode() == ISD::SRL || Val.getOpcode() == ISD::SRA) &&
1902+
SrcVT.isVector() && VT2.isVector()) {
1903+
unsigned SrcBits = SrcVT.getVectorElementType().getSizeInBits();
1904+
unsigned DestBits = VT2.getVectorElementType().getSizeInBits();
1905+
if (SrcBits == DestBits * 2) {
1906+
return true;
1907+
}
1908+
}
1909+
return TargetLowering::isTruncateFree(Val, VT2);
1910+
}
1911+
18971912
bool RISCVTargetLowering::isZExtFree(SDValue Val, EVT VT2) const {
18981913
// Zexts are free if they can be combined with a load.
18991914
// Don't advertise i32->i64 zextload as being free for RV64. It interacts

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ class RISCVTargetLowering : public TargetLowering {
497497
bool isLegalAddImmediate(int64_t Imm) const override;
498498
bool isTruncateFree(Type *SrcTy, Type *DstTy) const override;
499499
bool isTruncateFree(EVT SrcVT, EVT DstVT) const override;
500+
bool isTruncateFree(SDValue Val, EVT VT2) const override;
500501
bool isZExtFree(SDValue Val, EVT VT2) const override;
501502
bool isSExtCheaperThanZExt(EVT SrcVT, EVT DstVT) const override;
502503
bool signExtendConstant(const ConstantInt *CI) const override;

llvm/test/CodeGen/RISCV/pr94265.ll

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc < %s -mtriple=riscv32-- -mattr=+v | FileCheck -check-prefix=RV32I %s
3+
; RUN: llc < %s -mtriple=riscv64-- -mattr=+v | FileCheck -check-prefix=RV64I %s
4+
5+
define <8 x i16> @PR94265(<8 x i32> %a0) #0 {
6+
; RV32I-LABEL: PR94265:
7+
; RV32I: # %bb.0:
8+
; RV32I-NEXT: vsetivli zero, 8, e32, m2, ta, ma
9+
; RV32I-NEXT: vsra.vi v10, v8, 31
10+
; RV32I-NEXT: vsrl.vi v10, v10, 26
11+
; RV32I-NEXT: vadd.vv v8, v8, v10
12+
; RV32I-NEXT: vsetvli zero, zero, e16, m1, ta, ma
13+
; RV32I-NEXT: vnsrl.wi v10, v8, 6
14+
; RV32I-NEXT: vsll.vi v8, v10, 10
15+
; RV32I-NEXT: ret
16+
;
17+
; RV64I-LABEL: PR94265:
18+
; RV64I: # %bb.0:
19+
; RV64I-NEXT: vsetivli zero, 8, e32, m2, ta, ma
20+
; RV64I-NEXT: vsra.vi v10, v8, 31
21+
; RV64I-NEXT: vsrl.vi v10, v10, 26
22+
; RV64I-NEXT: vadd.vv v8, v8, v10
23+
; RV64I-NEXT: vsetvli zero, zero, e16, m1, ta, ma
24+
; RV64I-NEXT: vnsrl.wi v10, v8, 6
25+
; RV64I-NEXT: vsll.vi v8, v10, 10
26+
; RV64I-NEXT: ret
27+
%t1 = sdiv <8 x i32> %a0, <i32 64, i32 64, i32 64, i32 64, i32 64, i32 64, i32 64, i32 64>
28+
%t2 = trunc <8 x i32> %t1 to <8 x i16>
29+
%t3 = shl <8 x i16> %t2, <i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10>
30+
ret <8 x i16> %t3
31+
}

0 commit comments

Comments
 (0)