#!/usr/bin/env python3

import os
import argparse
from os.path import join, dirname, realpath, abspath, isfile, exists, splitext, basename
import subprocess as sp
import time

import btllib

from goldpolish_utils import (
    bind_to_parent,
    update_threads_in_use,
    compare_update_threads_in_use,
    update_simultaneous_batch_processes,
)

GOLDPOLISH_MAKE = "goldpolish-make"
GOLDPOLISH_MAKE_FULL_PATH = f"{os.path.dirname(os.path.realpath(__file__))}/{GOLDPOLISH_MAKE}"
GOLDPOLISH_HOLD = "goldpolish-hold"
RETRY_DELAY = 0.2


def get_cli_args():
    parser = argparse.ArgumentParser(
        description="Request the Bloom filters and run the polishing protocol on the given sequences."
    )
    parser.add_argument("seqs_to_polish")
    parser.add_argument("bfs_dir")
    parser.add_argument("workspace")
    parser.add_argument("prefix")
    parser.add_argument("max_threads", type=int)
    parser.add_argument(
        "-k", action="append", default=[], help="List of k values to use for polishing."
    )
    parser.add_argument(
        "-b",
        action="append",
        default=[],
        help="Bloom filter paths, corresponding to provided k values.",
    )
    parser.add_argument("--seq-ids", default="seq_ids")
    parser.add_argument("--bfs-ids-pipe", default="targeted_input")
    parser.add_argument("--bfs-ready-pipe", default="targeted_ready")
    parser.add_argument("--batch-done-pipe", default="batch_done")
    parser.add_argument("-t", "--threads", type=int, default=2)
    parser.add_argument("-v", "--verbose", action="store_true")

    args = parser.parse_args()

    args.seqs_to_polish = abspath(args.seqs_to_polish)
    assert len(args.k) > 0, "At least one k value must be specified."

    return args


def get_seq_ids(record_ids_filepath):
    with open(record_ids_filepath) as records_file:
        return [line.strip() for line in records_file]


def get_bfs_ready(bfs_dirpath, id_pipename, seq_ids, ready_pipename):
    with open(join(bfs_dirpath, id_pipename), "w") as id_pipe:
        for id in seq_ids:
            id_pipe.write(id + "\n")
    with open(join(bfs_dirpath, ready_pipename)) as f:
        f.read()


def run_polishing(seqs_to_polish, bfs, k_values, threads):
    polished_seqs = (
        f"{splitext(seqs_to_polish)[0]}.ntedited.prepd.sealer_scaffold.upper.fa"
    )

    bfs = " ".join(bfs)
    k_values = " ".join(k_values)

    try:
        sealer_protocol_process = sp.run(
            [
                f""" \
        make -Rrf \
        {GOLDPOLISH_MAKE_FULL_PATH} \
        seqs_to_polish={seqs_to_polish} \
        bfs='{bfs}' \
        K='{k_values}' \
        t={threads} \
        {polished_seqs} \
    """
            ],
            shell=True,
            text=True,
            capture_output=True,
            check=True,
        )
    except sp.CalledProcessError as e:
        btllib.log_error(f"{e.stderr}")
        raise e

    return sealer_protocol_process.stdout, sealer_protocol_process.stderr


def cleanup_bfs(bfs_dirpath, bfs):
    for bf in bfs:
        os.remove(join(bfs_dirpath, bf))


def confirm_batch_done(batch_done_pipepath):
    sp.Popen([GOLDPOLISH_HOLD, batch_done_pipepath])

def print_info(stdout, stderr):
    to_print = ""
    print_next_lines = 0
    for line in stderr.split("\n"):
        line = line.strip()
        if print_next_lines > 0:
            to_print += f"{line}"
            print_next_lines -= 1
            if print_next_lines == 0:
                to_print += ")\n"
            else:
                to_print += ", "
        elif "No start/goal kmer" in line:
            print_next_lines = 11
            to_print += f"({line}, "
        elif "Gaps closed = " in line:
            print_next_lines = 1
            to_print += f"{line} ("
    btllib.log_info(to_print)


if __name__ == "__main__":
    args = get_cli_args()

    bind_to_parent()

    seq_ids = get_seq_ids(args.seq_ids)
    get_bfs_ready(args.bfs_dir, args.bfs_ids_pipe, seq_ids, args.bfs_ready_pipe)

    while not compare_update_threads_in_use(
        args.workspace, args.prefix, args.max_threads, args.threads
    ):
        time.sleep(RETRY_DELAY)

    polishing_stdout, polishing_stderr = run_polishing(
        args.seqs_to_polish, args.b, args.k, args.threads
    )

    update_threads_in_use(args.workspace, args.prefix, -args.threads)

    cleanup_bfs(args.bfs_dir, args.b)

    if args.verbose:
        print_info(polishing_stdout, polishing_stderr)

    update_simultaneous_batch_processes(args.workspace, args.prefix, -1)

    confirm_batch_done(args.batch_done_pipe)
