#!/usr/bin/env python
version = '0.3.5'

import RNA, re, pysam, gzip
from pysam import FastaFile
import argparse, shutil, subprocess
from subprocess import Popen, PIPE, STDOUT
from Bio import SeqIO
import os, tqdm, glob
from pathlib import Path
import csv, sys, time
import pandas as pd
import numpy as np
from csv import writer
startTime = time.time()


#Check required packages
print("Checking required modules...")
req_packs=["bowtie","samtools","RNAfold","cutadapt"]
for pack in req_packs:
        path = shutil.which(pack)
        if path is not None:
            print('Requires package {0} : {1}'.format(pack,path))
        else:
            msg = 'Requires package {0} : Not found!'.format(pack)
            sys.exit(msg)


#Assign arguments for command line
ap = argparse.ArgumentParser()

ap.add_argument('-fastq',nargs='*',required=True,help='One or more fastq alignment files')
ap.add_argument("-mature", required = True, help="fasta file of mature miRNA sequence")
ap.add_argument("-hairpin", required = True, help="fasta file of hairpin precursor sequence")
ap.add_argument("-mm", help="Allow up to 1 mismatch in miRNA reads", nargs='?', const='')
ap.add_argument("-n", help="Results file name")
ap.add_argument("-nostrucvis", help="Do not include StrucVis output",  nargs='?', const='')
ap.add_argument("-threads", help="Specify number of threads for samtools", default = 1)
ap.add_argument("-kingdom", required=True,choices=["plant","animal"],help="Specify animal or plant")
ap.add_argument("-out", help="output directory",default="miRScore_output")
ap.add_argument("-autotrim", help="Trim untrimmed fastq files", nargs='?', const='')
ap.add_argument("-trimkey", help="Abundant miRNA used to find adapters for trimming with option -autotrim")
ap.add_argument("-rescue", help="Reevaluate failed miRNAs and reannotate loci with alternative miRNA duplex that meets all criteria",nargs='?', const='')
args = ap.parse_args()


#Check that mature and hairpin files exist.
if os.path.isfile(args.mature):
    print('')
else:
    msg=(args.mature + " does not exist")
    sys.exit(msg)
if os.path.isfile(args.hairpin):
    print('')
else:
    msg=(args.hairpin + " does not exist")
    sys.exit(msg)

#print log
print("--------")
print("miRScore version " + version)
print("Options:")
print("     'Mature file' " + args.mature)
print("     'Hairpin file' " + args.hairpin)
if args.n is not None:
    print("     'Filename' "+ args.n)
if args.mm is not None:
    print("     'Mismatches' Yes")
if args.nostrucvis is not None:
    print("     'No strucvis' Yes")
print("     'Threads' " + str(args.threads))
print("     'Kingdom' "+ args.kingdom)
if args.fastq is not None:
    print("     'Fastqs' " +str(args.fastq))
print("     'Output directory' "+ args.out+"/")
if args.rescue is not None:
    print("     'Rescue mode' Yes")

print("--------")

#Check fastq/fasta exists
if args.fastq is not None:
        #Create a list of all fastq files in provided directory.
        fastas = ['../' + item for item in args.fastq if item.endswith(('.fa', '.fasta', '.fa.gz', '.fasta.gz'))]
        fastqs = ['../' + item for item in args.fastq if item.endswith(('.fastq', '.fq', '.fastq.gz', '.fq.gz'))]

        total_files = len(fastas) + len(fastqs)
        print(f"Number of FASTA files detected: {len(fastas)}")
        print(f"Number of FASTQ files detected: {len(fastqs)}")
        print(f"Total input libraries: {total_files}")

        if total_files == 0:
            sys.exit("No valid FASTQ or FASTA files found!")
else:
    msg="Small-RNA libraries required. Please provide libraries."
    sys.exit(msg)

#Check paths exist
path=args.out
isExist = os.path.exists(path)
if isExist==True:
    msg = ("Output directory '" + args.out + "/' already exists. Please assign a new output directory using option '-out'.")
    sys.exit(msg)

DEVNULL = open(os.devnull, 'w')

#______________ FUNCTIONS ________________#
def find_pos(pairs, start):
    """
    Find the next position sequentially from 'start' that does not have None as a pair
    """
    for i in range(start + 1, len(pairs) + 1):
        if pairs[i] is not None:
            return i, pairs[i]
    return None, None

def run(cmd) :
#Run subprocess
    proc = subprocess.call(cmd, shell=True,stdout=subprocess.DEVNULL,stderr=subprocess.STDOUT)

def add_values_in_dict(dict, key, list_of_values):
    """
    Add a list of values to a dictionary under a specific key. If the key does not exist, it is created.

    Parameters:
        target_dict (dict): The dictionary to update.
        key (str): The key under which to add values.
        values (list): List of values to add.

    Returns:
        dict: The updated dictionary.
    """
    if key not in dict:
        dict[key] = list()
    dict[key].extend(list_of_values)
    return dict

def process_mirnas(input_file, output_file):
    """
    Process an input file containing miRNA sequences by converting 'U' to 'T' in sequences
    and saving the output to a specified file.

    Parameters:
        input_file (str): Path to the input file.
        output_file (str): Path to the output file.
    """
# Execute the shell command using subprocess
    subprocess.run(
        f"cat ../{input_file} | sed -E '/(>)/!s/U/T/g' > {output_file}",
        shell=True
    )

def run_rnaplot(x, seq, mstart, mstop, mirstart, mirstop, status):
    """
    Add a list of values to a dictionary under a specific key. If the key does not exist, it is created.

    Parameters:
        x (str): MIRNA name
        seq (str): hairpin sequence
        mstart (int): start position of mature miRNA on hairpin
        mstop (int): stop position of mature miRNA on hairpin
        mirstart (int): start position of star miRNA on hairpin
        mirstop (int): stop position of star miRNA on hairpin
        status (str): pass or fail result from miRScore scoring step

    Returns:
        PS file: post script file of RNAplot from viennaRNA
    """
    # Determine the working directory based on status
    work_dir = "failed" if "Fail" in status else "passed"
    
    # Change to the appropriate directory
    os.makedirs(work_dir, exist_ok=True)  # Ensure the directory exists
    os.chdir(work_dir)

    # Write the sequence to a fasta file
    fasta_file = f"{x}.fa"
    with open(fasta_file, "w") as f:
        f.write(f"> {x}\n{seq}")

    # Construct the RNAfold and RNAplot command
    command = (
        f"RNAfold {fasta_file} | RNAplot --pre '"
        f"{mstart} {mstop} 8 1 0.5 0 omark "
        f"{mirstart} {mirstop} 8 0 0.5 1 omark'"
    )
    
    # Execute the command
    subprocess.Popen(command, shell=True)

    # Return to the parent directory
    os.chdir('..')


def DNAcheck(dna):
    """
    Check whether a DNA/RNA sequence contains only valid nucleotides (A, T, G, C, U).

    Parameters:
        sequence (str): DNA/RNA sequence to validate.

    Returns:
        int: 2 if the sequence is valid, 1 otherwise.
    """
#Check that all characters in provided sequences are ATGCU
    y = dna.upper()
    if re.match("^[ATGCU]+$", y):
        return(2)
    else:
        return(1)

def index_of(asym, in_list):
    """
    Find the index of an asymmetry in a hairpin. If the asym is not present, return -1.

    Returns:
        int: Index of the asym if found, otherwise -1.
    """
#Determine asymmetry in mir complex
    try:
        return in_list.index(asym)
    except ValueError:
        return -1

def score_animal(mature,star,hp,ss,maturepos,starpos):
    """
    Score candidate mature miRNA based on criteria outlined in Axtell and Meyers et al. 2018.

    Parameters:
        mature (str): Mature miRNA sequence.
        star (str): miRNA* (star) sequence.
        hp (str): Hairpin precursor sequence.
        ss (str): Secondary structure in dot-bracket notation.
        maturepos (list): Start and end positions of the mature miRNA.
        starpos (list): Start and end positions of the star miRNA.

    Returns:
        tuple: A list of scores, reasons for deductions, and evaluation criteria.
    """
#Determine score of candidate matureNA based on criteria outlined in Axtell and Meyers et al 2018.
    reason=[]
    mscore =[]
    criteria=[]
    # Test that precursor is less than 200 nt
    if len(hp)>200:
        mscore.append(10)
        reason.append("Precursor > 200 nt")
        criteria.append("N")
    else:
        mscore.append(20)
        criteria.append("Y")
        
    #Test length of mature
    if (len(mature)>26) or (len(mature)<20):
        mscore.append(0)
        reason.append("Mature miRNA length not met")
        criteria.append("N")
    elif (len(mature)>19) and (len(mature)<27):
        mscore.append(20)
        criteria.append("Y")

    #Test length of star
    if (len(star)>26) or (len(star)<20):
        mscore.append(0)
        reason.append("Star length not met")
        criteria.append("N")
    elif (len(star)>19) and (len(star)<27):
        mscore.append(20)
        criteria.append("Y")

    #Obtain dotbrackets of mature and maturestar
    star_db=ss[starpos[0]:starpos[1]]
    mature_db = ss[maturepos[0]:maturepos[1]]
    pairs=get_pairs(ss)

    #Test that mature/mature* duplex does not contain greater than 5 mismatches
    if (star_db.count(".")>7) or (mature_db.count(".") > 7):
        mscore.append(0)
        reason.append(str("More than 7 mismatches in duplex"))
        if (star_db.count("."))>(mature_db.count(".")):
            criteria.append(star_db.count("."))
        else:
            criteria.append(mature_db.count("."))
    elif (star_db.count(".")>5) or (mature_db.count(".") > 5):
        mscore.append(10)
        reason.append(str("More than 5 mismatches in duplex"))
        if (star_db.count("."))>(mature_db.count(".")):
            criteria.append(star_db.count("."))
        else:
            criteria.append(mature_db.count("."))
    else:
        mscore.append(20)
        if (star_db.count("."))>(mature_db.count(".")):
            criteria.append(star_db.count("."))
        else:
            criteria.append(mature_db.count("."))

    #Test that mature/mature* does not contain an asymmetric bulge greater than 3    
    if index_of("....", star_db) != -1 and index_of("....", mature_db) == -1:
        start=index_of('....',star_db)+starpos[0]
        stop=find_pos(pairs,start)[0]
        start2=pairs[start]
        stop2=find_pos(pairs,start)[1]
        dbs=ss[stop2:start2]

        if start==None or stop ==None:
            mscore.append(0)
            criteria.append("1")

        elif "." in dbs:
            if (stop-(start+1))-dbs.count(".")>3:
                mscore.append(0)
                reason.append("asymmetric bulge greater than 3")
                criteria.append("1")
            else:
                mscore.append(20)
                criteria.append("0")
        else:
            mscore.append(0)
            reason.append("asymmetric bulge greater than 3")
            criteria.append("1")
    elif index_of("....", star_db) == -1 and index_of("....", mature_db) != -1:
        if maturepos[0] == 0:
            maturepos[0]=maturepos[0]+1
        start=index_of('....',mature_db)+maturepos[0]
        stop=find_pos(pairs,start)[0]
        start2=pairs[start]
        stop2=find_pos(pairs,start)[1]
        dbs=ss[stop2:start2]

        if start==None or stop ==None:
            mscore.append(0)
            criteria.append("1")

        elif "." in dbs:
            if (stop-(start+1))-dbs.count(".")>3:
                mscore.append(0)
                reason.append("asymmetric bulge greater than 3")
                criteria.append("1")
            else:
                mscore.append(20)
                criteria.append("0")
        else:
            mscore.append(0)
            reason.append("asymmetric bulge greater than 3")
            criteria.append("1")
    else:
        mscore.append(20)
        criteria.append("0")

    return mscore, reason, criteria

def score_plant(mature,star,hp,ss,maturepos,starpos):
    """
    Score candidate mature miRNA based on criteria outlined in Axtell and Meyers et al. 2018.

    Parameters:
        mature (str): Mature miRNA sequence.
        star (str): miRNA* (star) sequence.
        hp (str): Hairpin precursor sequence.
        ss (str): Secondary structure in dot-bracket notation.
        mature_pos (list): Start and end positions of the mature miRNA.
        star_pos (list): Start and end positions of the star miRNA.

    Returns:
        tuple: A list of scores, reasons for deductions, and evaluation criteria.
    """
#Determine score of candidate matureNA based on criteria outlined in Axtell and Meyers et al 2018.
    reason=[]
    mscore =[]
    criteria=[]
    # Test that precursor is less than 300 nt
    if len(hp)>300:
        mscore.append(10)
        reason.append("Precursor > 300 nt")
        criteria.append("N")
    else:
        mscore.append(20)
        criteria.append("Y")
        
    #Test length of miRNA
    if (len(mature)>24) or (len(mature)<20):
        mscore.append(0)
        reason.append("Mature miRNA length not met")
        criteria.append("N")
    elif (len(mature)>19) and (len(mature)<23):
        mscore.append(20)
        criteria.append("Y")
    elif (len(mature)>22) and (len(mature)<25):
        mscore.append(20)
        reason.append("23/24 nt miRNA")
        criteria.append("23/24")

    #Test length of miRNA
    if (len(star)>24) or (len(star)<19):
        mscore.append(0)
        reason.append("Star length not met")
        criteria.append("N")
    elif (len(star)>=19) and (len(star)<23):
        mscore.append(20)
        criteria.append("Y")
    elif (len(star)>22) and (len(star)<25):
        mscore.append(20)
        reason.append("23/24 nt miRNA star")
        criteria.append("23/24")

    #Obtain dotbrackets of mature and maturestar
    star_db=ss[starpos[0]:starpos[1]]
    mature_db = ss[maturepos[0]:maturepos[1]]
    pairs=get_pairs(ss)

    #Test that mature/mature* duplex does not contain greater than 5 mismatches
    if (star_db.count(".")>5) or (mature_db.count(".") > 5):
        mscore.append(0)
        reason.append("More than 5 mismatches in duplex")
        if (star_db.count("."))>(mature_db.count(".")):
            criteria.append(star_db.count("."))
        else:
            criteria.append(mature_db.count("."))
    else:
        mscore.append(20)
        if (star_db.count("."))>(mature_db.count(".")):
            criteria.append(star_db.count("."))
        else:
            criteria.append(mature_db.count("."))

    #Test that mature/mature* does not contain an asymmetric bulge greater than 3    
    if index_of("....", star_db) != -1 and index_of("....", mature_db) == -1:
        start=index_of('....',star_db)+starpos[0]
        stop=find_pos(pairs,start)[0]
        start2=pairs[start]
        stop2=find_pos(pairs,start)[1]
        dbs=ss[stop2:start2]

        if start==None or stop ==None:
            mscore.append(0)
            reason.append("Hairpin structure invalid")
            criteria.append("1")
        elif "." in dbs:
            if (stop-(start+1))-dbs.count(".")>3:
                mscore.append(0)
                reason.append("asymmetric bulge greater than 3")
                criteria.append("1")
            else:
                mscore.append(20)
                criteria.append("0")
        else:
            mscore.append(0)
            reason.append("asymmetric bulge greater than 3")
            criteria.append("1")
    elif index_of("....", star_db) == -1 and index_of("....", mature_db) != -1:
        start=index_of('....',mature_db)+maturepos[0]
        stop=find_pos(pairs,start)[0]
        start2=pairs[start]
        stop2=find_pos(pairs,start)[1]
        dbs=ss[stop2:start2]

        if start==None or stop ==None:
            mscore.append(0)
            criteria.append("1")
        elif "." in dbs:
            if (stop-(start+1))-dbs.count(".")>3:
                mscore.append(0)
                reason.append("asymmetric bulge greater than 3")
                criteria.append("1")
            else:
                mscore.append(20)
                criteria.append("0")
        else:
            mscore.append(0)
            reason.append("asymmetric bulge greater than 3")
            criteria.append("1")
    else:
        mscore.append(20)
        criteria.append("0")

    return mscore, reason, criteria

def get_s_rels(q_rel_start, q_rel_stop, dotbracket):
    """
    Determine the start and stop of miR* for a given miRNA and hairpin precursor
    """
    pairs = get_pairs(dotbracket)

    # determine the 3' end of miR* from q_rel_start
    s_rel_stop = None
    for attempts, qpos in enumerate(range(q_rel_start, q_rel_stop)):
        if pairs[qpos] is not None:
            s_rel_stop = pairs[qpos] + attempts + 2
            break
   
    # determine the 5' end or miR* from q_rel_stop - 2
    s_rel_start = None
    for attempts, qpos in enumerate(range((q_rel_stop - 2), q_rel_start, -1)):
        if pairs[qpos] is not None:
            s_rel_start = pairs[qpos] - attempts
            break
    return s_rel_start, s_rel_stop
   
def get_pairs(dotbracket):
    """
    Compute positional lookups for a dotbracket structure
    """
        # Note conversion from enumerate's zero-based system to one-based
    pairs = {}
    for i, j in enumerate(dotbracket):
        pairs[(i + 1)] = None
   
    stack = []
    for i, j in enumerate(dotbracket):
        if j == '(':
            stack.append((i + 1))
        elif j == ')':
            bp1 = stack.pop()
            bp2 = i + 1
            pairs[bp1] = bp2
            pairs[bp2] = bp1
    return pairs

def align_fastqs(args, files):
    """
    Align FASTQ or FASTA file(s) to hairpin sequences and output alignments in BAM format.

    Parameters:
        args: argparse arguments with .threads and .mm (mismatch) options
        files: list of input FASTQ or FASTA files

    Returns:
        None (writes BAM files to alignments/)
    """
    print("\nBuilding index...")
    subprocess.call(['bowtie-build', 'tmp.hairpin.fa', 'hairpin'],
                    stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)

    print("Mapping each file to hairpin...")
    for file in files:
        ext = Path(file).suffix.lower()
        is_gz = file.endswith('.gz')
        file_core = Path(file).with_suffix('').stem if is_gz else Path(file).stem
        print(file_core)
        print(file)

        # Determine mismatch and input format
        mismatch_option = "-v1" if args.mm else "-v0"
        fasta_flag = "-f" if ext in [".fa", ".fasta"] or file.endswith((".fa.gz", ".fasta.gz")) else ""

        print(f"Mapping {file} with mismatch option {mismatch_option} and format {'FASTA' if fasta_flag else 'FASTQ'}...")

        if is_gz:
            cmd = (
                f"zcat {file} | "
                f"bowtie -p {args.threads} -a {mismatch_option} --no-unal --norc -S -x hairpin -q - | "
                f"samtools view -Sb - > alignments/{file_core}.bam"
            )
            run(cmd)
        else:
            if ext in [".fa", ".fasta"]:
                cmd=f"bowtie -p {str(args.threads)} -a {mismatch_option} -f --no-unal -S --norc -x hairpin {file} | samtools view -Sb - > alignments/{file_core}.bam"
                run(cmd)
            else:
                cmd=f"bowtie -p {str(args.threads)} -a {mismatch_option} --no-unal -S --norc -x hairpin {file} | samtools view -Sb - > alignments/{file_core}.bam"
                run(cmd)




def process_fastqs(fastqs, keydna, args):
    for fastq in tqdm.tqdm(fastqs, desc="Library trimming", disable=None):
        # Check if the file exists
        if not os.path.exists(fastq):
            sys.exit(f"Error: File {fastq} does not exist. Please check the path.")

        # Handle zipped FASTQ files
        if fastq.endswith(".gz"):
            # Extract adapter sequence from gzipped FASTQ
            cmd = (
                f"gzip -dc {fastq} | awk -v target='{keydna}' '{{idx = index($0, target); if (idx) print substr($0, idx + length(target),20)}}' | "
                "uniq -c | sort -nr | head -n1 | awk '{print $2}'"
            )
        else:
            # Extract adapter sequence from unzipped FASTQ
            cmd = (
                f"awk -v target='{keydna}' '{{idx = index($0, target); if (idx) print substr($0, idx + length(target),20)}}' {fastq} | "
                "uniq -c | sort -nr | head -n1 | awk '{print $2}'"
            )

        # Run command to find adapter
        try:
            output = subprocess.check_output(cmd, shell=True, text=True).strip()
        except subprocess.CalledProcessError as e:
            sys.exit(f"Error while processing {fastq}: {e.output}")

        print(output)

        # Exit if no adapter detected
        if "G" not in output:
            sys.exit("No adapter detected. Please check input libraries are not already trimmed.")

        print(f"Trimming {fastq}...")
        print(f"Adapter detected: {output}")

        # Prepare trimmed output filename
        head, tail = os.path.split(fastq)
        tfile = os.path.join('trimmedLibraries', f"t_{tail}")

        # Use cutadapt to trim
        if fastq.endswith(".gz"):
            # For gzipped files
            cmd = (
                f"gzip -dc {fastq} | cutadapt -j {args.threads} -a {output} -o {tfile} -m 12 --report minimal -"
            )
        else:
            # For unzipped files
            cmd = (
                f"cutadapt -j {args.threads} -a {output} -o {tfile} -m 12 --report minimal {fastq}"
            )

        # Run the trimming command
        try:
            run(cmd)
        except Exception as e:
            sys.exit(f"Error trimming {fastq}: {e}")



def get_counts(bam, x, start_indx=None, stop_indx=None):
    if start_indx is None or stop_indx is None:
        command = f"samtools view -r {Path(bam).stem} -F 16 alignments/merged.bam {x} | wc -l"
    else:
        command = f"samtools view -r {Path(bam).stem} -F 16 -e 'pos>={start_indx} && endpos<={stop_indx}' alignments/merged.bam {x} | wc -l"

    output = subprocess.check_output(command, shell=True, text=True).strip()
    count = int(output)
    return count

def predict_alt_mirnas(mirna_dict):
    alt_mirnas = {}
    if args.kingdom == "plant":

        for x in mirna_dict:
            if mirna_dict[x][10] == "Fail":
                if mirna_dict[x][11] != "Hairpin structure invalid":
                    mir = str(mirna_dict[x][0])
                    hp = str(mirna_dict[x][8])
                    cmd = f"samtools view -@ {args.threads} -F 16 alignments/merged.bam {x} | awk -F '\t' '{{print $10}}' | sort | uniq -c | sort -nr | awk 'length($2) > 19 && length($2) < 25 {{print}}' | head -n1 | awk '{{print $2}}'"
                    output = subprocess.check_output(cmd, shell=True, text=True).strip()
                    if output:
                        alt_mirnas[x] = output
        return alt_mirnas
    elif args.kingdom == 'animal':
        for x in mirna_dict:
            if mirna_dict[x][10] == "Fail":
                if mirna_dict[x][11] != "Hairpin structure invalid":
                    mir = str(mirna_dict[x][0])
                    hp = str(mirna_dict[x][8])
                    cmd = f"samtools view -@ {args.threads} -F 16 alignments/merged.bam {x} | awk -F '\t' '{{print $10}}' | sort | uniq -c | sort -nr | awk 'length($2) >= 19 && length($2) <= 25 {{print}}' | head -n1 | awk '{{print $2}}'"
                    output = subprocess.check_output(cmd, shell=True, text=True).strip()
                    if output:
                        alt_mirnas[x] = output
        return alt_mirnas

def score_alternative_mirnas(alt_mrnas, hp_dict):
    """
    Score alternative miRNAs by evaluating their position, secondary structure, and other criteria.

    Parameters:
        alt_mirnas (dict): Dictionary of alternative miRNAs.
        hp_dict (dict): Dictionary of hairpin sequences.

    Returns:
        dict: Dictionary containing scores and attributes for alternative miRNAs.
    """
    pred_dict = {}

    for x in tqdm.tqdm(alt_mrnas, desc="Scoring alternative miRNAs", disable=None):
            hp = str(hp_dict[x].seq)
            mature = str(alt_mrnas[x])

            mstart0 = hp.index(mature)                  # 0-based
            mstart  = mstart0 + 1                       # 1-based
            mstop   = mstart0 + len(mature)             # 1-based stop (matches your main pipeline style)

            (ss, mfe) = RNA.fold(hp)

            # get_s_rels expects 1-based q positions (pairs keys are 1..len)
            mirspos = get_s_rels(mstart, mstop, ss)
            if mirspos[0] is not None:
                mirsspos = [mirspos[0], mirspos[1]]     # keep 1-based

                # build star sequence using 1-based -> python slice
                if mirsspos[0] < mirsspos[1] and (mirsspos[1] - mirsspos[0]) < 25:
                    mirstar = hp[(mirsspos[0]-1):mirsspos[1]]   # <-- note -1

                    mirpos = [mstart, mstop]            # 1-based
                    if args.kingdom == "plant":
                        result = score_plant(mature, mirstar, hp, ss, mirpos, mirsspos)
                    else:
                        result = score_animal(mature, mirstar, hp, ss, mirpos, mirsspos)

                    # Add result from score function to mirna dict
                sep = ";"
                if len(result[1]) == 0:
                        flag = ["NA"]
                else:
                        flag = list(result[1])
                    # If mirna contains both ) and (, it fails for pairing to itself. Else, if the result of scoring is greater than 80, add mirna to dictionary as a pass, and if not fail it.
                if sum(result[0]) > 81:
                        add_values_in_dict(pred_dict, x, [mature.translate(str.maketrans("tT", "uU")),
                                                          len(alt_mrnas[x]), mstart, mstop,
                                                          str(mirstar.translate(str.maketrans("tT", "uU"))),
                                                          len(mirstar), (mirsspos[0] + 1), mirsspos[1],
                                                          hp.translate(str.maketrans("tT", "uU")), len(hp),
                                                          "Pass", sep.join(flag), result[2][3], result[2][4]])
    return pred_dict

def swap_mirnas(file):
    # Load your CSV
    df = pd.read_csv(file)

    # Ensure mReads and msReads are numeric (just in case)
    df['mReads'] = pd.to_numeric(df['mReads'], errors='coerce')
    df['msReads'] = pd.to_numeric(df['msReads'], errors='coerce')

    # Create a mask that excludes NAs and selects rows where mReads < msReads
    valid_mask = df['mReads'].notna() & df['msReads'].notna()
    swap_mask = valid_mask & (df['mReads'] < df['msReads'])
    output_file = args.n + "_miRScore_results.csv" if args.n else "miRScore_results.csv"

    # Columns to swap in pairs
    column_pairs = [
        ('mReads', 'msReads'),
        ('mSeq', 'msSeq'),
        ('mLen', 'msLen'),
        ('mStart', 'msStart'),
        ('mStop', 'msStop')
    ]

    # Swap values only in rows where swap_mask is True
    for m_col, ms_col in column_pairs:
        temp = df.loc[swap_mask, m_col].copy()
        df.loc[swap_mask, m_col] = df.loc[swap_mask, ms_col]
        df.loc[swap_mask, ms_col] = temp

    df = df.fillna("NA")
    # Save to file
    df.to_csv(output_file, index=False)


def generate_options(hp):
    """
    Generate naming options for miRNA and miRNA* based on the input hairpin name (hp).

    Parameters:
        hp (str): Hairpin name.

    Returns:
        tuple: Two lists containing options for the miRNA and miRNA* (star) respectively.
    """
    if args.kingdom == "plant":
        if re.match(r'^\w+-\w+-\w+', hp):
            base_name = "-".join(hp.split("-")[:-1])  # Extract all elements before the last hyphen
            suffix_name = hp.split("-")[-1]  # Extract the last element after the last hyphen
            suffixes = ["","-5p","-3p","-" + suffix_name,"-" + suffix_name + "-5p","-" + suffix_name + "-3p",".mature",".1",".2",".3",".4"]
            star_suffixes = ["-3p", "-" + suffix_name + "-3p", ".star", "*"]
        else:
            base_name = hp
            suffixes = ["","-5p","-3p",".mature",".1",".2",".3",".4"]
            star_suffixes = ["-3p", ".star", "*"]
        options = [base_name + suffix for suffix in suffixes]
        star_options = [base_name + suffix for suffix in star_suffixes]
        return options, star_options
            
    elif args.kingdom == "animal":
        if hp.endswith("-as"):
            # Keep -as intact as part of the base name
            base_name = hp
            suffixes = ["", "-5p", "-3p", ".mature"]
            star_suffixes = ["-3p", ".star", "*"]
        if re.match(r'^\w+-mir\w+-\w+-', hp):
            base_name = "-".join(hp.split("-")[:-1])  # Extract the first three elements before the hyphen and join them with hyphens
            suffix_name = "-".join(hp.split("-")[-1:])  # Extract the first three elements before the hyphen and join them with hyphens
            suffixes = ["", "-5p", "-" + suffix_name, "-" + suffix_name + "-5p", ".mature",'-3p','-' + suffix_name + '-3p']
            star_suffixes = ["-3p","-" + suffix_name + '-3p',".star","*"]
        elif re.match(r'^\w+-\w+-\w+-', hp):
            base_name = "-".join(hp.split("-")[:-1])  # Extract the first three elements before the hyphen and join them with hyphens
            suffix_name = "-".join(hp.split("-")[-1:])  # Extract the first three elements before the hyphen and join them with hyphens
            suffixes = ["", "-5p", "-" + suffix_name, "-" + suffix_name + "-5p", ".mature",'-3p','-' + suffix_name + '-3p']
            star_suffixes = ["-3p","-" + suffix_name + '-3p',".star","*"]
        else:
            base_name = hp
            suffixes = ["", "-5p",".mature",'-3p']
            star_suffixes = ["-3p",".star","*"]
        options = [base_name + suffix for suffix in suffixes]
        star_options = [base_name + suffix for suffix in star_suffixes]
        return options, star_options


def cleanup_directories():
    """
    Removes temporary and intermediate files from various directories.
    """
    paths_to_remove = [
        "Results.csv",
        "RNAplots/passed/*ss.eps*",
        "RNAplots/failed/*ss.eps*",
        "RNAplots/passed/*.fa",
        "RNAplots/failed/*.fa",
        "*.ebwt",
        "*.fai",
        "tmp*",
        "alignments/*bai",
        "ms_fastqs/",
        "RNAplots/passed/*_plot.ps",
        "RNAplots/failed/*_plot.ps",
        "RNAplots/passed/*ss.ps*",
        "RNAplots/failed/*ss.ps*"
    ]

    for path in paths_to_remove:
        cmd = f"rm -rf {path}"
        os.system(cmd)
### _______________________ Code begins __________________________ ###

def main():

    #Create output directory
    subprocess.call(["mkdir",args.out])
    os.chdir(args.out)

    #Write log file for output
    text_file = open("log.txt", "w")

    text_file.write("miRScore version " + version + " \n")

    text_file.write("Options: \n")
    text_file.write("     'Mature file' " + args.mature + "\n")
    text_file.write("     'Hairpin file' " + args.hairpin +"\n")
    if args.n is not None:
        text_file.write("     'Filename' "+ args.n + "\n")
    if args.mm is not None:
        text_file.write("     'Mismatches' Yes \n")
    if args.nostrucvis is not None:
        text_file.write("     'No strucvis' Yes \n")
    text_file.write("     'Threads' " + str(args.threads)+ "\n")
    text_file.write("     'Kingdom' "+ args.kingdom + "\n")
    if args.fastq is not None:
        text_file.write("     'Fastqs' " +str(args.fastq) + "\n")
    text_file.write("     'Output directory' "+ args.out+"/ \n")
    text_file.close()

    #Create temporary files converting any U's to T's
    #hairpins
    process_mirnas(args.hairpin, "tmp.hairpin.fa")
    #miRNA
    process_mirnas(args.mature, "tmp.mature.fa")

    #Fetch miRNA precursor sequence matching mature candidate miRNA sequence
    mir_dict =  SeqIO.index("tmp.mature.fa", "fasta")
    hp_dict = SeqIO.index("tmp.hairpin.fa", "fasta")


    print("miRNAs submitted: " + str(len(hp_dict.values())))
          

    #Working list of miRNAs that failed and the first criteria in which they failed.
    name_dict={}
    y=[]
    mirna_dict={}
    failed={}
    mir_names={}
    sep=';'
    #Checks length of hairpin and if the miRNA multimaps to it. 
    #If multimapping is detect, miRNA is removed and added to failed dictionary
    #Check that each miRNA contains only A, T, G, C. If not, add to failed or quit.
    print('')
    print("Checking hairpin and miRNA sequences...")
    print(' ')
    print(' ')

#Create dictionaries with all lowercase names to compare names between files
    normalized_hpdict = {key.lower(): key for key in hp_dict.keys()}
    normalized_mirdict = {key.lower(): key for key in mir_dict.keys()}
    #determining start positions are at start of hairpin. Z is for mature, s for star.
    mirpos_error=[]
    starpos_error=[]
    character_error=[]
    noHP_error=[]
    noHPstar_error=[]

    for hp in normalized_hpdict:
        #Generate normalized name options for mature and star
        options = generate_options(hp)
        mature_option = next((option for option in options[0] if option in normalized_mirdict), None)
        if mature_option == None:
            noHP_error.append(normalized_hpdict[hp])
        if mature_option is not None:
            if '3p' not in mature_option:
                star_option = next((option for option in options[1] if option in normalized_mirdict), None)
            else:
                star_option = None
        else:
            star_option = None
        mir_names[hp] = mature_option,star_option  

        #Determine original, non-normalized name
        og_hp=normalized_hpdict[hp]
        try:
            og_mir=normalized_mirdict[mature_option]
        except:
            if og_hp not in noHP_error:
                noHP_error.append(og_hp)

        if star_option is not None:
            try:
                og_star=normalized_mirdict[star_option]
            except:
                if og_hp not in noHP_error:
                    noHPstar_error.append(og_hp)
        else:
            og_star=None

        if len(noHPstar_error) >1:
            sys.exit("miRNA star sequence not found for the following hairpin: " + str(noHPstar_error))
        
        if mature_option is not None:
            #Create a dictionary of hairpin names and corresponding mir/mir*
            name_dict[og_hp] = og_mir,og_star
            
            #Hairpin and mature sequences based on og name
            hairpin=str(hp_dict[og_hp].seq).upper()
            mature=str(mir_dict[og_mir].seq).upper()
            if len(hairpin)<len(mature):
                msg="Hairpin sequence is shorter than the mature sequence. Please check you used the correct files."
                sys.exit(msg)
            #Index mature sequence
            if mature in hairpin:
                maturestarti=hairpin.index(mature)
                maturestop=(maturestarti+len(mature))
                maturestart=maturestarti+1
                if maturestarti == 0 and og_star==None:
                    mirpos_error.append(og_hp)
                    maturestarti=maturestarti+5
                    maturestop=maturestop+5
                #if maturestop==len(hairpin):
                    #mirpos_error.append(og_hp)
                (ss,mfe)=RNA.fold(hairpin)
                mir_ss=ss[maturestarti:maturestop]

                #Index star sequence if exists, otherwise find star sequence.
                if og_star is not None:
                    star=str(mir_dict[og_star].seq).upper()
                    if star in hairpin:
                        starstarti=hairpin.index(star)
                        starstop=(starstarti+len(star))
                        #Normal indexing of start position
                        starstart=starstarti+1
                        starpos=[starstart,starstop]
                        starpos_check=get_s_rels(maturestart,maturestop,ss)
                    else:
                        og_star= None
                        starpos=get_s_rels(maturestart,maturestop,ss)
                        if starpos is None or starpos[0] is None or starpos[1] is None:
                            failed[og_hp]="Hairpin structure invalid"
                        else:
                            star=hairpin[(starpos[0]-1):starpos[1]].upper()

                else:
                    starpos=get_s_rels(maturestart,maturestop,ss)
                    if starpos[0]== None:
                        failed[og_hp]="Hairpin structure invalid"
                    else:
                        starpos=[(starpos[0]),(starpos[1])]
                        if starpos[0] > starpos[1] or starpos[0]<0:
                            failed[og_hp]="Hairpin structure invalid"
                        else:
                            star=hairpin[(starpos[0]-1):starpos[1]]

                if starpos[0]==0:
                    starpos_error.append(og_hp)
                    
                #Check mature, star, and hp sequences for ATGCU
                hp_check=DNAcheck(hairpin)
                star_check=DNAcheck(star)
                mature_check=DNAcheck(mature)
                if hp_check == 1 or star_check ==1 or mature_check == 1:
                    character_error.append(og_hp)
            
            if og_hp not in failed:
                    #Check hairpin length and if mature or star multimaps to hairpin
                    if mature not in hairpin:
                        failed[og_hp] = "Mature not found in hairpin sequence"
                    elif star not in hairpin:
                        failed[og_hp]='Star not found in hairpin sequence'
                    elif len(hairpin) < 50:
                        failed[og_hp] = "Hairpin is less than 50 basepairs" 
                    elif hairpin.count(mature) > 1:
                        failed[og_hp] = "miRNA multimaps to hairpin" 
                    elif hairpin.count(star)>1:
                        failed[og_hp] = "miRNA multimaps to hairpin" 
                    elif starpos[0]== None or starpos[1]>len(ss):
                        failed[og_hp]="Hairpin structure invalid"
                    elif starpos[0] > starpos[1] or starpos[0]<0:
                        failed[og_hp]="Hairpin structure invalid"
                    elif starpos[0]>maturestart and starpos[0]<=maturestop:
                        failed[og_hp]="Hairpin structure invalid"
                    elif starpos[0]<maturestart and starpos[1]>maturestart:
                        failed[og_hp]="Hairpin structure invalid"
                    else:
                        #Score miRNAs
                        maturepos=[maturestart,maturestop]
                        if args.kingdom == "animal":
                            result=score_animal(mature,star,hairpin,ss,maturepos,starpos)
                        elif args.kingdom == "plant": 
                            result=score_plant(mature,star,hairpin,ss,maturepos,starpos)
                            #Add result from score function to mirna dict
                        sep = ";"
                        if len(result[1]) == 0:
                            flag = ["NA"]
                        else:
                            flag=list(result[1])
                            #If mirna contains both ) and (, it fails for pairing to itself. Else, if the result of scoring is greater than 80, add mirna to dictionary as a pass, and if not fail it. 
            
                        if starpos[0] > starpos[1]:
                            flag=["Hairpin structure invalid"]
                            add_values_in_dict(mirna_dict,og_hp,[mature.translate(str.maketrans("tT", "uU")),len(mature),maturestart,maturestop,"NA","NA",(starpos[0]),starpos[1],hairpin.translate(str.maketrans("tT", "uU")),len(hairpin),"Fail",sep.join(flag),"NA","NA","NA","NA","NA","NA","NA"])
                        elif mir_ss.count("(") != 0 and mir_ss.count(")") !=0:
                            flag=["Hairpin structure invalid"]
                            y.append(og_hp)
                            add_values_in_dict(mirna_dict,og_hp,[mature.translate(str.maketrans("tT", "uU")),len(mature),maturestart,maturestop,str(star.translate(str.maketrans("tT", "uU"))),len(star),(starpos[0]),starpos[1],hairpin.translate(str.maketrans("tT", "uU")),len(hairpin),"Fail",sep.join(flag),result[2][3],result[2][4],"NA","NA","NA","NA","NA"])
                        elif og_star is not None:
                            if starpos[0] != starpos_check[0] or starpos[1] != starpos_check[1]:
                                flag=["No 2nt 3' overhang"]
                                add_values_in_dict(mirna_dict,og_hp,[mature.translate(str.maketrans("tT", "uU")),len(mature),maturestart,maturestop,str(star.translate(str.maketrans("tT", "uU"))),len(star),(starpos[0]),starpos[1],hairpin.translate(str.maketrans("tT", "uU")),len(hairpin),"Fail",sep.join(flag),result[2][3],result[2][4]])
                            elif sum(result[0]) < 81:
                                y.append(og_hp)
                                flag=list(result[1])
                                add_values_in_dict(mirna_dict,og_hp,[mature.translate(str.maketrans("tT", "uU")),len(mature),maturestart,maturestop,str(star.translate(str.maketrans("tT", "uU"))),len(star),(starpos[0]),starpos[1],hairpin.translate(str.maketrans("tT", "uU")),len(hairpin),"Fail",sep.join(flag),result[2][3],result[2][4]])
                            else:  
                                add_values_in_dict(mirna_dict,og_hp,[mature.translate(str.maketrans("tT", "uU")),len(mature),maturestart,maturestop,str(star.translate(str.maketrans("tT", "uU"))),len(star),(starpos[0]),starpos[1],hairpin.translate(str.maketrans("tT", "uU")),len(hairpin),"Pass",sep.join(flag),result[2][3],result[2][4]])
                        elif sum(result[0]) < 81:
                            y.append(og_hp)
                            flag=list(result[1])
                            add_values_in_dict(mirna_dict,og_hp,[mature.translate(str.maketrans("tT", "uU")),len(mature),maturestart,maturestop,str(star.translate(str.maketrans("tT", "uU"))),len(star),(starpos[0]),starpos[1],hairpin.translate(str.maketrans("tT", "uU")),len(hairpin),"Fail",sep.join(flag),result[2][3],result[2][4]])
                        else:
                            add_values_in_dict(mirna_dict,og_hp,[mature.translate(str.maketrans("tT", "uU")),len(mature),maturestart,maturestop,str(star.translate(str.maketrans("tT", "uU"))),len(star),(starpos[0]),starpos[1],hairpin.translate(str.maketrans("tT", "uU")),len(hairpin),"Pass",sep.join(flag),result[2][3],result[2][4]])
            
        #mirna_dict is now complete.
    #Stop program if all miRNAs failed previous check
    if len(mirna_dict) < 1:
        # Write failed miRNAs to CSV before exiting
        if failed:  # Check if there are any failed miRNAs
            header = ["name", "flag"]
            output_file = "failed.csv"
            with open(output_file, mode="w", newline='') as csv_out:
                csv_writer = writer(csv_out)
                csv_writer.writerow(header)
                for k, v in failed.items():
                    csv_writer.writerow([k, v] if isinstance(v, str) else [k, *v])
        sys.exit("Error! No candidate miRNAs left to score. Check 'failed.csv' for details.")
            
    if len(noHP_error)>0:
        print(f"Error! The following entries in the MIRNA hairpin file '" + args.hairpin + "' have no mature sequences that match their identifiers in file '" + args.mature + "'. "+ str(noHP_error))
        print("        Please check all hairpins in the hairpin FASTA file have miRNAs in the mature FASTA file with the same name.") 
        print("        If you have multiple miRNAs assigned to a single locus (i.e. osa-miR159a.1/osa-miR159a.2 to osa-MIR159a)")
        print("        please run 'hairpinHelper' then rerun miRScore with the'miRScore_adjusted_hairpins.fa' file as the hairpin FASTA input.")
        print("        See GitHub README for more details.")
        sys.exit()

    print("miRNAs that failed scoring step: " + str(len(y)))
    print(' ')
    #miRNAs that start at position 1
    #Quit program is any exist
    if len(mirpos_error)>0 and len(starpos_error) and len(character_error)>0:
        print('Error! There are multiple errors in the dataset.')
        print('Please address the following issues: ')
        print("")
        print('Mature miRNAs sequences that index to the first or last position of the hairpin: ')
        print(mirpos_error)
        print("")
        print("Star miRNAs sequences that index to the first or last position of the hairpin: ")
        print(starpos_error)
        print("")
        print( "The following miRNA or hairpin sequences have characters besides A, T, G, C, or U in the sequence. Please ammend or remove the following:")
        print(character_error)
        print(" ")
        sys.exit()
    elif len(mirpos_error)>0:
        print('Error! These miRNAs start at position 1 of the hairpin: ')
        print(mirpos_error)
        print(" ")
        print("The miRNA duplex structure requires a 3p overhang which cannot be determined if the mature/star starts or ends the hairpin. Please extend the hairpin sequences or remove them.")
        print(" ")
        sys.exit()
    elif len(starpos_error)>0:
        print('Error! These miRNA star sequences start at position 1 of the hairpin: ')
        print(starpos_error)
        print(" ")
        print("The miRNA duplex structure requires a 3p overhang which cannot be determined if the mature/star starts or ends the hairpin. Please extend the hairpin sequences or remove them.")
        print(" ")
        sys.exit()
    elif len(character_error)>0:
        print( "Error! The following have characters besides A, T, G, C, or U in the sequence. Please fix or remove the following sequences:")
        print(character_error)
        print(" ")
        sys.exit()
## ______________________BEGIN READ CHECK _____________________##

    DEVNULL = open(os.devnull, 'w')

    #If argument -fastq used, map FASTQ to hairpins directly
    if args.fastq is not None:
        #Create a list of all fastq files in provided directory.
        fastas = ['../' + item for item in args.fastq if item.endswith(('.fa', '.fasta', '.fa.gz', '.fasta.gz'))]
        fastqs = ['../' + item for item in args.fastq if item.endswith(('.fastq', '.fq', '.fastq.gz', '.fq.gz'))]

        total_files = len(fastas) + len(fastqs)
        print(f"Number of FASTA files detected: {len(fastas)}")
        print(f"Number of FASTQ files detected: {len(fastqs)}")
        print(f"Total input files: {total_files}")

        if total_files == 0:
            sys.exit("No valid FASTQ or FASTA files found!")

        subprocess.call(["mkdir","alignments"])
        #Trim fastq files
        if fastqs:
            if args.autotrim is not None:
                print("\nOption autotrim enabled. Trimming fastq files...")
                if args.trimkey:
                    print("Adapter trimming initiated with key:", args.trimkey)
                else:
                    print("No key provided for trimming.")
                    args.trimkey = "UCGGACCAGGCUUCAUUCCCC" if args.kingdom == 'plant' else "UGAGGUAGUAGGUUGUAUAGUU"
                    print("-trimkey used for discovering adapter sequence:", args.trimkey)

                key = args.trimkey.upper()
                keydna = key.translate(str.maketrans("uU", "tT"))

                # Sanity check
                if DNAcheck(keydna) == 1:
                    sys.exit("Error! Key for trimming adapters must only contain A, T, G, C, or U.")
                if not (20 <= len(keydna) <= 30):
                    sys.exit("Trim key must be between 20 and 30 nucleotides")

                subprocess.call(["mkdir", "-p", "trimmedLibraries"])
                process_fastqs(fastqs, keydna, args)

                print("Adapter trimming complete.")
                modified_fastqs = ["trimmedLibraries/t_" + os.path.basename(f) for f in fastqs]
                align_fastqs(args, modified_fastqs)
            else:
                align_fastqs(args, fastqs)

        # If there are any FASTA files, align them directly (no trimming needed)
        if fastas:
            for fasta in fastas:
                print(f"Processing untrimmed FASTA file: {fasta}")
            align_fastqs(args, fastas)

    bamfiles=glob.glob('alignments/*.bam')
    #Sort bam files
    print("Sorting bamfiles...")

    #Index and sort each bam file
    for bam in bamfiles:
        pysam.sort("-o", "alignments/" + Path(bam).stem + "_s.bam", "-@", str(args.threads), bam)
        pysam.index("alignments/" + Path(bam).stem + "_s.bam", "-@", str(args.threads))
        cmd=("rm -rf "+ bam) 
        run(cmd)        

    #Read counting
    mirna_counts={}
    # Merge library files, filter by minimum read length, index merged BAM file, and remove temporary file
    run("samtools merge -@ " + str(args.threads) + " -r alignments/merge1.bam alignments/*s.bam")
    run("samtools view -@ "+ str(args.threads) + " -e 'qlen>15' -bh alignments/merge1.bam > alignments/merged.bam")
    print("Indexing merged BAM file...")
    run("samtools index -@ " + str(args.threads) + " alignments/merged.bam")
    run("rm -rf alignments/merge1.bam")

    #Create a list of read groups in the merged bam file.
    rgs=[]
    bam_input = pysam.AlignmentFile("alignments/merged.bam", 'rb')
    rg=bam_input.header["RG"]
    for group in range(0,len(rg)):
        for values in rg[group].values():
            rgs.append(values)

     #Count number of reads for each miRNA loci
    print("Counting reads for each miRNA locus...")
    for bam in rgs:
        print("Counting reads in " + bam + "...")
        for x in tqdm.tqdm(mirna_dict,disable=None):
            #Set window for miRNA with variance
            if "Hairpin structure invalid" not in mirna_dict[x][11]:
                if mirna_dict[x][2] !=0 and mirna_dict[x][6]!=0:
                    mstart_indx=mirna_dict[x][2]-1
                    mstop_indx=mirna_dict[x][3]+1
                    mirstart_indx=mirna_dict[x][6]-1
                    mirstop_indx=mirna_dict[x][7]+1
                else:
                    print("Error! " + x +" mature or star sequence starts at first position of hairpin precursor. Please extend the precursor and rerun miRScore.")
                    sys.exit()

                add_values_in_dict(mirna_counts,x,[str(Path(bam).stem)])

                mir_count = get_counts(bam, x, mstart_indx, mstop_indx)
                mirs_count = get_counts(bam, x, mirstart_indx, mirstop_indx)
                tot_count = get_counts(bam, x)

                add_values_in_dict(mirna_counts, x, [mir_count, mirs_count,tot_count])
                
                if int(tot_count) != 0:
                    precision="{:.2f}".format((100*((int(mir_count)+int(mirs_count))/int(tot_count))))
                    add_values_in_dict(mirna_counts,x,[precision])
                else:
                    precision= 0.00
                    add_values_in_dict(mirna_counts,x,[precision])

    #Create pandas dataframe for writing reads file
    df = pd.DataFrame.from_dict(mirna_counts, orient='index')
    index = [x for x in mirna_counts for _ in bamfiles]

    # Reshape DataFrame and set column names
    df2 = pd.DataFrame(df.values.reshape(-1, 5), index=index,
                    columns=['library', 'mReads', 'msReads', 'allReads', 'precision'])

    #Add pass/fail column
    df2['precision'] = df2['precision'].astype('float')
    df2["mReads"] = df2["mReads"].astype(int)
    df2["msReads"] = df2["msReads"].astype(int)
    df2['result'] = df2.apply(lambda row: 'Pass' if row['mReads'] > 0 and row['precision'] >= 75 and row['msReads'] > 0 and (row['mReads'] + row['msReads'] >= 10) else 'Fail', axis=1)

    # Add flags column based on the reads and precision conditions
    def generate_flags(row):
        flags = []
        if row['mReads'] <= 0 or row['msReads']<=0:
            flags.append("No mature or star reads detected")
        elif (row['mReads'] + row['msReads']<10):
            flags.append("Less than 10 reads in a single library")
        elif row['precision'] < 75:
            flags.append("Precision less than 75%")
        else:
            flags.append("NA")
        return ', '.join(flags)
        
    #Add flags to reads.csv
    df2['flags'] = df2.apply(generate_flags, axis=1)

    # Save DataFrame to CSV
    filename = args.n + '_reads.csv' if args.n is not None else 'reads.csv'
    df2.to_csv(filename, sep=',', encoding='utf-8', index_label='miRNA')

    #Determine if miRNA and miRStar are found in multiple libraries. Precision is calculated individually, library by library.
    for x in mirna_counts:
        mirstar_reads = list(map(int, mirna_counts[x][2::5]))
        mir_reads = list(map(int, mirna_counts[x][1::5]))
        all_reads = list(map(int, mirna_counts[x][3::5]))
        precCount = list(map(float, mirna_counts[x][4::5]))

        #Changed from 'and' to 'or' on account that there should be reads detected for both
        if sum(mir_reads)==0 or sum(mirstar_reads)==0:
            mirna_dict[x][10]="Fail"
            if mirna_dict[x][11] =="NA":
                mirna_dict[x][11]="No mature or star reads detected"
                if sum(all_reads) != 0:
                    precision="{:.2f}".format((100*(sum(mir_reads)+sum(mirstar_reads))/sum(all_reads)))
                    add_values_in_dict(mirna_dict,x,[sum(mir_reads),sum(mirstar_reads),sum(all_reads),len(bamfiles),precision])
                else:
                    precision= 0.00
                    add_values_in_dict(mirna_dict,x,[sum(mir_reads),sum(mirstar_reads),sum(all_reads),len(bamfiles),precision])
            else:
                    mirna_dict[x][11]+=";No mature or star reads detected"       
                    if sum(all_reads) != 0:
                        precision="{:.2f}".format((100*(sum(mir_reads)+sum(mirstar_reads))/sum(all_reads)))
                        add_values_in_dict(mirna_dict,x,[sum(mir_reads),sum(mirstar_reads),sum(all_reads),len(bamfiles),precision])
                    else:
                        precision= 0.00
                        add_values_in_dict(mirna_dict,x,[sum(mir_reads),sum(mirstar_reads),sum(all_reads),len(bamfiles),precision]) 
        else:
            #combined miR/miR* reads
            totreads=[]
            #Number of libraries with mir/mir* reads and precision greater than 75
            count=0
            #mirReads
            mr=0
            #mirstarReads
            ms=0
            #Total reads
            all=0
            #Total precision across libraries counted in results
            precSum=0
            #Count libraries that have reads present
            for i in range(0,len(mir_reads)):
                #implement read floor for each library
                tot_c=int(mir_reads[i])+int(mirstar_reads[i])
                totreads.append(tot_c)
                if int(tot_c) >= 10 and int(mir_reads[i]) > 0 and int(mirstar_reads[i]) > 0 and precCount[i]>=75:
                    count=count+1
                    precSum=precSum+precCount[i]
                    mr=mr+int(mir_reads[i])
                    ms=ms+int(mirstar_reads[i])
                    all=all+int(all_reads[i])

            #calculate precision of libraries with reads present
            if count >= 1:
                prec=str(round(precSum/count,2))
                add_values_in_dict(mirna_dict,x,[mr,ms,all,count,prec])
            else:
                mirna_dict[x][10]="Fail"
                if sum(all_reads) != 0:
                    precision="{:.2f}".format((100*(sum(mir_reads)+sum(mirstar_reads))/sum(all_reads)))
                else:
                    precision= 0.00
                #If total count of miR and miR* is less than 10, fail the locus
                #implement read floor for each library
                floor=10
                #If any miRNAs have a library that meets the read floor, check precision and report it failed for precision.
                if any(num >= floor  for num in totreads):
                    if sum(mir_reads)>0 and sum(mirstar_reads)>0:
                        if len(bamfiles)>1:
                            if any(df2.loc[x]["flags"] == 'Precision less than 75%'):
                                if mirna_dict[x][11]=="NA":
                                    mirna_dict[x][11]="Precision less than 75%"
                                    add_values_in_dict(mirna_dict,x,[sum(mir_reads),sum(mirstar_reads),sum(all_reads),len(bamfiles),precision])
                                else:
                                    mirna_dict[x][11]+=";Precision less than 75%"      
                                    add_values_in_dict(mirna_dict,x,[sum(mir_reads),sum(mirstar_reads),sum(all_reads),len(bamfiles),precision]) 
                            else:
                                if mirna_dict[x][11]=="NA":
                                    mirna_dict[x][11]="No mature or star reads detected"
                                    add_values_in_dict(mirna_dict,x,[sum(mir_reads),sum(mirstar_reads),sum(all_reads),len(bamfiles),precision])
                                else:
                                    mirna_dict[x][11]+=";No mature or star reads detected"      
                                    add_values_in_dict(mirna_dict,x,[sum(mir_reads),sum(mirstar_reads),sum(all_reads),len(bamfiles),precision])    
                        else:
                            if df2.loc[x]["flags"] == 'Precision less than 75%':
                                if mirna_dict[x][11]=="NA":
                                    mirna_dict[x][11]="Precision less than 75%"
                                    add_values_in_dict(mirna_dict,x,[sum(mir_reads),sum(mirstar_reads),sum(all_reads),len(bamfiles),precision])
                                else:
                                    mirna_dict[x][11]+=";Precision less than 75%"      
                                    add_values_in_dict(mirna_dict,x,[sum(mir_reads),sum(mirstar_reads),sum(all_reads),len(bamfiles),precision]) 
                            else:
                                if mirna_dict[x][11]=="NA":
                                    mirna_dict[x][11]="No mature or star reads detected"
                                    add_values_in_dict(mirna_dict,x,[sum(mir_reads),sum(mirstar_reads),sum(all_reads),len(bamfiles),precision])
                                else:
                                    mirna_dict[x][11]+=";No mature or star reads detected"      
                                    add_values_in_dict(mirna_dict,x,[sum(mir_reads),sum(mirstar_reads),sum(all_reads),len(bamfiles),precision])  
                    #Otherwise, if no libraries had at least 10 reads, report it failed for reads.
                else:                    
                    if mirna_dict[x][11] =="NA":
                        mirna_dict[x][11]="Less than 10 reads in a single library"
                        add_values_in_dict(mirna_dict,x,[sum(mir_reads),sum(mirstar_reads),sum(all_reads),len(bamfiles),precision])
                    else:
                        mirna_dict[x][11]+=";Less than 10 reads in a single library"       
                        add_values_in_dict(mirna_dict,x,[sum(mir_reads),sum(mirstar_reads),sum(all_reads),len(bamfiles),precision])

    fails=['Hairpin is less than 50 basepairs','Mature not found in hairpin sequence','miRNA multimaps to hairpin','Sequence contained characters besides U,A, T, G, or C',"Hairpin structure invalid"]
    
    #Check failed miRNAs and add to results
    for mirfail in failed:
        if failed[mirfail] not in fails:
            hairpin = str(hp_dict[mirfail].seq).upper()
                #mature sequence
            og_mir=str(name_dict[mirfail][0])
            mature = str(mir_dict[og_mir].seq).upper()
            maturestart=hairpin.index(mature)
            maturestop=(hairpin.index(mature)+len(mature))

            (ss, mfe)=RNA.fold(hairpin)
            starpos=get_s_rels(maturestart,maturestop,ss)
            if starpos is not None:
                star=hairpin[(starpos[0]-1):starpos[1]].upper()
            else:
                star = "NA"
            if starpos is not None:
                add_values_in_dict(mirna_dict,mirfail,[mature.translate(str.maketrans("tT", "uU")),len(mature),maturestart,maturestop,star,len(star),"NA","NA",hairpin,len(hairpin),"Fail",str(failed[mirfail]),"NA","NA","NA","NA","NA","NA","NA"])
            else:
                add_values_in_dict(mirna_dict,mirfail,[mature.translate(str.maketrans("tT", "uU")),len(mature),maturestart,maturestop,"NA","NA","NA","NA",hairpin,len(hairpin),"Fail",str(failed[mirfail]),"NA","NA","NA","NA","NA","NA","NA"])
        else:
            seq= str(hp_dict[mirfail].seq).upper()
            og_mir=str(name_dict[mirfail][0])
            mat = str(mir_dict[og_mir].seq).upper()
            add_values_in_dict(mirna_dict,mirfail,[mat.translate(str.maketrans("tT", "uU")),len(mat),"NA","NA","NA","NA","NA","NA",seq.translate(str.maketrans("tT", "uU")),len(seq),"Fail",str(failed[mirfail]),"NA","NA","NA","NA","NA","NA","NA"])
    DEVNULL = open(os.devnull, 'w')

    #Write final output file
    header = ["name", "mSeq", "mLen", "mStart", "mStop", "msSeq", "msLen", "msStart", "msStop", "precSeq", "precLen", "result", "flags", "mismatches", "bulges", "mReads", "msReads", "totReads", "numLibraries", "precision"]
    output_file = "Results.csv"

    with open(output_file, mode="w", newline='') as csv_out:
        csv_writer = writer(csv_out)
        csv_writer.writerow(header)
        for k, v in mirna_dict.items():
            csv_writer.writerow([k, *v])
    # swap miRNA and miRNA* position if miRNA is more abundant
    swap_mirnas("Results.csv")
    ##############################################
    ##Beginning of mirna alternative analyzer
    if args.rescue is not None:
        print("*************")
        print("Reevaluting failed miRNAs for potential alternatives...")
        print("*************")

        alt_mirnas=predict_alt_mirnas(mirna_dict)
        pred_dict=score_alternative_mirnas(alt_mirnas,hp_dict)

        print("miRNAs predicted from failed miRNA loci:")
        print(len(alt_mirnas))

        pred_counts = {}
        for bam in rgs:
            print("Counting reads in " + bam + "...")
            for x in tqdm.tqdm(pred_dict, disable=None):
                if "Hairpin structure invalid" in pred_dict[x][11]:
                    continue

                # choose bounds
                if pred_dict[x][2] != 0 and pred_dict[x][6] != 0:
                    mstart_indx  = pred_dict[x][2] - 1
                    mstop_indx   = pred_dict[x][3] + 1
                    mirstart_indx = pred_dict[x][6] - 1
                    mirstop_indx  = pred_dict[x][7] + 1
                else:
                    mstart_indx  = pred_dict[x][2]
                    mstop_indx   = pred_dict[x][3] + 1
                    mirstart_indx = pred_dict[x][6]
                    mirstop_indx  = pred_dict[x][7] + 1

                # ALWAYS count + store (this is the critical part)
                add_values_in_dict(pred_counts, x, [str(Path(bam).stem)])

                mir_count  = get_counts(bam, x, mstart_indx,  mstop_indx)
                mirs_count = get_counts(bam, x, mirstart_indx, mirstop_indx)
                tot_count  = get_counts(bam, x)
                tot_locus  = int(mir_count) + int(mirs_count)

                add_values_in_dict(pred_counts, x, [mir_count, mirs_count, tot_locus, tot_count])

                if int(tot_count) != 0:
                    precision = "{:.2f}".format(100 * (tot_locus / int(tot_count)))
                else:
                    precision = 0.00
                add_values_in_dict(pred_counts, x, [precision])

        # Create DataFrame from dictionary
        df = pd.DataFrame.from_dict(pred_counts, orient='index')

        # Repeat the index for each BAM file
        index = [x for x in pred_counts for _ in bamfiles]

        # Reshape DataFrame and set column names
        df2 = pd.DataFrame(df.values.reshape(-1, 6), index=index,
                    columns=['library', 'mReads', 'msReads','locusReads', 'hairpinReads', 'precision'])

        # Convert 'precision' to float and add 'result' column
        df2['precision'] = df2['precision'].astype(float)
        df2['result'] = df2.apply(lambda row: 'Pass' if row['mReads'] > 0 and row['precision'] >= 75 and row['msReads'] > 0 and (row['mReads'] + row['msReads'] >= 10) else 'Fail', axis=1)
        df2['flags'] = df2.apply(generate_flags, axis=1)

        # Save DataFrame to CSV
        filename = args.n + '_alt_reads.csv' if args.n is not None else 'alt_reads.csv'
        df2.to_csv(filename, sep=',', encoding='utf-8', index_label='miRNA')

        for x in tqdm.tqdm(list(pred_counts),desc="Counting miRNAs",disable=None):
    #Count mir and mirstar reads for each locus
            mircounts=[]
            mirscounts=[]
            mirstar_reads=(pred_counts[x][2::6])
            mir_reads=(pred_counts[x][1::6])
            all_reads=(pred_counts[x][4::6])
            for c in mirstar_reads:
                mirscounts.append(int(c))
            for c in mir_reads:
                mircounts.append(int(c))
            precCount=[]
            p=(pred_counts[x][5::6])
            for c in p:
                precCount.append(float(c)) 

            if sum(mircounts)==0 or sum(mirscounts)==0:
                del(pred_dict[x])
            else:
                count=0
                mr=0
                ms=0
                all=0
                precSum=0
            #Count libraries that have reads present
                for i in range(0,len(mir_reads)):
                    sumreads=int(mir_reads[i])+int(mirstar_reads[i])
                    if int(mir_reads[i]) > 0 and int(mirstar_reads[i]) > 0 and precCount[i]>=75 and sumreads >9:
                        count=count+1
                        precSum=precSum+precCount[i]
                        mr=mr+int(mir_reads[i])
                        ms=ms+int(mirstar_reads[i])
                        all=all+int(all_reads[i])

                #calculate precision of libraries with reads present
                if count >= 1:
                    prec=str(round(precSum/count,2))
                    add_values_in_dict(pred_dict,x,[mr,ms,all,count,prec])
                else:
                    del(pred_dict[x])



        output_file = args.n + "_alt_mirna_results.csv" if args.n else "alt_mirna_results.csv"

        header = ["name", "mSeq", "mLen", "mStart", "mStop", "msSeq", "msLen", "msStart", 'msStop',
            "precSeq", "precLen", "result", "flags", "mismatches", "bulges", 'mReads', 'msReads',
            "totReads", "nlibraries", "precision"]

        with open(output_file, mode="w", newline='') as csv_out:
            csv_writer = writer(csv_out)
            csv_writer.writerow(header)
            for k, v in pred_dict.items():
                csv_writer.writerow([k, *v])
                
    #Make directories for post-script images
    subprocess.call(["mkdir","RNAplots"])
    subprocess.call(["mkdir","RNAplots/failed"])
    subprocess.call(["mkdir","RNAplots/passed"])

    if args.nostrucvis == None:
        #Index hairpins for StrucVis and make directories
        subprocess.call(["mkdir","strucVis"])
        subprocess.call(["mkdir","strucVis/failed"])
        subprocess.call(["mkdir","strucVis/passed"])
        cmd='samtools faidx tmp.hairpin.fa'
        run(cmd)

        for x in tqdm.tqdm(mirna_dict,desc="strucVis",disable=None):
            if mirna_dict[x][11]!="Hairpin structure invalid":
                #Create strucVis plots for each miRNA
                if "Fail" in mirna_dict[x][10]:
                    hp_len=str(mirna_dict[x][9])
                    cmd="strucVis -b alignments/merged.bam -g tmp.hairpin.fa -c "+ x + ":1-" + hp_len + " -s plus -p strucVis/failed/" +x+ ".ps -n " + x
                    run(cmd)
                else:
                    hp_len=str(mirna_dict[x][9])
                    cmd="strucVis -b alignments/merged.bam -g tmp.hairpin.fa -c "+ x + ":1-" + hp_len + " -s plus -p strucVis/passed/" +x+ ".ps -n " + x
                    run(cmd)


    #Create RNAplots depicting MIRNA loci
    os.chdir('RNAplots')
    # Main loop to process the miRNA dictionary
    for x in tqdm.tqdm(mirna_dict, desc="RNAplots", disable=None):
        if mirna_dict[x][11] not in fails:
            seq = str(mirna_dict[x][8])  # Full sequence
            mstart = int(mirna_dict[x][2])  # Mature start position
            mstop = int(mirna_dict[x][3])   # Mature stop position
            mirstart = int(mirna_dict[x][6])  # Mature* start position
            mirstop = int(mirna_dict[x][7])   # Mature* stop position
            status = mirna_dict[x][10]  # Pass or Fail status
            
            # Process the miRNA
            run_rnaplot(x, seq, mstart, mstop, mirstart, mirstop, status)               
    os.chdir('..')


    #Add legend to passed RNAplots
    #Add legend to failed RNAplots
    files=glob.glob("RNAplots/failed/*.fa")
    for file in files:
        x=Path(file).stem
        mstart=int(mirna_dict[x][2])
        mstop=int(mirna_dict[x][3])
        mirstart=int(mirna_dict[x][6])
        mirstop=int(mirna_dict[x][7])
        cmd="awk -v n=14 -v s='0 0 0 setrgbcolor/Helvetica findfont\\n9 scalefont setfont \\n72 114 moveto \\n(miRNA: "+ x+" ) show' 'NR == n {print s} {print}' RNAplots/failed/"+x+"_ss.eps | awk -v n=14 -v s='/Helvetica findfont\\n9 scalefont setfont \\n72 104 moveto \\n(miRNA locus: "+ str(mstart)+ "-" + str(mstop)+ ") show \\n/Helvetica findfont\\n9 scalefont setfont \\n72 94 moveto \\n(miRNAStar locus: "+ str(mirstart)+ "-" + str(mirstop)+ ") show \\n0 0.5 1 setrgbcolor \\n65 96 4 0 360 arc closepath fill stroke \\n 1 0.5 0 setrgbcolor \\n65 106 4 0 360 arc closepath fill stroke' 'NR==n {print s} {print}'  > RNAplots/failed/"+x+"_plot.ps"
        run(cmd)

    #Add legend to passed RNAplots
    files=glob.glob("RNAplots/passed/*.fa")
    for file in files:
        x=Path(file).stem
        mstart=int(mirna_dict[x][2])
        mstop=int(mirna_dict[x][3])
        mirstart=int(mirna_dict[x][6])
        mirstop=int(mirna_dict[x][7])
        cmd="awk -v n=14 -v s='0 0 0 setrgbcolor/Helvetica findfont\\n9 scalefont setfont \\n72 114 moveto \\n(miRNA: "+ x+" ) show' 'NR == n {print s} {print}' RNAplots/passed/"+x+"_ss.eps | awk -v n=14 -v s='/Helvetica findfont\\n9 scalefont setfont \\n72 104 moveto \\n(miRNA locus: "+ str(mstart)+ "-" + str(mstop)+ ") show \\n/Helvetica findfont\\n9 scalefont setfont \\n72 94 moveto \\n(miRNAStar locus: "+ str(mirstart)+ "-" + str(mirstop)+ ") show \\n0 0.5 1 setrgbcolor \\n65 96 4 0 360 arc closepath fill stroke \\n 1 0.5 0 setrgbcolor \\n65 106 4 0 360 arc closepath fill stroke' 'NR==n {print s} {print}'  > RNAplots/passed/"+x+"_plot.ps"
        run(cmd)

    ps_files = glob.glob(os.path.join('RNAplots/passed/', '*_plot.ps'))
    # Loop through each file and convert it to PDF
    for ps_file in ps_files:
        # Construct the PDF file name
        pdf_file = ps_file.replace('.ps', '.pdf')     
        # Ghostscript command to convert PS to PDF
        cmd=("ps2pdf "+ ps_file + " " + pdf_file)
        run(cmd)

    ps_files = glob.glob(os.path.join('RNAplots/failed/', '*_plot.ps'))
    # Loop through each file and convert it to PDF
    for ps_file in ps_files:
        # Construct the PDF file name
        pdf_file = ps_file.replace('.ps', '.pdf')
        cmd=("ps2pdf "+ ps_file + " " + pdf_file)
        run(cmd)


#Remove unecessary files and move to output directory
    cleanup_directories()

    print("Summary")
    print("____________________________________")
    if args.n != None:
        print("Results file:" + args.n + "_miRScore_results")
    print("Number of submitted candidate MIRNA loci: " + str(len(hp_dict)))
    res_pass = 0
    for key in mirna_dict:
        if mirna_dict[key][10] == "Pass":
            res_pass = res_pass + 1
    print("Total number of MIRNA loci validated to meet all criteria: " + str(res_pass) )
    res_fail = 0
    for key in mirna_dict:
        if mirna_dict[key][10] == "Fail":
            res_fail = res_fail + 1
    print("Total number of Failed MIRNA loci: " + str(res_fail))
    if args.rescue is not None:
        print("     Failed miRNAs that could pass with an alternative miRNA: " + str(len(pred_dict)))
    print('')
    executionTime = (time.time() - startTime)
    print('Time to run: ' + str(int(executionTime)) + " seconds")
    print('Run Completed!')

if __name__ == "__main__":
    main()

