#!/opt/conda/conda-bld/transabyss_1774380212466/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl/bin/python

# written by Ka Ming Nip
# Copyright 2018 Canada's Michael Smith Genome Sciences Centre

import argparse
import glob
import math
import multiprocessing
import os
import shutil
import string
import sys
import textwrap
import time
from transabyss import package_info
from transabyss import psl_cid_extractor
from transabyss.adj_utils import has_edges
from transabyss.adj_utils import remove_redundant_paths
from transabyss.adj_utils import unbraid
from transabyss.adj_utils import walk
from transabyss.common_utils import StopWatch
from transabyss.common_utils import check_env
from transabyss.common_utils import is_empty_txt
from transabyss.common_utils import log
from transabyss.common_utils import path_action
from transabyss.common_utils import paths_action
from transabyss.common_utils import run_shell_cmd
from transabyss.common_utils import threshold_action
from transabyss.common_utils import touch
from transabyss.fasta_utils import abyssmap_merge_fastas
from transabyss.fasta_utils import blat_merge_fastas
from transabyss.fasta_utils import blat_self_align
from transabyss.fasta_utils import filter_fasta


TRANSABYSS_VERSION = package_info.VERSION
TRANSABYSS_NAME = package_info.NAME

STAGE_DBG_STAMP = "DBG.COMPLETE"
STAGE_UNITIGS_STAMP = "UNITIGS.COMPLETE"
STAGE_CONTIGS_STAMP = "CONTIGS.COMPLETE"
STAGE_REFERENCES_STAMP = "REFERENCES.COMPLETE"
STAGE_JUNCTIONS_STAMP = "JUNCTIONS.COMPLETE"
STAGE_FINAL_STAMP = "FINAL.COMPLETE"
STAGE_DEPENDENCY = [STAGE_FINAL_STAMP, STAGE_JUNCTIONS_STAMP, STAGE_REFERENCES_STAMP, STAGE_CONTIGS_STAMP, STAGE_UNITIGS_STAMP, STAGE_DBG_STAMP]

STAGE_DBG = 'dbg'
STAGE_UNITIGS = 'unitigs'
STAGE_CONTIGS = 'contigs'
STAGE_REFERENCES = 'references'
STAGE_JUNCTIONS = 'junctions'
STAGE_FINAL = 'final'
STAGES = [STAGE_DBG, STAGE_UNITIGS, STAGE_CONTIGS, STAGE_REFERENCES, STAGE_JUNCTIONS, STAGE_FINAL]

PACKAGEDIR = package_info.PACKAGEDIR
BINDIR = package_info.BINDIR
SKIP_PSL_SELF = os.path.join(BINDIR, 'skip_psl_self.awk')
SKIP_PSL_SELF_SS = os.path.join(BINDIR, 'skip_psl_self_ss.awk')

REQUIRED_EXECUTABLES = ['abyss-pe', 'MergeContigs', 'abyss-filtergraph', 'abyss-junction', 'blat', 'abyss-map']
REQUIRED_SCRIPTS = [SKIP_PSL_SELF, SKIP_PSL_SELF_SS]

def defaultwd():
    """Return the default output directory.
    """
    
    return os.path.join(os.getcwd(), "transabyss_" + TRANSABYSS_VERSION + "_assembly")
#enddef

def compare_stages(stage1, stage2):    
    return STAGES.index(stage2) - STAGES.index(stage1)
#enddef

def check_stamp(prefix, stamp):
    """Check whether the stamp and any downstream stamp(s) already exist.
    """
    
    for s in STAGE_DEPENDENCY:
        if s == stamp:
            return os.path.exists(prefix + '.' + s)
        else:
            if os.path.exists(prefix + '.' + s):
                # This downstream stage is already done.
                return True
            #endif
        #endif
    #endfor
    
    return False
#enddef

def make_stamp(prefix, stamp):
    """Create the stamp.
    """
    
    touch(prefix + '.' + stamp)
#enddef

def required_files_exist(files):
    """Verify the required files exist; quit the program otherwise.
    """
    
    missing = []
    
    for f in files:
        if not os.path.exists(f):
            missing.append(f)
        #endif
    #endfor
    
    if len(missing) > 0:
        log('ERROR: Cannot find:\n' + '\n'.join(missing))
        sys.exit(1)
    #endif
#enddef

def dbg_assembly(reads, outdir, q=3, Q=None, k=32, name='transabyss', cov=2, threads=1, mpi_np=0, strand_specific=False, E=0, e=2):
    """Generate the initial De Bruijn graph assembly with ABySS.
    """
    
    # Generate *-1.fa, *-1.adj, *-bubbles.fa, coverage.hist
    cmd_params = ['abyss-pe', 'graph=adj', '--directory=%s' % outdir, 'k=%d' % k, 'name=%s' % name, 'E=%d' % E, 'e=%d' % e, 'c=%d' % cov, 'j=%d' % threads, '%s-1.fa' % name, '%s-1.adj' % name]
    
    if q:
        cmd_params.append('q=%d' % q)
    #endif
    
    if Q:
        cmd_params.append('Q=%d' % Q)
    #endif
    
    if strand_specific:
        cmd_params.append('SS=--SS')
    #endif
    
    # Specify the number of MPI processes
    if mpi_np > 0:
        cmd_params.append('np=%d' % mpi_np)
    #endif
    
    # Specify the input reads for sequence content of the assembly    
    cmd_params.append('se="%s"' % ' '.join(reads))

    run_shell_cmd(' '.join(cmd_params))
#enddef

def abyss_merge_contigs(in_fasta, in_adj, in_path, out_fasta, k=32, merged_only=False):
    """Run MergeContigs for the given paths.
    """
    
    cmd_params = ['MergeContigs', '--kmer=%d' % k, '--out=%s' % out_fasta]
    if merged_only:
        cmd_params.append('--merged')
    #endif
    cmd_params.extend([in_fasta, in_adj, in_path])
    run_shell_cmd(' '.join(cmd_params))
#enddef

def abyss_merge_contigs_adj(in_fasta, in_adj, in_path, out_fasta, out_adj, k=32, merged_only=False):
    """Run MergeContigs for the given paths.
    """
    
    cmd_params = ['MergeContigs', '--kmer=%d' % k, '--out=%s' % out_fasta, '--adj --graph=%s' % out_adj]
    if merged_only:
        cmd_params.append('--merged')
    #endif
    cmd_params.extend([in_fasta, in_adj, in_path])
    run_shell_cmd(' '.join(cmd_params))
#enddef

def get_skip_psl_self_awk_script_path(strand_specific=False):
    skip_psl_self_awk = None
    if strand_specific:
        assert os.path.isfile(SKIP_PSL_SELF_SS)
        skip_psl_self_awk = SKIP_PSL_SELF_SS
    elif not strand_specific:
        assert os.path.isfile(SKIP_PSL_SELF)
        skip_psl_self_awk = SKIP_PSL_SELF
    #endif
    return skip_psl_self_awk
#enddef

def unitig_assembly(in_fasta, in_adj, out_fasta, out_adj, tmp_file_prefix, island_size, max_iteration=2, k=32, strand_specific=False, min_percent_identity=0.95, indel_size_tolerance=1, seed_cov_grad=0.05, threads=1, useblat=False):
    """Generate the unitig assembly with multiple iterations of adjacency graph simplification.
    """
    
    i = 1
    if max_iteration > 0:
        last_good_adj = in_adj
        last_good_fasta = in_fasta
    
        while i <= max_iteration:
            iteration_prefix = tmp_file_prefix + 'r' + str(i)
            iteration_stamp = iteration_prefix + '.COMPLETE'
            fasta2 = iteration_prefix + '.filtered.fa'
            adj2 = iteration_prefix + '.filtered.adj'
            
            if os.path.exists(iteration_stamp):
                log('CHECKPOINT: Iteration %d of graph simplification was done previously. Will not re-run ...' % i)
            else:
                log('Iteration %d of graph simplification ...' % i)
                
                ref_path1 = iteration_prefix + '.ref.path'
                braid_cids1 = iteration_prefix + '.braid.cids'
                ref_fasta1 = iteration_prefix + '.ref.fa'
                ref_fasta1_selfalign_psl = ref_fasta1 + '.selfalign.psl'
                
                if not has_edges(last_good_adj):
                    # ADJ file has no edges. No more simplification.
                    log('WARNING: ADJ (%s) has no edges; no graph simplification can be done.' % os.path.basename(last_good_adj))
                    break
                #endif
                
                # Identify braids and generate reference paths        
                walked, marked = unbraid(last_good_adj, k, ref_path1, braid_cids1, strand_specific=strand_specific, cov_gradient=seed_cov_grad, length_diff_tolerance=indel_size_tolerance)
                log('Walked %d paths and marked %d vertices for removal.' % (walked, (marked if strand_specific else 2*marked)))
                                        
                remove_cids1 = braid_cids1
                have_braids = marked > 0
                have_rpaths = False
                
                if useblat:
                    log('Redundancy removal with BLAT ...')
                    
                    # Generate fasta for reference paths and islands
                    abyss_merge_contigs(last_good_fasta, last_good_adj, ref_path1, ref_fasta1, k=k)
                
                    # Self-align reference fasta with BLAT                    
                    blat_self_align(ref_fasta1, ref_fasta1_selfalign_psl, percent_id=min_percent_identity, max_consecutive_edits=indel_size_tolerance, min_seq_len=k, threads=threads, skip_psl_self_awk=get_skip_psl_self_awk_script_path(strand_specific=strand_specific))
                    
                    # Identify redundant references
                    rrefs = psl_cid_extractor.extract_cids(psl=ref_fasta1_selfalign_psl, samestrand=strand_specific, min_percent_identity=min_percent_identity, max_consecutive_edits=indel_size_tolerance, report_redundant=True)
                    
                    log('%d potentially removable paths ...' % len(rrefs))
                    
                    have_rpaths = len(rrefs) > 0

                    if have_rpaths:
                        # Evaluate potentially removable paths, generate a list of ids of removable sequences.
                        remove_cids1 = iteration_prefix + '.rm.cids'
                        marked = remove_redundant_paths(rrefs, last_good_adj, k, braid_cids1, ref_path1, remove_cids1, strand_specific=strand_specific)
                        log('Marked %d more vertices for removal.' % (marked if strand_specific else 2*marked))
                        rrefs = None
                    elif not have_braids:
                        # No potentially removable paths and no braids, this is the last iteration!
                        log('No removable paths ...')
                        break
                    else:
                        # No potentially removable paths, but we have braids. So, continue!
                        log('No removable paths ...')
                    #endif
                #endif
                                
                # Generate *.filtered.adj1, *.path
                path1 = iteration_prefix + '.path'
                adj2_tmp1 = adj2 + '1'
                remove_rref_cmd_params1 = ['abyss-filtergraph --shim --assemble', '--kmer=%d' % k, '--island=%d' % island_size, '--remove=%s' % remove_cids1, '--graph=%s' % adj2_tmp1]
                
                if strand_specific:
                    remove_rref_cmd_params1.append('--SS')
                #endif
                
                remove_rref_cmd_params1.append(last_good_adj)
                remove_rref_cmd_params1.append(last_good_fasta)
                
                remove_rref_cmd_params1.append('> %s' % path1)
                                                
                run_shell_cmd(' '.join(remove_rref_cmd_params1))                                                
                                                
                # Generate *.filtered.fa, *.filtered.adj
                abyss_merge_contigs_adj(last_good_fasta, adj2_tmp1, path1, fasta2, adj2, k=k)
                                                
                touch(iteration_stamp)
                log('Completed iteration %d of graph simplification.' % i)
            #endif
            
            last_good_adj = adj2
            last_good_fasta = fasta2
            i += 1
        #endwhile    
        log('Graph simplification stopped at iteration %d' % i)
        
        if last_good_adj == in_adj:
            shutil.copy(in_adj, out_adj)            
        else:
            shutil.move(last_good_adj, out_adj)
        #endif
        
        if last_good_fasta == in_fasta:
            shutil.copy(in_fasta, out_fasta)
        else:
            shutil.move(last_good_fasta, out_fasta)
        #endif
                
    else:
        log('Graph simplification has been turned off ...')
    
        shutil.copy(in_adj, out_adj)
        shutil.copy(in_fasta, out_fasta)
    #endif
    
#enddef

def reference_assembly(in_fasta, in_adj, ref_path, ref_fasta, k=32, strand_specific=False, seed_cov_grad=0.05, in_dist=None):
    """Generate an assembly of reference paths.
    """

    if has_edges(in_adj):
        walked = walk(in_adj, k, ref_path, strand_specific=strand_specific, cov_gradient=seed_cov_grad, dist_file=in_dist)
        log('Walked %d paths.' % walked)
        abyss_merge_contigs(in_fasta, in_adj, ref_path, ref_fasta, k=k, merged_only=True)
    else:
        log('WARNING: ADJ (%s) has no edges; no reference paths to assemble.' % os.path.basename(in_adj))
        # Create the empty dummy files.
        touch(ref_fasta)
        touch(ref_path)
    #endif
#enddef

def contig_assembly(pe_reads, outdir, k=32, n=2, name='transabyss', threads=1, strand_specific=False, s=32):
    """Generate the contig assembly with ABySS.
    """
    
    # List of abyss-pe parameters for paired-end transcriptome assembly
    pe_assembly_params = ['l=%d' % k, 's=%d' % s, 'n=%d' % n, 'SIMPLEGRAPH_OPTIONS="--no-scaffold"', 'OVERLAP_OPTIONS="--no-scaffold"', 'MERGEPATH_OPTIONS="--greedy"']
    
    cmd_params = ['abyss-pe', 'graph=adj', '--directory=%s' % outdir, 'k=%d' % k, 'name=%s' % name, 'j=%d' % threads, 'in="%s"' % ' '.join(pe_reads)]
    
    if strand_specific:
        cmd_params.append('SS=--SS')
    #endif
    
    cmd_params.extend(pe_assembly_params)
    cmd_params.append('%s-6.fa' % name)
    run_shell_cmd(' '.join(cmd_params))
#enddef

def junction_extension(in_adj, in_fasta, out_path, out_fasta, k=32, dist=None):
    """Extend junctions for the given adjacency graph.
    """
    
    if has_edges(in_adj):
        jn_cmd_params = ['abyss-junction', in_adj]
        if dist:
            jn_cmd_params.append(dist)        
        #endif
        jn_cmd_params.append(' >' + out_path)
        run_shell_cmd(' '.join(jn_cmd_params))
        
        abyss_merge_contigs(in_fasta, in_adj, out_path, out_fasta, k=k, merged_only=True)
    else:
        log('WARNING: ADJ (%s) has no edges; no junction paths to assemble.' % os.path.basename(in_adj))
        # Create the empty dummy files.
        touch(out_fasta)
        touch(out_path)
    #endif
#enddef

def clean_up(stage_files_dict, level=0):
    """Clean up intermediate files for the stages less than or equal to the specified level.
    """
    
    for l in sorted(stage_files_dict):
        if l <= level:
            for tmpfile in stage_files_dict[l]:
                if os.path.isfile(tmpfile):
                    os.remove(tmpfile)
                #endif
            #endfor
        #endif
    #endfor
#enddef           

def __main__():
    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
        description='Assemble RNAseq with Trans-ABySS.',
        epilog=textwrap.dedent(package_info.SUPPORT_INFO)
    )
        
    parser.add_argument('--version', action='version', version=TRANSABYSS_VERSION)
    
    input_group = parser.add_argument_group("Input")
    input_group.add_argument('--se', dest='se_reads', metavar='PATH', type=str, nargs='+', help='single-end read files', action=paths_action(check_exist=True))
    input_group.add_argument('--pe', dest='pe_reads', metavar='PATH', type=str, nargs='+', help='paired-end read files', action=paths_action(check_exist=True))
    input_group.add_argument('--SS', dest='stranded', help='input reads are strand-specific', action='store_true', default=False)

    general_group = parser.add_argument_group("Basic Options")
    general_group.add_argument('--outdir', dest='outdir', help='output directory [%(default)s]', metavar='PATH', type=str, default=defaultwd(), action=path_action(check_exist=False))
    general_group.add_argument('--name', dest='name', help='assembly name [%(default)s] (ie. output assembly: \'%(default)s-final.fa\')', metavar='STR', type=str, default="transabyss")
    general_group.add_argument('--stage', dest='stage', choices=STAGES, help='run up to the specified stage [%(default)s]', type=str, default=STAGE_FINAL)
    general_group.add_argument('--length', dest='length', help='minimum output sequence length [%(default)s]', metavar='INT', type=int, default=100, action=threshold_action(0, inequality='>='))
    general_group.add_argument('--cleanup', dest='cleanup', choices=[0, 1, 2, 3], help='level of clean-up of intermediate files [%(default)s]', type=int, default=1)
    #parser.add_argument('--verbose', dest='verbose', choices=[0, 1, 2], help='verbosity level 0,1,2 [%(default)s]', type=int, default=0)

    abyss_group = parser.add_argument_group("ABySS Parameters")
    abyss_group.add_argument('--threads', dest='threads', help='number of threads (\'j\' in abyss-pe) [%(default)s]', metavar='INT', type=int, default=1, action=threshold_action(1, inequality='>='))
    abyss_group.add_argument('--mpi', dest='mpi', help='number of MPI processes (\'np\' in abyss-pe) [%(default)s]', metavar='INT', type=int, default=0, action=threshold_action(0, inequality='>='))
    abyss_group.add_argument('-k', '--kmer', dest='k', help='k-mer size [%(default)s]', metavar='INT', type=int, default=32, action=threshold_action(1, inequality='>='))
    abyss_group.add_argument('-c', '--cov', dest='c', help='minimum mean k-mer coverage of a unitig [%(default)s]', metavar='INT', type=int, default=2, action=threshold_action(0, inequality='>='))
    abyss_group.add_argument('-e', '--eros', dest='e', help='minimum erosion k-mer coverage [c]', metavar='INT', type=int, action=threshold_action(0, inequality='>='))
    abyss_group.add_argument('-E', '--seros', dest='E', help='minimum erosion k-mer coverage per strand [%(default)s]', metavar='INT', type=int, default=0, action=threshold_action(0, inequality='>='))
    abyss_group.add_argument('-q', '--qends', dest='q', help='minimum base quality on 5\' and 3\' ends of a read [%(default)s]', metavar='INT', type=int, default=3, action=threshold_action(0, inequality='>='))
    abyss_group.add_argument('-Q', '--qall', dest='Q', help='minimum base quality throughout a read', metavar='INT', type=int, action=threshold_action(0, inequality='>='))
    abyss_group.add_argument('-n', '--pairs', dest='n', help='minimum number of pairs for building contigs [%(default)s]', metavar='INT', type=int, default=2, action=threshold_action(1, inequality='>='))
    abyss_group.add_argument('-s', '--seed', dest='s', help='minimum unitig size for building contigs [k]', metavar='INT', type=int, action=threshold_action(0, inequality='>='))

    graph_group = parser.add_argument_group("Advanced Options")
    graph_group.add_argument('--gsim', dest='iterations', help='maximum iterations of graph simplification [%(default)s]', metavar='INT', type=int, default=2, action=threshold_action(0, inequality='>='))
    graph_group.add_argument('--indel', dest='indel', help='indel size tolerance [%(default)s]', metavar='INT', type=int, default=1, action=threshold_action(0, inequality='>='))
    graph_group.add_argument('--island', dest='island', help='minimum length of island unitigs [%(default)s]', metavar='INT', type=int, default=0, action=threshold_action(0, inequality='>='))
    graph_group.add_argument('--useblat', dest='useblat', help='use BLAT alignments to remove redundant sequences.', action='store_true', default=False)
    #graph_group.add_argument('--overlap', dest='overlap', help='minimum overlap to merge contigs [2*k]', metavar='INT', type=int, action=threshold_action(0, inequality='>='))
    graph_group.add_argument('--pid', dest='p', help='minimum percent sequence identity of redundant sequences [%(default)s]', metavar='FLOAT', type=float, default=0.95, action=threshold_action(0.9, inequality='>='))
    graph_group.add_argument('--walk', dest='walk', help='percentage of mean k-mer coverage of seed for path-walking [%(default)s]', metavar='FLOAT', type=float, default=0.05, action=threshold_action(0.0, inequality='>='))
    graph_group.add_argument('--noref', dest='noref', help='do not include reference paths in final assembly', action='store_true', default=False)
    
    args = parser.parse_args()
    
    log(TRANSABYSS_NAME + ' ' + TRANSABYSS_VERSION)
    log('CMD: ' + ' '.join(sys.argv))
    log('=-' * 30)
    
    # Check environment and required paths
    if not check_env(executables=REQUIRED_EXECUTABLES, scripts=REQUIRED_SCRIPTS):
        log('ERROR: Your environment is not sufficient to run Trans-ABySS. Please check the missing executables, scripts, or directories.')
        sys.exit(1)
    #endif
        
    # Set default threads
    cpu_count = multiprocessing.cpu_count() 
    log("# CPU(s) available:\t" + str(cpu_count))
    log("# thread(s) requested:\t" + str(args.threads))
    args.threads = min(cpu_count, args.threads)
    log("# thread(s) to use:\t" + str(args.threads))

    # Set default erosion threshold
    if args.e is None:
        args.e = args.c
    #endif
    
    # Set default seed size
    if args.s is None:
        args.s = args.k
    #endif
        
    # Extract the absolute path of each input read file
    se_reads = []
    pe_reads = []
    all_reads = []

    if args.se_reads is not None:
        for path in args.se_reads:
            if os.path.exists(path):
                se_reads.append(os.path.abspath(path))
            else:
                log('ERROR: No such single-end reads file \'' + path + '\'')
                sys.exit(1)
            #endif
        #endfor
    #endif

    if args.pe_reads is not None:
        for path in args.pe_reads:
            if os.path.exists(path):
                pe_reads.append(os.path.abspath(path))
            else:
                log('ERROR: No such paired-end reads file \'' + path + '\'')
                sys.exit(1)
            #endif
        #endfor
    #endif

    all_reads.extend(pe_reads)
    all_reads.extend(se_reads)
    
    if len(se_reads) == 0 and len(pe_reads) == 0:
        log("ERROR: No input reads specified! Use option '--pe' to specify paired-end reads and/or option '--se' to specify single-end reads.")
        sys.exit(1)
    #endif
    
    # Create the output directory if it does not already exist
    if not os.path.isdir(args.outdir):
        log("Creating output directory: %s" % args.outdir)
        os.makedirs(args.outdir)
    #endif
    
    # The path prefix of all output files
    prefix = os.path.join(args.outdir, args.name)
            
    tmpfiles1 = []
    tmpfiles2 = []
    tmpfiles3 = []
    
    # Start the stop watch
    stopwatch = StopWatch()
    
    remaining_stages = compare_stages(STAGE_DBG, args.stage)
    if remaining_stages >= 0:
        # CHECKPOINT DBG
        if check_stamp(prefix, STAGE_DBG_STAMP):
            log('CHECKPOINT: De Bruijn graph assembly was done previously. Will not re-run ...')
        else:
            # Generate *-1.fa, *-1.adj, *-bubbles.fa, coverage.hist
            dbg_assembly(all_reads, args.outdir, q=args.q, Q=args.Q, k=args.k, name=args.name, cov=args.c, threads=args.threads, mpi_np=args.mpi, strand_specific=args.stranded, E=args.E, e=args.e)
            make_stamp(prefix, STAGE_DBG_STAMP)
            log('CHECKPOINT: De Bruijn graph assembly completed.')
        #endif
        
        if remaining_stages == 0:
            log('=-' * 30)
            log('Assembly stopped at stage \'%s\'' % STAGE_DBG)
            log('Total wallclock run time: %d h %d m %d s' % (stopwatch.stop()))
        #endif
    #endif

    adj1 = prefix + '-1.adj'
    fasta1 = prefix + '-1.fa'
    bubbles1 = prefix + '-bubbles.fa'        
    coverage_hist = os.path.join(args.outdir, 'coverage.hist')
    
    tmpfiles3.extend([adj1, fasta1, bubbles1, coverage_hist])
    
    last_good_adj = adj1
    last_good_fasta = fasta1
    adj3 = prefix + '-3.adj'
    fasta3 = prefix + '-3.fa'
    tmp_file_prefix = prefix + '-unitigs.'
    
    remaining_stages = compare_stages(STAGE_UNITIGS, args.stage)
    if remaining_stages >= 0:
        # CHECKPOINT: UNITIGS
        if check_stamp(prefix, STAGE_UNITIGS_STAMP):
            log('CHECKPOINT: Unitig assembly was done previously. Will not re-run ...')
        else:
            required_files_exist([fasta1, adj1])
            unitig_assembly(fasta1, adj1, fasta3, adj3, tmp_file_prefix, island_size=args.island, max_iteration=args.iterations, k=args.k, strand_specific=args.stranded, min_percent_identity=args.p, seed_cov_grad=args.walk, indel_size_tolerance=args.indel, threads=args.threads, useblat=args.useblat)
            make_stamp(prefix, STAGE_UNITIGS_STAMP)
            log('CHECKPOINT: Unitig assembly completed.')
        #endif

        if remaining_stages == 0:
            log('=-' * 30)
            log('Assembly stopped at stage \'%s\'' % STAGE_UNITIGS)
            log('Total wallclock run time: %d h %d m %d s' % (stopwatch.stop()))
        #endif
    #endif
    
    tmpfiles1.extend(glob.glob(tmp_file_prefix + '*'))
    
    ref_fasta = prefix + '-ref.fa'
    ref_path = prefix + '-ref.path'    
    pathj = prefix + '-jn.path'
    fastaj = prefix + '-jn.fa'
    std_assembly = None
    
    perform_pe_assembly = len(pe_reads) > 0
    
    if perform_pe_assembly and not has_edges(adj3):
        log('WARNING: ADJ (%s) has no edges; will not perform paired-end assembly.' % os.path.basename(adj3))
        perform_pe_assembly = False
    #endif
    
    if perform_pe_assembly:
        # have paired-end reads
    
        std_assembly = prefix + '-6.fa'
        adj5 = prefix + '-5.adj'
        dist3 = prefix + '-3.dist'
        
        remaining_stages = compare_stages(STAGE_CONTIGS, args.stage)
        if remaining_stages >= 0:
            # CHECKPOINT: CONTIGS
            if check_stamp(prefix, STAGE_CONTIGS_STAMP):
                log('CHECKPOINT: Contig assembly was done previously. Will not re-run ...')
            else:
                required_files_exist([fasta3, adj3])
                contig_assembly(pe_reads, args.outdir, k=args.k, n=args.n, name=args.name, threads=args.threads, strand_specific=args.stranded, s=args.s)
                make_stamp(prefix, STAGE_CONTIGS_STAMP)
                log('CHECKPOINT: Contig assembly completed.')
            #endif
            
            if remaining_stages == 0:
                log('=-' * 30)
                log('Assembly stopped at stage \'%s\'' % STAGE_CONTIGS)
                log('Total wallclock run time: %d h %d m %d s' % (stopwatch.stop()))
            #endif
        #endif
        
        remaining_stages = compare_stages(STAGE_REFERENCES, args.stage)
        if remaining_stages >= 0:
            # CHECKPOINT: REFERENCES
            if check_stamp(prefix, STAGE_REFERENCES_STAMP):
                log('CHECKPOINT: Reference path assembly was done previously. Will not re-run ...')
            else:
                # reference assembly
                required_files_exist([fasta3, adj5])
                reference_assembly(fasta3, adj5, ref_path, ref_fasta, k=args.k, strand_specific=args.stranded, seed_cov_grad=args.walk, in_dist=dist3)
                make_stamp(prefix, STAGE_REFERENCES_STAMP)
                log('CHECKPOINT: Reference path assembly completed.')
            #endif

            if remaining_stages == 0:
                log('=-' * 30)
                log('Assembly stopped at stage \'%s\'' % STAGE_REFERENCES)
                log('Total wallclock run time: %d h %d m %d s' % (stopwatch.stop()))
            #endif
        #endif
        
        remaining_stages = compare_stages(STAGE_JUNCTIONS, args.stage)
        if remaining_stages >= 0:
            # CHECKPOINT: JUNCTIONS
            if check_stamp(prefix, STAGE_JUNCTIONS_STAMP):
                log('CHECKPOINT: Junction extension was done previously. Will not re-run ...')
            else:
                # Extend PE junctions
                required_files_exist([fasta3, adj5])
                junction_extension(adj5, fasta3, pathj, fastaj, k=args.k, dist=dist3)
                make_stamp(prefix, STAGE_JUNCTIONS_STAMP)
                log('CHECKPOINT: Junction extension completed.')
            #endif

            if remaining_stages == 0:
                log('=-' * 30)
                log('Assembly stopped at stage \'%s\'' % STAGE_JUNCTIONS)
                log('Total wallclock run time: %d h %d m %d s' % (stopwatch.stop()))
            #endif
        #endif
                
        tmpfiles2.extend([adj5, dist3])
        for suffix in ['-3.hist', '-4.fa', '-4.adj', '-4.path1', '-4.path2', '-4.path3', '-5.fa', '-5.path']:
            tmpfiles2.append(prefix + suffix)
        #endfor 
    else:
        # no paired-end reads
        
        std_assembly = fasta3
        
        remaining_stages = compare_stages(STAGE_REFERENCES, args.stage)
        if remaining_stages >= 0:
            # CHECKPOINT: REFERENCES
            if check_stamp(prefix, STAGE_REFERENCES_STAMP):
                log('CHECKPOINT: Reference path assembly was done previously. Will not re-run ...')
            else:
                # reference assembly
                required_files_exist([fasta3, adj3])
                reference_assembly(fasta3, adj3, ref_path, ref_fasta, k=args.k, strand_specific=args.stranded, seed_cov_grad=args.walk)
                make_stamp(prefix, STAGE_REFERENCES_STAMP)
                log('CHECKPOINT: Reference path assembly completed.')
            #endif

            if remaining_stages == 0:
                log('=-' * 30)
                log('Assembly stopped at stage \'%s\'' % STAGE_REFERENCES)
                log('Total wallclock run time: %d h %d m %d s' % (stopwatch.stop()))
            #endif
        #endif
        
        remaining_stages = compare_stages(STAGE_JUNCTIONS, args.stage)
        if remaining_stages >= 0:
            # CHECKPOINT: JUNCTIONS
            if check_stamp(prefix, STAGE_JUNCTIONS_STAMP):
                log('CHECKPOINT: Junction extension was done previously. Will not re-run ...')
            else:                
                # Extend SE junctions
                required_files_exist([fasta3, adj3])
                junction_extension(adj3, fasta3, pathj, fastaj, k=args.k)
                make_stamp(prefix, STAGE_JUNCTIONS_STAMP)
                log('CHECKPOINT: Junction extension completed.')
            #endif

            if remaining_stages == 0:
                log('=-' * 30)
                log('Assembly stopped at stage \'%s\'' % STAGE_JUNCTIONS)
                log('Total wallclock run time: %d h %d m %d s' % (stopwatch.stop()))
            #endif
        #endif
    #endif
    
    tmpfiles2.extend([pathj, fastaj, std_assembly])

    fastac = prefix + '-concat.fa'
    fastaf = prefix + '-final.fa'
    fastac_selfalign_psl = fastac + '.selfalign.psl'
    
    remaining_stages = compare_stages(STAGE_FINAL, args.stage)
    if remaining_stages >= 0:
        # CHECKPOINT: FINAL
        if check_stamp(prefix, STAGE_FINAL_STAMP):
            log('CHECKPOINT: Final assembly was done previously. Will not re-run ...')
        else:
            # Concatenate -3.fa/-6.fa, -jn.fa (and -ref.fa) to concat.fa
            required_files = [std_assembly, fastaj]
            path_prefix_map = {'S':std_assembly, 'J':fastaj}
            
            if not args.noref:
                # include sequences from -ref.fa in the final assembly
                required_files.append(ref_fasta)
                path_prefix_map['R'] = ref_fasta
            #endif
            
            required_files_exist(required_files)
                        
            if args.length is not None and args.length > args.k:
                # Since the shortest assembled sequence *always* has a length >= k, we only filter if theshold > k.
                fastaf_all = prefix + '-final.all.fa'
                tmpfiles1.append(fastaf_all)
                if args.useblat:
                    log('Using BLAT to remove redundancy ...')
            
                    blat_merge_fastas(path_prefix_map, fastaf_all, concat_fa=fastac, concat_fa_selfalign_psl=fastac_selfalign_psl, percent_identity=args.p, strand_specific=args.stranded, indel_size_tolerance=args.indel, min_seq_len=args.k, minoverlap=0, threads=args.threads, cleanup=args.cleanup>0, skip_psl_self_awk=get_skip_psl_self_awk_script_path(strand_specific=args.stranded))
                else:
                    log('Using abyss-map to remove redundancy ...')
                    abyssmap_merge_fastas(path_prefix_map, fastaf_all, concat_fa=fastac, strand_specific=args.stranded, cleanup=args.cleanup>0, threads=args.threads, iterative=False)
                #endif
                
                log('Removing sequences shorter than %d ...' % args.length)
                filter_fasta(fastaf_all, fastaf, min_length=args.length)
            else:
                # Keep all sequences
                if args.useblat:
                    log('Using BLAT to remove redundancy ...')
                    
                    blat_merge_fastas(path_prefix_map, fastaf, concat_fa=fastac, concat_fa_selfalign_psl=fastac_selfalign_psl, percent_identity=args.p, strand_specific=args.stranded, indel_size_tolerance=args.indel, min_seq_len=args.k, minoverlap=0, threads=args.threads, cleanup=args.cleanup>0, skip_psl_self_awk=get_skip_psl_self_awk_script_path(strand_specific=args.stranded))
                else:
                    log('Using abyss-map to remove redundancy ...')
                    abyssmap_merge_fastas(path_prefix_map, fastaf, concat_fa=fastac, strand_specific=args.stranded, cleanup=args.cleanup>0, threads=args.threads, iterative=False)
                #endif
            #endif
            
            make_stamp(prefix, STAGE_FINAL_STAMP)
            log('CHECKPOINT: Final assembly completed.')
        #endif
        
        if remaining_stages == 0:
            log('=-' * 30)
            log('Assembly generated with %s %s :)' % (TRANSABYSS_NAME, TRANSABYSS_VERSION))
            log('Final assembly: ' + fastaf)
            log('Total wallclock run time: %d h %d m %d s' % (stopwatch.stop()))
        #endif
    #endif

    tmpfiles1.append(fastac)
    tmpfiles1.append(fastac_selfalign_psl)
    
    # Clean up intermediate files
    if args.cleanup > 0:
        clean_up({1:tmpfiles1, 2:tmpfiles2, 3:tmpfiles3}, level=args.cleanup)
    #endif
        
#enddef

if __name__ == '__main__':
    __main__()
#endif

#EOF
