@@ -259,6 +259,10 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
259
259
260
260
// We can use any register for comparisons
261
261
setHasMultipleConditionRegisters ();
262
+
263
+ if (Subtarget.hasStdExtZbp ()) {
264
+ setTargetDAGCombine (ISD::OR);
265
+ }
262
266
}
263
267
264
268
EVT RISCVTargetLowering::getSetCCResultType (const DataLayout &DL, LLVMContext &,
@@ -904,6 +908,10 @@ static RISCVISD::NodeType getRISCVWOpcode(unsigned Opcode) {
904
908
return RISCVISD::DIVUW;
905
909
case ISD::UREM:
906
910
return RISCVISD::REMUW;
911
+ case RISCVISD::GREVI:
912
+ return RISCVISD::GREVIW;
913
+ case RISCVISD::GORCI:
914
+ return RISCVISD::GORCIW;
907
915
}
908
916
}
909
917
@@ -1026,7 +1034,186 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
1026
1034
Results.push_back (DAG.getNode (ISD::TRUNCATE, DL, MVT::i32 , FPConv));
1027
1035
break ;
1028
1036
}
1037
+ case RISCVISD::GREVI:
1038
+ case RISCVISD::GORCI: {
1039
+ assert (N->getValueType (0 ) == MVT::i32 && Subtarget.is64Bit () &&
1040
+ " Unexpected custom legalisation" );
1041
+ // This is similar to customLegalizeToWOp, except that we pass the second
1042
+ // operand (a TargetConstant) straight through: it is already of type
1043
+ // XLenVT.
1044
+ SDLoc DL (N);
1045
+ RISCVISD::NodeType WOpcode = getRISCVWOpcode (N->getOpcode ());
1046
+ SDValue NewOp0 =
1047
+ DAG.getNode (ISD::ANY_EXTEND, DL, MVT::i64 , N->getOperand (0 ));
1048
+ SDValue NewRes =
1049
+ DAG.getNode (WOpcode, DL, MVT::i64 , NewOp0, N->getOperand (1 ));
1050
+ // ReplaceNodeResults requires we maintain the same type for the return
1051
+ // value.
1052
+ Results.push_back (DAG.getNode (ISD::TRUNCATE, DL, MVT::i32 , NewRes));
1053
+ break ;
1054
+ }
1055
+ }
1056
+ }
1057
+
1058
+ // A structure to hold one of the bit-manipulation patterns below. Together, a
1059
+ // SHL and non-SHL pattern may form a bit-manipulation pair on a single source:
1060
+ // (or (and (shl x, 1), 0xAAAAAAAA),
1061
+ // (and (srl x, 1), 0x55555555))
1062
+ struct RISCVBitmanipPat {
1063
+ SDValue Op;
1064
+ unsigned ShAmt;
1065
+ bool IsSHL;
1066
+
1067
+ bool formsPairWith (const RISCVBitmanipPat &Other) const {
1068
+ return Op == Other.Op && ShAmt == Other.ShAmt && IsSHL != Other.IsSHL ;
1069
+ }
1070
+ };
1071
+
1072
+ // Matches any of the following bit-manipulation patterns:
1073
+ // (and (shl x, 1), (0x55555555 << 1))
1074
+ // (and (srl x, 1), 0x55555555)
1075
+ // (shl (and x, 0x55555555), 1)
1076
+ // (srl (and x, (0x55555555 << 1)), 1)
1077
+ // where the shift amount and mask may vary thus:
1078
+ // [1] = 0x55555555 / 0xAAAAAAAA
1079
+ // [2] = 0x33333333 / 0xCCCCCCCC
1080
+ // [4] = 0x0F0F0F0F / 0xF0F0F0F0
1081
+ // [8] = 0x00FF00FF / 0xFF00FF00
1082
+ // [16] = 0x0000FFFF / 0xFFFFFFFF
1083
+ // [32] = 0x00000000FFFFFFFF / 0xFFFFFFFF00000000 (for RV64)
1084
+ static Optional<RISCVBitmanipPat> matchRISCVBitmanipPat (SDValue Op) {
1085
+ Optional<uint64_t > Mask;
1086
+ // Optionally consume a mask around the shift operation.
1087
+ if (Op.getOpcode () == ISD::AND && isa<ConstantSDNode>(Op.getOperand (1 ))) {
1088
+ Mask = Op.getConstantOperandVal (1 );
1089
+ Op = Op.getOperand (0 );
1090
+ }
1091
+ if (Op.getOpcode () != ISD::SHL && Op.getOpcode () != ISD::SRL)
1092
+ return None;
1093
+ bool IsSHL = Op.getOpcode () == ISD::SHL;
1094
+
1095
+ if (!isa<ConstantSDNode>(Op.getOperand (1 )))
1096
+ return None;
1097
+ auto ShAmt = Op.getConstantOperandVal (1 );
1098
+
1099
+ if (!isPowerOf2_64 (ShAmt))
1100
+ return None;
1101
+
1102
+ // These are the unshifted masks which we use to match bit-manipulation
1103
+ // patterns. They may be shifted left in certain circumstances.
1104
+ static const uint64_t BitmanipMasks[] = {
1105
+ 0x5555555555555555ULL , 0x3333333333333333ULL , 0x0F0F0F0F0F0F0F0FULL ,
1106
+ 0x00FF00FF00FF00FFULL , 0x0000FFFF0000FFFFULL , 0x00000000FFFFFFFFULL ,
1107
+ };
1108
+
1109
+ unsigned MaskIdx = Log2_64 (ShAmt);
1110
+ if (MaskIdx >= array_lengthof (BitmanipMasks))
1111
+ return None;
1112
+
1113
+ auto Src = Op.getOperand (0 );
1114
+
1115
+ unsigned Width = Op.getValueType () == MVT::i64 ? 64 : 32 ;
1116
+ auto ExpMask = BitmanipMasks[MaskIdx] & maskTrailingOnes<uint64_t >(Width);
1117
+
1118
+ // The expected mask is shifted left when the AND is found around SHL
1119
+ // patterns.
1120
+ // ((x >> 1) & 0x55555555)
1121
+ // ((x << 1) & 0xAAAAAAAA)
1122
+ bool SHLExpMask = IsSHL;
1123
+
1124
+ if (!Mask) {
1125
+ // Sometimes LLVM keeps the mask as an operand of the shift, typically when
1126
+ // the mask is all ones: consume that now.
1127
+ if (Src.getOpcode () == ISD::AND && isa<ConstantSDNode>(Src.getOperand (1 ))) {
1128
+ Mask = Src.getConstantOperandVal (1 );
1129
+ Src = Src.getOperand (0 );
1130
+ // The expected mask is now in fact shifted left for SRL, so reverse the
1131
+ // decision.
1132
+ // ((x & 0xAAAAAAAA) >> 1)
1133
+ // ((x & 0x55555555) << 1)
1134
+ SHLExpMask = !SHLExpMask;
1135
+ } else {
1136
+ // Use a default shifted mask of all-ones if there's no AND, truncated
1137
+ // down to the expected width. This simplifies the logic later on.
1138
+ Mask = maskTrailingOnes<uint64_t >(Width);
1139
+ *Mask &= (IsSHL ? *Mask << ShAmt : *Mask >> ShAmt);
1140
+ }
1029
1141
}
1142
+
1143
+ if (SHLExpMask)
1144
+ ExpMask <<= ShAmt;
1145
+
1146
+ if (Mask != ExpMask)
1147
+ return None;
1148
+
1149
+ return RISCVBitmanipPat{Src, (unsigned )ShAmt, IsSHL};
1150
+ }
1151
+
1152
+ // Match the following pattern as a GREVI(W) operation
1153
+ // (or (BITMANIP_SHL x), (BITMANIP_SRL x))
1154
+ static SDValue combineORToGREV (SDValue Op, SelectionDAG &DAG,
1155
+ const RISCVSubtarget &Subtarget) {
1156
+ if (Op.getSimpleValueType () == Subtarget.getXLenVT () ||
1157
+ (Subtarget.is64Bit () && Op.getSimpleValueType () == MVT::i32 )) {
1158
+ auto LHS = matchRISCVBitmanipPat (Op.getOperand (0 ));
1159
+ auto RHS = matchRISCVBitmanipPat (Op.getOperand (1 ));
1160
+ if (LHS && RHS && LHS->formsPairWith (*RHS)) {
1161
+ SDLoc DL (Op);
1162
+ return DAG.getNode (
1163
+ RISCVISD::GREVI, DL, Op.getValueType (), LHS->Op ,
1164
+ DAG.getTargetConstant (LHS->ShAmt , DL, Subtarget.getXLenVT ()));
1165
+ }
1166
+ }
1167
+ return SDValue ();
1168
+ }
1169
+
1170
+ // Matches any the following pattern as a GORCI(W) operation
1171
+ // 1. (or (GREVI x, shamt), x)
1172
+ // 2. (or x, (GREVI x, shamt))
1173
+ // 3. (or (or (BITMANIP_SHL x), x), (BITMANIP_SRL x))
1174
+ // Note that with the variant of 3.,
1175
+ // (or (or (BITMANIP_SHL x), (BITMANIP_SRL x)), x)
1176
+ // the inner pattern will first be matched as GREVI and then the outer
1177
+ // pattern will be matched to GORC via the first rule above.
1178
+ static SDValue combineORToGORC (SDValue Op, SelectionDAG &DAG,
1179
+ const RISCVSubtarget &Subtarget) {
1180
+ if (Op.getSimpleValueType () == Subtarget.getXLenVT () ||
1181
+ (Subtarget.is64Bit () && Op.getSimpleValueType () == MVT::i32 )) {
1182
+ SDLoc DL (Op);
1183
+ SDValue Op0 = Op.getOperand (0 );
1184
+ SDValue Op1 = Op.getOperand (1 );
1185
+
1186
+ // Check for either commutable permutation of (or (GREVI x, shamt), x)
1187
+ for (const auto &OpPair :
1188
+ {std::make_pair (Op0, Op1), std::make_pair (Op1, Op0)}) {
1189
+ if (OpPair.first .getOpcode () == RISCVISD::GREVI &&
1190
+ OpPair.first .getOperand (0 ) == OpPair.second )
1191
+ return DAG.getNode (RISCVISD::GORCI, DL, Op.getValueType (),
1192
+ OpPair.second , OpPair.first .getOperand (1 ));
1193
+ }
1194
+
1195
+ // OR is commutable so canonicalize its OR operand to the left
1196
+ if (Op0.getOpcode () != ISD::OR && Op1.getOpcode () == ISD::OR)
1197
+ std::swap (Op0, Op1);
1198
+ if (Op0.getOpcode () != ISD::OR)
1199
+ return SDValue ();
1200
+ SDValue OrOp0 = Op0.getOperand (0 );
1201
+ SDValue OrOp1 = Op0.getOperand (1 );
1202
+ auto LHS = matchRISCVBitmanipPat (OrOp0);
1203
+ // OR is commutable so swap the operands and try again: x might have been
1204
+ // on the left
1205
+ if (!LHS) {
1206
+ std::swap (OrOp0, OrOp1);
1207
+ LHS = matchRISCVBitmanipPat (OrOp0);
1208
+ }
1209
+ auto RHS = matchRISCVBitmanipPat (Op1);
1210
+ if (LHS && RHS && LHS->formsPairWith (*RHS) && LHS->Op == OrOp1) {
1211
+ return DAG.getNode (
1212
+ RISCVISD::GORCI, DL, Op.getValueType (), LHS->Op ,
1213
+ DAG.getTargetConstant (LHS->ShAmt , DL, Subtarget.getXLenVT ()));
1214
+ }
1215
+ }
1216
+ return SDValue ();
1030
1217
}
1031
1218
1032
1219
SDValue RISCVTargetLowering::PerformDAGCombine (SDNode *N,
@@ -1094,6 +1281,18 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1094
1281
}
1095
1282
break ;
1096
1283
}
1284
+ case RISCVISD::GREVIW:
1285
+ case RISCVISD::GORCIW: {
1286
+ // Only the lower 32 bits of the first operand are read
1287
+ SDValue Op0 = N->getOperand (0 );
1288
+ APInt Mask = APInt::getLowBitsSet (Op0.getValueSizeInBits (), 32 );
1289
+ if (SimplifyDemandedBits (Op0, Mask, DCI)) {
1290
+ if (N->getOpcode () != ISD::DELETED_NODE)
1291
+ DCI.AddToWorklist (N);
1292
+ return SDValue (N, 0 );
1293
+ }
1294
+ break ;
1295
+ }
1097
1296
case RISCVISD::FMV_X_ANYEXTW_RV64: {
1098
1297
SDLoc DL (N);
1099
1298
SDValue Op0 = N->getOperand (0 );
@@ -1124,6 +1323,12 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1124
1323
return DAG.getNode (ISD::AND, DL, MVT::i64 , NewFMV,
1125
1324
DAG.getConstant (~SignBit, DL, MVT::i64 ));
1126
1325
}
1326
+ case ISD::OR:
1327
+ if (auto GREV = combineORToGREV (SDValue (N, 0 ), DCI.DAG , Subtarget))
1328
+ return GREV;
1329
+ if (auto GORC = combineORToGORC (SDValue (N, 0 ), DCI.DAG , Subtarget))
1330
+ return GORC;
1331
+ break ;
1127
1332
}
1128
1333
1129
1334
return SDValue ();
@@ -1187,6 +1392,8 @@ unsigned RISCVTargetLowering::ComputeNumSignBitsForTargetNode(
1187
1392
case RISCVISD::DIVW:
1188
1393
case RISCVISD::DIVUW:
1189
1394
case RISCVISD::REMUW:
1395
+ case RISCVISD::GREVIW:
1396
+ case RISCVISD::GORCIW:
1190
1397
// TODO: As the result is sign-extended, this is conservatively correct. A
1191
1398
// more precise answer could be calculated for SRAW depending on known
1192
1399
// bits in the shift amount.
@@ -2625,6 +2832,10 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
2625
2832
NODE_NAME_CASE (FMV_W_X_RV64)
2626
2833
NODE_NAME_CASE (FMV_X_ANYEXTW_RV64)
2627
2834
NODE_NAME_CASE (READ_CYCLE_WIDE)
2835
+ NODE_NAME_CASE (GREVI)
2836
+ NODE_NAME_CASE (GREVIW)
2837
+ NODE_NAME_CASE (GORCI)
2838
+ NODE_NAME_CASE (GORCIW)
2628
2839
}
2629
2840
// clang-format on
2630
2841
return nullptr ;
0 commit comments