-
-
Notifications
You must be signed in to change notification settings - Fork 3k
Rework starargs with union argument #19651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
245cba3
f2d254c
df7b3f0
5e239a3
6cbb237
115bf00
08db3da
fac1e07
68bc671
e8dcf88
fe4289a
0139c12
109b4f3
0a6d757
e169d3e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,27 +3,39 @@ | |
from __future__ import annotations | ||
|
||
from collections.abc import Sequence | ||
from typing import TYPE_CHECKING, Callable | ||
from typing import TYPE_CHECKING, Callable, cast | ||
from typing_extensions import NewType, TypeGuard, TypeIs | ||
|
||
from mypy import nodes | ||
from mypy.maptype import map_instance_to_supertype | ||
from mypy.typeops import make_simplified_union | ||
from mypy.types import ( | ||
AnyType, | ||
CallableType, | ||
Instance, | ||
ParamSpecType, | ||
ProperType, | ||
TupleType, | ||
Type, | ||
TypedDictType, | ||
TypeOfAny, | ||
TypeVarId, | ||
TypeVarTupleType, | ||
TypeVarType, | ||
UnionType, | ||
UnpackType, | ||
flatten_nested_tuples, | ||
get_proper_type, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from mypy.infer import ArgumentInferContext | ||
|
||
|
||
IterableType = NewType("IterableType", Instance) | ||
"""Represents an instance of `Iterable[T]`.""" | ||
|
||
|
||
def map_actuals_to_formals( | ||
actual_kinds: list[nodes.ArgKind], | ||
actual_names: Sequence[str | None] | None, | ||
|
@@ -54,6 +66,17 @@ def map_actuals_to_formals( | |
elif actual_kind == nodes.ARG_STAR: | ||
# We need to know the actual type to map varargs. | ||
actualt = get_proper_type(actual_arg_type(ai)) | ||
|
||
# Special case for union of equal sized tuples. | ||
if ( | ||
isinstance(actualt, UnionType) | ||
and actualt.items | ||
and is_equal_sized_tuples( | ||
proper_types := [get_proper_type(t) for t in actualt.items] | ||
) | ||
): | ||
# pick an arbitrary member | ||
actualt = proper_types[0] | ||
if isinstance(actualt, TupleType): | ||
# A tuple actual maps to a fixed number of formals. | ||
for _ in range(len(actualt.items)): | ||
|
@@ -193,32 +216,20 @@ def expand_actual_type( | |
original_actual = actual_type | ||
actual_type = get_proper_type(actual_type) | ||
if actual_kind == nodes.ARG_STAR: | ||
if isinstance(actual_type, TypeVarTupleType): | ||
# This code path is hit when *Ts is passed to a callable and various | ||
# special-handling didn't catch this. The best thing we can do is to use | ||
# the upper bound. | ||
actual_type = get_proper_type(actual_type.upper_bound) | ||
if isinstance(actual_type, Instance) and actual_type.args: | ||
from mypy.subtypes import is_subtype | ||
|
||
if is_subtype(actual_type, self.context.iterable_type): | ||
return map_instance_to_supertype( | ||
actual_type, self.context.iterable_type.type | ||
).args[0] | ||
else: | ||
# We cannot properly unpack anything other | ||
# than `Iterable` type with `*`. | ||
# Just return `Any`, other parts of code would raise | ||
# a different error for improper use. | ||
return AnyType(TypeOfAny.from_error) | ||
elif isinstance(actual_type, TupleType): | ||
# parse *args as one of the following: | ||
# IterableType | TupleType | ParamSpecType | AnyType | ||
star_args_type = self.parse_star_args_type(actual_type) | ||
|
||
if self.is_iterable_instance_type(star_args_type): | ||
return star_args_type.args[0] | ||
elif isinstance(star_args_type, TupleType): | ||
# Get the next tuple item of a tuple *arg. | ||
if self.tuple_index >= len(actual_type.items): | ||
if self.tuple_index >= len(star_args_type.items): | ||
# Exhausted a tuple -- continue to the next *args. | ||
self.tuple_index = 1 | ||
else: | ||
self.tuple_index += 1 | ||
item = actual_type.items[self.tuple_index - 1] | ||
item = star_args_type.items[self.tuple_index - 1] | ||
if isinstance(item, UnpackType) and not allow_unpack: | ||
# An unpack item that doesn't have special handling, use upper bound as above. | ||
unpacked = get_proper_type(item.type) | ||
|
@@ -232,9 +243,9 @@ def expand_actual_type( | |
) | ||
item = fallback.args[0] | ||
return item | ||
elif isinstance(actual_type, ParamSpecType): | ||
elif isinstance(star_args_type, ParamSpecType): | ||
randolf-scholz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# ParamSpec is valid in *args but it can't be unpacked. | ||
return actual_type | ||
return star_args_type | ||
else: | ||
return AnyType(TypeOfAny.from_error) | ||
elif actual_kind == nodes.ARG_STAR2: | ||
|
@@ -265,3 +276,195 @@ def expand_actual_type( | |
else: | ||
# No translation for other kinds -- 1:1 mapping. | ||
return original_actual | ||
|
||
def is_iterable(self, typ: Type) -> bool: | ||
"""Check if the type is an iterable, i.e. implements the Iterable Protocol.""" | ||
from mypy.subtypes import is_subtype | ||
|
||
return is_subtype(typ, self.context.iterable_type) | ||
|
||
def is_iterable_instance_type(self, typ: Type) -> TypeIs[IterableType]: | ||
"""Check if the type is an Iterable[T].""" | ||
p_t = get_proper_type(typ) | ||
return isinstance(p_t, Instance) and p_t.type == self.context.iterable_type.type | ||
randolf-scholz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _make_iterable_instance_type(self, arg: Type) -> IterableType: | ||
value = Instance(self.context.iterable_type.type, [arg]) | ||
return cast(IterableType, value) | ||
|
||
def _solve_as_iterable(self, typ: Type) -> IterableType | AnyType: | ||
r"""Use the solver to cast a type as Iterable[T]. | ||
|
||
Returns `AnyType` if solving fails. | ||
""" | ||
from mypy.constraints import infer_constraints_for_callable | ||
from mypy.nodes import ARG_POS | ||
from mypy.solve import solve_constraints | ||
|
||
# We first create an upcast function: | ||
# def [T] (Iterable[T]) -> Iterable[T]: ... | ||
# and then solve for T, given the input type as the argument. | ||
T = TypeVarType( | ||
"T", | ||
"T", | ||
TypeVarId(-1), | ||
values=[], | ||
upper_bound=AnyType(TypeOfAny.from_omitted_generics), | ||
default=AnyType(TypeOfAny.from_omitted_generics), | ||
) | ||
target = self._make_iterable_instance_type(T) | ||
upcast_callable = CallableType( | ||
variables=[T], | ||
arg_types=[target], | ||
arg_kinds=[ARG_POS], | ||
arg_names=[None], | ||
ret_type=target, | ||
fallback=self.context.function_type, | ||
) | ||
constraints = infer_constraints_for_callable( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can |
||
upcast_callable, [typ], [ARG_POS], [None], [[0]], self.context | ||
) | ||
|
||
(sol,), _ = solve_constraints([T], constraints) | ||
|
||
if sol is None: # solving failed, return AnyType fallback | ||
return AnyType(TypeOfAny.from_error) | ||
return self._make_iterable_instance_type(sol) | ||
|
||
def as_iterable_type(self, typ: Type) -> IterableType | AnyType: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we reuse logic from |
||
"""Reinterpret a type as Iterable[T], or return AnyType if not possible. | ||
|
||
This function specially handles certain types like UnionType, TupleType, and UnpackType. | ||
Otherwise, the upcasting is performed using the solver. | ||
""" | ||
p_t = get_proper_type(typ) | ||
if self.is_iterable_instance_type(p_t) or isinstance(p_t, AnyType): | ||
return p_t | ||
elif isinstance(p_t, UnionType): | ||
# If the type is a union, map each item to the iterable supertype. | ||
# the return the combined iterable type Iterable[A] | Iterable[B] -> Iterable[A | B] | ||
converted_types = [self.as_iterable_type(get_proper_type(item)) for item in p_t.items] | ||
|
||
if any(not self.is_iterable_instance_type(it) for it in converted_types): | ||
# if any item could not be interpreted as Iterable[T], we return AnyType | ||
return AnyType(TypeOfAny.from_error) | ||
else: | ||
# all items are iterable, return Iterable[T₁ | T₂ | ... | Tₙ] | ||
randolf-scholz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
iterable_types = cast(list[IterableType], converted_types) | ||
arg = make_simplified_union([it.args[0] for it in iterable_types]) | ||
return self._make_iterable_instance_type(arg) | ||
elif isinstance(p_t, TupleType): | ||
# maps tuple[A, B, C] -> Iterable[A | B | C] | ||
# note: proper_elements may contain UnpackType, for instance with | ||
# tuple[None, *tuple[None, ...]].. | ||
proper_elements = [get_proper_type(t) for t in flatten_nested_tuples(p_t.items)] | ||
args: list[Type] = [] | ||
for p_e in proper_elements: | ||
if isinstance(p_e, UnpackType): | ||
r = self.as_iterable_type(p_e) | ||
if self.is_iterable_instance_type(r): | ||
args.append(r.args[0]) | ||
else: | ||
# this *should* never happen, since UnpackType should | ||
# only contain TypeVarTuple or a variable length tuple. | ||
# However, we could get an `AnyType(TypeOfAny.from_error)` | ||
# if for some reason the solver was triggered and failed. | ||
args.append(r) | ||
else: | ||
args.append(p_e) | ||
return self._make_iterable_instance_type(make_simplified_union(args)) | ||
elif isinstance(p_t, UnpackType): | ||
return self.as_iterable_type(p_t.type) | ||
elif isinstance(p_t, (TypeVarType, TypeVarTupleType)): | ||
return self.as_iterable_type(p_t.upper_bound) | ||
elif self.is_iterable(p_t): | ||
# TODO: add a 'fast path' (needs measurement) that uses the map_instance_to_supertype | ||
# mechanism? (Only if it works: gh-19662) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
return self._solve_as_iterable(p_t) | ||
return AnyType(TypeOfAny.from_error) | ||
|
||
def parse_star_args_type( | ||
self, typ: Type | ||
) -> TupleType | IterableType | ParamSpecType | AnyType: | ||
"""Parse the type of a ``*args`` argument. | ||
|
||
Returns one of TupleType, IterableType, ParamSpecType or AnyType. | ||
Returns AnyType(TypeOfAny.from_error) if the type cannot be parsed or is invalid. | ||
""" | ||
p_t = get_proper_type(typ) | ||
if isinstance(p_t, (TupleType, ParamSpecType, AnyType)): | ||
# just return the type as-is | ||
return p_t | ||
elif isinstance(p_t, TypeVarTupleType): | ||
return self.parse_star_args_type(p_t.upper_bound) | ||
elif isinstance(p_t, UnionType): | ||
proper_items = [get_proper_type(t) for t in p_t.items] | ||
# consider 2 cases: | ||
# 1. Union of equal sized tuples, e.g. tuple[A, B] | tuple[None, None] | ||
# In this case transform union of same-sized tuples into a tuple of unions | ||
# e.g. tuple[A, B] | tuple[None, None] -> tuple[A | None, B | None] | ||
if is_equal_sized_tuples(proper_items): | ||
|
||
tuple_args: list[Type] = [ | ||
make_simplified_union(items) for items in zip(*(t.items for t in proper_items)) | ||
] | ||
actual_type = TupleType( | ||
tuple_args, | ||
# use Iterable[A | B | C] as the fallback type | ||
fallback=Instance( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AFAIC we only use fallback to store There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
self.context.iterable_type.type, [UnionType.make_union(tuple_args)] | ||
), | ||
) | ||
return actual_type | ||
# 2. Union of iterable types, e.g. Iterable[A] | Iterable[B] | ||
# In this case return Iterable[A | B] | ||
# Note that this covers unions of differently sized tuples as well. | ||
else: | ||
converted_types = [self.as_iterable_type(p_i) for p_i in proper_items] | ||
if all(self.is_iterable_instance_type(it) for it in converted_types): | ||
# all items are iterable, return Iterable[T1 | T2 | ... | Tn] | ||
iterables = cast(list[IterableType], converted_types) | ||
arg = make_simplified_union([it.args[0] for it in iterables]) | ||
return self._make_iterable_instance_type(arg) | ||
else: | ||
# some items in the union are not iterable, return AnyType | ||
return AnyType(TypeOfAny.from_error) | ||
elif self.is_iterable_instance_type(parsed := self.as_iterable_type(p_t)): | ||
# in all other cases, we try to reinterpret the type as Iterable[T] | ||
return parsed | ||
return AnyType(TypeOfAny.from_error) | ||
|
||
|
||
def is_equal_sized_tuples(types: Sequence[ProperType]) -> TypeGuard[Sequence[TupleType]]: | ||
"""Check if all types are tuples of the same size. | ||
|
||
We use `flatten_nested_tuples` to deal with nested tuples. | ||
Note that the result may still contain | ||
""" | ||
if not types: | ||
return True | ||
|
||
iterator = iter(types) | ||
typ = next(iterator) | ||
if not isinstance(typ, TupleType): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you refactor L449-L458 into a helper function, this method would look like
|
||
return False | ||
flattened_elements = flatten_nested_tuples(typ.items) | ||
if any( | ||
isinstance(get_proper_type(member), (UnpackType, TypeVarTupleType)) | ||
for member in flattened_elements | ||
): | ||
# this can happen e.g. with tuple[int, *tuple[int, ...], int] | ||
return False | ||
size = len(flattened_elements) | ||
|
||
for typ in iterator: | ||
if not isinstance(typ, TupleType): | ||
return False | ||
flattened_elements = flatten_nested_tuples(typ.items) | ||
if len(flattened_elements) != size or any( | ||
isinstance(get_proper_type(member), (UnpackType, TypeVarTupleType)) | ||
for member in flattened_elements | ||
): | ||
# this can happen e.g. with tuple[int, *tuple[int, ...], int] | ||
return False | ||
return True |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2285,7 +2285,9 @@ def infer_function_type_arguments_pass2( | |
def argument_infer_context(self) -> ArgumentInferContext: | ||
if self._arg_infer_context_cache is None: | ||
self._arg_infer_context_cache = ArgumentInferContext( | ||
self.chk.named_type("typing.Mapping"), self.chk.named_type("typing.Iterable") | ||
self.chk.named_type("typing.Mapping"), | ||
self.chk.named_type("typing.Iterable"), | ||
self.chk.named_type("builtins.function"), | ||
) | ||
return self._arg_infer_context_cache | ||
|
||
|
@@ -2670,6 +2672,30 @@ def check_arg( | |
original_caller_type = get_proper_type(original_caller_type) | ||
callee_type = get_proper_type(callee_type) | ||
|
||
if isinstance(callee_type, UnpackType) and not isinstance(caller_type, UnpackType): | ||
# it can happen that the caller_type got expanded. | ||
# since this is from a callable definition, it should be one of the following: | ||
# - TupleType, TypeVarTupleType, or a variable length tuple Instance. | ||
unpack_arg = get_proper_type(callee_type.type) | ||
if isinstance(unpack_arg, TypeVarTupleType): | ||
# substitute with upper bound of the TypeVarTuple | ||
unpack_arg = get_proper_type(unpack_arg.upper_bound) | ||
# note: not using elif, since in the future upper bound may be a finite tuple | ||
if isinstance(unpack_arg, Instance) and unpack_arg.type.fullname == "builtins.tuple": | ||
callee_type = get_proper_type(unpack_arg.args[0]) | ||
elif isinstance(unpack_arg, TupleType): | ||
# this branch should currently never hit, but it may hit in the future, | ||
# if it will ever be allowed to upper bound TypeVarTuple with a tuple type. | ||
elements = flatten_nested_tuples(unpack_arg.items) | ||
if m < len(elements): | ||
# pick the corresponding item from the tuple | ||
callee_type = get_proper_type(elements[m]) | ||
else: | ||
self.chk.fail(message_registry.TUPLE_INDEX_OUT_OF_RANGE, context) | ||
return | ||
else: | ||
raise TypeError(f"did not expect unpack_arg to be of type {type(unpack_arg)=}") | ||
|
||
if isinstance(caller_type, DeletedType): | ||
self.msg.deleted_as_rvalue(caller_type, context) | ||
# Only non-abstract non-protocol class can be given where Type[...] is expected... | ||
|
@@ -5225,29 +5251,33 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type: | |
ctx = ctx_item.type | ||
else: | ||
ctx = None | ||
tt = self.accept(item.expr, ctx) | ||
tt = get_proper_type(tt) | ||
if isinstance(tt, TupleType): | ||
if find_unpack_in_list(tt.items) is not None: | ||
original_arg_type = self.accept(item.expr, ctx) | ||
# convert arg type to one of TupleType, IterableType, AnyType or | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ...or? :) |
||
arg_type_expander = ArgTypeExpander(self.argument_infer_context()) | ||
star_args_type = arg_type_expander.parse_star_args_type(original_arg_type) | ||
if isinstance(star_args_type, TupleType): | ||
if find_unpack_in_list(star_args_type.items) is not None: | ||
if seen_unpack_in_items: | ||
# Multiple unpack items are not allowed in tuples, | ||
# fall back to instance type. | ||
return self.check_lst_expr(e, "builtins.tuple", "<tuple>") | ||
else: | ||
seen_unpack_in_items = True | ||
items.extend(tt.items) | ||
items.extend(star_args_type.items) | ||
# Note: this logic depends on full structure match in tuple_context_matches(). | ||
if unpack_in_context: | ||
j += 1 | ||
else: | ||
# If there is an unpack in expressions, but not in context, this will | ||
# result in an error later, just do something predictable here. | ||
j += len(tt.items) | ||
j += len(star_args_type.items) | ||
else: | ||
if allow_precise_tuples and not seen_unpack_in_items: | ||
# Handle (x, *y, z), where y is e.g. tuple[Y, ...]. | ||
if isinstance(tt, Instance) and self.chk.type_is_iterable(tt): | ||
item_type = self.chk.iterable_item_type(tt, e) | ||
if isinstance(star_args_type, Instance) and self.chk.type_is_iterable( | ||
star_args_type | ||
): | ||
item_type = self.chk.iterable_item_type(star_args_type, e) | ||
mapped = self.chk.named_generic_type("builtins.tuple", [item_type]) | ||
items.append(UnpackType(mapped)) | ||
seen_unpack_in_items = True | ||
|
Uh oh!
There was an error while loading. Please reload this page.