@@ -4097,7 +4097,7 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
4097
4097
if (VT.getScalarType () != MVT::i64 )
4098
4098
return SDValue ();
4099
4099
4100
- // i64 (shl x, C) -> (build_pair 0, (shl x, C -32))
4100
+ // i64 (shl x, C) -> (build_pair 0, (shl x, C - 32))
4101
4101
4102
4102
// On some subtargets, 64-bit shift is a quarter rate instruction. In the
4103
4103
// common case, splitting this into a move and a 32-bit shift is faster and
@@ -4117,12 +4117,12 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
4117
4117
ShiftAmt = DAG.getConstant (RHSVal - TargetScalarType.getSizeInBits (), SL,
4118
4118
TargetType);
4119
4119
} else {
4120
- SDValue truncShiftAmt = DAG.getNode (ISD::TRUNCATE, SL, TargetType, RHS);
4120
+ SDValue TruncShiftAmt = DAG.getNode (ISD::TRUNCATE, SL, TargetType, RHS);
4121
4121
const SDValue ShiftMask =
4122
4122
DAG.getConstant (TargetScalarType.getSizeInBits () - 1 , SL, TargetType);
4123
4123
// This AND instruction will clamp out of bounds shift values.
4124
4124
// It will also be removed during later instruction selection.
4125
- ShiftAmt = DAG.getNode (ISD::AND, SL, TargetType, truncShiftAmt , ShiftMask);
4125
+ ShiftAmt = DAG.getNode (ISD::AND, SL, TargetType, TruncShiftAmt , ShiftMask);
4126
4126
}
4127
4127
4128
4128
SDValue Lo = DAG.getNode (ISD::TRUNCATE, SL, TargetType, LHS);
@@ -4181,50 +4181,105 @@ SDValue AMDGPUTargetLowering::performSraCombine(SDNode *N,
4181
4181
4182
4182
SDValue AMDGPUTargetLowering::performSrlCombine (SDNode *N,
4183
4183
DAGCombinerInfo &DCI) const {
4184
- auto *RHS = dyn_cast<ConstantSDNode>(N->getOperand (1 ));
4185
- if (!RHS)
4186
- return SDValue ();
4187
-
4184
+ SDValue RHS = N->getOperand (1 );
4185
+ ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(RHS);
4188
4186
EVT VT = N->getValueType (0 );
4189
4187
SDValue LHS = N->getOperand (0 );
4190
- unsigned ShiftAmt = RHS->getZExtValue ();
4191
4188
SelectionDAG &DAG = DCI.DAG ;
4192
4189
SDLoc SL (N);
4190
+ unsigned RHSVal;
4191
+
4192
+ if (CRHS) {
4193
+ RHSVal = CRHS->getZExtValue ();
4193
4194
4194
- // fold (srl (and x, c1 << c2), c2) -> (and (srl(x, c2), c1)
4195
- // this improves the ability to match BFE patterns in isel.
4196
- if (LHS.getOpcode () == ISD::AND) {
4197
- if (auto *Mask = dyn_cast<ConstantSDNode>(LHS.getOperand (1 ))) {
4198
- unsigned MaskIdx, MaskLen;
4199
- if (Mask->getAPIntValue ().isShiftedMask (MaskIdx, MaskLen) &&
4200
- MaskIdx == ShiftAmt) {
4201
- return DAG.getNode (
4202
- ISD::AND, SL, VT,
4203
- DAG.getNode (ISD::SRL, SL, VT, LHS.getOperand (0 ), N->getOperand (1 )),
4204
- DAG.getNode (ISD::SRL, SL, VT, LHS.getOperand (1 ), N->getOperand (1 )));
4195
+ // fold (srl (and x, c1 << c2), c2) -> (and (srl(x, c2), c1)
4196
+ // this improves the ability to match BFE patterns in isel.
4197
+ if (LHS.getOpcode () == ISD::AND) {
4198
+ if (auto *Mask = dyn_cast<ConstantSDNode>(LHS.getOperand (1 ))) {
4199
+ unsigned MaskIdx, MaskLen;
4200
+ if (Mask->getAPIntValue ().isShiftedMask (MaskIdx, MaskLen) &&
4201
+ MaskIdx == RHSVal) {
4202
+ return DAG.getNode (ISD::AND, SL, VT,
4203
+ DAG.getNode (ISD::SRL, SL, VT, LHS.getOperand (0 ),
4204
+ N->getOperand (1 )),
4205
+ DAG.getNode (ISD::SRL, SL, VT, LHS.getOperand (1 ),
4206
+ N->getOperand (1 )));
4207
+ }
4205
4208
}
4206
4209
}
4207
4210
}
4208
4211
4209
- if (VT != MVT::i64 )
4212
+ if (VT. getScalarType () != MVT::i64 )
4210
4213
return SDValue ();
4211
4214
4212
- if (ShiftAmt < 32 )
4215
+ // for C >= 32
4216
+ // i64 (srl x, C) -> (build_pair (srl hi_32(x), C -32), 0)
4217
+
4218
+ // On some subtargets, 64-bit shift is a quarter rate instruction. In the
4219
+ // common case, splitting this into a move and a 32-bit shift is faster and
4220
+ // the same code size.
4221
+ KnownBits Known = DAG.computeKnownBits (RHS);
4222
+
4223
+ EVT ElementType = VT.getScalarType ();
4224
+ EVT TargetScalarType = ElementType.getHalfSizedIntegerVT (*DAG.getContext ());
4225
+ EVT TargetType = VT.isVector () ? VT.changeVectorElementType (TargetScalarType)
4226
+ : TargetScalarType;
4227
+
4228
+ if (Known.getMinValue ().getZExtValue () < TargetScalarType.getSizeInBits ())
4213
4229
return SDValue ();
4214
4230
4215
- // srl i64:x, C for C >= 32
4216
- // =>
4217
- // build_pair (srl hi_32(x), C - 32), 0
4218
- SDValue Zero = DAG.getConstant (0 , SL, MVT::i32 );
4231
+ SDValue ShiftAmt;
4232
+ if (CRHS) {
4233
+ ShiftAmt = DAG.getConstant (RHSVal - TargetScalarType.getSizeInBits (), SL,
4234
+ TargetType);
4235
+ } else {
4236
+ SDValue TruncShiftAmt = DAG.getNode (ISD::TRUNCATE, SL, TargetType, RHS);
4237
+ const SDValue ShiftMask =
4238
+ DAG.getConstant (TargetScalarType.getSizeInBits () - 1 , SL, TargetType);
4239
+ // This AND instruction will clamp out of bounds shift values.
4240
+ // It will also be removed during later instruction selection.
4241
+ ShiftAmt = DAG.getNode (ISD::AND, SL, TargetType, TruncShiftAmt, ShiftMask);
4242
+ }
4243
+
4244
+ const SDValue Zero = DAG.getConstant (0 , SL, TargetScalarType);
4245
+ EVT ConcatType;
4246
+ SDValue Hi;
4247
+ SDLoc LHSSL (LHS);
4248
+ // Bitcast LHS into ConcatType so hi-half of source can be extracted into Hi
4249
+ if (VT.isVector ()) {
4250
+ unsigned NElts = TargetType.getVectorNumElements ();
4251
+ ConcatType = TargetType.getDoubleNumVectorElementsVT (*DAG.getContext ());
4252
+ SDValue SplitLHS = DAG.getNode (ISD::BITCAST, LHSSL, ConcatType, LHS);
4253
+ SmallVector<SDValue, 8 > HiOps (NElts);
4254
+ SmallVector<SDValue, 16 > HiAndLoOps;
4219
4255
4220
- SDValue Hi = getHiHalf64 (LHS, DAG);
4256
+ DAG.ExtractVectorElements (SplitLHS, HiAndLoOps, /* Start=*/ 0 , NElts * 2 );
4257
+ for (unsigned I = 0 ; I != NElts; ++I)
4258
+ HiOps[I] = HiAndLoOps[2 * I + 1 ];
4259
+ Hi = DAG.getNode (ISD::BUILD_VECTOR, LHSSL, TargetType, HiOps);
4260
+ } else {
4261
+ const SDValue One = DAG.getConstant (1 , LHSSL, TargetScalarType);
4262
+ ConcatType = EVT::getVectorVT (*DAG.getContext (), TargetType, 2 );
4263
+ SDValue SplitLHS = DAG.getNode (ISD::BITCAST, LHSSL, ConcatType, LHS);
4264
+ Hi = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, LHSSL, TargetType, SplitLHS, One);
4265
+ }
4221
4266
4222
- SDValue NewConst = DAG.getConstant (ShiftAmt - 32 , SL, MVT::i32 );
4223
- SDValue NewShift = DAG.getNode (ISD::SRL, SL, MVT::i32 , Hi, NewConst);
4267
+ SDValue NewShift = DAG.getNode (ISD::SRL, SL, TargetType, Hi, ShiftAmt);
4224
4268
4225
- SDValue BuildPair = DAG.getBuildVector (MVT::v2i32, SL, {NewShift, Zero});
4269
+ SDValue Vec;
4270
+ if (VT.isVector ()) {
4271
+ unsigned NElts = TargetType.getVectorNumElements ();
4272
+ SmallVector<SDValue, 8 > LoOps;
4273
+ SmallVector<SDValue, 16 > HiAndLoOps (NElts * 2 , Zero);
4226
4274
4227
- return DAG.getNode (ISD::BITCAST, SL, MVT::i64 , BuildPair);
4275
+ DAG.ExtractVectorElements (NewShift, LoOps, 0 , NElts);
4276
+ for (unsigned I = 0 ; I != NElts; ++I)
4277
+ HiAndLoOps[2 * I] = LoOps[I];
4278
+ Vec = DAG.getNode (ISD::BUILD_VECTOR, SL, ConcatType, HiAndLoOps);
4279
+ } else {
4280
+ Vec = DAG.getBuildVector (ConcatType, SL, {NewShift, Zero});
4281
+ }
4282
+ return DAG.getNode (ISD::BITCAST, SL, VT, Vec);
4228
4283
}
4229
4284
4230
4285
SDValue AMDGPUTargetLowering::performTruncateCombine (
@@ -5209,21 +5264,18 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N,
5209
5264
5210
5265
break ;
5211
5266
}
5212
- case ISD::SHL: {
5267
+ case ISD::SHL:
5268
+ case ISD::SRL: {
5213
5269
// Range metadata can be invalidated when loads are converted to legal types
5214
5270
// (e.g. v2i64 -> v4i32).
5215
- // Try to convert vector shl before type legalization so that range metadata
5216
- // can be utilized.
5271
+ // Try to convert vector shl/srl before type legalization so that range
5272
+ // metadata can be utilized.
5217
5273
if (!(N->getValueType (0 ).isVector () &&
5218
5274
DCI.getDAGCombineLevel () == BeforeLegalizeTypes) &&
5219
5275
DCI.getDAGCombineLevel () < AfterLegalizeDAG)
5220
5276
break ;
5221
- return performShlCombine (N, DCI);
5222
- }
5223
- case ISD::SRL: {
5224
- if (DCI.getDAGCombineLevel () < AfterLegalizeDAG)
5225
- break ;
5226
-
5277
+ if (N->getOpcode () == ISD::SHL)
5278
+ return performShlCombine (N, DCI);
5227
5279
return performSrlCombine (N, DCI);
5228
5280
}
5229
5281
case ISD::SRA: {
0 commit comments