#!/usr/bin/env python3

'''
Driver script for ntRoot
'''

import argparse
import os
import shlex
import subprocess
from packaging import version
import snakemake

NTROOT_VERSION = "v1.2.0"

def set_up_parser():
    "Set-up the ntRoot argparse parser"
    parser = argparse.ArgumentParser(description="ntRoot: Ancestry inference from genomic data",
                                     formatter_class=argparse.RawTextHelpFormatter,
                                     epilog="Note: please specify --reads OR --genome (not both)\n"
                                            "If you have any questions about ntRoot, please open an "
                                            "issue at https://github.com/bcgsc/ntRoot")
    parser.add_argument("--draft",
                help=argparse.SUPPRESS)
    parser.add_argument("-r", "--reference",
                help="Reference genome (FASTA, Multi-FASTA, and/or gzipped compatible)")

    parser.add_argument("--reads",
                help="Prefix of input reads file(s) for detecting SNVs. "
                    "All files in the working directory with the specified prefix will be used. "
                    "(fastq, fasta, gz, bz, zip)", type=str)
    parser.add_argument("--genome",
                        help="Genome assembly file(s) for detecting SNVs compared to --reference", nargs="+")
    parser.add_argument("-l",
                        help="input IVC VCF file with annotated variants (e.g., 1000GP_integrated_snv_v2a_27022019.GRCh38.phased_gt1.vcf.gz, clinvar.vcf, etc.)",
                        type=str, required=True)
    parser.add_argument("--exome", help="Input reads for detecting SNVs are from whole exome sequencing. "
                        "If provided, must also specify either --exome_bed or --masked. --cutoff 2 is implied unless otherwise specified.",
                        action="store_true")

    parser.add_argument("-k",
                        help="k-mer size",
                        required=False, type=int)
    parser.add_argument("--tile", help="Tile size for ancestry fraction inference (bp) [default=5000000]",
                        default=5000000, type=int)
    parser.add_argument("--lai", help="Output ancestry predictons per tile in a separate output file",
                        action="store_true")
    parser.add_argument("-t",
                        help="Number of threads [default=4]", default=4, type=int)
    parser.add_argument("-z",
                        help="Minimum contig length [default=100]", default=100, type=int)
    parser.add_argument("-j",
                        help="controls size of k-mer subset. When checking subset of k-mers, check every jth k-mer "
                            "[default=3]",
                        default=3, type=int)
    parser.add_argument("--cutoff", help="Minimum coverage of k-mers in ntEdit Bloom filter. "
                        "Solid k-mers are used if set to 0 [0]", default=0, type=int)
    parser.add_argument("-Y",
                        help="Ratio of number of k-mers in the k subset that should be present to accept "
                            "an edit (higher=stringent) "
                            "[default=0.55]", default=0.55, type=float)
    parser.add_argument("--custom_vcf", help="Input VCF for computing ancestry. "
                                        "When specified, ntRoot will skip the ntEdit step, and "
                                        "predict ancestry from the provided VCF.",
                        type=str)
    parser.add_argument("--masked", help="Exome Mode (--exome) only: "
                        "Indicates that the reference genome provided with --reference "
                        "has all NON-targeted exonic regions hard-masked. ",
                        action="store_true")
    parser.add_argument("--exome_bed", help="Exome Mode (--exome) only: BED file of exome targeted regions. ",
                        type=str)
    parser.add_argument("--strip_info", help="When using --custom_vcf, "
                        "strip the existing INFO field from the input VCF.",
                        action="store_true")
    parser.add_argument("-v", "--verbose",
                        help="Verbose mode [default=False]", action="store_true", default=False)
    parser.add_argument("-V", "--version", action="version", version=NTROOT_VERSION)

    parser.add_argument("-n", "--dry-run", help="Print out the commands that will be executed", action="store_true")

    parser.add_argument("-f", "--force", help="Run all ntRoot steps, regardless of existing output files",
                        action="store_true")
    return parser


def main():
    "Run ntRoot"
    parser = set_up_parser()

    args = parser.parse_args()

    if ((args.reads and args.genome) or (not args.reads and not args.genome)) and not args.custom_vcf:
        raise argparse.ArgumentTypeError("Please specify --reads OR --genome")

    if args.exome and not (args.masked or args.exome_bed):
        raise argparse.ArgumentTypeError("In exome mode, please specify either --masked or --exome_bed")

    if args.exome and (args.masked and args.exome_bed):
        raise argparse.ArgumentTypeError("In exome mode, please specify either --masked OR --exome_bed")

    if not args.exome and (args.masked or args.exome_bed):
        raise argparse.ArgumentTypeError("Parameters --masked and --exome_bed are only used only in exome mode")

    if args.exome and args.masked:
        print("--exome and --masked specified. Assuming reference genome provided with --reference has all"
              " NON-targeted exonic regions hard-masked.", flush=True)

    if args.exome and args.cutoff < 2:
        args.cutoff = 2
        print("In exome mode, minimum k-mer coverage cutoff is set to 2 by default "
              "unless explicitly set higher. Setting cutoff to 2.",
              flush=True)

    if not args.draft and not args.reference:
        raise argparse.ArgumentTypeError("Please specify the reference genome with --reference")
    if args.draft:
        args.reference = args.draft

    if args.reads or args.genome:
        if not args.k:
            raise argparse.ArgumentTypeError("Please specify the k-mer size with -k")

    base_dir = os.path.dirname(os.path.realpath(__file__))

    intro_string = ["Running ntRoot...",
                    "Parameter settings:"]

    if args.reads:
        if args.exome and args.lai:
            smk_rule = "ntroot_reads_exome_lai"
        elif args.exome:
            smk_rule = "ntroot_reads_exome"
        elif args.lai:
            smk_rule = "ntroot_reads_lai"
        else:
            smk_rule = "ntroot_reads"
    if args.genome:
        if args.lai:
            smk_rule = "ntroot_genome_lai"
        else:
            smk_rule = "ntroot_genome"
    if args.custom_vcf:
        if args.lai:
            smk_rule = "ntroot_input_vcf_lai"
        else:
            smk_rule = "ntroot_input_vcf"

    command = f"snakemake -s {base_dir}/ntroot_run_pipeline.smk {smk_rule} --nolock -p --cores {args.t} " \
            f"--config reference={args.reference} threads={args.t} " \
            f"tile_size={args.tile} "

    if args.genome:
        intro_string.append(f"\t--genome {args.genome}")
        command += f"genomes={args.genome}"
    elif args.custom_vcf:
        intro_string.append(f"\t--custom_vcf {args.custom_vcf}")
        command += f" input_vcf={args.custom_vcf}"
        if args.strip_info:
            intro_string.append("\t--strip_info")
            command += " strip_info=True"
    else:
        intro_string.append(f"\t--reads {args.reads}")
        command += f"reads={args.reads}"

        intro_string.append(f"\t--cutoff {args.cutoff}")
        command += f" cutoff={args.cutoff}"

    if args.exome:
        intro_string.append("\t--exome")
        command += " exome=True"
        if args.masked:
            intro_string.append("\t--masked")
            command += " masked=True"
        elif args.exome_bed:
            intro_string.append(f"\t--exome_bed {args.exome_bed}")
            command += f" exome_bed={args.exome_bed}"
        
    intro_string.extend([f"\t--reference {args.reference}",
                        f"\t-t {args.t}"])

    if args.genome or args.reads:
        intro_string.extend([f"\t-k {args.k}",
                             f"\t-z {args.z}",
                             f"\t-j {args.j}",
                             f"\t-Y {args.Y}"])
        command += f" kmer={args.k} z_param={args.z} j_param={args.j} Y_param={args.Y}"

    if args.lai:
        intro_string.append("\t--lai")

    if args.verbose:
        intro_string.append("\t-v")
        command += " verbose=1"
    else:
        command += " verbose=0"

    if not os.path.isfile(args.l):
        raise FileNotFoundError(f"VCF file {args.l} not found")
    intro_string.append(f"\t-l {args.l}")
    command += f" l_vcf={args.l}"

    print("\n".join(intro_string), flush=True)


    if version.parse(snakemake.__version__) >= version.parse("7.8.0"): # Keep behaviour consistent for smk versions
        command += " --rerun-trigger mtime "

    if args.dry_run:
        command += " -n"

    if args.force:
        command += " -F"

    print(f"Running {command}", flush=True)

    command = shlex.split(command)

    ret = subprocess.call(command)
    if ret != 0:
        raise subprocess.SubprocessError("ntRoot failed - check the logs for the error.")

    print("Done ntRoot!")


if __name__ == "__main__":
    main()
