from __future__ import annotations

from typing import Any, TYPE_CHECKING

from torch.fx.experimental.unification import Var  # type: ignore[attr-defined]

from ._compatibility import compatibility


if TYPE_CHECKING:
    from collections.abc import Sequence

    from torch.fx.experimental.migrate_gradual_types.constraint import DVar


__all__ = ["Dyn", "TensorType", "is_consistent", "is_more_precise"]


@compatibility(is_backward_compatible=False)
class TensorType:
    """
    TensorType defines a type for tensors, which consists of a list of dimensions.
    Example:
        class M(torch.nn.Module):
            def forward(self, x:TensorType((1,2,3, Dyn)), y:TensorType((1,2,3, Dyn))):
                return torch.add(x, y)
    """

    __args__: Sequence[DVar | int | _DynType]

    def __init__(self, dim: Sequence[Any]) -> None:
        self.__origin__ = TensorType
        self.__args__ = dim

    def __repr__(self) -> str:
        return f"TensorType[{self.__args__}]"

    def __eq__(self, other: object) -> bool:
        if isinstance(other, self.__class__):
            return list(self.__args__) == list(other.__args__)
        else:
            return False

    @staticmethod
    def __class_getitem__(*args: object) -> TensorType:
        if len(args) == 1 and isinstance(args[0], tuple):
            args = args[0]
        return TensorType(tuple(args))


class _DynType:
    """
    _DynType defines a type which stands for the absence of type information.
    """

    def __init__(self) -> None:
        self.__name__ = "_DynType"

    def __eq__(self, other: object) -> bool:
        return isinstance(other, self.__class__)

    def __str__(self) -> str:
        return "Dyn"

    def __repr__(self) -> str:
        return "Dyn"


Dyn = _DynType()


@compatibility(is_backward_compatible=False)
def is_consistent(t1: object, t2: object) -> bool:
    """
    A binary relation denoted by ~ that determines if t1 is consistent with t2.
    The relation is reflexive, symmetric but not transitive.
    returns True if t1 and t2 are consistent and False otherwise.
    Example:
        Dyn ~ TensorType((1,2,3))
        int ~ Dyn
        int ~ int
        TensorType((1,Dyn,3)) ~ TensorType((1,2,3))
    """

    if t1 == t2:
        return True

    if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var):
        return True

    if isinstance(t1, TensorType) and isinstance(t2, TensorType):
        return len(t1.__args__) == len(t2.__args__) and all(
            is_consistent(elem1, elem2)
            for elem1, elem2 in zip(t1.__args__, t2.__args__)
        )
    else:
        return False


@compatibility(is_backward_compatible=False)
def is_more_precise(t1: object, t2: object) -> bool:
    """
    A binary relation denoted by <= that determines if t1 is more precise than t2.
    The relation is reflexive and transitive.
    returns True if t1 is more precise than t2 and False otherwise.
    Example:
        Dyn >= TensorType((1,2,3))
        int >= Dyn
        int >= int
        TensorType((1,Dyn,3)) <= TensorType((1,2,3))
    """
    if t1 == t2:
        return True

    if isinstance(t2, _DynType):
        return True

    if isinstance(t1, TensorType) and isinstance(t2, TensorType):
        return len(t1.__args__) == len(t2.__args__) and all(
            is_more_precise(elem1, elem2)
            for elem1, elem2 in zip(t1.__args__, t2.__args__)
        )

    else:
        return False
