@@ -10946,6 +10946,9 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
10946
10946
if (Op.getValueSizeInBits() < 8)
10947
10947
return std::nullopt;
10948
10948
10949
+ if (Op.getValueType().isVector())
10950
+ return ByteProvider<SDValue>::getSrc(Op, DestByte, SrcIndex, IsSigned);
10951
+
10949
10952
switch (Op->getOpcode()) {
10950
10953
case ISD::TRUNCATE: {
10951
10954
return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned,
@@ -11031,8 +11034,12 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11031
11034
if (Index > BitWidth / 8 - 1)
11032
11035
return std::nullopt;
11033
11036
11037
+ bool IsVec = Op.getValueType().isVector();
11034
11038
switch (Op.getOpcode()) {
11035
11039
case ISD::OR: {
11040
+ if (IsVec)
11041
+ return std::nullopt;
11042
+
11036
11043
auto RHS = calculateByteProvider(Op.getOperand(1), Index, Depth + 1,
11037
11044
StartingIndex, IsSigned);
11038
11045
if (!RHS)
@@ -11053,6 +11060,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11053
11060
}
11054
11061
11055
11062
case ISD::AND: {
11063
+ if (IsVec)
11064
+ return std::nullopt;
11065
+
11056
11066
auto BitMaskOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
11057
11067
if (!BitMaskOp)
11058
11068
return std::nullopt;
@@ -11073,6 +11083,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11073
11083
}
11074
11084
11075
11085
case ISD::FSHR: {
11086
+ if (IsVec)
11087
+ return std::nullopt;
11088
+
11076
11089
// fshr(X,Y,Z): (X << (BW - (Z % BW))) | (Y >> (Z % BW))
11077
11090
auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(2));
11078
11091
if (!ShiftOp || Op.getValueType().isVector())
@@ -11098,6 +11111,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11098
11111
11099
11112
case ISD::SRA:
11100
11113
case ISD::SRL: {
11114
+ if (IsVec)
11115
+ return std::nullopt;
11116
+
11101
11117
auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
11102
11118
if (!ShiftOp)
11103
11119
return std::nullopt;
@@ -11123,6 +11139,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11123
11139
}
11124
11140
11125
11141
case ISD::SHL: {
11142
+ if (IsVec)
11143
+ return std::nullopt;
11144
+
11126
11145
auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
11127
11146
if (!ShiftOp)
11128
11147
return std::nullopt;
@@ -11147,6 +11166,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11147
11166
case ISD::SIGN_EXTEND_INREG:
11148
11167
case ISD::AssertZext:
11149
11168
case ISD::AssertSext: {
11169
+ if (IsVec)
11170
+ return std::nullopt;
11171
+
11150
11172
SDValue NarrowOp = Op->getOperand(0);
11151
11173
unsigned NarrowBitWidth = NarrowOp.getValueSizeInBits();
11152
11174
if (Op->getOpcode() == ISD::SIGN_EXTEND_INREG ||
@@ -11177,6 +11199,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11177
11199
}
11178
11200
11179
11201
case ISD::TRUNCATE: {
11202
+ if (IsVec)
11203
+ return std::nullopt;
11204
+
11180
11205
uint64_t NarrowByteWidth = BitWidth / 8;
11181
11206
11182
11207
if (NarrowByteWidth >= Index) {
@@ -11223,9 +11248,13 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11223
11248
return std::nullopt;
11224
11249
}
11225
11250
11226
- case ISD::BSWAP:
11251
+ case ISD::BSWAP: {
11252
+ if (IsVec)
11253
+ return std::nullopt;
11254
+
11227
11255
return calculateByteProvider(Op->getOperand(0), BitWidth / 8 - Index - 1,
11228
11256
Depth + 1, StartingIndex, IsSigned);
11257
+ }
11229
11258
11230
11259
case ISD::EXTRACT_VECTOR_ELT: {
11231
11260
auto IdxOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
@@ -11242,6 +11271,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11242
11271
}
11243
11272
11244
11273
case AMDGPUISD::PERM: {
11274
+ if (IsVec)
11275
+ return std::nullopt;
11276
+
11245
11277
auto PermMask = dyn_cast<ConstantSDNode>(Op->getOperand(2));
11246
11278
if (!PermMask)
11247
11279
return std::nullopt;
@@ -11339,25 +11371,55 @@ static bool hasNon16BitAccesses(uint64_t PermMask, SDValue &Op,
11339
11371
static SDValue getDWordFromOffset(SelectionDAG &DAG, SDLoc SL, SDValue Src,
11340
11372
unsigned DWordOffset) {
11341
11373
SDValue Ret;
11342
- if (Src.getValueSizeInBits() <= 32)
11343
- return DAG.getBitcastedAnyExtOrTrunc(Src, SL, MVT::i32);
11344
11374
11345
- if (Src.getValueSizeInBits() >= 256) {
11346
- assert(!(Src.getValueSizeInBits() % 32));
11347
- Ret = DAG.getBitcast(
11348
- MVT::getVectorVT(MVT::i32, Src.getValueSizeInBits() / 32), Src);
11349
- return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, Ret,
11350
- DAG.getConstant(DWordOffset, SL, MVT::i32));
11351
- }
11375
+ auto TypeSize = Src.getValueSizeInBits().getFixedValue();
11376
+ // ByteProvider must be at least 8 bits
11377
+ assert(Src.getValueSizeInBits().isKnownMultipleOf(8));
11352
11378
11353
- Ret = DAG.getBitcastedAnyExtOrTrunc(
11354
- Src, SL, MVT::getIntegerVT(Src.getValueSizeInBits()));
11355
- if (DWordOffset) {
11356
- auto Shifted = DAG.getNode(ISD::SRL, SL, Ret.getValueType(), Ret,
11357
- DAG.getConstant(DWordOffset * 32, SL, MVT::i32));
11358
- return DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, Shifted);
11359
- }
11379
+ if (TypeSize <= 32)
11380
+ return DAG.getBitcastedAnyExtOrTrunc(Src, SL, MVT::i32);
11360
11381
11382
+ if (Src.getValueType().isVector()) {
11383
+ auto ScalarTySize = Src.getScalarValueSizeInBits();
11384
+ auto ScalarTy = Src.getValueType().getScalarType();
11385
+ if (ScalarTySize == 32) {
11386
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, Src,
11387
+ DAG.getConstant(DWordOffset, SL, MVT::i32));
11388
+ }
11389
+ if (ScalarTySize > 32) {
11390
+ Ret = DAG.getNode(
11391
+ ISD::EXTRACT_VECTOR_ELT, SL, ScalarTy, Src,
11392
+ DAG.getConstant(DWordOffset / (ScalarTySize / 32), SL, MVT::i32));
11393
+ auto ShiftVal = 32 * (DWordOffset % (ScalarTySize / 32));
11394
+ if (ShiftVal)
11395
+ Ret = DAG.getNode(ISD::SRL, SL, Ret.getValueType(), Ret,
11396
+ DAG.getConstant(ShiftVal, SL, MVT::i32));
11397
+ return DAG.getBitcastedAnyExtOrTrunc(Ret, SL, MVT::i32);
11398
+ }
11399
+
11400
+ assert(ScalarTySize < 32);
11401
+ auto NumElements = TypeSize / ScalarTySize;
11402
+ auto Trunc32Elements = (ScalarTySize * NumElements) / 32;
11403
+ auto NormalizedTrunc = Trunc32Elements * 32 / ScalarTySize;
11404
+ auto NumElementsIn32 = 32 / ScalarTySize;
11405
+ auto NumAvailElements = DWordOffset < Trunc32Elements
11406
+ ? NumElementsIn32
11407
+ : NumElements - NormalizedTrunc;
11408
+
11409
+ SmallVector<SDValue, 4> VecSrcs;
11410
+ DAG.ExtractVectorElements(Src, VecSrcs, DWordOffset * NumElementsIn32,
11411
+ NumAvailElements);
11412
+
11413
+ Ret = DAG.getBuildVector(
11414
+ MVT::getVectorVT(MVT::getIntegerVT(ScalarTySize), NumAvailElements), SL,
11415
+ VecSrcs);
11416
+ return Ret = DAG.getBitcastedAnyExtOrTrunc(Ret, SL, MVT::i32);
11417
+ }
11418
+
11419
+ /// Scalar Type
11420
+ auto ShiftVal = 32 * DWordOffset;
11421
+ Ret = DAG.getNode(ISD::SRL, SL, Src.getValueType(), Src,
11422
+ DAG.getConstant(ShiftVal, SL, MVT::i32));
11361
11423
return DAG.getBitcastedAnyExtOrTrunc(Ret, SL, MVT::i32);
11362
11424
}
11363
11425
@@ -11426,13 +11488,12 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
11426
11488
return DAG.getBitcast(MVT::getIntegerVT(32), Op);
11427
11489
}
11428
11490
11429
- SDValue OtherOp =
11430
- SecondSrc.has_value() ? *PermNodes[SecondSrc->first].Src : Op;
11491
+ SDValue OtherOp = SecondSrc ? *PermNodes[SecondSrc->first].Src : Op;
11431
11492
11432
- if (SecondSrc)
11493
+ if (SecondSrc) {
11433
11494
OtherOp = getDWordFromOffset(DAG, DL, OtherOp, SecondSrc->second);
11434
-
11435
- assert(Op.getValueSizeInBits() == 32);
11495
+ assert(OtherOp.getValueSizeInBits() == 32);
11496
+ }
11436
11497
11437
11498
if (hasNon16BitAccesses(PermMask, Op, OtherOp)) {
11438
11499
0 commit comments