Skip to content

Commit 391bd68

Browse files
pytorchbotGithub Executorch
andauthored
add a bunch of bounds checking to pytree
Pull Request resolved: #7654 It's possible to pass arbitrary string input to pytree from Python; let's not have a bunch of low-hanging memory safety issues. ghstack-source-id: 265152272 @exported-using-ghexport Differential Revision: [D68166303](https://our.internmc.facebook.com/intern/diff/D68166303/) Co-authored-by: Github Executorch <[email protected]>
1 parent 456928f commit 391bd68

File tree

2 files changed

+71
-14
lines changed

2 files changed

+71
-14
lines changed

extension/pytree/pytree.h

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <cstdint>
1515
#include <cstring>
1616
#include <memory>
17+
#include <stdexcept>
1718
#include <string>
1819
#include <variant>
1920

@@ -60,7 +61,7 @@ struct Key {
6061
std::variant<std::monostate, KeyInt, KeyStr> repr_;
6162

6263
public:
63-
Key() {}
64+
Key() = default;
6465
/*implicit*/ Key(KeyInt key) : repr_(key) {}
6566
/*implicit*/ Key(KeyStr key) : repr_(std::move(key)) {}
6667

@@ -131,7 +132,7 @@ struct ContainerHandle {
131132
using leaf_type = T;
132133
std::unique_ptr<container_type> handle;
133134

134-
ContainerHandle() {}
135+
ContainerHandle() = default;
135136

136137
template <typename... Args>
137138
ContainerHandle(Args... args)
@@ -427,6 +428,22 @@ struct arr {
427428
return data_[idx];
428429
}
429430

431+
T& at(size_t idx) {
432+
if (idx >= size()) {
433+
throw std::out_of_range(
434+
"bounds check failed in pytree arr at index " + std::to_string(idx));
435+
}
436+
return data_[idx];
437+
}
438+
439+
const T& at(size_t idx) const {
440+
if (idx >= size()) {
441+
throw std::out_of_range(
442+
"bounds check failed in pytree arr at index " + std::to_string(idx));
443+
}
444+
return data_[idx];
445+
}
446+
430447
inline T* data() {
431448
return data_.get();
432449
}
@@ -458,7 +475,7 @@ struct arr {
458475

459476
inline size_t read_number(const StrTreeSpec& spec, size_t& read_idx) {
460477
size_t num = 0;
461-
while (isdigit(spec[read_idx])) {
478+
while (isdigit(spec.at(read_idx))) {
462479
num = 10 * num + (spec[read_idx] - '0');
463480
read_idx++;
464481
}
@@ -470,19 +487,22 @@ inline arr<size_t> read_node_layout(const StrTreeSpec& spec, size_t& read_idx) {
470487
arr<size_t> ret(child_num);
471488

472489
size_t child_idx = 0;
473-
while (spec[read_idx] == Config::kChildrenDataSep) {
490+
while (spec.at(read_idx) == Config::kChildrenDataSep) {
474491
++read_idx;
475-
ret[child_idx++] = read_number(spec, read_idx);
492+
ret.at(child_idx++) = read_number(spec, read_idx);
476493
}
477494
return ret;
478495
}
479496

497+
// spec_data comes from pre_parse, which guarantees 1)
498+
// spec_data.size() == spec.size() and 2) contents of spec_data are
499+
// in-bounds indices for spec, so we omit bounds checks for spec_data.
480500
template <typename Aux>
481501
TreeSpec<Aux> from_str_internal(
482502
const StrTreeSpec& spec,
483503
size_t read_idx,
484504
const arr<size_t>& spec_data) {
485-
const auto kind_char = spec[read_idx];
505+
const auto kind_char = spec.at(read_idx);
486506
switch (kind_char) {
487507
case Config::kTuple:
488508
case Config::kNamedTuple:
@@ -496,7 +516,7 @@ TreeSpec<Aux> from_str_internal(
496516
} else if (Config::kCustom == kind_char) {
497517
kind = Kind::Custom;
498518
read_idx++;
499-
assert(spec[read_idx] == '(');
519+
assert(spec.at(read_idx) == '(');
500520
auto type_str_end = spec_data[read_idx];
501521
read_idx++;
502522
custom_type = spec.substr(read_idx, type_str_end - read_idx);
@@ -515,10 +535,15 @@ TreeSpec<Aux> from_str_internal(
515535
size_t leaves_offset = 0;
516536

517537
if (size > 0) {
518-
while (spec[read_idx] != Config::kNodeDataEnd) {
538+
while (spec.at(read_idx) != Config::kNodeDataEnd) {
519539
// NOLINTNEXTLINE
520540
auto next_delim_idx = spec_data[read_idx];
521541
read_idx++;
542+
if (child_idx >= size) {
543+
throw std::out_of_range(
544+
"bounds check failed writing to pytree item at index " +
545+
std::to_string(child_idx));
546+
}
522547
c->items[child_idx] =
523548
from_str_internal<Aux>(spec, read_idx, spec_data);
524549
read_idx = next_delim_idx;
@@ -541,11 +566,16 @@ TreeSpec<Aux> from_str_internal(
541566
size_t leaves_offset = 0;
542567

543568
if (size > 0) {
544-
while (spec[read_idx] != Config::kNodeDataEnd) {
569+
while (spec.at(read_idx) != Config::kNodeDataEnd) {
545570
// NOLINTNEXTLINE
546571
auto next_delim_idx = spec_data[read_idx];
547572
read_idx++;
548-
if (spec[read_idx] == Config::kDictStrKeyQuote) {
573+
if (child_idx >= size) {
574+
throw std::out_of_range(
575+
"bounds check failed decoding pytree dict at index " +
576+
std::to_string(child_idx));
577+
}
578+
if (spec.at(read_idx) == Config::kDictStrKeyQuote) {
549579
auto key_delim_idx = spec_data[read_idx];
550580
read_idx++;
551581
const size_t key_len = key_delim_idx - read_idx;
@@ -562,7 +592,7 @@ TreeSpec<Aux> from_str_internal(
562592
c->items[child_idx] =
563593
from_str_internal<Aux>(spec, read_idx, spec_data);
564594
read_idx = next_delim_idx;
565-
leaves_offset += layout[child_idx++];
595+
leaves_offset += layout.at(child_idx++);
566596
}
567597
} else {
568598
read_idx++;
@@ -605,7 +635,9 @@ struct stack final {
605635
}
606636
};
607637

638+
// We guarantee indicies in the result are in bounds.
608639
inline arr<size_t> pre_parse(const StrTreeSpec& spec) {
640+
// Invariant: indices in stack are in bounds.
609641
stack<std::pair<size_t, size_t>> stack;
610642
size_t i = 0;
611643
const size_t size = spec.size();
@@ -627,11 +659,16 @@ inline arr<size_t> pre_parse(const StrTreeSpec& spec) {
627659
case Config::kDictStrKeyQuote: {
628660
size_t idx = i;
629661
i++;
630-
while (spec[i] != Config::kDictStrKeyQuote) {
662+
while (spec.at(i) != Config::kDictStrKeyQuote) {
631663
i++;
632664
}
633-
ret[idx] = i;
634-
ret[i] = idx;
665+
if (i >= size) {
666+
throw std::out_of_range(
667+
"bounds check failed while parsing dictionary key at index " +
668+
std::to_string(i));
669+
}
670+
ret.at(idx) = i;
671+
ret.at(i) = idx;
635672
break;
636673
}
637674
case Config::kChildrenSep: {

extension/pytree/test/test_pytree.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ using Leaf = int32_t;
2222
TEST(PyTreeTest, ArrBasic) {
2323
arr<int> x(5);
2424
ASSERT_EQ(x.size(), 5);
25+
EXPECT_THROW(x.at(5), std::out_of_range);
2526
for (int ii = 0; ii < x.size(); ++ii) {
2627
x[ii] = 2 * ii;
2728
}
@@ -197,3 +198,22 @@ TEST(pytree, FlattenNestedDict) {
197198
ASSERT_EQ(*leaves[i], items[i]);
198199
}
199200
}
201+
202+
TEST(pytree, EmptySpec) {
203+
Leaf items[1] = {9};
204+
EXPECT_THROW(unflatten("", items), std::out_of_range);
205+
}
206+
207+
TEST(pytree, BoundsCheckListLayout) {
208+
// Malformed: layout one child, have two
209+
std::string spec = "L1#1($,$)";
210+
Leaf items[2] = {11, 12};
211+
EXPECT_THROW(unflatten(spec, items), std::out_of_range);
212+
}
213+
214+
TEST(pytree, BoundsCheckDictLayout) {
215+
// Malformed: layout one child, have two.
216+
std::string spec = "D1#1('key0':$,'key1':$)";
217+
Leaf items[2] = {11, 12};
218+
EXPECT_THROW(unflatten(spec, items), std::out_of_range);
219+
}

0 commit comments

Comments
 (0)