#!/usr/bin/env python3
version = '2.0'

import re, gzip
import argparse, shutil, subprocess
import os, glob
from pathlib import Path
import csv, sys, time
from csv import writer
import subprocess

# New MJA
from collections import defaultdict
import pandas as pd
import plotly.express as px
import plotly.offline as pyo
import plotly.graph_objects as go

from concurrent.futures import ThreadPoolExecutor

def check_executables():
    req_exe=["cutadapt","awk","gzip","uniq","sort","head","grep","rm"]
    missing = []
    for pack in req_exe:
        if shutil.which(pack) is None:
            missing.append(pack)
    if missing:
        sys.exit(f"Missing required executables: {', '.join(missing)}")

def get_ShortCut_args(version):
    ap = argparse.ArgumentParser(description=f"ShortCut v{version}")
    ap.add_argument('--fastq', nargs='+', required=True,
                type=valid_fastq,
                help='One or more FASTQ alignment files')
    ap.add_argument("--min_read_size", type=int,
                help="Minimum length of trimmed reads to retain", default = 12)
    ap.add_argument("--out_dir", help="output directory",default="ShortCut_output")
    ap.add_argument("--threads", type=int, default=1,
                    help="Number of threads to use. Default=1")
    ap.add_argument("--dicermin", type=int, default=21,
                    help="Minimum trimmed read length for true regulatory small RNA. Default=21")
    ap.add_argument("--dicermax", type=int, default=24,
                    help="Minimum trimmed read length for true regulatory small RNA. Default=24")
    group = ap.add_mutually_exclusive_group()
    group.add_argument('--trim_key', type=valid_sequence, 
                    help='Trim key to use to infer adapter sequence (default: UCGGACCAGGCUUCAUUCCCC)')
    group.add_argument('--adapter', type=valid_sequence, help='Adapter sequence to use for trimming')

    args = ap.parse_args()

    # Set default after parsing
    if args.trim_key is None and args.adapter is None:
        args.trim_key = valid_sequence('UCGGACCAGGCUUCAUUCCCC')
    
    # Validate dicermin, dicermax, and min_read_size
    if(args.dicermin > args.dicermax):
        sys.exit("Invalid settings for dicermin and dicermax. The dicermin value must be <= the dicermax value.")
    if(args.dicermin < args.min_read_size):
        sys.exit("Invalide settings for dicermin and min_read_size. The dicermin value must be >= the min_read_size value.")

    return args

def valid_fastq(filepath):
    if not os.path.isfile(filepath):
        raise argparse.ArgumentTypeError(f"File not found: {filepath}")
    if not filepath.endswith(('.fastq', '.fastq.gz', '.fq', '.fq.gz')):
        raise argparse.ArgumentTypeError(f"Must be FASTQ format: {filepath}")
    if not os.access(filepath, os.R_OK):
        raise argparse.ArgumentTypeError(f"File not readable: {filepath}")
    return filepath

# trim_key and adapter are mutually exclusive, with trim_key being a default of miR166
def valid_sequence(seq):
    """Validate and normalize DNA/RNA sequence."""
    # Convert to uppercase
    seq = seq.upper()
    
    # Check length
    if len(seq) < 15 or len(seq) > 25:
        raise argparse.ArgumentTypeError(
            f"Sequence must be 15-25 characters long (got {len(seq)})"
        )
    
    # Check for valid characters
    valid_chars = set('ATGCU')
    if not set(seq).issubset(valid_chars):
        invalid = set(seq) - valid_chars
        raise argparse.ArgumentTypeError(
            f"Sequence contains invalid characters: {invalid}. Only A, T, G, C, U allowed"
        )
    
    # Convert U to T
    seq = seq.replace('U', 'T')
    
    return seq


def print_start_message(args):
    print(f"Starting ShortCut version {version} with these options:")
    for arg, value in vars(args).items():
        if value is not None:
            # Format lists nicely
            if isinstance(value, list):
                print(f"  {arg}: {', '.join(value)}")
            else:
                print(f"  {arg}: {value}")
    print()

def prep_outdir(args):
    path=args.out_dir
    isExist = os.path.exists(path)
    if isExist==True:
        msg = ("Output directory '" + args.out_dir + "/' already exists. Please assign a new output directory using option '-out'.")
        sys.exit(msg)
    subprocess.call(["mkdir",args.out_dir])
    os.chdir(args.out_dir)
    print(f'Made output directory {args.out_dir}')
    subprocess.call(["mkdir","trimmed_ok"])
    subprocess.call(["mkdir","short"])
    subprocess.call(["mkdir","untrimmed"])

def get_adapter(args, fastqs):
    # If necessary, learns the adapter from the first fastq file.
    if args.adapter is None:
        # use just one file to learn adapter
        abs_fastq = fastqs[0]
        print(f'        Detecting adapter for {abs_fastq} with trim_key {args.trim_key}')
        if not os.path.exists(abs_fastq):
            sys.exit(f"        Error: File {abs_fastq} does not exist. Please check the path.")

        if abs_fastq.endswith(".gz"):
            cmd = (
                f"gzip -dc {abs_fastq} | awk -v target='{args.trim_key}' '{{idx = index($0, target); if (idx) print substr($0, idx + length(target),20)}}' | "
                "uniq -c | sort -nr | head -n1 | awk '{print $2}'"
            )
        else:
            cmd = (
                f"awk -v target='{args.trim_key}' '{{idx = index($0, target); if (idx) print substr($0, idx + length(target),20)}}' {abs_fastq} | "
                "uniq -c | sort -nr | head -n1 | awk '{print $2}'"
            )

        try:
            adapter = subprocess.check_output(cmd, shell=True, text=True).strip()
        except subprocess.CalledProcessError as e:
            sys.exit(f"        Error while processing {abs_fastq}: {adapter.output}")

        if "G" not in adapter:
            print("        No adapter detected. Copying untrimmed file to trimmedLibraries instead.")
            sys.exit(f'        Error while inferring adapter from {abs_fastq} : Invalid adapter found: {adapter}')
        else:
            print(f"        Adapter detected: {adapter}")
            args.adapter = adapter
        return
    else:
        return

def create_dicerok_figure(full_df, args):
    # Make table that has total counts by Source
    src_totals = full_df.groupby(['Source'])[['Reads','Sequences']].sum().reset_index()

    # Filter, group and sum to get DicerOK counts
    dicerok_totals = full_df[full_df['Length'].between(args.dicermin, args.dicermax)].groupby('Source')[['Reads', 'Sequences']].sum().reset_index()

    # Compute the DicerOK ratios for both Reads and Sequences
    ratios = dicerok_totals.set_index('Source') / src_totals.set_index('Source')

    # reshape for plotting
    # Reset index to make Source a regular column
    df_long = ratios.reset_index()

    # Melt to long format
    df_long = df_long.melt(id_vars='Source', 
                        value_vars=['Reads', 'Sequences'],
                        var_name='Method', 
                        value_name='DicerOK')
    # Make figure
    fig = px.box(df_long,
                 color = "Method",
                 y = "DicerOK",
                 points="all",
                 hover_data=df_long.columns,
                 range_y=[0,1],
                 width = 600)

    html = fig.to_html(include_plotlyjs=False,
                       full_html=False,
                       div_id='dicerok')
    
    return html

def create_trim_stacked_barchart(df):
    # Sum reads and sequences by Category and Source
    summed_df = df.groupby(['Category', 'Source'])[['Reads', 'Sequences']].sum().reset_index()

    # Reshape into long format
    df_long = summed_df.melt(id_vars=['Source','Category'], 
                        value_vars=['Reads', 'Sequences'],
                        var_name='Metric', 
                        value_name='Count')

    # Boxplot, showing data
    fig = px.box(df_long,
                 x = "Category",
                 y = "Count",
                 color = "Metric",
                 points = "all",
                 hover_data=df_long.columns,
                 width = 900)

    
    html = fig.to_html(include_plotlyjs=False,
                       full_html=False,
                       div_id='trim_fig')
    
    return html


def get_len_figs(df, args):
    all_html = ''
    # Get all unique Source names into an array to iterate over
    unique_sources = df['Source'].unique()
    
    for unq_source in unique_sources:
        filt_df = df[df['Source'] == unq_source]
        
        # Get unique categories and metrics
        categories = filt_df['Category'].unique()
        metrics = ['Reads', 'Sequences']
        
        # Create figure
        fig = go.Figure()
        
        # Add traces for each metric and category combination
        for metric in metrics:
            for category in categories:
                # Filter data for this category
                category_data = filt_df[filt_df['Category'] == category]
                
                fig.add_trace(go.Scatter(
                    x=category_data['Length'],
                    y=category_data[metric],
                    mode='lines+markers',
                    name=category,
                    visible=(metric == 'Reads'),  # Only Reads visible initially
                    legendgroup=category
                ))
        
        # Create dropdown buttons
        buttons = []
        num_categories = len(categories)
        
        for i, metric in enumerate(metrics):
            visible = [False] * len(fig.data)
            showlegend = [False] * len(fig.data)
            
            # Make traces for this metric visible and show their legend
            start_idx = i * num_categories
            end_idx = start_idx + num_categories
            visible[start_idx:end_idx] = [True] * num_categories
            showlegend[start_idx:end_idx] = [True] * num_categories
            
            buttons.append(dict(
                label=metric,
                method='update',
                args=[{'visible': visible, 'showlegend': showlegend},
                      {'yaxis.title.text': metric}]
            ))
        
        # Add DicerOK shaded region
        fig.add_vrect(
            x0=(args.dicermin - 0.5), 
            x1=(args.dicermax + 0.5), 
            fillcolor="green",
            opacity=0.2, 
            annotation_text="DicerOK", 
            annotation_position="top left"
        )
        
        fig.update_layout(
            updatemenus=[dict(
                buttons=buttons,
                direction='down',
                showactive=True,
            )],
            title=f"{unq_source} : Length distribution by Reads or Sequences",
            xaxis_title='Length',
            yaxis_title='Reads'
        )
        
        html = fig.to_html(include_plotlyjs=False,
                           full_html=False,
                           div_id=f'len_{unq_source}')
        all_html += html

    return all_html

def trim_reads(args,fastqs):
    for fastq in fastqs:
        head, tail = os.path.split(fastq)
        suffixes = ['.fq.gz', '.fastq.gz', '.fq', '.fastq']
        for suffix in suffixes:
            if tail.endswith(suffix):
                tail = tail[:-len(suffix)]  # Remove suffix
                break
        tail = tail + '.fa'
        tfile = os.path.join('trimmed_ok', f"t_{tail}")
        sfile = os.path.join('short', f"s_{tail}")
        nafile = os.path.join('untrimmed', f"u_{tail}")
        cmd = f"cutadapt -j {args.threads} -m {args.min_read_size}"
        cmd += f" -a {args.adapter} --too-short-output {sfile}"
        cmd += f" --untrimmed-output {nafile} -o {tfile} --fasta"
        cmd += f" {fastq} > /dev/null"
        # stderr will have the progress bar, which passes through
        # stdout will have the report, which I throw out
        print()
        print(f"Trimming {fastq} with cutadapt using adapter {args.adapter}")
        subprocess.run(cmd, shell=True, text=True)
        print("    Condensing and counting the results")

        ca_files = [tfile, sfile, nafile]
        with ThreadPoolExecutor(max_workers=3) as executor:
            #executor.map(condense_fasta, ca_files)
            futures = [executor.submit(condense_fasta, f) for f in ca_files]
            for future in futures:
                future.result()  # This will re-raise any exception that occurred

def condense_fasta(input_file):
    cd_file = input_file.replace('.fa', '_Cd.txt')
    cd_cmd = f"grep -v '>' {input_file} | sort | uniq -c | sort -nr > {cd_file}"
    subprocess.run(cd_cmd, shell=True, text=True)
    subprocess.run(f"rm -f {input_file}", shell=True, text=True)
    cdfa_file = input_file.replace('.fa', '_Cd.fa')
    seqid = 0
    head, tail = os.path.split(cdfa_file)
    base = tail.replace('_Cd.fa','')
    read_tally = defaultdict(int)
    seq_tally = defaultdict(int)
    with open(cdfa_file, 'w') as cdfa:
        with open(cd_file) as sorted_file:
            for line in sorted_file:
                seqid += 1
                parts = line.split()
                
                if len(parts) >= 1:  # At least one column
                    count = parts[0]
                    sequence = parts[1] if len(parts) >= 2 else ''
                    header = f">{base}_Cd{seqid}_{count}"
                    cdfa.write(f"{header}\n{sequence}\n")
                    read_tally[len(sequence)] += int(count)
                    seq_tally[len(sequence)] += 1
    subprocess.run(f"rm -f {cd_file}", shell=True, text=True)
    tally_file = base + '_tally.csv'
    # check if we have a t_, s_, or u_
    prefix, rest = base[:2], base[2:]
    if prefix == 't_':
        category = 't (Trimmed OK)'
    elif prefix == 's_':
        category = 's (Short)'
    elif prefix == 'u_':
        category = 'u (Untrimmed)'
    else:
        sys.exit(f"Error: unable to parse {prefix} during read counting and condensation.")
    with open(tally_file, 'w') as tf:
        for length in sorted(read_tally):
            out = [length, read_tally[length], seq_tally[length], category, rest]
            out_str = ','.join(map(str, out))
            tf.write(out_str + '\n')

def merge_tallys():
    tally_file = 'ShortCut_tallys.csv'
    header_row = 'Length,Reads,Sequences,Category,Source\n'
    with open(tally_file, 'w') as tf:
        tf.write(header_row)
    cmd = f"cat *_tally.csv >> {tally_file}"
    subprocess.run(cmd, shell=True, text=True)
    clean_cmd = "rm -f *_tally.csv"
    subprocess.run(clean_cmd, shell=True, text=True)
    tallys_df = pd.read_csv('ShortCut_tallys.csv')
    print()
    print(f"Wrote summary data to {tally_file}")
    return tallys_df

def write_html(args,trim_fig,pdok_fig,len_figs,trim_prop_fig):
    # Get the Plotly.js library as a string to embed
    plotlyjs = pyo.get_plotlyjs()
    # Prepare the string from args
    argstring = ''
    for arg, value in vars(args).items():
        if value is not None:
            # Format lists nicely
            if isinstance(value, list):
                argstring += f"<li>{arg}: {', '.join(value)}</li>"
            else:
                argstring += f"<li>{arg}: {value}</li>"

    html_start_string = f"""
<!DOCTYPE html>
<html>
<head>
    <meta charset="utf-8">
    <title>ShortCut Report</title>
    <script type="text/javascript">{plotlyjs}</script>
       <style>
        body {{
            font-family: "Open Sans", verdana, arial, sans-serif;
        }}
    </style>
</head>
<body>
    <div id="settings">
        <h1>ShortCut Report</h1>
        <h2>Options</h2>
        <p>ShortCut version {version} completed at {time.strftime("%Y-%m-%d %H:%M:%S")}</p>
        <ul>{argstring}</ul>
    </div>
    <div><h2>Read Trimming Results</h2>
            <h3>By count</h3>
            {trim_fig}
            <h3>By proportion</h3>
            {trim_prop_fig}
    </div>
    <div><h2>DicerOK Results</h2>
        <p>"DicerOK" is the fraction >= dicermin {args.dicermin} AND <= dicermax {args.dicermax}</p>
        {pdok_fig}
    </div>
    <div><h2>Individual Read Length Distributions</h2>
        {len_figs}
    </div>
</body>
</html>  
"""
    with open('ShortCut_Results.html', 'w', newline='') as out:
        out.writelines(html_start_string)
    print()
    print('Report completed and written to ShortCut_Results.html')

def create_trim_prop_figure(df):
    # Sum reads and sequences by Category and Source
    summed_df = df.groupby(['Category', 'Source'])[['Reads', 'Sequences']].sum().reset_index()

    # Calculate proportions for each Source
    proportions_df = summed_df.copy()
    proportions_df['Reads'] = summed_df.groupby('Source')['Reads'].transform(lambda x: x / x.sum())
    proportions_df['Sequences'] = summed_df.groupby('Source')['Sequences'].transform(lambda x: x / x.sum())

    # Convert to long format
    proportions_long = proportions_df.melt(
        id_vars=['Category', 'Source'],
        value_vars=['Reads', 'Sequences'],
        var_name='Metric',
        value_name='Proportion'
    )

    # Produce the figure
    fig = px.box(proportions_long,
                 x = "Category",
                 y = "Proportion",
                 color = "Metric",
                 points = "all",
                 hover_data=proportions_long.columns,
                 width = 900,
                 range_y=[0,1])

    html = fig.to_html(include_plotlyjs=False,
                       full_html=False,
                       div_id='trim_prop_fig')
    
    return html

#_________ run ____________
def main():
    check_executables()
    args = get_ShortCut_args(version)
    print_start_message(args)
    fastqs = [os.path.abspath(item) for item in args.fastq]
    prep_outdir(args)
    get_adapter(args, fastqs)
    trim_reads(args,fastqs)
    tallys_df = merge_tallys()
    
    # Create figures
    # These functions return one or more html divs
    trim_fig = create_trim_stacked_barchart(tallys_df)
    trim_prop_fig = create_trim_prop_figure(tallys_df)
    pdok_fig = create_dicerok_figure(tallys_df,args)
    len_figs = get_len_figs(tallys_df,args)

    # Write html file
    write_html(args,trim_fig,pdok_fig,len_figs,trim_prop_fig)


if __name__ == "__main__":
    main()
