Skip to content

Commit 876a43c

Browse files
committed
Sync with Tomli
1 parent 1c9b341 commit 876a43c

File tree

3 files changed

+39
-5
lines changed

3 files changed

+39
-5
lines changed

Doc/library/tomllib.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ This module defines the following functions:
4444
float to be decoded. By default, this is equivalent to ``float(num_str)``.
4545
This can be used to use another datatype or parser for TOML floats
4646
(e.g. :class:`decimal.Decimal`). The callable must not return a
47-
:class:`dict`, a :class:`list`, or anything that has the ``append``
48-
attribute. These illegal types produce undefined behavior.
47+
:class:`dict` or a :class:`list`, else a :exc:`ValueError` is raised.
4948

5049
A :exc:`TOMLDecodeError` will be raised on an invalid TOML document.
5150

Lib/test/test_tomllib/test_error.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,17 @@ def test_invalid_char_quotes(self):
4141

4242
def test_module_name(self):
4343
self.assertEqual(tomllib.TOMLDecodeError().__module__, tomllib.__name__)
44+
45+
def test_invalid_parse_float(self):
46+
def dict_returner(s: str) -> dict:
47+
return {}
48+
49+
def list_returner(s: str) -> list:
50+
return []
51+
52+
for invalid_parse_float in (dict_returner, list_returner):
53+
with self.assertRaises(ValueError) as exc_info:
54+
tomllib.loads("f=0.1", parse_float=invalid_parse_float)
55+
self.assertEqual(
56+
str(exc_info.exception), "parse_float must not return dicts or lists"
57+
)

Lib/tomllib/_parser.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def loads(s: str, /, *, parse_float: ParseFloat = float) -> dict[str, Any]: # n
7575
pos = 0
7676
out = Output(NestedDict(), Flags())
7777
header: Key = ()
78+
parse_float = make_safe_parse_float(parse_float)
7879

7980
# Parse one statement at a time
8081
# (typically means one line in TOML source)
@@ -216,10 +217,9 @@ def append_nest_to_list(self, key: Key) -> None:
216217
last_key = key[-1]
217218
if last_key in cont:
218219
list_ = cont[last_key]
219-
try:
220-
list_.append({})
221-
except AttributeError:
220+
if not isinstance(list_, list):
222221
raise KeyError("An object other than list found behind this key")
222+
list_.append({})
223223
else:
224224
cont[last_key] = [{}]
225225

@@ -668,3 +668,24 @@ def coord_repr(src: str, pos: Pos) -> str:
668668

669669
def is_unicode_scalar_value(codepoint: int) -> bool:
670670
return (0 <= codepoint <= 55295) or (57344 <= codepoint <= 1114111)
671+
672+
673+
def make_safe_parse_float(parse_float: ParseFloat) -> ParseFloat:
674+
"""A decorator to make `parse_float` safe.
675+
676+
`parse_float` must not return dicts or lists, because these types
677+
would be mixed with parsed TOML tables and arrays, thus confusing
678+
the parser. The returned decorated callable raises `ValueError`
679+
instead of returning illegal types.
680+
"""
681+
# The default `float` callable never returns illegal types. Optimize it.
682+
if parse_float is float: # type: ignore[comparison-overlap]
683+
return float
684+
685+
def safe_parse_float(float_str: str) -> Any:
686+
float_value = parse_float(float_str)
687+
if isinstance(float_value, (dict, list)):
688+
raise ValueError("parse_float must not return dicts or lists")
689+
return float_value
690+
691+
return safe_parse_float

0 commit comments

Comments
 (0)