#!/usr/bin/env python3
"""
TranslatorY - pYthon successor to TranslatorX.
Multiple alignment of nucleotide sequences guided by amino acid translations.
"""

import sys
import os
import subprocess
import argparse
import warnings

# --- Suppress the specific free-threaded GIL warning from Biopython C-extensions ---
warnings.filterwarnings("ignore", message=".*global interpreter lock.*", category=RuntimeWarning)

from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio.Data import CodonTable


def setup_custom_codon_tables():
    """
    Register custom genetic codes used historically by TranslatorX 
    that might not match standard NCBI tables directly.
    """
    try:
        arthropod_table = CodonTable.unambiguous_dna_by_id[5]
        custom_forward_table = dict(arthropod_table.forward_table)
        custom_forward_table["AGG"] = "K" 
        
        custom_table_100 = CodonTable.CodonTable(
            forward_table=custom_forward_table,
            start_codons=arthropod_table.start_codons,
            stop_codons=arthropod_table.stop_codons
        )
        CodonTable.unambiguous_dna_by_id[100] = custom_table_100
        CodonTable.unambiguous_dna_by_id[101] = CodonTable.unambiguous_dna_by_id[33]
    except Exception as e:
        print(f"Warning: Could not register custom tables 100 and 101: {e}")

def parse_arguments():
    """Parse command line arguments and display the help menu."""
    
    # Text to display AFTER the options (epilog)
    codes_table = """
Available genetic codes (-c):
  1   Standard
  2   Vertebrate Mitochondrial
  3   Yeast Mitochondrial
  4   Mold, Protozoan, Coelenterate Mito; Mycoplasma; Spiroplasma
  5   Invertebrate Mitochondrial
  6   Ciliate, Dasycladacean, Hexamita Nuclear
  9   Echinoderm, Flatworm Mitochondrial
  10  Euplotid Nuclear
  11  Bacterial and Plant Plastid
  12  Alternative Yeast Nuclear
  13  Ascidian Mitochondrial
  14  Alternative Flatworm Mitochondrial
  15  Blepharisma Macronuclear
  16  Chlorophycean Mitochondrial
  21  Trematode Mitochondrial
  22  Scenedesmus obliquus Mitochondrial
  23  Thraustochytrium Mitochondrial
  100 Ancestral Arthropod Mitochondrial Code (AGG=K)
  101 Hemichordate Mitochondrial
"""
    
    parser = argparse.ArgumentParser(
        description=(
            "TranslatorY - pYthon successor to TranslatorX.\n"
            "Multiple alignment of nucleotide sequences guided by amino acid translations."
        ),
        epilog=codes_table,
        formatter_class=argparse.RawTextHelpFormatter,
        usage="\n  %(prog)s [-h] -i INPUT [-o OUTPUT] [-p {muscle,mafft,clustalw,t_coffee,prank}] [-c CODE] [-g] [-b]"
    )

    parser.add_argument('-i', '--input', required=True, help="Input nucleotide FASTA file")
    parser.add_argument('-o', '--output', default="translatorY_res", help="Output files prefix")
    parser.add_argument('-p', '--program', default="mafft", choices=['muscle', 'mafft', 'clustalw', 't_coffee', 'prank'], help="Alignment program")
    parser.add_argument('-c', '--code', type=int, default=1, help="Genetic code table number (default: 1)")
    parser.add_argument('-g', '--guess', action='store_true', help="Guess best reading frame to minimize stop codons")
    parser.add_argument('-b', '--gblocks', action='store_true', help="Apply Gblocks to clean the alignment")

    # If the user provides a directory path ending with a slash, 
    # append a default prefix to prevent the creation of hidden files.
    args = parser.parse_args()
    if args.output.endswith('/') or args.output.endswith('\\'):
        args.output += 'translatory_res'    
    return args

# Initialize the custom tables before the rest of the script runs
setup_custom_codon_tables()

def clean_sequence(seq_str):
    """Remove gaps, dots, and whitespaces from the input sequence."""
    return seq_str.replace("-", "").replace(".", "").replace(" ", "").replace("\n", "")

def find_best_frame(seq_str, genetic_code):
    """Find the reading frame (0, 1, or 2) that produces the fewest stop codons."""
    min_stops = float('inf')
    best_frame = 0
    best_aa = ""

    for frame in range(3):
        # Truncate sequence to a multiple of 3 for translation
        truncated_len = ((len(seq_str) - frame) // 3) * 3
        frame_seq = Seq(seq_str[frame : frame + truncated_len])

        # Translate to amino acids
        aa_seq = str(frame_seq.translate(table=genetic_code, to_stop=False))
        stop_count = aa_seq.count('*')

        if stop_count < min_stops:
            min_stops = stop_count
            best_frame = frame
            best_aa = aa_seq

    return best_frame, best_aa, min_stops

def run_aligner(program, input_fasta, output_fasta):
    """Run the selected multiple sequence alignment tool via subprocess."""
    print(f"Running {program}...")

    try:
        if program == "mafft":
            cmd = ["mafft", "--auto", input_fasta]
            with open(output_fasta, "w") as out_f:
                subprocess.run(cmd, stdout=out_f, stderr=subprocess.DEVNULL, check=True)

        elif program == "muscle":
            cmd = ["muscle", "-align", input_fasta, "-output", output_fasta]
            subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)

        elif program == "clustalw":
            cmd = ["clustalw", f"-INFILE={input_fasta}", f"-OUTFILE={output_fasta}", "-OUTPUT=FASTA"]
            subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)

        else:
            print(f"Error: Alignment program {program} handling not fully implemented in this script yet.")
            sys.exit(1)

    except FileNotFoundError:
        print(f"Error: {program} executable not found. Please ensure it is installed and in your PATH.")
        sys.exit(1)

    except subprocess.CalledProcessError as e:
        print(f"Error during alignment: {e}")
        sys.exit(1)

def generate_html_report(output_prefix, aa_ali_file, nt_ali_file):
    """
    Generate a static, zero-dependency HTML report replicating TranslatorX's 
    interleaved and colored codon format. Guaranteed to work on any browser.
    """
    
    # Read aligned records
    aa_records = list(SeqIO.parse(aa_ali_file, "fasta"))
    nt_records = list(SeqIO.parse(nt_ali_file, "fasta"))

    # Original TranslatorX color scheme for amino acids
    colors = {
        'A': '#C8C8C8', 'V': '#C8C8C8', 'L': '#C8C8C8', 'I': '#C8C8C8', 'P': '#C8C8C8', 'W': '#C8C8C8', 'F': '#C8C8C8', 'M': '#C8C8C8', 'G': '#C8C8C8',
        'S': '#15C015', 'T': '#15C015', 'C': '#15C015', 'Y': '#15C015', 'N': '#15C015', 'Q': '#15C015',
        'D': '#C048C0', 'E': '#C048C0',
        'K': '#F01505', 'R': '#F01505', 'H': '#F01505',
        '-': '#FFFFFF', '*': '#FFFFFF', 'X': '#FFFFFF'
    }

    html = [f"<!DOCTYPE html>\n<html lang='en'>\n<head>\n<meta charset='UTF-8'>\n<title>TranslatorY Results - {output_prefix}</title>"]
    
    # Modern CSS to make the alignment beautiful and readable
    html.append("<style>")
    html.append("body { font-family: 'Courier New', monospace; background-color: #f8f9fa; margin: 20px; }")
    html.append("h1 { color: #333; font-family: Arial, sans-serif; }")
    html.append(".block { background-color: #fff; padding: 20px; margin-bottom: 25px; border-radius: 8px; box-shadow: 0 2px 5px rgba(0,0,0,0.1); overflow-x: auto; white-space: nowrap; }")
    html.append(".row { display: flex; align-items: center; margin-bottom: 5px; }")
    html.append(".name { width: 200px; font-weight: bold; font-size: 13px; flex-shrink: 0; overflow: hidden; text-overflow: ellipsis; padding-right: 15px; color: #0056b3; }")
    html.append(".codon { display: inline-flex; flex-direction: column; align-items: center; width: 34px; border-radius: 4px; margin-right: 2px; padding: 3px 0; border: 1px solid rgba(0,0,0,0.05); }")
    html.append(".aa { font-weight: bold; font-size: 15px; color: #000; }")
    html.append(".nt { font-size: 11px; color: #444; letter-spacing: 0.5px; margin-top: 2px; }")
    html.append("</style>\n</head>\n<body>")

    html.append("<h1>TranslatorY Alignment Results</h1>")

    if not aa_records:
        html.append("<p>Error: No sequences found in alignment.</p></body></html>")
        with open(f"{output_prefix}.html", "w") as f:
            f.write("\n".join(html))
        return

    seq_len = len(aa_records[0].seq)
    wrap_len = 60  # Number of amino acids displayed per line (block)

    # Generate interwoven blocks of alignments (60 Amino Acids wide)
    for start in range(0, seq_len, wrap_len):
        end = min(start + wrap_len, seq_len)
        html.append("<div class='block'>")
        html.append(f"<h3 style='font-family: Arial, sans-serif; margin-top: 0; color: #666;'>Positions {start+1} - {end}</h3>")

        for aa_rec, nt_rec in zip(aa_records, nt_records):
            aa_sub = str(aa_rec.seq)[start:end]
            nt_sub = str(nt_rec.seq)[start*3:end*3]

            html.append("<div class='row'>")
            html.append(f"<div class='name' title='{aa_rec.id}'>{aa_rec.id}</div>")

            # Print each Amino Acid and its corresponding 3 Nucleotides
            for i, aa in enumerate(aa_sub):
                codon = nt_sub[i*3 : i*3+3]
                bg_color = colors.get(aa.upper(), '#FFFFFF')
                html.append(f"<div class='codon' style='background-color: {bg_color};'><span class='aa'>{aa}</span><span class='nt'>{codon}</span></div>")

            html.append("</div>") # End row

        html.append("</div>") # End block

    html.append("</body>\n</html>")

    # Save the file
    with open(f"{output_prefix}.html", "w") as f:
        f.write("\n".join(html))

def apply_gblocks_mask(aa_ali_file, aa_gb_file, nt_ali_file, nt_gb_file):
    """Parses Gblocks AA output, maps the kept columns, and trims the NT alignment."""
    orig_aa = list(SeqIO.parse(aa_ali_file, "fasta"))
    gb_aa = list(SeqIO.parse(aa_gb_file, "fasta"))
    orig_nt = list(SeqIO.parse(nt_ali_file, "fasta"))
    
    # Remove spaces that Gblocks injects into the output FASTA
    gb_seqs = [str(r.seq).replace(" ", "") for r in gb_aa]
    orig_seqs = [str(r.seq) for r in orig_aa]
    
    if not gb_seqs:
        return False
        
    num_orig_cols = len(orig_seqs[0])
    num_gb_cols = len(gb_seqs[0])
    
    keep_cols = []
    gb_col_idx = 0
    
    # Identify which columns were kept by Gblocks
    for orig_col_idx in range(num_orig_cols):
        if gb_col_idx >= num_gb_cols:
            break
        orig_col = [s[orig_col_idx] for s in orig_seqs]
        gb_col = [s[gb_col_idx] for s in gb_seqs]
        
        if orig_col == gb_col:
            keep_cols.append(orig_col_idx)
            gb_col_idx += 1
            
    # Apply the mask to the Nucleotide alignment
    gb_nt_records = []
    for nt_rec in orig_nt:
        orig_nt_seq = str(nt_rec.seq)
        new_nt_seq = "".join([orig_nt_seq[c*3 : c*3+3] for c in keep_cols])
        gb_nt_records.append(SeqRecord(Seq(new_nt_seq), id=nt_rec.id, description=""))
        
    SeqIO.write(gb_nt_records, nt_gb_file, "fasta")
    
    # Clean up the AA file (restore proper names and remove spaces) for the HTML report
    gb_aa_clean_records = []
    for i, gb_rec in enumerate(gb_aa):
        gb_aa_clean_records.append(SeqRecord(Seq(gb_seqs[i]), id=orig_aa[i].id, description=""))
    SeqIO.write(gb_aa_clean_records, aa_gb_file, "fasta")
    
    return True      

def main():
    args = parse_arguments()
    trimmed_nt_dict = {}
    aa_records = []

    print("Reading and translating sequences...")

    # Step 1 & 2: Read, Clean, and Translate

    for record in SeqIO.parse(args.input, "fasta"):
        clean_nt = clean_sequence(str(record.seq))

        if args.guess:
            best_frame, best_aa, stops = find_best_frame(clean_nt, args.code)
            if stops > 0:
                print(f"Warning: {record.id} has {stops} stop codons in its best reading frame.")
        else:
            truncated_len = (len(clean_nt) // 3) * 3
            best_frame = 0
            best_aa = str(Seq(clean_nt[0:truncated_len]).translate(table=args.code, to_stop=False))
            stops = best_aa.count('*')
            if stops > 0:
                print(f"Warning: {record.id} has {stops} stop codons. Consider using -g | --guess to guess frame.")

        # Store the frame-corrected nucleotide sequence for back-translation later
        nt_trimmed = clean_nt[best_frame : best_frame + (len(best_aa) * 3)]
        trimmed_nt_dict[record.id] = nt_trimmed

        # Create unaligned AA record
        aa_records.append(SeqRecord(Seq(best_aa), id=record.id, description=""))

    # Save unaligned AA for the aligner
    aa_unaligned_file = f"{args.output}.aaseqs.fasta"
    SeqIO.write(aa_records, aa_unaligned_file, "fasta")

    # Step 3: Run Aligner
    aa_aligned_file = f"{args.output}.aa_ali.fasta"
    run_aligner(args.program, aa_unaligned_file, aa_aligned_file)

    # Step 4: Back-translation
    print("Mapping nucleotide sequences back to the amino acid alignment...")
    aligned_aa_records = SeqIO.parse(aa_aligned_file, "fasta")
    aligned_nt_records = []

    for aa_rec in aligned_aa_records:
        seq_id = aa_rec.id
        aligned_aa = str(aa_rec.seq)
        unaligned_nt = trimmed_nt_dict[seq_id]
        aligned_nt = ""
        nt_idx = 0

        # Build the gapped nucleotide sequence
        for amino_acid in aligned_aa:
            if amino_acid == "-":
                aligned_nt += "---"
            else:
                aligned_nt += unaligned_nt[nt_idx : nt_idx + 3]
                nt_idx += 3
        aligned_nt_records.append(SeqRecord(Seq(aligned_nt), id=seq_id, description=""))

    nt_aligned_file = f"{args.output}.nt_ali.fasta"
    SeqIO.write(aligned_nt_records, nt_aligned_file, "fasta")

    # Default to the standard alignment files if Gblocks is not triggered
    aa_final_file = aa_aligned_file
    nt_final_file = nt_aligned_file

    if args.gblocks:
        print("Running Gblocks to clean the alignment...")
        try:
            # Removed 'check=True' because Gblocks often returns non-zero exit codes even upon success.
            subprocess.run(["Gblocks", aa_aligned_file, "-t=p"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
            
            # Gblocks' hardcoded output filenames
            gb_raw_fasta = aa_aligned_file + "-gb"
            gb_raw_htm = aa_aligned_file + "-gb.htm"
            
            # Our clean, modern filenames
            aa_gb_file = f"{args.output}.aa_ali_gblocks.fasta"
            htm_gb_file = f"{args.output}.aa_ali_gblocks.html"
            nt_gb_file = f"{args.output}.nt_ali_gblocks.fasta"
            
            # Verify success by checking if Gblocks actually created the output file
            if os.path.exists(gb_raw_fasta):
                # Rename the files to fix the extensions
                os.rename(gb_raw_fasta, aa_gb_file)
                if os.path.exists(gb_raw_htm):
                    os.rename(gb_raw_htm, htm_gb_file)
                
                success = apply_gblocks_mask(aa_aligned_file, aa_gb_file, nt_aligned_file, nt_gb_file)
                if success:
                    print(f"Gblocks cleaning successful! Cleaned alignments saved.")
                    # Update variables to feed the HTML report generator with the cleaned files
                    aa_final_file = aa_gb_file
                    nt_final_file = nt_gb_file
                else:
                    print("Warning: Gblocks removed all blocks! Falling back to uncleaned alignment.")
            else:
                print("Error: Gblocks failed to produce the expected output file. Skipping cleaning.")
                
        except FileNotFoundError:
            print("Error: 'Gblocks' executable not found in PATH. Skipping cleaning.")
        except Exception as e:
            print(f"Error during Gblocks execution: {e}. Skipping cleaning.")
            
    # Step 5: HTML Report                                                                                                     
    print("Generating modern HTML report...")
    # Passing the dynamically assigned files (either cleaned or original) to the report generator
    generate_html_report(args.output, aa_final_file, nt_final_file)
    
    print(f"Done! Results saved with prefix '{args.output}'.")
    print(f"Open {args.output}.html in your web browser to view the alignments.")

if __name__ == "__main__":
    main()
