import os
import re
import shutil
import json
import sys
import pandas as pd
import warnings

from glob import glob
from tqdm import tqdm

warnings.filterwarnings("ignore")
READ_TOKEN_RE = re.compile(r"R([12])(?=$|[._-])", re.IGNORECASE)

# Load configuration variables
raw_data_path = config.get("raw_data", "raw_data")
sample_manifest_path = config.get("sample_manifest", None)
list_of_samples = config.get("samples", None)
install_path = config.get("install_path", None) or os.path.dirname(os.path.realpath(__file__))
shovill_cpus = config.get("shovill_cpu_cores", 1)
shovill_ram = config.get("shovill_ram", None)
shovill_ram_flag = f"--ram {shovill_ram}" if shovill_ram else ""
prokka_cpus = config.get("prokka_cpu_cores", 1)
roary_cpus = config.get("roary_cpu_cores", 1)
work_directory = config.get("output_dir", None)
simple_report = config.get("simple_report", config.get("r_report", True))
if isinstance(simple_report, str):
    simple_report = simple_report.lower() in ("1", "true", "yes", "y")
interactive_report = config.get("interactive_report", config.get("html_report", False))
if isinstance(interactive_report, str):
    interactive_report = interactive_report.lower() in ("1", "true", "yes", "y")

html_path = f"{install_path}"
rmd_template_path = os.path.join(install_path, "report_template.Rmd")
r_render_script_path = os.path.join(install_path, "render_report.R")
r_report_outputs = ["report/report_r.html"] if simple_report else []
html_report_outputs = ["report/report.html"] if interactive_report else []

try:
    import importlib.metadata as metadata
    pegas_version = metadata.version("pegas")
except Exception:
    pegas_version = "unknown"

package_root = os.path.dirname(install_path)
if package_root not in sys.path:
    sys.path.insert(0, package_root)

from pegas.build_dataframe import build_dataframe
from pegas.build_report import build_report
# ================= Utility Functions =================

def list_fastq_files(path):
    """Returns a list of all .fastq.gz files in the specified path."""
    full_path = os.path.abspath(os.path.expanduser(path))
    return [f for f in glob(os.path.join(full_path, "*.fastq.gz"))]

def parse_fastq_read(filename):
    base = os.path.basename(filename)
    if not base.endswith(".fastq.gz"):
        return None
    stem = base[:-9]
    matches = list(READ_TOKEN_RE.finditer(stem))
    if not matches:
        return None
    match = matches[-1]
    sample = stem[:match.start()].rstrip("._-")
    if not sample:
        return None
    return sample, f"R{match.group(1)}"

def get_core_sample_name(filename):
    """Extracts the core sample name by removing _R1 or _R2 and other suffixes."""
    parsed = parse_fastq_read(filename)
    if parsed:
        return parsed[0]
    base = os.path.basename(filename)
    if base.endswith(".fastq.gz"):
        return base[:-9]
    return os.path.splitext(base)[0]

def build_fastq_pairs(fastq_files):
    """Pairs R1 and R2 files based on sample names."""
    pairs = {}
    unmatched = []
    for file in fastq_files:
        parsed = parse_fastq_read(file)
        if not parsed:
            unmatched.append(file)
            continue
        sample, read = parsed
        if sample not in pairs:
            pairs[sample] = {}
        pairs[sample][read] = file
    orphan_samples = sorted([
        sample for sample, reads in pairs.items()
        if "R1" not in reads or "R2" not in reads
    ])
    if unmatched:
        preview = ", ".join(os.path.basename(f) for f in unmatched[:5])
        more = f" (+{len(unmatched) - 5} more)" if len(unmatched) > 5 else ""
        tqdm.write(f"[pegas] Warning: {len(unmatched)} FASTQ files did not match the R1/R2 pattern: {preview}{more}")
    if orphan_samples:
        preview = ", ".join(orphan_samples[:5])
        more = f" (+{len(orphan_samples) - 5} more)" if len(orphan_samples) > 5 else ""
        tqdm.write(f"[pegas] Warning: {len(orphan_samples)} samples missing R1 or R2: {preview}{more}")
    # Filter out incomplete pairs
    return {s: p for s, p in pairs.items() if "R1" in p and "R2" in p}


# ================= Data Preparation =================

if sample_manifest_path and os.path.exists(sample_manifest_path):
    with open(sample_manifest_path, "r") as f:
        sample_manifest = json.load(f)
    fastq_files = sample_manifest.get("fastq_files", [])
    sample_pairs = sample_manifest.get("sample_pairs", {})
else:
    fastq_files = list_fastq_files(raw_data_path)
    sample_pairs = build_fastq_pairs(fastq_files)

file_names = [os.path.basename(f).replace(".fastq.gz", "") for f in fastq_files]
sample_names = list(sample_pairs.keys())
fastq_lookup = {os.path.basename(f).replace(".fastq.gz", ""): f for f in fastq_files}

# ================= Snakemake Workflow =================

rule all:
    input:
        html=expand("fastqc/{file}_fastqc.html", file=file_names),
        html_zip=expand("fastqc/{file}_fastqc.zip", file=file_names),
        shovill=expand(
            "results/{sample}/shovill/contigs.fa",
            sample=sample_pairs.keys()
        ),
        abricate_ncbi=expand(
            "results/{sample}/abricate_ncbi.tsv",
            sample=sample_pairs.keys()
        ),
        abricate_plasmidfinder=expand(
            "results/{sample}/abricate_plasmidfinder.tsv",
            sample=sample_pairs.keys()
        ),
        abricate_vfdb=expand(
            "results/{sample}/abricate_vfdb.tsv",
            sample=sample_pairs.keys()
        ),
        mlst=expand(
            "results/{sample}/mlst.tsv",
            sample=sample_pairs.keys()
        ),
        dataframe="dataframe/results.csv",
        prokka=expand(
            "results/{sample}/prokka/{sample}.gff",
            sample=sample_pairs.keys()
        ),
        report=html_report_outputs,
        r_report=r_report_outputs

rule fastqc:
    input:
        fastq_files=lambda wildcards: fastq_lookup[wildcards.file]
    output:
        html="fastqc/{file}_fastqc.html",
        html_zip="fastqc/{file}_fastqc.zip",
    conda:
        "envs/fastqc_env.yml"
    shell:
        """
        mkdir -p fastqc
        fastqc {input.fastq_files} -o fastqc
        """

rule shovill:
    input:
        R1=lambda wildcards: sample_pairs[wildcards.sample]["R1"],
        R2=lambda wildcards: sample_pairs[wildcards.sample]["R2"]
    output:
        assembly="results/{sample}/shovill/contigs.fa",
        gfa="results/{sample}/shovill/contigs.gfa",
        corrections="results/{sample}/shovill/shovill.corrections",
        log="results/{sample}/shovill/shovill.log",
        spades="results/{sample}/shovill/spades.fasta"
    threads:
        shovill_cpus
    resources:
        mem_gb=shovill_ram if shovill_ram else 0
    conda:
        "envs/shovill_env.yml"
    shell:
        r"""
        set -euo pipefail
        if ! shovill --trim \
                --outdir results/{wildcards.sample}/shovill \
                --R1 {input.R1} --R2 {input.R2} \
                --force --cpus {threads} {shovill_ram_flag} ; then
            echo -e "\n[ERROR] Shovill failed for sample {wildcards.sample}. Creating empty output files. Check {output.log} for details." >&2
            mkdir -p results/{wildcards.sample}/shovill
            # Create empty output files to prevent pipeline failure
            :> {output.assembly}
            :> {output.gfa}
            :> {output.corrections}
            :> {output.spades}
            # Ensure log file exists; if Shovill didn't create it, make an empty one
            if [[ ! -f {output.log} ]]; then
                :> {output.log}
            fi
        fi
        """

rule abricate_ncbi:
    input:
        assembly="results/{sample}/shovill/contigs.fa"
    output:
        abricate="results/{sample}/abricate_ncbi.tsv"
    conda:
        "envs/abricate_env.yml"
    shell:
        "abricate --db ncbi {input.assembly} > {output.abricate}"

rule abricate_plasmidfinder:
    input:
        assembly="results/{sample}/shovill/contigs.fa"
    output:
        abricate="results/{sample}/abricate_plasmidfinder.tsv"
    conda:
        "envs/abricate_env.yml"
    shell:
        "abricate --db plasmidfinder {input.assembly} > {output.abricate}"

rule abricate_vfdb:
    input:
        assembly="results/{sample}/shovill/contigs.fa"
    output:
        abricate="results/{sample}/abricate_vfdb.tsv"
    conda:
        "envs/abricate_env.yml"
    shell:
        "abricate --db vfdb {input.assembly} > {output.abricate}"

rule mlst:
    input:
        assembly="results/{sample}/shovill/contigs.fa"
    output:
        mlst="results/{sample}/mlst.tsv"
    conda:
        "envs/mlst_env.yml"
    shell:
        "mlst {input.assembly} > {output.mlst}"

rule prokka:
    input:
        assembly="results/{sample}/shovill/contigs.fa"
    output:
        prokka="results/{sample}/prokka/{sample}.gff",
        err="results/{sample}/prokka/{sample}.err",
        faa="results/{sample}/prokka/{sample}.faa",
        ffn="results/{sample}/prokka/{sample}.ffn",
        fna="results/{sample}/prokka/{sample}.fna",
        fsa="results/{sample}/prokka/{sample}.fsa",
        gbk="results/{sample}/prokka/{sample}.gbk",
        log="results/{sample}/prokka/{sample}.log",
        sqn="results/{sample}/prokka/{sample}.sqn",
        tbl="results/{sample}/prokka/{sample}.tbl",
        tsv="results/{sample}/prokka/{sample}.tsv",
        txt="results/{sample}/prokka/{sample}.txt"
    threads:
        prokka_cpus
    conda:
        "envs/prokka_env.yml"
    shell:
        r"""
        set -euo pipefail
        if ! prokka --centre X --compliant {input.assembly} \
            --outdir results/{wildcards.sample}/prokka/ \
            --force --cpus {threads} \
            --prefix {wildcards.sample}; then
            echo -e "\n[ERROR] Prokka failed for sample {wildcards.sample}. Creating empty outputs. Check {output.log} for details." >&2
            mkdir -p results/{wildcards.sample}/prokka
            # Create empty versions of all outputs
            :> {output.prokka}
            :> {output.err}
            :> {output.faa}
            :> {output.ffn}
            :> {output.fna}
            :> {output.fsa}
            :> {output.gbk}
            :> {output.sqn}
            :> {output.tbl}
            :> {output.tsv}
            :> {output.txt}
            # Preserve existing log or create empty
            if [[ ! -f {output.log} ]]; then
                :> {output.log}
            fi
        fi
        """

checkpoint build_dataframe:
    input:
        html=expand("fastqc/{file}_fastqc.html", file=file_names),
        html_zip=expand("fastqc/{file}_fastqc.zip", file=file_names),
        abricate_ncbi=expand("results/{sample}/abricate_ncbi.tsv", sample=sample_names),
        abricate_plasmidfinder=expand("results/{sample}/abricate_plasmidfinder.tsv", sample=sample_names),
        abricate_vfdb=expand("results/{sample}/abricate_vfdb.tsv", sample=sample_names),
        mlst=expand("results/{sample}/mlst.tsv", sample=sample_names),
        shovill=expand("results/{sample}/shovill/contigs.fa", sample=sample_names)
    output:
        dataframe="dataframe/results.csv"
    run:
        build_dataframe()

# Helper function to get the species based on the sample
def get_species_list_for_roary():
    df = pd.read_csv("dataframe/results.csv", dtype={'SAMPLE': str})

    # Group by species and samples
    df = df.groupby(["SPECIES", "SAMPLE"]).size().reset_index()
    df = df[['SPECIES', 'SAMPLE']]

    # Group by species and count the number of samples
    df_grouped = df.groupby("SPECIES").count().reset_index()

    # Get the species that have more than one sample
    eligeble_species = df_grouped.loc[df_grouped["SAMPLE"] > 1, "SPECIES"].tolist()

    df.loc[df["SPECIES"].isin(eligeble_species), "SAMPLE"].tolist()

    # Return only the species that have more than one sample
    return df.loc[df["SPECIES"].isin(eligeble_species), "SPECIES"].tolist()

def get_species_list_for_roary_unique():
    df = pd.read_csv("dataframe/results.csv", dtype={'SAMPLE': str})

    # Group by species and samples
    df = df.groupby(["SPECIES", "SAMPLE"]).size().reset_index()
    df = df[['SPECIES', 'SAMPLE']]

    df_grouped = df.groupby("SPECIES").count().reset_index()

    eligeble_species = df_grouped.loc[df_grouped["SAMPLE"] > 1, "SPECIES"].tolist()

    return list(set(df.loc[df["SPECIES"].isin(eligeble_species), "SPECIES"].tolist()))

def get_pangenome_prokka_outputs(wildcards):
    # Get the output of the build_dataframe checkpoint
    df_output = checkpoints.build_dataframe.get().output[0]
    df = pd.read_csv(df_output, dtype={'SAMPLE': str})
    
    # Get species with more than one sample
    df_grouped = df.groupby('SPECIES')['SAMPLE'].nunique().reset_index()
    eligible_species = df_grouped.loc[df_grouped['SAMPLE'] > 1, 'SPECIES'].tolist()
    
    # Get the list of Prokka outputs needed
    prokka_files = []
    for species in eligible_species:
        samples = df.loc[df['SPECIES'] == species, 'SAMPLE'].unique()
        for sample in samples:
            prokka_files.append(f"results/{sample}/prokka/{sample}.gff")
    return prokka_files

checkpoint pangenome:
    input:
        dataframe="dataframe/results.csv",
        prokka_outputs=get_pangenome_prokka_outputs
    output:
        pangenome_dir=directory("pangenome")
    threads:
        roary_cpus
    conda:
        "envs/roary_env.yml"
    script:
        "run_roary.py"

def get_pangenome_outputs(wildcards):
    df_output = checkpoints.build_dataframe.get().output[0]
    df = pd.read_csv(df_output, dtype={'SAMPLE': str})
    df_grouped = df.groupby('SPECIES')['SAMPLE'].nunique().reset_index()
    eligible_species = df_grouped.loc[df_grouped['SAMPLE'] > 1, 'SPECIES'].tolist()
    checkpoints.pangenome.get().output[0]
    outputs = []
    for sp in eligible_species:
        outputs.append(f"pangenome/{sp}/output/gene_presence_absence.csv")
        outputs.append(f"pangenome/{sp}/output/summary_statistics.txt")
    return outputs

rule build_report:
    input:
        pangenome_outputs=get_pangenome_outputs,
    output:
        report="report/report.html"
    run:
        build_report(html_path, raw_data_path, work_directory)

rule build_r_report:
    input:
        dataframe="dataframe/results.csv",
        rmd_template=rmd_template_path,
        render_script=r_render_script_path
    output:
        html="report/report_r.html"
    params:
        data_dir=raw_data_path,
        output_dir=work_directory,
        report_html="report/report.html" if interactive_report else "",
        pegas_version=pegas_version,
        pegas_install_dir=install_path
    conda:
        "envs/report_r_env.yml"
    shell:
        r"""
        Rscript "{input.render_script}" \
            --rmd "{input.rmd_template}" \
            --output "{output.html}" \
            --dataframe_csv "{input.dataframe}" \
            --report_html "{params.report_html}" \
            --data_dir "{params.data_dir}" \
            --output_dir "{params.output_dir}" \
            --pegas_version "{params.pegas_version}" \
            --pegas_install_dir "{params.pegas_install_dir}"
        """
