@@ -197,6 +197,11 @@ struct VectorLayout {
197
197
uint64_t SplitSize = 0 ;
198
198
};
199
199
200
+ static bool isStructOfVectors (Type *Ty) {
201
+ return isa<StructType>(Ty) && Ty->getNumContainedTypes () > 0 &&
202
+ isa<FixedVectorType>(Ty->getContainedType (0 ));
203
+ }
204
+
200
205
// / Concatenate the given fragments to a single vector value of the type
201
206
// / described in @p VS.
202
207
static Value *concatenate (IRBuilder<> &Builder, ArrayRef<Value *> Fragments,
@@ -276,6 +281,7 @@ class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
276
281
bool visitBitCastInst (BitCastInst &BCI);
277
282
bool visitInsertElementInst (InsertElementInst &IEI);
278
283
bool visitExtractElementInst (ExtractElementInst &EEI);
284
+ bool visitExtractValueInst (ExtractValueInst &EVI);
279
285
bool visitShuffleVectorInst (ShuffleVectorInst &SVI);
280
286
bool visitPHINode (PHINode &PHI);
281
287
bool visitLoadInst (LoadInst &LI);
@@ -552,7 +558,10 @@ void ScalarizerVisitor::transferMetadataAndIRFlags(Instruction *Op,
552
558
// Determine how Ty is split, if at all.
553
559
std::optional<VectorSplit> ScalarizerVisitor::getVectorSplit (Type *Ty) {
554
560
VectorSplit Split;
555
- Split.VecTy = dyn_cast<FixedVectorType>(Ty);
561
+ if (isStructOfVectors (Ty))
562
+ Split.VecTy = cast<FixedVectorType>(Ty->getContainedType (0 ));
563
+ else
564
+ Split.VecTy = dyn_cast<FixedVectorType>(Ty);
556
565
if (!Split.VecTy )
557
566
return {};
558
567
@@ -1030,6 +1039,33 @@ bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
1030
1039
return true ;
1031
1040
}
1032
1041
1042
+ bool ScalarizerVisitor::visitExtractValueInst (ExtractValueInst &EVI) {
1043
+ Value *Op = EVI.getOperand (0 );
1044
+ Type *OpTy = Op->getType ();
1045
+ ValueVector Res;
1046
+ if (!isStructOfVectors (OpTy))
1047
+ return false ;
1048
+ // Note: isStructOfVectors is also used in getVectorSplit.
1049
+ // The intent is to bail on this visit if it isn't a struct
1050
+ // of vectors. Downside is that when it is true we do two
1051
+ // isStructOfVectors calls.
1052
+ std::optional<VectorSplit> VS = getVectorSplit (OpTy);
1053
+ if (!VS)
1054
+ return false ;
1055
+ Scatterer Op0 = scatter (&EVI, Op, *VS);
1056
+ assert (!EVI.getIndices ().empty () && " Make sure an index exists" );
1057
+ // Note for our use case we only care about the top level index.
1058
+ unsigned Index = EVI.getIndices ()[0 ];
1059
+ for (unsigned OpIdx = 0 ; OpIdx < Op0.size (); ++OpIdx) {
1060
+ Value *ResElem = Builder.CreateExtractValue (
1061
+ Op0[OpIdx], Index, EVI.getName () + " .elem" + std::to_string (Index));
1062
+ Res.push_back (ResElem);
1063
+ }
1064
+ // replaceUses(&EVI, Res);
1065
+ gather (&EVI, Res, *VS);
1066
+ return true ;
1067
+ }
1068
+
1033
1069
bool ScalarizerVisitor::visitExtractElementInst (ExtractElementInst &EEI) {
1034
1070
std::optional<VectorSplit> VS = getVectorSplit (EEI.getOperand (0 )->getType ());
1035
1071
if (!VS)
@@ -1196,7 +1232,7 @@ bool ScalarizerVisitor::finish() {
1196
1232
if (!Op->use_empty ()) {
1197
1233
// The value is still needed, so recreate it using a series of
1198
1234
// insertelements and/or shufflevectors.
1199
- Value *Res;
1235
+ Value *Res = nullptr ;
1200
1236
if (auto *Ty = dyn_cast<FixedVectorType>(Op->getType ())) {
1201
1237
BasicBlock *BB = Op->getParent ();
1202
1238
IRBuilder<> Builder (Op);
@@ -1209,6 +1245,35 @@ bool ScalarizerVisitor::finish() {
1209
1245
Res = concatenate (Builder, CV, VS, Op->getName ());
1210
1246
1211
1247
Res->takeName (Op);
1248
+ } else if (auto *Ty = dyn_cast<StructType>(Op->getType ())) {
1249
+ BasicBlock *BB = Op->getParent ();
1250
+ IRBuilder<> Builder (Op);
1251
+ if (isa<PHINode>(Op))
1252
+ Builder.SetInsertPoint (BB, BB->getFirstInsertionPt ());
1253
+
1254
+ // Iterate over each element in the struct
1255
+ uint NumOfStructElements = Ty->getNumElements ();
1256
+ SmallVector<ValueVector, 4 > ElemCV (NumOfStructElements);
1257
+ for (unsigned I = 0 ; I < NumOfStructElements; ++I) {
1258
+ for (auto *CVelem : CV) {
1259
+ Value *Elem = Builder.CreateExtractValue (
1260
+ CVelem, I, Op->getName () + " .elem" + std::to_string (I));
1261
+ ElemCV[I].push_back (Elem);
1262
+ }
1263
+ }
1264
+ Res = PoisonValue::get (Ty);
1265
+ for (unsigned I = 0 ; I < NumOfStructElements; ++I) {
1266
+ Type *ElemTy = Ty->getElementType (I);
1267
+ assert (isa<FixedVectorType>(ElemTy) &&
1268
+ " Only Structs of all FixedVectorType supported" );
1269
+ VectorSplit VS = *getVectorSplit (ElemTy);
1270
+ assert (VS.NumFragments == CV.size ());
1271
+
1272
+ Value *ConcatenatedVector =
1273
+ concatenate (Builder, ElemCV[I], VS, Op->getName ());
1274
+ Res = Builder.CreateInsertValue (Res, ConcatenatedVector, I,
1275
+ Op->getName () + " .insert" );
1276
+ }
1212
1277
} else {
1213
1278
assert (CV.size () == 1 && Op->getType () == CV[0 ]->getType ());
1214
1279
Res = CV[0 ];
0 commit comments