import os
import pandas as pd
import sys
from pling.utils import get_pling_root_dir, get_fasta_file_info, get_number_of_batches

configfile: "../config.yaml"

FASTAFILES, FASTAEXT, FASTAPATH = get_fasta_file_info(config["genomes_list"])
GENOMES = list(FASTAEXT.keys())
OUTPUTPATH = config["output_dir"] #output directory will contain subdirectory with unimogs from integerisation pipeline, as well as placing all new output from current pipeline in subdirectories within it
PREFIX = config["prefix"]
INTEGERISATION = config["integerisation"]
CONTAINMENT_DISTANCE = config["seq_containment_distance"]
COMMUNITIES = config["communities"]
batch_size = config["batch_size"]
dcj_threshold = config["dcj_dist_threshold"]

def get_timelimit(timelimit):
    if config["timelimit"]=="None":
        return ""
    else:
        return f"--timelimit {timelimit}"

def get_unimog():
    if config.get("unimog", False):
        unimog = config["unimog"]
        return f"--unimog {unimog}"
    else:
        return ""

def get_prev_pling():
    if config.get("previous_pling",False):
        blub = config["previous_pling"].strip().split(",")
        blah = " ".join([f"{bl}/dcj_thresh_{dcj_threshold}_graph/objects/typing.tsv" for bl in blub])
        recluster = config["reclustering_method"]
        return f"--prev_typing {blah} --reclustering_method {recluster}"
    else:
        return ""

def get_prev_dcj():
    if config.get("previous_pling",False):
        blub = config["previous_pling"]
        return f"{blub}/all_plasmids_distances.tsv"
    else:
        return None

def get_vis():
    if config["visualisation"]=="none":
        return "--no-vis"
    else:
        return ""

def get_containmentpath(batch):
    if config.get("sourmash_only", False):
        return f"{OUTPUTPATH}/containment/all_pairs_containment_distance.tsv"
    else:
        return f"{OUTPUTPATH}/tmp_files/containment_batchwise/batch_{batch}_containment.tsv"

rule all:
    input:
        dcj_graph_outdir = f"{OUTPUTPATH}/dcj_thresh_{dcj_threshold}_graph"

rule ding:
    input:
        containment_tsv = lambda wildcards: get_containmentpath(wildcards.batch),
        batch_list = lambda wildcards: f"{OUTPUTPATH}/batches/batch_{wildcards.batch}.txt"
    output:
        f"{OUTPUTPATH}/tmp_files/dists_batchwise/batch_{{batch}}_dcj.tsv"
    params:
        containment_distance=CONTAINMENT_DISTANCE,
        integerisation=INTEGERISATION,
        outputpath=OUTPUTPATH,
        communitypath=f"{COMMUNITIES}/objects/communities.txt",
        batch=lambda wildcards: wildcards.batch,
        timelimit=get_timelimit(config["timelimit"]),
        snakefile_dir=os.path.dirname(sys.argv[sys.argv.index("--snakefile")+1]),
        pling_root_dir = get_pling_root_dir(),
        unimog = get_unimog(),
        ilp_solver = config["ilp_solver"]
    threads: config["ilp_threads"]
    resources:
        mem_mb = lambda wildcards, attempt: attempt * config["ilp_mem"]
    shadow: "shallow"
    shell:
            """
            PYTHONPATH={params.pling_root_dir} python {params.pling_root_dir}/pling/dcj_snakemake/dcj.py \
                    --batch_file {input.batch_list} \
                    --batch {params.batch} \
                    --containment_tsv {input.containment_tsv} \
                    --containment_distance {params.containment_distance} \
                    --outputpath {params.outputpath} \
                    --distpath {output} \
                    --communitypath {params.communitypath} \
                    --integerisation {params.integerisation} \
                    {params.timelimit} \
                    --threads {threads} \
                    --snakefile_dir {params.snakefile_dir} \
                    {params.unimog} \
                    --ilp_solver {params.ilp_solver}
            """

rule dcj_tsv:
    input:
        dists = expand(f"{OUTPUTPATH}/tmp_files/dists_batchwise/batch_{{batch}}_dcj.tsv", batch=[str(i) for i in range(get_number_of_batches(OUTPUTPATH))])
    output:
        tsv = f"{OUTPUTPATH}/{PREFIX}_distances.tsv"
    threads: 1
    resources:
        mem_mb=lambda wildcards, attempt: 4000*attempt
    params:
        prev_dcj = get_prev_dcj()
    run:
        with open(output.tsv, "w") as dcj_out:
            dcj_out.write("plasmid_1\tplasmid_2\tdistance\n")
            for file in input.dists:
                with open(file, "r") as f:
                    to_cat = f.read()
                dcj_out.write(to_cat)
            if params.prev_dcj:
                with open(params.prev_dcj) as f:
                    next(f)
                    to_cat = f.read()
                dcj_out.write(to_cat)

rule build_DCJ_graph:
    input:
        distances_tsv = rules.dcj_tsv.output.tsv,
        communities=COMMUNITIES+"/objects/communities.pkl"
    output:
        dcj_graph_outdir = directory(f"{OUTPUTPATH}/dcj_thresh_{dcj_threshold}_graph")
    threads: 1
    resources:
        mem_mb=lambda wildcards, attempt: config["build_DCJ_graph_mem"]*attempt
    params:
        dcj_dist_threshold=config["dcj_dist_threshold"],
        small_subcommunity_size_threshold = config["small_subcommunity_size_threshold"], #Communities with size up to this parameter will be joined to neighbouring larger subcommunities
        output_type = config["output_type"],
        prev_pling = get_prev_pling(),
        vis = get_vis()
    shell: """
            plasnet type \
                --distance-threshold {params.dcj_dist_threshold} \
                --small-subcommunity-size-threshold {params.small_subcommunity_size_threshold} \
                {input.communities} \
                {input.distances_tsv} \
                {output.dcj_graph_outdir} \
                --output-type {params.output_type} \
                {params.prev_pling} \
                {params.vis}
        """
