]>
Commit | Line | Data |
---|---|---|
53e6db90 DC |
1 | """ |
2 | PEP 0484 ( https://www.python.org/dev/peps/pep-0484/ ) describes type hints | |
3 | through function annotations. There is a strong suggestion in this document | |
4 | that only the type of type hinting defined in PEP0484 should be allowed | |
5 | as annotations in future python versions. | |
6 | """ | |
7 | ||
8 | import re | |
9 | from inspect import Parameter | |
10 | ||
11 | from parso import ParserSyntaxError, parse | |
12 | ||
13 | from jedi.inference.cache import inference_state_method_cache | |
14 | from jedi.inference.base_value import ValueSet, NO_VALUES | |
15 | from jedi.inference.gradual.base import DefineGenericBaseClass, GenericClass | |
16 | from jedi.inference.gradual.generics import TupleGenericManager | |
17 | from jedi.inference.gradual.type_var import TypeVar | |
18 | from jedi.inference.helpers import is_string | |
19 | from jedi.inference.compiled import builtin_from_name | |
20 | from jedi.inference.param import get_executed_param_names | |
21 | from jedi import debug | |
22 | from jedi import parser_utils | |
23 | ||
24 | ||
25 | def infer_annotation(context, annotation): | |
26 | """ | |
27 | Inferes an annotation node. This means that it inferes the part of | |
28 | `int` here: | |
29 | ||
30 | foo: int = 3 | |
31 | ||
32 | Also checks for forward references (strings) | |
33 | """ | |
34 | value_set = context.infer_node(annotation) | |
35 | if len(value_set) != 1: | |
36 | debug.warning("Inferred typing index %s should lead to 1 object, " | |
37 | " not %s" % (annotation, value_set)) | |
38 | return value_set | |
39 | ||
40 | inferred_value = list(value_set)[0] | |
41 | if is_string(inferred_value): | |
42 | result = _get_forward_reference_node(context, inferred_value.get_safe_value()) | |
43 | if result is not None: | |
44 | return context.infer_node(result) | |
45 | return value_set | |
46 | ||
47 | ||
48 | def _infer_annotation_string(context, string, index=None): | |
49 | node = _get_forward_reference_node(context, string) | |
50 | if node is None: | |
51 | return NO_VALUES | |
52 | ||
53 | value_set = context.infer_node(node) | |
54 | if index is not None: | |
55 | value_set = value_set.filter( | |
56 | lambda value: ( | |
57 | value.array_type == 'tuple' | |
58 | and len(list(value.py__iter__())) >= index | |
59 | ) | |
60 | ).py__simple_getitem__(index) | |
61 | return value_set | |
62 | ||
63 | ||
64 | def _get_forward_reference_node(context, string): | |
65 | try: | |
66 | new_node = context.inference_state.grammar.parse( | |
67 | string, | |
68 | start_symbol='eval_input', | |
69 | error_recovery=False | |
70 | ) | |
71 | except ParserSyntaxError: | |
72 | debug.warning('Annotation not parsed: %s' % string) | |
73 | return None | |
74 | else: | |
75 | module = context.tree_node.get_root_node() | |
76 | parser_utils.move(new_node, module.end_pos[0]) | |
77 | new_node.parent = context.tree_node | |
78 | return new_node | |
79 | ||
80 | ||
81 | def _split_comment_param_declaration(decl_text): | |
82 | """ | |
83 | Split decl_text on commas, but group generic expressions | |
84 | together. | |
85 | ||
86 | For example, given "foo, Bar[baz, biz]" we return | |
87 | ['foo', 'Bar[baz, biz]']. | |
88 | ||
89 | """ | |
90 | try: | |
91 | node = parse(decl_text, error_recovery=False).children[0] | |
92 | except ParserSyntaxError: | |
93 | debug.warning('Comment annotation is not valid Python: %s' % decl_text) | |
94 | return [] | |
95 | ||
96 | if node.type in ['name', 'atom_expr', 'power']: | |
97 | return [node.get_code().strip()] | |
98 | ||
99 | params = [] | |
100 | try: | |
101 | children = node.children | |
102 | except AttributeError: | |
103 | return [] | |
104 | else: | |
105 | for child in children: | |
106 | if child.type in ['name', 'atom_expr', 'power']: | |
107 | params.append(child.get_code().strip()) | |
108 | ||
109 | return params | |
110 | ||
111 | ||
112 | @inference_state_method_cache() | |
113 | def infer_param(function_value, param, ignore_stars=False): | |
114 | values = _infer_param(function_value, param) | |
115 | if ignore_stars or not values: | |
116 | return values | |
117 | inference_state = function_value.inference_state | |
118 | if param.star_count == 1: | |
119 | tuple_ = builtin_from_name(inference_state, 'tuple') | |
120 | return ValueSet([GenericClass( | |
121 | tuple_, | |
122 | TupleGenericManager((values,)), | |
123 | )]) | |
124 | elif param.star_count == 2: | |
125 | dct = builtin_from_name(inference_state, 'dict') | |
126 | generics = ( | |
127 | ValueSet([builtin_from_name(inference_state, 'str')]), | |
128 | values | |
129 | ) | |
130 | return ValueSet([GenericClass( | |
131 | dct, | |
132 | TupleGenericManager(generics), | |
133 | )]) | |
134 | return values | |
135 | ||
136 | ||
137 | def _infer_param(function_value, param): | |
138 | """ | |
139 | Infers the type of a function parameter, using type annotations. | |
140 | """ | |
141 | annotation = param.annotation | |
142 | if annotation is None: | |
143 | # If no Python 3-style annotation, look for a comment annotation. | |
144 | # Identify parameters to function in the same sequence as they would | |
145 | # appear in a type comment. | |
146 | all_params = [child for child in param.parent.children | |
147 | if child.type == 'param'] | |
148 | ||
149 | node = param.parent.parent | |
150 | comment = parser_utils.get_following_comment_same_line(node) | |
151 | if comment is None: | |
152 | return NO_VALUES | |
153 | ||
154 | match = re.match(r"^#\s*type:\s*\(([^#]*)\)\s*->", comment) | |
155 | if not match: | |
156 | return NO_VALUES | |
157 | params_comments = _split_comment_param_declaration(match.group(1)) | |
158 | ||
159 | # Find the specific param being investigated | |
160 | index = all_params.index(param) | |
161 | # If the number of parameters doesn't match length of type comment, | |
162 | # ignore first parameter (assume it's self). | |
163 | if len(params_comments) != len(all_params): | |
164 | debug.warning( | |
165 | "Comments length != Params length %s %s", | |
166 | params_comments, all_params | |
167 | ) | |
168 | if function_value.is_bound_method(): | |
169 | if index == 0: | |
170 | # Assume it's self, which is already handled | |
171 | return NO_VALUES | |
172 | index -= 1 | |
173 | if index >= len(params_comments): | |
174 | return NO_VALUES | |
175 | ||
176 | param_comment = params_comments[index] | |
177 | return _infer_annotation_string( | |
178 | function_value.get_default_param_context(), | |
179 | param_comment | |
180 | ) | |
181 | # Annotations are like default params and resolve in the same way. | |
182 | context = function_value.get_default_param_context() | |
183 | return infer_annotation(context, annotation) | |
184 | ||
185 | ||
186 | def py__annotations__(funcdef): | |
187 | dct = {} | |
188 | for function_param in funcdef.get_params(): | |
189 | param_annotation = function_param.annotation | |
190 | if param_annotation is not None: | |
191 | dct[function_param.name.value] = param_annotation | |
192 | ||
193 | return_annotation = funcdef.annotation | |
194 | if return_annotation: | |
195 | dct['return'] = return_annotation | |
196 | return dct | |
197 | ||
198 | ||
199 | def resolve_forward_references(context, all_annotations): | |
200 | def resolve(node): | |
201 | if node is None or node.type != 'string': | |
202 | return node | |
203 | ||
204 | node = _get_forward_reference_node( | |
205 | context, | |
206 | context.inference_state.compiled_subprocess.safe_literal_eval( | |
207 | node.value, | |
208 | ), | |
209 | ) | |
210 | ||
211 | if node is None: | |
212 | # There was a string, but it's not a valid annotation | |
213 | return None | |
214 | ||
215 | # The forward reference tree has an additional root node ('eval_input') | |
216 | # that we don't want. Extract the node we do want, that is equivalent to | |
217 | # the nodes returned by `py__annotations__` for a non-quoted node. | |
218 | node = node.children[0] | |
219 | ||
220 | return node | |
221 | ||
222 | return {name: resolve(node) for name, node in all_annotations.items()} | |
223 | ||
224 | ||
225 | @inference_state_method_cache() | |
226 | def infer_return_types(function, arguments): | |
227 | """ | |
228 | Infers the type of a function's return value, | |
229 | according to type annotations. | |
230 | """ | |
231 | context = function.get_default_param_context() | |
232 | all_annotations = resolve_forward_references( | |
233 | context, | |
234 | py__annotations__(function.tree_node), | |
235 | ) | |
236 | annotation = all_annotations.get("return", None) | |
237 | if annotation is None: | |
238 | # If there is no Python 3-type annotation, look for an annotation | |
239 | # comment. | |
240 | node = function.tree_node | |
241 | comment = parser_utils.get_following_comment_same_line(node) | |
242 | if comment is None: | |
243 | return NO_VALUES | |
244 | ||
245 | match = re.match(r"^#\s*type:\s*\([^#]*\)\s*->\s*([^#]*)", comment) | |
246 | if not match: | |
247 | return NO_VALUES | |
248 | ||
249 | return _infer_annotation_string( | |
250 | context, | |
251 | match.group(1).strip() | |
252 | ).execute_annotation() | |
253 | ||
254 | unknown_type_vars = find_unknown_type_vars(context, annotation) | |
255 | annotation_values = infer_annotation(context, annotation) | |
256 | if not unknown_type_vars: | |
257 | return annotation_values.execute_annotation() | |
258 | ||
259 | type_var_dict = infer_type_vars_for_execution(function, arguments, all_annotations) | |
260 | ||
261 | return ValueSet.from_sets( | |
262 | ann.define_generics(type_var_dict) | |
263 | if isinstance(ann, (DefineGenericBaseClass, TypeVar)) else ValueSet({ann}) | |
264 | for ann in annotation_values | |
265 | ).execute_annotation() | |
266 | ||
267 | ||
268 | def infer_type_vars_for_execution(function, arguments, annotation_dict): | |
269 | """ | |
270 | Some functions use type vars that are not defined by the class, but rather | |
271 | only defined in the function. See for example `iter`. In those cases we | |
272 | want to: | |
273 | ||
274 | 1. Search for undefined type vars. | |
275 | 2. Infer type vars with the execution state we have. | |
276 | 3. Return the union of all type vars that have been found. | |
277 | """ | |
278 | context = function.get_default_param_context() | |
279 | ||
280 | annotation_variable_results = {} | |
281 | executed_param_names = get_executed_param_names(function, arguments) | |
282 | for executed_param_name in executed_param_names: | |
283 | try: | |
284 | annotation_node = annotation_dict[executed_param_name.string_name] | |
285 | except KeyError: | |
286 | continue | |
287 | ||
288 | annotation_variables = find_unknown_type_vars(context, annotation_node) | |
289 | if annotation_variables: | |
290 | # Infer unknown type var | |
291 | annotation_value_set = context.infer_node(annotation_node) | |
292 | kind = executed_param_name.get_kind() | |
293 | actual_value_set = executed_param_name.infer() | |
294 | if kind is Parameter.VAR_POSITIONAL: | |
295 | actual_value_set = actual_value_set.merge_types_of_iterate() | |
296 | elif kind is Parameter.VAR_KEYWORD: | |
297 | # TODO _dict_values is not public. | |
298 | actual_value_set = actual_value_set.try_merge('_dict_values') | |
299 | merge_type_var_dicts( | |
300 | annotation_variable_results, | |
301 | annotation_value_set.infer_type_vars(actual_value_set), | |
302 | ) | |
303 | return annotation_variable_results | |
304 | ||
305 | ||
306 | def infer_return_for_callable(arguments, param_values, result_values): | |
307 | all_type_vars = {} | |
308 | for pv in param_values: | |
309 | if pv.array_type == 'list': | |
310 | type_var_dict = _infer_type_vars_for_callable(arguments, pv.py__iter__()) | |
311 | all_type_vars.update(type_var_dict) | |
312 | ||
313 | return ValueSet.from_sets( | |
314 | v.define_generics(all_type_vars) | |
315 | if isinstance(v, (DefineGenericBaseClass, TypeVar)) | |
316 | else ValueSet({v}) | |
317 | for v in result_values | |
318 | ).execute_annotation() | |
319 | ||
320 | ||
321 | def _infer_type_vars_for_callable(arguments, lazy_params): | |
322 | """ | |
323 | Infers type vars for the Calllable class: | |
324 | ||
325 | def x() -> Callable[[Callable[..., _T]], _T]: ... | |
326 | """ | |
327 | annotation_variable_results = {} | |
328 | for (_, lazy_value), lazy_callable_param in zip(arguments.unpack(), lazy_params): | |
329 | callable_param_values = lazy_callable_param.infer() | |
330 | # Infer unknown type var | |
331 | actual_value_set = lazy_value.infer() | |
332 | merge_type_var_dicts( | |
333 | annotation_variable_results, | |
334 | callable_param_values.infer_type_vars(actual_value_set), | |
335 | ) | |
336 | return annotation_variable_results | |
337 | ||
338 | ||
339 | def merge_type_var_dicts(base_dict, new_dict): | |
340 | for type_var_name, values in new_dict.items(): | |
341 | if values: | |
342 | try: | |
343 | base_dict[type_var_name] |= values | |
344 | except KeyError: | |
345 | base_dict[type_var_name] = values | |
346 | ||
347 | ||
348 | def merge_pairwise_generics(annotation_value, annotated_argument_class): | |
349 | """ | |
350 | Match up the generic parameters from the given argument class to the | |
351 | target annotation. | |
352 | ||
353 | This walks the generic parameters immediately within the annotation and | |
354 | argument's type, in order to determine the concrete values of the | |
355 | annotation's parameters for the current case. | |
356 | ||
357 | For example, given the following code: | |
358 | ||
359 | def values(mapping: Mapping[K, V]) -> List[V]: ... | |
360 | ||
361 | for val in values({1: 'a'}): | |
362 | val | |
363 | ||
364 | Then this function should be given representations of `Mapping[K, V]` | |
365 | and `Mapping[int, str]`, so that it can determine that `K` is `int and | |
366 | `V` is `str`. | |
367 | ||
368 | Note that it is responsibility of the caller to traverse the MRO of the | |
369 | argument type as needed in order to find the type matching the | |
370 | annotation (in this case finding `Mapping[int, str]` as a parent of | |
371 | `Dict[int, str]`). | |
372 | ||
373 | Parameters | |
374 | ---------- | |
375 | ||
376 | `annotation_value`: represents the annotation to infer the concrete | |
377 | parameter types of. | |
378 | ||
379 | `annotated_argument_class`: represents the annotated class of the | |
380 | argument being passed to the object annotated by `annotation_value`. | |
381 | """ | |
382 | ||
383 | type_var_dict = {} | |
384 | ||
385 | if not isinstance(annotated_argument_class, DefineGenericBaseClass): | |
386 | return type_var_dict | |
387 | ||
388 | annotation_generics = annotation_value.get_generics() | |
389 | actual_generics = annotated_argument_class.get_generics() | |
390 | ||
391 | for annotation_generics_set, actual_generic_set in zip(annotation_generics, actual_generics): | |
392 | merge_type_var_dicts( | |
393 | type_var_dict, | |
394 | annotation_generics_set.infer_type_vars(actual_generic_set.execute_annotation()), | |
395 | ) | |
396 | ||
397 | return type_var_dict | |
398 | ||
399 | ||
400 | def find_type_from_comment_hint_for(context, node, name): | |
401 | return _find_type_from_comment_hint(context, node, node.children[1], name) | |
402 | ||
403 | ||
404 | def find_type_from_comment_hint_with(context, node, name): | |
405 | if len(node.children) > 4: | |
406 | # In case there are multiple with_items, we do not want a type hint for | |
407 | # now. | |
408 | return [] | |
409 | assert len(node.children[1].children) == 3, \ | |
410 | "Can only be here when children[1] is 'foo() as f'" | |
411 | varlist = node.children[1].children[2] | |
412 | return _find_type_from_comment_hint(context, node, varlist, name) | |
413 | ||
414 | ||
415 | def find_type_from_comment_hint_assign(context, node, name): | |
416 | return _find_type_from_comment_hint(context, node, node.children[0], name) | |
417 | ||
418 | ||
419 | def _find_type_from_comment_hint(context, node, varlist, name): | |
420 | index = None | |
421 | if varlist.type in ("testlist_star_expr", "exprlist", "testlist"): | |
422 | # something like "a, b = 1, 2" | |
423 | index = 0 | |
424 | for child in varlist.children: | |
425 | if child == name: | |
426 | break | |
427 | if child.type == "operator": | |
428 | continue | |
429 | index += 1 | |
430 | else: | |
431 | return [] | |
432 | ||
433 | comment = parser_utils.get_following_comment_same_line(node) | |
434 | if comment is None: | |
435 | return [] | |
436 | match = re.match(r"^#\s*type:\s*([^#]*)", comment) | |
437 | if match is None: | |
438 | return [] | |
439 | return _infer_annotation_string( | |
440 | context, match.group(1).strip(), index | |
441 | ).execute_annotation() | |
442 | ||
443 | ||
444 | def find_unknown_type_vars(context, node): | |
445 | def check_node(node): | |
446 | if node.type in ('atom_expr', 'power'): | |
447 | trailer = node.children[-1] | |
448 | if trailer.type == 'trailer' and trailer.children[0] == '[': | |
449 | for subscript_node in _unpack_subscriptlist(trailer.children[1]): | |
450 | check_node(subscript_node) | |
451 | else: | |
452 | found[:] = _filter_type_vars(context.infer_node(node), found) | |
453 | ||
454 | found = [] # We're not using a set, because the order matters. | |
455 | check_node(node) | |
456 | return found | |
457 | ||
458 | ||
459 | def _filter_type_vars(value_set, found=()): | |
460 | new_found = list(found) | |
461 | for type_var in value_set: | |
462 | if isinstance(type_var, TypeVar) and type_var not in found: | |
463 | new_found.append(type_var) | |
464 | return new_found | |
465 | ||
466 | ||
467 | def _unpack_subscriptlist(subscriptlist): | |
468 | if subscriptlist.type == 'subscriptlist': | |
469 | for subscript in subscriptlist.children[::2]: | |
470 | if subscript.type != 'subscript': | |
471 | yield subscript | |
472 | else: | |
473 | if subscriptlist.type != 'subscript': | |
474 | yield subscriptlist |