@@ -200,11 +200,17 @@ struct VectorLayout {
200
200
static bool isStructAllVectors (Type *Ty) {
201
201
if (!isa<StructType>(Ty))
202
202
return false ;
203
-
204
- for (unsigned I = 0 ; I < Ty->getNumContainedTypes (); I++)
205
- if (!isa<FixedVectorType>(Ty->getContainedType (I)))
203
+ if (Ty->getNumContainedTypes () < 1 )
204
+ return false ;
205
+ FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType (0 ));
206
+ if (!VecTy)
207
+ return false ;
208
+ unsigned VecSize = VecTy->getNumElements ();
209
+ for (unsigned I = 1 ; I < Ty->getNumContainedTypes (); I++) {
210
+ VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType (I));
211
+ if (!VecTy || VecSize != VecTy->getNumElements ())
206
212
return false ;
207
-
213
+ }
208
214
return true ;
209
215
}
210
216
@@ -679,8 +685,9 @@ bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
679
685
bool ScalarizerVisitor::isTriviallyScalarizable (Intrinsic::ID ID) {
680
686
if (isTriviallyVectorizable (ID))
681
687
return true ;
688
+ // TODO: investigate vectorizable frexp
682
689
switch (ID) {
683
- case Intrinsic::frexp:
690
+ case Intrinsic::frexp:
684
691
return true ;
685
692
}
686
693
return Intrinsic::isTargetIntrinsic (ID) &&
@@ -690,10 +697,10 @@ bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) {
690
697
// / If a call to a vector typed intrinsic function, split into a scalar call per
691
698
// / element if possible for the intrinsic.
692
699
bool ScalarizerVisitor::splitCall (CallInst &CI) {
693
- Type* CallType = CI.getType ();
694
- bool areAllVectors = isStructAllVectors (CallType);
695
- std::optional<VectorSplit> VS;
696
- if (areAllVectors )
700
+ Type * CallType = CI.getType ();
701
+ bool AreAllVectors = isStructAllVectors (CallType);
702
+ std::optional<VectorSplit> VS;
703
+ if (AreAllVectors )
697
704
VS = getVectorSplit (CallType->getContainedType (0 ));
698
705
else
699
706
VS = getVectorSplit (CallType);
@@ -721,12 +728,12 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
721
728
if (isVectorIntrinsicWithOverloadTypeAtArg (ID, -1 ))
722
729
Tys.push_back (VS->SplitTy );
723
730
724
- if (areAllVectors ) {
725
- Type* PrevType = CallType->getContainedType (0 );
726
- Type* CallType = CI.getType ();
727
- for (unsigned I = 1 ; I < CallType->getNumContainedTypes (); I++) {
728
- Type* CurrType = cast<FixedVectorType>(CallType->getContainedType (I));
729
- if (PrevType != CurrType) {
731
+ if (AreAllVectors ) {
732
+ Type * PrevType = CallType->getContainedType (0 );
733
+ Type * CallType = CI.getType ();
734
+ for (unsigned I = 1 ; I < CallType->getNumContainedTypes (); I++) {
735
+ Type * CurrType = cast<FixedVectorType>(CallType->getContainedType (I));
736
+ if (PrevType != CurrType) {
730
737
std::optional<VectorSplit> CurrVS = getVectorSplit (CurrType);
731
738
Tys.push_back (CurrVS->SplitTy );
732
739
PrevType = CurrType;
@@ -1070,7 +1077,7 @@ bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) {
1070
1077
ValueVector Res;
1071
1078
if (!isStructAllVectors (OpTy))
1072
1079
return false ;
1073
- Type* VecType = cast<FixedVectorType>(OpTy->getContainedType (0 ));
1080
+ Type * VecType = cast<FixedVectorType>(OpTy->getContainedType (0 ));
1074
1081
std::optional<VectorSplit> VS = getVectorSplit (VecType);
1075
1082
if (!VS)
1076
1083
return false ;
@@ -1084,7 +1091,7 @@ bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) {
1084
1091
Op0[OpIdx], Index, EVI.getName () + " .elem" + std::to_string (Index));
1085
1092
Res.push_back (ResElem);
1086
1093
}
1087
- // replaceUses(&EVI, Res);
1094
+
1088
1095
gather (&EVI, Res, *VS);
1089
1096
return true ;
1090
1097
}
0 commit comments