Skip to content

[EquivalenceClasses] Use DenseMap instead of std::set. (NFC) #134264

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 24 additions & 42 deletions llvm/include/llvm/ADT/EquivalenceClasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
#ifndef LLVM_ADT_EQUIVALENCECLASSES_H
#define LLVM_ADT_EQUIVALENCECLASSES_H

#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Allocator.h"
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <set>

namespace llvm {

Expand All @@ -33,8 +34,7 @@ namespace llvm {
///
/// This implementation is an efficient implementation that only stores one copy
/// of the element being indexed per entry in the set, and allows any arbitrary
/// type to be indexed (as long as it can be ordered with operator< or a
/// comparator is provided).
/// type to be indexed (as long as it can be implements DenseMapInfo).
///
/// Here is a simple example using integers:
///
Expand All @@ -58,18 +58,17 @@ namespace llvm {
/// 4
/// 5 1 2
///
template <class ElemTy, class Compare = std::less<ElemTy>>
class EquivalenceClasses {
template <class ElemTy> class EquivalenceClasses {
/// ECValue - The EquivalenceClasses data structure is just a set of these.
/// Each of these represents a relation for a value. First it stores the
/// value itself, which provides the ordering that the set queries. Next, it
/// provides a "next pointer", which is used to enumerate all of the elements
/// in the unioned set. Finally, it defines either a "end of list pointer" or
/// "leader pointer" depending on whether the value itself is a leader. A
/// "leader pointer" points to the node that is the leader for this element,
/// if the node is not a leader. A "end of list pointer" points to the last
/// node in the list of members of this list. Whether or not a node is a
/// leader is determined by a bit stolen from one of the pointers.
/// value itself. Next, it provides a "next pointer", which is used to
/// enumerate all of the elements in the unioned set. Finally, it defines
/// either a "end of list pointer" or "leader pointer" depending on whether
/// the value itself is a leader. A "leader pointer" points to the node that
/// is the leader for this element, if the node is not a leader. A "end of
/// list pointer" points to the last node in the list of members of this list.
/// Whether or not a node is a leader is determined by a bit stolen from one
/// of the pointers.
class ECValue {
friend class EquivalenceClasses;

Expand Down Expand Up @@ -113,36 +112,15 @@ class EquivalenceClasses {
}
};

/// A wrapper of the comparator, to be passed to the set.
struct ECValueComparator {
using is_transparent = void;

ECValueComparator() : compare(Compare()) {}

bool operator()(const ECValue &lhs, const ECValue &rhs) const {
return compare(lhs.Data, rhs.Data);
}

template <typename T>
bool operator()(const T &lhs, const ECValue &rhs) const {
return compare(lhs, rhs.Data);
}

template <typename T>
bool operator()(const ECValue &lhs, const T &rhs) const {
return compare(lhs.Data, rhs);
}

const Compare compare;
};

/// TheMapping - This implicitly provides a mapping from ElemTy values to the
/// ECValues, it just keeps the key as part of the value.
std::set<ECValue, ECValueComparator> TheMapping;
DenseMap<ElemTy, ECValue *> TheMapping;

/// List of all members, used to provide a determinstic iteration order.
SmallVector<const ECValue *> Members;

mutable BumpPtrAllocator ECValueAllocator;

public:
EquivalenceClasses() = default;
EquivalenceClasses(const EquivalenceClasses &RHS) {
Expand Down Expand Up @@ -232,10 +210,14 @@ class EquivalenceClasses {
/// insert - Insert a new value into the union/find set, ignoring the request
/// if the value already exists.
const ECValue &insert(const ElemTy &Data) {
auto I = TheMapping.insert(ECValue(Data));
if (I.second)
Members.push_back(&*I.first);
return *I.first;
auto I = TheMapping.insert({Data, nullptr});
if (!I.second)
return *I.first->second;

auto *ECV = new (ECValueAllocator) ECValue(Data);
I.first->second = ECV;
Members.push_back(ECV);
return *ECV;
}

/// findLeader - Given a value in the set, return a member iterator for the
Expand All @@ -246,7 +228,7 @@ class EquivalenceClasses {
auto I = TheMapping.find(V);
if (I == TheMapping.end())
return member_iterator(nullptr);
return findLeader(*I);
return findLeader(*I->second);
}
member_iterator findLeader(const ECValue &ECV) const {
return member_iterator(ECV.getLeader());
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "llvm/Transforms/Utils/LoopUtils.h"
#include <numeric>
#include <queue>
#include <set>

#define DEBUG_TYPE "vector-combine"
#include "llvm/Transforms/Utils/InstructionWorklist.h"
Expand Down
26 changes: 0 additions & 26 deletions llvm/unittests/ADT/EquivalenceClassesTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,30 +109,4 @@ TYPED_TEST_P(ParameterizedTest, MultipleSets) {
EXPECT_FALSE(EqClasses.isEquivalent(i, j));
}

namespace {
// A dummy struct for testing EquivalenceClasses with a comparator.
struct TestStruct {
TestStruct(int value) : value(value) {}

bool operator==(const TestStruct &other) const {
return value == other.value;
}

int value;
};
// Comparator to be used in test case.
struct TestStructComparator {
bool operator()(const TestStruct &lhs, const TestStruct &rhs) const {
return lhs.value < rhs.value;
}
};
} // namespace

REGISTER_TYPED_TEST_SUITE_P(ParameterizedTest, MultipleSets);
using ParamTypes =
testing::Types<EquivalenceClasses<int>,
EquivalenceClasses<TestStruct, TestStructComparator>>;
INSTANTIATE_TYPED_TEST_SUITE_P(EquivalenceClassesTest, ParameterizedTest,
ParamTypes, );

} // llvm
Original file line number Diff line number Diff line change
Expand Up @@ -224,17 +224,8 @@ class OneShotAnalysisState : public AnalysisState {
}

private:
/// llvm::EquivalenceClasses wants comparable elements. This comparator uses
/// pointer comparison on the defining op. This is a poor man's comparison
/// but it's not like UnionFind needs ordering anyway.
struct ValueComparator {
bool operator()(const Value &lhs, const Value &rhs) const {
return lhs.getImpl() < rhs.getImpl();
}
};

using EquivalenceClassRangeType = llvm::iterator_range<
llvm::EquivalenceClasses<Value, ValueComparator>::member_iterator>;
using EquivalenceClassRangeType =
llvm::iterator_range<llvm::EquivalenceClasses<Value>::member_iterator>;
/// Check that aliasInfo for `v` exists and return a reference to it.
EquivalenceClassRangeType getAliases(Value v) const;

Expand All @@ -249,15 +240,15 @@ class OneShotAnalysisState : public AnalysisState {
/// value may alias with one of multiple other values. The concrete aliasing
/// value may not even be known at compile time. All such values are
/// considered to be aliases.
llvm::EquivalenceClasses<Value, ValueComparator> aliasInfo;
llvm::EquivalenceClasses<Value> aliasInfo;

/// Auxiliary structure to store all the equivalent buffer classes. Equivalent
/// buffer information is "must be" conservative: Only if two values are
/// guaranteed to be equivalent at runtime, they said to be equivalent. It is
/// possible that, in the presence of branches, it cannot be determined
/// statically if two values are equivalent. In that case, the values are
/// considered to be not equivalent.
llvm::EquivalenceClasses<Value, ValueComparator> equivalentInfo;
llvm::EquivalenceClasses<Value> equivalentInfo;

// Bufferization statistics.
int64_t statNumTensorOutOfPlace = 0;
Expand Down
Loading