# mypy: allow-untyped-defs
from functools import partial

import torch
import torch.nn.functional as F

from .expanded_weights_impl import implements_per_sample_grads
from .expanded_weights_utils import (
    forward_helper,
    set_grad_sample_if_exists,
    standard_kwargs,
    unpack_expanded_weight_or_tensor,
)


@implements_per_sample_grads(F.instance_norm)
class InstanceNormPerSampleGrad(torch.autograd.Function):
    @staticmethod
    # pyrefly: ignore [bad-override]
    def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
        instance_norm = partial(torch.instance_norm, cudnn_enabled=True)
        expanded_args, expanded_kwargs = standard_kwargs(
            kwarg_names, expanded_args_and_kwargs
        )
        output = forward_helper(instance_norm, expanded_args, expanded_kwargs)
        ctx.input = expanded_args[0]
        ctx.running_mean, ctx.running_var = (
            expanded_kwargs["running_mean"],
            expanded_kwargs["running_var"],
        )
        ctx.weight, ctx.bias, ctx.eps = (
            expanded_kwargs["weight"],
            expanded_kwargs["bias"],
            expanded_kwargs["eps"],
        )
        return output

    @staticmethod
    # pyrefly: ignore [bad-override]
    def backward(ctx, grad_output):
        input, running_mean, running_var = ctx.input, ctx.running_mean, ctx.running_var
        weight, bias, eps = ctx.weight, ctx.bias, ctx.eps

        results: list[torch.Tensor | None] = []
        results.append(None)  # for kwarg names
        results.append(None)  # for op reference
        if input.requires_grad:
            b = input.shape[0]
            c = input.shape[1]
            new_shape = (1, b * c, *input.shape[2:])

            weight_ = unpack_expanded_weight_or_tensor(
                weight, lambda orig_weight: orig_weight.repeat(b)
            )
            running_mean_ = running_mean.repeat(b) if running_mean is not None else None
            running_var_ = running_var.repeat(b) if running_var is not None else None
            input_reshaped = input.contiguous().view(new_shape)
            grad_output_reshaped = grad_output.contiguous().view(new_shape)
            mean = torch.mean(
                input_reshaped, (0,) + tuple(range(2, input.dim())), False
            )
            var = torch.var(
                input_reshaped,
                (0,) + tuple(range(2, input.dim())),
                keepdim=False,
                unbiased=False,
            )
            rstd = 1 / torch.sqrt(var + eps)

            # must use native batch norm since it supports all inputs. This may have used cuda or openmi during the forward but
            # it didn't save the metadata, so we don't know during the backward
            res = torch.ops.aten.native_batch_norm_backward(
                grad_output_reshaped,
                input_reshaped,
                weight_,
                running_mean_,
                running_var_,
                mean,
                rstd,
                True,
                eps,
                (True, False, False),
            )
            results.append(res[0].reshape(input.shape))
        else:
            results.append(None)

        # weight and bias don't compute batched gradients; no other arguments are differentiable (2 are not saved from the forward)
        results = results + [None] * 7

        # set grad_sample field for weight and bias with per sample gradients
        set_grad_sample_if_exists(
            weight,
            lambda _: torch.einsum(
                "ni...->ni", F.instance_norm(input, eps=eps) * grad_output
            ),
        )
        set_grad_sample_if_exists(
            bias, lambda _: torch.einsum("ni...->ni", grad_output)
        )
        return tuple(results)
