from __future__ import annotations

import warnings
from contextlib import contextmanager
from typing import Any, cast, TYPE_CHECKING

import torch
import torch.utils._pytree as pytree
from torch._guards import detect_fake_mode
from torch._library.opaque_object import is_opaque_type
from torch._opaque_base import OpaqueBase
from torch._subclasses import FakeTensor, FakeTensorMode
from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.utils._python_dispatch import is_traceable_wrapper_subclass

from .. import config
from .descriptors import BufferAOTInput, DifferentiableAOTInput, ParamAOTInput
from .schemas import AOTConfig, FakifiedFlatArgs


if TYPE_CHECKING:
    from collections.abc import Generator, KeysView


static_inputs_log = torch._logging.getArtifactLogger(
    __name__, "cudagraph_static_inputs"
)


def process_inputs(
    flat_args: list[Any],
    aot_config: AOTConfig,
    fake_mode: FakeTensorMode,
    shape_env: ShapeEnv | None,
    ignore_shape_env: bool = False,
) -> tuple[FakifiedFlatArgs, list[int]]:
    """Convert real tensor inputs into fake tensors for AOT autograd tracing.

    Called at compile time (not runtime) to produce the fake inputs that AOT
    autograd traces through. Each real tensor is converted to a FakeTensor
    via ``fake_mode.from_tensor``, preserving shape, dtype, device, and
    symbolic shape information from the ShapeEnv. Non-tensor inputs (ints,
    SymInts, ScriptObjects) are converted or passed through as appropriate.

    Tensor subclass inputs (DTensor, etc.) are fakified recursively by
    walking their ``__tensor_flatten__`` attrs. AsyncCollectiveTensors are
    resolved via ``trigger_wait()`` before fakification so they don't appear
    in the traced metadata (see below).

    Called from ``aot_function``, ``aot_module_simplified``, and
    ``aot_export_module`` — anywhere AOT autograd needs fake inputs before
    graph capture.

    Returns:
        A tuple of (fakified_args, act_input_indices) where act_input_indices
        records which positions held AsyncCollectiveTensors. These indices are
        stored on ViewAndMutationMeta so that the runtime wrapper can emit
        direct trigger_wait() calls on those positions.
    """
    # Resolve AsyncCollectiveTensors before tracing. ACTs are transient
    # eager-mode wrappers for async collective overlap; if they leak into the
    # traced graph as input types, AOT autograd records them in
    # SubclassCreationMeta for output tangent metadata. At runtime, autograd
    # produces plain tensor tangents, causing a type mismatch. Unwrapping
    # here prevents ACT from appearing in the traced metadata.
    try:
        from torch.distributed._functional_collectives import AsyncCollectiveTensor
    except ImportError:
        AsyncCollectiveTensor = None

    act_input_indices: list[int] = []
    if AsyncCollectiveTensor is not None:
        for i, a in enumerate(flat_args):
            if isinstance(a, AsyncCollectiveTensor):
                act_input_indices.append(i)
                flat_args[i] = a.trigger_wait()

    with fake_mode:

        def convert(idx: int, x: Any) -> Any:
            nonlocal ignore_shape_env
            if shape_env is not None and not ignore_shape_env:
                from torch._dynamo.source import ConstantSource

                if isinstance(x, int):
                    # We always specialize on scalar values in export.
                    if aot_config.is_export:
                        return x
                    source = ConstantSource(f"sym_{idx}")
                    return shape_env.create_symintnode(
                        shape_env.create_symbol(x, source, positive=x >= 0),
                        hint=x,
                        source=source,
                    )
            if isinstance(x, torch.ScriptObject) or is_opaque_type(type(x)):
                return torch._library.fake_class_registry.maybe_to_fake_obj(
                    fake_mode, x
                )
            if not isinstance(x, torch.Tensor):
                return x
            if isinstance(x, FakeTensor):
                # In the case of cross compilation we will have example inputs
                # with a different fake mode than our tracing fake mode.
                # In these cases we want to clone the fake tensor into our
                # inner fake mode.
                if x.fake_mode is not fake_mode:
                    return fake_mode.from_tensor(x)
                return x
            if is_traceable_wrapper_subclass(x):
                attrs, _ = x.__tensor_flatten__()
                # See if all inner tensors are FakeTensors from this mode
                all_this_fake = True
                for a in attrs:
                    match getattr(x, a):
                        case FakeTensor() as v:
                            if v.fake_mode is not fake_mode:
                                # FakeTensor subclass from a different mode.
                                # Fall through to refakify.
                                all_this_fake = False
                                break
                        case torch.Tensor():
                            all_this_fake = False
                            break
                        case OpaqueBase():
                            pass
                        case unexpected:
                            raise AssertionError(
                                f"expected Tensor or OpaqueBase, got {type(unexpected)}"
                            )

                if all_this_fake:
                    return x

            # see note [Tensor Fakification and Symbol Caching]
            symbolic_context = None
            source = None
            trace = True
            if tracing_context := torch._guards.TracingContext.try_get():
                if x in tracing_context.tensor_to_context:
                    symbolic_context = tracing_context.tensor_to_context[x]
                    source = symbolic_context.tensor_source
                    # We already fakeified this tensor in Dynamo, don't
                    # dump the trace for it again
                    trace = False
            if (
                idx < aot_config.num_params_buffers
                and config.static_weight_shapes
                and not symbolic_context
            ):
                # TODO: Ensure that this codepath is never exercised from
                # Dynamo
                return fake_mode.from_tensor(x, static_shapes=True)

            result = fake_mode.from_tensor(
                x,
                static_shapes=ignore_shape_env,
                symbolic_context=symbolic_context,
                source=source,
                trace=trace,
            )
            return result

        return FakifiedFlatArgs(
            [convert(idx, x) for idx, x in enumerate(flat_args)]
        ), act_input_indices


def construct_fake_mode(
    flat_args: list[Any], aot_config: AOTConfig
) -> tuple[FakeTensorMode, ShapeEnv | None]:
    fake_mode = detect_fake_mode(flat_args)
    if fake_mode is None:
        shape_env = ShapeEnv() if aot_config.dynamic_shapes else None
        fake_mode = FakeTensorMode(shape_env=shape_env)
    else:
        shape_env = fake_mode.shape_env
    return (fake_mode, shape_env)


def _try_get_metadata_from_dynamo(
    mod: torch.nn.Module,
    param_keys: KeysView[str],
    full_args_num: int,
    full_args_descs: list[DifferentiableAOTInput],
) -> tuple[list[torch._guards.Source | None] | None, list[int]]:
    """
    Metadata is forwarded from Dynamo to AOTDispatch via special fields on GraphModule.
    We first verify that `mod` does come from Dynamo, then we handle cases where
    metadata might be missing.

    Returns:
        aot_autograd_arg_pos_to_source: used to dedup params and their guards
        static_input_indices: used to identify static inputs for cudagraphs
    """
    # Note [Assumption on Dynamo Metadata]
    # This function assumes a graph module from dynamo provides `dynamo_compiled_id`,
    # _param_name_to_source, and every placeholder node has `_dynamo_source` attributes.
    # When gm is modified (e.g., DDPOptimizer via split_module), metadata needs to
    # be propagated in order to be recognized as a dynamo graph

    if not (isinstance(mod, torch.fx.GraphModule) and "dynamo_compile_id" in mod.meta):
        # graph was not captured by dynamo
        return None, []

    if not hasattr(mod, "_param_name_to_source"):
        # is from export
        static_input_indices = [
            i
            for i, node in enumerate(full_args_descs)
            if isinstance(node, (ParamAOTInput, BufferAOTInput))
        ]
        return None, static_input_indices

    # We now know this came from dynamo, and (1) we care about guards,
    # so setting up aot_autograd_arg_pos_to_source for downstream dedup guards
    # can now be done safely. (2) Dynamo logic protects the 1:1 sizing below.
    # Additionally, we mark static indices for cudagraphs.
    param_name_to_source = cast(
        dict[str, torch._guards.Source], mod._param_name_to_source
    )
    seen_sources = set()

    aot_autograd_arg_pos_to_source: list[torch._guards.Source | None] = []
    static_input_indices = []
    # Collect the new inputs lifted by aotdispatch
    for i, name in enumerate(param_keys):
        if name not in param_name_to_source:
            raise AssertionError(f"{name} not found in param_name_to_source")
        source = param_name_to_source[name]
        if source in seen_sources:
            raise AssertionError(f"source {source} already in seen_sources")
        if source is None:
            raise AssertionError(f"source must not be None for {name}")
        seen_sources.add(source)
        aot_autograd_arg_pos_to_source.append(source)

        static_input_indices.append(i)

    # Collect the dynamo graph inputs
    # TODO(mlazos): Revisit if this is still needed. With Dynamo install ID
    # matched tensors back into the Fx graph, this might not be necessary.
    for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")):
        if not hasattr(node, "_dynamo_source"):
            raise AssertionError(f"node {node} must have _dynamo_source attribute")
        source = node._dynamo_source
        # `source`` specifies the source from user code. ddp optimizer may have
        # intermediate values becoming submodule placeholders which does not
        # have a source
        if source is not None and source in seen_sources:
            raise AssertionError(f"source {source} already in seen_sources")
        seen_sources.add(source)
        aot_autograd_arg_pos_to_source.append(source)
        source_name = source.name if source else str(source)

        # input[i] in dynamo is now:
        # input[i + len(extra_params)] in AOT,
        # where extra_params are the params/buffers that dynamo baked into the
        # OutputGraph
        actual_pos = pos + len(param_keys)

        if "tensor_dict" in node.meta and node.meta["tensor_dict"].get(
            "_dynamo_static_input_type", None
        ):
            static_inputs_log.debug(
                "Adding static input pos %s for source %s", actual_pos, source_name
            )
            static_input_indices.append(actual_pos)
        else:
            static_inputs_log.debug(
                "Non-static input pos %s for source %s", actual_pos, source_name
            )

    if full_args_num != len(aot_autograd_arg_pos_to_source):
        raise AssertionError(
            f"full_args_num={full_args_num} != len(aot_autograd_arg_pos_to_source)={len(aot_autograd_arg_pos_to_source)}"
        )
    return aot_autograd_arg_pos_to_source, static_input_indices


@contextmanager
def _detect_attribute_assignment(mod: torch.nn.Module) -> Generator[None, None, None]:
    # Do not allow assignment of tensor attributes during export unless
    # the attribute is registered as a buffer.

    NN_MODULE_STD_ATTRS = [
        "_backward_hooks",
        "_backward_pre_hooks",
        "_buffers",
        "_forward_hooks",
        "_forward_hooks_always_called",
        "_forward_hooks_with_kwargs",
        "_forward_pre_hooks",
        "_forward_pre_hooks_with_kwargs",
        "_is_full_backward_hook",
        "_load_state_dict_post_hooks",
        "_load_state_dict_pre_hooks",
        "_modules",
        "_non_persistent_buffers_set",
        "_parameters",
        "_state_dict_hooks",
        "_state_dict_pre_hooks",
        "training",
    ]
    NN_MODULE_LAZY_STD_ATTRS = [
        "_initialize_hook",
        "_load_hook",
    ]
    STD_ATTRS = {
        *NN_MODULE_STD_ATTRS,
        *NN_MODULE_LAZY_STD_ATTRS,
    }

    def _get_attributes(mod: torch.nn.Module) -> dict[str, Any]:
        # return any attributes of a module that are not standard attributes
        return {k: v for k, v in mod.__dict__.items() if k not in STD_ATTRS}

    def _get_all_module_attributes(mod: torch.nn.Module) -> dict[str, dict[str, Any]]:
        # return attributes from all modules and submodules
        result = {}
        for name, submodule in mod.named_modules():
            result[name] = _get_attributes(submodule)
        return result

    def _restore_all_module_attributes(
        mod: torch.nn.Module, snapshot: dict[str, dict[str, Any]]
    ) -> None:
        # restore attributes to all modules and submodules
        for name, submodule in mod.named_modules():
            if name in snapshot:
                submodule.__dict__.update(snapshot[name])

    # save state of attributes before enter
    snapshot = pytree.tree_map(
        lambda x: x,
        _get_all_module_attributes(mod),
        is_leaf=lambda x: type(x) in _pytree_subclasses_that_lose_info,
    )
    try:
        yield
    finally:
        # after exit, compare state of attributes with snapshot
        # to detect which tensor attributes were assigned

        def _collect_assigned_tensor_attributes(
            snapshot: dict[str, dict[str, Any]], new_attrs: dict[str, dict[str, Any]]
        ) -> list[str]:
            assigned_tensor_attributes = []

            def _compare_values(path: str, old_val: Any, new_val: Any) -> None:
                """Recursively compare values, handling containers."""
                # Same object, no change
                if old_val is new_val:
                    return

                if old_val is None or new_val is None:
                    if isinstance(new_val, torch.Tensor):
                        assigned_tensor_attributes.append(path)
                    return

                # Check if it's a tensor that was reassigned
                if isinstance(new_val, torch.Tensor):
                    assigned_tensor_attributes.append(path)
                    return

                # Handle dict containers
                if isinstance(old_val, dict) and isinstance(new_val, dict):
                    all_keys = set(old_val.keys()) | set(new_val.keys())
                    for key in all_keys:
                        old_item = old_val.get(key)
                        new_item = new_val.get(key)
                        _compare_values(f"{path}[{key!r}]", old_item, new_item)
                    return

                # Handle list/tuple containers
                if isinstance(old_val, (list, tuple)) and isinstance(
                    new_val, (list, tuple)
                ):
                    # Different lengths = mutation happened
                    max_len = max(len(old_val), len(new_val))
                    for i in range(max_len):
                        old_item = old_val[i] if i < len(old_val) else None
                        new_item = new_val[i] if i < len(new_val) else None
                        _compare_values(f"{path}[{i}]", old_item, new_item)
                    return

                # For other types, just check if they're different objects
                # (we don't care about non-tensor mutations)

            for module_name in snapshot.keys() | new_attrs.keys():
                old_module_attrs = snapshot.get(module_name, {})
                new_module_attrs = new_attrs.get(module_name, {})

                for attr_name in old_module_attrs.keys() | new_module_attrs.keys():
                    module_prefix = f"self.{module_name}." if module_name else "self."
                    full_path = f"{module_prefix}{attr_name}"

                    old_val = old_module_attrs.get(attr_name)
                    new_val = new_module_attrs.get(attr_name)
                    _compare_values(full_path, old_val, new_val)

            return assigned_tensor_attributes

        new_attrs = _get_all_module_attributes(mod)
        assigned_tensor_attributes = _collect_assigned_tensor_attributes(
            snapshot, new_attrs
        )
        # restore state of all attributes (including, e.g., of primitive types)
        _restore_all_module_attributes(mod, snapshot)

        if assigned_tensor_attributes:
            if len(assigned_tensor_attributes) > 1:
                noun, verb = "attributes", "were"
            else:
                noun, verb = "attribute", "was"
            warnings.warn(
                f"The tensor {noun} {', '.join(assigned_tensor_attributes)} {verb} assigned during export. "
                "Such attributes must be registered as buffers using the `register_buffer` API "
                "(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer).",
                stacklevel=2,
            )
