#!/usr/bin/env python3

import argparse
import os
import subprocess
import sys
import shutil
import urllib.request
import re
from datetime import datetime
from typing import List

def log(msg: str, verbose: bool = True) -> None:
    if verbose:
        timestamp = datetime.now().strftime("[%Y-%m-%d %H:%M:%S]")
        sys.stderr.write(f"{timestamp} {msg}\n")

def run_help(binary_name: str) -> None:
    subprocess.run([binary_name, "-h"], check=True)

def list_available_lineages(base_url: str) -> None:
    log(f"Fetching lineage list from {base_url}...")
    with urllib.request.urlopen(base_url) as response:
        html = response.read().decode("utf-8")

    pattern = re.compile(r'href="(.*?)_(odb\d+).*?\.tar\.gz"')
    matches = set(pattern.findall(html))

    if not matches:
        print("No lineages found.")
        return

    lineage_versions = sorted(matches)
    print("Available lineages:")
    for lineage, odb in lineage_versions:
        print(f"  {lineage} (version: {odb})")
    print("\nTo use one of these, pass --lineage <name> and optionally --odb-version <version number, e.g., 12>")

def download_and_prepare_lineage(lineage_name: str, data_dir: str, odb_version: str, busco_website: str, dry_run: bool, verbose: bool = False) -> str:
    output_fasta = os.path.join(data_dir, f"{lineage_name}_{odb_version}.faa")

    if not dry_run:
        base_url = busco_website
        with urllib.request.urlopen(base_url) as response:
            html = response.read().decode("utf-8")

        pattern = re.compile(rf'href="({lineage_name}_{odb_version}\.\d{{4}}-\d{{2}}-\d{{2}}\.tar\.gz)"')
        matches = pattern.findall(html)

        if not matches:
            raise ValueError(f"No dataset found for {lineage_name}_{odb_version}")

        tarball_name = matches[0]
        tarball_url = base_url + tarball_name
        tarball_path = os.path.join(data_dir, tarball_name)
        lineage_dir_name = f"{lineage_name}_{odb_version}"

        log(f"Downloading with wget: {tarball_url}")
        subprocess.run(["wget", "-O", tarball_path, tarball_url], check=True)

        cwd = os.getcwd()
        os.makedirs(data_dir, exist_ok=True)
        os.chdir(data_dir)

        hmm_path = f"{lineage_dir_name}/hmms"
        try:
            if shutil.which("pigz"):
                log(f"Using pigz to extract {hmm_path}...", verbose)
                subprocess.run(["tar", "--use-compress-program=pigz", "-xf", tarball_name, hmm_path], check=True)
            else:
                log(f"Using gzip to extract {hmm_path}...", verbose)
                subprocess.run(["tar", "-xzf", tarball_name, hmm_path], check=True)
        finally:
            os.chdir(cwd)

        hmms_dir = os.path.join(data_dir, hmm_path)
        if not os.path.isdir(hmms_dir):
            raise FileNotFoundError(f"hmms directory not found: {hmms_dir}")

        
        log(f"Running hmmemit -c on HMMs in {hmms_dir}...", verbose)

        with open(output_fasta, "w") as out_f:
            for hmm_file in os.listdir(hmms_dir):
                hmm_path_full = os.path.join(hmms_dir, hmm_file)
                try:
                    result = subprocess.run(["hmmemit", "-c", hmm_path_full], capture_output=True, check=True, text=True)
                    out_f.write(result.stdout)
                except subprocess.CalledProcessError:
                    log(f"Warning: hmmemit failed for {hmm_file}")

        log("Cleaning up .tar.gz and extracted hmms folder...", verbose)
        os.remove(tarball_path)
        shutil.rmtree(hmms_dir)
        parent = os.path.dirname(hmms_dir)
        if os.path.isdir(parent) and not os.listdir(parent):
            shutil.rmtree(parent)
    else:
        log(f"Dry run: would download lineage {lineage_name} version {odb_version} to {data_dir}")

    log(f"Lineage FASTA ready: {output_fasta}")
    return output_fasta

def is_database_built(mibf_path: str, verbose: bool = False) -> bool:
    exists = os.path.isfile(mibf_path)
    log(f"{'Located' if exists else 'Could not locate'} miBF at: {mibf_path}")
    return exists

def wrap_with_time(cmd: List[str], label: str, output_prefix: str, track_time: bool) -> List[str]:
    if track_time and shutil.which("/usr/bin/time"):
        time_file = f"{label}_{output_prefix}.time"
        return ["/usr/bin/time", "-pv", "-o", time_file] + cmd
    return cmd

def run_make_mibf(args: argparse.Namespace) -> None:
    log("Starting make_mibf...")
    ref_name = os.path.basename(args.reference)
    base_name = os.path.splitext(ref_name)[0]
    full_db_path = os.path.join(args.db_dir, base_name)
    cmd = [
        "make_mibf", "-r", ref_name, "-o", base_name,
        "-k", str(args.kmer), "-h", str(args.hash), "-t", str(args.threads)
    ]
    if args.verbose:
        cmd.append("-v")
    if args.debug:
        cmd.append("--debug")

    cmd = wrap_with_time(cmd, "make_mibf", args.output, args.track_time)


    log(f"Running command in {full_db_path}: {' '.join(cmd)}")
    if not args.dry_run:
        os.makedirs(full_db_path, exist_ok=True)
        ref_path = os.path.join(full_db_path, ref_name)

        if not os.path.exists(ref_path):
            if args.lineage:
                shutil.move(os.path.abspath(args.reference), ref_path)
                log(f"Moved lineage-derived reference to: {ref_path}", args.verbose)
                args.reference = ref_path
            else:
                os.symlink(os.path.abspath(args.reference), ref_path)
                log(f"Created symlink for reference file: {ref_path}", args.verbose)
        else:
            log(f"Reference file already exists at: {ref_path}", args.verbose)


        subprocess.check_call(cmd, cwd=full_db_path)
    else:
        log("Dry run: skipping execution")

    log("Finished building miBF.")

def run_aakomp(args: argparse.Namespace) -> None:
    log("Starting aaKomp...")
    cmd = [
        "aakomp", "-i", args.input, "-o", args.output, "-r", args.reference,
        "-t", str(args.threads), "-k", str(args.kmer), "-h", str(args.hash),
        "-l", str(args.lower_bound), "--mibf_path", args.mibf_path,
        "--rescue_kmer", str(args.rescue_kmer), "--max_offset", str(args.max_offset)
    ]
    if args.verbose:
        cmd.append("-v")
    if args.debug:
        cmd.append("--debug")

    cmd = wrap_with_time(cmd, "aakomp", args.output, args.track_time)

    log(f"Running command: {' '.join(cmd)}")
    if not args.dry_run:
        subprocess.check_call(cmd)
    else:
        log("Dry run: skipping execution")
    log("Finished running aaKomp.")

def run_aakomp_analysis(args: argparse.Namespace) -> None:
    log("Starting aaKomp analysis...")
    results_file = f"{args.output}.gff"
    score_file = f"{args.output}_score.txt"
    cmd = ["aakomp_plot.R", results_file, str(args.lower_bound), score_file]
    cmd = wrap_with_time(cmd, "aakomp_plot.R", args.output, args.track_time)

    log(f"Running command: {' '.join(cmd)}")
    if not args.dry_run:
        subprocess.check_call(cmd)
    else:
        log("Dry run: skipping execution")

    log("Finished post-analysis.")

def main():
    parser = argparse.ArgumentParser(description="aaKomp driver script: download lineage, build miBf, run aaKomp")

    parser.add_argument("--help-aakomp", action="store_true")
    parser.add_argument("--help-mibf", action="store_true")
    parser.add_argument("-i", "--input", help="Input file name")
    parser.add_argument("-o", "--output", default="_")
    parser.add_argument("-r", "--reference")
    parser.add_argument("-t", "--threads", type=int, default=48)
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("-H", "--hash", type=int, default=9)
    parser.add_argument("-k", "--kmer", type=int, default=9)
    parser.add_argument("-l", "--lower-bound", type=float, default=0.7)
    parser.add_argument("--rescue-kmer", type=int, default=4)
    parser.add_argument("--max-offset", type=int, default=2)
    parser.add_argument("--lineage")
    parser.add_argument("--db-dir", default="./")
    parser.add_argument("--dry-run", action="store_true")
    parser.add_argument("--track-time", action="store_true")
    parser.add_argument("--odb-version", default="12", help="BUSCO odb version to use (default: 12)")
    parser.add_argument("--list-lineages", action="store_true", help="List available BUSCO lineages")
    parser.add_argument("--visualise", action="store_true", help="Visualise the cumulative distribution function")
    parser.add_argument("--version", action="version", version="v1.0.0")

    args = parser.parse_args()

    busco_website = "https://busco-data.ezlab.org/v5/data/lineages/"

    if args.list_lineages:
        list_available_lineages(busco_website)
        sys.exit(0)

    if args.help_aakomp:
        run_help("aakomp")
        sys.exit(0)
    if args.help_mibf:
        run_help("make_mibf")
        sys.exit(0)

    if not args.reference and not args.lineage:
        parser.error("--reference or --lineage must be provided.")
    if args.reference and args.lineage:
        parser.error("Cannot provide both --reference and --lineage.")
    if not args.input:
        parser.error("Input file is required.")
    elif not os.path.isfile(args.input):
        parser.error(f"Input file '{args.input}' does not exist.")

    log("Arguments:")
    for arg in vars(args):
        log(f"  {arg}: {getattr(args, arg)}")

    if args.lineage:
        if not args.dry_run:
            os.makedirs(args.db_dir, exist_ok=True)
        odb_version_str = "odb" + args.odb_version
        ref_name = f"{args.lineage}" + f"_{odb_version_str}" 
        final_faa = os.path.join(args.db_dir, ref_name, f"{ref_name}.faa")
        if os.path.isfile(final_faa):
            log(f"Found existing lineage FASTA: {final_faa}")
            args.reference = final_faa
        else:
            args.reference = download_and_prepare_lineage(args.lineage, args.db_dir, odb_version_str, busco_website, args.dry_run, args.verbose)

    base_name = os.path.splitext(os.path.basename(args.reference))[0]
    args.mibf_path = os.path.join(args.db_dir, f"{base_name}/", f"{base_name}.mibf")

    if not is_database_built(args.mibf_path):
        run_make_mibf(args)

    run_aakomp(args)

    if args.visualise:
        run_aakomp_analysis(args)

if __name__ == "__main__":
    main()

