#!/usr/bin/env python3
'''
ntEdit: Fast, lightweight, scalable genome sequence polishing and SNV detection & annotation
Written by Johnathan Wong and Lauren Coombe, adapted from https://github.com/bcgsc/ntSynt/blob/main/bin/ntSynt
'''
import argparse
import os
import shlex
import subprocess
import sys
from packaging import version

import snakemake

NTEDIT_VERSION = "ntEdit v2.1.1"

def main():
    "Run ntEdit snakemake file"

    parser = argparse.ArgumentParser(description="ntEdit: Fast, lightweight, scalable genome sequence "
                                                "polishing and SNV detection & annotation",
                                     formatter_class=argparse.RawTextHelpFormatter)

    subparsers = parser.add_subparsers(help="ntEdit can be run in polishing or SNV modes.",
                                       dest="mode")

    parser_polishing = subparsers.add_parser("polish", help="Run ntEdit polishing")

    parser_snv = subparsers.add_parser("snv", help="Run ntEdit SNV mode")

    # Add draft argument first to ensure at top of help page
    parser_polishing.add_argument("--draft",
                               help="Draft genome assembly. Must be specified with exact FILE NAME. "
                                    "Ex: --draft myDraft.fa (FASTA, Multi-FASTA, and/or gzipped compatible), REQUIRED",
                                    required=True)
    parser_snv.add_argument("--reference",
                            help="Reference genome assembly for SNV calling (FASTA, Multi-FASTA, and/or gzipped compatible), REQUIRED")
    parser_snv.add_argument("--draft",
                            help=argparse.SUPPRESS)

    # Arguments for SNV mode only
    parser_snv.add_argument("--reads",
                            help="Prefix of input reads file(s) for variant calling. "
                            "All files in the working directory with the specified prefix will be used for polishing "
                            "(fastq, fasta, gz)", type=str)
    parser_snv.add_argument("--genome",
                            help="Genome assembly file(s) for detecting SNV on --reference",
                            nargs="+")

    # Arguments for polishing only
    parser_polishing.add_argument("--reads",
                                  help="Prefix of reads file(s). "
                                    "All files in the working directory with the specified prefix will be used for polishing "
                                    "(fastq, fasta, gz), REQUIRED",
                                    required=True)
    parser_polishing.add_argument("-i",
                                  help="Maximum number of insertion bases to try, range 0-5, [default=5]",
                                  default=5, type=int, choices=range(0, 6))

    parser_polishing.add_argument("-d",
                                  help="Maximum number of deletions bases to try, range 0-10, [default=5]",
                                  default=5, type=int, choices=range(0, 11))
    parser_polishing.add_argument("-x",
                                  help="k/x ratio for the number of k-mers that should be missing, [default=5.000]",
                                  default=5.000, type=float)
    parser_polishing.add_argument("--cap",
                                  help="Cap for the number of base insertions that can be made at one position"
                                    "[default=k*1.5]",
                                  type=float)
    parser_polishing.add_argument("-m",
                                  help="""Mode of editing, range 0-2, [default=0]\n
                                  0: best substitution, or first good indel\n
                                  1: best substitution, or best indel
                                  2: best edit overall (suggestion that you reduce i and d for performance)""",
                                  default=0, type=int, choices=range(0, 3))
    parser_polishing.add_argument("-a",
                                  help="Soft masks missing k-mer positions having no fix (1 = yes, default = 0, no)",
                                  default=0, type=int, choices=range(0, 2))



    # Arguments common to all modes
    for subparser in [parser_polishing, parser_snv]:
        subparser.add_argument("-k",
                               help="k-mer size, REQUIRED",
                               required=True, type=int)
        subparser.add_argument("-l",
                            help="input VCF file with annotated variants (e.g., clinvar.vcf)",
                            type=str)
        subparser.add_argument("--cutoff",
                               help="The minimum coverage of k-mers in output Bloom filter "
                                    "[default=2, ignored if solid=True]",
                               default=2, type=int)
        subparser.add_argument("--solid",
                               help="Output the solid k-mers (non-erroneous k-mers), [default=False]",
                               action="store_true", default=False)

        subparser.add_argument("-t",
                               help="Number of threads [default=4]", default=4, type=int)

        subparser.add_argument("-z",
                               help="Minimum contig length [default=100]", default=100, type=int)
        subparser.add_argument("-y",
                               help="k/y ratio for the number of edited k-mers that should be present, [default=9.000]",
                               default=9.000, type=float)
        subparser.add_argument("-j",
                               help="controls size of k-mer subset. When checking subset of k-mers, check every jth k-mer [default=3]",
                               default=3, type=int)
        subparser.add_argument("-X",
                               help="Ratio of number of k-mers in the k subset that should be missing in order"
                                    "to attempt fix (higher=stringent) "
                                    "[default=0.5, if -Y is specified]", default=-1, type=float)

        subparser.add_argument("-Y",
                               help="Ratio of number of k-mers in the k subset that should "
                                    "be present to accept an edit (higher=stringent) [default=0.5, if -X is specified]",
                                    default=-1, type=float)
        subparser.add_argument("-e",
                                  help="False positive rate for ntStat Bloom filter", default=0.01, type=float)
        subparser.add_argument("-v",
                               help="Verbose mode, [default=False]", action="store_true", default=False)
        subparser.add_argument("-V", "--version", action="version", version=NTEDIT_VERSION)

        subparser.add_argument("-n", "--dry-run", help="Print out the commands that will be executed",
                               action="store_true")

        subparser.add_argument("-f", "--force", help="Run all ntEdit steps, regardless of existing output files",
                               action="store_true")


    args = parser.parse_args()

    if args.mode is None:
        parser.print_help()
        sys.exit()

    base_dir = os.path.dirname(os.path.realpath(__file__))

    # Determining which snakemake target to run
    if args.mode == "snv":
        if args.reads:
            smk_rule = "ntedit_snv_reads"
        if args.genome:
            smk_rule = "ntedit_snv_genome"
    else:
        smk_rule = "ntedit"

    intro_string = ["Running ntEdit...",
                    "Parameter settings:"]

    if args.mode == "snv":
        if args.reference:
            input_genome = args.reference
            intro_string.append(f"\t--reference {args.reference}")
        elif args.draft:
            input_genome = args.draft
            intro_string.append(f"\t--reference {args.draft}")
        else:
            parser.error("Please specify --reference")
    else:
        input_genome = args.draft
        intro_string.append(f"\t--draft {args.draft}")

    command = f"snakemake -s {base_dir}/ntedit_run_pipeline.smk {smk_rule} -p --cores {args.t} " \
            f"--config draft={input_genome} kmer={args.k} threads={args.t} " \
            f"z_param={args.z} y_param={args.y} j_param={args.j} e_param={args.e} "

    if args.mode == "snv" and args.genome:
        intro_string.append(f"\t--genome {args.genome}")
        command += f"genomes={args.genome} "
    else:
        intro_string.append(f"\t--reads {args.reads}")
        command += f"reads={args.reads}"

    intro_string.extend([f"\t-k {args.k}",
                        f"\t-t {args.t}",
                        f"\t-z {args.z}",
                        f"\t-y {args.y}",
                        f"\t-j {args.j}",
                        f"\t-e {args.e}"])

    if args.X != -1 or args.Y != -1:
        if args.X == -1:
            args.X = 0.5
        if args.Y == -1:
            args.Y = 0.5
        intro_string.append(f"\t-X {args.X}")
        intro_string.append(f"\t-Y {args.Y}")
        command += f" X_param={args.X} Y_param={args.Y}"

    if not (args.mode == "snv" and args.genome):
        if args.solid:
            intro_string.append("\t--solid")
            command += f" solid={args.solid}"
        else:
            intro_string.append(f"\t--cutoff {args.cutoff}")
            command += f" cutoff={args.cutoff}"

    if args.v:
        intro_string.append("\t-v")
        command += " verbose=1"
    else:
        command += " verbose=0"

    if args.l:
        if not os.path.isfile(args.l):
            raise FileNotFoundError(f"VCF file {args.l} not found")
        intro_string.append(f"\t-l {args.l}")
        command += f" l_vcf={args.l}"

    # Polishing-specific parameter logs
    if args.mode == "polish":
        if args.cap:
            intro_string.append(f"\t--cap {args.cap}")
            command += " cap={args.cap}"
        intro_string.append(f"\t-i {args.i}")
        intro_string.append(f"\t-d {args.d}")
        intro_string.append(f"\t-x {args.x}")
        intro_string.append(f"\t-m {args.m}")
        intro_string.append(f"\t-a {args.a}")
        command += f" i_param={args.i} d_param={args.d} x_param={args.x} m_param={args.m} a_param={args.a}"

    # SNV / Ancestry-specific parameter logs
    if args.mode == "snv":
        if args.reads and args.genome:
            raise argparse.ArgumentTypeError("Please specify --reads OR --genome")
        if not args.reads and not args.genome:
            raise argparse.ArgumentTypeError("Please specify --reads OR --genome")

    print("\n".join(intro_string), flush=True)

    if version.parse(snakemake.__version__) >= version.parse("7.8.0"): # Keep behaviour consistent for smk versions
        command += " --rerun-trigger mtime "

    if args.dry_run:
        command += " -n"

    if args.force:
        command += " -F"

    print(f"Running {command}", flush=True)

    command = shlex.split(command)

    ret = subprocess.call(command)
    if ret != 0:
        raise subprocess.SubprocessError("ntEdit failed - check the logs for the error.")

    print("Done ntEdit!")

if __name__ == "__main__":
    main()
