This commit is contained in:
ton
2024-10-07 10:13:40 +07:00
parent aa1631742f
commit 3a7d696db6
9729 changed files with 1832837 additions and 161742 deletions

View File

@@ -218,8 +218,12 @@ def _should_unflatten_callable_args(typ, args):
For example::
assert collections.abc.Callable[[int, int], str].__args__ == (int, int, str)
assert collections.abc.Callable[ParamSpec, str].__args__ == (ParamSpec, str)
>>> import collections.abc
>>> P = ParamSpec('P')
>>> collections.abc.Callable[[int, int], str].__args__ == (int, int, str)
True
>>> collections.abc.Callable[P, str].__args__ == (P, str)
True
As a result, if we need to reconstruct the Callable from its __args__,
we need to unflatten it.
@@ -261,7 +265,10 @@ def _collect_parameters(args):
For example::
assert _collect_parameters((T, Callable[P, T])) == (T, P)
>>> P = ParamSpec('P')
>>> T = TypeVar('T')
>>> _collect_parameters((T, Callable[P, T]))
(~T, ~P)
"""
parameters = []
for t in args:
@@ -307,19 +314,33 @@ def _unpack_args(args):
newargs.append(arg)
return newargs
def _deduplicate(params):
def _deduplicate(params, *, unhashable_fallback=False):
# Weed out strict duplicates, preserving the first of each occurrence.
all_params = set(params)
if len(all_params) < len(params):
new_params = []
for t in params:
if t in all_params:
new_params.append(t)
all_params.remove(t)
params = new_params
assert not all_params, all_params
return params
try:
return dict.fromkeys(params)
except TypeError:
if not unhashable_fallback:
raise
# Happens for cases like `Annotated[dict, {'x': IntValidator()}]`
return _deduplicate_unhashable(params)
def _deduplicate_unhashable(unhashable_params):
new_unhashable = []
for t in unhashable_params:
if t not in new_unhashable:
new_unhashable.append(t)
return new_unhashable
def _compare_args_orderless(first_args, second_args):
first_unhashable = _deduplicate_unhashable(first_args)
second_unhashable = _deduplicate_unhashable(second_args)
t = list(second_unhashable)
try:
for elem in first_unhashable:
t.remove(elem)
except ValueError:
return False
return not t
def _remove_dups_flatten(parameters):
"""Internal helper for Union creation and substitution.
@@ -334,7 +355,7 @@ def _remove_dups_flatten(parameters):
else:
params.append(p)
return tuple(_deduplicate(params))
return tuple(_deduplicate(params, unhashable_fallback=True))
def _flatten_literal_params(parameters):
@@ -382,7 +403,8 @@ def _tp_cache(func=None, /, *, typed=False):
return decorator
def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
def _eval_type(t, globalns, localns, type_params=None, *, recursive_guard=frozenset()):
"""Evaluate all forward references in the given type t.
For use of globalns and localns see the docstring for get_type_hints().
@@ -390,7 +412,7 @@ def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
ForwardRef.
"""
if isinstance(t, ForwardRef):
return t._evaluate(globalns, localns, recursive_guard)
return t._evaluate(globalns, localns, type_params, recursive_guard=recursive_guard)
if isinstance(t, (_GenericAlias, GenericAlias, types.UnionType)):
if isinstance(t, GenericAlias):
args = tuple(
@@ -404,7 +426,13 @@ def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
t = t.__origin__[args]
if is_unpacked:
t = Unpack[t]
ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__)
ev_args = tuple(
_eval_type(
a, globalns, localns, type_params, recursive_guard=recursive_guard
)
for a in t.__args__
)
if ev_args == t.__args__:
return t
if isinstance(t, GenericAlias):
@@ -523,7 +551,7 @@ class Any(metaclass=_AnyMeta):
def __new__(cls, *args, **kwargs):
if cls is Any:
raise TypeError("Any cannot be instantiated")
return super().__new__(cls, *args, **kwargs)
return super().__new__(cls)
@_SpecialForm
@@ -825,22 +853,25 @@ def TypeGuard(self, parameters):
2. If the return value is ``True``, the type of its argument
is the type inside ``TypeGuard``.
For example::
For example::
def is_str(val: Union[str, float]):
# "isinstance" type guard
if isinstance(val, str):
# Type of ``val`` is narrowed to ``str``
...
else:
# Else, type of ``val`` is narrowed to ``float``.
...
def is_str_list(val: list[object]) -> TypeGuard[list[str]]:
'''Determines whether all objects in the list are strings'''
return all(isinstance(x, str) for x in val)
def func1(val: list[object]):
if is_str_list(val):
# Type of ``val`` is narrowed to ``list[str]``.
print(" ".join(val))
else:
# Type of ``val`` remains as ``list[object]``.
print("Not a list of strings!")
Strict type narrowing is not enforced -- ``TypeB`` need not be a narrower
form of ``TypeA`` (it can even be a wider form) and this may lead to
type-unsafe results. The main reason is to allow for things like
narrowing ``List[object]`` to ``List[str]`` even though the latter is not
a subtype of the former, since ``List`` is invariant. The responsibility of
narrowing ``list[object]`` to ``list[str]`` even though the latter is not
a subtype of the former, since ``list`` is invariant. The responsibility of
writing type-safe type guards is left to the user.
``TypeGuard`` also works with type variables. For more information, see
@@ -865,7 +896,7 @@ class ForwardRef(_Final, _root=True):
# If we do `def f(*args: *Ts)`, then we'll have `arg = '*Ts'`.
# Unfortunately, this isn't a valid expression on its own, so we
# do the unpacking manually.
if arg[0] == '*':
if arg.startswith('*'):
arg_to_compile = f'({arg},)[0]' # E.g. (*Ts,)[0] or (*tuple[int, int],)[0]
else:
arg_to_compile = arg
@@ -882,7 +913,7 @@ class ForwardRef(_Final, _root=True):
self.__forward_is_class__ = is_class
self.__forward_module__ = module
def _evaluate(self, globalns, localns, recursive_guard):
def _evaluate(self, globalns, localns, type_params=None, *, recursive_guard):
if self.__forward_arg__ in recursive_guard:
return self
if not self.__forward_evaluated__ or localns is not globalns:
@@ -896,6 +927,22 @@ class ForwardRef(_Final, _root=True):
globalns = getattr(
sys.modules.get(self.__forward_module__, None), '__dict__', globalns
)
# type parameters require some special handling,
# as they exist in their own scope
# but `eval()` does not have a dedicated parameter for that scope.
# For classes, names in type parameter scopes should override
# names in the global scope (which here are called `localns`!),
# but should in turn be overridden by names in the class scope
# (which here are called `globalns`!)
if type_params:
globalns, localns = dict(globalns), dict(localns)
for param in type_params:
param_name = param.__name__
if not self.__forward_is_class__ or param_name not in globalns:
globalns[param_name] = param
localns.pop(param_name, None)
type_ = _type_check(
eval(self.__forward_code__, globalns, localns),
"Forward references must evaluate to types.",
@@ -903,7 +950,11 @@ class ForwardRef(_Final, _root=True):
allow_special_forms=self.__forward_is_class__,
)
self.__forward_value__ = _eval_type(
type_, globalns, localns, recursive_guard | {self.__forward_arg__}
type_,
globalns,
localns,
type_params,
recursive_guard=(recursive_guard | {self.__forward_arg__}),
)
self.__forward_evaluated__ = True
return self.__forward_value__
@@ -1133,7 +1184,9 @@ class _BaseGenericAlias(_Final, _root=True):
result = self.__origin__(*args, **kwargs)
try:
result.__orig_class__ = self
except AttributeError:
# Some objects raise TypeError (or something even more exotic)
# if you try to set attributes on them; we guard against that here
except Exception:
pass
return result
@@ -1539,7 +1592,10 @@ class _UnionGenericAlias(_NotIterable, _GenericAlias, _root=True):
def __eq__(self, other):
if not isinstance(other, (_UnionGenericAlias, types.UnionType)):
return NotImplemented
return set(self.__args__) == set(other.__args__)
try: # fast path
return set(self.__args__) == set(other.__args__)
except TypeError: # not hashable, slow path
return _compare_args_orderless(self.__args__, other.__args__)
def __hash__(self):
return hash(frozenset(self.__args__))
@@ -1657,8 +1713,9 @@ class _UnpackGenericAlias(_GenericAlias, _root=True):
assert self.__origin__ is Unpack
assert len(self.__args__) == 1
arg, = self.__args__
if isinstance(arg, _GenericAlias):
assert arg.__origin__ is tuple
if isinstance(arg, (_GenericAlias, types.GenericAlias)):
if arg.__origin__ is not tuple:
raise TypeError("Unpack[...] must be used with a tuple type")
return arg.__args__
return None
@@ -1676,7 +1733,7 @@ class _TypingEllipsis:
_TYPING_INTERNALS = frozenset({
'__parameters__', '__orig_bases__', '__orig_class__',
'_is_protocol', '_is_runtime_protocol', '__protocol_attrs__',
'__callable_proto_members_only__', '__type_params__',
'__non_callable_proto_members__', '__type_params__',
})
_SPECIAL_NAMES = frozenset({
@@ -1813,11 +1870,6 @@ class _ProtocolMeta(ABCMeta):
super().__init__(*args, **kwargs)
if getattr(cls, "_is_protocol", False):
cls.__protocol_attrs__ = _get_protocol_attrs(cls)
# PEP 544 prohibits using issubclass()
# with protocols that have non-method members.
cls.__callable_proto_members_only__ = all(
callable(getattr(cls, attr, None)) for attr in cls.__protocol_attrs__
)
def __subclasscheck__(cls, other):
if cls is Protocol:
@@ -1829,18 +1881,19 @@ class _ProtocolMeta(ABCMeta):
if not isinstance(other, type):
# Same error message as for issubclass(1, int).
raise TypeError('issubclass() arg 1 must be a class')
if (
not cls.__callable_proto_members_only__
and cls.__dict__.get("__subclasshook__") is _proto_hook
):
raise TypeError(
"Protocols with non-method members don't support issubclass()"
)
if not getattr(cls, '_is_runtime_protocol', False):
raise TypeError(
"Instance and class checks can only be used with "
"@runtime_checkable protocols"
)
if (
# this attribute is set by @runtime_checkable:
cls.__non_callable_proto_members__
and cls.__dict__.get("__subclasshook__") is _proto_hook
):
raise TypeError(
"Protocols with non-method members don't support issubclass()"
)
return super().__subclasscheck__(other)
def __instancecheck__(cls, instance):
@@ -1868,7 +1921,8 @@ class _ProtocolMeta(ABCMeta):
val = getattr_static(instance, attr)
except AttributeError:
break
if val is None and callable(getattr(cls, attr, None)):
# this attribute is set by @runtime_checkable:
if val is None and attr not in cls.__non_callable_proto_members__:
break
else:
return True
@@ -2058,9 +2112,14 @@ class Annotated:
def __new__(cls, *args, **kwargs):
raise TypeError("Type Annotated cannot be instantiated.")
@_tp_cache
def __class_getitem__(cls, params):
if not isinstance(params, tuple) or len(params) < 2:
if not isinstance(params, tuple):
params = (params,)
return cls._class_getitem_inner(cls, *params)
@_tp_cache(typed=True)
def _class_getitem_inner(cls, *params):
if len(params) < 2:
raise TypeError("Annotated[...] should be used "
"with at least two arguments (a type and an "
"annotation).")
@@ -2101,6 +2160,22 @@ def runtime_checkable(cls):
raise TypeError('@runtime_checkable can be only applied to protocol classes,'
' got %r' % cls)
cls._is_runtime_protocol = True
# PEP 544 prohibits using issubclass()
# with protocols that have non-method members.
# See gh-113320 for why we compute this attribute here,
# rather than in `_ProtocolMeta.__init__`
cls.__non_callable_proto_members__ = set()
for attr in cls.__protocol_attrs__:
try:
is_callable = callable(getattr(cls, attr, None))
except Exception as e:
raise TypeError(
f"Failed to determine whether protocol member {attr!r} "
"is a method member"
) from e
else:
if not is_callable:
cls.__non_callable_proto_members__.add(attr)
return cls
@@ -2194,7 +2269,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
value = type(None)
if isinstance(value, str):
value = ForwardRef(value, is_argument=False, is_class=True)
value = _eval_type(value, base_globals, base_locals)
value = _eval_type(value, base_globals, base_locals, base.__type_params__)
hints[name] = value
return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()}
@@ -2220,6 +2295,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
raise TypeError('{!r} is not a module, class, method, '
'or function.'.format(obj))
hints = dict(hints)
type_params = getattr(obj, "__type_params__", ())
for name, value in hints.items():
if value is None:
value = type(None)
@@ -2231,7 +2307,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
is_argument=not isinstance(obj, types.ModuleType),
is_class=False,
)
hints[name] = _eval_type(value, globalns, localns)
hints[name] = _eval_type(value, globalns, localns, type_params)
return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()}
@@ -2268,14 +2344,15 @@ def get_origin(tp):
Examples::
assert get_origin(Literal[42]) is Literal
assert get_origin(int) is None
assert get_origin(ClassVar[int]) is ClassVar
assert get_origin(Generic) is Generic
assert get_origin(Generic[T]) is Generic
assert get_origin(Union[T, int]) is Union
assert get_origin(List[Tuple[T, T]][int]) is list
assert get_origin(P.args) is P
>>> P = ParamSpec('P')
>>> assert get_origin(Literal[42]) is Literal
>>> assert get_origin(int) is None
>>> assert get_origin(ClassVar[int]) is ClassVar
>>> assert get_origin(Generic) is Generic
>>> assert get_origin(Generic[T]) is Generic
>>> assert get_origin(Union[T, int]) is Union
>>> assert get_origin(List[Tuple[T, T]][int]) is list
>>> assert get_origin(P.args) is P
"""
if isinstance(tp, _AnnotatedAlias):
return Annotated
@@ -2296,11 +2373,12 @@ def get_args(tp):
Examples::
assert get_args(Dict[str, int]) == (str, int)
assert get_args(int) == ()
assert get_args(Union[int, Union[T, int], str][int]) == (int, str)
assert get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int])
assert get_args(Callable[[], T][int]) == ([], int)
>>> T = TypeVar('T')
>>> assert get_args(Dict[str, int]) == (str, int)
>>> assert get_args(int) == ()
>>> assert get_args(Union[int, Union[T, int], str][int]) == (int, str)
>>> assert get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int])
>>> assert get_args(Callable[[], T][int]) == ([], int)
"""
if isinstance(tp, _AnnotatedAlias):
return (tp.__origin__,) + tp.__metadata__
@@ -2319,12 +2397,15 @@ def is_typeddict(tp):
For example::
class Film(TypedDict):
title: str
year: int
is_typeddict(Film) # => True
is_typeddict(Union[list, str]) # => False
>>> from typing import TypedDict
>>> class Film(TypedDict):
... title: str
... year: int
...
>>> is_typeddict(Film)
True
>>> is_typeddict(dict)
False
"""
return isinstance(tp, _TypedDictMeta)
@@ -2834,8 +2915,14 @@ class _TypedDictMeta(type):
for base in bases:
annotations.update(base.__dict__.get('__annotations__', {}))
required_keys.update(base.__dict__.get('__required_keys__', ()))
optional_keys.update(base.__dict__.get('__optional_keys__', ()))
base_required = base.__dict__.get('__required_keys__', set())
required_keys |= base_required
optional_keys -= base_required
base_optional = base.__dict__.get('__optional_keys__', set())
required_keys -= base_optional
optional_keys |= base_optional
annotations.update(own_annotations)
for annotation_key, annotation_type in own_annotations.items():
@@ -2847,14 +2934,23 @@ class _TypedDictMeta(type):
annotation_origin = get_origin(annotation_type)
if annotation_origin is Required:
required_keys.add(annotation_key)
is_required = True
elif annotation_origin is NotRequired:
optional_keys.add(annotation_key)
elif total:
is_required = False
else:
is_required = total
if is_required:
required_keys.add(annotation_key)
optional_keys.discard(annotation_key)
else:
optional_keys.add(annotation_key)
required_keys.discard(annotation_key)
assert required_keys.isdisjoint(optional_keys), (
f"Required keys overlap with optional keys in {name}:"
f" {required_keys=}, {optional_keys=}"
)
tp_dict.__annotations__ = annotations
tp_dict.__required_keys__ = frozenset(required_keys)
tp_dict.__optional_keys__ = frozenset(optional_keys)
@@ -2881,15 +2977,15 @@ def TypedDict(typename, fields=None, /, *, total=True, **kwargs):
Usage::
class Point2D(TypedDict):
x: int
y: int
label: str
a: Point2D = {'x': 1, 'y': 2, 'label': 'good'} # OK
b: Point2D = {'z': 3, 'label': 'bad'} # Fails type check
assert Point2D(x=1, y=2, label='first') == dict(x=1, y=2, label='first')
>>> class Point2D(TypedDict):
... x: int
... y: int
... label: str
...
>>> a: Point2D = {'x': 1, 'y': 2, 'label': 'good'} # OK
>>> b: Point2D = {'z': 3, 'label': 'bad'} # Fails type check
>>> Point2D(x=1, y=2, label='first') == dict(x=1, y=2, label='first')
True
The type info can be accessed via the Point2D.__annotations__ dict, and
the Point2D.__required_keys__ and Point2D.__optional_keys__ frozensets.
@@ -3209,11 +3305,11 @@ class TextIO(IO[str]):
class _DeprecatedType(type):
def __getattribute__(cls, name):
if name not in ("__dict__", "__module__") and name in cls.__dict__:
if name not in {"__dict__", "__module__", "__doc__"} and name in cls.__dict__:
warnings.warn(
f"{cls.__name__} is deprecated, import directly "
f"from typing instead. {cls.__name__} will be removed "
"in Python 3.12.",
"in Python 3.13.",
DeprecationWarning,
stacklevel=2,
)
@@ -3248,7 +3344,7 @@ sys.modules[re.__name__] = re
def reveal_type[T](obj: T, /) -> T:
"""Reveal the inferred type of a variable.
"""Ask a static type checker to reveal the inferred type of an expression.
When a static type checker encounters a call to ``reveal_type()``,
it will emit the inferred type of the argument::
@@ -3260,7 +3356,7 @@ def reveal_type[T](obj: T, /) -> T:
will produce output similar to 'Revealed type is "builtins.int"'.
At runtime, the function prints the runtime type of the
argument and returns it unchanged.
argument and returns the argument unchanged.
"""
print(f"Runtime type is {type(obj).__name__!r}", file=sys.stderr)
return obj
@@ -3364,7 +3460,7 @@ def override[F: _Func](method: F, /) -> F:
Usage::
class Base:
def method(self) -> None: ...
def method(self) -> None:
pass
class Child(Base):