]>
Commit | Line | Data |
---|---|---|
53e6db90 DC |
1 | import asyncio |
2 | import logging | |
3 | from concurrent.futures import Executor, ProcessPoolExecutor | |
4 | from datetime import datetime, timezone | |
5 | from functools import partial | |
6 | from multiprocessing import freeze_support | |
7 | from typing import Set, Tuple | |
8 | ||
9 | try: | |
10 | from aiohttp import web | |
11 | ||
12 | from .middlewares import cors | |
13 | except ImportError as ie: | |
14 | raise ImportError( | |
15 | f"aiohttp dependency is not installed: {ie}. " | |
16 | + "Please re-install black with the '[d]' extra install " | |
17 | + "to obtain aiohttp_cors: `pip install black[d]`" | |
18 | ) from None | |
19 | ||
20 | import click | |
21 | ||
22 | import black | |
23 | from _black_version import version as __version__ | |
24 | from black.concurrency import maybe_install_uvloop | |
25 | ||
26 | # This is used internally by tests to shut down the server prematurely | |
27 | _stop_signal = asyncio.Event() | |
28 | ||
29 | # Request headers | |
30 | PROTOCOL_VERSION_HEADER = "X-Protocol-Version" | |
31 | LINE_LENGTH_HEADER = "X-Line-Length" | |
32 | PYTHON_VARIANT_HEADER = "X-Python-Variant" | |
33 | SKIP_SOURCE_FIRST_LINE = "X-Skip-Source-First-Line" | |
34 | SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization" | |
35 | SKIP_MAGIC_TRAILING_COMMA = "X-Skip-Magic-Trailing-Comma" | |
36 | PREVIEW = "X-Preview" | |
37 | FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe" | |
38 | DIFF_HEADER = "X-Diff" | |
39 | ||
40 | BLACK_HEADERS = [ | |
41 | PROTOCOL_VERSION_HEADER, | |
42 | LINE_LENGTH_HEADER, | |
43 | PYTHON_VARIANT_HEADER, | |
44 | SKIP_SOURCE_FIRST_LINE, | |
45 | SKIP_STRING_NORMALIZATION_HEADER, | |
46 | SKIP_MAGIC_TRAILING_COMMA, | |
47 | PREVIEW, | |
48 | FAST_OR_SAFE_HEADER, | |
49 | DIFF_HEADER, | |
50 | ] | |
51 | ||
52 | # Response headers | |
53 | BLACK_VERSION_HEADER = "X-Black-Version" | |
54 | ||
55 | ||
56 | class InvalidVariantHeader(Exception): | |
57 | pass | |
58 | ||
59 | ||
60 | @click.command(context_settings={"help_option_names": ["-h", "--help"]}) | |
61 | @click.option( | |
62 | "--bind-host", | |
63 | type=str, | |
64 | help="Address to bind the server to.", | |
65 | default="localhost", | |
66 | show_default=True, | |
67 | ) | |
68 | @click.option( | |
69 | "--bind-port", type=int, help="Port to listen on", default=45484, show_default=True | |
70 | ) | |
71 | @click.version_option(version=black.__version__) | |
72 | def main(bind_host: str, bind_port: int) -> None: | |
73 | logging.basicConfig(level=logging.INFO) | |
74 | app = make_app() | |
75 | ver = black.__version__ | |
76 | black.out(f"blackd version {ver} listening on {bind_host} port {bind_port}") | |
77 | web.run_app(app, host=bind_host, port=bind_port, handle_signals=True, print=None) | |
78 | ||
79 | ||
80 | def make_app() -> web.Application: | |
81 | app = web.Application( | |
82 | middlewares=[cors(allow_headers=(*BLACK_HEADERS, "Content-Type"))] | |
83 | ) | |
84 | executor = ProcessPoolExecutor() | |
85 | app.add_routes([web.post("/", partial(handle, executor=executor))]) | |
86 | return app | |
87 | ||
88 | ||
89 | async def handle(request: web.Request, executor: Executor) -> web.Response: | |
90 | headers = {BLACK_VERSION_HEADER: __version__} | |
91 | try: | |
92 | if request.headers.get(PROTOCOL_VERSION_HEADER, "1") != "1": | |
93 | return web.Response( | |
94 | status=501, text="This server only supports protocol version 1" | |
95 | ) | |
96 | try: | |
97 | line_length = int( | |
98 | request.headers.get(LINE_LENGTH_HEADER, black.DEFAULT_LINE_LENGTH) | |
99 | ) | |
100 | except ValueError: | |
101 | return web.Response(status=400, text="Invalid line length header value") | |
102 | ||
103 | if PYTHON_VARIANT_HEADER in request.headers: | |
104 | value = request.headers[PYTHON_VARIANT_HEADER] | |
105 | try: | |
106 | pyi, versions = parse_python_variant_header(value) | |
107 | except InvalidVariantHeader as e: | |
108 | return web.Response( | |
109 | status=400, | |
110 | text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}", | |
111 | ) | |
112 | else: | |
113 | pyi = False | |
114 | versions = set() | |
115 | ||
116 | skip_string_normalization = bool( | |
117 | request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False) | |
118 | ) | |
119 | skip_magic_trailing_comma = bool( | |
120 | request.headers.get(SKIP_MAGIC_TRAILING_COMMA, False) | |
121 | ) | |
122 | skip_source_first_line = bool( | |
123 | request.headers.get(SKIP_SOURCE_FIRST_LINE, False) | |
124 | ) | |
125 | preview = bool(request.headers.get(PREVIEW, False)) | |
126 | fast = False | |
127 | if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast": | |
128 | fast = True | |
129 | mode = black.FileMode( | |
130 | target_versions=versions, | |
131 | is_pyi=pyi, | |
132 | line_length=line_length, | |
133 | skip_source_first_line=skip_source_first_line, | |
134 | string_normalization=not skip_string_normalization, | |
135 | magic_trailing_comma=not skip_magic_trailing_comma, | |
136 | preview=preview, | |
137 | ) | |
138 | req_bytes = await request.content.read() | |
139 | charset = request.charset if request.charset is not None else "utf8" | |
140 | req_str = req_bytes.decode(charset) | |
141 | then = datetime.now(timezone.utc) | |
142 | ||
143 | header = "" | |
144 | if skip_source_first_line: | |
145 | first_newline_position: int = req_str.find("\n") + 1 | |
146 | header = req_str[:first_newline_position] | |
147 | req_str = req_str[first_newline_position:] | |
148 | ||
149 | loop = asyncio.get_event_loop() | |
150 | formatted_str = await loop.run_in_executor( | |
151 | executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode) | |
152 | ) | |
153 | ||
154 | # Preserve CRLF line endings | |
155 | nl = req_str.find("\n") | |
156 | if nl > 0 and req_str[nl - 1] == "\r": | |
157 | formatted_str = formatted_str.replace("\n", "\r\n") | |
158 | # If, after swapping line endings, nothing changed, then say so | |
159 | if formatted_str == req_str: | |
160 | raise black.NothingChanged | |
161 | ||
162 | # Put the source first line back | |
163 | req_str = header + req_str | |
164 | formatted_str = header + formatted_str | |
165 | ||
166 | # Only output the diff in the HTTP response | |
167 | only_diff = bool(request.headers.get(DIFF_HEADER, False)) | |
168 | if only_diff: | |
169 | now = datetime.now(timezone.utc) | |
170 | src_name = f"In\t{then}" | |
171 | dst_name = f"Out\t{now}" | |
172 | loop = asyncio.get_event_loop() | |
173 | formatted_str = await loop.run_in_executor( | |
174 | executor, | |
175 | partial(black.diff, req_str, formatted_str, src_name, dst_name), | |
176 | ) | |
177 | ||
178 | return web.Response( | |
179 | content_type=request.content_type, | |
180 | charset=charset, | |
181 | headers=headers, | |
182 | text=formatted_str, | |
183 | ) | |
184 | except black.NothingChanged: | |
185 | return web.Response(status=204, headers=headers) | |
186 | except black.InvalidInput as e: | |
187 | return web.Response(status=400, headers=headers, text=str(e)) | |
188 | except Exception as e: | |
189 | logging.exception("Exception during handling a request") | |
190 | return web.Response(status=500, headers=headers, text=str(e)) | |
191 | ||
192 | ||
193 | def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]: | |
194 | if value == "pyi": | |
195 | return True, set() | |
196 | else: | |
197 | versions = set() | |
198 | for version in value.split(","): | |
199 | if version.startswith("py"): | |
200 | version = version[len("py") :] | |
201 | if "." in version: | |
202 | major_str, *rest = version.split(".") | |
203 | else: | |
204 | major_str = version[0] | |
205 | rest = [version[1:]] if len(version) > 1 else [] | |
206 | try: | |
207 | major = int(major_str) | |
208 | if major not in (2, 3): | |
209 | raise InvalidVariantHeader("major version must be 2 or 3") | |
210 | if len(rest) > 0: | |
211 | minor = int(rest[0]) | |
212 | if major == 2: | |
213 | raise InvalidVariantHeader("Python 2 is not supported") | |
214 | else: | |
215 | # Default to lowest supported minor version. | |
216 | minor = 7 if major == 2 else 3 | |
217 | version_str = f"PY{major}{minor}" | |
218 | if major == 3 and not hasattr(black.TargetVersion, version_str): | |
219 | raise InvalidVariantHeader(f"3.{minor} is not supported") | |
220 | versions.add(black.TargetVersion[version_str]) | |
221 | except (KeyError, ValueError): | |
222 | raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'") from None | |
223 | return False, versions | |
224 | ||
225 | ||
226 | def patched_main() -> None: | |
227 | maybe_install_uvloop() | |
228 | freeze_support() | |
229 | main() | |
230 | ||
231 | ||
232 | if __name__ == "__main__": | |
233 | patched_main() |