import pandas as pd
from pathlib import Path

# Load samples from CSV
samples_df = pd.read_csv(config["samples_csv"])
SAMPLES = samples_df.set_index("sample_name", drop=False).to_dict(orient="index")

# Reference genome
REF = config["reference"]
OUTDIR = config.get("outdir", "results")
PRIMER_BED = config.get("primer_bed", None)
MODE = config.get("mode", "reads")


def _value_or_blank(sample, key, allow_blank=False):
    value = SAMPLES[sample].get(key)
    if value is None or (isinstance(value, float) and pd.isna(value)):
        return "" if allow_blank else None
    text = str(value).strip()
    if not text:
        return "" if allow_blank else None
    return text


def _require_value(sample, key):
    value = _value_or_blank(sample, key)
    if not value:
        raise ValueError(f"Missing value for column '{key}' in sample '{sample}'")
    return value


def _initial_bam(wildcards):
    if MODE == "reads":
        return f"{OUTDIR}/{wildcards.sample}/aligned.clipped.bam"
    bam_path = _value_or_blank(wildcards.sample, "bam")
    if not bam_path:
        raise ValueError(
            f"Sample '{wildcards.sample}' is missing a BAM path in the input CSV"
        )
    return bam_path

rule all:
    input:
        expand(f"{OUTDIR}/{{sample}}/{{sample}}_variants.vcf.gz", sample=SAMPLES.keys()),
        expand(f"{OUTDIR}/{{sample}}/{{sample}}_variants.vcf.gz.csi", sample=SAMPLES.keys()),
        expand(f"{OUTDIR}/{{sample}}/{{sample}}_depth.txt", sample=SAMPLES.keys()),
        f"{OUTDIR}/vartracker_execution_spreadsheet.csv"

if MODE == "reads":
    rule bwa_index:
        input:
            ref = REF
        output:
            amb = REF + ".amb",
            ann = REF + ".ann",
            bwt = REF + ".bwt",
            pac = REF + ".pac",
            sa = REF + ".sa"
        shell:
            """
            bwa index {input.ref} 2> /dev/null
            """

    rule fastp:
        input:
            r1 = lambda w: _require_value(w.sample, "reads1"),
        params:
            r2 = lambda w: _value_or_blank(w.sample, "reads2", allow_blank=True)
        output:
            r1 = temp(f"{OUTDIR}/{{sample}}/trimmed_R1.fastq.gz"),
            r2 = temp(f"{OUTDIR}/{{sample}}/trimmed_R2.fastq.gz"),
            html = temp(f"{OUTDIR}/{{sample}}/fastp.html"),
            json = temp(f"{OUTDIR}/{{sample}}/fastp.json")
        threads: max(1, int(workflow.cores * 0.5))
        shell:
            """
            mkdir -p {OUTDIR}/{wildcards.sample}
            if [ -n "{params.r2}" ]; then
                fastp -i {input.r1} -I {params.r2} \
                    -o {output.r1} -O {output.r2} \
                    -w {threads} \
                    --detect_adapter_for_pe \
                    --cut_front \
                    --cut_tail \
                    --cut_mean_quality 20 \
                    --correction \
                    --length_required 50 \
                    -h {output.html} \
                    -j {output.json} 2> /dev/null
            else
                fastp -i {input.r1} \
                    -o {output.r1} \
                    -w {threads} \
                    --cut_front \
                    --cut_tail \
                    --cut_mean_quality 20 \
                    --length_required 50 \
                    -h {output.html} \
                    -j {output.json} 2> /dev/null
                touch {output.r2}
            fi
            """

    rule bwa_mem:
        input:
            r1 = f"{OUTDIR}/{{sample}}/trimmed_R1.fastq.gz",
            r2 = f"{OUTDIR}/{{sample}}/trimmed_R2.fastq.gz",
            ref = REF,
            idx = REF + ".amb"
        output:
            bam = temp(f"{OUTDIR}/{{sample}}/aligned.raw.bam"),
            bai = temp(f"{OUTDIR}/{{sample}}/aligned.raw.bam.bai")
        params:
            rg = lambda w: f"@RG\\tID:{w.sample}\\tSM:{w.sample}\\tPL:ILLUMINA"
        threads: max(1, workflow.cores)
        shell:
            """
            mkdir -p {OUTDIR}/{wildcards.sample}
            if [ -s {input.r2} ]; then
                bwa mem -t {threads} -R '{params.rg}' {input.ref} {input.r1} {input.r2} 2> /dev/null | \
                    samtools view -b - | \
                    samtools sort -@ {threads} -o {output.bam} 2> /dev/null
            else
                bwa mem -t {threads} -R '{params.rg}' {input.ref} {input.r1} 2> /dev/null | \
                    samtools view -b - | \
                    samtools sort -@ {threads} -o {output.bam} 2> /dev/null
            fi
            samtools index {output.bam}
            """

    rule ampliconclip:
        input:
            bam = f"{OUTDIR}/{{sample}}/aligned.raw.bam"
        output:
            bam = temp(f"{OUTDIR}/{{sample}}/aligned.clipped.bam"),
            bai = temp(f"{OUTDIR}/{{sample}}/aligned.clipped.bam.bai"),
            log = temp(f"{OUTDIR}/{{sample}}/ampliconclip.log")
        params:
            bed = PRIMER_BED if PRIMER_BED else ""
        threads: max(1, int(workflow.cores * 0.5))
        run:
            if PRIMER_BED:
                shell("""
                    samtools ampliconclip -b {params.bed} -@ {threads} \
                    --strand --both-ends -o - {input.bam} 2> {output.log} \
                    | samtools sort -@ {threads} -o {output.bam} - 2> /dev/null
                    samtools index {output.bam}
                """)
            else:
                shell("""
                    cp {input.bam} {output.bam}
                    samtools index {output.bam}
                    touch {output.log}
                """)

rule lofreq_indelqual:
    input:
        bam = _initial_bam,
        ref = REF
    output:
        bam = f"{OUTDIR}/{{sample}}/{{sample}}_aligned.indelqual.bam"
    shell:
        """
        mkdir -p {OUTDIR}/{wildcards.sample}
        lofreq indelqual --dindel -f {input.ref} -o {output.bam} {input.bam} 2> /dev/null
        samtools index {output.bam}
        """

rule samtools_depth:
    input:
        bam = f"{OUTDIR}/{{sample}}/{{sample}}_aligned.indelqual.bam"
    output:
        depth = f"{OUTDIR}/{{sample}}/{{sample}}_depth.txt"
    shell:
        """
        samtools depth -aa {input.bam} > {output.depth} 2> /dev/null
        """

rule lofreq_call:
    input:
        bam = f"{OUTDIR}/{{sample}}/{{sample}}_aligned.indelqual.bam",
        ref = REF
    output:
        vcf_raw = temp(f"{OUTDIR}/{{sample}}/{{sample}}_variants.raw.vcf"),
        vcf = f"{OUTDIR}/{{sample}}/{{sample}}_variants.vcf.gz",
        csi = f"{OUTDIR}/{{sample}}/{{sample}}_variants.vcf.gz.csi"
    threads: max(1, workflow.cores)
    shell:
        """
        lofreq call-parallel --no-baq --call-indels --pp-threads {threads} \
            -f {input.ref} -o {output.vcf_raw} {input.bam} 2> /dev/null

        bgzip -c {output.vcf_raw} > {output.vcf}
        bcftools index {output.vcf}
        """

rule update_csv:
    input:
        vcfs = expand(f"{OUTDIR}/{{sample}}/{{sample}}_variants.vcf.gz", sample=SAMPLES.keys()),
        depths = expand(f"{OUTDIR}/{{sample}}/{{sample}}_depth.txt", sample=SAMPLES.keys()),
        bams = expand(f"{OUTDIR}/{{sample}}/{{sample}}_aligned.indelqual.bam", sample=SAMPLES.keys())
    output:
        csv = f"{OUTDIR}/vartracker_execution_spreadsheet.csv"
    params:
        original_csv = config["samples_csv"],
        outdir = OUTDIR
    run:
        import pandas as pd
        import os

        # Read original CSV
        df = pd.read_csv(params.original_csv)

        # Add absolute paths for bam, vcf, and coverage (depth)
        df['bam'] = df['sample_name'].apply(
            lambda x: os.path.abspath(f"{params.outdir}/{x}/{x}_aligned.indelqual.bam")
        )
        df['vcf'] = df['sample_name'].apply(
            lambda x: os.path.abspath(f"{params.outdir}/{x}/{x}_variants.vcf.gz")
        )
        df['coverage'] = df['sample_name'].apply(
            lambda x: os.path.abspath(f"{params.outdir}/{x}/{x}_depth.txt")
        )

        # Write updated CSV
        df.to_csv(output.csv, index=False)
