diff --git a/src/stubgen.py b/src/stubgen.py index 9bfd4c26..5c480670 100755 --- a/src/stubgen.py +++ b/src/stubgen.py @@ -56,18 +56,31 @@ class and repeatedly call ``.put()`` to register modules or contents within the import argparse import builtins import enum -from inspect import Signature, Parameter, signature, ismodule, getmembers -import textwrap import importlib import importlib.machinery import importlib.util +import re +import sys +import textwrap import types import typing from dataclasses import dataclass -from typing import Dict, Sequence, List, Optional, Tuple, cast, Generator, Any, Callable, Union, Protocol, Literal +from inspect import Parameter, Signature, getmembers, ismodule, signature from pathlib import Path -import re -import sys +from typing import ( + Any, + Callable, + Dict, + Generator, + List, + Literal, + Optional, + Protocol, + Sequence, + Tuple, + Union, + cast, +) if sys.version_info < (3, 9): from typing import Match, Pattern @@ -250,11 +263,28 @@ def __init__( + sep_after ) - # Precompile RE to extract nanobind nd-arrays - self.ndarray_re = re.compile( - sep_before + r"(numpy\.ndarray|ndarray|torch\.Tensor)\[([^\]]*)\]" + # Precompile RE to extract known nd-arrays + self.known_ndarray_re = re.compile( + sep_before + + "(" + + "|".join( + [ + r"numpy\.ndarray", + r"torch\.Tensor", + r"tensorflow\.python\.framework\.ops\.EagerTensor", + r"jaxlib\.xla_extension\.DeviceArray", + ] + ) + + ")" + + r"\[([^\]]*)\]" ) + # Precompile RE to extract nanobind nd-arrays + self.nb_ndarray_re = re.compile(sep_before + "(ndarray)" + r"\[([^\]]*)\]") + + # Insert ndarray class + self.ndarray_class = False + # Types which moved from typing.* to collections.abc in Python 3.9 self.abc_re = re.compile( 'typing.(AsyncGenerator|AsyncIterable|AsyncIterator|Awaitable|Callable|' @@ -606,7 +636,10 @@ def simplify_types(self, s: str) -> str: - "NoneType" -> "None" - - "ndarray[...]" -> "Annotated[ArrayLike, dict(...)]" + - "[...]" -> "Annotated[, dict(...)]" + + - "ndarray[...]" -> "Annotated[NDArray, dict(...)]" + (with array protocol class added at top) - "collections.abc.X" -> "X" (with "from collections.abc import X" added at top) @@ -616,22 +649,62 @@ def simplify_types(self, s: str) -> str: changed to 'collections.abc' on newer Python versions) """ - # Process nd-array type annotations so that MyPy accepts them - def process_ndarray(m: Match[str]) -> str: - s = m.group(2) + # Process nd-array type annotations with metadata + def process_known_ndarray(m: Match[str]) -> str: + ndarray_type = m.group(1) + meta = m.group(2) - ndarray = self.import_object("numpy.typing", "ArrayLike") - assert ndarray - s = re.sub(r"dtype=([\w]*)\b", r"dtype='\g<1>'", s) - s = s.replace("*", "None") + if not meta: + return ndarray_type - if s: + if ndarray_type == "numpy.ndarray": + dm = re.search(r"dtype=([\w]*)\b", meta) + if dm and dm.group(1): + dtype = dm.group(1).replace("bool", "bool_") + ndarray_type = f"numpy.typing.NDArray[numpy.{dtype}]" + + meta = re.sub(r"dtype=([\w]*)\b", r"dtype='\g<1>'", meta) + meta = meta.replace("*", "None") + + if sys.version_info >= (3, 9, 0): annotated = self.import_object("typing", "Annotated") - return f"{annotated}[{ndarray}, dict({s})]" else: - return ndarray + annotated = self.import_object("typing_extensions", "Annotated") + return f"{annotated}[{ndarray_type}, dict({meta})]" + + s = self.known_ndarray_re.sub(process_known_ndarray, s) + + # Process nb-ndarray type annotations with metadata + def process_nb_ndarray(m: Match[str]) -> str: + ndarray_type = "NDArray" + meta = m.group(2) - s = self.ndarray_re.sub(process_ndarray, s) + self.ndarray_class = True + + self.import_object("typing", "Protocol") + if sys.version_info >= (3, 12, 0): + self.import_object("collections.abc", "Buffer") + else: + self.import_object("typing_extensions", "Buffer") + if sys.version_info >= (3, 10, 0): + self.import_object("typing", "TypeAlias") + else: + self.import_object("typing", "Union") + self.import_object("typing_extensions", "TypeAlias") + + if not meta: + return ndarray_type + + meta = re.sub(r"dtype=([\w]*)\b", r"dtype='\g<1>'", meta) + meta = meta.replace("*", "None") + + if sys.version_info >= (3, 9, 0): + annotated = self.import_object("typing", "Annotated") + else: + annotated = self.import_object("typing_extensions", "Annotated") + return f"{annotated}[{ndarray_type}, dict({meta})]" + + s = self.nb_ndarray_re.sub(process_nb_ndarray, s) if sys.version_info >= (3, 9, 0): s = self.abc_re.sub(r'collections.abc.\1', s) @@ -1143,12 +1216,31 @@ def get(self) -> str: s += items_v0 if len(items_v0) <= 70 else items_v1 s += "\n\n" + s += self.put_ndarray_class() # Append the main generated stub s += self.output return s.rstrip() + "\n" + def put_ndarray_class(self) -> str: + s = "" + if not self.ndarray_class: + return s + + s += "class DLPackBuffer(Protocol):\n" + s += " def __dlpack__(self) -> object: ...\n" + s += "\n" + if sys.version_info >= (3, 12, 0): + s += "type NDArray = Buffer | DLPackBuffer\n" + elif sys.version_info >= (3, 10, 0): + s += "NDArray: TypeAlias = Buffer | DLPackBuffer\n" + else: + s += "NDArray: TypeAlias = Union[Buffer, DLPackBuffer]\n" + s += "\n" + + return s + def parse_options(args: List[str]) -> argparse.Namespace: parser = argparse.ArgumentParser( prog="python -m nanobind.stubgen", diff --git a/tests/test_ndarray_ext.pyi.ref b/tests/test_ndarray_ext.pyi.ref index aecbe0cf..a62465cf 100644 --- a/tests/test_ndarray_ext.pyi.ref +++ b/tests/test_ndarray_ext.pyi.ref @@ -1,160 +1,168 @@ -from typing import Annotated, overload +from collections.abc import Buffer -from numpy.typing import ArrayLike +import numpy +import numpy.typing +import torch +from typing import Annotated, Protocol, overload + +class DLPackBuffer(Protocol): + def __dlpack__(self) -> object: ... + +type NDArray = Buffer | DLPackBuffer class Cls: def __init__(self) -> None: ... - def f1(self) -> Annotated[ArrayLike, dict(dtype='float32')]: ... + def f1(self) -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32')]: ... - def f2(self) -> Annotated[ArrayLike, dict(dtype='float32')]: ... + def f2(self) -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32')]: ... - def f1_ri(self) -> Annotated[ArrayLike, dict(dtype='float32')]: ... + def f1_ri(self) -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32')]: ... - def f2_ri(self) -> Annotated[ArrayLike, dict(dtype='float32')]: ... + def f2_ri(self) -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32')]: ... - def f3_ri(self, arg: object, /) -> Annotated[ArrayLike, dict(dtype='float32')]: ... + def f3_ri(self, arg: object, /) -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32')]: ... -def accept_ro(arg: Annotated[ArrayLike, dict(dtype='float32', writable=False, shape=(2))], /) -> float: ... +def accept_ro(arg: Annotated[NDArray, dict(dtype='float32', writable=False, shape=(2))], /) -> float: ... -def accept_rw(arg: Annotated[ArrayLike, dict(dtype='float32', shape=(2))], /) -> float: ... +def accept_rw(arg: Annotated[NDArray, dict(dtype='float32', shape=(2))], /) -> float: ... -def cast(arg: bool, /) -> ArrayLike: ... +def cast(arg: bool, /) -> numpy.ndarray: ... def check(arg: object, /) -> bool: ... -def check_bool(arg: ArrayLike, /) -> bool: ... +def check_bool(arg: NDArray, /) -> bool: ... @overload -def check_device(arg: Annotated[ArrayLike, dict(device='cpu')], /) -> str: ... +def check_device(arg: Annotated[NDArray, dict(device='cpu')], /) -> str: ... @overload -def check_device(arg: Annotated[ArrayLike, dict(device='cuda')], /) -> str: ... +def check_device(arg: Annotated[NDArray, dict(device='cuda')], /) -> str: ... -def check_float(arg: ArrayLike, /) -> bool: ... +def check_float(arg: NDArray, /) -> bool: ... @overload -def check_order(arg: Annotated[ArrayLike, dict(order='C')], /) -> str: ... +def check_order(arg: Annotated[NDArray, dict(order='C')], /) -> str: ... @overload -def check_order(arg: Annotated[ArrayLike, dict(order='F')], /) -> str: ... +def check_order(arg: Annotated[NDArray, dict(order='F')], /) -> str: ... @overload -def check_order(arg: ArrayLike, /) -> str: ... +def check_order(arg: NDArray, /) -> str: ... -def check_ro_by_const_ref_const_float64(arg: Annotated[ArrayLike, dict(dtype='float64', writable=False, shape=(None))], /) -> bool: ... +def check_ro_by_const_ref_const_float64(arg: Annotated[NDArray, dict(dtype='float64', writable=False, shape=(None))], /) -> bool: ... -def check_ro_by_const_ref_ro(arg: Annotated[ArrayLike, dict(writable=False)], /) -> bool: ... +def check_ro_by_const_ref_ro(arg: Annotated[NDArray, dict(writable=False)], /) -> bool: ... -def check_ro_by_rvalue_ref_const_float64(arg: Annotated[ArrayLike, dict(dtype='float64', writable=False, shape=(None))], /) -> bool: ... +def check_ro_by_rvalue_ref_const_float64(arg: Annotated[NDArray, dict(dtype='float64', writable=False, shape=(None))], /) -> bool: ... -def check_ro_by_rvalue_ref_ro(arg: Annotated[ArrayLike, dict(writable=False)], /) -> bool: ... +def check_ro_by_rvalue_ref_ro(arg: Annotated[NDArray, dict(writable=False)], /) -> bool: ... -def check_ro_by_value_const_float64(arg: Annotated[ArrayLike, dict(dtype='float64', writable=False, shape=(None))], /) -> bool: ... +def check_ro_by_value_const_float64(arg: Annotated[NDArray, dict(dtype='float64', writable=False, shape=(None))], /) -> bool: ... -def check_ro_by_value_ro(arg: Annotated[ArrayLike, dict(writable=False)], /) -> bool: ... +def check_ro_by_value_ro(arg: Annotated[NDArray, dict(writable=False)], /) -> bool: ... -def check_rw_by_const_ref(arg: ArrayLike, /) -> bool: ... +def check_rw_by_const_ref(arg: NDArray, /) -> bool: ... -def check_rw_by_const_ref_float64(arg: Annotated[ArrayLike, dict(dtype='float64', shape=(None))], /) -> bool: ... +def check_rw_by_const_ref_float64(arg: Annotated[NDArray, dict(dtype='float64', shape=(None))], /) -> bool: ... -def check_rw_by_rvalue_ref(arg: ArrayLike, /) -> bool: ... +def check_rw_by_rvalue_ref(arg: NDArray, /) -> bool: ... -def check_rw_by_rvalue_ref_float64(arg: Annotated[ArrayLike, dict(dtype='float64', shape=(None))], /) -> bool: ... +def check_rw_by_rvalue_ref_float64(arg: Annotated[NDArray, dict(dtype='float64', shape=(None))], /) -> bool: ... -def check_rw_by_value(arg: ArrayLike, /) -> bool: ... +def check_rw_by_value(arg: NDArray, /) -> bool: ... -def check_rw_by_value_float64(arg: Annotated[ArrayLike, dict(dtype='float64', shape=(None))], /) -> bool: ... +def check_rw_by_value_float64(arg: Annotated[NDArray, dict(dtype='float64', shape=(None))], /) -> bool: ... -def check_shape_ptr(arg: ArrayLike, /) -> bool: ... +def check_shape_ptr(arg: NDArray, /) -> bool: ... -def check_stride_ptr(arg: ArrayLike, /) -> bool: ... +def check_stride_ptr(arg: NDArray, /) -> bool: ... def destruct_count() -> int: ... -def fill_view_1(x: ArrayLike) -> None: ... +def fill_view_1(x: NDArray) -> None: ... -def fill_view_2(x: Annotated[ArrayLike, dict(dtype='float32', shape=(None, None), device='cpu')]) -> None: ... +def fill_view_2(x: Annotated[NDArray, dict(dtype='float32', shape=(None, None), device='cpu')]) -> None: ... -def fill_view_3(x: Annotated[ArrayLike, dict(dtype='float32', shape=(3, 4), order='C', device='cpu')]) -> None: ... +def fill_view_3(x: Annotated[NDArray, dict(dtype='float32', shape=(3, 4), order='C', device='cpu')]) -> None: ... -def fill_view_4(x: Annotated[ArrayLike, dict(dtype='float32', shape=(3, 4), order='F', device='cpu')]) -> None: ... +def fill_view_4(x: Annotated[NDArray, dict(dtype='float32', shape=(3, 4), order='F', device='cpu')]) -> None: ... -def fill_view_5(x: Annotated[ArrayLike, dict(dtype='complex64', shape=(2, 2), order='C', device='cpu')]) -> None: ... +def fill_view_5(x: Annotated[NDArray, dict(dtype='complex64', shape=(2, 2), order='C', device='cpu')]) -> None: ... -def fill_view_6(x: Annotated[ArrayLike, dict(dtype='complex64', shape=(2, 2), order='C', device='cpu')]) -> None: ... +def fill_view_6(x: Annotated[NDArray, dict(dtype='complex64', shape=(2, 2), order='C', device='cpu')]) -> None: ... -def get_is_valid(array: Annotated[ArrayLike, dict(writable=False)] | None) -> bool: ... +def get_is_valid(array: Annotated[NDArray, dict(writable=False)] | None) -> bool: ... -def get_itemsize(array: ArrayLike | None) -> int: ... +def get_itemsize(array: NDArray | None) -> int: ... -def get_nbytes(array: ArrayLike | None) -> int: ... +def get_nbytes(array: NDArray | None) -> int: ... -def get_shape(array: Annotated[ArrayLike, dict(writable=False)]) -> list: ... +def get_shape(array: Annotated[NDArray, dict(writable=False)]) -> list: ... -def get_size(array: ArrayLike | None) -> int: ... +def get_size(array: NDArray | None) -> int: ... -def get_stride(array: ArrayLike, i: int) -> int: ... +def get_stride(array: NDArray, i: int) -> int: ... -def implicit(array: Annotated[ArrayLike, dict(dtype='float32', order='C', shape=(2, 2))]) -> int: ... +def implicit(array: Annotated[NDArray, dict(dtype='float32', order='C', shape=(2, 2))]) -> int: ... @overload -def initialize(arg: Annotated[ArrayLike, dict(dtype='float32', shape=(10), device='cpu')], /) -> None: ... +def initialize(arg: Annotated[NDArray, dict(dtype='float32', shape=(10), device='cpu')], /) -> None: ... @overload -def initialize(arg: Annotated[ArrayLike, dict(dtype='float32', shape=(10, None), device='cpu')], /) -> None: ... +def initialize(arg: Annotated[NDArray, dict(dtype='float32', shape=(10, None), device='cpu')], /) -> None: ... -def inspect_ndarray(arg: ArrayLike, /) -> None: ... +def inspect_ndarray(arg: NDArray, /) -> None: ... -def make_contig(arg: Annotated[ArrayLike, dict(order='C')], /) -> Annotated[ArrayLike, dict(order='C')]: ... +def make_contig(arg: Annotated[NDArray, dict(order='C')], /) -> Annotated[NDArray, dict(order='C')]: ... -def noimplicit(array: Annotated[ArrayLike, dict(dtype='float32', order='C', shape=(2, 2))]) -> int: ... +def noimplicit(array: Annotated[NDArray, dict(dtype='float32', order='C', shape=(2, 2))]) -> int: ... -def noop_2d_f_contig(arg: Annotated[ArrayLike, dict(dtype='float32', shape=(None, None), order='F')], /) -> None: ... +def noop_2d_f_contig(arg: Annotated[NDArray, dict(dtype='float32', shape=(None, None), order='F')], /) -> None: ... -def noop_3d_c_contig(arg: Annotated[ArrayLike, dict(dtype='float32', shape=(None, None, None), order='C')], /) -> None: ... +def noop_3d_c_contig(arg: Annotated[NDArray, dict(dtype='float32', shape=(None, None, None), order='C')], /) -> None: ... -def pass_bool(array: Annotated[ArrayLike, dict(dtype='bool')]) -> None: ... +def pass_bool(array: Annotated[NDArray, dict(dtype='bool')]) -> None: ... -def pass_complex64(array: Annotated[ArrayLike, dict(dtype='complex64')]) -> None: ... +def pass_complex64(array: Annotated[NDArray, dict(dtype='complex64')]) -> None: ... -def pass_complex64_const(array: Annotated[ArrayLike, dict(dtype='complex64', writable=False)]) -> None: ... +def pass_complex64_const(array: Annotated[NDArray, dict(dtype='complex64', writable=False)]) -> None: ... -def pass_float32(array: Annotated[ArrayLike, dict(dtype='float32')]) -> None: ... +def pass_float32(array: Annotated[NDArray, dict(dtype='float32')]) -> None: ... -def pass_float32_const(array: Annotated[ArrayLike, dict(dtype='float32', writable=False)]) -> None: ... +def pass_float32_const(array: Annotated[NDArray, dict(dtype='float32', writable=False)]) -> None: ... -def pass_float32_shaped(array: Annotated[ArrayLike, dict(dtype='float32', shape=(3, None, 4))]) -> None: ... +def pass_float32_shaped(array: Annotated[NDArray, dict(dtype='float32', shape=(3, None, 4))]) -> None: ... -def pass_float32_shaped_ordered(array: Annotated[ArrayLike, dict(dtype='float32', order='C', shape=(None, None, 4))]) -> None: ... +def pass_float32_shaped_ordered(array: Annotated[NDArray, dict(dtype='float32', order='C', shape=(None, None, 4))]) -> None: ... -def pass_uint32(array: Annotated[ArrayLike, dict(dtype='uint32')]) -> None: ... +def pass_uint32(array: Annotated[NDArray, dict(dtype='uint32')]) -> None: ... -def passthrough(arg: ArrayLike, /) -> ArrayLike: ... +def passthrough(arg: NDArray, /) -> NDArray: ... -def passthrough_arg_none(arg: ArrayLike | None) -> ArrayLike: ... +def passthrough_arg_none(arg: NDArray | None) -> NDArray: ... -def passthrough_copy(arg: ArrayLike, /) -> ArrayLike: ... +def passthrough_copy(arg: NDArray, /) -> NDArray: ... -def process(arg: Annotated[ArrayLike, dict(dtype='uint8', shape=(None, None, 3), order='C', device='cpu')], /) -> None: ... +def process(arg: Annotated[NDArray, dict(dtype='uint8', shape=(None, None, 3), order='C', device='cpu')], /) -> None: ... -def ret_array_scalar() -> Annotated[ArrayLike, dict(dtype='float32')]: ... +def ret_array_scalar() -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32')]: ... -def ret_numpy() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4))]: ... +def ret_numpy() -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32', shape=(2, 4))]: ... -def ret_numpy_const() -> Annotated[ArrayLike, dict(dtype='float32', writable=False, shape=(2, 4))]: ... +def ret_numpy_const() -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32', writable=False, shape=(2, 4))]: ... -def ret_numpy_const_ref() -> Annotated[ArrayLike, dict(dtype='float32', writable=False, shape=(2, 4))]: ... +def ret_numpy_const_ref() -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32', writable=False, shape=(2, 4))]: ... -def ret_numpy_half() -> Annotated[ArrayLike, dict(dtype='float16', shape=(2, 4))]: ... +def ret_numpy_half() -> Annotated[numpy.typing.NDArray[numpy.float16], dict(dtype='float16', shape=(2, 4))]: ... -def ret_pytorch() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4))]: ... +def ret_pytorch() -> Annotated[torch.Tensor, dict(dtype='float32', shape=(2, 4))]: ... -def return_dlpack() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4))]: ... +def return_dlpack() -> Annotated[NDArray, dict(dtype='float32', shape=(2, 4))]: ... @overload -def set_item(arg0: Annotated[ArrayLike, dict(dtype='float64', shape=(None), order='C')], arg1: int, /) -> None: ... +def set_item(arg0: Annotated[NDArray, dict(dtype='float64', shape=(None), order='C')], arg1: int, /) -> None: ... @overload -def set_item(arg0: Annotated[ArrayLike, dict(dtype='complex128', shape=(None), order='C')], arg1: int, /) -> None: ... +def set_item(arg0: Annotated[NDArray, dict(dtype='complex128', shape=(None), order='C')], arg1: int, /) -> None: ... diff --git a/tests/test_stubs.py b/tests/test_stubs.py index c6e1121f..36e03204 100644 --- a/tests/test_stubs.py +++ b/tests/test_stubs.py @@ -19,6 +19,12 @@ def remove_platform_dependent(s): v.startswith('def test_slots()') or \ v.startswith('TypeAlias'): i += 2 + elif v.startswith('from typing') or \ + v.startswith('from typing_extensions') or \ + v.startswith('from collections.abc'): + i += 1 + elif v.startswith('class DLPackBuffer(Protocol)'): + i += 5 else: s2.append(v) i += 1