#!/usr/bin/env python3

import argparse
import time
from os.path import (
    join,
    exists,
    splitext,
    getsize,
)
import shutil

import btllib

from goldpolish_utils import bind_to_parent

RUSH_DELAY = 1  # Seconds

BATCH_DONE_PIPE = "polishing_done"
POLISHING_OVER_FILE = "gg"
SEPARATOR = "-"


def get_cli_args():
    parser = argparse.ArgumentParser("Reap the finished batches.")
    parser.add_argument(
        "workspace_path", help="The path to where all the polishing is taking place."
    )
    parser.add_argument("prefix", help="Unique prefix of this polishing instance.")
    parser.add_argument(
        "batch_polished_seqs", help="Name of the file with batch polished sequences."
    )
    parser.add_argument("output_seqs", help="Filepath to write polished seqs to.")
    return parser.parse_args()


def write_batch_results(polished_seqs, writer):
    btllib.check_error(
        getsize(polished_seqs) <= 0, f"Polished seqs file is empty: {polished_seqs}"
    )
    with btllib.SeqReader(polished_seqs, btllib.SeqReaderFlag.LONG_MODE) as reader:
        for record in reader:
            writer.write(record.id, record.comment, record.seq)


def wait_on_batch(batch_done_pipe):
    with open(batch_done_pipe) as f:
        f.read()


def start_reapin(workspace, prefix, batch_polished_seqs, output_seqs):
    with btllib.SeqWriter(output_seqs) as writer:
        batch_num = 0
        over = False
        while not over:
            batch_dir = join(workspace, f"{prefix}{SEPARATOR}{batch_num}")
            batch_done_pipe = join(batch_dir, BATCH_DONE_PIPE)
            polishing_over_file = join(batch_dir, POLISHING_OVER_FILE)

            while not exists(batch_done_pipe):
                time.sleep(RUSH_DELAY)

            wait_on_batch(batch_done_pipe)

            if exists(polishing_over_file):
                over = True
            else:
                polished_seqs = join(batch_dir, batch_polished_seqs)
                write_batch_results(polished_seqs, writer)

            shutil.rmtree(batch_dir, ignore_errors=True)

            batch_num += 1


if __name__ == "__main__":
    args = get_cli_args()

    bind_to_parent()

    start_reapin(
        args.workspace_path, args.prefix, args.batch_polished_seqs, args.output_seqs
    )
