# mypy: allow-untyped-defs

import torch.distributed.rpc as rpc
import torch.distributed.rpc._testing  # noqa: F401
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
    RpcAgentTestFixture,
)


# The following message types are currently retried in the RREF protocol and
# distributed autograd. Thus only these messages should be tested with the
# Faulty RPC Agent.
retryable_message_types = [
    "RREF_FORK_REQUEST",
    "RREF_CHILD_ACCEPT",
    "RREF_USER_DELETE",
    "CLEANUP_AUTOGRAD_CONTEXT_REQ",
]

# The following messages incur the corresponding delay in seconds while being
# processed in FaultyTensorPipeAgent's enqueueSend() function.
default_messages_to_delay = {
    "PYTHON_CALL": 1.5,  # Python UDF
    "SCRIPT_CALL": 1.5,  # Script/Builtin
}


class FaultyRpcAgentTestFixture(RpcAgentTestFixture):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.messages_to_fail = retryable_message_types
        self.messages_to_delay = default_messages_to_delay

    @property
    def rpc_backend(self):
        return rpc.backend_registry.BackendType["FAULTY_TENSORPIPE"]

    @property
    def rpc_backend_options(self):
        return rpc.backend_registry.construct_rpc_backend_options(
            self.rpc_backend,
            init_method=self.init_method,
            num_worker_threads=8,
            num_fail_sends=3,
            messages_to_fail=self.messages_to_fail,
            messages_to_delay=self.messages_to_delay,
        )

    def setup_fault_injection(self, faulty_messages, messages_to_delay):
        if faulty_messages is not None:
            self.messages_to_fail = faulty_messages
        if messages_to_delay is not None:
            self.messages_to_delay = messages_to_delay

    def get_shutdown_error_regex(self):
        error_regexes = [
            "Exception in thread pool task",
            "Connection reset by peer",
            "Connection closed by peer",
        ]
        return "|".join([f"({error_str})" for error_str in error_regexes])

    def get_timeout_error_regex(self):
        return "RPC ran for more than"
