#!/opt/conda/conda-bld/transabyss_1774380212466/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl/bin/python

# written by Ka Ming Nip
# Copyright 2018 Canada's Michael Smith Genome Sciences Centre

import argparse
import multiprocessing
import os
import shutil
import sys
import textwrap
from transabyss import package_info
from transabyss.common_utils import StopWatch
from transabyss.common_utils import check_env
from transabyss.common_utils import log
from transabyss.common_utils import path_action
from transabyss.common_utils import paths_action
from transabyss.common_utils import threshold_action
from transabyss.fasta_utils import abyssmap_merge_fastas
from transabyss.fasta_utils import blat_merge_fastas
from transabyss.fasta_utils import filter_fasta


TRANSABYSS_VERSION = package_info.VERSION
TRANSABYSS_NAME = package_info.NAME

PACKAGEDIR = package_info.PACKAGEDIR
BINDIR = package_info.BINDIR
SKIP_PSL_SELF = os.path.join(BINDIR, 'skip_psl_self.awk')
SKIP_PSL_SELF_SS = os.path.join(BINDIR, 'skip_psl_self_ss.awk')

REQUIRED_EXECUTABLES = ['blat', 'abyss-map']
REQUIRED_SCRIPTS = [SKIP_PSL_SELF, SKIP_PSL_SELF_SS]

def default_outfile():
    return os.path.join(os.getcwd(), "transabyss-merged.fa")
#enddef

def __main__():
           
    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
        description='Merge Trans-ABySS assemblies.',
        epilog=textwrap.dedent(package_info.SUPPORT_INFO)
    )
    
    parser.add_argument('--version', action='version', version=TRANSABYSS_VERSION)    

    input_group = parser.add_argument_group("Input")
    input_group.add_argument(dest='fastas', metavar='PATH', type=str, nargs='+', help='assembly FASTA file(s)', action=paths_action(check_exist=True))
    input_group.add_argument('--mink', dest='mink', help='smallest k-mer size', metavar='INT', type=int, required=True, action=threshold_action(1, inequality='>='))
    input_group.add_argument('--maxk', dest='maxk', help='largest k-mer size', metavar='INT', type=int, required=True, action=threshold_action(1, inequality='>='))
    input_group.add_argument('--prefixes', dest='prefixes', metavar='STR', type=str, nargs='+', help='prefixes for the contigs from each assembly')
    input_group.add_argument('--SS', dest='stranded', help='assemblies are strand-specific', action='store_true', default=False)
    
    general_group = parser.add_argument_group("Basic Options")
    general_group.add_argument('--force', dest='force', help='force overwriting', action='store_true', default=False)
    general_group.add_argument('--out', dest='outfile', help='output file [%(default)s]', metavar='PATH', type=str, default=default_outfile(), action=path_action(check_exist=False))
    general_group.add_argument('--threads', dest='threads', help='number of threads [%(default)s]', metavar='INT', type=int, default=1, action=threshold_action(1, inequality='>='))
    general_group.add_argument('--length', dest='length', help='shortest sequence to report [%(default)s]', metavar='INT', type=int, default=0, action=threshold_action(0, inequality='>='))
    general_group.add_argument('--no-cleanup', dest='nocleanup', help='do not remove intermediate files at completion', action='store_true', default=False)
    
    advanced_group = parser.add_argument_group("Advanced Options")
    advanced_group.add_argument('--abyssmap', dest='abyssmap', help='use abyss-map to merge all FASTA files at once (NOTE: faster than BLAT but less sensitive and more memory intensive)', action='store_true', default=False)
    advanced_group.add_argument('--abyssmap-itr', dest='abyssmap_itr', help='use abyss-map to merge one additional FASTA at a time, utilizing less memory.', action='store_true', default=False)
    advanced_group.add_argument('--indel', dest='indel', help='indel size tolerance [%(default)s]', metavar='INT', type=int, default=1, action=threshold_action(0, inequality='>='))
    #graph_group.add_argument('--overlap', dest='overlap', help='minimum overlap to merge contigs [2*maxk]', metavar='INT', type=int, action=threshold_action(0, inequality='>='))
    advanced_group.add_argument('--pid', dest='p', help='minimum percent sequence identity of redundant sequences [%(default)s]', metavar='FLOAT', type=float, default=0.95, action=threshold_action(0.9, inequality='>='))
    
    args = parser.parse_args()
    
    log(TRANSABYSS_NAME + ' ' + TRANSABYSS_VERSION)
    log('CMD: ' + ' '.join(sys.argv))
    log('=-' * 30)
    
    # Check environment and required paths
    if not check_env(executables=REQUIRED_EXECUTABLES, scripts=REQUIRED_SCRIPTS):
        log('ERROR: Your environment is not sufficient to run Trans-ABySS. Please check the missing executables, scripts, or directories.')
        sys.exit(1)
    #endif
    
    if not args.force and os.path.exists(args.outfile):
        log('ERROR: Output file already exists: %s\nPlease either use the \'--force\' option to overwrite the existing file or use the \'--out\' option to specify a different output file path.' % args.outfile)
        sys.exit(1)
    #endif

    # Set default threads
    cpu_count = multiprocessing.cpu_count() 
    log("# CPU(s) available:\t" + str(cpu_count))
    log("# thread(s) requested:\t" + str(args.threads))
    args.threads = min(cpu_count, args.threads)
    log("# thread(s) to use:\t" + str(args.threads))
        
    # Start the stop watch
    stopwatch = StopWatch()
    
    prefixes = args.prefixes
    if prefixes is None:
        prefixes = []
        for i in range(len(args.fastas)):
            prefixes.append('A' + str(i) + '.')
        #endfor
    #endif
    
    assert len(prefixes) == len(args.fastas)
    
    log('Assembly Prefixes:')
    path_prefix_map = {}
    for prefix, path in zip(prefixes, args.fastas):
        if os.path.exists(path):
            path_prefix_map[prefix] = path
            log(prefix + '\t' + path)
        else:
            log('ERROR: No such assembly file \'' + path + '\'')
            sys.exit(1)
        #endif
    #endfor
    
    if args.length is not None and args.length > 0:        
        merge_fa = args.outfile + '.prefilter.fa'
        
        if args.abyssmap:
            abyssmap_merge_fastas(path_prefix_map, merge_fa, strand_specific=args.stranded, cleanup=not args.nocleanup, threads=args.threads, iterative=False)
        elif args.abyssmap_itr:
            abyssmap_merge_fastas(path_prefix_map, merge_fa, strand_specific=args.stranded, cleanup=not args.nocleanup, threads=args.threads, iterative=True)
        else:
            skip_psl_self_awk = None
            if args.stranded and os.path.isfile(SKIP_PSL_SELF_SS):
                skip_psl_self_awk = SKIP_PSL_SELF_SS
            elif not args.stranded and os.path.isfile(SKIP_PSL_SELF):
                skip_psl_self_awk = SKIP_PSL_SELF
            #endif
        
            blat_merge_fastas(path_prefix_map, merge_fa, percent_identity=args.p, strand_specific=args.stranded, cleanup=not args.nocleanup, minoverlap=0, threads=args.threads, indel_size_tolerance=args.indel, min_seq_len=args.mink, skip_psl_self_awk=skip_psl_self_awk)
        #endif
        
        log('Removing sequences shorter than %d bp ...' % args.length)
        filtered_fa = args.outfile + '.tmp.filtered.fa'
        filter_fasta(merge_fa, filtered_fa, min_length=args.length)
        shutil.move(filtered_fa, args.outfile)
        
        if not args.nocleanup:
            for t in [merge_fa, filtered_fa]:
                if os.path.isfile(t):
                    os.remove(t)
                #endif
            #endfor
        #endif
    else:
        if args.abyssmap:
            abyssmap_merge_fastas(path_prefix_map, args.outfile, strand_specific=args.stranded, cleanup=not args.nocleanup, threads=args.threads, iterative=False)
        elif args.abyssmap_itr:
            abyssmap_merge_fastas(path_prefix_map, args.outfile, strand_specific=args.stranded, cleanup=not args.nocleanup, threads=args.threads, iterative=True)
        else:
            skip_psl_self_awk = None
            if args.stranded and os.path.isfile(SKIP_PSL_SELF_SS):
                skip_psl_self_awk = SKIP_PSL_SELF_SS
            elif not args.stranded and os.path.isfile(SKIP_PSL_SELF):
                skip_psl_self_awk = SKIP_PSL_SELF
            #endif
        
            blat_merge_fastas(path_prefix_map, args.outfile, percent_identity=args.p, strand_specific=args.stranded, cleanup=not args.nocleanup, minoverlap=0, threads=args.threads, indel_size_tolerance=args.indel, min_seq_len=args.mink, skip_psl_self_awk=skip_psl_self_awk)
        #endif
    #endif
    
    log('=-' * 30)
    log('Merged assembly: ' + args.outfile)
    log('Total wallclock run time: %d h %d m %d s' % (stopwatch.stop()))
#enddef

if __name__ == '__main__':
    __main__()
#endif

#EOF
