]>
Commit | Line | Data |
---|---|---|
1 | # SPDX-License-Identifier: MIT | |
2 | # SPDX-FileCopyrightText: 2021 Taneli Hukkinen | |
3 | # Licensed to PSF under a Contributor Agreement. | |
4 | ||
5 | from __future__ import annotations | |
6 | ||
7 | from collections.abc import Iterable | |
8 | import string | |
9 | from types import MappingProxyType | |
10 | from typing import Any, BinaryIO, NamedTuple | |
11 | ||
12 | from ._re import ( | |
13 | RE_DATETIME, | |
14 | RE_LOCALTIME, | |
15 | RE_NUMBER, | |
16 | match_to_datetime, | |
17 | match_to_localtime, | |
18 | match_to_number, | |
19 | ) | |
20 | from ._types import Key, ParseFloat, Pos | |
21 | ||
22 | ASCII_CTRL = frozenset(chr(i) for i in range(32)) | frozenset(chr(127)) | |
23 | ||
24 | # Neither of these sets include quotation mark or backslash. They are | |
25 | # currently handled as separate cases in the parser functions. | |
26 | ILLEGAL_BASIC_STR_CHARS = ASCII_CTRL - frozenset("\t") | |
27 | ILLEGAL_MULTILINE_BASIC_STR_CHARS = ASCII_CTRL - frozenset("\t\n") | |
28 | ||
29 | ILLEGAL_LITERAL_STR_CHARS = ILLEGAL_BASIC_STR_CHARS | |
30 | ILLEGAL_MULTILINE_LITERAL_STR_CHARS = ILLEGAL_MULTILINE_BASIC_STR_CHARS | |
31 | ||
32 | ILLEGAL_COMMENT_CHARS = ILLEGAL_BASIC_STR_CHARS | |
33 | ||
34 | TOML_WS = frozenset(" \t") | |
35 | TOML_WS_AND_NEWLINE = TOML_WS | frozenset("\n") | |
36 | BARE_KEY_CHARS = frozenset(string.ascii_letters + string.digits + "-_") | |
37 | KEY_INITIAL_CHARS = BARE_KEY_CHARS | frozenset("\"'") | |
38 | HEXDIGIT_CHARS = frozenset(string.hexdigits) | |
39 | ||
40 | BASIC_STR_ESCAPE_REPLACEMENTS = MappingProxyType( | |
41 | { | |
42 | "\\b": "\u0008", # backspace | |
43 | "\\t": "\u0009", # tab | |
44 | "\\n": "\u000A", # linefeed | |
45 | "\\f": "\u000C", # form feed | |
46 | "\\r": "\u000D", # carriage return | |
47 | '\\"': "\u0022", # quote | |
48 | "\\\\": "\u005C", # backslash | |
49 | } | |
50 | ) | |
51 | ||
52 | ||
53 | class TOMLDecodeError(ValueError): | |
54 | """An error raised if a document is not valid TOML.""" | |
55 | ||
56 | ||
57 | def load(__fp: BinaryIO, *, parse_float: ParseFloat = float) -> dict[str, Any]: | |
58 | """Parse TOML from a binary file object.""" | |
59 | b = __fp.read() | |
60 | try: | |
61 | s = b.decode() | |
62 | except AttributeError: | |
63 | raise TypeError( | |
64 | "File must be opened in binary mode, e.g. use `open('foo.toml', 'rb')`" | |
65 | ) from None | |
66 | return loads(s, parse_float=parse_float) | |
67 | ||
68 | ||
69 | def loads(__s: str, *, parse_float: ParseFloat = float) -> dict[str, Any]: # noqa: C901 | |
70 | """Parse TOML from a string.""" | |
71 | ||
72 | # The spec allows converting "\r\n" to "\n", even in string | |
73 | # literals. Let's do so to simplify parsing. | |
74 | src = __s.replace("\r\n", "\n") | |
75 | pos = 0 | |
76 | out = Output(NestedDict(), Flags()) | |
77 | header: Key = () | |
78 | parse_float = make_safe_parse_float(parse_float) | |
79 | ||
80 | # Parse one statement at a time | |
81 | # (typically means one line in TOML source) | |
82 | while True: | |
83 | # 1. Skip line leading whitespace | |
84 | pos = skip_chars(src, pos, TOML_WS) | |
85 | ||
86 | # 2. Parse rules. Expect one of the following: | |
87 | # - end of file | |
88 | # - end of line | |
89 | # - comment | |
90 | # - key/value pair | |
91 | # - append dict to list (and move to its namespace) | |
92 | # - create dict (and move to its namespace) | |
93 | # Skip trailing whitespace when applicable. | |
94 | try: | |
95 | char = src[pos] | |
96 | except IndexError: | |
97 | break | |
98 | if char == "\n": | |
99 | pos += 1 | |
100 | continue | |
101 | if char in KEY_INITIAL_CHARS: | |
102 | pos = key_value_rule(src, pos, out, header, parse_float) | |
103 | pos = skip_chars(src, pos, TOML_WS) | |
104 | elif char == "[": | |
105 | try: | |
106 | second_char: str | None = src[pos + 1] | |
107 | except IndexError: | |
108 | second_char = None | |
109 | out.flags.finalize_pending() | |
110 | if second_char == "[": | |
111 | pos, header = create_list_rule(src, pos, out) | |
112 | else: | |
113 | pos, header = create_dict_rule(src, pos, out) | |
114 | pos = skip_chars(src, pos, TOML_WS) | |
115 | elif char != "#": | |
116 | raise suffixed_err(src, pos, "Invalid statement") | |
117 | ||
118 | # 3. Skip comment | |
119 | pos = skip_comment(src, pos) | |
120 | ||
121 | # 4. Expect end of line or end of file | |
122 | try: | |
123 | char = src[pos] | |
124 | except IndexError: | |
125 | break | |
126 | if char != "\n": | |
127 | raise suffixed_err( | |
128 | src, pos, "Expected newline or end of document after a statement" | |
129 | ) | |
130 | pos += 1 | |
131 | ||
132 | return out.data.dict | |
133 | ||
134 | ||
135 | class Flags: | |
136 | """Flags that map to parsed keys/namespaces.""" | |
137 | ||
138 | # Marks an immutable namespace (inline array or inline table). | |
139 | FROZEN = 0 | |
140 | # Marks a nest that has been explicitly created and can no longer | |
141 | # be opened using the "[table]" syntax. | |
142 | EXPLICIT_NEST = 1 | |
143 | ||
144 | def __init__(self) -> None: | |
145 | self._flags: dict[str, dict] = {} | |
146 | self._pending_flags: set[tuple[Key, int]] = set() | |
147 | ||
148 | def add_pending(self, key: Key, flag: int) -> None: | |
149 | self._pending_flags.add((key, flag)) | |
150 | ||
151 | def finalize_pending(self) -> None: | |
152 | for key, flag in self._pending_flags: | |
153 | self.set(key, flag, recursive=False) | |
154 | self._pending_flags.clear() | |
155 | ||
156 | def unset_all(self, key: Key) -> None: | |
157 | cont = self._flags | |
158 | for k in key[:-1]: | |
159 | if k not in cont: | |
160 | return | |
161 | cont = cont[k]["nested"] | |
162 | cont.pop(key[-1], None) | |
163 | ||
164 | def set(self, key: Key, flag: int, *, recursive: bool) -> None: # noqa: A003 | |
165 | cont = self._flags | |
166 | key_parent, key_stem = key[:-1], key[-1] | |
167 | for k in key_parent: | |
168 | if k not in cont: | |
169 | cont[k] = {"flags": set(), "recursive_flags": set(), "nested": {}} | |
170 | cont = cont[k]["nested"] | |
171 | if key_stem not in cont: | |
172 | cont[key_stem] = {"flags": set(), "recursive_flags": set(), "nested": {}} | |
173 | cont[key_stem]["recursive_flags" if recursive else "flags"].add(flag) | |
174 | ||
175 | def is_(self, key: Key, flag: int) -> bool: | |
176 | if not key: | |
177 | return False # document root has no flags | |
178 | cont = self._flags | |
179 | for k in key[:-1]: | |
180 | if k not in cont: | |
181 | return False | |
182 | inner_cont = cont[k] | |
183 | if flag in inner_cont["recursive_flags"]: | |
184 | return True | |
185 | cont = inner_cont["nested"] | |
186 | key_stem = key[-1] | |
187 | if key_stem in cont: | |
188 | cont = cont[key_stem] | |
189 | return flag in cont["flags"] or flag in cont["recursive_flags"] | |
190 | return False | |
191 | ||
192 | ||
193 | class NestedDict: | |
194 | def __init__(self) -> None: | |
195 | # The parsed content of the TOML document | |
196 | self.dict: dict[str, Any] = {} | |
197 | ||
198 | def get_or_create_nest( | |
199 | self, | |
200 | key: Key, | |
201 | *, | |
202 | access_lists: bool = True, | |
203 | ) -> dict: | |
204 | cont: Any = self.dict | |
205 | for k in key: | |
206 | if k not in cont: | |
207 | cont[k] = {} | |
208 | cont = cont[k] | |
209 | if access_lists and isinstance(cont, list): | |
210 | cont = cont[-1] | |
211 | if not isinstance(cont, dict): | |
212 | raise KeyError("There is no nest behind this key") | |
213 | return cont | |
214 | ||
215 | def append_nest_to_list(self, key: Key) -> None: | |
216 | cont = self.get_or_create_nest(key[:-1]) | |
217 | last_key = key[-1] | |
218 | if last_key in cont: | |
219 | list_ = cont[last_key] | |
220 | if not isinstance(list_, list): | |
221 | raise KeyError("An object other than list found behind this key") | |
222 | list_.append({}) | |
223 | else: | |
224 | cont[last_key] = [{}] | |
225 | ||
226 | ||
227 | class Output(NamedTuple): | |
228 | data: NestedDict | |
229 | flags: Flags | |
230 | ||
231 | ||
232 | def skip_chars(src: str, pos: Pos, chars: Iterable[str]) -> Pos: | |
233 | try: | |
234 | while src[pos] in chars: | |
235 | pos += 1 | |
236 | except IndexError: | |
237 | pass | |
238 | return pos | |
239 | ||
240 | ||
241 | def skip_until( | |
242 | src: str, | |
243 | pos: Pos, | |
244 | expect: str, | |
245 | *, | |
246 | error_on: frozenset[str], | |
247 | error_on_eof: bool, | |
248 | ) -> Pos: | |
249 | try: | |
250 | new_pos = src.index(expect, pos) | |
251 | except ValueError: | |
252 | new_pos = len(src) | |
253 | if error_on_eof: | |
254 | raise suffixed_err(src, new_pos, f"Expected {expect!r}") from None | |
255 | ||
256 | if not error_on.isdisjoint(src[pos:new_pos]): | |
257 | while src[pos] not in error_on: | |
258 | pos += 1 | |
259 | raise suffixed_err(src, pos, f"Found invalid character {src[pos]!r}") | |
260 | return new_pos | |
261 | ||
262 | ||
263 | def skip_comment(src: str, pos: Pos) -> Pos: | |
264 | try: | |
265 | char: str | None = src[pos] | |
266 | except IndexError: | |
267 | char = None | |
268 | if char == "#": | |
269 | return skip_until( | |
270 | src, pos + 1, "\n", error_on=ILLEGAL_COMMENT_CHARS, error_on_eof=False | |
271 | ) | |
272 | return pos | |
273 | ||
274 | ||
275 | def skip_comments_and_array_ws(src: str, pos: Pos) -> Pos: | |
276 | while True: | |
277 | pos_before_skip = pos | |
278 | pos = skip_chars(src, pos, TOML_WS_AND_NEWLINE) | |
279 | pos = skip_comment(src, pos) | |
280 | if pos == pos_before_skip: | |
281 | return pos | |
282 | ||
283 | ||
284 | def create_dict_rule(src: str, pos: Pos, out: Output) -> tuple[Pos, Key]: | |
285 | pos += 1 # Skip "[" | |
286 | pos = skip_chars(src, pos, TOML_WS) | |
287 | pos, key = parse_key(src, pos) | |
288 | ||
289 | if out.flags.is_(key, Flags.EXPLICIT_NEST) or out.flags.is_(key, Flags.FROZEN): | |
290 | raise suffixed_err(src, pos, f"Cannot declare {key} twice") | |
291 | out.flags.set(key, Flags.EXPLICIT_NEST, recursive=False) | |
292 | try: | |
293 | out.data.get_or_create_nest(key) | |
294 | except KeyError: | |
295 | raise suffixed_err(src, pos, "Cannot overwrite a value") from None | |
296 | ||
297 | if not src.startswith("]", pos): | |
298 | raise suffixed_err(src, pos, "Expected ']' at the end of a table declaration") | |
299 | return pos + 1, key | |
300 | ||
301 | ||
302 | def create_list_rule(src: str, pos: Pos, out: Output) -> tuple[Pos, Key]: | |
303 | pos += 2 # Skip "[[" | |
304 | pos = skip_chars(src, pos, TOML_WS) | |
305 | pos, key = parse_key(src, pos) | |
306 | ||
307 | if out.flags.is_(key, Flags.FROZEN): | |
308 | raise suffixed_err(src, pos, f"Cannot mutate immutable namespace {key}") | |
309 | # Free the namespace now that it points to another empty list item... | |
310 | out.flags.unset_all(key) | |
311 | # ...but this key precisely is still prohibited from table declaration | |
312 | out.flags.set(key, Flags.EXPLICIT_NEST, recursive=False) | |
313 | try: | |
314 | out.data.append_nest_to_list(key) | |
315 | except KeyError: | |
316 | raise suffixed_err(src, pos, "Cannot overwrite a value") from None | |
317 | ||
318 | if not src.startswith("]]", pos): | |
319 | raise suffixed_err(src, pos, "Expected ']]' at the end of an array declaration") | |
320 | return pos + 2, key | |
321 | ||
322 | ||
323 | def key_value_rule( | |
324 | src: str, pos: Pos, out: Output, header: Key, parse_float: ParseFloat | |
325 | ) -> Pos: | |
326 | pos, key, value = parse_key_value_pair(src, pos, parse_float) | |
327 | key_parent, key_stem = key[:-1], key[-1] | |
328 | abs_key_parent = header + key_parent | |
329 | ||
330 | relative_path_cont_keys = (header + key[:i] for i in range(1, len(key))) | |
331 | for cont_key in relative_path_cont_keys: | |
332 | # Check that dotted key syntax does not redefine an existing table | |
333 | if out.flags.is_(cont_key, Flags.EXPLICIT_NEST): | |
334 | raise suffixed_err(src, pos, f"Cannot redefine namespace {cont_key}") | |
335 | # Containers in the relative path can't be opened with the table syntax or | |
336 | # dotted key/value syntax in following table sections. | |
337 | out.flags.add_pending(cont_key, Flags.EXPLICIT_NEST) | |
338 | ||
339 | if out.flags.is_(abs_key_parent, Flags.FROZEN): | |
340 | raise suffixed_err( | |
341 | src, pos, f"Cannot mutate immutable namespace {abs_key_parent}" | |
342 | ) | |
343 | ||
344 | try: | |
345 | nest = out.data.get_or_create_nest(abs_key_parent) | |
346 | except KeyError: | |
347 | raise suffixed_err(src, pos, "Cannot overwrite a value") from None | |
348 | if key_stem in nest: | |
349 | raise suffixed_err(src, pos, "Cannot overwrite a value") | |
350 | # Mark inline table and array namespaces recursively immutable | |
351 | if isinstance(value, (dict, list)): | |
352 | out.flags.set(header + key, Flags.FROZEN, recursive=True) | |
353 | nest[key_stem] = value | |
354 | return pos | |
355 | ||
356 | ||
357 | def parse_key_value_pair( | |
358 | src: str, pos: Pos, parse_float: ParseFloat | |
359 | ) -> tuple[Pos, Key, Any]: | |
360 | pos, key = parse_key(src, pos) | |
361 | try: | |
362 | char: str | None = src[pos] | |
363 | except IndexError: | |
364 | char = None | |
365 | if char != "=": | |
366 | raise suffixed_err(src, pos, "Expected '=' after a key in a key/value pair") | |
367 | pos += 1 | |
368 | pos = skip_chars(src, pos, TOML_WS) | |
369 | pos, value = parse_value(src, pos, parse_float) | |
370 | return pos, key, value | |
371 | ||
372 | ||
373 | def parse_key(src: str, pos: Pos) -> tuple[Pos, Key]: | |
374 | pos, key_part = parse_key_part(src, pos) | |
375 | key: Key = (key_part,) | |
376 | pos = skip_chars(src, pos, TOML_WS) | |
377 | while True: | |
378 | try: | |
379 | char: str | None = src[pos] | |
380 | except IndexError: | |
381 | char = None | |
382 | if char != ".": | |
383 | return pos, key | |
384 | pos += 1 | |
385 | pos = skip_chars(src, pos, TOML_WS) | |
386 | pos, key_part = parse_key_part(src, pos) | |
387 | key += (key_part,) | |
388 | pos = skip_chars(src, pos, TOML_WS) | |
389 | ||
390 | ||
391 | def parse_key_part(src: str, pos: Pos) -> tuple[Pos, str]: | |
392 | try: | |
393 | char: str | None = src[pos] | |
394 | except IndexError: | |
395 | char = None | |
396 | if char in BARE_KEY_CHARS: | |
397 | start_pos = pos | |
398 | pos = skip_chars(src, pos, BARE_KEY_CHARS) | |
399 | return pos, src[start_pos:pos] | |
400 | if char == "'": | |
401 | return parse_literal_str(src, pos) | |
402 | if char == '"': | |
403 | return parse_one_line_basic_str(src, pos) | |
404 | raise suffixed_err(src, pos, "Invalid initial character for a key part") | |
405 | ||
406 | ||
407 | def parse_one_line_basic_str(src: str, pos: Pos) -> tuple[Pos, str]: | |
408 | pos += 1 | |
409 | return parse_basic_str(src, pos, multiline=False) | |
410 | ||
411 | ||
412 | def parse_array(src: str, pos: Pos, parse_float: ParseFloat) -> tuple[Pos, list]: | |
413 | pos += 1 | |
414 | array: list = [] | |
415 | ||
416 | pos = skip_comments_and_array_ws(src, pos) | |
417 | if src.startswith("]", pos): | |
418 | return pos + 1, array | |
419 | while True: | |
420 | pos, val = parse_value(src, pos, parse_float) | |
421 | array.append(val) | |
422 | pos = skip_comments_and_array_ws(src, pos) | |
423 | ||
424 | c = src[pos : pos + 1] | |
425 | if c == "]": | |
426 | return pos + 1, array | |
427 | if c != ",": | |
428 | raise suffixed_err(src, pos, "Unclosed array") | |
429 | pos += 1 | |
430 | ||
431 | pos = skip_comments_and_array_ws(src, pos) | |
432 | if src.startswith("]", pos): | |
433 | return pos + 1, array | |
434 | ||
435 | ||
436 | def parse_inline_table(src: str, pos: Pos, parse_float: ParseFloat) -> tuple[Pos, dict]: | |
437 | pos += 1 | |
438 | nested_dict = NestedDict() | |
439 | flags = Flags() | |
440 | ||
441 | pos = skip_chars(src, pos, TOML_WS) | |
442 | if src.startswith("}", pos): | |
443 | return pos + 1, nested_dict.dict | |
444 | while True: | |
445 | pos, key, value = parse_key_value_pair(src, pos, parse_float) | |
446 | key_parent, key_stem = key[:-1], key[-1] | |
447 | if flags.is_(key, Flags.FROZEN): | |
448 | raise suffixed_err(src, pos, f"Cannot mutate immutable namespace {key}") | |
449 | try: | |
450 | nest = nested_dict.get_or_create_nest(key_parent, access_lists=False) | |
451 | except KeyError: | |
452 | raise suffixed_err(src, pos, "Cannot overwrite a value") from None | |
453 | if key_stem in nest: | |
454 | raise suffixed_err(src, pos, f"Duplicate inline table key {key_stem!r}") | |
455 | nest[key_stem] = value | |
456 | pos = skip_chars(src, pos, TOML_WS) | |
457 | c = src[pos : pos + 1] | |
458 | if c == "}": | |
459 | return pos + 1, nested_dict.dict | |
460 | if c != ",": | |
461 | raise suffixed_err(src, pos, "Unclosed inline table") | |
462 | if isinstance(value, (dict, list)): | |
463 | flags.set(key, Flags.FROZEN, recursive=True) | |
464 | pos += 1 | |
465 | pos = skip_chars(src, pos, TOML_WS) | |
466 | ||
467 | ||
468 | def parse_basic_str_escape( | |
469 | src: str, pos: Pos, *, multiline: bool = False | |
470 | ) -> tuple[Pos, str]: | |
471 | escape_id = src[pos : pos + 2] | |
472 | pos += 2 | |
473 | if multiline and escape_id in {"\\ ", "\\\t", "\\\n"}: | |
474 | # Skip whitespace until next non-whitespace character or end of | |
475 | # the doc. Error if non-whitespace is found before newline. | |
476 | if escape_id != "\\\n": | |
477 | pos = skip_chars(src, pos, TOML_WS) | |
478 | try: | |
479 | char = src[pos] | |
480 | except IndexError: | |
481 | return pos, "" | |
482 | if char != "\n": | |
483 | raise suffixed_err(src, pos, "Unescaped '\\' in a string") | |
484 | pos += 1 | |
485 | pos = skip_chars(src, pos, TOML_WS_AND_NEWLINE) | |
486 | return pos, "" | |
487 | if escape_id == "\\u": | |
488 | return parse_hex_char(src, pos, 4) | |
489 | if escape_id == "\\U": | |
490 | return parse_hex_char(src, pos, 8) | |
491 | try: | |
492 | return pos, BASIC_STR_ESCAPE_REPLACEMENTS[escape_id] | |
493 | except KeyError: | |
494 | raise suffixed_err(src, pos, "Unescaped '\\' in a string") from None | |
495 | ||
496 | ||
497 | def parse_basic_str_escape_multiline(src: str, pos: Pos) -> tuple[Pos, str]: | |
498 | return parse_basic_str_escape(src, pos, multiline=True) | |
499 | ||
500 | ||
501 | def parse_hex_char(src: str, pos: Pos, hex_len: int) -> tuple[Pos, str]: | |
502 | hex_str = src[pos : pos + hex_len] | |
503 | if len(hex_str) != hex_len or not HEXDIGIT_CHARS.issuperset(hex_str): | |
504 | raise suffixed_err(src, pos, "Invalid hex value") | |
505 | pos += hex_len | |
506 | hex_int = int(hex_str, 16) | |
507 | if not is_unicode_scalar_value(hex_int): | |
508 | raise suffixed_err(src, pos, "Escaped character is not a Unicode scalar value") | |
509 | return pos, chr(hex_int) | |
510 | ||
511 | ||
512 | def parse_literal_str(src: str, pos: Pos) -> tuple[Pos, str]: | |
513 | pos += 1 # Skip starting apostrophe | |
514 | start_pos = pos | |
515 | pos = skip_until( | |
516 | src, pos, "'", error_on=ILLEGAL_LITERAL_STR_CHARS, error_on_eof=True | |
517 | ) | |
518 | return pos + 1, src[start_pos:pos] # Skip ending apostrophe | |
519 | ||
520 | ||
521 | def parse_multiline_str(src: str, pos: Pos, *, literal: bool) -> tuple[Pos, str]: | |
522 | pos += 3 | |
523 | if src.startswith("\n", pos): | |
524 | pos += 1 | |
525 | ||
526 | if literal: | |
527 | delim = "'" | |
528 | end_pos = skip_until( | |
529 | src, | |
530 | pos, | |
531 | "'''", | |
532 | error_on=ILLEGAL_MULTILINE_LITERAL_STR_CHARS, | |
533 | error_on_eof=True, | |
534 | ) | |
535 | result = src[pos:end_pos] | |
536 | pos = end_pos + 3 | |
537 | else: | |
538 | delim = '"' | |
539 | pos, result = parse_basic_str(src, pos, multiline=True) | |
540 | ||
541 | # Add at maximum two extra apostrophes/quotes if the end sequence | |
542 | # is 4 or 5 chars long instead of just 3. | |
543 | if not src.startswith(delim, pos): | |
544 | return pos, result | |
545 | pos += 1 | |
546 | if not src.startswith(delim, pos): | |
547 | return pos, result + delim | |
548 | pos += 1 | |
549 | return pos, result + (delim * 2) | |
550 | ||
551 | ||
552 | def parse_basic_str(src: str, pos: Pos, *, multiline: bool) -> tuple[Pos, str]: | |
553 | if multiline: | |
554 | error_on = ILLEGAL_MULTILINE_BASIC_STR_CHARS | |
555 | parse_escapes = parse_basic_str_escape_multiline | |
556 | else: | |
557 | error_on = ILLEGAL_BASIC_STR_CHARS | |
558 | parse_escapes = parse_basic_str_escape | |
559 | result = "" | |
560 | start_pos = pos | |
561 | while True: | |
562 | try: | |
563 | char = src[pos] | |
564 | except IndexError: | |
565 | raise suffixed_err(src, pos, "Unterminated string") from None | |
566 | if char == '"': | |
567 | if not multiline: | |
568 | return pos + 1, result + src[start_pos:pos] | |
569 | if src.startswith('"""', pos): | |
570 | return pos + 3, result + src[start_pos:pos] | |
571 | pos += 1 | |
572 | continue | |
573 | if char == "\\": | |
574 | result += src[start_pos:pos] | |
575 | pos, parsed_escape = parse_escapes(src, pos) | |
576 | result += parsed_escape | |
577 | start_pos = pos | |
578 | continue | |
579 | if char in error_on: | |
580 | raise suffixed_err(src, pos, f"Illegal character {char!r}") | |
581 | pos += 1 | |
582 | ||
583 | ||
584 | def parse_value( # noqa: C901 | |
585 | src: str, pos: Pos, parse_float: ParseFloat | |
586 | ) -> tuple[Pos, Any]: | |
587 | try: | |
588 | char: str | None = src[pos] | |
589 | except IndexError: | |
590 | char = None | |
591 | ||
592 | # IMPORTANT: order conditions based on speed of checking and likelihood | |
593 | ||
594 | # Basic strings | |
595 | if char == '"': | |
596 | if src.startswith('"""', pos): | |
597 | return parse_multiline_str(src, pos, literal=False) | |
598 | return parse_one_line_basic_str(src, pos) | |
599 | ||
600 | # Literal strings | |
601 | if char == "'": | |
602 | if src.startswith("'''", pos): | |
603 | return parse_multiline_str(src, pos, literal=True) | |
604 | return parse_literal_str(src, pos) | |
605 | ||
606 | # Booleans | |
607 | if char == "t": | |
608 | if src.startswith("true", pos): | |
609 | return pos + 4, True | |
610 | if char == "f": | |
611 | if src.startswith("false", pos): | |
612 | return pos + 5, False | |
613 | ||
614 | # Arrays | |
615 | if char == "[": | |
616 | return parse_array(src, pos, parse_float) | |
617 | ||
618 | # Inline tables | |
619 | if char == "{": | |
620 | return parse_inline_table(src, pos, parse_float) | |
621 | ||
622 | # Dates and times | |
623 | datetime_match = RE_DATETIME.match(src, pos) | |
624 | if datetime_match: | |
625 | try: | |
626 | datetime_obj = match_to_datetime(datetime_match) | |
627 | except ValueError as e: | |
628 | raise suffixed_err(src, pos, "Invalid date or datetime") from e | |
629 | return datetime_match.end(), datetime_obj | |
630 | localtime_match = RE_LOCALTIME.match(src, pos) | |
631 | if localtime_match: | |
632 | return localtime_match.end(), match_to_localtime(localtime_match) | |
633 | ||
634 | # Integers and "normal" floats. | |
635 | # The regex will greedily match any type starting with a decimal | |
636 | # char, so needs to be located after handling of dates and times. | |
637 | number_match = RE_NUMBER.match(src, pos) | |
638 | if number_match: | |
639 | return number_match.end(), match_to_number(number_match, parse_float) | |
640 | ||
641 | # Special floats | |
642 | first_three = src[pos : pos + 3] | |
643 | if first_three in {"inf", "nan"}: | |
644 | return pos + 3, parse_float(first_three) | |
645 | first_four = src[pos : pos + 4] | |
646 | if first_four in {"-inf", "+inf", "-nan", "+nan"}: | |
647 | return pos + 4, parse_float(first_four) | |
648 | ||
649 | raise suffixed_err(src, pos, "Invalid value") | |
650 | ||
651 | ||
652 | def suffixed_err(src: str, pos: Pos, msg: str) -> TOMLDecodeError: | |
653 | """Return a `TOMLDecodeError` where error message is suffixed with | |
654 | coordinates in source.""" | |
655 | ||
656 | def coord_repr(src: str, pos: Pos) -> str: | |
657 | if pos >= len(src): | |
658 | return "end of document" | |
659 | line = src.count("\n", 0, pos) + 1 | |
660 | if line == 1: | |
661 | column = pos + 1 | |
662 | else: | |
663 | column = pos - src.rindex("\n", 0, pos) | |
664 | return f"line {line}, column {column}" | |
665 | ||
666 | return TOMLDecodeError(f"{msg} (at {coord_repr(src, pos)})") | |
667 | ||
668 | ||
669 | def is_unicode_scalar_value(codepoint: int) -> bool: | |
670 | return (0 <= codepoint <= 55295) or (57344 <= codepoint <= 1114111) | |
671 | ||
672 | ||
673 | def make_safe_parse_float(parse_float: ParseFloat) -> ParseFloat: | |
674 | """A decorator to make `parse_float` safe. | |
675 | ||
676 | `parse_float` must not return dicts or lists, because these types | |
677 | would be mixed with parsed TOML tables and arrays, thus confusing | |
678 | the parser. The returned decorated callable raises `ValueError` | |
679 | instead of returning illegal types. | |
680 | """ | |
681 | # The default `float` callable never returns illegal types. Optimize it. | |
682 | if parse_float is float: # type: ignore[comparison-overlap] | |
683 | return float | |
684 | ||
685 | def safe_parse_float(float_str: str) -> Any: | |
686 | float_value = parse_float(float_str) | |
687 | if isinstance(float_value, (dict, list)): | |
688 | raise ValueError("parse_float must not return dicts or lists") | |
689 | return float_value | |
690 | ||
691 | return safe_parse_float |