Skip to content

[RISCV] Porting hasAllNBitUsers to RISCV GISel for instruction select #124678

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,21 @@ class RISCVInstructionSelector : public InstructionSelector {
const TargetRegisterClass *
getRegClassForTypeOnBank(LLT Ty, const RegisterBank &RB) const;

static constexpr unsigned MaxRecursionDepth = 6;

// const MachineInstr &MI
bool hasAllNBitUsers(const MachineInstr &MI, unsigned Bits,
const unsigned Depth = 0) const;
bool hasAllBUsers(const MachineInstr &MI) const {
return hasAllNBitUsers(MI, 8);
}
bool hasAllHUsers(const MachineInstr &MI) const {
return hasAllNBitUsers(MI, 16);
}
bool hasAllWUsers(const MachineInstr &MI) const {
return hasAllNBitUsers(MI, 32);
}

bool isRegInGprb(Register Reg) const;
bool isRegInFprb(Register Reg) const;

Expand Down Expand Up @@ -186,6 +201,79 @@ RISCVInstructionSelector::RISCVInstructionSelector(
{
}

// Mimics optimizations in ISel and RISCVOptWInst Pass
bool RISCVInstructionSelector::hasAllNBitUsers(const MachineInstr &MI,
unsigned Bits,
const unsigned Depth) const {

assert((MI.getOpcode() == TargetOpcode::G_ADD ||
MI.getOpcode() == TargetOpcode::G_SUB ||
MI.getOpcode() == TargetOpcode::G_MUL ||
MI.getOpcode() == TargetOpcode::G_SHL ||
MI.getOpcode() == TargetOpcode::G_LSHR ||
MI.getOpcode() == TargetOpcode::G_AND ||
MI.getOpcode() == TargetOpcode::G_OR ||
MI.getOpcode() == TargetOpcode::G_XOR ||
MI.getOpcode() == TargetOpcode::G_SEXT_INREG || Depth != 0) &&
"Unexpected opcode");

if (Depth >= RISCVInstructionSelector::MaxRecursionDepth)
return false;

auto DestReg = MI.getOperand(0).getReg();
for (auto &UserOp : MRI->use_nodbg_operands(DestReg)) {
assert(UserOp.getParent() && "UserOp must have a parent");
const MachineInstr &UserMI = *UserOp.getParent();
unsigned OpIdx = UserOp.getOperandNo();

switch (UserMI.getOpcode()) {
default:
return false;
case RISCV::ADDW:
case RISCV::ADDIW:
case RISCV::SUBW:
if (Bits >= 32)
break;
return false;
case RISCV::SLL:
case RISCV::SRA:
case RISCV::SRL:
// Shift amount operands only use log2(Xlen) bits.
if (OpIdx == 2 && Bits >= Log2_32(Subtarget->getXLen()))
break;
return false;
case RISCV::SLLI:
// SLLI only uses the lower (XLen - ShAmt) bits.
if (Bits >= Subtarget->getXLen() - UserMI.getOperand(2).getImm())
break;
return false;
case RISCV::ANDI:
if (Bits >= (unsigned)llvm::bit_width<uint64_t>(
(uint64_t)UserMI.getOperand(2).getImm()))
break;
goto RecCheck;
case RISCV::AND:
case RISCV::OR:
case RISCV::XOR:
RecCheck:
if (hasAllNBitUsers(UserMI, Bits, Depth + 1))
break;
return false;
case RISCV::SRLI: {
unsigned ShAmt = UserMI.getOperand(2).getImm();
// If we are shifting right by less than Bits, and users don't demand any
// bits that were shifted into [Bits-1:0], then we can consider this as an
// N-Bit user.
if (Bits > ShAmt && hasAllNBitUsers(UserMI, Bits - ShAmt, Depth + 1))
break;
return false;
}
}
}

return true;
}

InstructionSelector::ComplexRendererFns
RISCVInstructionSelector::selectShiftMask(MachineOperand &Root,
unsigned ShiftWidth) const {
Expand Down
14 changes: 9 additions & 5 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1943,17 +1943,21 @@ def : Pat<(i64 (shl (and GPR:$rs1, 0xffffffff), uimm5:$shamt)),

class binop_allhusers<SDPatternOperator operator>
: PatFrag<(ops node:$lhs, node:$rhs),
(XLenVT (operator node:$lhs, node:$rhs)), [{
(XLenVT(operator node:$lhs, node:$rhs)), [{
return hasAllHUsers(Node);
}]>;
}]> {
let GISelPredicateCode = [{ return hasAllHUsers(MI); }];
}

// PatFrag to allow ADDW/SUBW/MULW/SLLW to be selected from i64 add/sub/mul/shl
// if only the lower 32 bits of their result is used.
class binop_allwusers<SDPatternOperator operator>
: PatFrag<(ops node:$lhs, node:$rhs),
(i64 (operator node:$lhs, node:$rhs)), [{
: PatFrag<(ops node:$lhs, node:$rhs), (i64(operator node:$lhs, node:$rhs)),
[{
return hasAllWUsers(Node);
}]>;
}]> {
let GISelPredicateCode = [{ return hasAllWUsers(MI); }];
}

def sexti32_allwusers : PatFrag<(ops node:$src),
(sext_inreg node:$src, i32), [{
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/RISCV/GlobalISel/combine.ll
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ define i32 @constant_to_rhs(i32 %x) {
; RV64-O0: # %bb.0:
; RV64-O0-NEXT: mv a1, a0
; RV64-O0-NEXT: li a0, 1
; RV64-O0-NEXT: add a0, a0, a1
; RV64-O0-NEXT: addw a0, a0, a1
; RV64-O0-NEXT: sext.w a0, a0
; RV64-O0-NEXT: ret
;
Expand Down
Loading