]>
Commit | Line | Data |
---|---|---|
53e6db90 DC |
1 | """Caching of formatted files with feature-based invalidation.""" |
2 | import hashlib | |
3 | import os | |
4 | import pickle | |
5 | import sys | |
6 | import tempfile | |
7 | from dataclasses import dataclass, field | |
8 | from pathlib import Path | |
9 | from typing import Dict, Iterable, NamedTuple, Set, Tuple | |
10 | ||
11 | from platformdirs import user_cache_dir | |
12 | ||
13 | from _black_version import version as __version__ | |
14 | from black.mode import Mode | |
15 | ||
16 | if sys.version_info >= (3, 11): | |
17 | from typing import Self | |
18 | else: | |
19 | from typing_extensions import Self | |
20 | ||
21 | ||
22 | class FileData(NamedTuple): | |
23 | st_mtime: float | |
24 | st_size: int | |
25 | hash: str | |
26 | ||
27 | ||
28 | def get_cache_dir() -> Path: | |
29 | """Get the cache directory used by black. | |
30 | ||
31 | Users can customize this directory on all systems using `BLACK_CACHE_DIR` | |
32 | environment variable. By default, the cache directory is the user cache directory | |
33 | under the black application. | |
34 | ||
35 | This result is immediately set to a constant `black.cache.CACHE_DIR` as to avoid | |
36 | repeated calls. | |
37 | """ | |
38 | # NOTE: Function mostly exists as a clean way to test getting the cache directory. | |
39 | default_cache_dir = user_cache_dir("black", version=__version__) | |
40 | cache_dir = Path(os.environ.get("BLACK_CACHE_DIR", default_cache_dir)) | |
41 | return cache_dir | |
42 | ||
43 | ||
44 | CACHE_DIR = get_cache_dir() | |
45 | ||
46 | ||
47 | def get_cache_file(mode: Mode) -> Path: | |
48 | return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle" | |
49 | ||
50 | ||
51 | @dataclass | |
52 | class Cache: | |
53 | mode: Mode | |
54 | cache_file: Path | |
55 | file_data: Dict[str, FileData] = field(default_factory=dict) | |
56 | ||
57 | @classmethod | |
58 | def read(cls, mode: Mode) -> Self: | |
59 | """Read the cache if it exists and is well formed. | |
60 | ||
61 | If it is not well formed, the call to write later should | |
62 | resolve the issue. | |
63 | """ | |
64 | cache_file = get_cache_file(mode) | |
65 | if not cache_file.exists(): | |
66 | return cls(mode, cache_file) | |
67 | ||
68 | with cache_file.open("rb") as fobj: | |
69 | try: | |
70 | data: Dict[str, Tuple[float, int, str]] = pickle.load(fobj) | |
71 | file_data = {k: FileData(*v) for k, v in data.items()} | |
72 | except (pickle.UnpicklingError, ValueError, IndexError): | |
73 | return cls(mode, cache_file) | |
74 | ||
75 | return cls(mode, cache_file, file_data) | |
76 | ||
77 | @staticmethod | |
78 | def hash_digest(path: Path) -> str: | |
79 | """Return hash digest for path.""" | |
80 | ||
81 | data = path.read_bytes() | |
82 | return hashlib.sha256(data).hexdigest() | |
83 | ||
84 | @staticmethod | |
85 | def get_file_data(path: Path) -> FileData: | |
86 | """Return file data for path.""" | |
87 | ||
88 | stat = path.stat() | |
89 | hash = Cache.hash_digest(path) | |
90 | return FileData(stat.st_mtime, stat.st_size, hash) | |
91 | ||
92 | def is_changed(self, source: Path) -> bool: | |
93 | """Check if source has changed compared to cached version.""" | |
94 | res_src = source.resolve() | |
95 | old = self.file_data.get(str(res_src)) | |
96 | if old is None: | |
97 | return True | |
98 | ||
99 | st = res_src.stat() | |
100 | if st.st_size != old.st_size: | |
101 | return True | |
102 | if int(st.st_mtime) != int(old.st_mtime): | |
103 | new_hash = Cache.hash_digest(res_src) | |
104 | if new_hash != old.hash: | |
105 | return True | |
106 | return False | |
107 | ||
108 | def filtered_cached(self, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]: | |
109 | """Split an iterable of paths in `sources` into two sets. | |
110 | ||
111 | The first contains paths of files that modified on disk or are not in the | |
112 | cache. The other contains paths to non-modified files. | |
113 | """ | |
114 | changed: Set[Path] = set() | |
115 | done: Set[Path] = set() | |
116 | for src in sources: | |
117 | if self.is_changed(src): | |
118 | changed.add(src) | |
119 | else: | |
120 | done.add(src) | |
121 | return changed, done | |
122 | ||
123 | def write(self, sources: Iterable[Path]) -> None: | |
124 | """Update the cache file data and write a new cache file.""" | |
125 | self.file_data.update( | |
126 | **{str(src.resolve()): Cache.get_file_data(src) for src in sources} | |
127 | ) | |
128 | try: | |
129 | CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
130 | with tempfile.NamedTemporaryFile( | |
131 | dir=str(self.cache_file.parent), delete=False | |
132 | ) as f: | |
133 | # We store raw tuples in the cache because pickling NamedTuples | |
134 | # doesn't work with mypyc on Python 3.8, and because it's faster. | |
135 | data: Dict[str, Tuple[float, int, str]] = { | |
136 | k: (*v,) for k, v in self.file_data.items() | |
137 | } | |
138 | pickle.dump(data, f, protocol=4) | |
139 | os.replace(f.name, self.cache_file) | |
140 | except OSError: | |
141 | pass |