################################
# Imports
################################
import os
import shutil

from seqnado import Assay
from seqnado.config import SeqnadoConfig
from seqnado.inputs import (
    SampleGroupings,
    SampleGroups,
    get_sample_collection,
)
from seqnado.outputs import SeqnadoOutputFactory, SeqnadoOutputFiles
from seqnado.utils import remove_unwanted_run_files


################################
# Hardcoded Config
################################
container: "oras://ghcr.io/alsmith151/seqnado_pipeline:latest"
SCALE_RESOURCES = float(os.environ.get("SCALE_RESOURCES", "1"))


################################
# Load Configuration
################################
CONFIG = SeqnadoConfig(**config)  # pyright: ignore[reportUndefinedVariable]
ASSAY = CONFIG.assay
INPUT_FILES = get_sample_collection(assay=ASSAY, path=CONFIG.metadata, config=CONFIG)
OUTPUT_DIR = config.get("output_dir", f"seqnado_output/{ASSAY.clean_name}")

################################
# Define Sample Groupings
################################
SAMPLE_GROUPINGS = SampleGroupings()

# Defines groupings used to normalise between bigwigs
normalisation_groups = SampleGroups.from_dataframe(
    INPUT_FILES.to_dataframe(),
    subset_column="scaling_group",
)

# Defines groupings used to scale bigwigs e.g antibody
scaling_groups = SampleGroups.from_dataframe(
    INPUT_FILES.to_dataframe(),
    subset_column="scaling_group",
)

# Defines groupings used to call consensus peaks and aggregate samples into bigwigs.
# For FastqCollectionForIP, to_dataframe() returns only IP rows (index = uid = sample_id + ip),
# so consensus groups are guaranteed to contain IP samples only (no input/control samples).
consensus_groups = SampleGroups.from_dataframe(
    INPUT_FILES.to_dataframe(),
    subset_column="consensus_group",
)

# Defines condition-based groupings for comparison bigwigs and eventual differential analysis
condition_groups = SampleGroups.from_dataframe(
    INPUT_FILES.to_dataframe(),
    subset_column="condition",
)

# Container to hold all groupings
SAMPLE_GROUPINGS.add_grouping("normalisation", normalisation_groups)
SAMPLE_GROUPINGS.add_grouping("scaling", scaling_groups)
SAMPLE_GROUPINGS.add_grouping("consensus", consensus_groups)
SAMPLE_GROUPINGS.add_grouping("condition", condition_groups)


################################
# Define Outputs
################################
OUTPUT: SeqnadoOutputFiles = (
    SeqnadoOutputFactory(
        assay=ASSAY,
        samples=INPUT_FILES,
        config=CONFIG,
        sample_groupings=SAMPLE_GROUPINGS,
        output_dir=OUTPUT_DIR,
    )
    .create_output_builder()
    .build()
)
SAMPLE_NAMES = OUTPUT.sample_names

# For IP-based assays, create a separate list for IP-only samples (used for peak calling)
from seqnado.inputs import FastqCollectionForIP
if isinstance(INPUT_FILES, FastqCollectionForIP):
    IP_SAMPLE_NAMES = INPUT_FILES.ip_sample_names
else:
    IP_SAMPLE_NAMES = SAMPLE_NAMES

################################
# Set-Up Workflow
################################
fastq_dir = f"{OUTPUT_DIR}/fastqs/"
INPUT_FILES.symlink_fastq_files(output_dir=fastq_dir)


################################
# Include Rules
################################
include: "rules/common/utilities.smk"
include: "rules/fastq/screen.smk"
include: "rules/protocol/protocol.smk"
include: "rules/qc/qc.smk"

if CONFIG.assay_config.create_dataset:
    include: "rules/dataset/dataset.smk"

match ASSAY:
    case Assay.ATAC | Assay.CAT | Assay.CHIP:
        include: "rules/alignment/dna.smk"
        include: "rules/bam/all.smk"
        include: "rules/fastq/trim.smk"
        include: "rules/motifs/motif.smk"
        include: "rules/normalization/scaling.smk"
        include: "rules/peaks/default.smk"
        include: "rules/peaks/merge.smk"
        include: "rules/pileup/common.smk"
        include: "rules/pileup/dna.smk"
        include: "rules/quant/merged.smk"
        include: "rules/visualise/browser.smk"
        include: "rules/visualise/hub.smk"
    case Assay.RNA:
        include: "rules/alignment/rna.smk"
        include: "rules/bam/all.smk"
        include: "rules/fastq/trim.smk"
        include: "rules/normalization/scaling.smk"
        include: "rules/pileup/common.smk"
        include: "rules/pileup/rna.smk"
        include: "rules/quant/featurecounts.smk"
        include: "rules/quant/salmon.smk"
        include: "rules/visualise/browser.smk"
        include: "rules/visualise/hub.smk"
    case Assay.MCC:
        include: "rules/bam/all.smk"
        include: "rules/fastq/manipulate.smk"
        include: "rules/fastq/trim.smk"
        include: "rules/mcc/all.smk"
        include: "rules/visualise/hub.smk"
    case Assay.CRISPR:
        include: "rules/crispr/all.smk"
    case Assay.METH:
        include: "rules/alignment/dna.smk"
        include: "rules/bam/all.smk"
        include: "rules/fastq/trim.smk"
        include: "rules/meth/calling.smk"
        include: "rules/pileup/common.smk"
        include: "rules/pileup/dna.smk"
        include: "rules/visualise/browser.smk"
    case Assay.SNP:
        include: "rules/alignment/dna.smk"
        include: "rules/bam/all.smk"
        include: "rules/fastq/trim.smk"
        include: "rules/variant/call.smk"
        include: "rules/variant/annotate.smk"
    case _:
        raise ValueError(f"Assay {ASSAY} not supported yet.")


################################
# Conditional Includes
################################
if CONFIG.assay_config.create_heatmaps:
    include: "rules/visualise/heatmap.smk"

if CONFIG.assay_config.create_geo_submission_files:
    include: "rules/geo/submission.smk"


################################
# Rule Ordering
################################
if CONFIG.assay_config.has_spikein:
    include: "rules/spikein/all.smk"
    include: "rules/normalization/spikein.smk"
    match ASSAY:
        case Assay.RNA:
            ruleorder: rename_aligned_spikein > align_paired_spikein_rna > align_single_spikein_rna
            ruleorder: bam_move_ref > rename_aligned > align_paired_spikein_rna > align_paired > align_single_spikein_rna > align_single
        case _:
            ruleorder: align_paired_spikein > align_single_spikein
            ruleorder: bam_move_ref > align_paired > align_single
            ruleorder: make_bigwigs_deeptools_spikein > make_bigwigs_deeptools
else:
    match ASSAY:
        case Assay.RNA:
            ruleorder: rename_aligned > align_paired > align_single
        case ASSAY.ATAC | ASSAY.CAT | ASSAY.CHIP:
            ruleorder: align_paired > align_single
            ruleorder: call_peaks_lanceotron_no_input_consensus > call_peaks_lanceotron_no_input 
            ruleorder: call_peaks_macs2_no_input_consensus > call_peaks_macs2_no_input


################################
# Final Aggregation Rule
###############################
rule all:
    input:
        OUTPUT.all_files
    message: "All workflow steps completed successfully."


################################
# Workflow Hooks
################################
onsuccess:
    remove_unwanted_run_files()


onerror:
    log_out = "seqnado_error.log"
    # log is a list in Snakemake, get the first element
    log_file = log[0] if isinstance(log, list) else log  # pyright: ignore[reportUndefinedVariable]
    shutil.copyfile(log_file, log_out)
    print(
        f"An error occurred. Please check the log file {log_out} for more information."
    )
    remove_unwanted_run_files()
