"""Typings for function definitions."""

from __future__ import annotations

from typing import TypeVar, Union

from onnxscript import (
    BFLOAT16,
    BOOL,
    COMPLEX128,
    COMPLEX64,
    DOUBLE,
    FLOAT,
    FLOAT16,
    INT16,
    INT32,
    INT64,
    INT8,
    STRING,
    UINT8,
)


# NOTE: We do not care about unsigned types beyond UINT8 because PyTorch does not us them.
# More detail can be found: https://pytorch.org/docs/stable/tensors.html

TensorType = Union[  # noqa: UP007
    BFLOAT16,
    BOOL,
    COMPLEX64,
    COMPLEX128,
    DOUBLE,
    FLOAT,
    FLOAT16,
    INT8,
    INT16,
    INT32,
    INT64,
    UINT8,
]
_FloatType = Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16]  # noqa: UP007
IntType = Union[INT8, INT16, INT32, INT64]  # noqa: UP007
RealType = Union[  # noqa: UP007
    BFLOAT16,
    FLOAT16,
    FLOAT,
    DOUBLE,
    INT8,
    INT16,
    INT32,
    INT64,
]

TTensor = TypeVar("TTensor", bound=TensorType)
# Duplicate TTensor for inputs/outputs that accept the same set of types as TTensor
# but do not constrain the type to be the same as the other inputs/outputs
TTensor2 = TypeVar("TTensor2", bound=TensorType)
TTensorOrString = TypeVar("TTensorOrString", bound=Union[TensorType, STRING])  # noqa: UP007
TFloat = TypeVar("TFloat", bound=_FloatType)
TFloatOrUInt8 = TypeVar(
    "TFloatOrUInt8",
    bound=Union[FLOAT, FLOAT16, DOUBLE, INT8, UINT8],  # noqa: UP007
)
TInt = TypeVar("TInt", bound=IntType)
TReal = TypeVar("TReal", bound=RealType)
TRealUnlessInt16OrInt8 = TypeVar(
    "TRealUnlessInt16OrInt8",
    bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16, INT32, INT64],  # noqa: UP007
)
TRealUnlessFloat16OrInt8 = TypeVar(
    "TRealUnlessFloat16OrInt8",
    bound=Union[DOUBLE, FLOAT, INT16, INT32, INT64],  # noqa: UP007
)
TRealOrUInt8 = TypeVar("TRealOrUInt8", bound=Union[RealType, UINT8])  # noqa: UP007
TFloatHighPrecision = TypeVar("TFloatHighPrecision", bound=Union[FLOAT, DOUBLE])  # noqa: UP007
