# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import math
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from typing import Any, cast

import torch
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._op_schema import (
    OpSchema,
    OpSpec,
    OpStrategy,
    PlacementList,
    RuntimeSchemaInfo,
    TupleStrategy,
)
from torch.distributed.tensor._ops.single_dim_strategy import (
    _ShardingPlaceholder,
    register_single_dim_strategy,
)
from torch.distributed.tensor._ops.utils import (
    as_list,
    expand_to_full_mesh_op_strategy,
    generate_redistribute_costs,
    is_tensor_evenly_shardable,
    is_tensor_evenly_shardable_on_dim,
    normalize_dim,
    normalize_dims,
    register_op_strategy,
)
from torch.distributed.tensor._utils import normalize_to_torch_size
from torch.distributed.tensor.placement_types import (
    _is_shard_like,
    _StridedShard,
    Partial,
    Placement,
    Replicate,
    Shard,
)


aten = torch.ops.aten
prims = torch.ops.prims


class Reduction(Enum):
    NONE = 0
    MEAN = 1
    SUM = 2


@dataclass(frozen=True)
class NormReduction:
    norm_type: int | float


ReductionOpType = NormReduction | str


@dataclass(frozen=True)
class _NormPartial(Partial):
    """
    This placement is used for partial p-norm (p not in {inf, -inf, 0}).

    For p-norms, the p-norm over n elements computes (sum_i x_i^p)^(1/p).
    For example, consider 2 ranks, a (4,) tensor sharded on dim-0, and 2-norm:
        Rank 0: [t1, t2] | Rank 1: [t3, t4]
    After computing 2-norm per gradient (partial placement):
        Rank 0: [sqrt(t1^2 + t2^2)] | Rank 1: [sqrt(t3^2 + t4^2)]
    Converting from partial to replicate wants to ultimately get:
        Rank 0/1: [sqrt(t1^2 + t2^2 + t3^2 + t4^2)]
    This is achieved by: x^p -> allreduce sum -> x^(1/p).
    """

    norm_type: int | float = 2

    def __init__(self, norm_type: int | float = 2):
        super().__init__("sum")
        object.__setattr__(self, "norm_type", norm_type)

    def _partition_value(
        self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
    ) -> torch.Tensor:
        return tensor / math.pow(mesh.size(mesh_dim), 1 / self.norm_type)

    def _reduce_shard_value(
        self,
        tensor: torch.Tensor,
        mesh: DeviceMesh,
        mesh_dim: int,
        shard_spec: Placement,
    ) -> torch.Tensor:
        if not isinstance(shard_spec, Shard):
            raise AssertionError(f"Expected Shard, got {type(shard_spec)}")
        tensor = self._pre_reduce_transform(tensor)
        reduced_tensor = super()._reduce_shard_value(tensor, mesh, mesh_dim, shard_spec)
        return self._post_reduce_transform(reduced_tensor)

    def _reduce_value(
        self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
    ) -> torch.Tensor:
        tensor = self._pre_reduce_transform(tensor)
        reduced_tensor = super()._reduce_value(tensor, mesh, mesh_dim)
        return self._post_reduce_transform(reduced_tensor)

    def _pre_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor:
        return tensor**self.norm_type

    def _post_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor:
        return tensor ** (1.0 / self.norm_type)

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, _NormPartial):
            return False
        return self.norm_type == other.norm_type

    def __hash__(self) -> int:
        return 1 + hash(self.norm_type)

    def __repr__(self) -> str:
        return f"_NormPartial({self.norm_type})"

    def __str__(self) -> str:
        return f"_NormP({self.norm_type})"


def _infer_reduction_dims(dims_arg: object, ndim: int) -> list[int] | None:
    if dims_arg is None:
        return None
    dims = cast(list[int], as_list(dims_arg))
    dims = cast(list[int], normalize_dims(dims, ndim))
    empty_dims = [[0], [-1], []]
    if ndim == 0 and dims_arg in empty_dims:
        return None
    return dims


def _infer_reduce_dims_map(
    reduction_dims: list[int], input_ndim: int, keep_dim=False
) -> list[int]:
    reduction_dims_map = []
    new_dim_count = 0
    for input_dim in range(input_ndim):
        if input_dim in reduction_dims and not keep_dim:
            # if input dim in reduction dims, mark it as -1
            reduction_dims_map.append(-1)
        else:
            # otherwise mark it as the new dim
            reduction_dims_map.append(new_dim_count)
            new_dim_count += 1

    return reduction_dims_map


def _replicate_dims_start_at(
    placements: Sequence[Placement], start_dim: int = 0
) -> tuple[Placement, ...]:
    new_placements: list[Placement] = []
    for p in placements:
        if p.is_partial() or (_is_shard_like(p) and p.dim >= start_dim):
            new_placements.append(Replicate())  # make it replicate
        else:
            new_placements.append(p)  # keep the placement
    return tuple(new_placements)


# return new_placements which align with placements but skip the skipped_dim
# Precondition: no shard-like placement on skipped_dim (callers must
# replicate it first via replicate_reduction_dims).
def _skip_dim(
    placements: tuple[Placement, ...], skipped_dim: int
) -> tuple[Placement, ...]:
    new_placements: list[Placement] = []
    for p in placements:
        if isinstance(p, _StridedShard) and p.dim >= skipped_dim:
            new_placements.append(_StridedShard(p.dim - 1, split_factor=p.split_factor))
        elif isinstance(p, Shard) and p.dim >= skipped_dim:
            new_placements.append(Shard(p.dim - 1))
        else:
            new_placements.append(p)
    return tuple(new_placements)


def replicate_reduction_dims(
    placements: tuple[Placement, ...], reduction_dims: list[int]
) -> tuple[Placement, ...]:
    # replicate the reduction dims if not reduction_linear
    new_placements: list[Placement] = []

    for p in placements:
        if p.is_partial():
            new_placements.append(Replicate())
        elif _is_shard_like(p) and p.dim in reduction_dims:
            new_placements.append(Replicate())
        else:
            new_placements.append(p)

    return tuple(new_placements)


def map_placements_after_reduction(
    placements: tuple[Placement, ...],
    reduction_dims: list[int],
    reduction_dims_map: list[int],
    reduction_op: ReductionOpType,
) -> tuple[Placement, ...]:
    """
    Map each placement based on the output shape after reduction.
    """
    new_placements: list[Placement] = []
    for placement in placements:
        if isinstance(placement, (Replicate, Partial)):
            new_placements.append(placement)
        else:
            if not _is_shard_like(placement):
                raise AssertionError(
                    f"Expected Shard/_StridedShard, got {type(placement)}"
                )
            shard_dim = placement.dim
            new_shard_dim = reduction_dims_map[shard_dim]
            if new_shard_dim == -1 or shard_dim in reduction_dims:
                # if new_shard_dim collapsed or its in the reduction dims
                # (i.e. for the case where keepdims=True), we generate partial
                new_placements.append(get_placement_from_reduction_op(reduction_op))
            else:
                if isinstance(placement, _StridedShard):
                    new_placements.append(
                        _StridedShard(
                            new_shard_dim, split_factor=placement.split_factor
                        )
                    )
                elif isinstance(placement, Shard):
                    new_placements.append(Shard(new_shard_dim))
    return tuple(new_placements)


def get_placement_from_reduction_op(reduction_op: ReductionOpType) -> Placement:
    if isinstance(reduction_op, NormReduction):
        if reduction_op.norm_type == 0:
            # return P(sum) for easier reduction_linear handling.
            return Partial("sum")
        return _NormPartial(norm_type=reduction_op.norm_type)
    return Partial(reduction_op)


def common_reduction_strategy(
    input_strategy: OpStrategy,
    reduce_dims: list[int],
    keep_dim: bool = False,
    reduction_linear: bool = True,
    reduction_op: ReductionOpType = "sum",
) -> OpStrategy:
    """
    reduction_linear means that the reduction `f` follows this rule:
        f([f(a), f(b)]) = f([a, b])

    reduction linear should be super set of linearity.
    """
    # by default follow reduction input strategy
    reduction_strategy = OpStrategy([])

    for op_spec in input_strategy.strategies:
        if reduction_op == "avg":
            output_spec = op_spec.output_spec
            local_shape = list(output_spec.tensor_meta.shape)  # type:ignore[union-attr]
            for dim in reduce_dims:
                if not is_tensor_evenly_shardable_on_dim(local_shape, output_spec, dim):
                    # reduce(avg) is not linear for unevenly sharded tensors
                    reduction_linear = False
                    break

        for p in op_spec.output_spec.placements:
            # when the partial reduction op matches the global reduction op,
            # we can delay redistribution (i.e max, max)
            if isinstance(p, Partial) and p.reduce_op != reduction_op:
                reduction_linear = False
                break

        if not reduction_linear:
            # input placements for this strategy should clear out pending sum and sharding
            # on the reduction dimension
            input_placements = replicate_reduction_dims(
                op_spec.output_spec.placements, reduce_dims
            )
        else:
            input_placements = op_spec.output_spec.placements

        input_spec = DTensorSpec(
            mesh=input_strategy.mesh,
            placements=input_placements,
            tensor_meta=op_spec.output_spec.tensor_meta,
        )

        reduce_dims_map = _infer_reduce_dims_map(reduce_dims, input_spec.ndim, keep_dim)
        out_placements = map_placements_after_reduction(
            input_spec.placements, reduce_dims, reduce_dims_map, reduction_op
        )
        redistribute_cost = [generate_redistribute_costs(input_strategy, input_spec)]
        reduction_strategy.strategies.append(
            OpSpec(
                output_specs=DTensorSpec(
                    mesh=input_strategy.mesh,
                    placements=out_placements,
                ),
                input_specs=(input_spec,),
                redistribute_cost=redistribute_cost,
            )
        )

    return reduction_strategy


LINEAR_REDUCTION_OP_MAP = {
    aten.all.default: "product",
    aten.all.dim: "product",
    aten.sum.default: "sum",
    aten.sum.dim_IntList: "sum",
    prims.sum.default: "sum",
    aten.any.default: "sum",
    aten.any.dim: "sum",
    aten.any.dims: "sum",
    aten.any.out: "sum",
    # These are only valid when there is no padding
    aten.prod.default: "product",
    aten.prod.dim_int: "product",
    aten.prod.int_out: "product",
    prims.prod.default: "product",
    # avg is only linear when there is no padding
    aten.mean.default: "avg",
    aten.mean.dim: "avg",
    aten.mean.out: "avg",
    aten.max.default: "max",
    aten.max.out: "max",
    aten.min.default: "min",
    aten.min.out: "min",
    aten.amax.default: "max",
    aten.amax.out: "max",
    aten.amin.default: "min",
    aten.amin.out: "min",
    aten.nansum.default: "sum",
}

# argmax/argmin return indices which cannot be combined with P(max/min).
# They need special handling that forces redistribution on reduction dims.
ARGMAX_ARGMIN_OPS = {
    aten.argmax.default: "max",
    aten.argmin.default: "min",
}


@register_op_strategy(
    list(LINEAR_REDUCTION_OP_MAP.keys()), schema_info=RuntimeSchemaInfo(1)
)
def linear_reduction_strategy(op_schema: OpSchema) -> OpStrategy:
    args_schema = op_schema.args_schema
    input_strategy = args_schema[0]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")

    dims = None
    if len(op_schema.args_schema) > 1:
        dims = _infer_reduction_dims(args_schema[1], input_strategy.ndim)

    reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims

    keep_dim = len(op_schema.args_schema) > 2 and bool(op_schema.args_schema[2])
    reduction_op = LINEAR_REDUCTION_OP_MAP[op_schema.op]
    return common_reduction_strategy(
        input_strategy,
        reduce_dims,
        keep_dim=keep_dim,
        reduction_linear=True,
        reduction_op=reduction_op,
    )


# max.dim/min.dim return (values, indices). Indices are local to each shard
# and cannot be combined across ranks, so we force Replicate on reduction dims
# (same approach as argmax/argmin).
@register_single_dim_strategy(
    [aten.max.dim, aten.min.dim], schema_info=RuntimeSchemaInfo(1)
)
def max_min_dim_single_dim_strategy(
    op: torch._ops.OpOverload,
    args_schema: tuple[Any, ...],
    kwargs_schema: dict[str, Any],
) -> list[list[Placement | _ShardingPlaceholder]]:
    input_meta = args_schema[0]
    if not isinstance(input_meta, TensorMeta):
        raise AssertionError(f"Expected TensorMeta, got {type(input_meta)}")

    ndim = len(input_meta.shape)
    dim = normalize_dim(cast(int, args_schema[1]), ndim)
    keep_dim = len(args_schema) > 2 and bool(args_schema[2])

    strategies: list[list[Placement | _ShardingPlaceholder]] = []
    for d in range(ndim):
        if d == dim:
            continue
        out_d = d if keep_dim or d < dim else d - 1
        # [values, indices, input]: shard on non-reduction dim
        strategies.append(
            [
                _ShardingPlaceholder(out_d),
                _ShardingPlaceholder(out_d),
                _ShardingPlaceholder(d),
            ]
        )
    return strategies


@register_op_strategy(list(ARGMAX_ARGMIN_OPS.keys()), schema_info=RuntimeSchemaInfo(1))
def argmax_argmin_strategy(op_schema: OpSchema) -> OpStrategy:
    """
    Strategy for argmax/argmin. These return indices, not values, so they cannot
    use P(max/min) output placements. The indices are local to each shard and
    cannot be meaningfully combined across ranks with a max/min reduction.
    Force redistribution on reduction dimensions by using reduction_linear=False.
    """
    args_schema = op_schema.args_schema
    input_strategy = args_schema[0]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")

    dims = None
    if len(op_schema.args_schema) > 1:
        dims = _infer_reduction_dims(args_schema[1], input_strategy.ndim)

    reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims
    keep_dim = len(op_schema.args_schema) > 2 and bool(op_schema.args_schema[2])
    reduction_op = ARGMAX_ARGMIN_OPS[op_schema.op]
    return common_reduction_strategy(
        input_strategy,
        reduce_dims,
        keep_dim=keep_dim,
        reduction_linear=False,  # Force redistribution - indices can't use P(max/min)
        # reduction_op is effectively unused here: reduction_linear=False
        # forces all reduction-dim Shard placements to Replicate before
        # map_placements_after_reduction, so no Shard-on-reduction-dim
        # remains to convert to Partial. Passed for consistency.
        reduction_op=reduction_op,
    )


@register_op_strategy(
    [aten.cumsum.default, aten.cumprod.default, aten.logcumsumexp.default],
    schema_info=RuntimeSchemaInfo(1),
)
def scan_strategy(op_schema: OpSchema) -> OpStrategy:
    args_schema = op_schema.args_schema
    input_strategy = args_schema[0]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
    dim = args_schema[1]
    if not isinstance(dim, int):
        raise AssertionError(f"Expected int, got {type(dim)}")
    return common_reduction_strategy(
        input_strategy, [dim], keep_dim=True, reduction_linear=False
    )


@register_op_strategy(
    [aten.median.default, aten.nanmedian.default],
    schema_info=RuntimeSchemaInfo(1),
)
def global_median_strategy(op_schema: OpSchema) -> OpStrategy:
    input_strategy = cast(OpStrategy, op_schema.args_schema[0])
    reduce_dims = list(range(input_strategy.ndim))
    return common_reduction_strategy(
        input_strategy, reduce_dims, reduction_linear=False
    )


@register_single_dim_strategy(
    [aten.median.dim, aten.nanmedian.dim, aten.mode.default],
    schema_info=RuntimeSchemaInfo(1),
)
def dim_reduction_with_indices_strategy(
    op: torch._ops.OpOverload,
    args_schema: tuple[Any, ...],
    kwargs_schema: dict[str, Any],
) -> list[list[Placement | _ShardingPlaceholder]]:
    input_meta = args_schema[0]
    if not isinstance(input_meta, TensorMeta):
        raise AssertionError(f"Expected TensorMeta, got {type(input_meta)}")

    ndim = len(input_meta.shape)
    dim = normalize_dim(cast(int, args_schema[1]) if len(args_schema) > 1 else -1, ndim)
    keep_dim = len(args_schema) > 2 and bool(args_schema[2])

    strategies: list[list[Placement | _ShardingPlaceholder]] = []
    for d in range(ndim):
        if d == dim:
            continue
        out_d = d if keep_dim or d < dim else d - 1
        strategies.append(
            [
                _ShardingPlaceholder(out_d),
                _ShardingPlaceholder(out_d),
                _ShardingPlaceholder(d),
            ]
        )
    return strategies


@register_single_dim_strategy(
    [aten.kthvalue.default],
    schema_info=RuntimeSchemaInfo(2),
)
def kthvalue_strategy(
    op: torch._ops.OpOverload,
    args_schema: tuple[Any, ...],
    kwargs_schema: dict[str, Any],
) -> list[list[Placement | _ShardingPlaceholder]]:
    input_meta = args_schema[0]
    if not isinstance(input_meta, TensorMeta):
        raise AssertionError(f"Expected TensorMeta, got {type(input_meta)}")

    ndim = len(input_meta.shape)
    dim = normalize_dim(cast(int, args_schema[2]) if len(args_schema) > 2 else -1, ndim)
    keep_dim = len(args_schema) > 3 and bool(args_schema[3])

    strategies: list[list[Placement | _ShardingPlaceholder]] = []
    for d in range(ndim):
        if d == dim:
            continue
        out_d = d if keep_dim or d < dim else d - 1
        strategies.append(
            [
                _ShardingPlaceholder(out_d),
                _ShardingPlaceholder(out_d),
                _ShardingPlaceholder(d),
            ]
        )
    return strategies


@register_op_strategy(
    [aten.cummax.default, aten.cummin.default],
    schema_info=RuntimeSchemaInfo(1),
)
def cummax_cummin_strategy(op_schema: OpSchema) -> OpStrategy:
    dim = cast(int, op_schema.args_schema[1])
    return sort_strategy(op_schema, dim)


@register_op_strategy(
    [
        aten.std.correction,
        aten.std.correction_out,
        aten.var.correction,
        aten.var.correction_out,
        aten.var_mean.correction,
        aten.var_mean.correction_out,
        prims.var.default,
    ],
    schema_info=RuntimeSchemaInfo(1, ["keepdim"]),
)
def std_var_reduction_strategy(op_schema: OpSchema) -> OpStrategy:
    args_schema = op_schema.args_schema
    input_strategy = args_schema[0]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
    dims = None
    if len(op_schema.args_schema) > 1:
        dims = _infer_reduction_dims(args_schema[1], input_strategy.ndim)

    reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims

    keep_dim = cast(bool, op_schema.kwargs_schema.get("keepdim", False))
    return common_reduction_strategy(
        input_strategy, reduce_dims, keep_dim=keep_dim, reduction_linear=False
    )


def _get_norm_reduction_op(norm_type: int | float | str) -> ReductionOpType:
    """Get the reduction op for vector/foreach norm based on norm_type.

    For inf/-inf norms, returns simple reduction ops ("max", "min").
    For other norms (including 0), returns NormReduction which produces the
    appropriate Partial placement via get_placement_from_reduction_op.
    """
    if norm_type in (float("inf"), "inf"):
        return "max"
    elif norm_type in (float("-inf"), "-inf"):
        return "min"
    else:
        if not isinstance(norm_type, (int, float)):
            raise AssertionError
        return NormReduction(norm_type)


@register_op_strategy(
    [aten.linalg_vector_norm.default, aten.norm.Scalar],
    schema_info=RuntimeSchemaInfo(1),
)
def vector_norm_strategy(op_schema: OpSchema) -> OpStrategy:
    args_schema = op_schema.args_schema
    input_strategy = args_schema[0]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")

    norm_type = args_schema[1] if len(args_schema) > 1 else 2
    if not isinstance(norm_type, (int, float, str)):
        raise AssertionError(f"Expected int, float, or str, got {type(norm_type)}")
    dim = args_schema[2] if len(args_schema) > 2 else None
    keepdim = args_schema[3] if len(args_schema) > 3 else False
    dims = _infer_reduction_dims(dim, input_strategy.ndim)
    reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims
    return common_reduction_strategy(
        input_strategy,
        reduce_dims,
        keep_dim=cast(bool, keepdim),
        reduction_op=_get_norm_reduction_op(norm_type),
    )


@register_op_strategy(
    [aten._foreach_norm.Scalar], schema_info=RuntimeSchemaInfo(1, needs_pytree=True)
)
def foreach_norm_strategy(op_schema: OpSchema) -> TupleStrategy:
    args_schema = op_schema.args_schema
    input_tuple_strategy = args_schema[0]
    if not isinstance(input_tuple_strategy, TupleStrategy):
        raise AssertionError(
            f"Expected TupleStrategy, got {type(input_tuple_strategy)}"
        )
    norm_type = args_schema[1] if len(args_schema) > 1 else 2
    if not isinstance(norm_type, (int, float, str)):
        raise AssertionError(f"Expected int, float, or str, got {type(norm_type)}")
    output_tuple_strategy_children: list[OpStrategy] = []
    for op_strategy in input_tuple_strategy.children:
        if not isinstance(op_strategy, OpStrategy):
            raise AssertionError(f"Expected OpStrategy, got {type(op_strategy)}")
        reduce_dims = list(range(op_strategy.ndim))
        output_strategy = common_reduction_strategy(
            op_strategy,
            reduce_dims,
            reduction_op=_get_norm_reduction_op(norm_type),
        )
        output_tuple_strategy_children.append(output_strategy)
    return TupleStrategy(output_tuple_strategy_children)


@register_op_strategy([aten.linalg__powsum.default], schema_info=RuntimeSchemaInfo(1))
def powsum_strategy(op_schema: OpSchema) -> OpStrategy:
    """
    Strategy for linalg__powsum: computes sum(|x|^ord) without the final root.
    Output is always reducible with Partial("sum").
    """
    args_schema = op_schema.args_schema
    input_strategy = args_schema[0]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")

    dim = args_schema[2] if len(args_schema) > 2 else None
    keepdim = args_schema[3] if len(args_schema) > 3 else False
    dims = _infer_reduction_dims(dim, input_strategy.ndim)
    reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims
    return common_reduction_strategy(
        input_strategy,
        reduce_dims,
        keep_dim=cast(bool, keepdim),
        reduction_linear=True,
        reduction_op="sum",
    )


@register_op_strategy(
    [aten._foreach_powsum.Scalar], schema_info=RuntimeSchemaInfo(1, needs_pytree=True)
)
def foreach_powsum_strategy(op_schema: OpSchema) -> TupleStrategy:
    """
    Strategy for _foreach_powsum: computes sum(|x|^ord) for each tensor.
    Output is always reducible with Partial("sum").
    """
    args_schema = op_schema.args_schema
    input_tuple_strategy = args_schema[0]
    if not isinstance(input_tuple_strategy, TupleStrategy):
        raise AssertionError(
            f"Expected TupleStrategy, got {type(input_tuple_strategy)}"
        )
    output_tuple_strategy_children: list[OpStrategy] = []
    for op_strategy in input_tuple_strategy.children:
        if not isinstance(op_strategy, OpStrategy):
            raise AssertionError(f"Expected OpStrategy, got {type(op_strategy)}")
        reduce_dims = list(range(op_strategy.ndim))
        output_strategy = common_reduction_strategy(
            op_strategy,
            reduce_dims,
            reduction_linear=True,
            reduction_op="sum",
        )
        output_tuple_strategy_children.append(output_strategy)
    return TupleStrategy(output_tuple_strategy_children)


@register_op_strategy(
    [aten._foreach_max.default], schema_info=RuntimeSchemaInfo(1, needs_pytree=True)
)
def foreach_max_strategy(op_schema: OpSchema) -> TupleStrategy:
    """
    Strategy for _foreach_max, which reduces each tensor in a list to its maximum value.
    """
    args_schema = op_schema.args_schema
    input_tuple_strategy = args_schema[0]
    if not isinstance(input_tuple_strategy, TupleStrategy):
        raise AssertionError(
            f"Expected TupleStrategy, got {type(input_tuple_strategy)}"
        )
    output_tuple_strategy_children: list[OpStrategy] = []
    for op_strategy in input_tuple_strategy.children:
        if not isinstance(op_strategy, OpStrategy):
            raise AssertionError(f"Expected OpStrategy, got {type(op_strategy)}")
        # Reduce all dimensions to get a scalar
        reduce_dims = list(range(op_strategy.ndim))
        output_strategy = common_reduction_strategy(
            op_strategy,
            reduce_dims,
            reduction_linear=True,
            reduction_op="max",
        )
        output_tuple_strategy_children.append(output_strategy)
    return TupleStrategy(output_tuple_strategy_children)


@register_op_strategy(
    [
        aten._linalg_svd.default,
        aten.linalg_qr.default,
        # TODO: The diagonal ops can have an improved sharding strategy for
        # shard placements that does not require redistributing to replicate.
        aten.diagonal_copy.default,
        aten.diag_embed.default,
        aten.diag.default,
        aten.diagonal.default,
        aten.tril.default,
        aten.triu.default,
        aten._linalg_eigh.default,
    ],
    schema_info=RuntimeSchemaInfo(1),
)
def linalg_replicate_strategy(op_schema: OpSchema) -> OpStrategy:
    """
    Since we do not have a simple way to compute some linear algebra operations
    like SVD or QR decomposition, always fall back to replicate.
    """
    args_schema = op_schema.args_schema
    input_strategy = args_schema[0]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
    mesh = input_strategy.mesh

    output_strategies: list[OpSpec] = []
    for placement_strategy in input_strategy.strategies:
        replicate_placements = tuple(Replicate() for _ in range(mesh.ndim))
        replicate_spec = DTensorSpec(
            mesh=mesh,
            placements=replicate_placements,
            tensor_meta=placement_strategy.output_spec.tensor_meta,
        )
        redistribute_cost = [
            generate_redistribute_costs(input_strategy, replicate_spec)
        ]
        replicate_strategy = OpSpec(
            output_specs=replicate_spec,
            input_specs=(replicate_spec,),
            redistribute_cost=redistribute_cost,
        )
        output_strategies.append(replicate_strategy)
    return OpStrategy(output_strategies)


# Maps each pooling op to its spatial rank (number of spatial dimensions).
# Batched inputs have layout (N, C, *spatial) with ndim = spatial_rank + 2;
# unbatched inputs drop the batch dim giving ndim = spatial_rank + 1.
POOL_SPATIAL_RANK: dict[torch._ops.OpOverload, int] = {
    aten.avg_pool1d.default: 1,
    aten.avg_pool2d.default: 2,
    aten.avg_pool3d.default: 3,
    aten.adaptive_avg_pool1d.default: 1,
    aten._adaptive_avg_pool2d.default: 2,
    aten._adaptive_avg_pool3d.default: 3,
    aten.adaptive_max_pool1d.default: 1,
    aten.adaptive_max_pool2d.default: 2,
    aten.adaptive_max_pool3d.default: 3,
    aten.fractional_max_pool2d.default: 2,
    aten.fractional_max_pool3d.default: 3,
    aten.max_pool1d_with_indices.default: 1,
    aten.max_pool2d_with_indices.default: 2,
    aten.max_pool3d_with_indices.default: 3,
}

AVG_POOL_OPS = [
    aten.avg_pool1d.default,
    aten.avg_pool2d.default,
    aten.avg_pool3d.default,
    aten.adaptive_avg_pool1d.default,
    aten._adaptive_avg_pool2d.default,
    aten._adaptive_avg_pool3d.default,
]

MAX_POOL_OPS = [
    aten.adaptive_max_pool1d.default,
    aten.adaptive_max_pool2d.default,
    aten.adaptive_max_pool3d.default,
    aten.fractional_max_pool2d.default,
    aten.fractional_max_pool3d.default,
    aten.max_pool1d_with_indices.default,
    aten.max_pool2d_with_indices.default,
    aten.max_pool3d_with_indices.default,
]


@register_op_strategy(
    AVG_POOL_OPS + MAX_POOL_OPS,
    schema_info=RuntimeSchemaInfo(1),
)
def pooling_strategy(op_schema: OpSchema) -> OpStrategy:
    input_strategy = cast(OpStrategy, op_schema.args_schema[0])
    mesh = input_strategy.mesh
    num_outputs = 2 if op_schema.op in MAX_POOL_OPS else 1
    num_inputs = len(op_schema.args_strategy) + len(op_schema.kwargs_strategy)
    n = num_outputs + num_inputs
    single_mesh_dim_strategies: list[PlacementList] = [
        [Replicate()] * n,
        [Shard(0)] * n,
    ]
    # avg_pool is linear: Partial(sum) and Partial(avg) pass through unchanged.
    if op_schema.op in AVG_POOL_OPS:
        single_mesh_dim_strategies.append([Partial("sum")] * n)
        single_mesh_dim_strategies.append([Partial("avg")] * n)
    # S(1) is safe when dim 1 is the channel dim (pooling never touches it).
    # Batched inputs have layout (N, C, *spatial) with ndim = spatial_rank + 2.
    spatial_rank = POOL_SPATIAL_RANK[op_schema.op]
    is_batched = input_strategy.ndim >= spatial_rank + 2
    if is_batched:
        single_mesh_dim_strategies.append([Shard(1)] * n)
    return expand_to_full_mesh_op_strategy(
        mesh, op_schema, single_mesh_dim_strategies, input_index=num_outputs
    )


@register_op_strategy(
    [aten._log_softmax.default, aten._softmax.default, aten._safe_softmax.default],
    schema_info=RuntimeSchemaInfo(1),
)
def softmax_strategy(op_schema: OpSchema) -> OpStrategy:
    input_strategy, softmax_dim, *_ = op_schema.args_schema
    input_strategy = cast(OpStrategy, input_strategy)

    softmax_dim = cast(int, softmax_dim)
    softmax_dim = normalize_dim(softmax_dim, input_strategy.ndim)

    output_strategy = OpStrategy([])
    for input_placement_strategy in input_strategy.strategies:
        redistribute_costs = []
        input_src_spec = input_placement_strategy.output_spec

        # make sure input is replicated along the softmax dim
        input_target_spec = DTensorSpec(
            mesh=input_strategy.mesh,
            placements=replicate_reduction_dims(
                input_src_spec.placements, [softmax_dim]
            ),
            tensor_meta=input_src_spec.tensor_meta,
        )
        redistribute_costs.append(
            generate_redistribute_costs(input_strategy, input_target_spec)
        )
        output_target_spec = input_target_spec
        output_strategy.strategies.append(
            OpSpec(
                output_specs=output_target_spec,
                input_specs=[input_target_spec],
                redistribute_cost=redistribute_costs,
            )
        )

    return output_strategy


@register_op_strategy(
    [
        aten._log_softmax_backward_data.default,
        aten._softmax_backward_data.default,
    ],
    schema_info=RuntimeSchemaInfo(2),
)
def softmax_backward_strategy(op_schema: OpSchema) -> OpStrategy:
    grad_out_strategy, out_strategy, softmax_dim, _ = op_schema.args_schema
    grad_out_strategy = cast(OpStrategy, grad_out_strategy)
    out_strategy = cast(OpStrategy, out_strategy)
    softmax_dim = cast(int, softmax_dim)
    softmax_dim = normalize_dim(softmax_dim, grad_out_strategy.ndim)

    grad_in_strategy = OpStrategy([])
    for grad_out_placement_strat, out_placement_strat in zip(
        grad_out_strategy.strategies, out_strategy.strategies
    ):
        # follow the sharding of the grad_out or out depending on which has more shards
        grad_out_src_spec = grad_out_placement_strat.output_spec
        out_src_spec = out_placement_strat.output_spec
        src_spec = (
            grad_out_src_spec
            if grad_out_src_spec.num_shards >= out_src_spec.num_shards
            else out_src_spec
        )

        # make sure inputs are replicated along the softmax dim
        tgt_spec = DTensorSpec(
            mesh=grad_out_strategy.mesh,
            placements=replicate_reduction_dims(src_spec.placements, [softmax_dim]),
        )
        new_grad_out_spec = DTensorSpec(
            mesh=tgt_spec.mesh,
            placements=tgt_spec.placements,
            tensor_meta=grad_out_src_spec.tensor_meta,
        )
        new_out_spec = DTensorSpec(
            mesh=tgt_spec.mesh,
            placements=tgt_spec.placements,
            tensor_meta=out_src_spec.tensor_meta,
        )
        redist_grad_out_cost = generate_redistribute_costs(grad_out_strategy, tgt_spec)
        redist_out_cost = generate_redistribute_costs(out_strategy, tgt_spec)
        grad_in_strategy.strategies.append(
            OpSpec(
                output_specs=tgt_spec,
                input_specs=(new_grad_out_spec, new_out_spec),
                redistribute_cost=[redist_grad_out_cost, redist_out_cost],
            )
        )

    return grad_in_strategy


@register_op_strategy(
    [aten.nll_loss_forward.default, aten.nll_loss2d_forward.default],
    schema_info=RuntimeSchemaInfo(3),
)
def nll_loss_forward_strategy(op_schema: OpSchema) -> OpStrategy:
    mesh = op_schema.get_mesh_from_args()

    if not len(op_schema.args_schema) == 5:
        raise AssertionError(f"Expected 5 args, got {len(op_schema.args_schema)}")

    (
        input_strategy,
        target_strategy,
        weight_strategy,
        reduction,
        _,
    ) = op_schema.args_schema
    input_strategy = cast(OpStrategy, input_strategy)
    target_strategy = cast(OpStrategy, target_strategy)
    reduction = cast(int, reduction)

    input_shape = input_strategy.shape
    channel_dim = 1 if len(input_shape) >= 2 else 0

    output_strategy = OpStrategy([])
    for idx, input_placement_strategy in enumerate(input_strategy.strategies):
        op_args_target_specs = []
        redistribute_costs = []

        # make sure input is replicated along the channel dim
        input_src_spec = input_placement_strategy.output_spec
        input_expected_spec = DTensorSpec(
            mesh=mesh,
            placements=replicate_reduction_dims(
                input_src_spec.placements, [channel_dim]
            ),
            tensor_meta=input_src_spec.tensor_meta,
        )
        op_args_target_specs.append(input_expected_spec)
        redistribute_costs.append(
            generate_redistribute_costs(input_strategy, input_expected_spec)
        )

        # target doesn't have channel dim, and it follows input on other dims
        target_src_spec = target_strategy.strategies[idx].output_spec
        target_expected_spec = DTensorSpec(
            mesh=mesh,
            placements=_skip_dim(input_expected_spec.placements, channel_dim),
            tensor_meta=target_src_spec.tensor_meta,
        )
        op_args_target_specs.append(target_expected_spec)
        redistribute_costs.append(
            generate_redistribute_costs(target_strategy, target_expected_spec)
        )

        # weight tensor, if given, has to be a Tensor of size input_shape[channel_dim]
        # make sure it is replicated
        if weight_strategy is not None:
            if not isinstance(weight_strategy, OpStrategy):
                raise AssertionError(
                    f"Expected OpStrategy, got {type(weight_strategy)}"
                )
            weight_src_spec = weight_strategy.strategies[idx].output_spec
            weight_expected_spec = DTensorSpec(
                mesh=mesh,
                placements=_replicate_dims_start_at(weight_src_spec.placements),
                tensor_meta=weight_src_spec.tensor_meta,
            )
            op_args_target_specs.append(weight_expected_spec)
            redistribute_costs.append(
                generate_redistribute_costs(weight_strategy, weight_expected_spec)
            )

        if reduction == Reduction.NONE.value:
            output_expected_spec = target_expected_spec
            total_weight_expected_spec = DTensorSpec(
                mesh=mesh, placements=tuple([Replicate()] * mesh.ndim)
            )
        else:
            if reduction == Reduction.MEAN.value:
                reduction_op = "avg"
                if not is_tensor_evenly_shardable(
                    target_expected_spec.shape, target_expected_spec
                ):
                    raise ValueError(
                        "The intermediate results of nll_loss cannot be evenly sharded, \
                        resulting in biased mean result."
                    )
            else:  # reduction == Reduction.SUM.value:
                reduction_op = "sum"
            reduce_dims = list(range(target_expected_spec.ndim))
            reduce_dims_map = _infer_reduce_dims_map(
                reduce_dims, target_expected_spec.ndim, keep_dim=False
            )
            out_placements = map_placements_after_reduction(
                target_expected_spec.placements,
                reduce_dims,
                reduce_dims_map,
                reduction_op,
            )
            output_expected_spec = DTensorSpec(
                mesh=mesh,
                placements=out_placements,
            )

            # whether reduction is sum or mean, the total weight has to be summed up if not replicated
            total_weight_placements = map_placements_after_reduction(
                target_expected_spec.placements,
                reduce_dims,
                reduce_dims_map,
                "sum",
            )
            total_weight_expected_spec = DTensorSpec(
                mesh=mesh,
                placements=total_weight_placements,
            )

        output_strategy.strategies.append(
            OpSpec(
                output_specs=(output_expected_spec, total_weight_expected_spec),
                input_specs=op_args_target_specs,
                redistribute_cost=redistribute_costs,
            )
        )

    return output_strategy


@register_op_strategy(
    [aten.nll_loss_backward.default, aten.nll_loss2d_backward.default],
    schema_info=RuntimeSchemaInfo(4),
)
def nll_loss_backward_strategy(op_schema: OpSchema) -> OpStrategy:
    # backward op does not need to validate the mesh since forward op has already done it
    mesh = op_schema.get_mesh_from_args(validate=False)

    if not len(op_schema.args_schema) == 7:
        raise AssertionError(f"Expected 7 args, got {len(op_schema.args_schema)}")
    (
        grad_out_strategy,
        input_strategy,
        target_strategy,
        weight_strategy,
        reduction,
        _,
        total_weight_strategy,
    ) = op_schema.args_schema
    grad_out_strategy = cast(OpStrategy, grad_out_strategy)
    input_strategy = cast(OpStrategy, input_strategy)
    target_strategy = cast(OpStrategy, target_strategy)
    reduction = cast(int, reduction)
    total_weight_strategy = cast(OpStrategy, total_weight_strategy)

    input_shape = input_strategy.shape
    channel_dim = 1 if len(input_shape) >= 2 else 0

    grad_in_strategy = OpStrategy([])
    for idx, input_placement_strategy in enumerate(input_strategy.strategies):
        op_args_target_specs = []
        redistribute_costs = []

        # make sure input is replicated along the channel dim
        input_src_spec = input_placement_strategy.output_spec
        input_expected_spec = DTensorSpec(
            mesh=mesh,
            placements=replicate_reduction_dims(
                input_src_spec.placements, [channel_dim]
            ),
            tensor_meta=input_src_spec.tensor_meta,
        )
        op_args_target_specs.append(input_expected_spec)
        redistribute_costs.append(
            generate_redistribute_costs(input_strategy, input_expected_spec)
        )

        # target doesn't have channel dim, and it follows input on other dims
        target_src_spec = target_strategy.strategies[idx].output_spec
        target_expected_spec = DTensorSpec(
            mesh=mesh,
            placements=_skip_dim(input_expected_spec.placements, channel_dim),
            tensor_meta=target_src_spec.tensor_meta,
        )
        op_args_target_specs.append(target_expected_spec)
        redistribute_costs.append(
            generate_redistribute_costs(target_strategy, target_expected_spec)
        )

        # grad_out follows target if there is no reduction;
        # otherwise, it should be a replicated scalar.
        grad_out_src_spec = grad_out_strategy.strategies[idx].output_spec
        if reduction == Reduction.NONE.value:
            grad_out_expected_spec = target_expected_spec
        else:
            grad_out_expected_spec = DTensorSpec(
                mesh=mesh,
                placements=_replicate_dims_start_at(grad_out_src_spec.placements),
                tensor_meta=grad_out_src_spec.tensor_meta,
            )
        op_args_target_specs.insert(0, grad_out_expected_spec)
        redistribute_costs.insert(
            0, generate_redistribute_costs(grad_out_strategy, grad_out_expected_spec)
        )

        # weight tensor, if given, has to be a Tensor of size input_shape[channel_dim]
        # make sure it is replicated
        if weight_strategy is not None:
            if not isinstance(weight_strategy, OpStrategy):
                raise AssertionError(
                    f"Expected OpStrategy, got {type(weight_strategy)}"
                )
            weight_src_spec = weight_strategy.strategies[idx].output_spec
            weight_expected_spec = DTensorSpec(
                mesh=mesh,
                placements=_replicate_dims_start_at(weight_src_spec.placements),
                tensor_meta=weight_src_spec.tensor_meta,
            )
            op_args_target_specs.append(weight_expected_spec)
            redistribute_costs.append(
                generate_redistribute_costs(weight_strategy, weight_expected_spec)
            )

        # total_weight is only used by the backward kernel for reduction='mean'.
        # For reduction='sum' or 'none', it is unused, so no redistribution needed.
        total_weight_src_spec = total_weight_strategy.strategies[idx].output_spec
        if reduction == Reduction.MEAN.value:
            total_weight_expected_spec = DTensorSpec(
                mesh=mesh,
                placements=_replicate_dims_start_at(total_weight_src_spec.placements),
                tensor_meta=total_weight_src_spec.tensor_meta,
            )
        else:
            total_weight_expected_spec = total_weight_src_spec
        op_args_target_specs.append(total_weight_expected_spec)
        redistribute_costs.append(
            generate_redistribute_costs(
                total_weight_strategy, total_weight_expected_spec
            )
        )

        grad_in_expected_spec = input_expected_spec
        grad_in_strategy.strategies.append(
            OpSpec(
                output_specs=grad_in_expected_spec,
                input_specs=op_args_target_specs,
                redistribute_cost=redistribute_costs,
            )
        )

    return grad_in_strategy


@register_single_dim_strategy(
    [aten.native_layer_norm.default],
    schema_info=RuntimeSchemaInfo(1),
)
def layer_norm_single_dim_strategy(
    op: torch._ops.OpOverload,
    args_schema: tuple[Any, ...],
    kwargs_schema: dict[str, Any],
) -> list[list[Placement | _ShardingPlaceholder]]:
    input_meta = args_schema[0]
    normalized_shape = args_schema[1]
    weight_meta = args_schema[2]
    bias_meta = args_schema[3]

    axis = len(input_meta.shape) - len(normalize_to_torch_size(normalized_shape))

    strategies: list[list[Placement | _ShardingPlaceholder]] = []
    for dim in range(axis):
        # [out, mean, rstd, input, weight?, bias?]
        rule: list[Placement | _ShardingPlaceholder] = [
            _ShardingPlaceholder(dim),  # out
            _ShardingPlaceholder(dim),  # mean
            _ShardingPlaceholder(dim),  # rstd
            _ShardingPlaceholder(dim),  # input
        ]
        if weight_meta is not None:
            rule.append(Replicate())
        if bias_meta is not None:
            rule.append(Replicate())
        strategies.append(rule)
    return strategies


@register_single_dim_strategy(
    [aten._fused_rms_norm.default],
    schema_info=RuntimeSchemaInfo(1),
)
def rms_norm_single_dim_strategy(
    op: torch._ops.OpOverload,
    args_schema: tuple[Any, ...],
    kwargs_schema: dict[str, Any],
) -> list[list[Placement | _ShardingPlaceholder]]:
    input_meta = args_schema[0]
    normalized_shape = args_schema[1]
    weight_meta = args_schema[2]

    axis = len(input_meta.shape) - len(normalize_to_torch_size(normalized_shape))

    strategies: list[list[Placement | _ShardingPlaceholder]] = []
    for dim in range(axis):
        # [out, rrms, input, weight?]
        rule: list[Placement | _ShardingPlaceholder] = [
            _ShardingPlaceholder(dim),  # out
            _ShardingPlaceholder(dim),  # rrms
            _ShardingPlaceholder(dim),  # input
        ]
        if weight_meta is not None:
            rule.append(Replicate())
        strategies.append(rule)
    return strategies


@register_single_dim_strategy(
    [aten.native_layer_norm_backward.default],
    schema_info=RuntimeSchemaInfo(2),
)
def layer_norm_bwd_single_dim_strategy(
    op: torch._ops.OpOverload,
    args_schema: tuple[Any, ...],
    kwargs_schema: dict[str, Any],
) -> list[list[Placement | _ShardingPlaceholder | None]]:
    input_meta = args_schema[1]
    normalized_shape = args_schema[2]
    # mean = args_schema[3], rstd = args_schema[4]
    weight_meta = args_schema[5]
    bias_meta = args_schema[6]

    axis = len(input_meta.shape) - len(normalize_to_torch_size(normalized_shape))

    strategies: list[list[Placement | _ShardingPlaceholder | None]] = []
    for dim in range(axis):
        # outputs: [d_input, d_weight, d_bias] — always 3 per schema
        # d_weight/d_bias use None when weight/bias are None
        rule: list[Placement | _ShardingPlaceholder | None] = [
            _ShardingPlaceholder(dim),  # d_input
            Partial("sum") if weight_meta is not None else None,  # d_weight
            Partial("sum") if bias_meta is not None else None,  # d_bias
        ]
        # inputs: [grad_out, input, mean, rstd, weight?, bias?]
        rule.extend(
            [
                _ShardingPlaceholder(dim),  # grad_out
                _ShardingPlaceholder(dim),  # input
                _ShardingPlaceholder(dim),  # mean
                _ShardingPlaceholder(dim),  # rstd
            ]
        )
        if weight_meta is not None:
            rule.append(Replicate())
        if bias_meta is not None:
            rule.append(Replicate())
        strategies.append(rule)

    return strategies


@register_single_dim_strategy(
    [aten._fused_rms_norm_backward.default],
    schema_info=RuntimeSchemaInfo(2),
)
def rms_norm_bwd_single_dim_strategy(
    op: torch._ops.OpOverload,
    args_schema: tuple[Any, ...],
    kwargs_schema: dict[str, Any],
) -> list[list[Placement | _ShardingPlaceholder | None]]:
    input_meta = args_schema[1]
    normalized_shape = args_schema[2]
    # rstd = args_schema[3]
    weight_meta = args_schema[4]

    axis = len(input_meta.shape) - len(normalize_to_torch_size(normalized_shape))

    strategies: list[list[Placement | _ShardingPlaceholder | None]] = []
    for dim in range(axis):
        # outputs: [d_input, d_weight] — always 2 per schema
        # d_weight uses None when weight is None
        # inputs: [grad_out, input, rstd, weight?]
        rule: list[Placement | _ShardingPlaceholder | None] = [
            _ShardingPlaceholder(dim),  # d_input
            Partial("sum") if weight_meta is not None else None,  # d_weight
            _ShardingPlaceholder(dim),  # grad_out
            _ShardingPlaceholder(dim),  # input
            _ShardingPlaceholder(dim),  # rstd
        ]
        if weight_meta is not None:
            rule.append(Replicate())
        strategies.append(rule)

    return strategies


def sort_strategy(op_schema: OpSchema, sort_dim: int) -> OpStrategy:
    input_strategy = cast(OpStrategy, op_schema.args_schema[0])
    sort_dim = normalize_dim(sort_dim, input_strategy.ndim)
    single_mesh_dim_strategies = []
    all_replicate: PlacementList = [Replicate()] * 3
    single_mesh_dim_strategies.append(all_replicate)
    for dim in range(input_strategy.ndim):
        if dim != sort_dim:
            dim_shardings: PlacementList = [Shard(dim)] * 3
            single_mesh_dim_strategies.append(dim_shardings)
    return expand_to_full_mesh_op_strategy(
        input_strategy.mesh, op_schema, single_mesh_dim_strategies, input_index=2
    )


@register_op_strategy(
    [aten.topk.default],
    schema_info=RuntimeSchemaInfo(2),
)
def topk_strategy(op_schema: OpSchema) -> OpStrategy:
    topk_dim = (
        cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else -1
    )
    return sort_strategy(op_schema, topk_dim)


@register_op_strategy(
    aten.sort.default,
    schema_info=RuntimeSchemaInfo(
        1,
    ),
)
def sort_default_strategy(op_schema: OpSchema) -> OpStrategy:
    # mostly copy paste from topk_strategy
    input_strategy = op_schema.args_schema[0]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
    sort_dim = -1
    if len(op_schema.args_schema) > 1:
        sort_dim = cast(int, op_schema.args_schema[1])
    return sort_strategy(op_schema, sort_dim)


@register_op_strategy(
    aten.sort.stable,
    schema_info=RuntimeSchemaInfo(
        1,
        static_kwargkey=["dim", "descending", "stable"],
    ),
)
def sort_stable_strategy(op_schema: OpSchema) -> OpStrategy:
    # mostly copy paste from topk_strategy
    input_strategy = op_schema.args_schema[0]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
    sort_dim = -1
    if "dim" in op_schema.kwargs_schema:
        sort_dim = cast(int, op_schema.kwargs_schema["dim"])
    return sort_strategy(op_schema, sort_dim)


@register_op_strategy(
    [aten.histc.default],
    # strategy choice depends on the value of 'min' and 'max' kwargs, which are position 2 and 3
    schema_info=RuntimeSchemaInfo(2),
)
def histc_strategy(op_schema: OpSchema) -> OpStrategy:
    input_strategy = cast(OpStrategy, op_schema.args_schema[0])
    single_mesh_dim_strategies: list[PlacementList] = []
    single_mesh_dim_strategies.append([Replicate(), Replicate()])

    # histc can support sharded input and partial output on any input dim, provided the min and max
    # values are user-specified.  If not user-specified, the true min and max of the data in each local
    # tensor will be used to compute bin boundaries, which will not be the same across ranks, leading to
    # an incorrect final result
    if len(op_schema.args_schema) == 4:
        for dim in range(input_strategy.ndim):
            dim_shardings: PlacementList = [Partial(), Shard(dim)]
            single_mesh_dim_strategies.append(dim_shardings)

    return expand_to_full_mesh_op_strategy(
        input_strategy.mesh, op_schema, single_mesh_dim_strategies
    )


@register_op_strategy(
    [aten.logsumexp.default],
    schema_info=RuntimeSchemaInfo(
        # static_argnum is the position where non-Tensor args beings.
        static_argnum=1,
        # static_kwargkey is the name of kwargs to hash (which determines
        # whether sharding prop can be cached).
        static_kwargkey=["keepdim"],
    ),
)
def logsumexp_strategy(op_schema: OpSchema) -> OpStrategy:
    """Implements the sharding propagation strategy for logsumexp."""

    # args_schema contains all but the DTensor args (e.g., dim, keepdim).
    args_schema = op_schema.args_schema
    if not len(args_schema) > 1:
        raise AssertionError(
            f"Expected more than 1 arg (input and dim are required), got {len(args_schema)}"
        )

    input_strategy = args_schema[0]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")

    dims_arg = args_schema[1]
    reduce_dims = _infer_reduction_dims(dims_arg, input_strategy.ndim)
    if reduce_dims is None:
        raise AssertionError("Expected reduce_dims to not be None")

    keep_dim = cast(bool, op_schema.kwargs_schema.get("keepdim", False))
    return common_reduction_strategy(
        input_strategy,
        reduce_dims,
        keep_dim=keep_dim,
        reduction_linear=False,
    )


_LINALG_NUM_PLACEMENTS = {
    # 1 in 1 out
    aten.cholesky.default: 2,
    aten.cholesky_inverse.default: 2,
    aten.linalg_matrix_exp.default: 2,
    # 2 in 1 out
    aten.cholesky_solve.default: 3,
    aten.linalg_householder_product.default: 3,
    aten.linalg_solve_triangular.default: 3,
    # 3 in 1 out
    aten.linalg_ldl_solve.default: 4,
    aten.linalg_lu_solve.default: 4,
    aten.ormqr.default: 4,
    # 1 in 2 out
    aten.geqrf.default: 3,
    aten.linalg_cholesky_ex.default: 3,
    aten.linalg_eig.default: 3,
    aten.linalg_inv_ex.default: 3,
    # 2 in 2 out
    aten.triangular_solve.default: 4,
    # 1 in 3 out
    aten._linalg_det.default: 4,
    aten.linalg_ldl_factor_ex.default: 4,
    aten.linalg_lu.default: 4,
    aten.linalg_lu_factor_ex.default: 4,
    # 2 in 3 out
    aten.lu_unpack.default: 5,
    # 1 in 4 out
    aten._linalg_slogdet.default: 5,
    # 2 in 4 out
    aten._linalg_solve_ex.default: 6,
    # 1 in
    aten._linalg_check_errors.default: 1,
}


def _linalg_batch_dim_strategies(
    ndim: int, n_placements: int
) -> list[list[Placement | _ShardingPlaceholder]]:
    """Build single-dim strategies for linalg ops that operate on the last 1-2 dims.

    Returns sharding on each batch dim (all dims except the last 2), with all
    outputs and inputs sharded on the same dim.
    """
    strategies: list[list[Placement | _ShardingPlaceholder]] = []
    for dim in range(ndim - 2):
        strategies.append([_ShardingPlaceholder(dim)] * n_placements)
    return strategies


def _get_ndim(tensor_meta: Any) -> int:
    if not isinstance(tensor_meta, TensorMeta):
        raise AssertionError(f"Expected TensorMeta, got {type(tensor_meta)}")
    return len(tensor_meta.shape)


@register_single_dim_strategy(
    [
        aten.cholesky.default,
        aten.cholesky_inverse.default,
        aten.linalg_matrix_exp.default,
        aten.cholesky_solve.default,
        aten.linalg_householder_product.default,
        aten.linalg_solve_triangular.default,
        aten.linalg_ldl_solve.default,
        aten.linalg_lu_solve.default,
        aten.ormqr.default,
        aten.geqrf.default,
        aten.linalg_cholesky_ex.default,
        aten.linalg_eig.default,
        aten.linalg_inv_ex.default,
        aten.triangular_solve.default,
        aten._linalg_det.default,
        aten.linalg_ldl_factor_ex.default,
        aten.linalg_lu.default,
        aten.linalg_lu_factor_ex.default,
        aten.lu_unpack.default,
        aten._linalg_slogdet.default,
        aten._linalg_solve_ex.default,
        aten._linalg_check_errors.default,
    ],
    schema_info=RuntimeSchemaInfo(1),
)
def linalg_batch_dim_strategy(
    op: torch._ops.OpOverload,
    args_schema: tuple[Any, ...],
    kwargs_schema: dict[str, Any],
) -> list[list[Placement | _ShardingPlaceholder]]:
    ndim = _get_ndim(args_schema[0])
    if op not in _LINALG_NUM_PLACEMENTS:
        raise AssertionError(f"Expected op in _LINALG_NUM_PLACEMENTS, got {op}")

    n_placements = _LINALG_NUM_PLACEMENTS[op]
    strategies = _linalg_batch_dim_strategies(ndim, n_placements=n_placements)

    if op == aten.linalg_solve_triangular.default:
        # solve_triangular(A, B) -> result: linear in B
        strategies.append([Partial(), Replicate(), Partial()])
        strategies.append([Partial("avg"), Replicate(), Partial("avg")])
        # A replicated, B sharded on batch dims (B may have more batch dims than A)
        ndim_b = _get_ndim(args_schema[1])
        for dim in range(ndim_b - 2):
            strategies.append(
                [_ShardingPlaceholder(dim), Replicate(), _ShardingPlaceholder(dim)]
            )
    elif op == aten.cholesky_solve.default:
        # cholesky_solve(B, A) -> result  (B is arg0)
        strategies.append([Partial(), Partial(), Replicate()])
    elif op == aten.linalg_lu_solve.default:
        # linalg_lu_solve(LU, pivots, B) -> result
        strategies.append([Partial(), Replicate(), Replicate(), Partial()])
    elif op == aten.linalg_ldl_solve.default:
        # linalg_ldl_solve(LD, pivots, B) -> result
        strategies.append([Partial(), Replicate(), Replicate(), Partial()])
    elif op == aten.ormqr.default:
        # ormqr(a, tau, C) -> result  (linear in C)
        strategies.append([Partial(), Replicate(), Replicate(), Partial()])
    elif op == aten._linalg_solve_ex.default:
        # _linalg_solve_ex(A, B) -> (result, LU, pivots, info)
        strategies.append(
            [Partial(), Replicate(), Replicate(), Replicate(), Replicate(), Partial()]
        )

    return strategies


# linalg_pinv has optional tensor kwargs atol, rtol (scalar tensors when present).
# Schema: (Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian) -> Tensor
# When atol/rtol are None, num_inputs=1; when present, they add to num_inputs.
@register_single_dim_strategy(
    [aten.linalg_pinv.atol_rtol_tensor],
    schema_info=RuntimeSchemaInfo(1),
)
def linalg_pinv_strategy(
    op: torch._ops.OpOverload,
    args_schema: tuple[Any, ...],
    kwargs_schema: dict[str, Any],
) -> list[list[Placement | _ShardingPlaceholder]]:
    ndim = _get_ndim(args_schema[0])
    # Count optional tensor kwargs that are actually present
    extra_tensors = sum(
        isinstance(kwargs_schema.get(k), TensorMeta) for k in ("atol", "rtol")
    )
    strategies: list[list[Placement | _ShardingPlaceholder]] = []
    for dim in range(ndim - 2):
        s: list[Placement | _ShardingPlaceholder] = [
            _ShardingPlaceholder(dim),
            _ShardingPlaceholder(dim),
        ]
        # atol, rtol are scalar tensors — always Replicate
        s.extend([Replicate()] * extra_tensors)
        strategies.append(s)
    return strategies


# linalg_cross is pointwise on every dim except the cross-product dim (which
# must be size 3).  Shard on any other dim.
@register_single_dim_strategy(
    [aten.linalg_cross.default],
    schema_info=RuntimeSchemaInfo(1),
)
def linalg_cross_strategy(
    op: torch._ops.OpOverload,
    args_schema: tuple[Any, ...],
    kwargs_schema: dict[str, Any],
) -> list[list[Placement | _ShardingPlaceholder]]:
    ndim = _get_ndim(args_schema[0])
    cross_dim = kwargs_schema.get("dim", -1) % ndim
    strategies: list[list[Placement | _ShardingPlaceholder]] = []
    for dim in range(ndim):
        if dim == cross_dim:
            continue
        strategies.append([_ShardingPlaceholder(dim)] * 3)
    return strategies


# ---------------------------------------------------------------------------
# Interpolation / upsample / pooling ops
#
# These ops operate on spatial dims and are safely shardable on batch (dim 0)
# and channel (dim 1). grid_sampler is batch-only because the grid tensor has
# no channel dimension.
# ---------------------------------------------------------------------------


@register_single_dim_strategy(
    [
        aten.upsample_nearest1d.default,
        aten.upsample_nearest2d.default,
        aten.upsample_nearest3d.default,
        aten._upsample_nearest_exact1d.default,
        aten._upsample_nearest_exact2d.default,
        aten._upsample_nearest_exact3d.default,
        aten._upsample_bilinear2d_aa.default,
        aten.upsample_bicubic2d.default,
        aten.upsample_bilinear2d.default,
        aten.upsample_linear1d.default,
        aten.upsample_trilinear3d.default,
    ],
    schema_info=RuntimeSchemaInfo(1),
)
def interp_upsample_1out_1in_strategy(
    op: torch._ops.OpOverload,
    args_schema: tuple[Any, ...],
    kwargs_schema: dict[str, Any],
) -> list[list[Placement | _ShardingPlaceholder]]:
    # 1 output + 1 input = 2 placements; shard on batch (0) and channel (1)
    # Upsample is a linear transformation so Partial(sum/avg) is valid.
    return [
        [_ShardingPlaceholder(0)] * 2,
        [_ShardingPlaceholder(1)] * 2,
        [Partial("sum"), Partial("sum")],
        [Partial("avg"), Partial("avg")],
    ]


@register_single_dim_strategy(
    [
        aten.max_unpool2d.default,
        aten.max_unpool3d.default,
        aten._adaptive_avg_pool2d_backward.default,
    ],
    schema_info=RuntimeSchemaInfo(1),
)
def interp_pool_1out_2in_strategy(
    op: torch._ops.OpOverload,
    args_schema: tuple[Any, ...],
    kwargs_schema: dict[str, Any],
) -> list[list[Placement | _ShardingPlaceholder]]:
    # 1 output + 2 inputs = 3 placements; shard on batch (0) and channel (1)
    return [
        [_ShardingPlaceholder(0)] * 3,
        [_ShardingPlaceholder(1)] * 3,
    ]


@register_single_dim_strategy(
    [aten.max_pool2d_with_indices_backward.default],
    schema_info=RuntimeSchemaInfo(1),
)
def pool_backward_strategy(
    op: torch._ops.OpOverload,
    args_schema: tuple[Any, ...],
    kwargs_schema: dict[str, Any],
) -> list[list[Placement | _ShardingPlaceholder]]:
    # max_pool2d_with_indices_backward(grad_output, self, ..., indices) -> grad_input
    # 1 output + 3 tensor inputs = 4 placements
    # Order: [output, grad_output, self, indices]
    input_meta = cast(TensorMeta, args_schema[0])
    strategies: list[list[Placement | _ShardingPlaceholder]] = [
        [_ShardingPlaceholder(0)] * 4,
    ]
    if len(input_meta.shape) >= 4:  # batched: (N, C, H, W)
        strategies.append([_ShardingPlaceholder(1)] * 4)
    # The backward is linear in grad_output, so P(sum/avg) pass through.
    # indices must be replicated (integer positions, not reducible).
    # self is only used for shape, so replicate it too.
    r = Replicate()
    for reduce_op in ("sum", "avg"):
        p = Partial(reduce_op)
        strategies.append([p, p, r, r])
    return strategies


@register_single_dim_strategy(
    [aten.grid_sampler_2d.default, aten.grid_sampler_3d.default],
    schema_info=RuntimeSchemaInfo(1),
)
def grid_sampler_strategy(
    op: torch._ops.OpOverload,
    args_schema: tuple[Any, ...],
    kwargs_schema: dict[str, Any],
) -> list[list[Placement | _ShardingPlaceholder]]:
    # grid_sampler_{2,3}d(input[N,C,...], grid[N,...,{2,3}]) -> output[N,C,...]
    # grid has no channel dim, so only batch sharding applies to both inputs.
    # Linear in input: P(sum/avg) on input with replicated grid is valid.
    return [
        [_ShardingPlaceholder(0)] * 3,
        [Partial("sum"), Partial("sum"), Replicate()],
        [Partial("avg"), Partial("avg"), Replicate()],
    ]


@register_single_dim_strategy(
    [aten.grid_sampler_2d_backward.default, aten.grid_sampler_3d_backward.default],
    schema_info=RuntimeSchemaInfo(1),
)
def grid_sampler_backward_strategy(
    op: torch._ops.OpOverload,
    args_schema: tuple[Any, ...],
    kwargs_schema: dict[str, Any],
) -> list[list[Placement | _ShardingPlaceholder]]:
    # grid_sampler_{2,3}d_backward: 2 outputs (grad_input, grad_grid) + 3 inputs = 5 placements, batch-only
    return [[_ShardingPlaceholder(0)] * 5]
