]>
Commit | Line | Data |
---|---|---|
53e6db90 DC |
1 | """ |
2 | This module contains implementations for the termui module. To keep the | |
3 | import time of Click down, some infrequently used functionality is | |
4 | placed in this module and only imported as needed. | |
5 | """ | |
6 | import contextlib | |
7 | import math | |
8 | import os | |
9 | import sys | |
10 | import time | |
11 | import typing as t | |
12 | from gettext import gettext as _ | |
13 | from io import StringIO | |
14 | from types import TracebackType | |
15 | ||
16 | from ._compat import _default_text_stdout | |
17 | from ._compat import CYGWIN | |
18 | from ._compat import get_best_encoding | |
19 | from ._compat import isatty | |
20 | from ._compat import open_stream | |
21 | from ._compat import strip_ansi | |
22 | from ._compat import term_len | |
23 | from ._compat import WIN | |
24 | from .exceptions import ClickException | |
25 | from .utils import echo | |
26 | ||
27 | V = t.TypeVar("V") | |
28 | ||
29 | if os.name == "nt": | |
30 | BEFORE_BAR = "\r" | |
31 | AFTER_BAR = "\n" | |
32 | else: | |
33 | BEFORE_BAR = "\r\033[?25l" | |
34 | AFTER_BAR = "\033[?25h\n" | |
35 | ||
36 | ||
37 | class ProgressBar(t.Generic[V]): | |
38 | def __init__( | |
39 | self, | |
40 | iterable: t.Optional[t.Iterable[V]], | |
41 | length: t.Optional[int] = None, | |
42 | fill_char: str = "#", | |
43 | empty_char: str = " ", | |
44 | bar_template: str = "%(bar)s", | |
45 | info_sep: str = " ", | |
46 | show_eta: bool = True, | |
47 | show_percent: t.Optional[bool] = None, | |
48 | show_pos: bool = False, | |
49 | item_show_func: t.Optional[t.Callable[[t.Optional[V]], t.Optional[str]]] = None, | |
50 | label: t.Optional[str] = None, | |
51 | file: t.Optional[t.TextIO] = None, | |
52 | color: t.Optional[bool] = None, | |
53 | update_min_steps: int = 1, | |
54 | width: int = 30, | |
55 | ) -> None: | |
56 | self.fill_char = fill_char | |
57 | self.empty_char = empty_char | |
58 | self.bar_template = bar_template | |
59 | self.info_sep = info_sep | |
60 | self.show_eta = show_eta | |
61 | self.show_percent = show_percent | |
62 | self.show_pos = show_pos | |
63 | self.item_show_func = item_show_func | |
64 | self.label: str = label or "" | |
65 | ||
66 | if file is None: | |
67 | file = _default_text_stdout() | |
68 | ||
69 | # There are no standard streams attached to write to. For example, | |
70 | # pythonw on Windows. | |
71 | if file is None: | |
72 | file = StringIO() | |
73 | ||
74 | self.file = file | |
75 | self.color = color | |
76 | self.update_min_steps = update_min_steps | |
77 | self._completed_intervals = 0 | |
78 | self.width: int = width | |
79 | self.autowidth: bool = width == 0 | |
80 | ||
81 | if length is None: | |
82 | from operator import length_hint | |
83 | ||
84 | length = length_hint(iterable, -1) | |
85 | ||
86 | if length == -1: | |
87 | length = None | |
88 | if iterable is None: | |
89 | if length is None: | |
90 | raise TypeError("iterable or length is required") | |
91 | iterable = t.cast(t.Iterable[V], range(length)) | |
92 | self.iter: t.Iterable[V] = iter(iterable) | |
93 | self.length = length | |
94 | self.pos = 0 | |
95 | self.avg: t.List[float] = [] | |
96 | self.last_eta: float | |
97 | self.start: float | |
98 | self.start = self.last_eta = time.time() | |
99 | self.eta_known: bool = False | |
100 | self.finished: bool = False | |
101 | self.max_width: t.Optional[int] = None | |
102 | self.entered: bool = False | |
103 | self.current_item: t.Optional[V] = None | |
104 | self.is_hidden: bool = not isatty(self.file) | |
105 | self._last_line: t.Optional[str] = None | |
106 | ||
107 | def __enter__(self) -> "ProgressBar[V]": | |
108 | self.entered = True | |
109 | self.render_progress() | |
110 | return self | |
111 | ||
112 | def __exit__( | |
113 | self, | |
114 | exc_type: t.Optional[t.Type[BaseException]], | |
115 | exc_value: t.Optional[BaseException], | |
116 | tb: t.Optional[TracebackType], | |
117 | ) -> None: | |
118 | self.render_finish() | |
119 | ||
120 | def __iter__(self) -> t.Iterator[V]: | |
121 | if not self.entered: | |
122 | raise RuntimeError("You need to use progress bars in a with block.") | |
123 | self.render_progress() | |
124 | return self.generator() | |
125 | ||
126 | def __next__(self) -> V: | |
127 | # Iteration is defined in terms of a generator function, | |
128 | # returned by iter(self); use that to define next(). This works | |
129 | # because `self.iter` is an iterable consumed by that generator, | |
130 | # so it is re-entry safe. Calling `next(self.generator())` | |
131 | # twice works and does "what you want". | |
132 | return next(iter(self)) | |
133 | ||
134 | def render_finish(self) -> None: | |
135 | if self.is_hidden: | |
136 | return | |
137 | self.file.write(AFTER_BAR) | |
138 | self.file.flush() | |
139 | ||
140 | @property | |
141 | def pct(self) -> float: | |
142 | if self.finished: | |
143 | return 1.0 | |
144 | return min(self.pos / (float(self.length or 1) or 1), 1.0) | |
145 | ||
146 | @property | |
147 | def time_per_iteration(self) -> float: | |
148 | if not self.avg: | |
149 | return 0.0 | |
150 | return sum(self.avg) / float(len(self.avg)) | |
151 | ||
152 | @property | |
153 | def eta(self) -> float: | |
154 | if self.length is not None and not self.finished: | |
155 | return self.time_per_iteration * (self.length - self.pos) | |
156 | return 0.0 | |
157 | ||
158 | def format_eta(self) -> str: | |
159 | if self.eta_known: | |
160 | t = int(self.eta) | |
161 | seconds = t % 60 | |
162 | t //= 60 | |
163 | minutes = t % 60 | |
164 | t //= 60 | |
165 | hours = t % 24 | |
166 | t //= 24 | |
167 | if t > 0: | |
168 | return f"{t}d {hours:02}:{minutes:02}:{seconds:02}" | |
169 | else: | |
170 | return f"{hours:02}:{minutes:02}:{seconds:02}" | |
171 | return "" | |
172 | ||
173 | def format_pos(self) -> str: | |
174 | pos = str(self.pos) | |
175 | if self.length is not None: | |
176 | pos += f"/{self.length}" | |
177 | return pos | |
178 | ||
179 | def format_pct(self) -> str: | |
180 | return f"{int(self.pct * 100): 4}%"[1:] | |
181 | ||
182 | def format_bar(self) -> str: | |
183 | if self.length is not None: | |
184 | bar_length = int(self.pct * self.width) | |
185 | bar = self.fill_char * bar_length | |
186 | bar += self.empty_char * (self.width - bar_length) | |
187 | elif self.finished: | |
188 | bar = self.fill_char * self.width | |
189 | else: | |
190 | chars = list(self.empty_char * (self.width or 1)) | |
191 | if self.time_per_iteration != 0: | |
192 | chars[ | |
193 | int( | |
194 | (math.cos(self.pos * self.time_per_iteration) / 2.0 + 0.5) | |
195 | * self.width | |
196 | ) | |
197 | ] = self.fill_char | |
198 | bar = "".join(chars) | |
199 | return bar | |
200 | ||
201 | def format_progress_line(self) -> str: | |
202 | show_percent = self.show_percent | |
203 | ||
204 | info_bits = [] | |
205 | if self.length is not None and show_percent is None: | |
206 | show_percent = not self.show_pos | |
207 | ||
208 | if self.show_pos: | |
209 | info_bits.append(self.format_pos()) | |
210 | if show_percent: | |
211 | info_bits.append(self.format_pct()) | |
212 | if self.show_eta and self.eta_known and not self.finished: | |
213 | info_bits.append(self.format_eta()) | |
214 | if self.item_show_func is not None: | |
215 | item_info = self.item_show_func(self.current_item) | |
216 | if item_info is not None: | |
217 | info_bits.append(item_info) | |
218 | ||
219 | return ( | |
220 | self.bar_template | |
221 | % { | |
222 | "label": self.label, | |
223 | "bar": self.format_bar(), | |
224 | "info": self.info_sep.join(info_bits), | |
225 | } | |
226 | ).rstrip() | |
227 | ||
228 | def render_progress(self) -> None: | |
229 | import shutil | |
230 | ||
231 | if self.is_hidden: | |
232 | # Only output the label as it changes if the output is not a | |
233 | # TTY. Use file=stderr if you expect to be piping stdout. | |
234 | if self._last_line != self.label: | |
235 | self._last_line = self.label | |
236 | echo(self.label, file=self.file, color=self.color) | |
237 | ||
238 | return | |
239 | ||
240 | buf = [] | |
241 | # Update width in case the terminal has been resized | |
242 | if self.autowidth: | |
243 | old_width = self.width | |
244 | self.width = 0 | |
245 | clutter_length = term_len(self.format_progress_line()) | |
246 | new_width = max(0, shutil.get_terminal_size().columns - clutter_length) | |
247 | if new_width < old_width: | |
248 | buf.append(BEFORE_BAR) | |
249 | buf.append(" " * self.max_width) # type: ignore | |
250 | self.max_width = new_width | |
251 | self.width = new_width | |
252 | ||
253 | clear_width = self.width | |
254 | if self.max_width is not None: | |
255 | clear_width = self.max_width | |
256 | ||
257 | buf.append(BEFORE_BAR) | |
258 | line = self.format_progress_line() | |
259 | line_len = term_len(line) | |
260 | if self.max_width is None or self.max_width < line_len: | |
261 | self.max_width = line_len | |
262 | ||
263 | buf.append(line) | |
264 | buf.append(" " * (clear_width - line_len)) | |
265 | line = "".join(buf) | |
266 | # Render the line only if it changed. | |
267 | ||
268 | if line != self._last_line: | |
269 | self._last_line = line | |
270 | echo(line, file=self.file, color=self.color, nl=False) | |
271 | self.file.flush() | |
272 | ||
273 | def make_step(self, n_steps: int) -> None: | |
274 | self.pos += n_steps | |
275 | if self.length is not None and self.pos >= self.length: | |
276 | self.finished = True | |
277 | ||
278 | if (time.time() - self.last_eta) < 1.0: | |
279 | return | |
280 | ||
281 | self.last_eta = time.time() | |
282 | ||
283 | # self.avg is a rolling list of length <= 7 of steps where steps are | |
284 | # defined as time elapsed divided by the total progress through | |
285 | # self.length. | |
286 | if self.pos: | |
287 | step = (time.time() - self.start) / self.pos | |
288 | else: | |
289 | step = time.time() - self.start | |
290 | ||
291 | self.avg = self.avg[-6:] + [step] | |
292 | ||
293 | self.eta_known = self.length is not None | |
294 | ||
295 | def update(self, n_steps: int, current_item: t.Optional[V] = None) -> None: | |
296 | """Update the progress bar by advancing a specified number of | |
297 | steps, and optionally set the ``current_item`` for this new | |
298 | position. | |
299 | ||
300 | :param n_steps: Number of steps to advance. | |
301 | :param current_item: Optional item to set as ``current_item`` | |
302 | for the updated position. | |
303 | ||
304 | .. versionchanged:: 8.0 | |
305 | Added the ``current_item`` optional parameter. | |
306 | ||
307 | .. versionchanged:: 8.0 | |
308 | Only render when the number of steps meets the | |
309 | ``update_min_steps`` threshold. | |
310 | """ | |
311 | if current_item is not None: | |
312 | self.current_item = current_item | |
313 | ||
314 | self._completed_intervals += n_steps | |
315 | ||
316 | if self._completed_intervals >= self.update_min_steps: | |
317 | self.make_step(self._completed_intervals) | |
318 | self.render_progress() | |
319 | self._completed_intervals = 0 | |
320 | ||
321 | def finish(self) -> None: | |
322 | self.eta_known = False | |
323 | self.current_item = None | |
324 | self.finished = True | |
325 | ||
326 | def generator(self) -> t.Iterator[V]: | |
327 | """Return a generator which yields the items added to the bar | |
328 | during construction, and updates the progress bar *after* the | |
329 | yielded block returns. | |
330 | """ | |
331 | # WARNING: the iterator interface for `ProgressBar` relies on | |
332 | # this and only works because this is a simple generator which | |
333 | # doesn't create or manage additional state. If this function | |
334 | # changes, the impact should be evaluated both against | |
335 | # `iter(bar)` and `next(bar)`. `next()` in particular may call | |
336 | # `self.generator()` repeatedly, and this must remain safe in | |
337 | # order for that interface to work. | |
338 | if not self.entered: | |
339 | raise RuntimeError("You need to use progress bars in a with block.") | |
340 | ||
341 | if self.is_hidden: | |
342 | yield from self.iter | |
343 | else: | |
344 | for rv in self.iter: | |
345 | self.current_item = rv | |
346 | ||
347 | # This allows show_item_func to be updated before the | |
348 | # item is processed. Only trigger at the beginning of | |
349 | # the update interval. | |
350 | if self._completed_intervals == 0: | |
351 | self.render_progress() | |
352 | ||
353 | yield rv | |
354 | self.update(1) | |
355 | ||
356 | self.finish() | |
357 | self.render_progress() | |
358 | ||
359 | ||
360 | def pager(generator: t.Iterable[str], color: t.Optional[bool] = None) -> None: | |
361 | """Decide what method to use for paging through text.""" | |
362 | stdout = _default_text_stdout() | |
363 | ||
364 | # There are no standard streams attached to write to. For example, | |
365 | # pythonw on Windows. | |
366 | if stdout is None: | |
367 | stdout = StringIO() | |
368 | ||
369 | if not isatty(sys.stdin) or not isatty(stdout): | |
370 | return _nullpager(stdout, generator, color) | |
371 | pager_cmd = (os.environ.get("PAGER", None) or "").strip() | |
372 | if pager_cmd: | |
373 | if WIN: | |
374 | return _tempfilepager(generator, pager_cmd, color) | |
375 | return _pipepager(generator, pager_cmd, color) | |
376 | if os.environ.get("TERM") in ("dumb", "emacs"): | |
377 | return _nullpager(stdout, generator, color) | |
378 | if WIN or sys.platform.startswith("os2"): | |
379 | return _tempfilepager(generator, "more <", color) | |
380 | if hasattr(os, "system") and os.system("(less) 2>/dev/null") == 0: | |
381 | return _pipepager(generator, "less", color) | |
382 | ||
383 | import tempfile | |
384 | ||
385 | fd, filename = tempfile.mkstemp() | |
386 | os.close(fd) | |
387 | try: | |
388 | if hasattr(os, "system") and os.system(f'more "{filename}"') == 0: | |
389 | return _pipepager(generator, "more", color) | |
390 | return _nullpager(stdout, generator, color) | |
391 | finally: | |
392 | os.unlink(filename) | |
393 | ||
394 | ||
395 | def _pipepager(generator: t.Iterable[str], cmd: str, color: t.Optional[bool]) -> None: | |
396 | """Page through text by feeding it to another program. Invoking a | |
397 | pager through this might support colors. | |
398 | """ | |
399 | import subprocess | |
400 | ||
401 | env = dict(os.environ) | |
402 | ||
403 | # If we're piping to less we might support colors under the | |
404 | # condition that | |
405 | cmd_detail = cmd.rsplit("/", 1)[-1].split() | |
406 | if color is None and cmd_detail[0] == "less": | |
407 | less_flags = f"{os.environ.get('LESS', '')}{' '.join(cmd_detail[1:])}" | |
408 | if not less_flags: | |
409 | env["LESS"] = "-R" | |
410 | color = True | |
411 | elif "r" in less_flags or "R" in less_flags: | |
412 | color = True | |
413 | ||
414 | c = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, env=env) | |
415 | stdin = t.cast(t.BinaryIO, c.stdin) | |
416 | encoding = get_best_encoding(stdin) | |
417 | try: | |
418 | for text in generator: | |
419 | if not color: | |
420 | text = strip_ansi(text) | |
421 | ||
422 | stdin.write(text.encode(encoding, "replace")) | |
423 | except (OSError, KeyboardInterrupt): | |
424 | pass | |
425 | else: | |
426 | stdin.close() | |
427 | ||
428 | # Less doesn't respect ^C, but catches it for its own UI purposes (aborting | |
429 | # search or other commands inside less). | |
430 | # | |
431 | # That means when the user hits ^C, the parent process (click) terminates, | |
432 | # but less is still alive, paging the output and messing up the terminal. | |
433 | # | |
434 | # If the user wants to make the pager exit on ^C, they should set | |
435 | # `LESS='-K'`. It's not our decision to make. | |
436 | while True: | |
437 | try: | |
438 | c.wait() | |
439 | except KeyboardInterrupt: | |
440 | pass | |
441 | else: | |
442 | break | |
443 | ||
444 | ||
445 | def _tempfilepager( | |
446 | generator: t.Iterable[str], cmd: str, color: t.Optional[bool] | |
447 | ) -> None: | |
448 | """Page through text by invoking a program on a temporary file.""" | |
449 | import tempfile | |
450 | ||
451 | fd, filename = tempfile.mkstemp() | |
452 | # TODO: This never terminates if the passed generator never terminates. | |
453 | text = "".join(generator) | |
454 | if not color: | |
455 | text = strip_ansi(text) | |
456 | encoding = get_best_encoding(sys.stdout) | |
457 | with open_stream(filename, "wb")[0] as f: | |
458 | f.write(text.encode(encoding)) | |
459 | try: | |
460 | os.system(f'{cmd} "{filename}"') | |
461 | finally: | |
462 | os.close(fd) | |
463 | os.unlink(filename) | |
464 | ||
465 | ||
466 | def _nullpager( | |
467 | stream: t.TextIO, generator: t.Iterable[str], color: t.Optional[bool] | |
468 | ) -> None: | |
469 | """Simply print unformatted text. This is the ultimate fallback.""" | |
470 | for text in generator: | |
471 | if not color: | |
472 | text = strip_ansi(text) | |
473 | stream.write(text) | |
474 | ||
475 | ||
476 | class Editor: | |
477 | def __init__( | |
478 | self, | |
479 | editor: t.Optional[str] = None, | |
480 | env: t.Optional[t.Mapping[str, str]] = None, | |
481 | require_save: bool = True, | |
482 | extension: str = ".txt", | |
483 | ) -> None: | |
484 | self.editor = editor | |
485 | self.env = env | |
486 | self.require_save = require_save | |
487 | self.extension = extension | |
488 | ||
489 | def get_editor(self) -> str: | |
490 | if self.editor is not None: | |
491 | return self.editor | |
492 | for key in "VISUAL", "EDITOR": | |
493 | rv = os.environ.get(key) | |
494 | if rv: | |
495 | return rv | |
496 | if WIN: | |
497 | return "notepad" | |
498 | for editor in "sensible-editor", "vim", "nano": | |
499 | if os.system(f"which {editor} >/dev/null 2>&1") == 0: | |
500 | return editor | |
501 | return "vi" | |
502 | ||
503 | def edit_file(self, filename: str) -> None: | |
504 | import subprocess | |
505 | ||
506 | editor = self.get_editor() | |
507 | environ: t.Optional[t.Dict[str, str]] = None | |
508 | ||
509 | if self.env: | |
510 | environ = os.environ.copy() | |
511 | environ.update(self.env) | |
512 | ||
513 | try: | |
514 | c = subprocess.Popen(f'{editor} "{filename}"', env=environ, shell=True) | |
515 | exit_code = c.wait() | |
516 | if exit_code != 0: | |
517 | raise ClickException( | |
518 | _("{editor}: Editing failed").format(editor=editor) | |
519 | ) | |
520 | except OSError as e: | |
521 | raise ClickException( | |
522 | _("{editor}: Editing failed: {e}").format(editor=editor, e=e) | |
523 | ) from e | |
524 | ||
525 | def edit(self, text: t.Optional[t.AnyStr]) -> t.Optional[t.AnyStr]: | |
526 | import tempfile | |
527 | ||
528 | if not text: | |
529 | data = b"" | |
530 | elif isinstance(text, (bytes, bytearray)): | |
531 | data = text | |
532 | else: | |
533 | if text and not text.endswith("\n"): | |
534 | text += "\n" | |
535 | ||
536 | if WIN: | |
537 | data = text.replace("\n", "\r\n").encode("utf-8-sig") | |
538 | else: | |
539 | data = text.encode("utf-8") | |
540 | ||
541 | fd, name = tempfile.mkstemp(prefix="editor-", suffix=self.extension) | |
542 | f: t.BinaryIO | |
543 | ||
544 | try: | |
545 | with os.fdopen(fd, "wb") as f: | |
546 | f.write(data) | |
547 | ||
548 | # If the filesystem resolution is 1 second, like Mac OS | |
549 | # 10.12 Extended, or 2 seconds, like FAT32, and the editor | |
550 | # closes very fast, require_save can fail. Set the modified | |
551 | # time to be 2 seconds in the past to work around this. | |
552 | os.utime(name, (os.path.getatime(name), os.path.getmtime(name) - 2)) | |
553 | # Depending on the resolution, the exact value might not be | |
554 | # recorded, so get the new recorded value. | |
555 | timestamp = os.path.getmtime(name) | |
556 | ||
557 | self.edit_file(name) | |
558 | ||
559 | if self.require_save and os.path.getmtime(name) == timestamp: | |
560 | return None | |
561 | ||
562 | with open(name, "rb") as f: | |
563 | rv = f.read() | |
564 | ||
565 | if isinstance(text, (bytes, bytearray)): | |
566 | return rv | |
567 | ||
568 | return rv.decode("utf-8-sig").replace("\r\n", "\n") # type: ignore | |
569 | finally: | |
570 | os.unlink(name) | |
571 | ||
572 | ||
573 | def open_url(url: str, wait: bool = False, locate: bool = False) -> int: | |
574 | import subprocess | |
575 | ||
576 | def _unquote_file(url: str) -> str: | |
577 | from urllib.parse import unquote | |
578 | ||
579 | if url.startswith("file://"): | |
580 | url = unquote(url[7:]) | |
581 | ||
582 | return url | |
583 | ||
584 | if sys.platform == "darwin": | |
585 | args = ["open"] | |
586 | if wait: | |
587 | args.append("-W") | |
588 | if locate: | |
589 | args.append("-R") | |
590 | args.append(_unquote_file(url)) | |
591 | null = open("/dev/null", "w") | |
592 | try: | |
593 | return subprocess.Popen(args, stderr=null).wait() | |
594 | finally: | |
595 | null.close() | |
596 | elif WIN: | |
597 | if locate: | |
598 | url = _unquote_file(url.replace('"', "")) | |
599 | args = f'explorer /select,"{url}"' | |
600 | else: | |
601 | url = url.replace('"', "") | |
602 | wait_str = "/WAIT" if wait else "" | |
603 | args = f'start {wait_str} "" "{url}"' | |
604 | return os.system(args) | |
605 | elif CYGWIN: | |
606 | if locate: | |
607 | url = os.path.dirname(_unquote_file(url).replace('"', "")) | |
608 | args = f'cygstart "{url}"' | |
609 | else: | |
610 | url = url.replace('"', "") | |
611 | wait_str = "-w" if wait else "" | |
612 | args = f'cygstart {wait_str} "{url}"' | |
613 | return os.system(args) | |
614 | ||
615 | try: | |
616 | if locate: | |
617 | url = os.path.dirname(_unquote_file(url)) or "." | |
618 | else: | |
619 | url = _unquote_file(url) | |
620 | c = subprocess.Popen(["xdg-open", url]) | |
621 | if wait: | |
622 | return c.wait() | |
623 | return 0 | |
624 | except OSError: | |
625 | if url.startswith(("http://", "https://")) and not locate and not wait: | |
626 | import webbrowser | |
627 | ||
628 | webbrowser.open(url) | |
629 | return 0 | |
630 | return 1 | |
631 | ||
632 | ||
633 | def _translate_ch_to_exc(ch: str) -> t.Optional[BaseException]: | |
634 | if ch == "\x03": | |
635 | raise KeyboardInterrupt() | |
636 | ||
637 | if ch == "\x04" and not WIN: # Unix-like, Ctrl+D | |
638 | raise EOFError() | |
639 | ||
640 | if ch == "\x1a" and WIN: # Windows, Ctrl+Z | |
641 | raise EOFError() | |
642 | ||
643 | return None | |
644 | ||
645 | ||
646 | if WIN: | |
647 | import msvcrt | |
648 | ||
649 | @contextlib.contextmanager | |
650 | def raw_terminal() -> t.Iterator[int]: | |
651 | yield -1 | |
652 | ||
653 | def getchar(echo: bool) -> str: | |
654 | # The function `getch` will return a bytes object corresponding to | |
655 | # the pressed character. Since Windows 10 build 1803, it will also | |
656 | # return \x00 when called a second time after pressing a regular key. | |
657 | # | |
658 | # `getwch` does not share this probably-bugged behavior. Moreover, it | |
659 | # returns a Unicode object by default, which is what we want. | |
660 | # | |
661 | # Either of these functions will return \x00 or \xe0 to indicate | |
662 | # a special key, and you need to call the same function again to get | |
663 | # the "rest" of the code. The fun part is that \u00e0 is | |
664 | # "latin small letter a with grave", so if you type that on a French | |
665 | # keyboard, you _also_ get a \xe0. | |
666 | # E.g., consider the Up arrow. This returns \xe0 and then \x48. The | |
667 | # resulting Unicode string reads as "a with grave" + "capital H". | |
668 | # This is indistinguishable from when the user actually types | |
669 | # "a with grave" and then "capital H". | |
670 | # | |
671 | # When \xe0 is returned, we assume it's part of a special-key sequence | |
672 | # and call `getwch` again, but that means that when the user types | |
673 | # the \u00e0 character, `getchar` doesn't return until a second | |
674 | # character is typed. | |
675 | # The alternative is returning immediately, but that would mess up | |
676 | # cross-platform handling of arrow keys and others that start with | |
677 | # \xe0. Another option is using `getch`, but then we can't reliably | |
678 | # read non-ASCII characters, because return values of `getch` are | |
679 | # limited to the current 8-bit codepage. | |
680 | # | |
681 | # Anyway, Click doesn't claim to do this Right(tm), and using `getwch` | |
682 | # is doing the right thing in more situations than with `getch`. | |
683 | func: t.Callable[[], str] | |
684 | ||
685 | if echo: | |
686 | func = msvcrt.getwche # type: ignore | |
687 | else: | |
688 | func = msvcrt.getwch # type: ignore | |
689 | ||
690 | rv = func() | |
691 | ||
692 | if rv in ("\x00", "\xe0"): | |
693 | # \x00 and \xe0 are control characters that indicate special key, | |
694 | # see above. | |
695 | rv += func() | |
696 | ||
697 | _translate_ch_to_exc(rv) | |
698 | return rv | |
699 | ||
700 | else: | |
701 | import tty | |
702 | import termios | |
703 | ||
704 | @contextlib.contextmanager | |
705 | def raw_terminal() -> t.Iterator[int]: | |
706 | f: t.Optional[t.TextIO] | |
707 | fd: int | |
708 | ||
709 | if not isatty(sys.stdin): | |
710 | f = open("/dev/tty") | |
711 | fd = f.fileno() | |
712 | else: | |
713 | fd = sys.stdin.fileno() | |
714 | f = None | |
715 | ||
716 | try: | |
717 | old_settings = termios.tcgetattr(fd) | |
718 | ||
719 | try: | |
720 | tty.setraw(fd) | |
721 | yield fd | |
722 | finally: | |
723 | termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) | |
724 | sys.stdout.flush() | |
725 | ||
726 | if f is not None: | |
727 | f.close() | |
728 | except termios.error: | |
729 | pass | |
730 | ||
731 | def getchar(echo: bool) -> str: | |
732 | with raw_terminal() as fd: | |
733 | ch = os.read(fd, 32).decode(get_best_encoding(sys.stdin), "replace") | |
734 | ||
735 | if echo and isatty(sys.stdout): | |
736 | sys.stdout.write(ch) | |
737 | ||
738 | _translate_ch_to_exc(ch) | |
739 | return ch |