Skip to content

Commit bac6a7f

Browse files
committed
Use std::variant to implement pytree Key
Key was a struct that should've been a union; std::variant makes using a union much easier. Differential Revision: [D65575184](https://our.internmc.facebook.com/intern/diff/D65575184/) ghstack-source-id: 252232728 Pull Request resolved: #6701
1 parent 03b1ef2 commit bac6a7f

File tree

1 file changed

+14
-29
lines changed

1 file changed

+14
-29
lines changed

extension/pytree/pytree.h

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <cstring>
1616
#include <memory>
1717
#include <string>
18+
#include <variant>
1819

1920
// NB: This is a local, pytree FunctionRef and not from the ExecuTorch runtime.
2021
#include <executorch/extension/pytree/function_ref.h>
@@ -54,52 +55,36 @@ using KeyInt = int32_t;
5455

5556
struct Key {
5657
enum class Kind : uint8_t { None, Int, Str } kind_;
58+
private:
59+
std::variant<std::monostate, KeyInt, KeyStr> repr_;
5760

58-
KeyInt as_int_ = {};
59-
KeyStr as_str_ = {};
60-
61-
Key() : kind_(Kind::None) {}
62-
/*implicit*/ Key(KeyInt key) : kind_(Kind::Int), as_int_(std::move(key)) {}
63-
/*implicit*/ Key(KeyStr key) : kind_(Kind::Str), as_str_(std::move(key)) {}
61+
public:
62+
Key() {}
63+
/*implicit*/ Key(KeyInt key) : repr_(key) {}
64+
/*implicit*/ Key(KeyStr key) : repr_(std::move(key)) {}
6465

65-
const Kind& kind() const {
66-
return kind_;
66+
Kind kind() const {
67+
return static_cast<Kind>(repr_.index());
6768
}
6869

69-
const KeyInt& as_int() const {
70-
pytree_assert(kind_ == Key::Kind::Int);
71-
return as_int_;
70+
KeyInt as_int() const {
71+
return std::get<KeyInt>(repr_);
7272
}
7373

74-
operator const KeyInt&() const {
74+
operator KeyInt() const {
7575
return as_int();
7676
}
7777

7878
const KeyStr& as_str() const {
79-
pytree_assert(kind_ == Key::Kind::Str);
80-
return as_str_;
79+
return std::get<KeyStr>(repr_);
8180
}
8281

8382
operator const KeyStr&() const {
8483
return as_str();
8584
}
8685

8786
bool operator==(const Key& rhs) const {
88-
if (kind_ != rhs.kind_) {
89-
return false;
90-
}
91-
switch (kind_) {
92-
case Kind::Str: {
93-
return as_str_ == rhs.as_str_;
94-
}
95-
case Kind::Int: {
96-
return as_int_ == rhs.as_int_;
97-
}
98-
case Kind::None: {
99-
return true;
100-
}
101-
}
102-
pytree_unreachable();
87+
return repr_ == rhs.repr_;
10388
}
10489

10590
bool operator!=(const Key& rhs) const {

0 commit comments

Comments
 (0)