]>
Commit | Line | Data |
---|---|---|
53e6db90 DC |
1 | # This module is based on the excellent work by Adam Bartoš who |
2 | # provided a lot of what went into the implementation here in | |
3 | # the discussion to issue1602 in the Python bug tracker. | |
4 | # | |
5 | # There are some general differences in regards to how this works | |
6 | # compared to the original patches as we do not need to patch | |
7 | # the entire interpreter but just work in our little world of | |
8 | # echo and prompt. | |
9 | import io | |
10 | import sys | |
11 | import time | |
12 | import typing as t | |
13 | from ctypes import byref | |
14 | from ctypes import c_char | |
15 | from ctypes import c_char_p | |
16 | from ctypes import c_int | |
17 | from ctypes import c_ssize_t | |
18 | from ctypes import c_ulong | |
19 | from ctypes import c_void_p | |
20 | from ctypes import POINTER | |
21 | from ctypes import py_object | |
22 | from ctypes import Structure | |
23 | from ctypes.wintypes import DWORD | |
24 | from ctypes.wintypes import HANDLE | |
25 | from ctypes.wintypes import LPCWSTR | |
26 | from ctypes.wintypes import LPWSTR | |
27 | ||
28 | from ._compat import _NonClosingTextIOWrapper | |
29 | ||
30 | assert sys.platform == "win32" | |
31 | import msvcrt # noqa: E402 | |
32 | from ctypes import windll # noqa: E402 | |
33 | from ctypes import WINFUNCTYPE # noqa: E402 | |
34 | ||
35 | c_ssize_p = POINTER(c_ssize_t) | |
36 | ||
37 | kernel32 = windll.kernel32 | |
38 | GetStdHandle = kernel32.GetStdHandle | |
39 | ReadConsoleW = kernel32.ReadConsoleW | |
40 | WriteConsoleW = kernel32.WriteConsoleW | |
41 | GetConsoleMode = kernel32.GetConsoleMode | |
42 | GetLastError = kernel32.GetLastError | |
43 | GetCommandLineW = WINFUNCTYPE(LPWSTR)(("GetCommandLineW", windll.kernel32)) | |
44 | CommandLineToArgvW = WINFUNCTYPE(POINTER(LPWSTR), LPCWSTR, POINTER(c_int))( | |
45 | ("CommandLineToArgvW", windll.shell32) | |
46 | ) | |
47 | LocalFree = WINFUNCTYPE(c_void_p, c_void_p)(("LocalFree", windll.kernel32)) | |
48 | ||
49 | STDIN_HANDLE = GetStdHandle(-10) | |
50 | STDOUT_HANDLE = GetStdHandle(-11) | |
51 | STDERR_HANDLE = GetStdHandle(-12) | |
52 | ||
53 | PyBUF_SIMPLE = 0 | |
54 | PyBUF_WRITABLE = 1 | |
55 | ||
56 | ERROR_SUCCESS = 0 | |
57 | ERROR_NOT_ENOUGH_MEMORY = 8 | |
58 | ERROR_OPERATION_ABORTED = 995 | |
59 | ||
60 | STDIN_FILENO = 0 | |
61 | STDOUT_FILENO = 1 | |
62 | STDERR_FILENO = 2 | |
63 | ||
64 | EOF = b"\x1a" | |
65 | MAX_BYTES_WRITTEN = 32767 | |
66 | ||
67 | try: | |
68 | from ctypes import pythonapi | |
69 | except ImportError: | |
70 | # On PyPy we cannot get buffers so our ability to operate here is | |
71 | # severely limited. | |
72 | get_buffer = None | |
73 | else: | |
74 | ||
75 | class Py_buffer(Structure): | |
76 | _fields_ = [ | |
77 | ("buf", c_void_p), | |
78 | ("obj", py_object), | |
79 | ("len", c_ssize_t), | |
80 | ("itemsize", c_ssize_t), | |
81 | ("readonly", c_int), | |
82 | ("ndim", c_int), | |
83 | ("format", c_char_p), | |
84 | ("shape", c_ssize_p), | |
85 | ("strides", c_ssize_p), | |
86 | ("suboffsets", c_ssize_p), | |
87 | ("internal", c_void_p), | |
88 | ] | |
89 | ||
90 | PyObject_GetBuffer = pythonapi.PyObject_GetBuffer | |
91 | PyBuffer_Release = pythonapi.PyBuffer_Release | |
92 | ||
93 | def get_buffer(obj, writable=False): | |
94 | buf = Py_buffer() | |
95 | flags = PyBUF_WRITABLE if writable else PyBUF_SIMPLE | |
96 | PyObject_GetBuffer(py_object(obj), byref(buf), flags) | |
97 | ||
98 | try: | |
99 | buffer_type = c_char * buf.len | |
100 | return buffer_type.from_address(buf.buf) | |
101 | finally: | |
102 | PyBuffer_Release(byref(buf)) | |
103 | ||
104 | ||
105 | class _WindowsConsoleRawIOBase(io.RawIOBase): | |
106 | def __init__(self, handle): | |
107 | self.handle = handle | |
108 | ||
109 | def isatty(self): | |
110 | super().isatty() | |
111 | return True | |
112 | ||
113 | ||
114 | class _WindowsConsoleReader(_WindowsConsoleRawIOBase): | |
115 | def readable(self): | |
116 | return True | |
117 | ||
118 | def readinto(self, b): | |
119 | bytes_to_be_read = len(b) | |
120 | if not bytes_to_be_read: | |
121 | return 0 | |
122 | elif bytes_to_be_read % 2: | |
123 | raise ValueError( | |
124 | "cannot read odd number of bytes from UTF-16-LE encoded console" | |
125 | ) | |
126 | ||
127 | buffer = get_buffer(b, writable=True) | |
128 | code_units_to_be_read = bytes_to_be_read // 2 | |
129 | code_units_read = c_ulong() | |
130 | ||
131 | rv = ReadConsoleW( | |
132 | HANDLE(self.handle), | |
133 | buffer, | |
134 | code_units_to_be_read, | |
135 | byref(code_units_read), | |
136 | None, | |
137 | ) | |
138 | if GetLastError() == ERROR_OPERATION_ABORTED: | |
139 | # wait for KeyboardInterrupt | |
140 | time.sleep(0.1) | |
141 | if not rv: | |
142 | raise OSError(f"Windows error: {GetLastError()}") | |
143 | ||
144 | if buffer[0] == EOF: | |
145 | return 0 | |
146 | return 2 * code_units_read.value | |
147 | ||
148 | ||
149 | class _WindowsConsoleWriter(_WindowsConsoleRawIOBase): | |
150 | def writable(self): | |
151 | return True | |
152 | ||
153 | @staticmethod | |
154 | def _get_error_message(errno): | |
155 | if errno == ERROR_SUCCESS: | |
156 | return "ERROR_SUCCESS" | |
157 | elif errno == ERROR_NOT_ENOUGH_MEMORY: | |
158 | return "ERROR_NOT_ENOUGH_MEMORY" | |
159 | return f"Windows error {errno}" | |
160 | ||
161 | def write(self, b): | |
162 | bytes_to_be_written = len(b) | |
163 | buf = get_buffer(b) | |
164 | code_units_to_be_written = min(bytes_to_be_written, MAX_BYTES_WRITTEN) // 2 | |
165 | code_units_written = c_ulong() | |
166 | ||
167 | WriteConsoleW( | |
168 | HANDLE(self.handle), | |
169 | buf, | |
170 | code_units_to_be_written, | |
171 | byref(code_units_written), | |
172 | None, | |
173 | ) | |
174 | bytes_written = 2 * code_units_written.value | |
175 | ||
176 | if bytes_written == 0 and bytes_to_be_written > 0: | |
177 | raise OSError(self._get_error_message(GetLastError())) | |
178 | return bytes_written | |
179 | ||
180 | ||
181 | class ConsoleStream: | |
182 | def __init__(self, text_stream: t.TextIO, byte_stream: t.BinaryIO) -> None: | |
183 | self._text_stream = text_stream | |
184 | self.buffer = byte_stream | |
185 | ||
186 | @property | |
187 | def name(self) -> str: | |
188 | return self.buffer.name | |
189 | ||
190 | def write(self, x: t.AnyStr) -> int: | |
191 | if isinstance(x, str): | |
192 | return self._text_stream.write(x) | |
193 | try: | |
194 | self.flush() | |
195 | except Exception: | |
196 | pass | |
197 | return self.buffer.write(x) | |
198 | ||
199 | def writelines(self, lines: t.Iterable[t.AnyStr]) -> None: | |
200 | for line in lines: | |
201 | self.write(line) | |
202 | ||
203 | def __getattr__(self, name: str) -> t.Any: | |
204 | return getattr(self._text_stream, name) | |
205 | ||
206 | def isatty(self) -> bool: | |
207 | return self.buffer.isatty() | |
208 | ||
209 | def __repr__(self): | |
210 | return f"<ConsoleStream name={self.name!r} encoding={self.encoding!r}>" | |
211 | ||
212 | ||
213 | def _get_text_stdin(buffer_stream: t.BinaryIO) -> t.TextIO: | |
214 | text_stream = _NonClosingTextIOWrapper( | |
215 | io.BufferedReader(_WindowsConsoleReader(STDIN_HANDLE)), | |
216 | "utf-16-le", | |
217 | "strict", | |
218 | line_buffering=True, | |
219 | ) | |
220 | return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream)) | |
221 | ||
222 | ||
223 | def _get_text_stdout(buffer_stream: t.BinaryIO) -> t.TextIO: | |
224 | text_stream = _NonClosingTextIOWrapper( | |
225 | io.BufferedWriter(_WindowsConsoleWriter(STDOUT_HANDLE)), | |
226 | "utf-16-le", | |
227 | "strict", | |
228 | line_buffering=True, | |
229 | ) | |
230 | return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream)) | |
231 | ||
232 | ||
233 | def _get_text_stderr(buffer_stream: t.BinaryIO) -> t.TextIO: | |
234 | text_stream = _NonClosingTextIOWrapper( | |
235 | io.BufferedWriter(_WindowsConsoleWriter(STDERR_HANDLE)), | |
236 | "utf-16-le", | |
237 | "strict", | |
238 | line_buffering=True, | |
239 | ) | |
240 | return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream)) | |
241 | ||
242 | ||
243 | _stream_factories: t.Mapping[int, t.Callable[[t.BinaryIO], t.TextIO]] = { | |
244 | 0: _get_text_stdin, | |
245 | 1: _get_text_stdout, | |
246 | 2: _get_text_stderr, | |
247 | } | |
248 | ||
249 | ||
250 | def _is_console(f: t.TextIO) -> bool: | |
251 | if not hasattr(f, "fileno"): | |
252 | return False | |
253 | ||
254 | try: | |
255 | fileno = f.fileno() | |
256 | except (OSError, io.UnsupportedOperation): | |
257 | return False | |
258 | ||
259 | handle = msvcrt.get_osfhandle(fileno) | |
260 | return bool(GetConsoleMode(handle, byref(DWORD()))) | |
261 | ||
262 | ||
263 | def _get_windows_console_stream( | |
264 | f: t.TextIO, encoding: t.Optional[str], errors: t.Optional[str] | |
265 | ) -> t.Optional[t.TextIO]: | |
266 | if ( | |
267 | get_buffer is not None | |
268 | and encoding in {"utf-16-le", None} | |
269 | and errors in {"strict", None} | |
270 | and _is_console(f) | |
271 | ): | |
272 | func = _stream_factories.get(f.fileno()) | |
273 | if func is not None: | |
274 | b = getattr(f, "buffer", None) | |
275 | ||
276 | if b is None: | |
277 | return None | |
278 | ||
279 | return func(b) |