AE_WORKDIR = cfg.AE.renameLocalDir()

AE_index_input,AE_graph_file,AE_index_output = cfg.AE.getModuleIndexFiles("aberrant-expression-pipeline",AE_WORKDIR)
rule aberrantExpression:
    input:  AE_index_input,AE_graph_file
    output: AE_index_output
    run:
        if cfg.AE.run:
            ci(str(AE_WORKDIR), 'aberrant-expression-pipeline')

rule aberrantExpression_dependency:
    output: AE_graph_file
    shell:
        """
        snakemake --nolock --rulegraph {AE_index_output} | \
            sed -ne '/digraph snakemake_dag/,/}}/p' | \
            dot -Tsvg -Grankdir=TB > {output}
        """

def get_bam_only_ids(wildcards):
    # Get all RNA samples in the group
    rna_ids = set(sa.getIDsByGroup(wildcards.dataset, assay="RNA"))
    # Get RNA IDs that have external counts
    ext_ids = set(sa.getIDsByGroup(wildcards.dataset, assay="GENE_COUNT"))
    # Only those with BAM and NO external counts
    return sorted(list(rna_ids - ext_ids))

rule aberrantExpression_bamStats:
    input:
        bam = lambda wildcards: sa.getFilePath(wildcards.sampleID, "RNA_BAM_FILE"),
        ucsc2ncbi = cfg.workDir  / "Scripts/Pipeline/chr_UCSC_NCBI.txt"
    output:
        cfg.processedDataDir / "aberrant_expression" / "bam_stats" / "{sampleID}.txt"
    params:
        samtools = config["tools"]["samtoolsCmd"]
    shell:
        """
        # identify chromosome format
        if {params.samtools} idxstats {input.bam} | grep -qP "^chr";
        then
            chrNames=$(cut -f1 {input.ucsc2ncbi} | tr '\n' '|')
        else
            chrNames=$(cut -f2 {input.ucsc2ncbi} | tr '\n' '|')
        fi

        # write coverage from idxstats into file
        count=$({params.samtools} idxstats {input.bam} | grep -E "^($chrNames)" | \
                cut -f3 | paste -sd+ - | bc)

        echo -e "{wildcards.sampleID}\t${{count}}" > {output}
        """

rule aberrantExpression_mergeBamStats:
    input:
        lambda w: expand(
            rules.aberrantExpression_bamStats.output[0],
            sampleID=get_bam_only_ids(w)
        )
    output:
        cfg.processedDataDir / "aberrant_expression" / "bam_stats" / "{dataset}.tsv"
    params:
        exIDs = lambda w: sa.getIDsByGroup(w.dataset, assay="GENE_COUNT")
    run:
        with open(output[0], "w") as bam_stats:
            bam_stats.write("RNA_ID\trecord_count\n")
            for f in input:
                bam_stats.write(open(f, "r").read())
            for eid in params.exIDs:
                bam_stats.write(f"{eid}\tNA\n")
