@@ -4151,32 +4151,96 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
4151
4151
4152
4152
SDValue AMDGPUTargetLowering::performSraCombine (SDNode *N,
4153
4153
DAGCombinerInfo &DCI) const {
4154
- if (N->getValueType (0 ) != MVT::i64 )
4154
+ SDValue RHS = N->getOperand (1 );
4155
+ ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(RHS);
4156
+ EVT VT = N->getValueType (0 );
4157
+ SDValue LHS = N->getOperand (0 );
4158
+ SelectionDAG &DAG = DCI.DAG ;
4159
+ SDLoc SL (N);
4160
+
4161
+ if (VT.getScalarType () != MVT::i64 )
4155
4162
return SDValue ();
4156
4163
4157
- const ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(N->getOperand (1 ));
4158
- if (!RHS)
4164
+ // For C >= 32
4165
+ // i64 (sra x, C) -> (build_pair (sra hi_32(x), C - 32), sra hi_32(x), 31))
4166
+
4167
+ // On some subtargets, 64-bit shift is a quarter rate instruction. In the
4168
+ // common case, splitting this into a move and a 32-bit shift is faster and
4169
+ // the same code size.
4170
+ KnownBits Known = DAG.computeKnownBits (RHS);
4171
+
4172
+ EVT ElementType = VT.getScalarType ();
4173
+ EVT TargetScalarType = ElementType.getHalfSizedIntegerVT (*DAG.getContext ());
4174
+ EVT TargetType = VT.isVector () ? VT.changeVectorElementType (TargetScalarType)
4175
+ : TargetScalarType;
4176
+
4177
+ if (Known.getMinValue ().getZExtValue () < TargetScalarType.getSizeInBits ())
4159
4178
return SDValue ();
4160
4179
4161
- SelectionDAG &DAG = DCI.DAG ;
4162
- SDLoc SL (N);
4163
- unsigned RHSVal = RHS->getZExtValue ();
4180
+ SDValue ShiftFullAmt =
4181
+ DAG.getConstant (TargetScalarType.getSizeInBits () - 1 , SL, TargetType);
4182
+ SDValue ShiftAmt;
4183
+ if (CRHS) {
4184
+ unsigned RHSVal = CRHS->getZExtValue ();
4185
+ ShiftAmt = DAG.getConstant (RHSVal - TargetScalarType.getSizeInBits (), SL,
4186
+ TargetType);
4187
+ } else if (Known.getMinValue ().getZExtValue () ==
4188
+ (ElementType.getSizeInBits () - 1 )) {
4189
+ ShiftAmt = ShiftFullAmt;
4190
+ } else {
4191
+ SDValue truncShiftAmt = DAG.getNode (ISD::TRUNCATE, SL, TargetType, RHS);
4192
+ const SDValue ShiftMask =
4193
+ DAG.getConstant (TargetScalarType.getSizeInBits () - 1 , SL, TargetType);
4194
+ // This AND instruction will clamp out of bounds shift values.
4195
+ // It will also be removed during later instruction selection.
4196
+ ShiftAmt = DAG.getNode (ISD::AND, SL, TargetType, truncShiftAmt, ShiftMask);
4197
+ }
4164
4198
4165
- // For C >= 32
4166
- // (sra i64:x, C) -> build_pair (sra hi_32(x), C - 32), (sra hi_32(x), 31)
4167
- if (RHSVal >= 32 ) {
4168
- SDValue Hi = getHiHalf64 (N->getOperand (0 ), DAG);
4169
- Hi = DAG.getFreeze (Hi);
4170
- SDValue HiShift = DAG.getNode (ISD::SRA, SL, MVT::i32 , Hi,
4171
- DAG.getConstant (31 , SL, MVT::i32 ));
4172
- SDValue LoShift = DAG.getNode (ISD::SRA, SL, MVT::i32 , Hi,
4173
- DAG.getConstant (RHSVal - 32 , SL, MVT::i32 ));
4199
+ EVT ConcatType;
4200
+ SDValue Hi;
4201
+ SDLoc LHSSL (LHS);
4202
+ // Bitcast LHS into ConcatType so hi-half of source can be extracted into Hi
4203
+ if (VT.isVector ()) {
4204
+ unsigned NElts = TargetType.getVectorNumElements ();
4205
+ ConcatType = TargetType.getDoubleNumVectorElementsVT (*DAG.getContext ());
4206
+ SDValue SplitLHS = DAG.getNode (ISD::BITCAST, LHSSL, ConcatType, LHS);
4207
+ SmallVector<SDValue, 8 > HiOps (NElts);
4208
+ SmallVector<SDValue, 16 > HiAndLoOps;
4174
4209
4175
- SDValue BuildVec = DAG.getBuildVector (MVT::v2i32, SL, {LoShift, HiShift});
4176
- return DAG.getNode (ISD::BITCAST, SL, MVT::i64 , BuildVec);
4210
+ DAG.ExtractVectorElements (SplitLHS, HiAndLoOps, 0 , NElts * 2 );
4211
+ for (unsigned I = 0 ; I != NElts; ++I) {
4212
+ HiOps[I] = HiAndLoOps[2 * I + 1 ];
4213
+ }
4214
+ Hi = DAG.getNode (ISD::BUILD_VECTOR, LHSSL, TargetType, HiOps);
4215
+ } else {
4216
+ const SDValue One = DAG.getConstant (1 , LHSSL, TargetScalarType);
4217
+ ConcatType = EVT::getVectorVT (*DAG.getContext (), TargetType, 2 );
4218
+ SDValue SplitLHS = DAG.getNode (ISD::BITCAST, LHSSL, ConcatType, LHS);
4219
+ Hi = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, LHSSL, TargetType, SplitLHS, One);
4177
4220
}
4221
+ Hi = DAG.getFreeze (Hi);
4178
4222
4179
- return SDValue ();
4223
+ SDValue HiShift = DAG.getNode (ISD::SRA, SL, TargetType, Hi, ShiftFullAmt);
4224
+ SDValue NewShift = DAG.getNode (ISD::SRA, SL, TargetType, Hi, ShiftAmt);
4225
+
4226
+ SDValue Vec;
4227
+ if (VT.isVector ()) {
4228
+ unsigned NElts = TargetType.getVectorNumElements ();
4229
+ SmallVector<SDValue, 8 > HiOps;
4230
+ SmallVector<SDValue, 8 > LoOps;
4231
+ SmallVector<SDValue, 16 > HiAndLoOps (NElts * 2 );
4232
+
4233
+ DAG.ExtractVectorElements (HiShift, HiOps, 0 , NElts);
4234
+ DAG.ExtractVectorElements (NewShift, LoOps, 0 , NElts);
4235
+ for (unsigned I = 0 ; I != NElts; ++I) {
4236
+ HiAndLoOps[2 * I + 1 ] = HiOps[I];
4237
+ HiAndLoOps[2 * I] = LoOps[I];
4238
+ }
4239
+ Vec = DAG.getNode (ISD::BUILD_VECTOR, SL, ConcatType, HiAndLoOps);
4240
+ } else {
4241
+ Vec = DAG.getBuildVector (ConcatType, SL, {NewShift, HiShift});
4242
+ }
4243
+ return DAG.getNode (ISD::BITCAST, SL, VT, Vec);
4180
4244
}
4181
4245
4182
4246
SDValue AMDGPUTargetLowering::performSrlCombine (SDNode *N,
@@ -4213,7 +4277,7 @@ SDValue AMDGPUTargetLowering::performSrlCombine(SDNode *N,
4213
4277
return SDValue ();
4214
4278
4215
4279
// for C >= 32
4216
- // i64 (srl x, C) -> (build_pair (srl hi_32(x), C -32), 0)
4280
+ // i64 (srl x, C) -> (build_pair (srl hi_32(x), C - 32), 0)
4217
4281
4218
4282
// On some subtargets, 64-bit shift is a quarter rate instruction. In the
4219
4283
// common case, splitting this into a move and a 32-bit shift is faster and
@@ -5265,25 +5329,22 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N,
5265
5329
break ;
5266
5330
}
5267
5331
case ISD::SHL:
5332
+ case ISD::SRA:
5268
5333
case ISD::SRL: {
5269
5334
// Range metadata can be invalidated when loads are converted to legal types
5270
5335
// (e.g. v2i64 -> v4i32).
5271
- // Try to convert vector shl/srl before type legalization so that range
5336
+ // Try to convert vector shl/sra/ srl before type legalization so that range
5272
5337
// metadata can be utilized.
5273
5338
if (!(N->getValueType (0 ).isVector () &&
5274
5339
DCI.getDAGCombineLevel () == BeforeLegalizeTypes) &&
5275
5340
DCI.getDAGCombineLevel () < AfterLegalizeDAG)
5276
5341
break ;
5277
5342
if (N->getOpcode () == ISD::SHL)
5278
5343
return performShlCombine (N, DCI);
5344
+ if (N->getOpcode () == ISD::SRA)
5345
+ return performSraCombine (N, DCI);
5279
5346
return performSrlCombine (N, DCI);
5280
5347
}
5281
- case ISD::SRA: {
5282
- if (DCI.getDAGCombineLevel () < AfterLegalizeDAG)
5283
- break ;
5284
-
5285
- return performSraCombine (N, DCI);
5286
- }
5287
5348
case ISD::TRUNCATE:
5288
5349
return performTruncateCombine (N, DCI);
5289
5350
case ISD::MUL:
0 commit comments