from pathlib import Path
import os
import pandas as pd
import subprocess
import glob
from pling.utils import get_pling_root_dir

configfile: "config.yaml"

INTEGERISATION="align"
OUTPUTPATH = config["submatrices_dir"]
dcj_threshold = config["dcj_dist_threshold"]
old_dir = config["output_dir"]
SUBCOMMUNITIESPATH = f"{old_dir}/dcj_thresh_{dcj_threshold}_graph/objects/typing.tsv"

def get_subcomms(wildcards):
    if not config["ignore_containment"]:
        checkpoints.incomplete_submatrices.get(**wildcards).output[0]
        subcomms = [os.path.basename(file).replace(".dist","") for file in glob.glob(f"{OUTPUTPATH}/dists/*")]
    else:
        subcommunities_list = pd.read_csv(SUBCOMMUNITIESPATH, sep='\t')["type"].to_list()
        count = {subcommunity:subcommunities_list.count(subcommunity) for subcommunity in set(subcommunities_list)}
        subcomms = [subcomm for subcomm in subcommunities_list if count[subcomm]>2]
    return subcomms

def get_regions():
    if config.get("regions", False):
        return "--regions"
    else:
        return ""

def get_topology(topology):
    if config["topology"]=="None":
        return ""
    else:
        return f"--topology {topology}"

def read_in_dists(filepath):
    dists = {}
    with open(filepath, "r") as f:
        for line in f:
            plasmid1, plasmid2, dist = line.strip("\n").split("\t")
            dists[(plasmid1,plasmid2)] = int(dist)
            dists[(plasmid2,plasmid1)] = int(dist)
    return dists

def get_final_outputs(wildcards):
    if config["vis_trees"]:
        return expand(f"{OUTPUTPATH}/vis/{{subcommunity}}.pdf", subcommunity=get_subcomms(wildcards))
    else:
        return expand(f"{OUTPUTPATH}/trees/{{subcommunity}}.tree", subcommunity=get_subcomms(wildcards))

def get_input(wildcards):
    checkpoints.incomplete_submatrices.get(subcommunity=wildcards.subcommunity).output[0]
    return f"{OUTPUTPATH}/dists/{wildcards.subcommunity}.dist"

rule all:
    input:
        get_final_outputs

checkpoint incomplete_submatrices:
    input:
        subcom = SUBCOMMUNITIESPATH,
        tsv = f"{old_dir}/all_plasmids_distances.tsv"
    output:
        outdir_1 = directory(f"{OUTPUTPATH}/incomplete"),
        outdir_2 = directory(f"{OUTPUTPATH}/dists")
    resources:
        mem_mb = lambda wildcards, attempt: attempt * 1000
    script:
        "get_submatrices.py"

rule missing_entries:
    input:
        incomplete_submatrix = lambda wildcards: f"{OUTPUTPATH}/incomplete/{wildcards.subcommunity}_incomplete.dist"
    output:
        missing_entries = f"{OUTPUTPATH}/missing/{{subcommunity}}_missing.txt"
    resources:
        mem_mb = lambda wildcards, attempt: attempt * 1000
    run:
        submatrix = pd.read_csv(input.incomplete_submatrix, sep="\t", index_col=0)
        submatrix.columns = submatrix.index
        missing_df = submatrix.isna()
        with open(output.missing_entries, "w") as f:
            for i in range(0, len(submatrix.index)):
                plasmid1 = list(submatrix.index)[i]
                for j in range(0,i):
                    if missing_df.iloc[i,j]==True:
                        plasmid2 = list(submatrix.columns)[j]
                        f.write(str([plasmid1,plasmid2])+"\n")

rule make_unimogs:
    input:
        batch_list=lambda wildcards: f"{OUTPUTPATH}/missing/{wildcards.subcommunity}_missing.txt"
    output:
        containment=f"{OUTPUTPATH}/containment_missing/{{subcommunity}}_containment.tsv",
        unimog = f"{OUTPUTPATH}/unimogs_missing/{{subcommunity}}_align.unimog",
        map = f"{OUTPUTPATH}/unimogs_missing/{{subcommunity}}_map.txt"
    threads: 1
    resources:
        mem_mb=lambda wildcards, attempt: 10000*attempt
    params:
        genomes_list = config["genomes_list"],
        outputpath = OUTPUTPATH,
        identity_threshold = 80,
        containment_distance = 1,
        pling_root_dir = get_pling_root_dir(),
        regions = get_regions(),
        topology = get_topology(config["topology"])
    shadow: "shallow"
    shell:
        """
        PYTHONPATH={params.pling_root_dir} python {params.pling_root_dir}/pling/align_snakemake/unimog.py \
            --genomes_list {params.genomes_list} \
            --batch {input.batch_list} \
            --identity_threshold {params.identity_threshold} \
            --containment_distance {params.containment_distance} \
            --outputpath {params.outputpath} \
            --containment_output {output.containment} \
            --unimog_output {output.unimog} \
            --map_output {output.map} \
            {params.regions} \
            {params.topology}
        """

rule ding:
    input:
        containment_tsv=f"{OUTPUTPATH}/containment_missing/{{subcommunity}}_containment.tsv",
        unimog = f"{OUTPUTPATH}/unimogs_missing/{{subcommunity}}_align.unimog",
        batch_list = lambda wildcards: f"{OUTPUTPATH}/missing/{wildcards.subcommunity}_missing.txt"
    output:
        f"{OUTPUTPATH}/dists_missing/{{subcommunity}}_dcj.tsv"
    params:
        containment_distance=1,
        outputpath=OUTPUTPATH,
        pling_root_dir = get_pling_root_dir(),
        ilp_solver = config["ilp_solver"]
    threads: 1
    resources:
        mem_mb = lambda wildcards, attempt: attempt * 16000
    shadow: "shallow"
    shell:
            """
            PYTHONPATH={params.pling_root_dir} python {params.pling_root_dir}/pling/dcj_snakemake/dcj.py \
                    --batch_file {input.batch_list} \
                    --batch 0 \
                    --containment_tsv {input.containment_tsv} \
                    --containment_distance {params.containment_distance} \
                    --outputpath {params.outputpath} \
                    --distpath {output} \
                    --communitypath NA \
                    --integerisation align \
                    --threads {threads} \
                    --snakefile_dir {params.pling_root_dir}/pling/dcj_snakemake \
                    --unimog {input.unimog} \
                    --ilp_solver {params.ilp_solver}
            """

rule dcj_submatrix:
    input:
        missing_dists = lambda wildcards: f"{OUTPUTPATH}/dists_missing/{wildcards.subcommunity}_dcj.tsv",
        incomplete_submatrix = lambda wildcards: f"{OUTPUTPATH}/incomplete/{wildcards.subcommunity}_incomplete.dist"
    output:
        submatrix_tsv = f"{OUTPUTPATH}/dists/{{subcommunity}}.dist"
    run:
        submatrix = pd.read_csv(input.incomplete_submatrix, sep="\t", index_col=0)
        submatrix.columns = submatrix.index
        missing_dists = read_in_dists(input.missing_dists)
        distances = submatrix.copy()
        missing_df = submatrix.isna()
        for i in range(0, len(submatrix.index)):
            plasmid1 = list(submatrix.index)[i]
            for j in range(0,i):
                if missing_df.iloc[i,j]==True:
                    plasmid2 = list(submatrix.columns)[j]
                    distances.loc[plasmid1,plasmid2] = missing_dists[(plasmid1,plasmid2)]
                    distances.loc[plasmid2,plasmid1] = missing_dists[(plasmid1,plasmid2)]
        new_row = pd.DataFrame({el:pd.NA for el in submatrix.columns}, index=[submatrix.index.name])
        distances = pd.concat([new_row,distances.loc[:]])
        distances.to_csv(output.submatrix_tsv, sep="\t", index=True, header=False)

rule trees:
    input:
        get_input
    output:
        f"{OUTPUTPATH}/trees/{{subcommunity}}.tree"
    resources:
        mem_mb = lambda wildcards, attempt: attempt * 1000
    shell:
        "quicktree -in m {input} > {output}"


rule vis:
    input:
        f"{OUTPUTPATH}/trees/{{subcommunity}}.tree"
    output:
        f"{OUTPUTPATH}/vis/{{subcommunity}}.pdf"
    params:
        pling_root_dir = get_pling_root_dir()
    shell:
        "PYTHONPATH={params.pling_root_dir} python {params.pling_root_dir}/pling/submatrix_snakemake/draw_trees.py {input} {output}"