@@ -108,8 +108,15 @@ static bool isNeg(Value *V);
108
108
static Value *getNegOperand (Value *V);
109
109
110
110
namespace {
111
+ template <typename T, typename IterT>
112
+ std::optional<T> findCommonBetweenCollections (IterT A, IterT B) {
113
+ auto Common = llvm::find_if (A, [B](T I){return llvm::is_contained (B, I);});
114
+ if (Common != A.end ())
115
+ return std::make_optional (*Common);
116
+ return std::nullopt;
117
+ }
111
118
112
- class ComplexDeinterleavingLegacyPass : public FunctionPass {
119
+ class ComplexDeinterleavingLegacyPass : public FunctionPass {
113
120
public:
114
121
static char ID;
115
122
@@ -337,7 +344,7 @@ class ComplexDeinterleavingGraph {
337
344
NodePtr identifyPartialReduction (Value *R, Value *I);
338
345
NodePtr identifyDotProduct (Value *Inst);
339
346
340
- NodePtr identifyNode (Value *R, Value *I, bool *FromCache = nullptr );
347
+ NodePtr identifyNode (Value *R, Value *I);
341
348
342
349
// / Determine if a sum of complex numbers can be formed from \p RealAddends
343
350
// / and \p ImagAddens. If \p Accumulator is not null, add the result to it.
@@ -902,16 +909,16 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
902
909
903
910
ComplexDeinterleavingGraph::NodePtr
904
911
ComplexDeinterleavingGraph::identifyDotProduct (Value *V) {
905
- auto *Inst = cast<Instruction>(V);
906
912
907
913
if (!TL->isComplexDeinterleavingOperationSupported (
908
- ComplexDeinterleavingOperation::CDot, Inst ->getType ())) {
914
+ ComplexDeinterleavingOperation::CDot, V ->getType ())) {
909
915
LLVM_DEBUG (dbgs () << " Target doesn't support complex deinterleaving "
910
916
" operation CDot with the type "
911
- << *Inst ->getType () << " \n " );
917
+ << *V ->getType () << " \n " );
912
918
return nullptr ;
913
919
}
914
920
921
+ auto *Inst = cast<Instruction>(V);
915
922
auto *RealUser = cast<Instruction>(*Inst->user_begin ());
916
923
917
924
NodePtr CN =
@@ -987,13 +994,26 @@ ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
987
994
BReal = UnwrapCast (BReal);
988
995
BImag = UnwrapCast (BImag);
989
996
990
- bool WasANodeFromCache = false ;
991
- NodePtr Node = identifyNode (AReal, AImag, &WasANodeFromCache);
997
+ VectorType *VTy = cast<VectorType>(V->getType ());
998
+ Type *ExpectedOperandTy = VectorType::getSubdividedVectorType (VTy, 2 );
999
+ if (AReal->getType () != ExpectedOperandTy)
1000
+ return nullptr ;
1001
+ if (AImag->getType () != ExpectedOperandTy)
1002
+ return nullptr ;
1003
+ if (BReal->getType () != ExpectedOperandTy)
1004
+ return nullptr ;
1005
+ if (BImag->getType () != ExpectedOperandTy)
1006
+ return nullptr ;
1007
+
1008
+ if (Phi->getType () != VTy && RealUser->getType () != VTy)
1009
+ return nullptr ;
1010
+
1011
+ NodePtr Node = identifyNode (AReal, AImag);
992
1012
993
1013
// In the case that a node was identified to figure out the rotation, ensure
994
1014
// that trying to identify a node with AReal and AImag post-unwrap results in
995
1015
// the same node
996
- if (Node && ANode && !WasANodeFromCache ) {
1016
+ if (ANode && Node != ANode ) {
997
1017
LLVM_DEBUG (
998
1018
dbgs ()
999
1019
<< " Identified node is different from previously identified node. "
@@ -1010,38 +1030,17 @@ ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
1010
1030
1011
1031
ComplexDeinterleavingGraph::NodePtr
1012
1032
ComplexDeinterleavingGraph::identifyPartialReduction (Value *R, Value *I) {
1013
- if (!I->hasOneUser ())
1033
+ // Partial reductions don't support non-vector types, so check these first
1034
+ if (!isa<VectorType>(R->getType ()) || !isa<VectorType>(I->getType ()))
1014
1035
return nullptr ;
1015
1036
1016
- VectorType *RealTy = dyn_cast<VectorType>(R->getType ());
1017
- if (!RealTy)
1018
- return nullptr ;
1019
- VectorType *ImagTy = dyn_cast<VectorType>(I->getType ());
1020
- if (!ImagTy)
1021
- return nullptr ;
1022
-
1023
- if (RealTy->isScalableTy () != ImagTy->isScalableTy ())
1024
- return nullptr ;
1025
- if (RealTy->getElementType () != ImagTy->getElementType ())
1026
- return nullptr ;
1027
-
1028
- // `I` is known to only have one user, so iterate over the Phi (R) users to
1029
- // find the common user between R and I
1030
- auto *CommonUser = *I->user_begin ();
1031
- bool CommonUserFound = false ;
1032
- for (auto *User : R->users ()) {
1033
- if (User == CommonUser) {
1034
- CommonUserFound = true ;
1035
- break ;
1036
- }
1037
- }
1038
-
1039
- if (!CommonUserFound)
1037
+ auto CommonUser = findCommonBetweenCollections<Value*>(R->users (), I->users ());
1038
+ if (!CommonUser)
1040
1039
return nullptr ;
1041
1040
1042
- auto *IInst = dyn_cast<IntrinsicInst>(CommonUser);
1041
+ auto *IInst = dyn_cast<IntrinsicInst>(* CommonUser);
1043
1042
if (!IInst || IInst->getIntrinsicID () !=
1044
- Intrinsic::experimental_vector_partial_reduce_add)
1043
+ Intrinsic::experimental_vector_partial_reduce_add)
1045
1044
return nullptr ;
1046
1045
1047
1046
if (NodePtr CN = identifyDotProduct (IInst))
@@ -1051,12 +1050,10 @@ ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
1051
1050
}
1052
1051
1053
1052
ComplexDeinterleavingGraph::NodePtr
1054
- ComplexDeinterleavingGraph::identifyNode (Value *R, Value *I, bool *FromCache ) {
1053
+ ComplexDeinterleavingGraph::identifyNode (Value *R, Value *I) {
1055
1054
auto It = CachedResult.find ({R, I});
1056
1055
if (It != CachedResult.end ()) {
1057
1056
LLVM_DEBUG (dbgs () << " - Folding to existing node\n " );
1058
- if (FromCache != nullptr )
1059
- *FromCache = true ;
1060
1057
return It->second ;
1061
1058
}
1062
1059
0 commit comments