"""Static typing helpers."""

from __future__ import annotations

from types import EllipsisType
from typing import Protocol, TypeAlias

# TODO import from typing (requires Python >=3.12)
from typing_extensions import override

# TODO: use array-api-typing once it is available

class Array(Protocol):  # pylint: disable=missing-class-docstring
    # Unary operations
    def __abs__(self) -> Array: ...
    def __pos__(self) -> Array: ...
    def __neg__(self) -> Array: ...
    def __invert__(self) -> Array: ...
    # Binary operations
    def __add__(self, other: Array | complex, /) -> Array: ...
    def __sub__(self, other: Array | complex, /) -> Array: ...
    def __mul__(self, other: Array | complex, /) -> Array: ...
    def __truediv__(self, other: Array | complex, /) -> Array: ...
    def __floordiv__(self, other: Array | complex, /) -> Array: ...
    def __mod__(self, other: Array | complex, /) -> Array: ...
    def __pow__(self, other: Array | complex, /) -> Array: ...
    def __matmul__(self, other: Array, /) -> Array: ...
    def __and__(self, other: Array | int, /) -> Array: ...
    def __or__(self, other: Array | int, /) -> Array: ...
    def __xor__(self, other: Array | int, /) -> Array: ...
    def __lshift__(self, other: Array | int, /) -> Array: ...
    def __rshift__(self, other: Array | int, /) -> Array: ...
    def __lt__(self, other: Array | complex, /) -> Array: ...
    def __le__(self, other: Array | complex, /) -> Array: ...
    def __gt__(self, other: Array | complex, /) -> Array: ...
    def __ge__(self, other: Array | complex, /) -> Array: ...
    @override
    def __eq__(self, other: Array | complex, /) -> Array: ...  # type: ignore[override]  # pyright: ignore[reportIncompatibleMethodOverride]
    @override
    def __ne__(self, other: Array | complex, /) -> Array: ...  # type: ignore[override]  # pyright: ignore[reportIncompatibleMethodOverride]
    # Reflected operations
    def __radd__(self, other: Array | complex, /) -> Array: ...
    def __rsub__(self, other: Array | complex, /) -> Array: ...
    def __rmul__(self, other: Array | complex, /) -> Array: ...
    def __rtruediv__(self, other: Array | complex, /) -> Array: ...
    def __rfloordiv__(self, other: Array | complex, /) -> Array: ...
    def __rmod__(self, other: Array | complex, /) -> Array: ...
    def __rpow__(self, other: Array | complex, /) -> Array: ...
    def __rmatmul__(self, other: Array, /) -> Array: ...
    def __rand__(self, other: Array | int, /) -> Array: ...
    def __ror__(self, other: Array | int, /) -> Array: ...
    def __rxor__(self, other: Array | int, /) -> Array: ...
    def __rlshift__(self, other: Array | int, /) -> Array: ...
    def __rrshift__(self, other: Array | int, /) -> Array: ...
    # Attributes
    @property
    def dtype(self) -> DType: ...
    @property
    def device(self) -> Device: ...
    @property
    def mT(self) -> Array: ...  # pylint: disable=invalid-name
    @property
    def ndim(self) -> int: ...
    @property
    def shape(self) -> tuple[int | None, ...]: ...
    @property
    def size(self) -> int | None: ...
    @property
    def T(self) -> Array: ...  # pylint: disable=invalid-name
    # Collection operations (note: an Array does not have to be Sized or Iterable)
    def __getitem__(self, key: GetIndex, /) -> Array: ...
    def __setitem__(self, key: SetIndex, value: Array | complex, /) -> None: ...
    # Materialization methods (may raise on lazy arrays)
    def __bool__(self) -> bool: ...
    def __complex__(self) -> complex: ...
    def __float__(self) -> float: ...
    def __index__(self) -> int: ...
    def __int__(self) -> int: ...

    # Misc methods (frequently not implemented in Arrays wrapped by array-api-compat)
    # def __array_namespace__(*, api_version: str | None) -> ModuleType: ...
    # def __dlpack__(
    #     *,
    #     stream: int | Any | None = None,
    #     max_version: tuple[int, int] | None = None,
    #     dl_device: tuple[int, int] | None = None,  # tuple[Enum, int]
    #     copy: bool | None = None,
    # ) -> Any: ...
    # def __dlpack_device__() -> tuple[int, int]: ...  # tuple[Enum, int]
    # def to_device(device: Device, /, *, stream: int | Any | None = None) -> Array: ...

class DType(Protocol):  # pylint: disable=missing-class-docstring
    pass

class Device(Protocol):  # pylint: disable=missing-class-docstring
    pass

SetIndex: TypeAlias = (
    int | slice | EllipsisType | Array | tuple[int | slice | EllipsisType | Array, ...]
)
GetIndex: TypeAlias = (
    SetIndex | None | tuple[int | slice | EllipsisType | None | Array, ...]
)

__all__ = ["Array", "DType", "Device", "GetIndex", "SetIndex"]
