Skip to content

Commit 1ac9b54

Browse files
committed
[RISCV] Lower GREVI and GORCI as custom nodes
This moves the recognition of GREVI and GORCI from TableGen patterns into a DAGCombine. This is done primarily to match "deeper" patterns in the future, like (grevi (grevi x, 1) 2) -> (grevi x, 3). TableGen is not best suited to matching patterns such as these as the compile time of the DAG matchers quickly gets out of hand due to the expansion of commutative permutations. Reviewed By: craig.topper Differential Revision: https://reviews.llvm.org/D91259
1 parent 5b7bd89 commit 1ac9b54

File tree

4 files changed

+241
-124
lines changed

4 files changed

+241
-124
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,10 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
259259

260260
// We can use any register for comparisons
261261
setHasMultipleConditionRegisters();
262+
263+
if (Subtarget.hasStdExtZbp()) {
264+
setTargetDAGCombine(ISD::OR);
265+
}
262266
}
263267

264268
EVT RISCVTargetLowering::getSetCCResultType(const DataLayout &DL, LLVMContext &,
@@ -904,6 +908,10 @@ static RISCVISD::NodeType getRISCVWOpcode(unsigned Opcode) {
904908
return RISCVISD::DIVUW;
905909
case ISD::UREM:
906910
return RISCVISD::REMUW;
911+
case RISCVISD::GREVI:
912+
return RISCVISD::GREVIW;
913+
case RISCVISD::GORCI:
914+
return RISCVISD::GORCIW;
907915
}
908916
}
909917

@@ -1026,7 +1034,186 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
10261034
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, FPConv));
10271035
break;
10281036
}
1037+
case RISCVISD::GREVI:
1038+
case RISCVISD::GORCI: {
1039+
assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
1040+
"Unexpected custom legalisation");
1041+
// This is similar to customLegalizeToWOp, except that we pass the second
1042+
// operand (a TargetConstant) straight through: it is already of type
1043+
// XLenVT.
1044+
SDLoc DL(N);
1045+
RISCVISD::NodeType WOpcode = getRISCVWOpcode(N->getOpcode());
1046+
SDValue NewOp0 =
1047+
DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(0));
1048+
SDValue NewRes =
1049+
DAG.getNode(WOpcode, DL, MVT::i64, NewOp0, N->getOperand(1));
1050+
// ReplaceNodeResults requires we maintain the same type for the return
1051+
// value.
1052+
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, NewRes));
1053+
break;
1054+
}
1055+
}
1056+
}
1057+
1058+
// A structure to hold one of the bit-manipulation patterns below. Together, a
1059+
// SHL and non-SHL pattern may form a bit-manipulation pair on a single source:
1060+
// (or (and (shl x, 1), 0xAAAAAAAA),
1061+
// (and (srl x, 1), 0x55555555))
1062+
struct RISCVBitmanipPat {
1063+
SDValue Op;
1064+
unsigned ShAmt;
1065+
bool IsSHL;
1066+
1067+
bool formsPairWith(const RISCVBitmanipPat &Other) const {
1068+
return Op == Other.Op && ShAmt == Other.ShAmt && IsSHL != Other.IsSHL;
1069+
}
1070+
};
1071+
1072+
// Matches any of the following bit-manipulation patterns:
1073+
// (and (shl x, 1), (0x55555555 << 1))
1074+
// (and (srl x, 1), 0x55555555)
1075+
// (shl (and x, 0x55555555), 1)
1076+
// (srl (and x, (0x55555555 << 1)), 1)
1077+
// where the shift amount and mask may vary thus:
1078+
// [1] = 0x55555555 / 0xAAAAAAAA
1079+
// [2] = 0x33333333 / 0xCCCCCCCC
1080+
// [4] = 0x0F0F0F0F / 0xF0F0F0F0
1081+
// [8] = 0x00FF00FF / 0xFF00FF00
1082+
// [16] = 0x0000FFFF / 0xFFFFFFFF
1083+
// [32] = 0x00000000FFFFFFFF / 0xFFFFFFFF00000000 (for RV64)
1084+
static Optional<RISCVBitmanipPat> matchRISCVBitmanipPat(SDValue Op) {
1085+
Optional<uint64_t> Mask;
1086+
// Optionally consume a mask around the shift operation.
1087+
if (Op.getOpcode() == ISD::AND && isa<ConstantSDNode>(Op.getOperand(1))) {
1088+
Mask = Op.getConstantOperandVal(1);
1089+
Op = Op.getOperand(0);
1090+
}
1091+
if (Op.getOpcode() != ISD::SHL && Op.getOpcode() != ISD::SRL)
1092+
return None;
1093+
bool IsSHL = Op.getOpcode() == ISD::SHL;
1094+
1095+
if (!isa<ConstantSDNode>(Op.getOperand(1)))
1096+
return None;
1097+
auto ShAmt = Op.getConstantOperandVal(1);
1098+
1099+
if (!isPowerOf2_64(ShAmt))
1100+
return None;
1101+
1102+
// These are the unshifted masks which we use to match bit-manipulation
1103+
// patterns. They may be shifted left in certain circumstances.
1104+
static const uint64_t BitmanipMasks[] = {
1105+
0x5555555555555555ULL, 0x3333333333333333ULL, 0x0F0F0F0F0F0F0F0FULL,
1106+
0x00FF00FF00FF00FFULL, 0x0000FFFF0000FFFFULL, 0x00000000FFFFFFFFULL,
1107+
};
1108+
1109+
unsigned MaskIdx = Log2_64(ShAmt);
1110+
if (MaskIdx >= array_lengthof(BitmanipMasks))
1111+
return None;
1112+
1113+
auto Src = Op.getOperand(0);
1114+
1115+
unsigned Width = Op.getValueType() == MVT::i64 ? 64 : 32;
1116+
auto ExpMask = BitmanipMasks[MaskIdx] & maskTrailingOnes<uint64_t>(Width);
1117+
1118+
// The expected mask is shifted left when the AND is found around SHL
1119+
// patterns.
1120+
// ((x >> 1) & 0x55555555)
1121+
// ((x << 1) & 0xAAAAAAAA)
1122+
bool SHLExpMask = IsSHL;
1123+
1124+
if (!Mask) {
1125+
// Sometimes LLVM keeps the mask as an operand of the shift, typically when
1126+
// the mask is all ones: consume that now.
1127+
if (Src.getOpcode() == ISD::AND && isa<ConstantSDNode>(Src.getOperand(1))) {
1128+
Mask = Src.getConstantOperandVal(1);
1129+
Src = Src.getOperand(0);
1130+
// The expected mask is now in fact shifted left for SRL, so reverse the
1131+
// decision.
1132+
// ((x & 0xAAAAAAAA) >> 1)
1133+
// ((x & 0x55555555) << 1)
1134+
SHLExpMask = !SHLExpMask;
1135+
} else {
1136+
// Use a default shifted mask of all-ones if there's no AND, truncated
1137+
// down to the expected width. This simplifies the logic later on.
1138+
Mask = maskTrailingOnes<uint64_t>(Width);
1139+
*Mask &= (IsSHL ? *Mask << ShAmt : *Mask >> ShAmt);
1140+
}
10291141
}
1142+
1143+
if (SHLExpMask)
1144+
ExpMask <<= ShAmt;
1145+
1146+
if (Mask != ExpMask)
1147+
return None;
1148+
1149+
return RISCVBitmanipPat{Src, (unsigned)ShAmt, IsSHL};
1150+
}
1151+
1152+
// Match the following pattern as a GREVI(W) operation
1153+
// (or (BITMANIP_SHL x), (BITMANIP_SRL x))
1154+
static SDValue combineORToGREV(SDValue Op, SelectionDAG &DAG,
1155+
const RISCVSubtarget &Subtarget) {
1156+
if (Op.getSimpleValueType() == Subtarget.getXLenVT() ||
1157+
(Subtarget.is64Bit() && Op.getSimpleValueType() == MVT::i32)) {
1158+
auto LHS = matchRISCVBitmanipPat(Op.getOperand(0));
1159+
auto RHS = matchRISCVBitmanipPat(Op.getOperand(1));
1160+
if (LHS && RHS && LHS->formsPairWith(*RHS)) {
1161+
SDLoc DL(Op);
1162+
return DAG.getNode(
1163+
RISCVISD::GREVI, DL, Op.getValueType(), LHS->Op,
1164+
DAG.getTargetConstant(LHS->ShAmt, DL, Subtarget.getXLenVT()));
1165+
}
1166+
}
1167+
return SDValue();
1168+
}
1169+
1170+
// Matches any the following pattern as a GORCI(W) operation
1171+
// 1. (or (GREVI x, shamt), x)
1172+
// 2. (or x, (GREVI x, shamt))
1173+
// 3. (or (or (BITMANIP_SHL x), x), (BITMANIP_SRL x))
1174+
// Note that with the variant of 3.,
1175+
// (or (or (BITMANIP_SHL x), (BITMANIP_SRL x)), x)
1176+
// the inner pattern will first be matched as GREVI and then the outer
1177+
// pattern will be matched to GORC via the first rule above.
1178+
static SDValue combineORToGORC(SDValue Op, SelectionDAG &DAG,
1179+
const RISCVSubtarget &Subtarget) {
1180+
if (Op.getSimpleValueType() == Subtarget.getXLenVT() ||
1181+
(Subtarget.is64Bit() && Op.getSimpleValueType() == MVT::i32)) {
1182+
SDLoc DL(Op);
1183+
SDValue Op0 = Op.getOperand(0);
1184+
SDValue Op1 = Op.getOperand(1);
1185+
1186+
// Check for either commutable permutation of (or (GREVI x, shamt), x)
1187+
for (const auto &OpPair :
1188+
{std::make_pair(Op0, Op1), std::make_pair(Op1, Op0)}) {
1189+
if (OpPair.first.getOpcode() == RISCVISD::GREVI &&
1190+
OpPair.first.getOperand(0) == OpPair.second)
1191+
return DAG.getNode(RISCVISD::GORCI, DL, Op.getValueType(),
1192+
OpPair.second, OpPair.first.getOperand(1));
1193+
}
1194+
1195+
// OR is commutable so canonicalize its OR operand to the left
1196+
if (Op0.getOpcode() != ISD::OR && Op1.getOpcode() == ISD::OR)
1197+
std::swap(Op0, Op1);
1198+
if (Op0.getOpcode() != ISD::OR)
1199+
return SDValue();
1200+
SDValue OrOp0 = Op0.getOperand(0);
1201+
SDValue OrOp1 = Op0.getOperand(1);
1202+
auto LHS = matchRISCVBitmanipPat(OrOp0);
1203+
// OR is commutable so swap the operands and try again: x might have been
1204+
// on the left
1205+
if (!LHS) {
1206+
std::swap(OrOp0, OrOp1);
1207+
LHS = matchRISCVBitmanipPat(OrOp0);
1208+
}
1209+
auto RHS = matchRISCVBitmanipPat(Op1);
1210+
if (LHS && RHS && LHS->formsPairWith(*RHS) && LHS->Op == OrOp1) {
1211+
return DAG.getNode(
1212+
RISCVISD::GORCI, DL, Op.getValueType(), LHS->Op,
1213+
DAG.getTargetConstant(LHS->ShAmt, DL, Subtarget.getXLenVT()));
1214+
}
1215+
}
1216+
return SDValue();
10301217
}
10311218

10321219
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
@@ -1094,6 +1281,18 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
10941281
}
10951282
break;
10961283
}
1284+
case RISCVISD::GREVIW:
1285+
case RISCVISD::GORCIW: {
1286+
// Only the lower 32 bits of the first operand are read
1287+
SDValue Op0 = N->getOperand(0);
1288+
APInt Mask = APInt::getLowBitsSet(Op0.getValueSizeInBits(), 32);
1289+
if (SimplifyDemandedBits(Op0, Mask, DCI)) {
1290+
if (N->getOpcode() != ISD::DELETED_NODE)
1291+
DCI.AddToWorklist(N);
1292+
return SDValue(N, 0);
1293+
}
1294+
break;
1295+
}
10971296
case RISCVISD::FMV_X_ANYEXTW_RV64: {
10981297
SDLoc DL(N);
10991298
SDValue Op0 = N->getOperand(0);
@@ -1124,6 +1323,12 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
11241323
return DAG.getNode(ISD::AND, DL, MVT::i64, NewFMV,
11251324
DAG.getConstant(~SignBit, DL, MVT::i64));
11261325
}
1326+
case ISD::OR:
1327+
if (auto GREV = combineORToGREV(SDValue(N, 0), DCI.DAG, Subtarget))
1328+
return GREV;
1329+
if (auto GORC = combineORToGORC(SDValue(N, 0), DCI.DAG, Subtarget))
1330+
return GORC;
1331+
break;
11271332
}
11281333

11291334
return SDValue();
@@ -1187,6 +1392,8 @@ unsigned RISCVTargetLowering::ComputeNumSignBitsForTargetNode(
11871392
case RISCVISD::DIVW:
11881393
case RISCVISD::DIVUW:
11891394
case RISCVISD::REMUW:
1395+
case RISCVISD::GREVIW:
1396+
case RISCVISD::GORCIW:
11901397
// TODO: As the result is sign-extended, this is conservatively correct. A
11911398
// more precise answer could be calculated for SRAW depending on known
11921399
// bits in the shift amount.
@@ -2625,6 +2832,10 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
26252832
NODE_NAME_CASE(FMV_W_X_RV64)
26262833
NODE_NAME_CASE(FMV_X_ANYEXTW_RV64)
26272834
NODE_NAME_CASE(READ_CYCLE_WIDE)
2835+
NODE_NAME_CASE(GREVI)
2836+
NODE_NAME_CASE(GREVIW)
2837+
NODE_NAME_CASE(GORCI)
2838+
NODE_NAME_CASE(GORCIW)
26282839
}
26292840
// clang-format on
26302841
return nullptr;

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,19 @@ enum NodeType : unsigned {
5151
FMV_X_ANYEXTW_RV64,
5252
// READ_CYCLE_WIDE - A read of the 64-bit cycle CSR on a 32-bit target
5353
// (returns (Lo, Hi)). It takes a chain operand.
54-
READ_CYCLE_WIDE
54+
READ_CYCLE_WIDE,
55+
// Generalized Reverse and Generalized Or-Combine - directly matching the
56+
// semantics of the named RISC-V instructions. Lowered as custom nodes as
57+
// TableGen chokes when faced with commutative permutations in deeply-nested
58+
// DAGs. Each node takes an input operand and a TargetConstant immediate
59+
// shift amount, and outputs a bit-manipulated version of input. All operands
60+
// are of type XLenVT.
61+
GREVI,
62+
GREVIW,
63+
GORCI,
64+
GORCIW,
5565
};
56-
}
66+
} // namespace RISCVISD
5767

5868
class RISCVTargetLowering : public TargetLowering {
5969
const RISCVSubtarget &Subtarget;

0 commit comments

Comments
 (0)