]>
Commit | Line | Data |
---|---|---|
53e6db90 DC |
1 | import io |
2 | import json | |
3 | import platform | |
4 | import re | |
5 | import sys | |
6 | import tokenize | |
7 | import traceback | |
8 | from contextlib import contextmanager | |
9 | from dataclasses import replace | |
10 | from datetime import datetime, timezone | |
11 | from enum import Enum | |
12 | from json.decoder import JSONDecodeError | |
13 | from pathlib import Path | |
14 | from typing import ( | |
15 | Any, | |
16 | Dict, | |
17 | Generator, | |
18 | Iterator, | |
19 | List, | |
20 | MutableMapping, | |
21 | Optional, | |
22 | Pattern, | |
23 | Sequence, | |
24 | Set, | |
25 | Sized, | |
26 | Tuple, | |
27 | Union, | |
28 | ) | |
29 | ||
30 | import click | |
31 | from click.core import ParameterSource | |
32 | from mypy_extensions import mypyc_attr | |
33 | from pathspec import PathSpec | |
34 | from pathspec.patterns.gitwildmatch import GitWildMatchPatternError | |
35 | ||
36 | from _black_version import version as __version__ | |
37 | from black.cache import Cache | |
38 | from black.comments import normalize_fmt_off | |
39 | from black.const import ( | |
40 | DEFAULT_EXCLUDES, | |
41 | DEFAULT_INCLUDES, | |
42 | DEFAULT_LINE_LENGTH, | |
43 | STDIN_PLACEHOLDER, | |
44 | ) | |
45 | from black.files import ( | |
46 | find_project_root, | |
47 | find_pyproject_toml, | |
48 | find_user_pyproject_toml, | |
49 | gen_python_files, | |
50 | get_gitignore, | |
51 | normalize_path_maybe_ignore, | |
52 | parse_pyproject_toml, | |
53 | wrap_stream_for_windows, | |
54 | ) | |
55 | from black.handle_ipynb_magics import ( | |
56 | PYTHON_CELL_MAGICS, | |
57 | TRANSFORMED_MAGICS, | |
58 | jupyter_dependencies_are_installed, | |
59 | mask_cell, | |
60 | put_trailing_semicolon_back, | |
61 | remove_trailing_semicolon, | |
62 | unmask_cell, | |
63 | ) | |
64 | from black.linegen import LN, LineGenerator, transform_line | |
65 | from black.lines import EmptyLineTracker, LinesBlock | |
66 | from black.mode import FUTURE_FLAG_TO_FEATURE, VERSION_TO_FEATURES, Feature | |
67 | from black.mode import Mode as Mode # re-exported | |
68 | from black.mode import TargetVersion, supports_feature | |
69 | from black.nodes import ( | |
70 | STARS, | |
71 | is_number_token, | |
72 | is_simple_decorator_expression, | |
73 | is_string_token, | |
74 | syms, | |
75 | ) | |
76 | from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out | |
77 | from black.parsing import InvalidInput # noqa F401 | |
78 | from black.parsing import lib2to3_parse, parse_ast, stringify_ast | |
79 | from black.report import Changed, NothingChanged, Report | |
80 | from black.trans import iter_fexpr_spans | |
81 | from blib2to3.pgen2 import token | |
82 | from blib2to3.pytree import Leaf, Node | |
83 | ||
84 | COMPILED = Path(__file__).suffix in (".pyd", ".so") | |
85 | ||
86 | # types | |
87 | FileContent = str | |
88 | Encoding = str | |
89 | NewLine = str | |
90 | ||
91 | ||
92 | class WriteBack(Enum): | |
93 | NO = 0 | |
94 | YES = 1 | |
95 | DIFF = 2 | |
96 | CHECK = 3 | |
97 | COLOR_DIFF = 4 | |
98 | ||
99 | @classmethod | |
100 | def from_configuration( | |
101 | cls, *, check: bool, diff: bool, color: bool = False | |
102 | ) -> "WriteBack": | |
103 | if check and not diff: | |
104 | return cls.CHECK | |
105 | ||
106 | if diff and color: | |
107 | return cls.COLOR_DIFF | |
108 | ||
109 | return cls.DIFF if diff else cls.YES | |
110 | ||
111 | ||
112 | # Legacy name, left for integrations. | |
113 | FileMode = Mode | |
114 | ||
115 | ||
116 | def read_pyproject_toml( | |
117 | ctx: click.Context, param: click.Parameter, value: Optional[str] | |
118 | ) -> Optional[str]: | |
119 | """Inject Black configuration from "pyproject.toml" into defaults in `ctx`. | |
120 | ||
121 | Returns the path to a successfully found and read configuration file, None | |
122 | otherwise. | |
123 | """ | |
124 | if not value: | |
125 | value = find_pyproject_toml( | |
126 | ctx.params.get("src", ()), ctx.params.get("stdin_filename", None) | |
127 | ) | |
128 | if value is None: | |
129 | return None | |
130 | ||
131 | try: | |
132 | config = parse_pyproject_toml(value) | |
133 | except (OSError, ValueError) as e: | |
134 | raise click.FileError( | |
135 | filename=value, hint=f"Error reading configuration file: {e}" | |
136 | ) from None | |
137 | ||
138 | if not config: | |
139 | return None | |
140 | else: | |
141 | # Sanitize the values to be Click friendly. For more information please see: | |
142 | # https://github.com/psf/black/issues/1458 | |
143 | # https://github.com/pallets/click/issues/1567 | |
144 | config = { | |
145 | k: str(v) if not isinstance(v, (list, dict)) else v | |
146 | for k, v in config.items() | |
147 | } | |
148 | ||
149 | target_version = config.get("target_version") | |
150 | if target_version is not None and not isinstance(target_version, list): | |
151 | raise click.BadOptionUsage( | |
152 | "target-version", "Config key target-version must be a list" | |
153 | ) | |
154 | ||
155 | exclude = config.get("exclude") | |
156 | if exclude is not None and not isinstance(exclude, str): | |
157 | raise click.BadOptionUsage("exclude", "Config key exclude must be a string") | |
158 | ||
159 | extend_exclude = config.get("extend_exclude") | |
160 | if extend_exclude is not None and not isinstance(extend_exclude, str): | |
161 | raise click.BadOptionUsage( | |
162 | "extend-exclude", "Config key extend-exclude must be a string" | |
163 | ) | |
164 | ||
165 | default_map: Dict[str, Any] = {} | |
166 | if ctx.default_map: | |
167 | default_map.update(ctx.default_map) | |
168 | default_map.update(config) | |
169 | ||
170 | ctx.default_map = default_map | |
171 | return value | |
172 | ||
173 | ||
174 | def target_version_option_callback( | |
175 | c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...] | |
176 | ) -> List[TargetVersion]: | |
177 | """Compute the target versions from a --target-version flag. | |
178 | ||
179 | This is its own function because mypy couldn't infer the type correctly | |
180 | when it was a lambda, causing mypyc trouble. | |
181 | """ | |
182 | return [TargetVersion[val.upper()] for val in v] | |
183 | ||
184 | ||
185 | def re_compile_maybe_verbose(regex: str) -> Pattern[str]: | |
186 | """Compile a regular expression string in `regex`. | |
187 | ||
188 | If it contains newlines, use verbose mode. | |
189 | """ | |
190 | if "\n" in regex: | |
191 | regex = "(?x)" + regex | |
192 | compiled: Pattern[str] = re.compile(regex) | |
193 | return compiled | |
194 | ||
195 | ||
196 | def validate_regex( | |
197 | ctx: click.Context, | |
198 | param: click.Parameter, | |
199 | value: Optional[str], | |
200 | ) -> Optional[Pattern[str]]: | |
201 | try: | |
202 | return re_compile_maybe_verbose(value) if value is not None else None | |
203 | except re.error as e: | |
204 | raise click.BadParameter(f"Not a valid regular expression: {e}") from None | |
205 | ||
206 | ||
207 | @click.command( | |
208 | context_settings={"help_option_names": ["-h", "--help"]}, | |
209 | # While Click does set this field automatically using the docstring, mypyc | |
210 | # (annoyingly) strips 'em so we need to set it here too. | |
211 | help="The uncompromising code formatter.", | |
212 | ) | |
213 | @click.option("-c", "--code", type=str, help="Format the code passed in as a string.") | |
214 | @click.option( | |
215 | "-l", | |
216 | "--line-length", | |
217 | type=int, | |
218 | default=DEFAULT_LINE_LENGTH, | |
219 | help="How many characters per line to allow.", | |
220 | show_default=True, | |
221 | ) | |
222 | @click.option( | |
223 | "-t", | |
224 | "--target-version", | |
225 | type=click.Choice([v.name.lower() for v in TargetVersion]), | |
226 | callback=target_version_option_callback, | |
227 | multiple=True, | |
228 | help=( | |
229 | "Python versions that should be supported by Black's output. By default, Black" | |
230 | " will try to infer this from the project metadata in pyproject.toml. If this" | |
231 | " does not yield conclusive results, Black will use per-file auto-detection." | |
232 | ), | |
233 | ) | |
234 | @click.option( | |
235 | "--pyi", | |
236 | is_flag=True, | |
237 | help=( | |
238 | "Format all input files like typing stubs regardless of file extension (useful" | |
239 | " when piping source on standard input)." | |
240 | ), | |
241 | ) | |
242 | @click.option( | |
243 | "--ipynb", | |
244 | is_flag=True, | |
245 | help=( | |
246 | "Format all input files like Jupyter Notebooks regardless of file extension " | |
247 | "(useful when piping source on standard input)." | |
248 | ), | |
249 | ) | |
250 | @click.option( | |
251 | "--python-cell-magics", | |
252 | multiple=True, | |
253 | help=( | |
254 | "When processing Jupyter Notebooks, add the given magic to the list" | |
255 | f" of known python-magics ({', '.join(sorted(PYTHON_CELL_MAGICS))})." | |
256 | " Useful for formatting cells with custom python magics." | |
257 | ), | |
258 | default=[], | |
259 | ) | |
260 | @click.option( | |
261 | "-x", | |
262 | "--skip-source-first-line", | |
263 | is_flag=True, | |
264 | help="Skip the first line of the source code.", | |
265 | ) | |
266 | @click.option( | |
267 | "-S", | |
268 | "--skip-string-normalization", | |
269 | is_flag=True, | |
270 | help="Don't normalize string quotes or prefixes.", | |
271 | ) | |
272 | @click.option( | |
273 | "-C", | |
274 | "--skip-magic-trailing-comma", | |
275 | is_flag=True, | |
276 | help="Don't use trailing commas as a reason to split lines.", | |
277 | ) | |
278 | @click.option( | |
279 | "--experimental-string-processing", | |
280 | is_flag=True, | |
281 | hidden=True, | |
282 | help="(DEPRECATED and now included in --preview) Normalize string literals.", | |
283 | ) | |
284 | @click.option( | |
285 | "--preview", | |
286 | is_flag=True, | |
287 | help=( | |
288 | "Enable potentially disruptive style changes that may be added to Black's main" | |
289 | " functionality in the next major release." | |
290 | ), | |
291 | ) | |
292 | @click.option( | |
293 | "--check", | |
294 | is_flag=True, | |
295 | help=( | |
296 | "Don't write the files back, just return the status. Return code 0 means" | |
297 | " nothing would change. Return code 1 means some files would be reformatted." | |
298 | " Return code 123 means there was an internal error." | |
299 | ), | |
300 | ) | |
301 | @click.option( | |
302 | "--diff", | |
303 | is_flag=True, | |
304 | help="Don't write the files back, just output a diff for each file on stdout.", | |
305 | ) | |
306 | @click.option( | |
307 | "--color/--no-color", | |
308 | is_flag=True, | |
309 | help="Show colored diff. Only applies when `--diff` is given.", | |
310 | ) | |
311 | @click.option( | |
312 | "--fast/--safe", | |
313 | is_flag=True, | |
314 | help="If --fast given, skip temporary sanity checks. [default: --safe]", | |
315 | ) | |
316 | @click.option( | |
317 | "--required-version", | |
318 | type=str, | |
319 | help=( | |
320 | "Require a specific version of Black to be running (useful for unifying results" | |
321 | " across many environments e.g. with a pyproject.toml file). It can be" | |
322 | " either a major version number or an exact version." | |
323 | ), | |
324 | ) | |
325 | @click.option( | |
326 | "--include", | |
327 | type=str, | |
328 | default=DEFAULT_INCLUDES, | |
329 | callback=validate_regex, | |
330 | help=( | |
331 | "A regular expression that matches files and directories that should be" | |
332 | " included on recursive searches. An empty value means all files are included" | |
333 | " regardless of the name. Use forward slashes for directories on all platforms" | |
334 | " (Windows, too). Exclusions are calculated first, inclusions later." | |
335 | ), | |
336 | show_default=True, | |
337 | ) | |
338 | @click.option( | |
339 | "--exclude", | |
340 | type=str, | |
341 | callback=validate_regex, | |
342 | help=( | |
343 | "A regular expression that matches files and directories that should be" | |
344 | " excluded on recursive searches. An empty value means no paths are excluded." | |
345 | " Use forward slashes for directories on all platforms (Windows, too)." | |
346 | " Exclusions are calculated first, inclusions later. [default:" | |
347 | f" {DEFAULT_EXCLUDES}]" | |
348 | ), | |
349 | show_default=False, | |
350 | ) | |
351 | @click.option( | |
352 | "--extend-exclude", | |
353 | type=str, | |
354 | callback=validate_regex, | |
355 | help=( | |
356 | "Like --exclude, but adds additional files and directories on top of the" | |
357 | " excluded ones. (Useful if you simply want to add to the default)" | |
358 | ), | |
359 | ) | |
360 | @click.option( | |
361 | "--force-exclude", | |
362 | type=str, | |
363 | callback=validate_regex, | |
364 | help=( | |
365 | "Like --exclude, but files and directories matching this regex will be " | |
366 | "excluded even when they are passed explicitly as arguments." | |
367 | ), | |
368 | ) | |
369 | @click.option( | |
370 | "--stdin-filename", | |
371 | type=str, | |
372 | is_eager=True, | |
373 | help=( | |
374 | "The name of the file when passing it through stdin. Useful to make " | |
375 | "sure Black will respect --force-exclude option on some " | |
376 | "editors that rely on using stdin." | |
377 | ), | |
378 | ) | |
379 | @click.option( | |
380 | "-W", | |
381 | "--workers", | |
382 | type=click.IntRange(min=1), | |
383 | default=None, | |
384 | help=( | |
385 | "Number of parallel workers [default: BLACK_NUM_WORKERS environment variable " | |
386 | "or number of CPUs in the system]" | |
387 | ), | |
388 | ) | |
389 | @click.option( | |
390 | "-q", | |
391 | "--quiet", | |
392 | is_flag=True, | |
393 | help=( | |
394 | "Don't emit non-error messages to stderr. Errors are still emitted; silence" | |
395 | " those with 2>/dev/null." | |
396 | ), | |
397 | ) | |
398 | @click.option( | |
399 | "-v", | |
400 | "--verbose", | |
401 | is_flag=True, | |
402 | help=( | |
403 | "Also emit messages to stderr about files that were not changed or were ignored" | |
404 | " due to exclusion patterns." | |
405 | ), | |
406 | ) | |
407 | @click.version_option( | |
408 | version=__version__, | |
409 | message=( | |
410 | f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n" | |
411 | f"Python ({platform.python_implementation()}) {platform.python_version()}" | |
412 | ), | |
413 | ) | |
414 | @click.argument( | |
415 | "src", | |
416 | nargs=-1, | |
417 | type=click.Path( | |
418 | exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True | |
419 | ), | |
420 | is_eager=True, | |
421 | metavar="SRC ...", | |
422 | ) | |
423 | @click.option( | |
424 | "--config", | |
425 | type=click.Path( | |
426 | exists=True, | |
427 | file_okay=True, | |
428 | dir_okay=False, | |
429 | readable=True, | |
430 | allow_dash=False, | |
431 | path_type=str, | |
432 | ), | |
433 | is_eager=True, | |
434 | callback=read_pyproject_toml, | |
435 | help="Read configuration from FILE path.", | |
436 | ) | |
437 | @click.pass_context | |
438 | def main( # noqa: C901 | |
439 | ctx: click.Context, | |
440 | code: Optional[str], | |
441 | line_length: int, | |
442 | target_version: List[TargetVersion], | |
443 | check: bool, | |
444 | diff: bool, | |
445 | color: bool, | |
446 | fast: bool, | |
447 | pyi: bool, | |
448 | ipynb: bool, | |
449 | python_cell_magics: Sequence[str], | |
450 | skip_source_first_line: bool, | |
451 | skip_string_normalization: bool, | |
452 | skip_magic_trailing_comma: bool, | |
453 | experimental_string_processing: bool, | |
454 | preview: bool, | |
455 | quiet: bool, | |
456 | verbose: bool, | |
457 | required_version: Optional[str], | |
458 | include: Pattern[str], | |
459 | exclude: Optional[Pattern[str]], | |
460 | extend_exclude: Optional[Pattern[str]], | |
461 | force_exclude: Optional[Pattern[str]], | |
462 | stdin_filename: Optional[str], | |
463 | workers: Optional[int], | |
464 | src: Tuple[str, ...], | |
465 | config: Optional[str], | |
466 | ) -> None: | |
467 | """The uncompromising code formatter.""" | |
468 | ctx.ensure_object(dict) | |
469 | ||
470 | if src and code is not None: | |
471 | out( | |
472 | main.get_usage(ctx) | |
473 | + "\n\n'SRC' and 'code' cannot be passed simultaneously." | |
474 | ) | |
475 | ctx.exit(1) | |
476 | if not src and code is None: | |
477 | out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.") | |
478 | ctx.exit(1) | |
479 | ||
480 | root, method = ( | |
481 | find_project_root(src, stdin_filename) if code is None else (None, None) | |
482 | ) | |
483 | ctx.obj["root"] = root | |
484 | ||
485 | if verbose: | |
486 | if root: | |
487 | out( | |
488 | f"Identified `{root}` as project root containing a {method}.", | |
489 | fg="blue", | |
490 | ) | |
491 | ||
492 | if config: | |
493 | config_source = ctx.get_parameter_source("config") | |
494 | user_level_config = str(find_user_pyproject_toml()) | |
495 | if config == user_level_config: | |
496 | out( | |
497 | "Using configuration from user-level config at " | |
498 | f"'{user_level_config}'.", | |
499 | fg="blue", | |
500 | ) | |
501 | elif config_source in ( | |
502 | ParameterSource.DEFAULT, | |
503 | ParameterSource.DEFAULT_MAP, | |
504 | ): | |
505 | out("Using configuration from project root.", fg="blue") | |
506 | else: | |
507 | out(f"Using configuration in '{config}'.", fg="blue") | |
508 | if ctx.default_map: | |
509 | for param, value in ctx.default_map.items(): | |
510 | out(f"{param}: {value}") | |
511 | ||
512 | error_msg = "Oh no! 💥 💔 💥" | |
513 | if ( | |
514 | required_version | |
515 | and required_version != __version__ | |
516 | and required_version != __version__.split(".")[0] | |
517 | ): | |
518 | err( | |
519 | f"{error_msg} The required version `{required_version}` does not match" | |
520 | f" the running version `{__version__}`!" | |
521 | ) | |
522 | ctx.exit(1) | |
523 | if ipynb and pyi: | |
524 | err("Cannot pass both `pyi` and `ipynb` flags!") | |
525 | ctx.exit(1) | |
526 | ||
527 | write_back = WriteBack.from_configuration(check=check, diff=diff, color=color) | |
528 | if target_version: | |
529 | versions = set(target_version) | |
530 | else: | |
531 | # We'll autodetect later. | |
532 | versions = set() | |
533 | mode = Mode( | |
534 | target_versions=versions, | |
535 | line_length=line_length, | |
536 | is_pyi=pyi, | |
537 | is_ipynb=ipynb, | |
538 | skip_source_first_line=skip_source_first_line, | |
539 | string_normalization=not skip_string_normalization, | |
540 | magic_trailing_comma=not skip_magic_trailing_comma, | |
541 | experimental_string_processing=experimental_string_processing, | |
542 | preview=preview, | |
543 | python_cell_magics=set(python_cell_magics), | |
544 | ) | |
545 | ||
546 | if code is not None: | |
547 | # Run in quiet mode by default with -c; the extra output isn't useful. | |
548 | # You can still pass -v to get verbose output. | |
549 | quiet = True | |
550 | ||
551 | report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose) | |
552 | ||
553 | if code is not None: | |
554 | reformat_code( | |
555 | content=code, fast=fast, write_back=write_back, mode=mode, report=report | |
556 | ) | |
557 | else: | |
558 | assert root is not None # root is only None if code is not None | |
559 | try: | |
560 | sources = get_sources( | |
561 | root=root, | |
562 | src=src, | |
563 | quiet=quiet, | |
564 | verbose=verbose, | |
565 | include=include, | |
566 | exclude=exclude, | |
567 | extend_exclude=extend_exclude, | |
568 | force_exclude=force_exclude, | |
569 | report=report, | |
570 | stdin_filename=stdin_filename, | |
571 | ) | |
572 | except GitWildMatchPatternError: | |
573 | ctx.exit(1) | |
574 | ||
575 | path_empty( | |
576 | sources, | |
577 | "No Python files are present to be formatted. Nothing to do 😴", | |
578 | quiet, | |
579 | verbose, | |
580 | ctx, | |
581 | ) | |
582 | ||
583 | if len(sources) == 1: | |
584 | reformat_one( | |
585 | src=sources.pop(), | |
586 | fast=fast, | |
587 | write_back=write_back, | |
588 | mode=mode, | |
589 | report=report, | |
590 | ) | |
591 | else: | |
592 | from black.concurrency import reformat_many | |
593 | ||
594 | reformat_many( | |
595 | sources=sources, | |
596 | fast=fast, | |
597 | write_back=write_back, | |
598 | mode=mode, | |
599 | report=report, | |
600 | workers=workers, | |
601 | ) | |
602 | ||
603 | if verbose or not quiet: | |
604 | if code is None and (verbose or report.change_count or report.failure_count): | |
605 | out() | |
606 | out(error_msg if report.return_code else "All done! ✨ 🍰 ✨") | |
607 | if code is None: | |
608 | click.echo(str(report), err=True) | |
609 | ctx.exit(report.return_code) | |
610 | ||
611 | ||
612 | def get_sources( | |
613 | *, | |
614 | root: Path, | |
615 | src: Tuple[str, ...], | |
616 | quiet: bool, | |
617 | verbose: bool, | |
618 | include: Pattern[str], | |
619 | exclude: Optional[Pattern[str]], | |
620 | extend_exclude: Optional[Pattern[str]], | |
621 | force_exclude: Optional[Pattern[str]], | |
622 | report: "Report", | |
623 | stdin_filename: Optional[str], | |
624 | ) -> Set[Path]: | |
625 | """Compute the set of files to be formatted.""" | |
626 | sources: Set[Path] = set() | |
627 | ||
628 | using_default_exclude = exclude is None | |
629 | exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude | |
630 | gitignore: Optional[Dict[Path, PathSpec]] = None | |
631 | root_gitignore = get_gitignore(root) | |
632 | ||
633 | for s in src: | |
634 | if s == "-" and stdin_filename: | |
635 | p = Path(stdin_filename) | |
636 | is_stdin = True | |
637 | else: | |
638 | p = Path(s) | |
639 | is_stdin = False | |
640 | ||
641 | if is_stdin or p.is_file(): | |
642 | normalized_path: Optional[str] = normalize_path_maybe_ignore( | |
643 | p, root, report | |
644 | ) | |
645 | if normalized_path is None: | |
646 | if verbose: | |
647 | out(f'Skipping invalid source: "{normalized_path}"', fg="red") | |
648 | continue | |
649 | if verbose: | |
650 | out(f'Found input source: "{normalized_path}"', fg="blue") | |
651 | ||
652 | normalized_path = "/" + normalized_path | |
653 | # Hard-exclude any files that matches the `--force-exclude` regex. | |
654 | if force_exclude: | |
655 | force_exclude_match = force_exclude.search(normalized_path) | |
656 | else: | |
657 | force_exclude_match = None | |
658 | if force_exclude_match and force_exclude_match.group(0): | |
659 | report.path_ignored(p, "matches the --force-exclude regular expression") | |
660 | continue | |
661 | ||
662 | if is_stdin: | |
663 | p = Path(f"{STDIN_PLACEHOLDER}{str(p)}") | |
664 | ||
665 | if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed( | |
666 | warn=verbose or not quiet | |
667 | ): | |
668 | continue | |
669 | ||
670 | sources.add(p) | |
671 | elif p.is_dir(): | |
672 | p_relative = normalize_path_maybe_ignore(p, root, report) | |
673 | assert p_relative is not None | |
674 | p = root / p_relative | |
675 | if verbose: | |
676 | out(f'Found input source directory: "{p}"', fg="blue") | |
677 | ||
678 | if using_default_exclude: | |
679 | gitignore = { | |
680 | root: root_gitignore, | |
681 | p: get_gitignore(p), | |
682 | } | |
683 | sources.update( | |
684 | gen_python_files( | |
685 | p.iterdir(), | |
686 | root, | |
687 | include, | |
688 | exclude, | |
689 | extend_exclude, | |
690 | force_exclude, | |
691 | report, | |
692 | gitignore, | |
693 | verbose=verbose, | |
694 | quiet=quiet, | |
695 | ) | |
696 | ) | |
697 | elif s == "-": | |
698 | if verbose: | |
699 | out("Found input source stdin", fg="blue") | |
700 | sources.add(p) | |
701 | else: | |
702 | err(f"invalid path: {s}") | |
703 | ||
704 | return sources | |
705 | ||
706 | ||
707 | def path_empty( | |
708 | src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context | |
709 | ) -> None: | |
710 | """ | |
711 | Exit if there is no `src` provided for formatting | |
712 | """ | |
713 | if not src: | |
714 | if verbose or not quiet: | |
715 | out(msg) | |
716 | ctx.exit(0) | |
717 | ||
718 | ||
719 | def reformat_code( | |
720 | content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report | |
721 | ) -> None: | |
722 | """ | |
723 | Reformat and print out `content` without spawning child processes. | |
724 | Similar to `reformat_one`, but for string content. | |
725 | ||
726 | `fast`, `write_back`, and `mode` options are passed to | |
727 | :func:`format_file_in_place` or :func:`format_stdin_to_stdout`. | |
728 | """ | |
729 | path = Path("<string>") | |
730 | try: | |
731 | changed = Changed.NO | |
732 | if format_stdin_to_stdout( | |
733 | content=content, fast=fast, write_back=write_back, mode=mode | |
734 | ): | |
735 | changed = Changed.YES | |
736 | report.done(path, changed) | |
737 | except Exception as exc: | |
738 | if report.verbose: | |
739 | traceback.print_exc() | |
740 | report.failed(path, str(exc)) | |
741 | ||
742 | ||
743 | # diff-shades depends on being to monkeypatch this function to operate. I know it's | |
744 | # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26 | |
745 | @mypyc_attr(patchable=True) | |
746 | def reformat_one( | |
747 | src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report" | |
748 | ) -> None: | |
749 | """Reformat a single file under `src` without spawning child processes. | |
750 | ||
751 | `fast`, `write_back`, and `mode` options are passed to | |
752 | :func:`format_file_in_place` or :func:`format_stdin_to_stdout`. | |
753 | """ | |
754 | try: | |
755 | changed = Changed.NO | |
756 | ||
757 | if str(src) == "-": | |
758 | is_stdin = True | |
759 | elif str(src).startswith(STDIN_PLACEHOLDER): | |
760 | is_stdin = True | |
761 | # Use the original name again in case we want to print something | |
762 | # to the user | |
763 | src = Path(str(src)[len(STDIN_PLACEHOLDER) :]) | |
764 | else: | |
765 | is_stdin = False | |
766 | ||
767 | if is_stdin: | |
768 | if src.suffix == ".pyi": | |
769 | mode = replace(mode, is_pyi=True) | |
770 | elif src.suffix == ".ipynb": | |
771 | mode = replace(mode, is_ipynb=True) | |
772 | if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode): | |
773 | changed = Changed.YES | |
774 | else: | |
775 | cache = Cache.read(mode) | |
776 | if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF): | |
777 | if not cache.is_changed(src): | |
778 | changed = Changed.CACHED | |
779 | if changed is not Changed.CACHED and format_file_in_place( | |
780 | src, fast=fast, write_back=write_back, mode=mode | |
781 | ): | |
782 | changed = Changed.YES | |
783 | if (write_back is WriteBack.YES and changed is not Changed.CACHED) or ( | |
784 | write_back is WriteBack.CHECK and changed is Changed.NO | |
785 | ): | |
786 | cache.write([src]) | |
787 | report.done(src, changed) | |
788 | except Exception as exc: | |
789 | if report.verbose: | |
790 | traceback.print_exc() | |
791 | report.failed(src, str(exc)) | |
792 | ||
793 | ||
794 | def format_file_in_place( | |
795 | src: Path, | |
796 | fast: bool, | |
797 | mode: Mode, | |
798 | write_back: WriteBack = WriteBack.NO, | |
799 | lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy | |
800 | ) -> bool: | |
801 | """Format file under `src` path. Return True if changed. | |
802 | ||
803 | If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted | |
804 | code to the file. | |
805 | `mode` and `fast` options are passed to :func:`format_file_contents`. | |
806 | """ | |
807 | if src.suffix == ".pyi": | |
808 | mode = replace(mode, is_pyi=True) | |
809 | elif src.suffix == ".ipynb": | |
810 | mode = replace(mode, is_ipynb=True) | |
811 | ||
812 | then = datetime.fromtimestamp(src.stat().st_mtime, timezone.utc) | |
813 | header = b"" | |
814 | with open(src, "rb") as buf: | |
815 | if mode.skip_source_first_line: | |
816 | header = buf.readline() | |
817 | src_contents, encoding, newline = decode_bytes(buf.read()) | |
818 | try: | |
819 | dst_contents = format_file_contents(src_contents, fast=fast, mode=mode) | |
820 | except NothingChanged: | |
821 | return False | |
822 | except JSONDecodeError: | |
823 | raise ValueError( | |
824 | f"File '{src}' cannot be parsed as valid Jupyter notebook." | |
825 | ) from None | |
826 | src_contents = header.decode(encoding) + src_contents | |
827 | dst_contents = header.decode(encoding) + dst_contents | |
828 | ||
829 | if write_back == WriteBack.YES: | |
830 | with open(src, "w", encoding=encoding, newline=newline) as f: | |
831 | f.write(dst_contents) | |
832 | elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF): | |
833 | now = datetime.now(timezone.utc) | |
834 | src_name = f"{src}\t{then}" | |
835 | dst_name = f"{src}\t{now}" | |
836 | if mode.is_ipynb: | |
837 | diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name) | |
838 | else: | |
839 | diff_contents = diff(src_contents, dst_contents, src_name, dst_name) | |
840 | ||
841 | if write_back == WriteBack.COLOR_DIFF: | |
842 | diff_contents = color_diff(diff_contents) | |
843 | ||
844 | with lock or nullcontext(): | |
845 | f = io.TextIOWrapper( | |
846 | sys.stdout.buffer, | |
847 | encoding=encoding, | |
848 | newline=newline, | |
849 | write_through=True, | |
850 | ) | |
851 | f = wrap_stream_for_windows(f) | |
852 | f.write(diff_contents) | |
853 | f.detach() | |
854 | ||
855 | return True | |
856 | ||
857 | ||
858 | def format_stdin_to_stdout( | |
859 | fast: bool, | |
860 | *, | |
861 | content: Optional[str] = None, | |
862 | write_back: WriteBack = WriteBack.NO, | |
863 | mode: Mode, | |
864 | ) -> bool: | |
865 | """Format file on stdin. Return True if changed. | |
866 | ||
867 | If content is None, it's read from sys.stdin. | |
868 | ||
869 | If `write_back` is YES, write reformatted code back to stdout. If it is DIFF, | |
870 | write a diff to stdout. The `mode` argument is passed to | |
871 | :func:`format_file_contents`. | |
872 | """ | |
873 | then = datetime.now(timezone.utc) | |
874 | ||
875 | if content is None: | |
876 | src, encoding, newline = decode_bytes(sys.stdin.buffer.read()) | |
877 | else: | |
878 | src, encoding, newline = content, "utf-8", "" | |
879 | ||
880 | dst = src | |
881 | try: | |
882 | dst = format_file_contents(src, fast=fast, mode=mode) | |
883 | return True | |
884 | ||
885 | except NothingChanged: | |
886 | return False | |
887 | ||
888 | finally: | |
889 | f = io.TextIOWrapper( | |
890 | sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True | |
891 | ) | |
892 | if write_back == WriteBack.YES: | |
893 | # Make sure there's a newline after the content | |
894 | if dst and dst[-1] != "\n": | |
895 | dst += "\n" | |
896 | f.write(dst) | |
897 | elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF): | |
898 | now = datetime.now(timezone.utc) | |
899 | src_name = f"STDIN\t{then}" | |
900 | dst_name = f"STDOUT\t{now}" | |
901 | d = diff(src, dst, src_name, dst_name) | |
902 | if write_back == WriteBack.COLOR_DIFF: | |
903 | d = color_diff(d) | |
904 | f = wrap_stream_for_windows(f) | |
905 | f.write(d) | |
906 | f.detach() | |
907 | ||
908 | ||
909 | def check_stability_and_equivalence( | |
910 | src_contents: str, dst_contents: str, *, mode: Mode | |
911 | ) -> None: | |
912 | """Perform stability and equivalence checks. | |
913 | ||
914 | Raise AssertionError if source and destination contents are not | |
915 | equivalent, or if a second pass of the formatter would format the | |
916 | content differently. | |
917 | """ | |
918 | assert_equivalent(src_contents, dst_contents) | |
919 | assert_stable(src_contents, dst_contents, mode=mode) | |
920 | ||
921 | ||
922 | def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent: | |
923 | """Reformat contents of a file and return new contents. | |
924 | ||
925 | If `fast` is False, additionally confirm that the reformatted code is | |
926 | valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it. | |
927 | `mode` is passed to :func:`format_str`. | |
928 | """ | |
929 | if mode.is_ipynb: | |
930 | dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode) | |
931 | else: | |
932 | dst_contents = format_str(src_contents, mode=mode) | |
933 | if src_contents == dst_contents: | |
934 | raise NothingChanged | |
935 | ||
936 | if not fast and not mode.is_ipynb: | |
937 | # Jupyter notebooks will already have been checked above. | |
938 | check_stability_and_equivalence(src_contents, dst_contents, mode=mode) | |
939 | return dst_contents | |
940 | ||
941 | ||
942 | def validate_cell(src: str, mode: Mode) -> None: | |
943 | """Check that cell does not already contain TransformerManager transformations, | |
944 | or non-Python cell magics, which might cause tokenizer_rt to break because of | |
945 | indentations. | |
946 | ||
947 | If a cell contains ``!ls``, then it'll be transformed to | |
948 | ``get_ipython().system('ls')``. However, if the cell originally contained | |
949 | ``get_ipython().system('ls')``, then it would get transformed in the same way: | |
950 | ||
951 | >>> TransformerManager().transform_cell("get_ipython().system('ls')") | |
952 | "get_ipython().system('ls')\n" | |
953 | >>> TransformerManager().transform_cell("!ls") | |
954 | "get_ipython().system('ls')\n" | |
955 | ||
956 | Due to the impossibility of safely roundtripping in such situations, cells | |
957 | containing transformed magics will be ignored. | |
958 | """ | |
959 | if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS): | |
960 | raise NothingChanged | |
961 | if ( | |
962 | src[:2] == "%%" | |
963 | and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics | |
964 | ): | |
965 | raise NothingChanged | |
966 | ||
967 | ||
968 | def format_cell(src: str, *, fast: bool, mode: Mode) -> str: | |
969 | """Format code in given cell of Jupyter notebook. | |
970 | ||
971 | General idea is: | |
972 | ||
973 | - if cell has trailing semicolon, remove it; | |
974 | - if cell has IPython magics, mask them; | |
975 | - format cell; | |
976 | - reinstate IPython magics; | |
977 | - reinstate trailing semicolon (if originally present); | |
978 | - strip trailing newlines. | |
979 | ||
980 | Cells with syntax errors will not be processed, as they | |
981 | could potentially be automagics or multi-line magics, which | |
982 | are currently not supported. | |
983 | """ | |
984 | validate_cell(src, mode) | |
985 | src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon( | |
986 | src | |
987 | ) | |
988 | try: | |
989 | masked_src, replacements = mask_cell(src_without_trailing_semicolon) | |
990 | except SyntaxError: | |
991 | raise NothingChanged from None | |
992 | masked_dst = format_str(masked_src, mode=mode) | |
993 | if not fast: | |
994 | check_stability_and_equivalence(masked_src, masked_dst, mode=mode) | |
995 | dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements) | |
996 | dst = put_trailing_semicolon_back( | |
997 | dst_without_trailing_semicolon, has_trailing_semicolon | |
998 | ) | |
999 | dst = dst.rstrip("\n") | |
1000 | if dst == src: | |
1001 | raise NothingChanged from None | |
1002 | return dst | |
1003 | ||
1004 | ||
1005 | def validate_metadata(nb: MutableMapping[str, Any]) -> None: | |
1006 | """If notebook is marked as non-Python, don't format it. | |
1007 | ||
1008 | All notebook metadata fields are optional, see | |
1009 | https://nbformat.readthedocs.io/en/latest/format_description.html. So | |
1010 | if a notebook has empty metadata, we will try to parse it anyway. | |
1011 | """ | |
1012 | language = nb.get("metadata", {}).get("language_info", {}).get("name", None) | |
1013 | if language is not None and language != "python": | |
1014 | raise NothingChanged from None | |
1015 | ||
1016 | ||
1017 | def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent: | |
1018 | """Format Jupyter notebook. | |
1019 | ||
1020 | Operate cell-by-cell, only on code cells, only for Python notebooks. | |
1021 | If the ``.ipynb`` originally had a trailing newline, it'll be preserved. | |
1022 | """ | |
1023 | if not src_contents: | |
1024 | raise NothingChanged | |
1025 | ||
1026 | trailing_newline = src_contents[-1] == "\n" | |
1027 | modified = False | |
1028 | nb = json.loads(src_contents) | |
1029 | validate_metadata(nb) | |
1030 | for cell in nb["cells"]: | |
1031 | if cell.get("cell_type", None) == "code": | |
1032 | try: | |
1033 | src = "".join(cell["source"]) | |
1034 | dst = format_cell(src, fast=fast, mode=mode) | |
1035 | except NothingChanged: | |
1036 | pass | |
1037 | else: | |
1038 | cell["source"] = dst.splitlines(keepends=True) | |
1039 | modified = True | |
1040 | if modified: | |
1041 | dst_contents = json.dumps(nb, indent=1, ensure_ascii=False) | |
1042 | if trailing_newline: | |
1043 | dst_contents = dst_contents + "\n" | |
1044 | return dst_contents | |
1045 | else: | |
1046 | raise NothingChanged | |
1047 | ||
1048 | ||
1049 | def format_str(src_contents: str, *, mode: Mode) -> str: | |
1050 | """Reformat a string and return new contents. | |
1051 | ||
1052 | `mode` determines formatting options, such as how many characters per line are | |
1053 | allowed. Example: | |
1054 | ||
1055 | >>> import black | |
1056 | >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode())) | |
1057 | def f(arg: str = "") -> None: | |
1058 | ... | |
1059 | ||
1060 | A more complex example: | |
1061 | ||
1062 | >>> print( | |
1063 | ... black.format_str( | |
1064 | ... "def f(arg:str='')->None: hey", | |
1065 | ... mode=black.Mode( | |
1066 | ... target_versions={black.TargetVersion.PY36}, | |
1067 | ... line_length=10, | |
1068 | ... string_normalization=False, | |
1069 | ... is_pyi=False, | |
1070 | ... ), | |
1071 | ... ), | |
1072 | ... ) | |
1073 | def f( | |
1074 | arg: str = '', | |
1075 | ) -> None: | |
1076 | hey | |
1077 | ||
1078 | """ | |
1079 | dst_contents = _format_str_once(src_contents, mode=mode) | |
1080 | # Forced second pass to work around optional trailing commas (becoming | |
1081 | # forced trailing commas on pass 2) interacting differently with optional | |
1082 | # parentheses. Admittedly ugly. | |
1083 | if src_contents != dst_contents: | |
1084 | return _format_str_once(dst_contents, mode=mode) | |
1085 | return dst_contents | |
1086 | ||
1087 | ||
1088 | def _format_str_once(src_contents: str, *, mode: Mode) -> str: | |
1089 | src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions) | |
1090 | dst_blocks: List[LinesBlock] = [] | |
1091 | if mode.target_versions: | |
1092 | versions = mode.target_versions | |
1093 | else: | |
1094 | future_imports = get_future_imports(src_node) | |
1095 | versions = detect_target_versions(src_node, future_imports=future_imports) | |
1096 | ||
1097 | context_manager_features = { | |
1098 | feature | |
1099 | for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS} | |
1100 | if supports_feature(versions, feature) | |
1101 | } | |
1102 | normalize_fmt_off(src_node) | |
1103 | lines = LineGenerator(mode=mode, features=context_manager_features) | |
1104 | elt = EmptyLineTracker(mode=mode) | |
1105 | split_line_features = { | |
1106 | feature | |
1107 | for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF} | |
1108 | if supports_feature(versions, feature) | |
1109 | } | |
1110 | block: Optional[LinesBlock] = None | |
1111 | for current_line in lines.visit(src_node): | |
1112 | block = elt.maybe_empty_lines(current_line) | |
1113 | dst_blocks.append(block) | |
1114 | for line in transform_line( | |
1115 | current_line, mode=mode, features=split_line_features | |
1116 | ): | |
1117 | block.content_lines.append(str(line)) | |
1118 | if dst_blocks: | |
1119 | dst_blocks[-1].after = 0 | |
1120 | dst_contents = [] | |
1121 | for block in dst_blocks: | |
1122 | dst_contents.extend(block.all_lines()) | |
1123 | if not dst_contents: | |
1124 | # Use decode_bytes to retrieve the correct source newline (CRLF or LF), | |
1125 | # and check if normalized_content has more than one line | |
1126 | normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8")) | |
1127 | if "\n" in normalized_content: | |
1128 | return newline | |
1129 | return "" | |
1130 | return "".join(dst_contents) | |
1131 | ||
1132 | ||
1133 | def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]: | |
1134 | """Return a tuple of (decoded_contents, encoding, newline). | |
1135 | ||
1136 | `newline` is either CRLF or LF but `decoded_contents` is decoded with | |
1137 | universal newlines (i.e. only contains LF). | |
1138 | """ | |
1139 | srcbuf = io.BytesIO(src) | |
1140 | encoding, lines = tokenize.detect_encoding(srcbuf.readline) | |
1141 | if not lines: | |
1142 | return "", encoding, "\n" | |
1143 | ||
1144 | newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n" | |
1145 | srcbuf.seek(0) | |
1146 | with io.TextIOWrapper(srcbuf, encoding) as tiow: | |
1147 | return tiow.read(), encoding, newline | |
1148 | ||
1149 | ||
1150 | def get_features_used( # noqa: C901 | |
1151 | node: Node, *, future_imports: Optional[Set[str]] = None | |
1152 | ) -> Set[Feature]: | |
1153 | """Return a set of (relatively) new Python features used in this file. | |
1154 | ||
1155 | Currently looking for: | |
1156 | - f-strings; | |
1157 | - self-documenting expressions in f-strings (f"{x=}"); | |
1158 | - underscores in numeric literals; | |
1159 | - trailing commas after * or ** in function signatures and calls; | |
1160 | - positional only arguments in function signatures and lambdas; | |
1161 | - assignment expression; | |
1162 | - relaxed decorator syntax; | |
1163 | - usage of __future__ flags (annotations); | |
1164 | - print / exec statements; | |
1165 | - parenthesized context managers; | |
1166 | - match statements; | |
1167 | - except* clause; | |
1168 | - variadic generics; | |
1169 | """ | |
1170 | features: Set[Feature] = set() | |
1171 | if future_imports: | |
1172 | features |= { | |
1173 | FUTURE_FLAG_TO_FEATURE[future_import] | |
1174 | for future_import in future_imports | |
1175 | if future_import in FUTURE_FLAG_TO_FEATURE | |
1176 | } | |
1177 | ||
1178 | for n in node.pre_order(): | |
1179 | if is_string_token(n): | |
1180 | value_head = n.value[:2] | |
1181 | if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}: | |
1182 | features.add(Feature.F_STRINGS) | |
1183 | if Feature.DEBUG_F_STRINGS not in features: | |
1184 | for span_beg, span_end in iter_fexpr_spans(n.value): | |
1185 | if n.value[span_beg : span_end - 1].rstrip().endswith("="): | |
1186 | features.add(Feature.DEBUG_F_STRINGS) | |
1187 | break | |
1188 | ||
1189 | elif is_number_token(n): | |
1190 | if "_" in n.value: | |
1191 | features.add(Feature.NUMERIC_UNDERSCORES) | |
1192 | ||
1193 | elif n.type == token.SLASH: | |
1194 | if n.parent and n.parent.type in { | |
1195 | syms.typedargslist, | |
1196 | syms.arglist, | |
1197 | syms.varargslist, | |
1198 | }: | |
1199 | features.add(Feature.POS_ONLY_ARGUMENTS) | |
1200 | ||
1201 | elif n.type == token.COLONEQUAL: | |
1202 | features.add(Feature.ASSIGNMENT_EXPRESSIONS) | |
1203 | ||
1204 | elif n.type == syms.decorator: | |
1205 | if len(n.children) > 1 and not is_simple_decorator_expression( | |
1206 | n.children[1] | |
1207 | ): | |
1208 | features.add(Feature.RELAXED_DECORATORS) | |
1209 | ||
1210 | elif ( | |
1211 | n.type in {syms.typedargslist, syms.arglist} | |
1212 | and n.children | |
1213 | and n.children[-1].type == token.COMMA | |
1214 | ): | |
1215 | if n.type == syms.typedargslist: | |
1216 | feature = Feature.TRAILING_COMMA_IN_DEF | |
1217 | else: | |
1218 | feature = Feature.TRAILING_COMMA_IN_CALL | |
1219 | ||
1220 | for ch in n.children: | |
1221 | if ch.type in STARS: | |
1222 | features.add(feature) | |
1223 | ||
1224 | if ch.type == syms.argument: | |
1225 | for argch in ch.children: | |
1226 | if argch.type in STARS: | |
1227 | features.add(feature) | |
1228 | ||
1229 | elif ( | |
1230 | n.type in {syms.return_stmt, syms.yield_expr} | |
1231 | and len(n.children) >= 2 | |
1232 | and n.children[1].type == syms.testlist_star_expr | |
1233 | and any(child.type == syms.star_expr for child in n.children[1].children) | |
1234 | ): | |
1235 | features.add(Feature.UNPACKING_ON_FLOW) | |
1236 | ||
1237 | elif ( | |
1238 | n.type == syms.annassign | |
1239 | and len(n.children) >= 4 | |
1240 | and n.children[3].type == syms.testlist_star_expr | |
1241 | ): | |
1242 | features.add(Feature.ANN_ASSIGN_EXTENDED_RHS) | |
1243 | ||
1244 | elif ( | |
1245 | n.type == syms.with_stmt | |
1246 | and len(n.children) > 2 | |
1247 | and n.children[1].type == syms.atom | |
1248 | ): | |
1249 | atom_children = n.children[1].children | |
1250 | if ( | |
1251 | len(atom_children) == 3 | |
1252 | and atom_children[0].type == token.LPAR | |
1253 | and atom_children[1].type == syms.testlist_gexp | |
1254 | and atom_children[2].type == token.RPAR | |
1255 | ): | |
1256 | features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS) | |
1257 | ||
1258 | elif n.type == syms.match_stmt: | |
1259 | features.add(Feature.PATTERN_MATCHING) | |
1260 | ||
1261 | elif ( | |
1262 | n.type == syms.except_clause | |
1263 | and len(n.children) >= 2 | |
1264 | and n.children[1].type == token.STAR | |
1265 | ): | |
1266 | features.add(Feature.EXCEPT_STAR) | |
1267 | ||
1268 | elif n.type in {syms.subscriptlist, syms.trailer} and any( | |
1269 | child.type == syms.star_expr for child in n.children | |
1270 | ): | |
1271 | features.add(Feature.VARIADIC_GENERICS) | |
1272 | ||
1273 | elif ( | |
1274 | n.type == syms.tname_star | |
1275 | and len(n.children) == 3 | |
1276 | and n.children[2].type == syms.star_expr | |
1277 | ): | |
1278 | features.add(Feature.VARIADIC_GENERICS) | |
1279 | ||
1280 | elif n.type in (syms.type_stmt, syms.typeparams): | |
1281 | features.add(Feature.TYPE_PARAMS) | |
1282 | ||
1283 | return features | |
1284 | ||
1285 | ||
1286 | def detect_target_versions( | |
1287 | node: Node, *, future_imports: Optional[Set[str]] = None | |
1288 | ) -> Set[TargetVersion]: | |
1289 | """Detect the version to target based on the nodes used.""" | |
1290 | features = get_features_used(node, future_imports=future_imports) | |
1291 | return { | |
1292 | version for version in TargetVersion if features <= VERSION_TO_FEATURES[version] | |
1293 | } | |
1294 | ||
1295 | ||
1296 | def get_future_imports(node: Node) -> Set[str]: | |
1297 | """Return a set of __future__ imports in the file.""" | |
1298 | imports: Set[str] = set() | |
1299 | ||
1300 | def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]: | |
1301 | for child in children: | |
1302 | if isinstance(child, Leaf): | |
1303 | if child.type == token.NAME: | |
1304 | yield child.value | |
1305 | ||
1306 | elif child.type == syms.import_as_name: | |
1307 | orig_name = child.children[0] | |
1308 | assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports" | |
1309 | assert orig_name.type == token.NAME, "Invalid syntax parsing imports" | |
1310 | yield orig_name.value | |
1311 | ||
1312 | elif child.type == syms.import_as_names: | |
1313 | yield from get_imports_from_children(child.children) | |
1314 | ||
1315 | else: | |
1316 | raise AssertionError("Invalid syntax parsing imports") | |
1317 | ||
1318 | for child in node.children: | |
1319 | if child.type != syms.simple_stmt: | |
1320 | break | |
1321 | ||
1322 | first_child = child.children[0] | |
1323 | if isinstance(first_child, Leaf): | |
1324 | # Continue looking if we see a docstring; otherwise stop. | |
1325 | if ( | |
1326 | len(child.children) == 2 | |
1327 | and first_child.type == token.STRING | |
1328 | and child.children[1].type == token.NEWLINE | |
1329 | ): | |
1330 | continue | |
1331 | ||
1332 | break | |
1333 | ||
1334 | elif first_child.type == syms.import_from: | |
1335 | module_name = first_child.children[1] | |
1336 | if not isinstance(module_name, Leaf) or module_name.value != "__future__": | |
1337 | break | |
1338 | ||
1339 | imports |= set(get_imports_from_children(first_child.children[3:])) | |
1340 | else: | |
1341 | break | |
1342 | ||
1343 | return imports | |
1344 | ||
1345 | ||
1346 | def assert_equivalent(src: str, dst: str) -> None: | |
1347 | """Raise AssertionError if `src` and `dst` aren't equivalent.""" | |
1348 | try: | |
1349 | src_ast = parse_ast(src) | |
1350 | except Exception as exc: | |
1351 | raise AssertionError( | |
1352 | "cannot use --safe with this file; failed to parse source file AST: " | |
1353 | f"{exc}\n" | |
1354 | "This could be caused by running Black with an older Python version " | |
1355 | "that does not support new syntax used in your source file." | |
1356 | ) from exc | |
1357 | ||
1358 | try: | |
1359 | dst_ast = parse_ast(dst) | |
1360 | except Exception as exc: | |
1361 | log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst) | |
1362 | raise AssertionError( | |
1363 | f"INTERNAL ERROR: Black produced invalid code: {exc}. " | |
1364 | "Please report a bug on https://github.com/psf/black/issues. " | |
1365 | f"This invalid output might be helpful: {log}" | |
1366 | ) from None | |
1367 | ||
1368 | src_ast_str = "\n".join(stringify_ast(src_ast)) | |
1369 | dst_ast_str = "\n".join(stringify_ast(dst_ast)) | |
1370 | if src_ast_str != dst_ast_str: | |
1371 | log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst")) | |
1372 | raise AssertionError( | |
1373 | "INTERNAL ERROR: Black produced code that is not equivalent to the" | |
1374 | " source. Please report a bug on " | |
1375 | f"https://github.com/psf/black/issues. This diff might be helpful: {log}" | |
1376 | ) from None | |
1377 | ||
1378 | ||
1379 | def assert_stable(src: str, dst: str, mode: Mode) -> None: | |
1380 | """Raise AssertionError if `dst` reformats differently the second time.""" | |
1381 | # We shouldn't call format_str() here, because that formats the string | |
1382 | # twice and may hide a bug where we bounce back and forth between two | |
1383 | # versions. | |
1384 | newdst = _format_str_once(dst, mode=mode) | |
1385 | if dst != newdst: | |
1386 | log = dump_to_file( | |
1387 | str(mode), | |
1388 | diff(src, dst, "source", "first pass"), | |
1389 | diff(dst, newdst, "first pass", "second pass"), | |
1390 | ) | |
1391 | raise AssertionError( | |
1392 | "INTERNAL ERROR: Black produced different code on the second pass of the" | |
1393 | " formatter. Please report a bug on https://github.com/psf/black/issues." | |
1394 | f" This diff might be helpful: {log}" | |
1395 | ) from None | |
1396 | ||
1397 | ||
1398 | @contextmanager | |
1399 | def nullcontext() -> Iterator[None]: | |
1400 | """Return an empty context manager. | |
1401 | ||
1402 | To be used like `nullcontext` in Python 3.7. | |
1403 | """ | |
1404 | yield | |
1405 | ||
1406 | ||
1407 | def patched_main() -> None: | |
1408 | # PyInstaller patches multiprocessing to need freeze_support() even in non-Windows | |
1409 | # environments so just assume we always need to call it if frozen. | |
1410 | if getattr(sys, "frozen", False): | |
1411 | from multiprocessing import freeze_support | |
1412 | ||
1413 | freeze_support() | |
1414 | ||
1415 | main() | |
1416 | ||
1417 | ||
1418 | if __name__ == "__main__": | |
1419 | patched_main() |