]>
Commit | Line | Data |
---|---|---|
1 | """Builds on top of nodes.py to track brackets.""" | |
2 | ||
3 | from dataclasses import dataclass, field | |
4 | from typing import Dict, Final, Iterable, List, Optional, Sequence, Set, Tuple, Union | |
5 | ||
6 | from black.nodes import ( | |
7 | BRACKET, | |
8 | CLOSING_BRACKETS, | |
9 | COMPARATORS, | |
10 | LOGIC_OPERATORS, | |
11 | MATH_OPERATORS, | |
12 | OPENING_BRACKETS, | |
13 | UNPACKING_PARENTS, | |
14 | VARARGS_PARENTS, | |
15 | is_vararg, | |
16 | syms, | |
17 | ) | |
18 | from blib2to3.pgen2 import token | |
19 | from blib2to3.pytree import Leaf, Node | |
20 | ||
21 | # types | |
22 | LN = Union[Leaf, Node] | |
23 | Depth = int | |
24 | LeafID = int | |
25 | NodeType = int | |
26 | Priority = int | |
27 | ||
28 | ||
29 | COMPREHENSION_PRIORITY: Final = 20 | |
30 | COMMA_PRIORITY: Final = 18 | |
31 | TERNARY_PRIORITY: Final = 16 | |
32 | LOGIC_PRIORITY: Final = 14 | |
33 | STRING_PRIORITY: Final = 12 | |
34 | COMPARATOR_PRIORITY: Final = 10 | |
35 | MATH_PRIORITIES: Final = { | |
36 | token.VBAR: 9, | |
37 | token.CIRCUMFLEX: 8, | |
38 | token.AMPER: 7, | |
39 | token.LEFTSHIFT: 6, | |
40 | token.RIGHTSHIFT: 6, | |
41 | token.PLUS: 5, | |
42 | token.MINUS: 5, | |
43 | token.STAR: 4, | |
44 | token.SLASH: 4, | |
45 | token.DOUBLESLASH: 4, | |
46 | token.PERCENT: 4, | |
47 | token.AT: 4, | |
48 | token.TILDE: 3, | |
49 | token.DOUBLESTAR: 2, | |
50 | } | |
51 | DOT_PRIORITY: Final = 1 | |
52 | ||
53 | ||
54 | class BracketMatchError(Exception): | |
55 | """Raised when an opening bracket is unable to be matched to a closing bracket.""" | |
56 | ||
57 | ||
58 | @dataclass | |
59 | class BracketTracker: | |
60 | """Keeps track of brackets on a line.""" | |
61 | ||
62 | depth: int = 0 | |
63 | bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict) | |
64 | delimiters: Dict[LeafID, Priority] = field(default_factory=dict) | |
65 | previous: Optional[Leaf] = None | |
66 | _for_loop_depths: List[int] = field(default_factory=list) | |
67 | _lambda_argument_depths: List[int] = field(default_factory=list) | |
68 | invisible: List[Leaf] = field(default_factory=list) | |
69 | ||
70 | def mark(self, leaf: Leaf) -> None: | |
71 | """Mark `leaf` with bracket-related metadata. Keep track of delimiters. | |
72 | ||
73 | All leaves receive an int `bracket_depth` field that stores how deep | |
74 | within brackets a given leaf is. 0 means there are no enclosing brackets | |
75 | that started on this line. | |
76 | ||
77 | If a leaf is itself a closing bracket and there is a matching opening | |
78 | bracket earlier, it receives an `opening_bracket` field with which it forms a | |
79 | pair. This is a one-directional link to avoid reference cycles. Closing | |
80 | bracket without opening happens on lines continued from previous | |
81 | breaks, e.g. `) -> "ReturnType":` as part of a funcdef where we place | |
82 | the return type annotation on its own line of the previous closing RPAR. | |
83 | ||
84 | If a leaf is a delimiter (a token on which Black can split the line if | |
85 | needed) and it's on depth 0, its `id()` is stored in the tracker's | |
86 | `delimiters` field. | |
87 | """ | |
88 | if leaf.type == token.COMMENT: | |
89 | return | |
90 | ||
91 | if ( | |
92 | self.depth == 0 | |
93 | and leaf.type in CLOSING_BRACKETS | |
94 | and (self.depth, leaf.type) not in self.bracket_match | |
95 | ): | |
96 | return | |
97 | ||
98 | self.maybe_decrement_after_for_loop_variable(leaf) | |
99 | self.maybe_decrement_after_lambda_arguments(leaf) | |
100 | if leaf.type in CLOSING_BRACKETS: | |
101 | self.depth -= 1 | |
102 | try: | |
103 | opening_bracket = self.bracket_match.pop((self.depth, leaf.type)) | |
104 | except KeyError as e: | |
105 | raise BracketMatchError( | |
106 | "Unable to match a closing bracket to the following opening" | |
107 | f" bracket: {leaf}" | |
108 | ) from e | |
109 | leaf.opening_bracket = opening_bracket | |
110 | if not leaf.value: | |
111 | self.invisible.append(leaf) | |
112 | leaf.bracket_depth = self.depth | |
113 | if self.depth == 0: | |
114 | delim = is_split_before_delimiter(leaf, self.previous) | |
115 | if delim and self.previous is not None: | |
116 | self.delimiters[id(self.previous)] = delim | |
117 | else: | |
118 | delim = is_split_after_delimiter(leaf, self.previous) | |
119 | if delim: | |
120 | self.delimiters[id(leaf)] = delim | |
121 | if leaf.type in OPENING_BRACKETS: | |
122 | self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf | |
123 | self.depth += 1 | |
124 | if not leaf.value: | |
125 | self.invisible.append(leaf) | |
126 | self.previous = leaf | |
127 | self.maybe_increment_lambda_arguments(leaf) | |
128 | self.maybe_increment_for_loop_variable(leaf) | |
129 | ||
130 | def any_open_brackets(self) -> bool: | |
131 | """Return True if there is an yet unmatched open bracket on the line.""" | |
132 | return bool(self.bracket_match) | |
133 | ||
134 | def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority: | |
135 | """Return the highest priority of a delimiter found on the line. | |
136 | ||
137 | Values are consistent with what `is_split_*_delimiter()` return. | |
138 | Raises ValueError on no delimiters. | |
139 | """ | |
140 | return max(v for k, v in self.delimiters.items() if k not in exclude) | |
141 | ||
142 | def delimiter_count_with_priority(self, priority: Priority = 0) -> int: | |
143 | """Return the number of delimiters with the given `priority`. | |
144 | ||
145 | If no `priority` is passed, defaults to max priority on the line. | |
146 | """ | |
147 | if not self.delimiters: | |
148 | return 0 | |
149 | ||
150 | priority = priority or self.max_delimiter_priority() | |
151 | return sum(1 for p in self.delimiters.values() if p == priority) | |
152 | ||
153 | def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool: | |
154 | """In a for loop, or comprehension, the variables are often unpacks. | |
155 | ||
156 | To avoid splitting on the comma in this situation, increase the depth of | |
157 | tokens between `for` and `in`. | |
158 | """ | |
159 | if leaf.type == token.NAME and leaf.value == "for": | |
160 | self.depth += 1 | |
161 | self._for_loop_depths.append(self.depth) | |
162 | return True | |
163 | ||
164 | return False | |
165 | ||
166 | def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool: | |
167 | """See `maybe_increment_for_loop_variable` above for explanation.""" | |
168 | if ( | |
169 | self._for_loop_depths | |
170 | and self._for_loop_depths[-1] == self.depth | |
171 | and leaf.type == token.NAME | |
172 | and leaf.value == "in" | |
173 | ): | |
174 | self.depth -= 1 | |
175 | self._for_loop_depths.pop() | |
176 | return True | |
177 | ||
178 | return False | |
179 | ||
180 | def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool: | |
181 | """In a lambda expression, there might be more than one argument. | |
182 | ||
183 | To avoid splitting on the comma in this situation, increase the depth of | |
184 | tokens between `lambda` and `:`. | |
185 | """ | |
186 | if leaf.type == token.NAME and leaf.value == "lambda": | |
187 | self.depth += 1 | |
188 | self._lambda_argument_depths.append(self.depth) | |
189 | return True | |
190 | ||
191 | return False | |
192 | ||
193 | def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool: | |
194 | """See `maybe_increment_lambda_arguments` above for explanation.""" | |
195 | if ( | |
196 | self._lambda_argument_depths | |
197 | and self._lambda_argument_depths[-1] == self.depth | |
198 | and leaf.type == token.COLON | |
199 | ): | |
200 | self.depth -= 1 | |
201 | self._lambda_argument_depths.pop() | |
202 | return True | |
203 | ||
204 | return False | |
205 | ||
206 | def get_open_lsqb(self) -> Optional[Leaf]: | |
207 | """Return the most recent opening square bracket (if any).""" | |
208 | return self.bracket_match.get((self.depth - 1, token.RSQB)) | |
209 | ||
210 | ||
211 | def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority: | |
212 | """Return the priority of the `leaf` delimiter, given a line break after it. | |
213 | ||
214 | The delimiter priorities returned here are from those delimiters that would | |
215 | cause a line break after themselves. | |
216 | ||
217 | Higher numbers are higher priority. | |
218 | """ | |
219 | if leaf.type == token.COMMA: | |
220 | return COMMA_PRIORITY | |
221 | ||
222 | return 0 | |
223 | ||
224 | ||
225 | def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority: | |
226 | """Return the priority of the `leaf` delimiter, given a line break before it. | |
227 | ||
228 | The delimiter priorities returned here are from those delimiters that would | |
229 | cause a line break before themselves. | |
230 | ||
231 | Higher numbers are higher priority. | |
232 | """ | |
233 | if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS): | |
234 | # * and ** might also be MATH_OPERATORS but in this case they are not. | |
235 | # Don't treat them as a delimiter. | |
236 | return 0 | |
237 | ||
238 | if ( | |
239 | leaf.type == token.DOT | |
240 | and leaf.parent | |
241 | and leaf.parent.type not in {syms.import_from, syms.dotted_name} | |
242 | and (previous is None or previous.type in CLOSING_BRACKETS) | |
243 | ): | |
244 | return DOT_PRIORITY | |
245 | ||
246 | if ( | |
247 | leaf.type in MATH_OPERATORS | |
248 | and leaf.parent | |
249 | and leaf.parent.type not in {syms.factor, syms.star_expr} | |
250 | ): | |
251 | return MATH_PRIORITIES[leaf.type] | |
252 | ||
253 | if leaf.type in COMPARATORS: | |
254 | return COMPARATOR_PRIORITY | |
255 | ||
256 | if ( | |
257 | leaf.type == token.STRING | |
258 | and previous is not None | |
259 | and previous.type == token.STRING | |
260 | ): | |
261 | return STRING_PRIORITY | |
262 | ||
263 | if leaf.type not in {token.NAME, token.ASYNC}: | |
264 | return 0 | |
265 | ||
266 | if ( | |
267 | leaf.value == "for" | |
268 | and leaf.parent | |
269 | and leaf.parent.type in {syms.comp_for, syms.old_comp_for} | |
270 | or leaf.type == token.ASYNC | |
271 | ): | |
272 | if ( | |
273 | not isinstance(leaf.prev_sibling, Leaf) | |
274 | or leaf.prev_sibling.value != "async" | |
275 | ): | |
276 | return COMPREHENSION_PRIORITY | |
277 | ||
278 | if ( | |
279 | leaf.value == "if" | |
280 | and leaf.parent | |
281 | and leaf.parent.type in {syms.comp_if, syms.old_comp_if} | |
282 | ): | |
283 | return COMPREHENSION_PRIORITY | |
284 | ||
285 | if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test: | |
286 | return TERNARY_PRIORITY | |
287 | ||
288 | if leaf.value == "is": | |
289 | return COMPARATOR_PRIORITY | |
290 | ||
291 | if ( | |
292 | leaf.value == "in" | |
293 | and leaf.parent | |
294 | and leaf.parent.type in {syms.comp_op, syms.comparison} | |
295 | and not ( | |
296 | previous is not None | |
297 | and previous.type == token.NAME | |
298 | and previous.value == "not" | |
299 | ) | |
300 | ): | |
301 | return COMPARATOR_PRIORITY | |
302 | ||
303 | if ( | |
304 | leaf.value == "not" | |
305 | and leaf.parent | |
306 | and leaf.parent.type == syms.comp_op | |
307 | and not ( | |
308 | previous is not None | |
309 | and previous.type == token.NAME | |
310 | and previous.value == "is" | |
311 | ) | |
312 | ): | |
313 | return COMPARATOR_PRIORITY | |
314 | ||
315 | if leaf.value in LOGIC_OPERATORS and leaf.parent: | |
316 | return LOGIC_PRIORITY | |
317 | ||
318 | return 0 | |
319 | ||
320 | ||
321 | def max_delimiter_priority_in_atom(node: LN) -> Priority: | |
322 | """Return maximum delimiter priority inside `node`. | |
323 | ||
324 | This is specific to atoms with contents contained in a pair of parentheses. | |
325 | If `node` isn't an atom or there are no enclosing parentheses, returns 0. | |
326 | """ | |
327 | if node.type != syms.atom: | |
328 | return 0 | |
329 | ||
330 | first = node.children[0] | |
331 | last = node.children[-1] | |
332 | if not (first.type == token.LPAR and last.type == token.RPAR): | |
333 | return 0 | |
334 | ||
335 | bt = BracketTracker() | |
336 | for c in node.children[1:-1]: | |
337 | if isinstance(c, Leaf): | |
338 | bt.mark(c) | |
339 | else: | |
340 | for leaf in c.leaves(): | |
341 | bt.mark(leaf) | |
342 | try: | |
343 | return bt.max_delimiter_priority() | |
344 | ||
345 | except ValueError: | |
346 | return 0 | |
347 | ||
348 | ||
349 | def get_leaves_inside_matching_brackets(leaves: Sequence[Leaf]) -> Set[LeafID]: | |
350 | """Return leaves that are inside matching brackets. | |
351 | ||
352 | The input `leaves` can have non-matching brackets at the head or tail parts. | |
353 | Matching brackets are included. | |
354 | """ | |
355 | try: | |
356 | # Start with the first opening bracket and ignore closing brackets before. | |
357 | start_index = next( | |
358 | i for i, l in enumerate(leaves) if l.type in OPENING_BRACKETS | |
359 | ) | |
360 | except StopIteration: | |
361 | return set() | |
362 | bracket_stack = [] | |
363 | ids = set() | |
364 | for i in range(start_index, len(leaves)): | |
365 | leaf = leaves[i] | |
366 | if leaf.type in OPENING_BRACKETS: | |
367 | bracket_stack.append((BRACKET[leaf.type], i)) | |
368 | if leaf.type in CLOSING_BRACKETS: | |
369 | if bracket_stack and leaf.type == bracket_stack[-1][0]: | |
370 | _, start = bracket_stack.pop() | |
371 | for j in range(start, i + 1): | |
372 | ids.add(id(leaves[j])) | |
373 | else: | |
374 | break | |
375 | return ids |