Skip to content

Commit 15b1f39

Browse files
pytorchbotswolchok
andauthored
Use std::variant to implement pytree Key (#6792)
Pull Request resolved: #6701 Key was a struct that should've been a union; std::variant makes using a union much easier. ghstack-source-id: 253128071 @exported-using-ghexport Differential Revision: [D65575184](https://our.internmc.facebook.com/intern/diff/D65575184/) Co-authored-by: Scott Wolchok <[email protected]>
1 parent 5e03714 commit 15b1f39

File tree

1 file changed

+14
-28
lines changed

1 file changed

+14
-28
lines changed

extension/pytree/pytree.h

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <cstring>
1616
#include <memory>
1717
#include <string>
18+
#include <variant>
1819

1920
// NB: This is a local, pytree FunctionRef and not from the ExecuTorch runtime.
2021
#include <executorch/extension/pytree/function_ref.h>
@@ -55,51 +56,36 @@ using KeyInt = int32_t;
5556
struct Key {
5657
enum class Kind : uint8_t { None, Int, Str } kind_;
5758

58-
KeyInt as_int_ = {};
59-
KeyStr as_str_ = {};
59+
private:
60+
std::variant<std::monostate, KeyInt, KeyStr> repr_;
6061

61-
Key() : kind_(Kind::None) {}
62-
/*implicit*/ Key(KeyInt key) : kind_(Kind::Int), as_int_(std::move(key)) {}
63-
/*implicit*/ Key(KeyStr key) : kind_(Kind::Str), as_str_(std::move(key)) {}
62+
public:
63+
Key() {}
64+
/*implicit*/ Key(KeyInt key) : repr_(key) {}
65+
/*implicit*/ Key(KeyStr key) : repr_(std::move(key)) {}
6466

65-
const Kind& kind() const {
66-
return kind_;
67+
Kind kind() const {
68+
return static_cast<Kind>(repr_.index());
6769
}
6870

69-
const KeyInt& as_int() const {
70-
pytree_assert(kind_ == Key::Kind::Int);
71-
return as_int_;
71+
KeyInt as_int() const {
72+
return std::get<KeyInt>(repr_);
7273
}
7374

74-
operator const KeyInt&() const {
75+
operator KeyInt() const {
7576
return as_int();
7677
}
7778

7879
const KeyStr& as_str() const {
79-
pytree_assert(kind_ == Key::Kind::Str);
80-
return as_str_;
80+
return std::get<KeyStr>(repr_);
8181
}
8282

8383
operator const KeyStr&() const {
8484
return as_str();
8585
}
8686

8787
bool operator==(const Key& rhs) const {
88-
if (kind_ != rhs.kind_) {
89-
return false;
90-
}
91-
switch (kind_) {
92-
case Kind::Str: {
93-
return as_str_ == rhs.as_str_;
94-
}
95-
case Kind::Int: {
96-
return as_int_ == rhs.as_int_;
97-
}
98-
case Kind::None: {
99-
return true;
100-
}
101-
}
102-
pytree_unreachable();
88+
return repr_ == rhs.repr_;
10389
}
10490

10591
bool operator!=(const Key& rhs) const {

0 commit comments

Comments
 (0)