{{def_kernel("A", "B")}}
    M = {{size("A", 0)}}
    N = {{size("B", 1)}}
    K = {{size("A", 1)}}
    if M * N == 0:
        # early exit due to zero-size input(s)
        return
    start_pid = tl.program_id(0)
    grid_m = tl.cdiv(M, BLOCK_M)
    grid_n = tl.cdiv(N, BLOCK_N)
    k_tiles = tl.cdiv(K, BLOCK_K)
    num_tiles = grid_m * grid_n

    # Note: We require TMA_EXPERIMENTAL_API == False, which
    # we will check before invoking this template.
    stride_am = {{stride("A", 0)}}
    stride_ak = {{stride("A", 1)}}
    stride_bk = {{stride("B", 0)}}
    stride_bn = {{stride("B", 1)}}
    a_desc = triton.language.make_tensor_descriptor(
        base=A,
        shape=[M, K] if A_ROW_MAJOR else [K, M],
        strides=[stride_am, 1] if A_ROW_MAJOR else [stride_ak, 1],
        block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],
    )
    b_desc = triton.language.make_tensor_descriptor(
        base=B,
        shape=[K, N] if B_ROW_MAJOR else [N, K],
        strides=[stride_bk, 1] if B_ROW_MAJOR else [stride_bn, 1],
        block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K],
    )

    # tile_id_c is used in the epilogue to break the dependency between
    # the prologue and the epilogue
    tile_id_c = start_pid - NUM_SMS
    num_pid_in_group = GROUP_M * grid_n

    for tile_id in tl.range(
        start_pid, num_tiles, NUM_SMS, flatten=FLATTEN, warp_specialize=WARP_SPECIALIZE
    ):
        pid_m, pid_n = _compute_pid(
            tile_id, num_pid_in_group, grid_m, GROUP_M, NUM_SMS
        )
        offs_am = pid_m * BLOCK_M
        offs_bn = pid_n * BLOCK_N
        offs_am_desc = offs_am.to(tl.int32)
        offs_bn_desc = offs_bn.to(tl.int32)

        accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
        for ki in range(k_tiles):
            offs_k = ki * BLOCK_K
            offs_k_desc = offs_k.to(tl.int32)
            a = tl.load_tensor_descriptor(
                a_desc,
                [offs_am_desc, offs_k_desc]
                if A_ROW_MAJOR
                else [offs_k_desc, offs_am_desc],
            )
            b = tl.load_tensor_descriptor(
                b_desc,
                [offs_k_desc, offs_bn_desc]
                if B_ROW_MAJOR
                else [offs_bn_desc, offs_k_desc],
            )
            accumulator += tl.dot(
                a if A_ROW_MAJOR else a.T,
                b if B_ROW_MAJOR else b.T,
                allow_tf32=ALLOW_TF32,
            )

        tile_id_c += NUM_SMS
        pid_m, pid_n = _compute_pid(
            tile_id_c, num_pid_in_group, grid_m, GROUP_M, NUM_SMS
        )
        offs_cm = pid_m * BLOCK_M
        offs_cn = pid_n * BLOCK_N
        subtiles = _subtile_accumulator(accumulator, BLOCK_M, BLOCK_N, EPILOGUE_SUBTILE)
        for i in tl.static_range(EPILOGUE_SUBTILE):
            subtile = subtiles[i]
            offs_cn_i = offs_cn + i * (BLOCK_N // EPILOGUE_SUBTILE)
            {{store_output(
                ("offs_cm", "offs_cn_i"),
                "subtile",
                indent_width=12,
                val_shape=("BLOCK_M", "BLOCK_N // EPILOGUE_SUBTILE"),
                block_indexing=True
            )}}

@triton.jit
def _compute_pid(tile_id, num_pid_in_group, grid_m, GROUP_M: tl.constexpr, NUM_SMS: tl.constexpr):
    group_id = tile_id // num_pid_in_group
    first_pid_m = group_id * GROUP_M
    GROUP_M = min(grid_m - first_pid_m, GROUP_M)
    pid_m = first_pid_m + (tile_id % GROUP_M)
    pid_n = (tile_id % num_pid_in_group) // GROUP_M
    return pid_m, pid_n


@triton.jit
def _subtile_accumulator(
    acc,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    SUBTILE_FACTOR: tl.constexpr,
):
    """
    Recursively split an accumulator into SUBTILE_FACTOR pieces using
    reshape + permute + split operations.

    Args:
        acc: The accumulator tensor of shape (BLOCK_M, BLOCK_N)
        BLOCK_M: The M dimension size (constexpr)
        BLOCK_N: The N dimension size (constexpr)
        SUBTILE_FACTOR: Must be a power of 2. The number of pieces to split into (constexpr)

    Returns:
        A tuple of SUBTILE_FACTOR tensors, each of shape (BLOCK_M, BLOCK_N // SUBTILE_FACTOR)
    """
    # SUBTILE_FACTOR must be a power of 2
    tl.static_assert(SUBTILE_FACTOR > 0, "SUBTILE_FACTOR must be positive")
    tl.static_assert((SUBTILE_FACTOR & (SUBTILE_FACTOR - 1)) == 0, "SUBTILE_FACTOR must be a power of 2")

    if SUBTILE_FACTOR == 1:
        # Base case: no subtiling needed, return the input as a single-element tuple
        return (acc,)
    else:
        # Recursive case: split into 2 halves, then recursively split each half
        tl.static_assert(BLOCK_N % 2 == 0)
        acc = tl.reshape(acc, (BLOCK_M, 2, BLOCK_N // 2))
        acc = tl.permute(acc, (0, 2, 1))
        left, right = tl.split(acc)
        # Recursively split each half
        left_subtiles = _subtile_accumulator(left, BLOCK_M, BLOCK_N // 2, SUBTILE_FACTOR // 2)
        right_subtiles = _subtile_accumulator(right, BLOCK_M, BLOCK_N // 2, SUBTILE_FACTOR // 2)
        # Concatenate the tuples
        return left_subtiles + right_subtiles
