# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from collections.abc import Callable, Sequence, Sized
from typing import cast

import torch
from torch._ops import OpOverload
from torch._prims_common import IntLike
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor._op_schema import (
    ArgsType,
    KwargsType,
    OpSchema,
    OpSpec,
    OpStrategy,
    PlacementList,
    RuntimeSchemaInfo,
    StrategyType,
    TensorMeta,
    TupleStrategy,
)
from torch.distributed.tensor._ops.single_dim_strategy import (
    _ShardingPlaceholder,
    register_single_dim_strategy,
)
from torch.distributed.tensor._ops.utils import (
    expand_to_full_mesh_op_strategy,
    generate_redistribute_costs,
    is_tensor_dim_sharded,
    is_tensor_partial,
    normalize_dim,
    register_op_strategy,
    shift_shard_dims_after_insert,
    shift_shard_dims_after_remove,
)
from torch.distributed.tensor.placement_types import (
    _is_shard_like,
    _MaskPartial,
    Partial,
    Placement,
    Replicate,
    Shard,
)
from torch.fx.experimental.symbolic_shapes import statically_known_true


aten = torch.ops.aten
prims = torch.ops.prims


def propagate_single_input_strategy(op_schema: OpSchema) -> StrategyType:
    # For ops with a single tensor input, we perform a 1:1 mapping such that
    # for each strategy that the input supports, we create a corresponding strategy.
    # Note: this may be a complete waste of work, because it should be equivalent to
    # `return first_input_strategy` (unless creating a deep copy is important for some reason)
    if len([s for s in op_schema.args_schema if isinstance(s, OpStrategy)]) != 1:
        raise AssertionError(
            "propagate_single_input_strategy only works for single-tensor-input ops"
        )
    first_input_strategy = op_schema.args_schema[0]
    if not isinstance(first_input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(first_input_strategy)}")
    return OpStrategy(
        [
            OpSpec(
                output_specs=DTensorSpec(
                    mesh=first_input_strategy.mesh,
                    placements=strategy.output_spec.placements,
                    tensor_meta=strategy.output_spec.tensor_meta,
                ),
                input_specs=[
                    DTensorSpec(
                        mesh=first_input_strategy.mesh,
                        placements=strategy.output_spec.placements,
                        tensor_meta=strategy.output_spec.tensor_meta,
                    )
                ],
                redistribute_cost=[
                    generate_redistribute_costs(
                        first_input_strategy, strategy.output_spec
                    )
                ],
            )
            for strategy in first_input_strategy.strategies
        ]
    )


register_op_strategy(
    [
        aten.clone.default,
        aten.contiguous.default,
        aten.detach.default,
        aten.alias.default,
        aten.fill_.Scalar,
        aten.view.dtype,
        aten.zero_.default,
        prims.view_of.default,
    ]
)(propagate_single_input_strategy)


def _partial_needs_reduce_for_dtype_cast(
    reduce_op: str,
    src_dtype: torch.dtype,
    target_dtype: torch.dtype | None,
) -> bool:
    """Return True when reduce_op does not commute with the dtype cast."""
    if target_dtype is None or src_dtype == target_dtype:
        return False
    if target_dtype == torch.bool:
        return True
    if reduce_op in ("max", "min"):
        return False
    return src_dtype.is_floating_point and not target_dtype.is_floating_point


@register_single_dim_strategy(
    aten._to_copy.default,
    schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"]),
    allow_unbacked_sharding=True,
    allow_uneven_sharding=True,
)
def _to_copy_single_dim_strategy(
    op: OpOverload, args_schema: ArgsType, kwargs_schema: KwargsType
) -> list[list[Placement | _ShardingPlaceholder]]:
    input_meta = cast(TensorMeta, args_schema[0])
    src_dtype = input_meta.dtype
    target_dtype = cast(torch.dtype | None, kwargs_schema.get("dtype", None))

    strategies: list[list[Placement | _ShardingPlaceholder]] = []
    for dim in range(len(input_meta.shape)):
        strategies.append([_ShardingPlaceholder(dim), _ShardingPlaceholder(dim)])
    for reduce_op in Partial.ALL_REDUCE_OPS:
        if not _partial_needs_reduce_for_dtype_cast(reduce_op, src_dtype, target_dtype):
            strategies.append([Partial(reduce_op), Partial(reduce_op)])
    return strategies


@register_op_strategy(
    [
        aten.equal.default,
        aten.is_same_size.default,
    ]
)
def equal_strategy(op_schema: OpSchema) -> StrategyType:
    # equal_strategy deals with ops that comparing two tensor, we need to make sure
    # sharding layout the same with two operands, we choose to follow the arg with max
    # num of shards, still keep is_same_size here for completeness as they share the
    # same strategy in theory.
    mesh = op_schema.get_mesh_from_args()
    self_strategy, other_strategy = op_schema.args_schema
    if not isinstance(self_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(self_strategy)}")
    if not isinstance(other_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(other_strategy)}")

    # If either tensor is 0-dimensional (scalar), we must use Replicate for both
    if self_strategy.ndim == 0 or other_strategy.ndim == 0:
        replicate_spec = DTensorSpec(
            mesh=mesh,
            placements=tuple(Replicate() for _ in range(mesh.ndim)),
        )
        return OpStrategy([OpSpec(output_specs=replicate_spec)])

    select_strategy = (
        self_strategy
        if self_strategy.max_num_shards() >= other_strategy.max_num_shards()
        else other_strategy
    )
    equal_strategy = OpStrategy([])

    for arg_strategy in select_strategy.strategies:
        arg_spec = arg_strategy.output_spec
        if is_tensor_partial(arg_spec):
            # if the arg_spec have partial, reshard to replicate
            # otherwise local shard tensor comparison would be invalid
            output_spec = DTensorSpec(
                mesh=mesh,
                placements=tuple(
                    Replicate() if isinstance(p, Partial) else p
                    for p in arg_spec.placements
                ),
            )
            equal_strategy.strategies.append(OpSpec(output_specs=output_spec))
        else:
            equal_strategy.strategies.append(OpSpec(arg_spec))
    return equal_strategy


register_op_strategy(
    aten.empty_like.default, schema_info=RuntimeSchemaInfo(1, ["dtype"])
)(propagate_single_input_strategy)


@register_op_strategy(
    [
        aten.ones_like.default,
        aten.rand_like.default,
        aten.randn_like.default,
        aten.zeros_like.default,
    ],
    schema_info=RuntimeSchemaInfo(1, ["dtype"]),
)
@register_op_strategy(
    [aten.full_like.default],
    schema_info=RuntimeSchemaInfo(2, ["dtype"]),
)
@register_op_strategy(
    [
        aten.randint_like.default,
        aten.randint_like.low_dtype,
        aten.randint_like.low_dtype_out,
    ],
    schema_info=RuntimeSchemaInfo(3, ["dtype"]),
)
def create_like_strategy(op_schema: OpSchema) -> StrategyType:
    # create_like_strategy deals with ops that creating tensors with same
    # shape as input, but with specific content that does not depend on
    # the input, we can propagate sharding, but we have to make sure we
    # move from partial to replicated.
    select_strategy = op_schema.args_schema[0]
    create_like_strategy = OpStrategy([])
    if not isinstance(select_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(select_strategy)}")
    for arg_strategy in select_strategy.strategies:
        arg_spec = arg_strategy.output_spec
        output_spec = DTensorSpec(
            mesh=select_strategy.mesh,
            placements=tuple(
                Replicate() if isinstance(p, Partial) else p
                for p in arg_spec.placements
            ),
            tensor_meta=arg_spec.tensor_meta,
        )
        create_like_strategy.strategies.append(
            OpSpec(
                output_specs=output_spec,
                input_specs=(arg_spec,),
                redistribute_cost=[
                    generate_redistribute_costs(select_strategy, arg_spec),
                ],
            )
        )

    return create_like_strategy


@register_op_strategy(
    [
        aten.new_empty.default,
        aten.new_full.default,
        aten.new_ones.default,
        aten.new_zeros.default,
        aten.new_empty_strided.default,
    ],
    schema_info=RuntimeSchemaInfo(1, ["dtype"]),
)
def new_factory_strategy(op_schema: OpSchema) -> StrategyType:
    # Currently there are two strategies:
    # 1. let the output be replicated
    # 2. let the output follow the input if input and output have the same shape
    input_strategy = op_schema.args_schema[0]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")

    mesh = input_strategy.mesh
    input_shape = input_strategy.shape
    output_shape = op_schema.args_schema[1]
    if not isinstance(output_shape, list):
        raise AssertionError(f"Expected list, got {type(output_shape)}")

    new_factory_strategy = OpStrategy([])
    for arg_strategy in input_strategy.strategies:
        input_spec = arg_strategy.output_spec
        replica_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim))
        new_factory_strategy.strategies.append(
            OpSpec(
                output_specs=replica_spec,
                input_specs=(input_spec,),
                redistribute_cost=[[0.0] * len(input_strategy.strategies)],
            )
        )

        if tuple(input_shape) == tuple(output_shape) and input_spec.is_sharded():
            new_factory_strategy.strategies.append(
                OpSpec(
                    output_specs=input_spec,
                    input_specs=(input_spec,),
                    # encouraging new tensor placement to be the same as input
                    redistribute_cost=[[-0.1] * len(input_strategy.strategies)],
                )
            )

    return new_factory_strategy


@register_single_dim_strategy(aten.bucketize.Tensor)
def bucketize_single_dim_strategy(
    op: OpOverload, args_schema: ArgsType, kwargs_schema: KwargsType
) -> list[list[Placement | _ShardingPlaceholder]]:
    """Bucketize returns indices into a sorted boundary tensor.

    Three families of strategies:
    1. Shard the input (and output) on any dim, keep boundaries replicated.
    2. Shard boundaries on dim 0, replicate input, output is Partial("sum").
       Each rank counts how many of its local boundary values each input
       element exceeds; summing across ranks gives the correct global index.
    3. Partial("max") or Partial("min") input with replicated boundaries.
       Bucketize is monotonically non-decreasing in its input, so reducing
       local bucket indices with max (or min) across ranks gives the same
       result as bucketizing the reduced input values.
    """
    input_meta, _boundaries_meta = args_schema
    if not isinstance(input_meta, TensorMeta):
        raise AssertionError(f"Expected TensorMeta, got {type(input_meta)}")
    strategies: list[list[Placement | _ShardingPlaceholder]] = []
    for dim in range(len(input_meta.shape)):
        strategies.append(
            [_ShardingPlaceholder(dim), _ShardingPlaceholder(dim), Replicate()]
        )
    strategies.append([Partial("sum"), Replicate(), _ShardingPlaceholder(0)])
    for reduce_op in ("max", "min"):
        strategies.append([Partial(reduce_op), Partial(reduce_op), Replicate()])
    return strategies


@register_op_strategy(aten.select.int, schema_info=RuntimeSchemaInfo(1))
def select_int_strategy(op_schema: OpSchema) -> StrategyType:
    """
    In this select op, first determine the input specs, then determine the output specs.
    - Input specs:
        - If the input is sharded on the selected dim, unshard it and change to replicate.
        - Otherwise, keep the original input specs.
    - Output specs:
        - It checks the input specs with the following cases:
        - Case 1 shard_dim == selected_dim: not possible as the input is already unsharded.
        - Case 2 shard_dim < selected_dim: keep the input specs.
        - Case 3 shard_dim > selected_dim: shard_dim -= 1.
    """
    input_strategy = op_schema.args_schema[0]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
    if len(op_schema.args_schema) != 3:
        raise AssertionError(f"Expected 3 args, got {len(op_schema.args_schema)}")
    selected_dim, index = (
        cast(int, op_schema.args_schema[1]),
        cast(int, op_schema.args_schema[2]),
    )
    input_shape = input_strategy.shape
    input_ndim = input_strategy.ndim
    selected_dim = normalize_dim(selected_dim, input_ndim)
    index = normalize_dim(index, input_shape[selected_dim])

    select_strategy = OpStrategy([])
    for arg_strategy in input_strategy.strategies:
        arg_spec = arg_strategy.output_spec

        # determine input spec
        input_specs = arg_spec
        if is_tensor_dim_sharded(arg_spec, dim=selected_dim):
            # if input is sharded on the selected dim, need to unshard it, change to replicate
            arg_target_placements = unshard_tensor_dim(
                arg_spec.placements, dim=selected_dim
            )
            input_specs = DTensorSpec(arg_spec.mesh, arg_target_placements)  # R

        # determine output spec
        output_specs = input_specs
        if input_specs.is_sharded():
            # handle cases with sharded_dim != selected_dim
            output_placements = shift_shard_dims_after_remove(
                input_specs.placements, selected_dim
            )
            output_specs = DTensorSpec(
                arg_spec.mesh, placements=tuple(output_placements)
            )

        select_strategy.strategies.append(
            OpSpec(
                output_specs=output_specs,
                input_specs=(input_specs,),
            )
        )
    return select_strategy


@register_op_strategy(
    aten.select_backward.default,
    schema_info=RuntimeSchemaInfo(1),
)
def select_backward_strategy(op_schema: OpSchema) -> OpStrategy:
    # func: select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor
    args_schema = op_schema.args_schema
    input_strategy, dim = args_schema[0], args_schema[2]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {input_strategy}")
    if not isinstance(dim, int):
        raise AssertionError(f"Expected int, got {type(dim)}")
    output_strategies: list[OpSpec] = []
    for placement_strategy in input_strategy.strategies:
        input_spec = placement_strategy.output_spec
        # NOTE: shard_dim is guaranteed to exist because
        # grad_input has one more dim than grad_output
        output_placements = shift_shard_dims_after_insert(input_spec.placements, dim)
        output_specs = DTensorSpec(input_spec.mesh, tuple(output_placements))
        output_strategies.append(
            OpSpec(output_specs=output_specs, input_specs=(input_spec,))
        )
    return OpStrategy(output_strategies)


@register_op_strategy(aten.slice.Tensor, schema_info=RuntimeSchemaInfo(1))
def gen_slice_strategy(op_schema: OpSchema) -> StrategyType:
    """Forward all shardings except the slice dimension."""
    defaults = (None, 0, None, None, 1)
    input_strategy, dim, start, end, step = (
        op_schema.args_schema + defaults[len(op_schema.args_schema) :]
    )
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")

    mesh = input_strategy.mesh
    input_shape = input_strategy.shape
    input_ndim = input_strategy.ndim
    if not isinstance(dim, int):
        raise AssertionError(f"Expected int, got {type(dim)}")
    if start is None:
        start = 0
    if end is None or statically_known_true(end > input_shape[dim]):
        end = input_shape[dim]
    if not isinstance(start, IntLike):
        raise AssertionError(f"Expected IntLike, got {type(start)}")
    if not isinstance(end, IntLike):
        raise AssertionError(f"Expected IntLike, got {type(end)}")
    if not isinstance(step, IntLike):
        raise AssertionError(f"Expected IntLike, got {type(step)}")

    # normalize args
    slice_dim = normalize_dim(dim, input_ndim)  # type: ignore[arg-type]
    start = normalize_dim(start, input_shape[dim])  # type: ignore[arg-type]
    end = normalize_dim(end, input_shape[dim])  # type: ignore[arg-type]

    statically_redundant_slice = (
        statically_known_true(start == 0)
        and statically_known_true(end == input_shape[dim])
        and statically_known_true(step == 1)
    )

    slice_strategy = OpStrategy([])

    for arg_strategy in input_strategy.strategies:
        arg_spec = arg_strategy.output_spec
        if (
            not is_tensor_dim_sharded(arg_spec, dim=slice_dim)
            or statically_redundant_slice
        ):
            # only add the strategy if the slice dim is not sharded
            out_spec = DTensorSpec(mesh, arg_spec.placements)
            slice_strategy.strategies.append(
                OpSpec(
                    output_specs=out_spec,
                    input_specs=(arg_spec,),
                    redistribute_cost=[[0.0] * len(input_strategy.strategies)],
                )
            )
    if not slice_strategy.strategies:
        # if all strategies are filtered out, unsharding all specs on slice dim
        # of the input strategy, and use that as the op strategy
        for arg_strategy in input_strategy.strategies:
            arg_spec = arg_strategy.output_spec
            unshard_spec = DTensorSpec(
                mesh, unshard_tensor_dim(arg_spec.placements, dim=slice_dim)
            )
            slice_strategy.strategies.append(
                OpSpec(
                    output_specs=unshard_spec,
                    redistribute_cost=[
                        generate_redistribute_costs(input_strategy, unshard_spec)
                    ],
                )
            )
    return slice_strategy


@register_op_strategy(
    aten.slice_backward.default,
    schema_info=RuntimeSchemaInfo(1),
)
def slice_backward_rules(op_schema: OpSchema) -> OpStrategy:
    # func: slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor
    args_schema = op_schema.args_schema
    input_strategy, dim = args_schema[0], args_schema[2]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {input_strategy}")
    output_strategies: list[OpSpec] = []
    for placement_strategy in input_strategy.strategies:
        output_spec = placement_strategy.output_spec
        new_placements: list[Placement] = []
        for placement in output_spec.placements:
            # Redistribute to replicate only if the dim is sharded and matches the slice dim
            if _is_shard_like(placement) and placement.dim == dim:
                new_placements.append(Replicate())
            else:
                new_placements.append(placement)
        new_spec = DTensorSpec(output_spec.mesh, tuple(new_placements))
        redistribute_cost = [generate_redistribute_costs(input_strategy, new_spec)]
        new_strategy = OpSpec(
            output_specs=new_spec, redistribute_cost=redistribute_cost
        )
        output_strategies.append(new_strategy)
    return OpStrategy(output_strategies)


def unshard_tensor_dim(
    placements: Sequence[Placement], dim: int
) -> tuple[Placement, ...]:
    """Disallow the given tensor dimension to be sharded."""
    return tuple(
        p if (not _is_shard_like(p) or p.dim != dim) else Replicate()
        for p in placements
    )


def replicate_tensor_dim(
    placements: Sequence[Placement], dim: int
) -> tuple[Placement, ...]:
    """Force the given tensor dimension to be replicated."""
    return tuple(
        Replicate() if p.is_partial() or (_is_shard_like(p) and p.dim == dim) else p
        for p in placements
    )


@register_op_strategy(aten.slice_scatter.default, schema_info=RuntimeSchemaInfo(2))
def gen_slice_scatter_strategy(op_schema: OpSchema) -> StrategyType:
    # 1. number of dimensions in input and src need to match.
    # 2. number of elements on all non-dim need to match between input and src.
    # 3. number of elements in src in dim need to match the slice size.
    # Given the above:
    # - We suggest for src to follow the sharding of input, except on the scatter dimension,
    #   where our best bet for now is to make them replicated as a fall-back.
    #   TODO: Ideally we'd like to make sure the output is re-sharded afterwards to keep input sharding.
    mesh = op_schema.get_mesh_from_args()
    input_strategy = op_schema.args_schema[0]
    src_strategy = op_schema.args_schema[1]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
    if not isinstance(src_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(src_strategy)}")
    input_ndim = input_strategy.ndim
    slice_dim = (
        cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else 0
    )
    slice_dim = normalize_dim(slice_dim, input_ndim)

    slice_scatter_strategy = OpStrategy([])
    # by default follow the input strategy for both input and src
    for arg_strategy in input_strategy.strategies:
        arg_spec = arg_strategy.output_spec
        if not (
            is_tensor_dim_sharded(arg_spec, dim=slice_dim)
            or is_tensor_partial(arg_spec)
        ):
            input_spec = DTensorSpec(mesh, arg_spec.placements, arg_spec.tensor_meta)
            # TODO: need to relax the constraint to src
            src_spec = DTensorSpec(mesh, arg_spec.placements)
            # only add the strategy if the slice_scatter dim is not sharded or partial
            slice_scatter_strategy.strategies.append(
                OpSpec(
                    output_specs=arg_spec,
                    input_specs=(input_spec, src_spec),
                    redistribute_cost=[
                        generate_redistribute_costs(input_strategy, input_spec),
                        generate_redistribute_costs(src_strategy, src_spec),
                    ],
                )
            )

    if not slice_scatter_strategy.strategies:
        # if all strategies are filtered out, replicating all specs on slice_scatter dim
        # of the input strategy, and use that as the op strategy
        for arg_strategy in input_strategy.strategies:
            arg_spec = arg_strategy.output_spec
            new_placement = replicate_tensor_dim(arg_spec.placements, dim=slice_dim)
            input_spec = DTensorSpec(mesh, new_placement)
            src_spec = DTensorSpec(mesh, new_placement)
            slice_scatter_strategy.strategies.append(
                OpSpec(
                    output_specs=input_spec,
                    input_specs=(input_spec, src_spec),
                    redistribute_cost=[
                        generate_redistribute_costs(input_strategy, input_spec),
                        generate_redistribute_costs(src_strategy, src_spec),
                    ],
                )
            )
    return slice_scatter_strategy


@register_single_dim_strategy(
    [aten.select_scatter.default],
    schema_info=RuntimeSchemaInfo(1),
)
def select_scatter_single_dim_strategy(
    op: OpOverload,
    args_schema: ArgsType,
    kwargs_schema: KwargsType,
) -> list[list[Placement | _ShardingPlaceholder]]:
    input_meta = args_schema[0]
    if not isinstance(input_meta, TensorMeta):
        raise AssertionError(f"Expected TensorMeta, got {type(input_meta)}")
    ndim = len(input_meta.shape)
    dim = normalize_dim(cast(int, args_schema[2]), ndim)
    # [output, self, src] — src has the select dim removed
    strategies: list[list[Placement | _ShardingPlaceholder]] = []
    for d in range(ndim):
        if d == dim:
            continue
        strategies.append(
            [
                _ShardingPlaceholder(d),
                _ShardingPlaceholder(d),
                _ShardingPlaceholder(d if d < dim else d - 1),
            ]
        )
    return strategies


@register_single_dim_strategy(
    [aten.diagonal_scatter.default],
    schema_info=RuntimeSchemaInfo(1),
)
def diagonal_scatter_single_dim_strategy(
    op: OpOverload,
    args_schema: ArgsType,
    kwargs_schema: KwargsType,
) -> list[list[Placement | _ShardingPlaceholder]]:
    input_meta = args_schema[0]
    if not isinstance(input_meta, TensorMeta):
        raise AssertionError(f"Expected TensorMeta, got {type(input_meta)}")
    ndim = len(input_meta.shape)
    # schema: (self, src, offset=0, dim1=0, dim2=1)
    dim1 = cast(int, args_schema[3]) if len(args_schema) > 3 else 0
    dim2 = cast(int, args_schema[4]) if len(args_schema) > 4 else 1
    dim1 = normalize_dim(dim1, ndim)
    dim2 = normalize_dim(dim2, ndim)
    min_d, max_d = min(dim1, dim2), max(dim1, dim2)
    # [output, self, src] — src has dim1/dim2 removed and diagonal appended
    strategies: list[list[Placement | _ShardingPlaceholder]] = []
    for d in range(ndim):
        if d in (dim1, dim2):
            continue
        removed = (1 if d > min_d else 0) + (1 if d > max_d else 0)
        strategies.append(
            [
                _ShardingPlaceholder(d),
                _ShardingPlaceholder(d),
                _ShardingPlaceholder(d - removed),
            ]
        )
    return strategies


@register_op_strategy(aten._local_scalar_dense.default)
def replica_only_strategy(op_schema: OpSchema) -> StrategyType:
    """Only allow replication on the input/output."""
    input_strategy = op_schema.args_schema[0]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
    mesh = input_strategy.mesh
    replicate_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim))
    return OpStrategy([OpSpec(replicate_spec)])


@register_op_strategy(
    [
        aten.scatter_.value,
        aten.scatter.value,
        aten.scatter_.src,
        aten.scatter.src,
    ],
    schema_info=RuntimeSchemaInfo(1),
)
def scatter_strategy(op_schema: OpSchema) -> StrategyType:
    mesh = op_schema.get_mesh_from_args()
    single_mesh_dim_strategies = []

    # placement list stores placements of [output, input, index, src]
    # first we always have replicate all for inputs and output
    if len(op_schema.args_strategy) < 3:
        # scatter_.src/scatter.src with src be float number instead of tensor
        all_replicate: PlacementList = [Replicate()] * 3
    else:
        all_replicate = [Replicate()] * 4
    single_mesh_dim_strategies.append(all_replicate)

    # TODO: see if we can support input sharding pattern
    op_strategy = expand_to_full_mesh_op_strategy(
        mesh,
        op_schema,
        single_mesh_dim_strategies,
        inplace_op=op_schema.is_inplace_op(),
    )
    return op_strategy


@register_op_strategy(aten.scatter_add.default, schema_info=RuntimeSchemaInfo(1))
def scatter_add_strategy(op_schema: OpSchema) -> StrategyType:
    input_strategy = op_schema.args_schema[0]
    dim = op_schema.args_schema[1]
    index_strategy = op_schema.args_schema[2]

    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
    if not isinstance(index_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(index_strategy)}")
    if not isinstance(dim, int):
        raise AssertionError(f"Expected int, got {type(dim)}")
    dim = normalize_dim(dim, input_strategy.ndim)
    mesh = input_strategy.mesh
    input_shape = input_strategy.shape
    index_shape = index_strategy.shape

    single_mesh_dim_strategies = []

    # placement list stores placements of [output, input, index, src]
    # first we always have replicate all for inputs and output
    all_replicate: PlacementList = [Replicate()] * 4
    single_mesh_dim_strategies.append(all_replicate)

    if len(input_shape) == len(index_shape):
        for d in range(len(input_shape)):
            if d != dim and input_shape[d] == index_shape[d]:
                sharding: PlacementList = [Shard(d), Shard(d), Shard(d), Shard(d)]
                single_mesh_dim_strategies.append(sharding)

    return expand_to_full_mesh_op_strategy(
        mesh, op_schema, single_mesh_dim_strategies, input_index=1
    )


@register_op_strategy(aten.gather.default, schema_info=RuntimeSchemaInfo(1))
def gather_strategy(op_schema: OpSchema) -> StrategyType:
    mesh = op_schema.get_mesh_from_args()
    input_strategy = cast(OpStrategy, op_schema.args_schema[0])
    dim = cast(int, op_schema.args_schema[1])
    dim = normalize_dim(dim, input_strategy.ndim)
    index_strategy = cast(OpStrategy, op_schema.args_schema[2])

    input_shape = input_strategy.shape
    index_shape = index_strategy.shape

    single_mesh_dim_strategies = []

    # placement list stores placements of [output, input, index]
    # first we always have replicate all for inputs and output
    all_replicate: PlacementList = [Replicate()] * 3
    single_mesh_dim_strategies.append(all_replicate)

    # input sharding, input sharded, index accepts mask partial, output follows index
    # this only works when the input is sharded on the gather dimension, and
    # index has size 1 on the gather dimension
    if dim < len(index_shape) and index_shape[dim] == 1:
        index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim)
        input_sharding: PlacementList = [
            index_partial_placement,
            Shard(dim),
            index_partial_placement,
        ]
        single_mesh_dim_strategies.append(input_sharding)

    # index sharding, input replicated, index sharded, output follows index
    # this only works when the sharding dimension is the gather dimension
    index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim)]
    single_mesh_dim_strategies.append(index_sharding)

    if len(input_shape) == len(index_shape):
        for d in range(len(input_shape)):
            if d != dim:
                sharding: PlacementList = [Shard(d), Shard(d), Shard(d)]
                single_mesh_dim_strategies.append(sharding)

    return expand_to_full_mesh_op_strategy(
        mesh, op_schema, single_mesh_dim_strategies, input_index=1
    )


def _derive_follow_placements_from_tuple_strategy(
    op: torch._ops.OpOverload,
    tuple_strategy: TupleStrategy,
) -> Sequence[Placement]:
    """
    derive the placements to follow from the tuple strategy, mainly used by
    aten.stack, aten.cat, where each operand have the same shape, and correspondingly
    expecting the same sharding
    """

    def merge_placement(
        cur_placement: Placement, new_placement: Placement
    ) -> Placement:
        # semantic if we already have a follow placement, we
        # check each placement for the current arg placement
        # to see if we want to merge/adjust the placement to follow
        # the priority: Partial -> Shard -> Replicate
        # _StridedShard.__eq__ compares both dim and split_factor,
        # so two _StridedShard with different split_factor won't match here.
        if cur_placement == new_placement:
            return cur_placement

        if cur_placement.is_partial():
            if _is_shard_like(new_placement):
                # follow new placement
                return new_placement
            elif new_placement.is_partial():
                # different partial types, we can't merge and have to replicate all here
                return Replicate()
            else:
                # follow partial
                return cur_placement
        elif _is_shard_like(cur_placement):
            if _is_shard_like(new_placement):
                # cur/new placement are different sharding (i.e. different shard dim)
                # currently fallback to replicate all args
                return Replicate()
            else:
                # for partial/replicate, follow the current shard placement
                return cur_placement
        else:
            # current replicate, just follow new placement
            return new_placement

    follow_placements: list[Placement] | None = None
    mesh = tuple_strategy.child_mesh(0)
    for arg_strategy in tuple_strategy.children:
        if not isinstance(arg_strategy, OpStrategy):
            raise AssertionError(f"Expected OpStrategy, got {type(arg_strategy)}")
        if arg_strategy.mesh != mesh:
            raise ValueError(
                f"All operands in {op} must have the same mesh, "
                f"but got {arg_strategy.mesh} and {mesh}."
            )

        for placement_strategy in arg_strategy.strategies:
            arg_placements = placement_strategy.output_spec.placements
            if follow_placements is None:
                follow_placements = list(arg_placements)
                continue
            if follow_placements is None:
                raise AssertionError(
                    "follow_placements should not be None at this point"
                )
            for mesh_idx in range(mesh.ndim):
                # merge placements with the priority
                follow_placements[mesh_idx] = merge_placement(
                    follow_placements[mesh_idx], arg_placements[mesh_idx]
                )
    if follow_placements is None:
        raise AssertionError("follow placements should not be None!")
    return follow_placements


@register_op_strategy(aten.stack.default, RuntimeSchemaInfo(1, needs_pytree=True))
def stack_strategy(op_schema: OpSchema) -> StrategyType:
    args_schema = op_schema.args_schema
    input_tuple_strategy = args_schema[0]
    if not isinstance(input_tuple_strategy, TupleStrategy):
        raise AssertionError(f"Expected TupleStrategy, got {input_tuple_strategy}")
    input_strategies: list[OpStrategy] = []
    for child in input_tuple_strategy.children:
        if not isinstance(child, OpStrategy):
            raise AssertionError(f"Expected OpStrategy, got {child}")
        input_strategies.append(child)
    first_input_strategy = input_strategies[0]
    common_input_ndim = first_input_strategy.ndim
    dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0
    # normalize the dim to be within the output ndim (input ndim + 1),
    # since stack inserts a new dimension
    dim = normalize_dim(dim, common_input_ndim + 1)

    mesh = first_input_strategy.mesh

    follow_placements = _derive_follow_placements_from_tuple_strategy(
        op_schema.op, input_tuple_strategy
    )

    # create op strategy base on the follow placements
    op_strategy = OpStrategy([])

    input_specs = tuple(
        DTensorSpec(mesh, tuple(follow_placements))
        for _ in range(len(input_tuple_strategy.children))
    )

    # stack op would "insert" new dim, so all sharded dim >= the inserted dim need to
    # be normalized with the new Shard placement
    follow_placements = shift_shard_dims_after_insert(follow_placements, dim)
    output_spec = DTensorSpec(mesh, tuple(follow_placements))
    redistribute_cost = [
        generate_redistribute_costs(input_strategies[i], input_specs[i])
        for i in range(len(input_specs))
    ]
    op_strategy.strategies.append(
        OpSpec(
            output_specs=output_spec,
            input_specs=input_specs,
            redistribute_cost=redistribute_cost,
        )
    )
    return op_strategy


# TODO enable in a separate PR along with more extensive validation.
# currently just used in test_single_dim_strategy.py to help validate the single-dim expansion infra
# @register_single_dim_strategy(aten.cat.default, RuntimeSchemaInfo(1, needs_pytree=True))
def cat_single_dim_strategy(
    op: OpOverload, args_schema: ArgsType, kwargs_schema: KwargsType
) -> list[list[Placement | _ShardingPlaceholder]]:
    input_list = args_schema[0]
    # unfortunate naming, but yes it's a TensorList input, and we represent it as a tuple of TensorMeta
    if not isinstance(input_list, (tuple, list)):
        raise AssertionError(type(input_list))
    if not all(isinstance(tm, TensorMeta) for tm in input_list):
        raise AssertionError

    if isinstance(input_list, list):
        input_list = tuple(input_list)

    num_inputs = len(input_list)
    ndim_set = {len(meta.shape) for meta in input_list}
    if len(ndim_set) not in (1, 2):
        raise AssertionError(
            "Expected all cat inputs to be the same ndim, except empty tensors"
        )
    if len(ndim_set) == 2:
        if 0 not in ndim_set:
            raise AssertionError
    common_ndim = max(ndim_set)
    cat_dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0
    cat_dim = normalize_dim(cat_dim, common_ndim)
    single_dim_strategies = []
    for i in range(common_ndim):
        if i != cat_dim:
            single_dim_strategies.append([_ShardingPlaceholder(i)] * (1 + num_inputs))
    # pyrefly: ignore [bad-argument-type]
    single_dim_strategies.append([Partial("sum")] * (1 + num_inputs))
    # pyrefly: ignore [bad-return]
    return single_dim_strategies


@register_op_strategy(aten.cat.default, RuntimeSchemaInfo(1, needs_pytree=True))
def cat_strategy(op_schema: OpSchema) -> StrategyType:
    args_schema = op_schema.args_schema
    input_tuple_strategy = args_schema[0]
    if not isinstance(input_tuple_strategy, TupleStrategy):
        raise AssertionError(f"Expected TupleStrategy, got {input_tuple_strategy}")
    num_input_tensor = len(input_tuple_strategy.children)
    first_input_strategy = input_tuple_strategy.children[0]
    if not isinstance(first_input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {first_input_strategy}")
    common_input_ndim = first_input_strategy.ndim
    dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0
    # normalize the dim to be within the common input ndim
    dim = normalize_dim(dim, common_input_ndim)

    mesh = first_input_strategy.mesh

    op_strategy = OpStrategy([])
    # use a set to deduplicate strategies with the same placement
    strategies_placement_pool = set()
    for this_strategy in input_tuple_strategy.children:
        # check strategy of each tensor to be concatenated
        if not isinstance(this_strategy, OpStrategy):
            raise AssertionError(f"Expected OpStrategy, got {type(this_strategy)}")
        if this_strategy.mesh != mesh:
            raise AssertionError("cat op doesn't support cross mesh concatenation")
        for op_spec in this_strategy.strategies:
            # Check each OpSpec of the tensor, the placement in this OpSpec
            # is used as the exemplar strategy that other tensors and output
            # tensor should follow. We also need to deduplicate the output
            # strategy with the same placement.
            if not isinstance(op_spec, OpSpec):
                raise AssertionError(f"Expected OpSpec, got {type(op_spec)}")
            # exemplar OpSpec to follow
            exemplar_spec = op_spec.output_spec
            # check if the tensor is sharded on the concat dim
            if is_tensor_dim_sharded(exemplar_spec, dim):
                # if the tensor is sharded on the concat dim, we need to unshard it
                # first
                exemplar_placement = unshard_tensor_dim(exemplar_spec.placements, dim)
            else:
                exemplar_placement = exemplar_spec.placements
            if exemplar_placement not in strategies_placement_pool:
                strategies_placement_pool.add(exemplar_placement)
                # assert isinstance(exemplar_placement, Tuple)
                redistribute_costs = []
                input_specs = []
                for idx in range(num_input_tensor):
                    # extract the strategy for the idx tensors to build the tensor_metadata and redistribute_cost
                    that_tensor_strategy = input_tuple_strategy.children[idx]
                    if not isinstance(that_tensor_strategy, OpStrategy):
                        raise AssertionError(
                            f"Expected OpStrategy, got {type(that_tensor_strategy)}"
                        )
                    input_spec = DTensorSpec(
                        mesh,
                        exemplar_placement,
                        tensor_meta=that_tensor_strategy.strategies[
                            0
                        ].output_spec.tensor_meta,
                    )
                    input_specs.append(input_spec)
                    redistribute_costs.append(
                        generate_redistribute_costs(that_tensor_strategy, input_spec)
                    )
                op_strategy.strategies.append(
                    OpSpec(
                        output_specs=DTensorSpec(mesh, exemplar_placement),
                        input_specs=tuple(input_specs),
                        redistribute_cost=redistribute_costs,
                    )
                )
    return op_strategy


@register_single_dim_strategy(
    aten.index_select.default, schema_info=RuntimeSchemaInfo(1)
)
def index_select_single_dim_strategy(
    op: OpOverload, args_schema: ArgsType, kwargs_schema: KwargsType
) -> list[list[Placement | _ShardingPlaceholder]]:
    values_meta, dim, index_meta = args_schema
    if not isinstance(values_meta, TensorMeta):
        raise AssertionError(f"Expected TensorMeta, got {type(values_meta)}")
    if not isinstance(dim, int):
        raise AssertionError(f"Expected int, got {type(dim)}")
    dim = normalize_dim(dim, len(values_meta.shape))

    strategies: list[list[Placement | _ShardingPlaceholder]] = []

    # Shard values on any non-indexed dim (output has same ndim)
    for d in range(len(values_meta.shape)):
        if d == dim:
            continue
        strategies.append(
            [_ShardingPlaceholder(d), _ShardingPlaceholder(d), Replicate()]
        )

    # Shard index → output sharded on the indexed dim
    strategies.append([_ShardingPlaceholder(dim), Replicate(), _ShardingPlaceholder(0)])

    # Partial passthrough from values
    for reduce_op in Partial.ALL_REDUCE_OPS:
        strategies.append([Partial(reduce_op), Partial(reduce_op), Replicate()])

    return strategies


@register_single_dim_strategy(
    aten.index.Tensor, schema_info=RuntimeSchemaInfo(needs_pytree=True)
)
def index_single_dim_strategy(
    op: OpOverload, args_schema: ArgsType, kwargs_schema: KwargsType
) -> list[list[Placement | _ShardingPlaceholder]]:
    values_meta, multi_indices_meta = args_schema
    if not isinstance(values_meta, TensorMeta):
        raise AssertionError(f"Expected TensorMeta, got {type(values_meta)}")
    if not isinstance(multi_indices_meta, (list, tuple)):
        raise AssertionError(f"Expected list or tuple, got {type(multi_indices_meta)}")

    indexed_dims = [i for i, idx in enumerate(multi_indices_meta) if idx is not None]
    non_indexed_dims = [
        i for i in range(len(values_meta.shape)) if i not in set(indexed_dims)
    ]

    index_metas = [idx for idx in multi_indices_meta if idx is not None]
    if not all(isinstance(m, TensorMeta) for m in index_metas):
        raise AssertionError("Expected all index metas to be TensorMeta")
    broadcast_ndim = max(len(m.shape) for m in index_metas)
    num_indices = len(indexed_dims)

    # Determine where index output dims are inserted in the result
    all_consecutive = all(
        indexed_dims[i + 1] - indexed_dims[i] == 1 for i in range(len(indexed_dims) - 1)
    )
    insert_dim = indexed_dims[0] if all_consecutive else 0

    def values_dim_to_output_dim(d: int) -> int:
        if d < insert_dim:
            return d
        return d + broadcast_ndim - sum(1 for idx_dim in indexed_dims if d > idx_dim)

    strategies: list[list[Placement | _ShardingPlaceholder]] = []

    # Shard values on a non-indexed dim, all indices replicated
    for d in non_indexed_dims:
        out_dim = values_dim_to_output_dim(d)
        rule: list[Placement | _ShardingPlaceholder] = [_ShardingPlaceholder(out_dim)]
        rule.append(_ShardingPlaceholder(d))
        rule.extend([Replicate()] * num_indices)
        strategies.append(rule)

    # Shard indices on the same broadcast dim.  Each index tensor may
    # have a different ndim, so we map broadcast dim → tensor dim via
    # left-padding.  Tensors with size 1 on that dim are replicated
    # (broadcast semantics).
    for bd in range(broadcast_ndim):
        per_tensor: list[tuple[int, int]] = []  # (tensor_dim, size)
        for m in index_metas:
            offset = broadcast_ndim - len(m.shape)
            if bd < offset:
                per_tensor.append((-1, 1))  # implicit broadcast
            else:
                td = bd - offset
                per_tensor.append((td, m.shape[td]))
        if all(s == 1 for _, s in per_tensor):
            continue  # all broadcast-only, skip
        out_dim = bd + insert_dim
        rule: list[Placement | _ShardingPlaceholder] = [_ShardingPlaceholder(out_dim)]
        rule.append(Replicate())
        for td, s in per_tensor:
            if s > 1:
                rule.append(_ShardingPlaceholder(td))
            else:
                rule.append(Replicate())
        strategies.append(rule)

    # Partial passthrough from values
    for reduce_op in Partial.LINEAR_REDUCE_OPS:
        rule: list[Placement | _ShardingPlaceholder] = [
            Partial(reduce_op),
            Partial(reduce_op),
        ]
        rule.extend([Replicate()] * num_indices)
        strategies.append(rule)

    return strategies


@register_single_dim_strategy(
    [aten.index_put.default, aten.index_put_.default, aten._index_put_impl_.default],
    schema_info=RuntimeSchemaInfo(needs_pytree=True),
)
def index_put_single_dim_strategy(
    op: OpOverload, args: ArgsType, kwargs: KwargsType
) -> list[list[Placement | _ShardingPlaceholder]]:
    """Single-dim sharding strategy for index_put(self, indices, values).

    Strategy format: [output, input, *indices, value]

    How index_put works:

      indices is a tuple of index tensors and Nones:
      - an index tensor at entry i means self is indexed on dim i.
      - a None at entry i means all elements along dim i are selected (like :).
      - any trailing dims (if self.ndim > len(indices)) are also not indexed
        (i.e. implicit trailing Nones).

      All non-None index tensors are broadcast together to produce a
      broadcasted indexing shape. Each position in this broadcasted shape
      serves as an indexing coordinate into self. Each coordinate selects a
      tensor element, or a slice (if non-indexed dims exist).

      values is a tensor broadcastable to the indexing output shape.
      When indexed dims are consecutive starting at dim k, this shape is
      (*self[:k], *broadcast_shape, *self[k+n_indexed:]). When indexed
      dims are non-consecutive, it is (*broadcast_shape, *non_indexed_dims).

    Sharding rules (possibly conservative and incomplete):
      - Index tensors: always Replicate (every rank needs all coordinates).
      - Self cannot be sharded on indexed dims (local position != global position).
      - Self and values CAN be sharded on non-indexed dims.
        The exception is broadcasted value dimensions (size 1) - we require Replicate, but can shard self.
      - Additionally, we allow the full Partial rule on non-indexing tensors.

    """
    self_meta = cast(TensorMeta, args[0])
    indices_meta = cast(tuple[TensorMeta | None, ...], args[1])
    values_meta = cast(TensorMeta, args[2])

    # Determine indexed vs non-indexed dims of self.
    indexed_dims = {i for i, idx in enumerate(indices_meta) if idx is not None}
    non_indexed_dims = [d for d in range(len(self_meta.shape)) if d not in indexed_dims]
    n_indexed = len(indexed_dims)
    values_ndim = len(values_meta.shape)

    # Explicitly compute the broadcast shape of the index tensors.
    index_shapes = [idx.shape for idx in indices_meta if idx is not None]
    broadcast_ndim = len(torch.broadcast_shapes(*index_shapes)) if index_shapes else 0

    # Strategy format: [output, input, *indices, value]
    # The infra flattens the indices list and drops None entries, so only
    # non-None index tensors get a placement slot (all Replicate).
    #
    # Values dim mapping depends on whether indexed dims are contiguous:
    #   Contiguous (e.g., (None, idx0, idx1)): broadcast replaces indexed block in-place.
    #     values shape = (*non_indexed_before, *broadcast_shape, *non_indexed_after)
    #   Non-contiguous (e.g., (idx0, None, idx1)): broadcast goes to front.
    #     values shape = (*broadcast_shape, *non_indexed_dim_sizes)
    indexed_dims_sorted = sorted(indexed_dims)
    contiguous_indexed = len(indexed_dims_sorted) <= 1 or (
        indexed_dims_sorted[-1] - indexed_dims_sorted[0] + 1 == len(indexed_dims_sorted)
    )

    strategies: list[list[Placement | _ShardingPlaceholder]] = []
    for i, self_dim in enumerate(non_indexed_dims):
        if contiguous_indexed and indexed_dims_sorted:
            # Broadcast replaces the indexed block in-place.
            first_indexed = indexed_dims_sorted[0]
            if self_dim < first_indexed:
                values_dim = self_dim
            else:
                values_dim = self_dim - n_indexed + broadcast_ndim
        else:
            # Broadcast goes to front (non-contiguous or no indexed dims).
            values_dim = broadcast_ndim + i

        # values_dim is the position in the result tensor, but values may
        # have fewer dims (right-aligned broadcasting). Convert to the
        # actual values tensor dimension.
        result_ndim = broadcast_ndim + len(non_indexed_dims)
        values_tensor_dim = values_dim - (result_ndim - values_ndim)

        if values_tensor_dim < 0:
            values_placement: Placement | _ShardingPlaceholder = Replicate()
        elif values_meta.shape[values_tensor_dim] == 1:
            values_placement = Replicate()
        else:
            values_placement = _ShardingPlaceholder(values_tensor_dim)

        strategies.append(
            [
                _ShardingPlaceholder(self_dim),
                _ShardingPlaceholder(self_dim),
                *([Replicate()] * n_indexed),
                values_placement,
            ]
        )

    # full-partial rule on non-indexing tensors
    strategies.append(
        [
            Partial(),
            Partial(),
            *([Replicate()] * n_indexed),
            Partial(),
        ]
    )
    return strategies


def _index_dim_strategy(
    args_schema: ArgsType,
    shard_row: Callable[[int], list[Placement | _ShardingPlaceholder]],
    partial_rules: list[list[Placement | _ShardingPlaceholder]] | None = None,
) -> list[list[Placement | _ShardingPlaceholder]]:
    """Common strategy for index ops that shard on all dims except the indexed dim.

    Args:
        shard_row: given a dim d, returns the strategy row for sharding on that dim.
        partial_rules: additional Partial passthrough strategies.
    """
    self_meta = cast(TensorMeta, args_schema[0])
    ndim = len(self_meta.shape)
    dim = normalize_dim(cast(int, args_schema[1]), ndim)
    strategies: list[list[Placement | _ShardingPlaceholder]] = []
    for d in range(ndim):
        if d != dim:
            strategies.append(shard_row(d))
    if partial_rules:
        strategies.extend(partial_rules)
    return strategies


@register_single_dim_strategy(
    [aten.index_fill.int_Scalar, aten.index_fill_.int_Scalar],
    schema_info=RuntimeSchemaInfo(1),
)
def index_fill_scalar_single_dim_strategy(
    op: OpOverload, args_schema: ArgsType, kwargs_schema: KwargsType
) -> list[list[Placement | _ShardingPlaceholder]]:
    # index_fill(self, dim, index, value) — fills self[..., index, ...] with scalar value.
    # Partial rules: each rank fills with the same scalar v, then reduces.
    # Only idempotent reduces work: avg(v,v,...,v)=v, max(v,v,...,v)=v, min(v,v,...,v)=v.
    # sum and product fail: sum(v,v,...,v)=nv, product(v,v,...,v)=v^n.
    return _index_dim_strategy(
        args_schema,
        lambda d: [
            _ShardingPlaceholder(d),  # result
            _ShardingPlaceholder(d),  # self
            Replicate(),  # value (scalar, same on all ranks)
        ],
        [[Partial(op), Partial(op), Replicate()] for op in ("avg", "max", "min")],
    )


@register_single_dim_strategy(
    [aten.index_fill.int_Tensor, aten.index_fill_.int_Tensor],
    schema_info=RuntimeSchemaInfo(1),
)
def index_fill_tensor_single_dim_strategy(
    op: OpOverload, args_schema: ArgsType, kwargs_schema: KwargsType
) -> list[list[Placement | _ShardingPlaceholder]]:
    # index_fill(self, dim, index, value) — fills self[..., index, ...] with 0-d tensor value.
    # Partial rules: each rank fills with its partial value v_i, then reduces.
    # All reduce ops work because reduce(v_0, ..., v_{n-1}) = V (the global value)
    # regardless of op, since fill is a pure replacement (no mixing with self).
    return _index_dim_strategy(
        args_schema,
        lambda d: [
            _ShardingPlaceholder(d),  # result
            _ShardingPlaceholder(d),  # self
            Replicate(),  # index
            Replicate(),  # value
        ],
        [
            [Partial(op), Partial(op), Replicate(), Partial(op)]
            for op in Partial.ALL_REDUCE_OPS
        ],
    )


@register_single_dim_strategy(
    [aten.index_reduce.default, aten.index_reduce_.default],
    schema_info=RuntimeSchemaInfo(1),
)
def index_reduce_single_dim_strategy(
    op: OpOverload, args_schema: ArgsType, kwargs_schema: KwargsType
) -> list[list[Placement | _ShardingPlaceholder]]:
    # index_reduce(self, dim, index, source, reduce) — reduces source into self at index positions.
    # No partial rules: reduce ops are "mean"/"amax"/"amin"/"prod", which don't match
    # any Partial reduce op names ("avg"/"max"/"min"/"product"/"sum").
    return _index_dim_strategy(
        args_schema,
        lambda d: [
            _ShardingPlaceholder(d),  # result
            _ShardingPlaceholder(d),  # self
            Replicate(),  # index
            _ShardingPlaceholder(d),  # source
        ],
    )


@register_op_strategy(
    [
        aten.split.Tensor,
        aten.split_with_sizes.default,
        aten.split_with_sizes_copy.default,
    ],
    RuntimeSchemaInfo(1),
)
def split_strategy(op_schema: OpSchema) -> OpStrategy:
    input_strategy = op_schema.args_schema[0]
    split_size_or_sections = op_schema.args_schema[1]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
    input_ndim = input_strategy.ndim
    split_dim = (
        cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else 0
    )
    dim = normalize_dim(split_dim, input_ndim)

    def size_split(N, i) -> list:
        # Last chunk will be smaller if the tensor size N
        # along the given dimension dim is not divisible by i.
        if not i > 0:
            raise AssertionError(f"Split size must be positive, got {i}")
        return [i] * (N // i) + ([N % i] if N % i != 0 else [])

    output_size_list = (
        size_split(input_strategy.shape[dim], split_size_or_sections)
        if isinstance(split_size_or_sections, IntLike)
        else split_size_or_sections
    )
    if not isinstance(output_size_list, Sized):
        raise AssertionError(f"Expected Sized, got {type(output_size_list)}")

    all_strategies = []
    for strategy in input_strategy.strategies:
        spec = strategy.output_spec
        placements = spec.placements
        if is_tensor_dim_sharded(spec, dim=dim):
            # if the input is sharded on the split dim, we need to unshard it
            placements = unshard_tensor_dim(spec.placements, dim=dim)

        input_spec = DTensorSpec(spec.device_mesh, placements, spec.tensor_meta)
        output_specs = tuple(
            DTensorSpec(spec.device_mesh, placements)
            for _ in range(len(output_size_list))
        )
        all_strategies.append(
            OpSpec(
                output_specs=output_specs,
                input_specs=(input_spec,),
                redistribute_cost=[
                    generate_redistribute_costs(input_strategy, input_spec)
                ],
            )
        )

    return OpStrategy(all_strategies)


# TODO: fix remaining failures in xfail("unbind") in test_dtensor_ops.py
#       and remove this xfail item
@register_op_strategy(aten.unbind.int, schema_info=RuntimeSchemaInfo(1))
def gen_unbind_strategy(op_schema: OpSchema) -> StrategyType:
    """Forward all shardings except the unbind dimension."""
    input_strategy = op_schema.args_schema[0]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
    input_ndim = input_strategy.ndim
    input_shape = input_strategy.shape
    unbind_dim = (
        cast(int, op_schema.args_schema[1]) if len(op_schema.args_schema) > 1 else 0
    )
    unbind_dim = normalize_dim(unbind_dim, input_ndim)

    mesh = input_strategy.mesh
    unbind_strategy = OpStrategy([])
    for arg_strategy in input_strategy.strategies:
        arg_spec = arg_strategy.output_spec
        if is_tensor_dim_sharded(arg_spec, dim=unbind_dim):
            raise RuntimeError(
                f"Attempted to unbind along the sharded dimension {unbind_dim}. ",
                "It cannot be performed without redistribution, which is disallowed "
                "by the current operator.",
            )
        # only add the strategy if the unbind dim is not sharded
        output_placements = shift_shard_dims_after_remove(
            arg_spec.placements, unbind_dim
        )
        output_specs = tuple(
            DTensorSpec(mesh, tuple(output_placements))
            for _ in range(input_shape[unbind_dim])
        )
        unbind_strategy.strategies.append(
            OpSpec(
                output_specs=output_specs,
                input_specs=(arg_spec,),
                redistribute_cost=[[0.0] * len(input_strategy.strategies)],
            )
        )
    return unbind_strategy


@register_op_strategy(aten.eye.m_out)
def eye_out_strategy(op_schema: OpSchema) -> OpStrategy:
    """
    Strategy for torch.eye with out= parameter.
    The sharding is determined by the out tensor's placement.
    """
    # eye.m_out has signature: eye(int n, int m, *, Tensor(a!) out) -> Tensor(a!)
    # The out kwarg is a DTensor that determines the sharding
    out_spec = op_schema.kwargs_schema["out"]
    if not isinstance(out_spec, OpStrategy):
        raise AssertionError(f"Expected OpStrategy for out, got {type(out_spec)}")

    return OpStrategy(
        [
            OpSpec(
                output_specs=strategy.output_spec,
                input_specs=[strategy.output_spec],  # out is both input and output
                redistribute_cost=[[0.0]],
            )
            for strategy in out_spec.strategies
        ]
    )


def _pass_through_partials(
    num_inputs: int = 1,
) -> list[list[Placement | _ShardingPlaceholder]]:
    """Pass-through strategies for all supported reduce ops."""
    return [[Partial(op)] * (1 + num_inputs) for op in ("sum", "avg", "max", "min")]


def _shard_inactive_dims(
    ndim: int, active_dims: set[int], num_inputs: int = 1
) -> list[list[Placement | _ShardingPlaceholder]]:
    """Single-dim strategies: shard on dims the op doesn't touch."""
    strategies: list[list[Placement | _ShardingPlaceholder]] = []
    for d in range(ndim):
        if d not in active_dims:
            strategies.append([_ShardingPlaceholder(d)] * (1 + num_inputs))
    return strategies


@register_single_dim_strategy(aten.roll.default, schema_info=RuntimeSchemaInfo(1))
def roll_single_dim_strategy(
    op: OpOverload, args_schema: ArgsType, kwargs_schema: KwargsType
) -> list[list[Placement | _ShardingPlaceholder]]:
    input_meta = args_schema[0]
    if not isinstance(input_meta, TensorMeta):
        raise AssertionError(f"Expected TensorMeta, got {type(input_meta)}")
    ndim = len(input_meta.shape)
    raw_dims = cast(list[int], args_schema[2]) if len(args_schema) > 2 else []
    # When dims is empty, roll flattens the tensor — all dims are active
    if not raw_dims:
        raw_dims = list(range(ndim))
    active_dims = {normalize_dim(d, ndim) for d in raw_dims}
    return _shard_inactive_dims(ndim, active_dims) + _pass_through_partials()


@register_single_dim_strategy(aten.flip.default, schema_info=RuntimeSchemaInfo(1))
def flip_single_dim_strategy(
    op: OpOverload, args_schema: ArgsType, kwargs_schema: KwargsType
) -> list[list[Placement | _ShardingPlaceholder]]:
    input_meta = args_schema[0]
    if not isinstance(input_meta, TensorMeta):
        raise AssertionError(f"Expected TensorMeta, got {type(input_meta)}")
    ndim = len(input_meta.shape)
    raw_dims = cast(list[int], args_schema[1])
    active_dims = {normalize_dim(d, ndim) for d in raw_dims}
    return _shard_inactive_dims(ndim, active_dims) + _pass_through_partials()


@register_single_dim_strategy(
    [aten._fft_c2c.default, aten._fft_r2c.default, aten._fft_c2r.default],
    schema_info=RuntimeSchemaInfo(1),
)
def fft_single_dim_strategy(
    op: OpOverload, args_schema: ArgsType, kwargs_schema: KwargsType
) -> list[list[Placement | _ShardingPlaceholder]]:
    input_meta = args_schema[0]
    if not isinstance(input_meta, TensorMeta):
        raise AssertionError(f"Expected TensorMeta, got {type(input_meta)}")
    ndim = len(input_meta.shape)
    raw_dims = cast(list[int], args_schema[1])
    active_dims = {normalize_dim(d, ndim) for d in raw_dims}
    return _shard_inactive_dims(ndim, active_dims) + _pass_through_partials()
