Skip to content

Commit 646ffe3

Browse files
committed
Address comments, and refactor where certain checks are performed
1 parent 073fa96 commit 646ffe3

File tree

3 files changed

+255
-38
lines changed

3 files changed

+255
-38
lines changed

llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp

Lines changed: 35 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,15 @@ static bool isNeg(Value *V);
108108
static Value *getNegOperand(Value *V);
109109

110110
namespace {
111+
template<typename T, typename IterT>
112+
std::optional<T> findCommonBetweenCollections(IterT A, IterT B) {
113+
auto Common = llvm::find_if(A, [B](T I){return llvm::is_contained(B, I);});
114+
if (Common != A.end())
115+
return std::make_optional(*Common);
116+
return std::nullopt;
117+
}
111118

112-
class ComplexDeinterleavingLegacyPass : public FunctionPass {
119+
class ComplexDeinterleavingLegacyPass : public FunctionPass {
113120
public:
114121
static char ID;
115122

@@ -337,7 +344,7 @@ class ComplexDeinterleavingGraph {
337344
NodePtr identifyPartialReduction(Value *R, Value *I);
338345
NodePtr identifyDotProduct(Value *Inst);
339346

340-
NodePtr identifyNode(Value *R, Value *I, bool *FromCache = nullptr);
347+
NodePtr identifyNode(Value *R, Value *I);
341348

342349
/// Determine if a sum of complex numbers can be formed from \p RealAddends
343350
/// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
@@ -902,16 +909,16 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
902909

903910
ComplexDeinterleavingGraph::NodePtr
904911
ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
905-
auto *Inst = cast<Instruction>(V);
906912

907913
if (!TL->isComplexDeinterleavingOperationSupported(
908-
ComplexDeinterleavingOperation::CDot, Inst->getType())) {
914+
ComplexDeinterleavingOperation::CDot, V->getType())) {
909915
LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving "
910916
"operation CDot with the type "
911-
<< *Inst->getType() << "\n");
917+
<< *V->getType() << "\n");
912918
return nullptr;
913919
}
914920

921+
auto *Inst = cast<Instruction>(V);
915922
auto *RealUser = cast<Instruction>(*Inst->user_begin());
916923

917924
NodePtr CN =
@@ -987,13 +994,26 @@ ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
987994
BReal = UnwrapCast(BReal);
988995
BImag = UnwrapCast(BImag);
989996

990-
bool WasANodeFromCache = false;
991-
NodePtr Node = identifyNode(AReal, AImag, &WasANodeFromCache);
997+
VectorType *VTy = cast<VectorType>(V->getType());
998+
Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2);
999+
if (AReal->getType() != ExpectedOperandTy)
1000+
return nullptr;
1001+
if (AImag->getType() != ExpectedOperandTy)
1002+
return nullptr;
1003+
if (BReal->getType() != ExpectedOperandTy)
1004+
return nullptr;
1005+
if (BImag->getType() != ExpectedOperandTy)
1006+
return nullptr;
1007+
1008+
if (Phi->getType() != VTy && RealUser->getType() != VTy)
1009+
return nullptr;
1010+
1011+
NodePtr Node = identifyNode(AReal, AImag);
9921012

9931013
// In the case that a node was identified to figure out the rotation, ensure
9941014
// that trying to identify a node with AReal and AImag post-unwrap results in
9951015
// the same node
996-
if (Node && ANode && !WasANodeFromCache) {
1016+
if (ANode && Node != ANode) {
9971017
LLVM_DEBUG(
9981018
dbgs()
9991019
<< "Identified node is different from previously identified node. "
@@ -1010,38 +1030,17 @@ ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
10101030

10111031
ComplexDeinterleavingGraph::NodePtr
10121032
ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
1013-
if (!I->hasOneUser())
1033+
// Partial reductions don't support non-vector types, so check these first
1034+
if (!isa<VectorType>(R->getType()) || !isa<VectorType>(I->getType()))
10141035
return nullptr;
10151036

1016-
VectorType *RealTy = dyn_cast<VectorType>(R->getType());
1017-
if (!RealTy)
1018-
return nullptr;
1019-
VectorType *ImagTy = dyn_cast<VectorType>(I->getType());
1020-
if (!ImagTy)
1021-
return nullptr;
1022-
1023-
if (RealTy->isScalableTy() != ImagTy->isScalableTy())
1024-
return nullptr;
1025-
if (RealTy->getElementType() != ImagTy->getElementType())
1026-
return nullptr;
1027-
1028-
// `I` is known to only have one user, so iterate over the Phi (R) users to
1029-
// find the common user between R and I
1030-
auto *CommonUser = *I->user_begin();
1031-
bool CommonUserFound = false;
1032-
for (auto *User : R->users()) {
1033-
if (User == CommonUser) {
1034-
CommonUserFound = true;
1035-
break;
1036-
}
1037-
}
1038-
1039-
if (!CommonUserFound)
1037+
auto CommonUser = findCommonBetweenCollections<Value*>(R->users(), I->users());
1038+
if (!CommonUser)
10401039
return nullptr;
10411040

1042-
auto *IInst = dyn_cast<IntrinsicInst>(CommonUser);
1041+
auto *IInst = dyn_cast<IntrinsicInst>(*CommonUser);
10431042
if (!IInst || IInst->getIntrinsicID() !=
1044-
Intrinsic::experimental_vector_partial_reduce_add)
1043+
Intrinsic::experimental_vector_partial_reduce_add)
10451044
return nullptr;
10461045

10471046
if (NodePtr CN = identifyDotProduct(IInst))
@@ -1051,12 +1050,10 @@ ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
10511050
}
10521051

10531052
ComplexDeinterleavingGraph::NodePtr
1054-
ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I, bool *FromCache) {
1053+
ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
10551054
auto It = CachedResult.find({R, I});
10561055
if (It != CachedResult.end()) {
10571056
LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
1058-
if (FromCache != nullptr)
1059-
*FromCache = true;
10601057
return It->second;
10611058
}
10621059

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29414,6 +29414,10 @@ bool AArch64TargetLowering::isComplexDeinterleavingOperationSupported(
2941429414
return 8 <= ScalarWidth && ScalarWidth <= 64;
2941529415
}
2941629416

29417+
// CDot is not supported outside of scalable/sve scopes
29418+
if (Operation == ComplexDeinterleavingOperation::CDot)
29419+
return false;
29420+
2941729421
return (ScalarTy->isHalfTy() && Subtarget->hasFullFP16()) ||
2941829422
ScalarTy->isFloatTy() || ScalarTy->isDoubleTy();
2941929423
}

0 commit comments

Comments
 (0)