]>
Commit | Line | Data |
---|---|---|
53e6db90 DC |
1 | from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable, TypeVar |
2 | ||
3 | from aiohttp.web_request import Request | |
4 | from aiohttp.web_response import StreamResponse | |
5 | ||
6 | if TYPE_CHECKING: | |
7 | F = TypeVar("F", bound=Callable[..., Any]) | |
8 | middleware: Callable[[F], F] | |
9 | else: | |
10 | try: | |
11 | from aiohttp.web_middlewares import middleware | |
12 | except ImportError: | |
13 | # @middleware is deprecated and its behaviour is the default since aiohttp 4.0 | |
14 | # so if it doesn't exist anymore, define a no-op for forward compatibility. | |
15 | middleware = lambda x: x # noqa: E731 | |
16 | ||
17 | Handler = Callable[[Request], Awaitable[StreamResponse]] | |
18 | Middleware = Callable[[Request, Handler], Awaitable[StreamResponse]] | |
19 | ||
20 | ||
21 | def cors(allow_headers: Iterable[str]) -> Middleware: | |
22 | @middleware | |
23 | async def impl(request: Request, handler: Handler) -> StreamResponse: | |
24 | is_options = request.method == "OPTIONS" | |
25 | is_preflight = is_options and "Access-Control-Request-Method" in request.headers | |
26 | if is_preflight: | |
27 | resp = StreamResponse() | |
28 | else: | |
29 | resp = await handler(request) | |
30 | ||
31 | origin = request.headers.get("Origin") | |
32 | if not origin: | |
33 | return resp | |
34 | ||
35 | resp.headers["Access-Control-Allow-Origin"] = "*" | |
36 | resp.headers["Access-Control-Expose-Headers"] = "*" | |
37 | if is_options: | |
38 | resp.headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers) | |
39 | resp.headers["Access-Control-Allow-Methods"] = ", ".join( | |
40 | ("OPTIONS", "POST") | |
41 | ) | |
42 | ||
43 | return resp | |
44 | ||
45 | return impl |