# mypy: allow-untyped-defs
import functools
import math
import traceback
from dataclasses import dataclass, field
from enum import auto, Enum
from typing import Any

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable.contract import _get_registry
from torch.distributed.tensor import DeviceMesh, DTensor, Shard
from torch.distributed.tensor._dtensor_spec import DTensorSpec

from ._fsdp_api import DataParallelMeshDims


def _dynamo_disable(func):
    """Disable dynamo tracing for FSDP hooks."""

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        return torch._dynamo.disable(
            func, recursive=True, reason="skipping FSDP hooks"
        )(*args, **kwargs)

    return wrapper


@dataclass
class DataParallelMeshInfo:
    mesh: DeviceMesh
    shard_mesh_dim: int | None = None
    replicate_mesh_dim: int | None = None
    dp_mesh_dims: DataParallelMeshDims | None = None
    # The full SPMD mesh (excluding PP dims) that params are distributed on.
    # Must include all non-PP SPMD dims (e.g. DP + TP); passing a submesh
    # that omits dims like TP will lead to incorrect behavior.
    spmd_mesh: DeviceMesh | None = field(default=None, repr=False)
    is_spmd_mesh: bool = field(default=False, init=False, repr=False)

    def __post_init__(self):
        if self.shard_mesh_dim is None and self.replicate_mesh_dim is None:
            raise AssertionError(
                "At least one of shard_mesh_dim and replicate_mesh_dim must not be None"
            )
        self.is_spmd_mesh = self.dp_mesh_dims is not None


@dataclass
class FSDPMeshInfo(DataParallelMeshInfo):
    def __post_init__(self):
        super().__post_init__()
        if self.shard_mesh_dim is None:
            raise AssertionError("Expects non-None shard_mesh_dim")
        self.shard_mesh_size: int = self.mesh.size(self.shard_mesh_dim)
        self.shard_process_group = self.mesh.get_group(self.shard_mesh_dim)
        self.shard_mesh_rank: int = self.shard_process_group.rank()


@dataclass
class DDPMeshInfo(DataParallelMeshInfo):
    def __post_init__(self):
        super().__post_init__()
        if self.replicate_mesh_dim is None:
            raise AssertionError("Expects non-None replicate_mesh_dim")
        self.replicate_mesh_size: int = self.mesh.size(self.replicate_mesh_dim)
        self.replicate_process_group = self.mesh.get_group(self.replicate_mesh_dim)
        self.replicate_mesh_rank: int = self.replicate_process_group.rank()


@dataclass
class HSDPMeshInfo(FSDPMeshInfo, DDPMeshInfo):
    def __post_init__(self):  # pylint:disable=useless-parent-delegation
        # Calls `FSDPMeshInfo` -> `DDPMeshInfo` -> `DataParallelMeshInfo`
        super().__post_init__()


class TrainingState(Enum):
    """Describes the training state of one FSDP state / parameter group."""

    # Transition to forward starting pre-forward until post-forward
    FORWARD = auto()
    # Transition to pre-backward when unsharding in backward
    PRE_BACKWARD = auto()
    # Transition to post-backward when resharding and reducing gradients
    POST_BACKWARD = auto()
    # Idle before/after forward or before pre-backward/after post-backward
    IDLE = auto()


def _raise_assert_with_print(*args: Any, **kwargs: Any):
    print(f"[Rank {dist.get_rank()}] ", end="")
    print(*args, **kwargs)
    traceback.print_stack()
    raise AssertionError(*args, **kwargs)


def _is_composable_with_fsdp(module: nn.Module) -> bool:
    registry = _get_registry(module)
    if registry is None:
        return True
    # Registry keys by function name
    return "replicate" not in registry


def _get_dim0_padded_size(tensor_size: torch.Size, dim0_factor: int) -> torch.Size:
    padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor
    return torch.Size([padded_dim0]) + tensor_size[1:]


def _chunk_with_empty(
    tensor: torch.Tensor, num_chunks: int, dim: int
) -> list[torch.Tensor]:
    chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
    while len(chunks) < num_chunks:
        chunks.append(chunks[0].new_empty(0))
    return chunks


def _get_dim_chunked_size(
    chunk: torch.Tensor, unchunked_size: torch.Size, dim: int
) -> torch.Size:
    if chunk.numel() > 0:
        return chunk.size()
    # For 0 numel, we need to preserve nonzero-sized dims for DTensor APIs
    # pyrefly: ignore [bad-return]
    return unchunked_size[:dim] + torch.Size([0]) + unchunked_size[dim + 1 :]


def _from_local_no_grad(
    local_tensor: torch.Tensor,
    sharding_spec: DTensorSpec,
) -> DTensor:
    """
    This method is similar to ``DTensor.from_local()`` except that in eager mode
    it avoids some CPU overhead by avoiding default args and not being differentiable.
    """
    # pyrefly: ignore [bad-argument-type]
    return DTensor(
        # Use the local tensor directly instead of constructing a new tensor
        # variable, e.g. with `view_as()`, since this is not differentiable
        # pyrefly: ignore [bad-argument-count]
        local_tensor,
        sharding_spec,
        # pyrefly: ignore [unexpected-keyword]
        requires_grad=local_tensor.requires_grad,
    )


def _to_dtype_if_needed(
    tensor: torch.Tensor, dtype: torch.dtype | None
) -> torch.Tensor:
    if dtype is not None and tensor.dtype != dtype:
        return tensor.to(dtype)
    return tensor


def _cast_fp_tensor(dtype: torch.dtype, x: torch.Tensor) -> torch.Tensor:
    if (
        not isinstance(x, torch.Tensor)
        or not torch.is_floating_point(x)
        or x.dtype == dtype
    ):
        return x
    return x.to(dtype)


def is_bw() -> bool:
    return torch._C._current_graph_task_id() != -1


@dataclass
class ShardPlacementResult:
    placement: Shard | None
    mesh_info: FSDPMeshInfo


ShardPlacementFnResult = Shard | ShardPlacementResult | None


def resolve_shard_placement(
    result: ShardPlacementFnResult,
    default_mesh_info: FSDPMeshInfo,
) -> ShardPlacementResult:
    """Resolve the shard_placement_fn result to a ShardPlacementResult.

    Handles different input types and applies defaults:
    - None: Use default sharding (Shard(0)) on default mesh
    - Shard: Use specified shard dimension on default mesh
    - ShardPlacementResult: Use as-is

    Args:
        result: The return value from shard_placement_fn, or None if no fn provided.
        default_mesh_info: The default FSDPMeshInfo to use if not specified.

    Returns:
        A ShardPlacementResult with placement and mesh_info.
    """
    if result is None:
        return ShardPlacementResult(placement=None, mesh_info=default_mesh_info)
    if isinstance(result, Shard):
        return ShardPlacementResult(placement=result, mesh_info=default_mesh_info)
    if isinstance(result, ShardPlacementResult):
        return result
    raise ValueError(f"Invalid shard_placement_fn result: {result}")
