]>
Commit | Line | Data |
---|---|---|
1 | import contextlib | |
2 | import io | |
3 | import os | |
4 | import shlex | |
5 | import shutil | |
6 | import sys | |
7 | import tempfile | |
8 | import typing as t | |
9 | from types import TracebackType | |
10 | ||
11 | from . import formatting | |
12 | from . import termui | |
13 | from . import utils | |
14 | from ._compat import _find_binary_reader | |
15 | ||
16 | if t.TYPE_CHECKING: | |
17 | from .core import BaseCommand | |
18 | ||
19 | ||
20 | class EchoingStdin: | |
21 | def __init__(self, input: t.BinaryIO, output: t.BinaryIO) -> None: | |
22 | self._input = input | |
23 | self._output = output | |
24 | self._paused = False | |
25 | ||
26 | def __getattr__(self, x: str) -> t.Any: | |
27 | return getattr(self._input, x) | |
28 | ||
29 | def _echo(self, rv: bytes) -> bytes: | |
30 | if not self._paused: | |
31 | self._output.write(rv) | |
32 | ||
33 | return rv | |
34 | ||
35 | def read(self, n: int = -1) -> bytes: | |
36 | return self._echo(self._input.read(n)) | |
37 | ||
38 | def read1(self, n: int = -1) -> bytes: | |
39 | return self._echo(self._input.read1(n)) # type: ignore | |
40 | ||
41 | def readline(self, n: int = -1) -> bytes: | |
42 | return self._echo(self._input.readline(n)) | |
43 | ||
44 | def readlines(self) -> t.List[bytes]: | |
45 | return [self._echo(x) for x in self._input.readlines()] | |
46 | ||
47 | def __iter__(self) -> t.Iterator[bytes]: | |
48 | return iter(self._echo(x) for x in self._input) | |
49 | ||
50 | def __repr__(self) -> str: | |
51 | return repr(self._input) | |
52 | ||
53 | ||
54 | @contextlib.contextmanager | |
55 | def _pause_echo(stream: t.Optional[EchoingStdin]) -> t.Iterator[None]: | |
56 | if stream is None: | |
57 | yield | |
58 | else: | |
59 | stream._paused = True | |
60 | yield | |
61 | stream._paused = False | |
62 | ||
63 | ||
64 | class _NamedTextIOWrapper(io.TextIOWrapper): | |
65 | def __init__( | |
66 | self, buffer: t.BinaryIO, name: str, mode: str, **kwargs: t.Any | |
67 | ) -> None: | |
68 | super().__init__(buffer, **kwargs) | |
69 | self._name = name | |
70 | self._mode = mode | |
71 | ||
72 | @property | |
73 | def name(self) -> str: | |
74 | return self._name | |
75 | ||
76 | @property | |
77 | def mode(self) -> str: | |
78 | return self._mode | |
79 | ||
80 | ||
81 | def make_input_stream( | |
82 | input: t.Optional[t.Union[str, bytes, t.IO[t.Any]]], charset: str | |
83 | ) -> t.BinaryIO: | |
84 | # Is already an input stream. | |
85 | if hasattr(input, "read"): | |
86 | rv = _find_binary_reader(t.cast(t.IO[t.Any], input)) | |
87 | ||
88 | if rv is not None: | |
89 | return rv | |
90 | ||
91 | raise TypeError("Could not find binary reader for input stream.") | |
92 | ||
93 | if input is None: | |
94 | input = b"" | |
95 | elif isinstance(input, str): | |
96 | input = input.encode(charset) | |
97 | ||
98 | return io.BytesIO(input) | |
99 | ||
100 | ||
101 | class Result: | |
102 | """Holds the captured result of an invoked CLI script.""" | |
103 | ||
104 | def __init__( | |
105 | self, | |
106 | runner: "CliRunner", | |
107 | stdout_bytes: bytes, | |
108 | stderr_bytes: t.Optional[bytes], | |
109 | return_value: t.Any, | |
110 | exit_code: int, | |
111 | exception: t.Optional[BaseException], | |
112 | exc_info: t.Optional[ | |
113 | t.Tuple[t.Type[BaseException], BaseException, TracebackType] | |
114 | ] = None, | |
115 | ): | |
116 | #: The runner that created the result | |
117 | self.runner = runner | |
118 | #: The standard output as bytes. | |
119 | self.stdout_bytes = stdout_bytes | |
120 | #: The standard error as bytes, or None if not available | |
121 | self.stderr_bytes = stderr_bytes | |
122 | #: The value returned from the invoked command. | |
123 | #: | |
124 | #: .. versionadded:: 8.0 | |
125 | self.return_value = return_value | |
126 | #: The exit code as integer. | |
127 | self.exit_code = exit_code | |
128 | #: The exception that happened if one did. | |
129 | self.exception = exception | |
130 | #: The traceback | |
131 | self.exc_info = exc_info | |
132 | ||
133 | @property | |
134 | def output(self) -> str: | |
135 | """The (standard) output as unicode string.""" | |
136 | return self.stdout | |
137 | ||
138 | @property | |
139 | def stdout(self) -> str: | |
140 | """The standard output as unicode string.""" | |
141 | return self.stdout_bytes.decode(self.runner.charset, "replace").replace( | |
142 | "\r\n", "\n" | |
143 | ) | |
144 | ||
145 | @property | |
146 | def stderr(self) -> str: | |
147 | """The standard error as unicode string.""" | |
148 | if self.stderr_bytes is None: | |
149 | raise ValueError("stderr not separately captured") | |
150 | return self.stderr_bytes.decode(self.runner.charset, "replace").replace( | |
151 | "\r\n", "\n" | |
152 | ) | |
153 | ||
154 | def __repr__(self) -> str: | |
155 | exc_str = repr(self.exception) if self.exception else "okay" | |
156 | return f"<{type(self).__name__} {exc_str}>" | |
157 | ||
158 | ||
159 | class CliRunner: | |
160 | """The CLI runner provides functionality to invoke a Click command line | |
161 | script for unittesting purposes in a isolated environment. This only | |
162 | works in single-threaded systems without any concurrency as it changes the | |
163 | global interpreter state. | |
164 | ||
165 | :param charset: the character set for the input and output data. | |
166 | :param env: a dictionary with environment variables for overriding. | |
167 | :param echo_stdin: if this is set to `True`, then reading from stdin writes | |
168 | to stdout. This is useful for showing examples in | |
169 | some circumstances. Note that regular prompts | |
170 | will automatically echo the input. | |
171 | :param mix_stderr: if this is set to `False`, then stdout and stderr are | |
172 | preserved as independent streams. This is useful for | |
173 | Unix-philosophy apps that have predictable stdout and | |
174 | noisy stderr, such that each may be measured | |
175 | independently | |
176 | """ | |
177 | ||
178 | def __init__( | |
179 | self, | |
180 | charset: str = "utf-8", | |
181 | env: t.Optional[t.Mapping[str, t.Optional[str]]] = None, | |
182 | echo_stdin: bool = False, | |
183 | mix_stderr: bool = True, | |
184 | ) -> None: | |
185 | self.charset = charset | |
186 | self.env: t.Mapping[str, t.Optional[str]] = env or {} | |
187 | self.echo_stdin = echo_stdin | |
188 | self.mix_stderr = mix_stderr | |
189 | ||
190 | def get_default_prog_name(self, cli: "BaseCommand") -> str: | |
191 | """Given a command object it will return the default program name | |
192 | for it. The default is the `name` attribute or ``"root"`` if not | |
193 | set. | |
194 | """ | |
195 | return cli.name or "root" | |
196 | ||
197 | def make_env( | |
198 | self, overrides: t.Optional[t.Mapping[str, t.Optional[str]]] = None | |
199 | ) -> t.Mapping[str, t.Optional[str]]: | |
200 | """Returns the environment overrides for invoking a script.""" | |
201 | rv = dict(self.env) | |
202 | if overrides: | |
203 | rv.update(overrides) | |
204 | return rv | |
205 | ||
206 | @contextlib.contextmanager | |
207 | def isolation( | |
208 | self, | |
209 | input: t.Optional[t.Union[str, bytes, t.IO[t.Any]]] = None, | |
210 | env: t.Optional[t.Mapping[str, t.Optional[str]]] = None, | |
211 | color: bool = False, | |
212 | ) -> t.Iterator[t.Tuple[io.BytesIO, t.Optional[io.BytesIO]]]: | |
213 | """A context manager that sets up the isolation for invoking of a | |
214 | command line tool. This sets up stdin with the given input data | |
215 | and `os.environ` with the overrides from the given dictionary. | |
216 | This also rebinds some internals in Click to be mocked (like the | |
217 | prompt functionality). | |
218 | ||
219 | This is automatically done in the :meth:`invoke` method. | |
220 | ||
221 | :param input: the input stream to put into sys.stdin. | |
222 | :param env: the environment overrides as dictionary. | |
223 | :param color: whether the output should contain color codes. The | |
224 | application can still override this explicitly. | |
225 | ||
226 | .. versionchanged:: 8.0 | |
227 | ``stderr`` is opened with ``errors="backslashreplace"`` | |
228 | instead of the default ``"strict"``. | |
229 | ||
230 | .. versionchanged:: 4.0 | |
231 | Added the ``color`` parameter. | |
232 | """ | |
233 | bytes_input = make_input_stream(input, self.charset) | |
234 | echo_input = None | |
235 | ||
236 | old_stdin = sys.stdin | |
237 | old_stdout = sys.stdout | |
238 | old_stderr = sys.stderr | |
239 | old_forced_width = formatting.FORCED_WIDTH | |
240 | formatting.FORCED_WIDTH = 80 | |
241 | ||
242 | env = self.make_env(env) | |
243 | ||
244 | bytes_output = io.BytesIO() | |
245 | ||
246 | if self.echo_stdin: | |
247 | bytes_input = echo_input = t.cast( | |
248 | t.BinaryIO, EchoingStdin(bytes_input, bytes_output) | |
249 | ) | |
250 | ||
251 | sys.stdin = text_input = _NamedTextIOWrapper( | |
252 | bytes_input, encoding=self.charset, name="<stdin>", mode="r" | |
253 | ) | |
254 | ||
255 | if self.echo_stdin: | |
256 | # Force unbuffered reads, otherwise TextIOWrapper reads a | |
257 | # large chunk which is echoed early. | |
258 | text_input._CHUNK_SIZE = 1 # type: ignore | |
259 | ||
260 | sys.stdout = _NamedTextIOWrapper( | |
261 | bytes_output, encoding=self.charset, name="<stdout>", mode="w" | |
262 | ) | |
263 | ||
264 | bytes_error = None | |
265 | if self.mix_stderr: | |
266 | sys.stderr = sys.stdout | |
267 | else: | |
268 | bytes_error = io.BytesIO() | |
269 | sys.stderr = _NamedTextIOWrapper( | |
270 | bytes_error, | |
271 | encoding=self.charset, | |
272 | name="<stderr>", | |
273 | mode="w", | |
274 | errors="backslashreplace", | |
275 | ) | |
276 | ||
277 | @_pause_echo(echo_input) # type: ignore | |
278 | def visible_input(prompt: t.Optional[str] = None) -> str: | |
279 | sys.stdout.write(prompt or "") | |
280 | val = text_input.readline().rstrip("\r\n") | |
281 | sys.stdout.write(f"{val}\n") | |
282 | sys.stdout.flush() | |
283 | return val | |
284 | ||
285 | @_pause_echo(echo_input) # type: ignore | |
286 | def hidden_input(prompt: t.Optional[str] = None) -> str: | |
287 | sys.stdout.write(f"{prompt or ''}\n") | |
288 | sys.stdout.flush() | |
289 | return text_input.readline().rstrip("\r\n") | |
290 | ||
291 | @_pause_echo(echo_input) # type: ignore | |
292 | def _getchar(echo: bool) -> str: | |
293 | char = sys.stdin.read(1) | |
294 | ||
295 | if echo: | |
296 | sys.stdout.write(char) | |
297 | ||
298 | sys.stdout.flush() | |
299 | return char | |
300 | ||
301 | default_color = color | |
302 | ||
303 | def should_strip_ansi( | |
304 | stream: t.Optional[t.IO[t.Any]] = None, color: t.Optional[bool] = None | |
305 | ) -> bool: | |
306 | if color is None: | |
307 | return not default_color | |
308 | return not color | |
309 | ||
310 | old_visible_prompt_func = termui.visible_prompt_func | |
311 | old_hidden_prompt_func = termui.hidden_prompt_func | |
312 | old__getchar_func = termui._getchar | |
313 | old_should_strip_ansi = utils.should_strip_ansi # type: ignore | |
314 | termui.visible_prompt_func = visible_input | |
315 | termui.hidden_prompt_func = hidden_input | |
316 | termui._getchar = _getchar | |
317 | utils.should_strip_ansi = should_strip_ansi # type: ignore | |
318 | ||
319 | old_env = {} | |
320 | try: | |
321 | for key, value in env.items(): | |
322 | old_env[key] = os.environ.get(key) | |
323 | if value is None: | |
324 | try: | |
325 | del os.environ[key] | |
326 | except Exception: | |
327 | pass | |
328 | else: | |
329 | os.environ[key] = value | |
330 | yield (bytes_output, bytes_error) | |
331 | finally: | |
332 | for key, value in old_env.items(): | |
333 | if value is None: | |
334 | try: | |
335 | del os.environ[key] | |
336 | except Exception: | |
337 | pass | |
338 | else: | |
339 | os.environ[key] = value | |
340 | sys.stdout = old_stdout | |
341 | sys.stderr = old_stderr | |
342 | sys.stdin = old_stdin | |
343 | termui.visible_prompt_func = old_visible_prompt_func | |
344 | termui.hidden_prompt_func = old_hidden_prompt_func | |
345 | termui._getchar = old__getchar_func | |
346 | utils.should_strip_ansi = old_should_strip_ansi # type: ignore | |
347 | formatting.FORCED_WIDTH = old_forced_width | |
348 | ||
349 | def invoke( | |
350 | self, | |
351 | cli: "BaseCommand", | |
352 | args: t.Optional[t.Union[str, t.Sequence[str]]] = None, | |
353 | input: t.Optional[t.Union[str, bytes, t.IO[t.Any]]] = None, | |
354 | env: t.Optional[t.Mapping[str, t.Optional[str]]] = None, | |
355 | catch_exceptions: bool = True, | |
356 | color: bool = False, | |
357 | **extra: t.Any, | |
358 | ) -> Result: | |
359 | """Invokes a command in an isolated environment. The arguments are | |
360 | forwarded directly to the command line script, the `extra` keyword | |
361 | arguments are passed to the :meth:`~clickpkg.Command.main` function of | |
362 | the command. | |
363 | ||
364 | This returns a :class:`Result` object. | |
365 | ||
366 | :param cli: the command to invoke | |
367 | :param args: the arguments to invoke. It may be given as an iterable | |
368 | or a string. When given as string it will be interpreted | |
369 | as a Unix shell command. More details at | |
370 | :func:`shlex.split`. | |
371 | :param input: the input data for `sys.stdin`. | |
372 | :param env: the environment overrides. | |
373 | :param catch_exceptions: Whether to catch any other exceptions than | |
374 | ``SystemExit``. | |
375 | :param extra: the keyword arguments to pass to :meth:`main`. | |
376 | :param color: whether the output should contain color codes. The | |
377 | application can still override this explicitly. | |
378 | ||
379 | .. versionchanged:: 8.0 | |
380 | The result object has the ``return_value`` attribute with | |
381 | the value returned from the invoked command. | |
382 | ||
383 | .. versionchanged:: 4.0 | |
384 | Added the ``color`` parameter. | |
385 | ||
386 | .. versionchanged:: 3.0 | |
387 | Added the ``catch_exceptions`` parameter. | |
388 | ||
389 | .. versionchanged:: 3.0 | |
390 | The result object has the ``exc_info`` attribute with the | |
391 | traceback if available. | |
392 | """ | |
393 | exc_info = None | |
394 | with self.isolation(input=input, env=env, color=color) as outstreams: | |
395 | return_value = None | |
396 | exception: t.Optional[BaseException] = None | |
397 | exit_code = 0 | |
398 | ||
399 | if isinstance(args, str): | |
400 | args = shlex.split(args) | |
401 | ||
402 | try: | |
403 | prog_name = extra.pop("prog_name") | |
404 | except KeyError: | |
405 | prog_name = self.get_default_prog_name(cli) | |
406 | ||
407 | try: | |
408 | return_value = cli.main(args=args or (), prog_name=prog_name, **extra) | |
409 | except SystemExit as e: | |
410 | exc_info = sys.exc_info() | |
411 | e_code = t.cast(t.Optional[t.Union[int, t.Any]], e.code) | |
412 | ||
413 | if e_code is None: | |
414 | e_code = 0 | |
415 | ||
416 | if e_code != 0: | |
417 | exception = e | |
418 | ||
419 | if not isinstance(e_code, int): | |
420 | sys.stdout.write(str(e_code)) | |
421 | sys.stdout.write("\n") | |
422 | e_code = 1 | |
423 | ||
424 | exit_code = e_code | |
425 | ||
426 | except Exception as e: | |
427 | if not catch_exceptions: | |
428 | raise | |
429 | exception = e | |
430 | exit_code = 1 | |
431 | exc_info = sys.exc_info() | |
432 | finally: | |
433 | sys.stdout.flush() | |
434 | stdout = outstreams[0].getvalue() | |
435 | if self.mix_stderr: | |
436 | stderr = None | |
437 | else: | |
438 | stderr = outstreams[1].getvalue() # type: ignore | |
439 | ||
440 | return Result( | |
441 | runner=self, | |
442 | stdout_bytes=stdout, | |
443 | stderr_bytes=stderr, | |
444 | return_value=return_value, | |
445 | exit_code=exit_code, | |
446 | exception=exception, | |
447 | exc_info=exc_info, # type: ignore | |
448 | ) | |
449 | ||
450 | @contextlib.contextmanager | |
451 | def isolated_filesystem( | |
452 | self, temp_dir: t.Optional[t.Union[str, "os.PathLike[str]"]] = None | |
453 | ) -> t.Iterator[str]: | |
454 | """A context manager that creates a temporary directory and | |
455 | changes the current working directory to it. This isolates tests | |
456 | that affect the contents of the CWD to prevent them from | |
457 | interfering with each other. | |
458 | ||
459 | :param temp_dir: Create the temporary directory under this | |
460 | directory. If given, the created directory is not removed | |
461 | when exiting. | |
462 | ||
463 | .. versionchanged:: 8.0 | |
464 | Added the ``temp_dir`` parameter. | |
465 | """ | |
466 | cwd = os.getcwd() | |
467 | dt = tempfile.mkdtemp(dir=temp_dir) | |
468 | os.chdir(dt) | |
469 | ||
470 | try: | |
471 | yield dt | |
472 | finally: | |
473 | os.chdir(cwd) | |
474 | ||
475 | if temp_dir is None: | |
476 | try: | |
477 | shutil.rmtree(dt) | |
478 | except OSError: # noqa: B014 | |
479 | pass |