1 # SPDX-License-Identifier: MIT
2 # SPDX-FileCopyrightText: 2021 Taneli Hukkinen
3 # Licensed to PSF under a Contributor Agreement.
5 from __future__
import annotations
7 from collections
.abc
import Iterable
9 from types
import MappingProxyType
10 from typing
import Any
, BinaryIO
, NamedTuple
20 from ._types
import Key
, ParseFloat
, Pos
22 ASCII_CTRL
= frozenset(chr(i
) for i
in range(32)) |
frozenset(chr(127))
24 # Neither of these sets include quotation mark or backslash. They are
25 # currently handled as separate cases in the parser functions.
26 ILLEGAL_BASIC_STR_CHARS
= ASCII_CTRL
- frozenset("\t")
27 ILLEGAL_MULTILINE_BASIC_STR_CHARS
= ASCII_CTRL
- frozenset("\t\n")
29 ILLEGAL_LITERAL_STR_CHARS
= ILLEGAL_BASIC_STR_CHARS
30 ILLEGAL_MULTILINE_LITERAL_STR_CHARS
= ILLEGAL_MULTILINE_BASIC_STR_CHARS
32 ILLEGAL_COMMENT_CHARS
= ILLEGAL_BASIC_STR_CHARS
34 TOML_WS
= frozenset(" \t")
35 TOML_WS_AND_NEWLINE
= TOML_WS |
frozenset("\n")
36 BARE_KEY_CHARS
= frozenset(string
.ascii_letters
+ string
.digits
+ "-_")
37 KEY_INITIAL_CHARS
= BARE_KEY_CHARS |
frozenset("\"'")
38 HEXDIGIT_CHARS
= frozenset(string
.hexdigits
)
40 BASIC_STR_ESCAPE_REPLACEMENTS
= MappingProxyType(
42 "\\b": "\u0008", # backspace
43 "\\t": "\u0009", # tab
44 "\\n": "\u000A", # linefeed
45 "\\f": "\u000C", # form feed
46 "\\r": "\u000D", # carriage return
47 '\\"': "\u0022", # quote
48 "\\\\": "\u005C", # backslash
53 class TOMLDecodeError(ValueError):
54 """An error raised if a document is not valid TOML."""
57 def load(__fp
: BinaryIO
, *, parse_float
: ParseFloat
= float) -> dict[str, Any
]:
58 """Parse TOML from a binary file object."""
62 except AttributeError:
64 "File must be opened in binary mode, e.g. use `open('foo.toml', 'rb')`"
66 return loads(s
, parse_float
=parse_float
)
69 def loads(__s
: str, *, parse_float
: ParseFloat
= float) -> dict[str, Any
]: # noqa: C901
70 """Parse TOML from a string."""
72 # The spec allows converting "\r\n" to "\n", even in string
73 # literals. Let's do so to simplify parsing.
74 src
= __s
.replace("\r\n", "\n")
76 out
= Output(NestedDict(), Flags())
78 parse_float
= make_safe_parse_float(parse_float
)
80 # Parse one statement at a time
81 # (typically means one line in TOML source)
83 # 1. Skip line leading whitespace
84 pos
= skip_chars(src
, pos
, TOML_WS
)
86 # 2. Parse rules. Expect one of the following:
91 # - append dict to list (and move to its namespace)
92 # - create dict (and move to its namespace)
93 # Skip trailing whitespace when applicable.
101 if char
in KEY_INITIAL_CHARS
:
102 pos
= key_value_rule(src
, pos
, out
, header
, parse_float
)
103 pos
= skip_chars(src
, pos
, TOML_WS
)
106 second_char
: str |
None = src
[pos
+ 1]
109 out
.flags
.finalize_pending()
110 if second_char
== "[":
111 pos
, header
= create_list_rule(src
, pos
, out
)
113 pos
, header
= create_dict_rule(src
, pos
, out
)
114 pos
= skip_chars(src
, pos
, TOML_WS
)
116 raise suffixed_err(src
, pos
, "Invalid statement")
119 pos
= skip_comment(src
, pos
)
121 # 4. Expect end of line or end of file
128 src
, pos
, "Expected newline or end of document after a statement"
136 """Flags that map to parsed keys/namespaces."""
138 # Marks an immutable namespace (inline array or inline table).
140 # Marks a nest that has been explicitly created and can no longer
141 # be opened using the "[table]" syntax.
144 def __init__(self
) -> None:
145 self
._flags
: dict[str, dict] = {}
146 self
._pending
_flags
: set[tuple[Key
, int]] = set()
148 def add_pending(self
, key
: Key
, flag
: int) -> None:
149 self
._pending
_flags
.add((key
, flag
))
151 def finalize_pending(self
) -> None:
152 for key
, flag
in self
._pending
_flags
:
153 self
.set(key
, flag
, recursive
=False)
154 self
._pending
_flags
.clear()
156 def unset_all(self
, key
: Key
) -> None:
161 cont
= cont
[k
]["nested"]
162 cont
.pop(key
[-1], None)
164 def set(self
, key
: Key
, flag
: int, *, recursive
: bool) -> None: # noqa: A003
166 key_parent
, key_stem
= key
[:-1], key
[-1]
169 cont
[k
] = {"flags": set(), "recursive_flags": set(), "nested": {}}
170 cont
= cont
[k
]["nested"]
171 if key_stem
not in cont
:
172 cont
[key_stem
] = {"flags": set(), "recursive_flags": set(), "nested": {}}
173 cont
[key_stem
]["recursive_flags" if recursive
else "flags"].add(flag
)
175 def is_(self
, key
: Key
, flag
: int) -> bool:
177 return False # document root has no flags
183 if flag
in inner_cont
["recursive_flags"]:
185 cont
= inner_cont
["nested"]
188 cont
= cont
[key_stem
]
189 return flag
in cont
["flags"] or flag
in cont
["recursive_flags"]
194 def __init__(self
) -> None:
195 # The parsed content of the TOML document
196 self
.dict: dict[str, Any
] = {}
198 def get_or_create_nest(
202 access_lists
: bool = True,
204 cont
: Any
= self
.dict
209 if access_lists
and isinstance(cont
, list):
211 if not isinstance(cont
, dict):
212 raise KeyError("There is no nest behind this key")
215 def append_nest_to_list(self
, key
: Key
) -> None:
216 cont
= self
.get_or_create_nest(key
[:-1])
219 list_
= cont
[last_key
]
220 if not isinstance(list_
, list):
221 raise KeyError("An object other than list found behind this key")
224 cont
[last_key
] = [{}]
227 class Output(NamedTuple
):
232 def skip_chars(src
: str, pos
: Pos
, chars
: Iterable
[str]) -> Pos
:
234 while src
[pos
] in chars
:
246 error_on
: frozenset[str],
250 new_pos
= src
.index(expect
, pos
)
254 raise suffixed_err(src
, new_pos
, f
"Expected {expect!r}") from None
256 if not error_on
.isdisjoint(src
[pos
:new_pos
]):
257 while src
[pos
] not in error_on
:
259 raise suffixed_err(src
, pos
, f
"Found invalid character {src[pos]!r}")
263 def skip_comment(src
: str, pos
: Pos
) -> Pos
:
265 char
: str |
None = src
[pos
]
270 src
, pos
+ 1, "\n", error_on
=ILLEGAL_COMMENT_CHARS
, error_on_eof
=False
275 def skip_comments_and_array_ws(src
: str, pos
: Pos
) -> Pos
:
277 pos_before_skip
= pos
278 pos
= skip_chars(src
, pos
, TOML_WS_AND_NEWLINE
)
279 pos
= skip_comment(src
, pos
)
280 if pos
== pos_before_skip
:
284 def create_dict_rule(src
: str, pos
: Pos
, out
: Output
) -> tuple[Pos
, Key
]:
286 pos
= skip_chars(src
, pos
, TOML_WS
)
287 pos
, key
= parse_key(src
, pos
)
289 if out
.flags
.is_(key
, Flags
.EXPLICIT_NEST
) or out
.flags
.is_(key
, Flags
.FROZEN
):
290 raise suffixed_err(src
, pos
, f
"Cannot declare {key} twice")
291 out
.flags
.set(key
, Flags
.EXPLICIT_NEST
, recursive
=False)
293 out
.data
.get_or_create_nest(key
)
295 raise suffixed_err(src
, pos
, "Cannot overwrite a value") from None
297 if not src
.startswith("]", pos
):
298 raise suffixed_err(src
, pos
, "Expected ']' at the end of a table declaration")
302 def create_list_rule(src
: str, pos
: Pos
, out
: Output
) -> tuple[Pos
, Key
]:
304 pos
= skip_chars(src
, pos
, TOML_WS
)
305 pos
, key
= parse_key(src
, pos
)
307 if out
.flags
.is_(key
, Flags
.FROZEN
):
308 raise suffixed_err(src
, pos
, f
"Cannot mutate immutable namespace {key}")
309 # Free the namespace now that it points to another empty list item...
310 out
.flags
.unset_all(key
)
311 # ...but this key precisely is still prohibited from table declaration
312 out
.flags
.set(key
, Flags
.EXPLICIT_NEST
, recursive
=False)
314 out
.data
.append_nest_to_list(key
)
316 raise suffixed_err(src
, pos
, "Cannot overwrite a value") from None
318 if not src
.startswith("]]", pos
):
319 raise suffixed_err(src
, pos
, "Expected ']]' at the end of an array declaration")
324 src
: str, pos
: Pos
, out
: Output
, header
: Key
, parse_float
: ParseFloat
326 pos
, key
, value
= parse_key_value_pair(src
, pos
, parse_float
)
327 key_parent
, key_stem
= key
[:-1], key
[-1]
328 abs_key_parent
= header
+ key_parent
330 relative_path_cont_keys
= (header
+ key
[:i
] for i
in range(1, len(key
)))
331 for cont_key
in relative_path_cont_keys
:
332 # Check that dotted key syntax does not redefine an existing table
333 if out
.flags
.is_(cont_key
, Flags
.EXPLICIT_NEST
):
334 raise suffixed_err(src
, pos
, f
"Cannot redefine namespace {cont_key}")
335 # Containers in the relative path can't be opened with the table syntax or
336 # dotted key/value syntax in following table sections.
337 out
.flags
.add_pending(cont_key
, Flags
.EXPLICIT_NEST
)
339 if out
.flags
.is_(abs_key_parent
, Flags
.FROZEN
):
341 src
, pos
, f
"Cannot mutate immutable namespace {abs_key_parent}"
345 nest
= out
.data
.get_or_create_nest(abs_key_parent
)
347 raise suffixed_err(src
, pos
, "Cannot overwrite a value") from None
349 raise suffixed_err(src
, pos
, "Cannot overwrite a value")
350 # Mark inline table and array namespaces recursively immutable
351 if isinstance(value
, (dict, list)):
352 out
.flags
.set(header
+ key
, Flags
.FROZEN
, recursive
=True)
353 nest
[key_stem
] = value
357 def parse_key_value_pair(
358 src
: str, pos
: Pos
, parse_float
: ParseFloat
359 ) -> tuple[Pos
, Key
, Any
]:
360 pos
, key
= parse_key(src
, pos
)
362 char
: str |
None = src
[pos
]
366 raise suffixed_err(src
, pos
, "Expected '=' after a key in a key/value pair")
368 pos
= skip_chars(src
, pos
, TOML_WS
)
369 pos
, value
= parse_value(src
, pos
, parse_float
)
370 return pos
, key
, value
373 def parse_key(src
: str, pos
: Pos
) -> tuple[Pos
, Key
]:
374 pos
, key_part
= parse_key_part(src
, pos
)
375 key
: Key
= (key_part
,)
376 pos
= skip_chars(src
, pos
, TOML_WS
)
379 char
: str |
None = src
[pos
]
385 pos
= skip_chars(src
, pos
, TOML_WS
)
386 pos
, key_part
= parse_key_part(src
, pos
)
388 pos
= skip_chars(src
, pos
, TOML_WS
)
391 def parse_key_part(src
: str, pos
: Pos
) -> tuple[Pos
, str]:
393 char
: str |
None = src
[pos
]
396 if char
in BARE_KEY_CHARS
:
398 pos
= skip_chars(src
, pos
, BARE_KEY_CHARS
)
399 return pos
, src
[start_pos
:pos
]
401 return parse_literal_str(src
, pos
)
403 return parse_one_line_basic_str(src
, pos
)
404 raise suffixed_err(src
, pos
, "Invalid initial character for a key part")
407 def parse_one_line_basic_str(src
: str, pos
: Pos
) -> tuple[Pos
, str]:
409 return parse_basic_str(src
, pos
, multiline
=False)
412 def parse_array(src
: str, pos
: Pos
, parse_float
: ParseFloat
) -> tuple[Pos
, list]:
416 pos
= skip_comments_and_array_ws(src
, pos
)
417 if src
.startswith("]", pos
):
418 return pos
+ 1, array
420 pos
, val
= parse_value(src
, pos
, parse_float
)
422 pos
= skip_comments_and_array_ws(src
, pos
)
424 c
= src
[pos
: pos
+ 1]
426 return pos
+ 1, array
428 raise suffixed_err(src
, pos
, "Unclosed array")
431 pos
= skip_comments_and_array_ws(src
, pos
)
432 if src
.startswith("]", pos
):
433 return pos
+ 1, array
436 def parse_inline_table(src
: str, pos
: Pos
, parse_float
: ParseFloat
) -> tuple[Pos
, dict]:
438 nested_dict
= NestedDict()
441 pos
= skip_chars(src
, pos
, TOML_WS
)
442 if src
.startswith("}", pos
):
443 return pos
+ 1, nested_dict
.dict
445 pos
, key
, value
= parse_key_value_pair(src
, pos
, parse_float
)
446 key_parent
, key_stem
= key
[:-1], key
[-1]
447 if flags
.is_(key
, Flags
.FROZEN
):
448 raise suffixed_err(src
, pos
, f
"Cannot mutate immutable namespace {key}")
450 nest
= nested_dict
.get_or_create_nest(key_parent
, access_lists
=False)
452 raise suffixed_err(src
, pos
, "Cannot overwrite a value") from None
454 raise suffixed_err(src
, pos
, f
"Duplicate inline table key {key_stem!r}")
455 nest
[key_stem
] = value
456 pos
= skip_chars(src
, pos
, TOML_WS
)
457 c
= src
[pos
: pos
+ 1]
459 return pos
+ 1, nested_dict
.dict
461 raise suffixed_err(src
, pos
, "Unclosed inline table")
462 if isinstance(value
, (dict, list)):
463 flags
.set(key
, Flags
.FROZEN
, recursive
=True)
465 pos
= skip_chars(src
, pos
, TOML_WS
)
468 def parse_basic_str_escape(
469 src
: str, pos
: Pos
, *, multiline
: bool = False
470 ) -> tuple[Pos
, str]:
471 escape_id
= src
[pos
: pos
+ 2]
473 if multiline
and escape_id
in {"\\ ", "\\\t", "\\\n"}:
474 # Skip whitespace until next non-whitespace character or end of
475 # the doc. Error if non-whitespace is found before newline.
476 if escape_id
!= "\\\n":
477 pos
= skip_chars(src
, pos
, TOML_WS
)
483 raise suffixed_err(src
, pos
, "Unescaped '\\' in a string")
485 pos
= skip_chars(src
, pos
, TOML_WS_AND_NEWLINE
)
487 if escape_id
== "\\u":
488 return parse_hex_char(src
, pos
, 4)
489 if escape_id
== "\\U":
490 return parse_hex_char(src
, pos
, 8)
492 return pos
, BASIC_STR_ESCAPE_REPLACEMENTS
[escape_id
]
494 raise suffixed_err(src
, pos
, "Unescaped '\\' in a string") from None
497 def parse_basic_str_escape_multiline(src
: str, pos
: Pos
) -> tuple[Pos
, str]:
498 return parse_basic_str_escape(src
, pos
, multiline
=True)
501 def parse_hex_char(src
: str, pos
: Pos
, hex_len
: int) -> tuple[Pos
, str]:
502 hex_str
= src
[pos
: pos
+ hex_len
]
503 if len(hex_str
) != hex_len
or not HEXDIGIT_CHARS
.issuperset(hex_str
):
504 raise suffixed_err(src
, pos
, "Invalid hex value")
506 hex_int
= int(hex_str
, 16)
507 if not is_unicode_scalar_value(hex_int
):
508 raise suffixed_err(src
, pos
, "Escaped character is not a Unicode scalar value")
509 return pos
, chr(hex_int
)
512 def parse_literal_str(src
: str, pos
: Pos
) -> tuple[Pos
, str]:
513 pos
+= 1 # Skip starting apostrophe
516 src
, pos
, "'", error_on
=ILLEGAL_LITERAL_STR_CHARS
, error_on_eof
=True
518 return pos
+ 1, src
[start_pos
:pos
] # Skip ending apostrophe
521 def parse_multiline_str(src
: str, pos
: Pos
, *, literal
: bool) -> tuple[Pos
, str]:
523 if src
.startswith("\n", pos
):
528 end_pos
= skip_until(
532 error_on
=ILLEGAL_MULTILINE_LITERAL_STR_CHARS
,
535 result
= src
[pos
:end_pos
]
539 pos
, result
= parse_basic_str(src
, pos
, multiline
=True)
541 # Add at maximum two extra apostrophes/quotes if the end sequence
542 # is 4 or 5 chars long instead of just 3.
543 if not src
.startswith(delim
, pos
):
546 if not src
.startswith(delim
, pos
):
547 return pos
, result
+ delim
549 return pos
, result
+ (delim
* 2)
552 def parse_basic_str(src
: str, pos
: Pos
, *, multiline
: bool) -> tuple[Pos
, str]:
554 error_on
= ILLEGAL_MULTILINE_BASIC_STR_CHARS
555 parse_escapes
= parse_basic_str_escape_multiline
557 error_on
= ILLEGAL_BASIC_STR_CHARS
558 parse_escapes
= parse_basic_str_escape
565 raise suffixed_err(src
, pos
, "Unterminated string") from None
568 return pos
+ 1, result
+ src
[start_pos
:pos
]
569 if src
.startswith('"""', pos
):
570 return pos
+ 3, result
+ src
[start_pos
:pos
]
574 result
+= src
[start_pos
:pos
]
575 pos
, parsed_escape
= parse_escapes(src
, pos
)
576 result
+= parsed_escape
580 raise suffixed_err(src
, pos
, f
"Illegal character {char!r}")
584 def parse_value( # noqa: C901
585 src
: str, pos
: Pos
, parse_float
: ParseFloat
586 ) -> tuple[Pos
, Any
]:
588 char
: str |
None = src
[pos
]
592 # IMPORTANT: order conditions based on speed of checking and likelihood
596 if src
.startswith('"""', pos
):
597 return parse_multiline_str(src
, pos
, literal
=False)
598 return parse_one_line_basic_str(src
, pos
)
602 if src
.startswith("'''", pos
):
603 return parse_multiline_str(src
, pos
, literal
=True)
604 return parse_literal_str(src
, pos
)
608 if src
.startswith("true", pos
):
611 if src
.startswith("false", pos
):
612 return pos
+ 5, False
616 return parse_array(src
, pos
, parse_float
)
620 return parse_inline_table(src
, pos
, parse_float
)
623 datetime_match
= RE_DATETIME
.match(src
, pos
)
626 datetime_obj
= match_to_datetime(datetime_match
)
627 except ValueError as e
:
628 raise suffixed_err(src
, pos
, "Invalid date or datetime") from e
629 return datetime_match
.end(), datetime_obj
630 localtime_match
= RE_LOCALTIME
.match(src
, pos
)
632 return localtime_match
.end(), match_to_localtime(localtime_match
)
634 # Integers and "normal" floats.
635 # The regex will greedily match any type starting with a decimal
636 # char, so needs to be located after handling of dates and times.
637 number_match
= RE_NUMBER
.match(src
, pos
)
639 return number_match
.end(), match_to_number(number_match
, parse_float
)
642 first_three
= src
[pos
: pos
+ 3]
643 if first_three
in {"inf", "nan"}:
644 return pos
+ 3, parse_float(first_three
)
645 first_four
= src
[pos
: pos
+ 4]
646 if first_four
in {"-inf", "+inf", "-nan", "+nan"}:
647 return pos
+ 4, parse_float(first_four
)
649 raise suffixed_err(src
, pos
, "Invalid value")
652 def suffixed_err(src
: str, pos
: Pos
, msg
: str) -> TOMLDecodeError
:
653 """Return a `TOMLDecodeError` where error message is suffixed with
654 coordinates in source."""
656 def coord_repr(src
: str, pos
: Pos
) -> str:
658 return "end of document"
659 line
= src
.count("\n", 0, pos
) + 1
663 column
= pos
- src
.rindex("\n", 0, pos
)
664 return f
"line {line}, column {column}"
666 return TOMLDecodeError(f
"{msg} (at {coord_repr(src, pos)})")
669 def is_unicode_scalar_value(codepoint
: int) -> bool:
670 return (0 <= codepoint
<= 55295) or (57344 <= codepoint
<= 1114111)
673 def make_safe_parse_float(parse_float
: ParseFloat
) -> ParseFloat
:
674 """A decorator to make `parse_float` safe.
676 `parse_float` must not return dicts or lists, because these types
677 would be mixed with parsed TOML tables and arrays, thus confusing
678 the parser. The returned decorated callable raises `ValueError`
679 instead of returning illegal types.
681 # The default `float` callable never returns illegal types. Optimize it.
682 if parse_float
is float: # type: ignore[comparison-overlap]
685 def safe_parse_float(float_str
: str) -> Any
:
686 float_value
= parse_float(float_str
)
687 if isinstance(float_value
, (dict, list)):
688 raise ValueError("parse_float must not return dicts or lists")
691 return safe_parse_float