#!/usr/bin/env python3
import sys
import argparse
import os
import math
import shlex
import subprocess as sp
from os.path import (
    join,
    dirname,
    realpath,
    abspath,
    isfile,
    isdir,
    exists,
    splitext,
    basename,
    getsize,
)
import shutil
import time
from enum import Enum, auto

import btllib

from goldpolish_utils import (
    get_random_name,
    watch_process,
    create_threads_in_use,
    create_simultaneous_batch_processes,
    get_simultaneous_batch_processes,
    update_simultaneous_batch_processes,
)

SEQ_IDS_FILENAME = "seq_ids"
BF_NAME_TEMPLATE = "{}-k{}.bf"
BFS_DIRNAME = "targeted_bfs"
BATCH_NAME_INPUT_PIPE = "batch_name_input"
BATCH_TARGET_IDS_INPUT_READY_PIPE = "batch_target_ids_input_ready"
TARGET_IDS_INPUT_PIPE = "target_ids_input"
BFS_READY_PIPE = "bfs_ready"
BATCH_DONE_PIPE = "polishing_done"
MAX_SIMULTANEOUS_BATCH_PROCESSES = 200
BATCH_THREADS = 1
SPAWN_DELAY = 0.05
POLISHING_OVER_FILE = "gg"
SEPARATOR = "-"
GOLDPOLISH_TARGETED_BFS = "goldpolish-targeted-bfs"
GOLDPOLISH_MAKE = "goldpolish-make"
GOLDPOLISH_MAKE_FULL_PATH = f"{os.path.dirname(os.path.realpath(__file__))}/{GOLDPOLISH_MAKE}"
GOLDPOLISH_POLISH_BATCH = "goldpolish-polish-batch"
GOLDPOLISH_REAPER = "goldpolish-reaper"
END_SYMBOL = "x"
NTLINK_SUBSAMPLE_MAX_READS_PER_10KBP = 100
MINIMAP2_SUBSAMPLE_MAX_READS_PER_10KBP = 40
GOLDPOLISH_TARGET = "goldpolish-target.py"


class MappingTool(Enum):
    NTLINK = auto()
    MINIMAP2 = auto()
    MAPPINGS_PROVIDED = auto()


def get_bf_builder_threads_num(total_threads):
    assert total_threads >= 1
    bf_builder_threads_num = math.ceil(total_threads / 1.8)
    assert bf_builder_threads_num >= 1
    return bf_builder_threads_num


def get_cli_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("seqs_to_polish", help="Sequences to polish.")
    parser.add_argument("polishing_seqs", help="Sequences to polish with.")
    parser.add_argument("output_seqs", help="Filename to write polished sequences to.")
    parser.add_argument(
        "-k",
        action="append",
        default=[],
        help="k-mer sizes to use for polishing. Example: -k32 -k28 (Default: 32, 28, 24, 20)",
    )
    parser.add_argument(
        "-b",
        "--bsize",
        default=1,
        type=int,
        help="Batch size. A batch is how many polished sequences are processed per Bloom filter. (Default: 1)",
    )
    parser.add_argument(
        "-m",
        "--shared-mem",
        default="/dev/shm",
        help="Shared memory path to do polishing in. (Default: /dev/shm)",
    )
    parser.add_argument(
        "-t",
        "--threads",
        type=int,
        default=48,
        help="How many threads to use. (Default: 48)",
    )
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument(
        "-x",
        "--mx-max-reads-per-10kbp",
        type=float,
        default=150.0,
        help="When subsampling, increase the common minimizer count threshold for ntLink mappings until there's at most this many reads per 10kbp of polished sequence. (Default: 150)",
    )
    parser.add_argument(
        "-s",
        "--subsample-max-reads-per-10kbp",
        type=float,
        default=-1,
        help=f"Random subsampling of mapped reads. For ntLink mappings, this is done after common minimizer subsampling. For minimap2 mappings, only this subsampling is done. By default, {MINIMAP2_SUBSAMPLE_MAX_READS_PER_10KBP} if using minimap2 mappings and {NTLINK_SUBSAMPLE_MAX_READS_PER_10KBP} if using ntLink mappings.",
    )
    group = parser.add_mutually_exclusive_group()
    group.add_argument(
        "--ntlink",
        action="store_true",
        help="Run ntLink to generate read mappings (default).",
    )
    group.add_argument(
        "--minimap2",
        action="store_true",
        help="Run minimap2 to generate read mappings.",
    )
    group.add_argument(
        "--mappings",
        default="",
        help="Use provided pre-generated mappings. Accepted formats are PAF, SAM, and *.verbose_mapping.tsv from ntLink.",
    )
    parser.add_argument(
        "--k-ntlink",
        type=int,
        default=88,
        help="k-mer size used for ntLink mappings (if --ntlink specified)",
    )
    parser.add_argument(
        "--w-ntlink",
        type=int,
        default=1000,
        help="Window size used for ntLink mappings (if --ntlink specified)",
    )
    parser.add_argument(
        "--target",
        action="store_true",
        help='Run GoldPolish in targeted mode'
    )
    parser.add_argument(
        "-l",
        "--length",
        type=int,
        default=64,
        help="GoldPolish-Target flank length (if --target specified)"
    )
    parser.add_argument(
        "--sensitive",
        default=True,
        required=False,
        help=argparse.SUPPRESS
    )
    target_group = parser.add_mutually_exclusive_group()
    target_group.add_argument(
        "--bed",
        help="BED file specifying target coordinates (if --target specified)"
    )
    target_group.add_argument(
        "--softmask",
        action="store_true",
        help="Target coordinates determined from softmasked regions in the input assembly (if --target specified)"
    )
    parser.add_argument(
        "--benchmark",
        action="store_true",
        help=argparse.SUPPRESS
    )
    parser.add_argument(
        '--target_dev',
        action="store_true",
        help=argparse.SUPPRESS
    )

    args = parser.parse_args()
    args.seqs_to_polish = abspath(args.seqs_to_polish)
    args.polishing_seqs = abspath(args.polishing_seqs)
    args.output_seqs = abspath(args.output_seqs)
    if len(args.k) == 0:
        args.k = [32, 28, 24, 20]

    if args.threads < 2:
        args.threads = 2
        btllib.log_warning("Threads number is too low. Setting to 2.")

    return args


def autoclean(workspace, prefix):
    process = sp.Popen(["goldpolish-autoclean", workspace, prefix])
    watch_process(process)


def build_indexes_and_mappings(
    polishing_seqs_path,
    seqs_to_polish_path,
    k_values,
    mapping_tool,
    mappings,
    subsample_max_reads_per_10kbp,
    threads,
    verbose,
    k_ntlink,
    w_ntlink,
):
    btllib.log_info(f"Building indexes and mappings...")

    # Indexes and mapping filenames
    polishing_seqs_index = f"{basename(polishing_seqs_path)}.index"
    seqs_to_polish_index = f"{basename(seqs_to_polish_path)}.index"

    if mapping_tool == MappingTool.NTLINK:
        mappings_to_build = (
            f"{basename(seqs_to_polish_path)}.k{k_ntlink}.w{w_ntlink}.z1000.mapping.tsv"
        )
        mappings = mappings_to_build
        if subsample_max_reads_per_10kbp == -1:
            subsample_max_reads_per_10kbp = NTLINK_SUBSAMPLE_MAX_READS_PER_10KBP
    elif mapping_tool == MappingTool.MINIMAP2:
        mappings_to_build = (
            f"{basename(seqs_to_polish_path)}.{basename(polishing_seqs_path)}.paf"
        )
        mappings = mappings_to_build
        if subsample_max_reads_per_10kbp == -1:
            subsample_max_reads_per_10kbp = MINIMAP2_SUBSAMPLE_MAX_READS_PER_10KBP
    elif mapping_tool == MappingTool.MAPPINGS_PROVIDED:
        if mappings.endswith(".verbose_mapping.tsv"):
            if not isfile(basename(mappings)):
                os.symlink(mappings, basename(mappings))
            tokens = mappings.split(".")
            params = ".".join(tokens[(-5):(-2)])
            mappings_to_build = f"{basename(seqs_to_polish_path)}.{params}.mapping.tsv"
            mappings = mappings_to_build
            if subsample_max_reads_per_10kbp == -1:
                subsample_max_reads_per_10kbp = NTLINK_SUBSAMPLE_MAX_READS_PER_10KBP
        else:
            mappings_to_build = ""
            if subsample_max_reads_per_10kbp == -1:
                subsample_max_reads_per_10kbp = MINIMAP2_SUBSAMPLE_MAX_READS_PER_10KBP
    else:
        raise ValueError(f"Unknown mapping tool: {mapping_tool}")

    k_values = " ".join(str(k) for k in k_values)

    try:
        p = sp.run(
            [
                f""" \
        make -Rrf \
        {GOLDPOLISH_MAKE_FULL_PATH} \
        t={threads} \
        polishing_seqs={polishing_seqs_path} \
        seqs_to_polish={seqs_to_polish_path} \
        K='{k_values}' \
        {seqs_to_polish_index} \
        {mappings_to_build} \
        {polishing_seqs_index} \
        {"time=true" if verbose else ""} \
        {f"k_ntLink={k_ntlink}" if mapping_tool == MappingTool.NTLINK else ""} \
        {f"w_ntLink={w_ntlink}" if mapping_tool == MappingTool.NTLINK else ""}
    """
            ],
            shell=True,
            text=True,
            capture_output=True,
            check=True,
        )
    except sp.CalledProcessError as e:
        btllib.log_error(f"{e.stderr}")
        raise e
    btllib.log_info(p.stdout + p.stderr)

    polishing_seqs_index = abspath(polishing_seqs_index)
    seqs_to_polish_index = abspath(seqs_to_polish_index)
    mappings = abspath(mappings)

    btllib.log_info("Indexes and mappings built.")

    return (
        polishing_seqs_index,
        seqs_to_polish_index,
        mappings,
        subsample_max_reads_per_10kbp,
    )


def run_bf_builder(
    bfs_dir,
    seqs_to_polish,
    seqs_to_polish_index,
    mappings,
    polishing_seqs,
    polishing_seqs_index,
    k_values,
    batch_name_input_pipe,
    batch_target_ids_input_ready_pipe,
    mx_max_reads_per_10kbp,
    subsample_max_reads_per_10kbp,
    threads,
):
    k_values = [str(k) for k in k_values]

    process = sp.Popen(
        [
            GOLDPOLISH_TARGETED_BFS,
            seqs_to_polish,
            seqs_to_polish_index,
            mappings,
            polishing_seqs,
            polishing_seqs_index,
            str(mx_max_reads_per_10kbp),
            str(subsample_max_reads_per_10kbp),
            str(threads),
        ]
        + k_values,
        cwd=bfs_dir,
    )
    watch_process(process)

    while not exists(batch_name_input_pipe) or not exists(
        batch_target_ids_input_ready_pipe
    ):
        time.sleep(2)
    btllib.log_info(f"{GOLDPOLISH_TARGETED_BFS} is ready!")

    return process


def end_bf_builder(batch_name_input_pipe):
    with open(batch_name_input_pipe, "w") as f:
        print(END_SYMBOL, file=f)


def get_next_batch_of_contigs(reader, output_filepath, batch_size):
    with btllib.SeqWriter(output_filepath) as writer:
        seq_ids = []
        reader_done = True
        while record := reader.read():
            writer.write(record.id, record.comment, record.seq)
            seq_ids.append(record.id)
            if record.num % batch_size == batch_size - 1:
                reader_done = False
                break
        return reader_done, seq_ids


def make_tmp_dir(workspace, prefix, suffix):
    tmp_dir = join(workspace, f"{prefix}{SEPARATOR}{suffix}")
    os.mkdir(tmp_dir)
    return tmp_dir


def polish_batch(
    batch_dir,
    batch_name_input_pipe,
    batch_target_ids_input_ready_pipe,
    target_ids_input_pipe,
    bfs_ready_pipe,
    batch_seqs,
    batch_num,
    seq_ids,
    batch_done_pipe,
    bfs_dir,
    k_values,
    workspace,
    prefix,
    threads,
    max_threads,
    verbose,
):
    with open(batch_name_input_pipe, "w") as f:
        print(batch_num, file=f)
    with open(batch_target_ids_input_ready_pipe) as f:
        f.read()

    seq_ids_file = join(batch_dir, SEQ_IDS_FILENAME)
    with open(seq_ids_file, "w") as f:
        for id in seq_ids:
            print(id, file=f)

    # Bloom filter names
    bfs = []
    for k in k_values:
        bf_name = BF_NAME_TEMPLATE.format(batch_num, k)
        bf_path = join(bfs_dir, bf_name)
        bfs.append(f"-b{bf_path}")

    k_values = [f"-k{k}" for k in k_values]

    process = sp.Popen(
        [
            GOLDPOLISH_POLISH_BATCH,
            batch_seqs,
            bfs_dir,
            workspace,
            prefix,
            str(max_threads),
        ]
        + k_values
        + bfs
        + [
            "--seq-ids",
            seq_ids_file,
            "--bfs-ids-pipe",
            target_ids_input_pipe,
            "--bfs-ready-pipe",
            bfs_ready_pipe,
            "--batch-done-pipe",
            batch_done_pipe,
            "--threads",
            str(threads),
        ]
        + (["--verbose"] if verbose else []),
        cwd=batch_dir,
    )
    watch_process(process)


def start_reaper(workspace, prefix, batch_polished_seqs, output_seqs):
    process = sp.Popen(
        [GOLDPOLISH_REAPER, workspace, prefix, batch_polished_seqs, output_seqs]
    )
    watch_process(process)


def create_end_batch(end_file):
    with open(end_file, "w") as f:
        print("1899", file=f)  # :O?!


def confirm_polishing_done(polishing_done_pipepath):
    with open(polishing_done_pipepath, "w") as f:
        f.write(str(1))


def polish_seqs(
    seqs_to_polish,
    polishing_seqs,
    output_seqs,
    k_values,
    batch_size,
    workspace,
    mx_max_reads_per_10kbp,
    subsample_max_reads_per_10kbp,
    threads,
    mapping_tool,
    mappings,
    verbose,
    k_ntlink,
    w_ntlink,
):
    prefix = get_random_name()

    autoclean(workspace, prefix)

    # Build indexes and mappings
    (
        polishing_seqs_index,
        seqs_to_polish_index,
        mappings,
        subsample_max_reads_per_10kbp,
    ) = build_indexes_and_mappings(
        polishing_seqs,
        seqs_to_polish,
        k_values,
        mapping_tool,
        mappings,
        subsample_max_reads_per_10kbp,
        threads,
        verbose,
        k_ntlink,
        w_ntlink,
    )
    btllib.check_error(
        subsample_max_reads_per_10kbp <= 0, "Subsample max reads per 10kbp is <=0"
    )
    btllib.log_info(f"Subsampling mapped reads to {subsample_max_reads_per_10kbp}")

    # Where to build the Bloom filters
    bfs_dir = make_tmp_dir(workspace, prefix, BFS_DIRNAME)

    # Pipes for communicating with the process building the Bloom filters
    batch_name_input_pipe = join(bfs_dir, BATCH_NAME_INPUT_PIPE)
    batch_target_ids_input_ready_pipe = join(bfs_dir, BATCH_TARGET_IDS_INPUT_READY_PIPE)

    bf_builder_threads = get_bf_builder_threads_num(threads)

    build_targeted_bfs_process = run_bf_builder(
        bfs_dir,
        seqs_to_polish,
        seqs_to_polish_index,
        mappings,
        polishing_seqs,
        polishing_seqs_index,
        k_values,
        batch_name_input_pipe,
        batch_target_ids_input_ready_pipe,
        mx_max_reads_per_10kbp,
        subsample_max_reads_per_10kbp,
        bf_builder_threads,
    )

    create_simultaneous_batch_processes(workspace, prefix)
    create_threads_in_use(workspace, prefix, bf_builder_threads)

    # Temporary name for the seq(s) to polish
    batch_seqs = "batch.fa"
    batch_polished_seqs = "batch.ntedited.prepd.sealer_scaffold.upper.fa"
    batch_num = 0

    start_reaper(workspace, prefix, batch_polished_seqs, output_seqs)

    btllib.log_info("Polishing batches...")
    reader_done = False
    with btllib.SeqReader(seqs_to_polish, btllib.SeqReaderFlag.LONG_MODE) as reader:
        while not reader_done:
            while True:
                simultaneous_batch_processes = get_simultaneous_batch_processes(
                    workspace, prefix
                )
                if simultaneous_batch_processes <= MAX_SIMULTANEOUS_BATCH_PROCESSES - 1:
                    break
                time.sleep(SPAWN_DELAY)

            batch_dir = make_tmp_dir(workspace, prefix, batch_num)

            reader_done, seq_ids = get_next_batch_of_contigs(
                reader, join(batch_dir, batch_seqs), batch_size
            )

            batch_done_pipe = join(batch_dir, BATCH_DONE_PIPE)
            os.mkfifo(batch_done_pipe)

            if reader_done:
                assert len(seq_ids) == 0
                polishing_over_file = join(batch_dir, POLISHING_OVER_FILE)
                create_end_batch(polishing_over_file)
                confirm_polishing_done(batch_done_pipe)
            else:
                target_ids_input_pipe = join(
                    bfs_dir, f"{batch_num}{SEPARATOR}{TARGET_IDS_INPUT_PIPE}"
                )
                bfs_ready_pipe = join(
                    bfs_dir, f"{batch_num}{SEPARATOR}{BFS_READY_PIPE}"
                )

                update_simultaneous_batch_processes(workspace, prefix, 1)
                polish_batch(
                    batch_dir,
                    batch_name_input_pipe,
                    batch_target_ids_input_ready_pipe,
                    target_ids_input_pipe,
                    bfs_ready_pipe,
                    batch_seqs,
                    batch_num,
                    seq_ids,
                    batch_done_pipe,
                    bfs_dir,
                    k_values,
                    workspace,
                    prefix,
                    BATCH_THREADS,
                    threads,
                    verbose,
                )

            batch_num += 1

    btllib.log_info("Done polishing batches, ending BF builder process...")
    end_bf_builder(batch_name_input_pipe)
    build_targeted_bfs_process.wait()
    shutil.rmtree(bfs_dir, ignore_errors=True)
    btllib.log_info("Polisher done")

def run_goldpolish_target(fasta, reads, output, length, k, w, minimap2, sensitive, bed, benchmark, target_dev):
    btllib.log_info("Starting targeted polishing")

    if minimap2:
        mapper='--minimap2'
    else:
        mapper='--ntLink'

    command = (
        f"{GOLDPOLISH_TARGET} --fasta {fasta} --reads {reads} --output {output} "
        f"--length {length} {mapper} --k_ntlink {k} --w_ntlink {w} "
        f"--sensitive {sensitive} "
    )

    command += f"--bed {bed} " if bed else ""
    command += "--benchmark " if benchmark else ""
    command += "--dev " if target_dev else ""

    command = shlex.split(command)

    sp.call(command)
    btllib.log_info("Targeted polishing finished")
    sys.exit(0)

if __name__ == "__main__":
    args = get_cli_args()

    # Work in shared memory if possible
    if isdir(args.shared_mem):
        workspace = args.shared_mem
    else:
        workspace = os.getcwd()
        btllib.log_warning(
            f"GoldPolish: {args.shared_mem} not present. Polishing might run slower."
        )

    if args.target:
        run_goldpolish_target(
            args.seqs_to_polish,
            args.polishing_seqs,
            args.output_seqs,
            args.length,
            args.k_ntlink,
            args.w_ntlink,
            args.minimap2,
            args.sensitive,
            args.bed,
            args.benchmark,
            args.target_dev
        )

    mapping_tool = MappingTool.NTLINK
    if args.minimap2:
        mapping_tool = MappingTool.MINIMAP2
    elif len(args.mappings) > 0:
        mapping_tool = MappingTool.MAPPINGS_PROVIDED

    polish_seqs(
        args.seqs_to_polish,
        args.polishing_seqs,
        args.output_seqs,
        args.k,
        args.bsize,
        workspace,
        args.mx_max_reads_per_10kbp,
        args.subsample_max_reads_per_10kbp,
        args.threads,
        mapping_tool,
        args.mappings,
        args.verbose,
        args.k_ntlink,
        args.w_ntlink,
    )