import abc
import importlib
import pickle
from typing import Any

import torch


def _serialize_triton_kernel(kernel: Any) -> tuple[str, str]:
    """
    Serialize a triton kernel by extracting its module path and function name.
    Returns (module_path, function_name) tuple.

    Triton JITFunction objects contain unpicklable _thread.RLock objects, so we
    serialize the import path instead and reimport on load.

    Raises:
        RuntimeError: If the kernel cannot be serialized (missing attributes).
    """
    fn = getattr(kernel, "fn", None)
    module_path = fn and getattr(fn, "__module__", None)
    func_name = fn and getattr(fn, "__name__", None)
    if fn is None or module_path is None or func_name is None:
        raise RuntimeError(
            f"Kernel fn missing __module__ or __name__: "
            f"module={module_path}, name={func_name}. "
            f"Cannot serialize for precompilation."
        )
    return (module_path, func_name)


def _deserialize_triton_kernel(kernel_info: tuple[str, str]) -> Any:
    """
    Deserialize a triton kernel by reimporting from its module.
    kernel_info is (module_path, function_name) tuple.
    """
    module_path, func_name = kernel_info
    module = importlib.import_module(module_path)
    kernel = getattr(module, func_name)
    return kernel


# Note: [Triton Kernel Side Table Serialization]
#
# When dynamo captures user-defined triton kernels, it creates FX graph nodes
# (triton_kernel_wrapper_mutation/functional) with a `kernel_idx` parameter that
# references the global `kernel_side_table` in triton_kernel_wrap.py. This side
# table maps integer indices to actual triton kernel objects.
#
# For kernels that go through inductor's codegen path, this is fine - inductor
# looks up the kernel from the side table at codegen time and embeds the kernel
# source code directly into the generated wrapper. The compiled code doesn't
# need the side table at runtime.
#
# However, not all triton kernels go through inductor codegen. When using
# regional_inductor, only annotated regions are compiled by inductor. Triton
# kernels outside these regions are executed via the FX interpreter, which
# calls the higher-order op directly and needs the kernel to be in the side
# table at runtime.
#
# When serializing/deserializing bundled AOT artifacts across process boundaries,
# the kernel_side_table is empty in the new process, causing:
#   AssertionError: Kernel index X not found in id_to_kernel
#
# To fix this, we capture the kernel_side_table state during serialization and
# restore it during deserialization. Kernels are serialized by their import path
# (module_path, function_name) since triton JITFunction objects contain
# unpicklable RLock objects.


class SerializableCallable(abc.ABC):
    @classmethod
    @abc.abstractmethod
    def serialize_compile_artifacts(cls, fn: Any) -> bytes:
        pass

    @classmethod
    @abc.abstractmethod
    def deserialize_compile_artifacts(cls, data: bytes) -> Any:
        pass

    @abc.abstractmethod
    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        pass


class GraphModuleSerializableCallable(SerializableCallable):
    def __init__(self, graph_module: torch.fx.GraphModule) -> None:
        assert isinstance(graph_module, torch.fx.GraphModule)
        self.graph_module = graph_module

    @classmethod
    def serialize_compile_artifacts(
        cls, fn: "GraphModuleSerializableCallable"
    ) -> bytes:
        from torch.fx._graph_pickler import GraphPickler, Options

        state = fn.__dict__.copy()

        graph_module = state["graph_module"]
        for node in graph_module.graph.nodes:
            node.meta.pop("nn_module_stack", None)
            node.meta.pop("source_fn_stack", None)
            node.meta.pop("example_value", None)

        state["graph_module"] = GraphPickler.dumps(
            graph_module, Options(ops_filter=None)
        )
        return pickle.dumps(state)

    @classmethod
    def deserialize_compile_artifacts(cls, data: bytes) -> Any:
        from torch._subclasses import FakeTensorMode
        from torch.fx._graph_pickler import GraphPickler
        from torch.fx.experimental.symbolic_shapes import ShapeEnv

        state = pickle.loads(data)

        fake_mode = FakeTensorMode(shape_env=ShapeEnv())
        state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
        assert isinstance(state["graph_module"], torch.fx.GraphModule)
        state["graph_module"].recompile()

        return cls(**state)

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        return self.graph_module.forward(*args, **kwargs)


class BundledAOTAutogradSerializableCallable(SerializableCallable):
    """
    Represents a serializable callable generated by compile_fx.
    This class wraps around the compiled function generated by AOTAutograd.

    TODO: Instead of using PrecompileContext to grab it from AOTAutograd,
    this object should be what's *returned* by aot_module_simplified.
    We'll do that refactor in a later PR.
    """

    def __init__(self, compiled_fn: Any) -> None:
        """
        Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form
        of a compiled function generated by AOTAutograd.
        """
        assert hasattr(compiled_fn, "serialize")
        self.compiled_fn = compiled_fn

    def __getattr__(self, attr: Any) -> Any:
        return getattr(self.compiled_fn, attr)

    @classmethod
    def serialize_compile_artifacts(
        cls, fn: "BundledAOTAutogradSerializableCallable"
    ) -> bytes:
        from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table

        # See Note: [Triton Kernel Side Table Serialization]
        # Capture triton kernel side table state BEFORE serialization.
        triton_kernels: dict[int, tuple[str, str]] = {
            idx: _serialize_triton_kernel(kernel)
            for idx, kernel in kernel_side_table.id_to_kernel.items()
        }
        triton_constant_args: dict[int, dict[str, Any]] = dict(
            kernel_side_table.constant_args
        )

        with torch._functorch.config.patch("bundled_autograd_cache", True):
            serialized_entry = fn.compiled_fn.serialize()
            # Bundle the triton kernel side table with the serialized entry
            bundle = (serialized_entry, triton_kernels, triton_constant_args)
            result = pickle.dumps(bundle)
            return result

    @classmethod
    def deserialize_compile_artifacts(cls, data: bytes) -> Any:
        from torch._functorch._aot_autograd.aot_autograd_result import (
            deserialize_bundled_cache_entry,
        )
        from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table

        bundle = pickle.loads(data)

        # Handle both old format (just entry) and new format (entry, kernels, const_args)
        if isinstance(bundle, tuple) and len(bundle) == 3:
            entry, triton_kernels, triton_constant_args = bundle
        else:
            # Backwards compatibility with old serialized artifacts
            entry = bundle
            # pyrefly: ignore [implicit-any]
            triton_kernels = {}
            # pyrefly: ignore [implicit-any]
            triton_constant_args = {}

        # See Note: [Triton Kernel Side Table Serialization]
        # Restore triton kernel side table BEFORE deserializing the compiled function.
        # The compiled function may reference kernels by index if any triton kernels
        # don't go through inductor codegen (e.g., triton kernels outside of
        # regional_inductor compiled regions).
        for idx, kernel_info in triton_kernels.items():
            kernel = _deserialize_triton_kernel(kernel_info)
            kernel_side_table.id_to_kernel[idx] = kernel
            kernel_side_table.kernel_to_id[kernel] = idx

        for idx, args in triton_constant_args.items():
            kernel_side_table.constant_args[idx] = args

        compiled_fn = deserialize_bundled_cache_entry(entry)
        return cls(compiled_fn)

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        return self.compiled_fn(*args, **kwargs)
