3 from concurrent
.futures
import Executor
, ProcessPoolExecutor
4 from datetime
import datetime
, timezone
5 from functools
import partial
6 from multiprocessing
import freeze_support
7 from typing
import Set
, Tuple
10 from aiohttp
import web
12 from .middlewares
import cors
13 except ImportError as ie
:
15 f
"aiohttp dependency is not installed: {ie}. "
16 + "Please re-install black with the '[d]' extra install "
17 + "to obtain aiohttp_cors: `pip install black[d]`"
23 from _black_version
import version
as __version__
24 from black
.concurrency
import maybe_install_uvloop
26 # This is used internally by tests to shut down the server prematurely
27 _stop_signal
= asyncio
.Event()
30 PROTOCOL_VERSION_HEADER
= "X-Protocol-Version"
31 LINE_LENGTH_HEADER
= "X-Line-Length"
32 PYTHON_VARIANT_HEADER
= "X-Python-Variant"
33 SKIP_SOURCE_FIRST_LINE
= "X-Skip-Source-First-Line"
34 SKIP_STRING_NORMALIZATION_HEADER
= "X-Skip-String-Normalization"
35 SKIP_MAGIC_TRAILING_COMMA
= "X-Skip-Magic-Trailing-Comma"
37 FAST_OR_SAFE_HEADER
= "X-Fast-Or-Safe"
38 DIFF_HEADER
= "X-Diff"
41 PROTOCOL_VERSION_HEADER
,
43 PYTHON_VARIANT_HEADER
,
44 SKIP_SOURCE_FIRST_LINE
,
45 SKIP_STRING_NORMALIZATION_HEADER
,
46 SKIP_MAGIC_TRAILING_COMMA
,
53 BLACK_VERSION_HEADER
= "X-Black-Version"
56 class InvalidVariantHeader(Exception):
60 @click.command(context_settings
={"help_option_names": ["-h", "--help"]})
64 help="Address to bind the server to.",
69 "--bind-port", type=int, help="Port to listen on", default
=45484, show_default
=True
71 @click.version_option(version
=black
.__version
__)
72 def main(bind_host
: str, bind_port
: int) -> None:
73 logging
.basicConfig(level
=logging
.INFO
)
75 ver
= black
.__version
__
76 black
.out(f
"blackd version {ver} listening on {bind_host} port {bind_port}")
77 web
.run_app(app
, host
=bind_host
, port
=bind_port
, handle_signals
=True, print=None)
80 def make_app() -> web
.Application
:
81 app
= web
.Application(
82 middlewares
=[cors(allow_headers
=(*BLACK_HEADERS
, "Content-Type"))]
84 executor
= ProcessPoolExecutor()
85 app
.add_routes([web
.post("/", partial(handle
, executor
=executor
))])
89 async def handle(request
: web
.Request
, executor
: Executor
) -> web
.Response
:
90 headers
= {BLACK_VERSION_HEADER
: __version__
}
92 if request
.headers
.get(PROTOCOL_VERSION_HEADER
, "1") != "1":
94 status
=501, text
="This server only supports protocol version 1"
98 request
.headers
.get(LINE_LENGTH_HEADER
, black
.DEFAULT_LINE_LENGTH
)
101 return web
.Response(status
=400, text
="Invalid line length header value")
103 if PYTHON_VARIANT_HEADER
in request
.headers
:
104 value
= request
.headers
[PYTHON_VARIANT_HEADER
]
106 pyi
, versions
= parse_python_variant_header(value
)
107 except InvalidVariantHeader
as e
:
110 text
=f
"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}",
116 skip_string_normalization
= bool(
117 request
.headers
.get(SKIP_STRING_NORMALIZATION_HEADER
, False)
119 skip_magic_trailing_comma
= bool(
120 request
.headers
.get(SKIP_MAGIC_TRAILING_COMMA
, False)
122 skip_source_first_line
= bool(
123 request
.headers
.get(SKIP_SOURCE_FIRST_LINE
, False)
125 preview
= bool(request
.headers
.get(PREVIEW
, False))
127 if request
.headers
.get(FAST_OR_SAFE_HEADER
, "safe") == "fast":
129 mode
= black
.FileMode(
130 target_versions
=versions
,
132 line_length
=line_length
,
133 skip_source_first_line
=skip_source_first_line
,
134 string_normalization
=not skip_string_normalization
,
135 magic_trailing_comma
=not skip_magic_trailing_comma
,
138 req_bytes
= await request
.content
.read()
139 charset
= request
.charset
if request
.charset
is not None else "utf8"
140 req_str
= req_bytes
.decode(charset
)
141 then
= datetime
.now(timezone
.utc
)
144 if skip_source_first_line
:
145 first_newline_position
: int = req_str
.find("\n") + 1
146 header
= req_str
[:first_newline_position
]
147 req_str
= req_str
[first_newline_position
:]
149 loop
= asyncio
.get_event_loop()
150 formatted_str
= await loop
.run_in_executor(
151 executor
, partial(black
.format_file_contents
, req_str
, fast
=fast
, mode
=mode
)
154 # Preserve CRLF line endings
155 nl
= req_str
.find("\n")
156 if nl
> 0 and req_str
[nl
- 1] == "\r":
157 formatted_str
= formatted_str
.replace("\n", "\r\n")
158 # If, after swapping line endings, nothing changed, then say so
159 if formatted_str
== req_str
:
160 raise black
.NothingChanged
162 # Put the source first line back
163 req_str
= header
+ req_str
164 formatted_str
= header
+ formatted_str
166 # Only output the diff in the HTTP response
167 only_diff
= bool(request
.headers
.get(DIFF_HEADER
, False))
169 now
= datetime
.now(timezone
.utc
)
170 src_name
= f
"In\t{then}"
171 dst_name
= f
"Out\t{now}"
172 loop
= asyncio
.get_event_loop()
173 formatted_str
= await loop
.run_in_executor(
175 partial(black
.diff
, req_str
, formatted_str
, src_name
, dst_name
),
179 content_type
=request
.content_type
,
184 except black
.NothingChanged
:
185 return web
.Response(status
=204, headers
=headers
)
186 except black
.InvalidInput
as e
:
187 return web
.Response(status
=400, headers
=headers
, text
=str(e
))
188 except Exception as e
:
189 logging
.exception("Exception during handling a request")
190 return web
.Response(status
=500, headers
=headers
, text
=str(e
))
193 def parse_python_variant_header(value
: str) -> Tuple
[bool, Set
[black
.TargetVersion
]]:
198 for version
in value
.split(","):
199 if version
.startswith("py"):
200 version
= version
[len("py") :]
202 major_str
, *rest
= version
.split(".")
204 major_str
= version
[0]
205 rest
= [version
[1:]] if len(version
) > 1 else []
207 major
= int(major_str
)
208 if major
not in (2, 3):
209 raise InvalidVariantHeader("major version must be 2 or 3")
213 raise InvalidVariantHeader("Python 2 is not supported")
215 # Default to lowest supported minor version.
216 minor
= 7 if major
== 2 else 3
217 version_str
= f
"PY{major}{minor}"
218 if major
== 3 and not hasattr(black
.TargetVersion
, version_str
):
219 raise InvalidVariantHeader(f
"3.{minor} is not supported")
220 versions
.add(black
.TargetVersion
[version_str
])
221 except (KeyError, ValueError):
222 raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'") from None
223 return False, versions
226 def patched_main() -> None:
227 maybe_install_uvloop()
232 if __name__
== "__main__":