Skip to content

[DAG] Use SDValue for PatFrag checks #137519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions llvm/include/llvm/CodeGen/SelectionDAGISel.h
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ class SelectionDAGISel {
/// It runs node predicate number PredNo and returns true if it succeeds or
/// false if it fails. The number is a private implementation
/// detail to the code tblgen produces.
virtual bool CheckNodePredicate(SDNode *N, unsigned PredNo) const {
virtual bool CheckNodePredicate(SDValue Op, unsigned PredNo) const {
llvm_unreachable("Tblgen should generate the implementation of this!");
}

Expand All @@ -436,7 +436,7 @@ class SelectionDAGISel {
/// false if it fails. The number is a private implementation detail to the
/// code tblgen produces.
virtual bool CheckNodePredicateWithOperands(
SDNode *N, unsigned PredNo,
SDValue Op, unsigned PredNo,
const SmallVectorImpl<SDValue> &Operands) const {
llvm_unreachable("Tblgen should generate the implementation of this!");
}
Expand Down
11 changes: 5 additions & 6 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2897,11 +2897,11 @@ CheckPatternPredicate(unsigned Opcode, const unsigned char *MatcherTable,
LLVM_ATTRIBUTE_ALWAYS_INLINE static bool
CheckNodePredicate(unsigned Opcode, const unsigned char *MatcherTable,
unsigned &MatcherIndex, const SelectionDAGISel &SDISel,
SDNode *N) {
SDValue Op) {
unsigned PredNo = Opcode == SelectionDAGISel::OPC_CheckPredicate
? MatcherTable[MatcherIndex++]
: Opcode - SelectionDAGISel::OPC_CheckPredicate0;
return SDISel.CheckNodePredicate(N, PredNo);
return SDISel.CheckNodePredicate(Op, PredNo);
}

LLVM_ATTRIBUTE_ALWAYS_INLINE static bool
Expand Down Expand Up @@ -3062,7 +3062,7 @@ static unsigned IsPredicateKnownToFail(const unsigned char *Table,
case SelectionDAGISel::OPC_CheckPredicate5:
case SelectionDAGISel::OPC_CheckPredicate6:
case SelectionDAGISel::OPC_CheckPredicate7:
Result = !::CheckNodePredicate(Opcode, Table, Index, SDISel, N.getNode());
Result = !::CheckNodePredicate(Opcode, Table, Index, SDISel, N);
return Index;
case SelectionDAGISel::OPC_CheckOpcode:
Result = !::CheckOpcode(Table, Index, N.getNode());
Expand Down Expand Up @@ -3575,8 +3575,7 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
case SelectionDAGISel::OPC_CheckPredicate6:
case SelectionDAGISel::OPC_CheckPredicate7:
case OPC_CheckPredicate:
if (!::CheckNodePredicate(Opcode, MatcherTable, MatcherIndex, *this,
N.getNode()))
if (!::CheckNodePredicate(Opcode, MatcherTable, MatcherIndex, *this, N))
break;
continue;
case OPC_CheckPredicateWithOperands: {
Expand All @@ -3587,7 +3586,7 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
Operands.push_back(RecordedNodes[MatcherTable[MatcherIndex++]].first);

unsigned PredNo = MatcherTable[MatcherIndex++];
if (!CheckNodePredicateWithOperands(N.getNode(), PredNo, Operands))
if (!CheckNodePredicateWithOperands(N, PredNo, Operands))
break;
continue;
}
Expand Down
16 changes: 8 additions & 8 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -683,24 +683,24 @@ defm trunc_masked_scatter_i32 : masked_gather_scatter<trunc_masked_scatter_i32>;

// top16Zero - answer true if the upper 16 bits of $src are 0, false otherwise
def top16Zero: PatLeaf<(i32 GPR32:$src), [{
return SDValue(N,0)->getValueType(0) == MVT::i32 &&
CurDAG->MaskedValueIsZero(SDValue(N,0), APInt::getHighBitsSet(32, 16));
return Op.getValueType() == MVT::i32 &&
CurDAG->MaskedValueIsZero(Op, APInt::getHighBitsSet(32, 16));
}]>;

// top32Zero - answer true if the upper 32 bits of $src are 0, false otherwise
def top32Zero: PatLeaf<(i64 GPR64:$src), [{
return SDValue(N,0)->getValueType(0) == MVT::i64 &&
CurDAG->MaskedValueIsZero(SDValue(N,0), APInt::getHighBitsSet(64, 32));
return Op.getValueType() == MVT::i64 &&
CurDAG->MaskedValueIsZero(Op, APInt::getHighBitsSet(64, 32));
}]>;

// topbitsallzero - Return true if all bits except the lowest bit are known zero
def topbitsallzero32: PatLeaf<(i32 GPR32:$src), [{
return SDValue(N,0)->getValueType(0) == MVT::i32 &&
CurDAG->MaskedValueIsZero(SDValue(N,0), APInt::getHighBitsSet(32, 31));
return Op.getValueType() == MVT::i32 &&
CurDAG->MaskedValueIsZero(Op, APInt::getHighBitsSet(32, 31));
}]>;
def topbitsallzero64: PatLeaf<(i64 GPR64:$src), [{
return SDValue(N,0)->getValueType(0) == MVT::i64 &&
CurDAG->MaskedValueIsZero(SDValue(N,0), APInt::getHighBitsSet(64, 63));
return Op.getValueType() == MVT::i64 &&
CurDAG->MaskedValueIsZero(Op, APInt::getHighBitsSet(64, 63));
}]>;

// Node definitions.
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AMDGPU/SIInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,7 @@ def MFMALdScaleXForm : SDNodeXForm<timm, [{
def is_canonicalized : PatLeaf<(fAny srcvalue:$src), [{
const SITargetLowering &Lowering =
*static_cast<const SITargetLowering *>(getTargetLowering());
return Lowering.isCanonicalized(*CurDAG, SDValue(N, 0));
return Lowering.isCanonicalized(*CurDAG, Op);
}]> {
let GISelPredicateCode = [{
const SITargetLowering *TLI = static_cast<const SITargetLowering *>(
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AMDGPU/SIInstructions.td
Original file line number Diff line number Diff line change
Expand Up @@ -3861,7 +3861,7 @@ def : AMDGPUPat <
>;

def uint5Bits : PatLeaf<(i32 VGPR_32:$width), [{
return CurDAG->computeKnownBits(SDValue(N, 0)).countMaxActiveBits() <= 5;
return CurDAG->computeKnownBits(Op).countMaxActiveBits() <= 5;
}]>;

// x & (-1 >> (bitwidth - y))
Expand Down
10 changes: 5 additions & 5 deletions llvm/lib/Target/ARM/ARMInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def imm16_31 : ImmLeaf<i32, [{

// sext_16_node predicate - True if the SDNode is sign-extended 16 or more bits.
def sext_16_node : PatLeaf<(i32 GPR:$a), [{
return CurDAG->ComputeNumSignBits(SDValue(N,0)) >= 17;
return CurDAG->ComputeNumSignBits(Op) >= 17;
}]>;

def sext_bottom_16 : PatFrag<(ops node:$a),
Expand Down Expand Up @@ -451,14 +451,14 @@ def lo16AllZero : PatLeaf<(i32 imm), [{

// top16Zero - answer true if the upper 16 bits of $src are 0, false otherwise
def top16Zero: PatLeaf<(i32 GPR:$src), [{
return !SDValue(N,0)->getValueType(0).isVector() &&
CurDAG->MaskedValueIsZero(SDValue(N,0), APInt::getHighBitsSet(32, 16));
return !Op.getValueType().isVector() &&
CurDAG->MaskedValueIsZero(Op, APInt::getHighBitsSet(32, 16));
}]>;

// topbitsallzero - Return true if all bits except the lowest bit are known zero
def topbitsallzero32 : PatLeaf<(i32 GPRwithZR:$src), [{
return SDValue(N,0)->getValueType(0) == MVT::i32 &&
CurDAG->MaskedValueIsZero(SDValue(N,0), APInt::getHighBitsSet(32, 31));
return Op.getValueType() == MVT::i32 &&
CurDAG->MaskedValueIsZero(Op, APInt::getHighBitsSet(32, 31));
}]>;

class BinOpFrag<dag res> : PatFrag<(ops node:$LHS, node:$RHS), res>;
Expand Down
8 changes: 4 additions & 4 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1337,7 +1337,7 @@ def ext_oneuse : unop_oneuse<ext>;
def fpext_oneuse : unop_oneuse<any_fpextend>;

def 33signbits_node : PatLeaf<(i64 GPR:$src), [{
return CurDAG->ComputeNumSignBits(SDValue(N, 0)) > 32;
return CurDAG->ComputeNumSignBits(Op) > 32;
}]>;

class immop_oneuse<ImmLeaf leaf> : PatLeaf<(leaf), [{
Expand Down Expand Up @@ -1977,7 +1977,7 @@ def : Pat<(i64 (shl (and GPR:$rs1, 0xffffffff), uimm5:$shamt)),
class binop_allhusers<SDPatternOperator operator>
: PatFrag<(ops node:$lhs, node:$rhs),
(XLenVT (operator node:$lhs, node:$rhs)), [{
return hasAllHUsers(Node);
return hasAllHUsers(N);
}]> {
let GISelPredicateCode = [{ return hasAllHUsers(MI); }];
}
Expand All @@ -1987,14 +1987,14 @@ class binop_allhusers<SDPatternOperator operator>
class binop_allwusers<SDPatternOperator operator>
: PatFrag<(ops node:$lhs, node:$rhs), (i64 (operator node:$lhs, node:$rhs)),
[{
return hasAllWUsers(Node);
return hasAllWUsers(N);
}]> {
let GISelPredicateCode = [{ return hasAllWUsers(MI); }];
}

def sexti32_allwusers : PatFrag<(ops node:$src),
(sext_inreg node:$src, i32), [{
return hasAllWUsers(Node);
return hasAllWUsers(N);
}]>;

def ImmSExt32 : SDNodeXForm<imm, [{
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def TypeIndex : Operand<i32>;

// TODO: Find more places to use this.
def bool_node : PatLeaf<(i32 I32:$cond), [{
return CurDAG->computeKnownBits(SDValue(N, 0)).countMinLeadingZeros() == 31;
return CurDAG->computeKnownBits(Op).countMinLeadingZeros() == 31;
}]>;

//===----------------------------------------------------------------------===//
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/X86/X86InstrSSE.td
Original file line number Diff line number Diff line change
Expand Up @@ -5705,7 +5705,7 @@ let Predicates = [UseSSE41, OptForSize] in {
// commuting would change which operand is inverted.
def X86ptest_commutable : PatFrag<(ops node:$src1, node:$src2),
(X86ptest node:$src1, node:$src2), [{
return onlyUsesZeroFlag(SDValue(Node, 0));
return onlyUsesZeroFlag(SDValue(N, 0));
}]>;

// ptest instruction we'll lower to this in X86ISelLowering primarily from
Expand Down Expand Up @@ -5772,7 +5772,7 @@ multiclass avx_bittest<bits<8> opc, string OpcodeStr, RegisterClass RC,
// used, commuting would change which operand is inverted.
def X86testp_commutable : PatFrag<(ops node:$src1, node:$src2),
(X86testp node:$src1, node:$src2), [{
return onlyUsesZeroFlag(SDValue(Node, 0));
return onlyUsesZeroFlag(SDValue(N, 0));
}]>;

let Defs = [EFLAGS], Predicates = [HasAVX] in {
Expand Down
16 changes: 16 additions & 0 deletions llvm/test/CodeGen/AArch64/aarch64-mull-masks.ll
Original file line number Diff line number Diff line change
Expand Up @@ -2098,3 +2098,19 @@ B:
%t = icmp eq i64 0, %3
br i1 %t, label %A, label %B
}

define i64 @pr137274(ptr %ptr) {
; CHECK-LABEL: pr137274:
; CHECK: // %bb.0:
; CHECK-NEXT: ldr x8, [x0]
; CHECK-NEXT: ldr w9, [x8, #8]!
; CHECK-NEXT: mul x0, x8, x9
; CHECK-NEXT: ret
%l0 = load i64, ptr %ptr, align 8
%add = add i64 %l0, 8
%i1 = inttoptr i64 %add to ptr
%l2 = load i32, ptr %i1, align 4
%conv = zext i32 %l2 to i64
%mul = mul i64 %add, %conv
ret i64 %mul
}
2 changes: 1 addition & 1 deletion llvm/test/TableGen/HasNoUse.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def NO_RET_ATOMIC_ADD : I<(outs), (ins GPR32Op:$src0, GPR32Op:$src1), []>;

// SDAG: case 0: {
// SDAG-NEXT: // Predicate_atomic_load_add_no_ret_i32
// SDAG-NEXT: SDNode *N = Node;
// SDAG-NEXT: SDNode *N = Op.getNode();
// SDAG-NEXT: (void)N;
// SDAG-NEXT: if (cast<MemSDNode>(N)->getMemoryVT() != MVT::i32) return false;
// SDAG-NEXT: if (N->hasAnyUseOfValue(0)) return false;
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/TableGen/address-space-patfrags.td
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def inst_d : Instruction {
// SDAG: case 0: {
// SDAG-NEXT: // Predicate_pat_frag_b
// SDAG-NEXT: // Predicate_truncstorei16_addrspace
// SDAG-NEXT: SDNode *N = Node;
// SDAG-NEXT: SDNode *N = Op.getNode();
// SDAG-NEXT: (void)N;
// SDAG-NEXT: unsigned AddrSpace = cast<MemSDNode>(N)->getAddressSpace();
// SDAG-NEXT: if (AddrSpace != 123 && AddrSpace != 455)
Expand All @@ -71,7 +71,7 @@ def : Pat <

// SDAG: case 4: {
// SDAG: // Predicate_pat_frag_a
// SDAG-NEXT: SDNode *N = Node;
// SDAG-NEXT: SDNode *N = Op.getNode();
// SDAG-NEXT: (void)N;
// SDAG-NEXT: if (cast<MemSDNode>(N)->getAlign() < Align(2))
// SDAG-NEXT: return false;
Expand Down
10 changes: 5 additions & 5 deletions llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1375,11 +1375,11 @@ std::string TreePredicateFn::getCodeToRunOnSDNode() const {

std::string Result = (" " + getImmType() + " Imm = ").str();
if (immCodeUsesAPFloat())
Result += "cast<ConstantFPSDNode>(Node)->getValueAPF();\n";
Result += "cast<ConstantFPSDNode>(Op.getNode())->getValueAPF();\n";
else if (immCodeUsesAPInt())
Result += "Node->getAsAPIntVal();\n";
Result += "Op->getAsAPIntVal();\n";
else
Result += "cast<ConstantSDNode>(Node)->getSExtValue();\n";
Result += "cast<ConstantSDNode>(Op.getNode())->getSExtValue();\n";
return Result + ImmCode;
}

Expand Down Expand Up @@ -1410,9 +1410,9 @@ std::string TreePredicateFn::getCodeToRunOnSDNode() const {

std::string Result;
if (ClassName == "SDNode")
Result = " SDNode *N = Node;\n";
Result = " SDNode *N = Op.getNode();\n";
else
Result = " auto *N = cast<" + ClassName.str() + ">(Node);\n";
Result = " auto *N = cast<" + ClassName.str() + ">(Op.getNode());\n";

return (Twine(Result) + " (void)N;\n" + getPredCode()).str();
}
Expand Down
4 changes: 2 additions & 2 deletions llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1149,11 +1149,11 @@ void MatcherTableEmitter::EmitPredicateFunctions(raw_ostream &OS) {

// Emit Node predicates.
EmitNodePredicatesFunction(
NodePredicates, "CheckNodePredicate(SDNode *Node, unsigned PredNo) const",
NodePredicates, "CheckNodePredicate(SDValue Op, unsigned PredNo) const",
OS);
EmitNodePredicatesFunction(
NodePredicatesWithOperands,
"CheckNodePredicateWithOperands(SDNode *Node, unsigned PredNo, "
"CheckNodePredicateWithOperands(SDValue Op, unsigned PredNo, "
"const SmallVectorImpl<SDValue> &Operands) const",
OS);

Expand Down