#!/usr/bin/env python
from __future__ import print_function

import sys
import os
import re
import signal
import shutil
import argparse
from random import randint, uniform
from glob import iglob, glob
from paralleltask import Task

SCRIPT_PATH = os.path.dirname(os.path.realpath(__file__))
sys.path.append(SCRIPT_PATH + '/lib/')
from kit import *
from config_parser import ConfigParser

log = ''

class HelpFormatter(argparse.RawDescriptionHelpFormatter,argparse.ArgumentDefaultsHelpFormatter):
	pass

shell_index = 1
def set_task(cfg, task, args = None):
	cmd = ''
	path = ''
	global shell_index
	if task == 'db_stat':
		cmd = db_stat(cfg)
		path = cfg['raw_aligndir'] + '/%02d.db_stat.sh' % shell_index
		shell_index += 1
	elif task == 'db_split':
		cmd = db_split(cfg)
		path = cfg['raw_aligndir'] + '/%02d.db_split.sh' % shell_index
		shell_index += 1
	elif task == 'db_local':
		path = cfg['raw_aligndir'] + '/%02d.db_local.sh' % shell_index
		cmd = db_local(cfg, path[:-3] + '_clean.sh')
		shell_index += 1
	elif task == 'raw_align':
		path = cfg['raw_aligndir'] + '/%02d.raw_align.sh' % shell_index
		cmd = raw_align(cfg)
		shell_index += 1
	elif task == 'sort_align':
		path = cfg['raw_aligndir'] + '/%02d.sort_align.sh' % shell_index
		cmd = sort_align(cfg)
		shell_index = 1
	elif task == "seed_cns":
		path = cfg['cns_aligndir'] + '/%02d.seed_cns.sh' % shell_index
		cmd = seed_cns(cfg, args)
		shell_index += 1
	elif task == 'split_seed':
		shell_index = 1
		path = cfg['cns_aligndir'] + '/%02d.split_seed.sh' % shell_index
		cmd = split_seed(cfg)
		shell_index += 1
	elif task == 'cns_align':
		path = cfg['cns_aligndir'] + '/%02d.cns_align.sh' % shell_index
		cmd = cns_align(cfg, args)
		shell_index = 1
	elif task =='ctg_graph':
		path = cfg['ctg_graphdir'] + '/%02d.ctg_graph.sh' % shell_index
		cmd = ctg_graph(cfg)
		shell_index += 1
	elif task == 'ctg_align':
		path = cfg['ctg_graphdir'] + '/%02d.ctg_align.sh' % shell_index
		cmd = ctg_align(cfg, args)
		shell_index += 1
	elif task == 'ctg_cns':
		path = cfg['ctg_graphdir'] + '/%02d.ctg_cns.sh' % shell_index
		cmd = ctg_cns(cfg, args)
		shell_index = 1
	write2file(cmd, path)
	return path

def seed_cns(cfg, args):
	cmd = ''
	d = cfg['cns_aligndir'] + '/'
	for sort_align in args:
		cmd += pypath() + ' ' +  SCRIPT_PATH + '/lib/nextcorrect.py ' +  ' -f ' + d + \
			'/01.seed_cns.input.idxs -i ' + sort_align + ' ' +  cfg['correction_options'] + \
			' -o cns.fasta;\n'
	return cmd

def split_seed(cfg):
	def cal_cns_count(n):
		i = s = int(cfg['seed_cutfiles'])
		while i * (i + 1) < 2 * int(n):
			i += s
		return i
	total = cal_cns_count(cfg['parallel_jobs'])
	cmd = pypath() + ' ' + SCRIPT_PATH + '/lib/split_cns.py ' +  ' -f ' + cfg['input_fofn'] + \
		' -l %s -c %d' % (cfg['seed_cutoff'], total)
	return cmd

def pre_cns_seqs(cfg, infiles, out_file):
	with open(cfg['ctg_graphdir'] + '/' + out_file, 'w') as OUT:
		for infile in infiles:
			print(infile, file=OUT)

def pre_cns_ovl(cfg, infiles, out_file):
	with open(cfg['ctg_graphdir'] + '/' + out_file, 'w') as OUT:
		for infile in infiles:
			print(infile, file=OUT)

def pre_raw_idxs(cfg, out_file):
	with open (cfg['cns_aligndir'] + '/'  + out_file, 'w') as OUT:
		d = cfg['raw_aligndir'] + '/'
		for part_file in iglob(d + '.input.part.*.idx'):
			part_file = cfg['usetempdir'] + '/' + part_file.split('/')[-1] if 'usetempdir' in cfg else part_file
			print(part_file, file=OUT)
		for part_file in iglob(d + '.input.seed.*.idx'):
			part_file = cfg['usetempdir'] + '/' + part_file.split('/')[-1] if 'usetempdir' in cfg else part_file
			print(part_file, file=OUT)

def get_sort_align_output(subtasks):
	seed_result = []
	seed_ovl_iglob = 'input.seed.*.sorted.ovl'
	for subtask in subtasks:
		filenames = list(iglob(os.path.dirname(subtask.path) + '/' + seed_ovl_iglob))
		if len(filenames) == 1:
			seed_result.append(filenames[0])
		else:
			log.error('Failed to find output file pattern for task: ' + subtask.path)
			sys.exit(1)
	return seed_result

def get_cns_ovl_output(subtasks):
	ovl_result = []
	for subtask in subtasks:
		filename = os.path.dirname(subtask.path) + '/cns.filt.dovt.ovl'
		ovl_result.append(filename)
	return ovl_result

def get_ctg_graph_output(subtasks):
	ctg_result = os.path.dirname(subtasks[0].path) + '/nd.asm.p.fasta'
	return ctg_result

def get_cns_output(subtasks):
	cns_result = []
	for subtask in subtasks:
		filename = os.path.dirname(subtask.path) + '/cns.fasta'
		cns_result.append(filename)
	return cns_result

def split_cns_output(subtasks):
	pdir = os.path.dirname(subtasks[0].path) + '/cns*.fasta'
	cns_result = glob(pdir)
	return cns_result

def get_raw_align_output(subtasks):
	seed_result = {}
	seed_ovl_re = re.compile(r"input\.seed\.(\d+)\.2bit\.\d+\.ovl")
	seed_ovl_iglob = 'input.seed.*.ovl'
	for subtask in subtasks:
		filenames = list(iglob(os.path.dirname(subtask.path) + '/' + seed_ovl_iglob))
		if len(filenames) >= 1:
			for filename in filenames:
				g = seed_ovl_re.search(filename)
				if g:
					seed_idx = g.group(1)
					if seed_idx not in seed_result:
						seed_result[seed_idx] = []
					seed_result[seed_idx].append(filename)
				else:
					log.error('Failed to find output file pattern for task: ' + subtask.path)
					sys.exit(1)
		elif len(filenames) > 1:
			log.error('Output error for task: ' + subtask.path)
			sys.exit(1)
	return seed_result

def get_ctg_align_output(subtasks):
	align_results = []
	ctg_align_iglob = '*.sort.bam'
	for subtask in subtasks:
		filenames = glob(os.path.dirname(subtask.path) + '/' + ctg_align_iglob)
		if len(filenames) == 1:
			align_results.append(filenames[0])
		else:
			log.error('Failed to find output file pattern for task: ' + subtask.path)
			sys.exit(1)		
	return align_results

def clean_ctg_align_output(align_results):
	for align_result in align_results:
		if os.path.exists(align_result):
			os.remove(align_result)
			log.info('remove temporary result: ' + align_result)
		for format in ['bai', 'csi']:
			index_file = align_result + '.' + format
			if os.path.exists(index_file):
				os.remove(index_file)

def get_seed_idx(subtask):
	seed_ovl_re = re.compile(r"\.input\.seed\.(\d+)\.idx")
	with open(subtask) as IN:
		g = seed_ovl_re.search(IN.read())
		if g:
			return g.group(1)
		else:
			log.error('Failed to find shell file pattern for task: ' + subtask)
			sys.exit(1)

def clean_raw_align_output(finished_tasks):
	for finished_task in finished_tasks:
		if os.path.exists(finished_task.path + '.done'):
			with open (os.path.dirname(finished_task.path) + '/input.fofn') as IN:
				for input_file in IN:
					input_file = input_file.strip()
					os.remove(input_file)
					log.info('remove temporary result: ' + input_file)

def pre_sort_align_input(sort_align_input, subtasks):
	for subtask in subtasks:
		d = os.path.dirname(subtask.path)
		input_fofn = d + '/input.fofn'
		seed_idx = get_seed_idx(subtask.path)
		with open(input_fofn, 'w') as OUT:
			for input_file in sort_align_input[seed_idx]:
				print(input_file, file=OUT)

def pre_ctg_cns_input(outfile, ctg_align_output):
	write2file("\n".join(ctg_align_output), outfile)

def blc_genome(n, ref):
	data = {}
	genome_size = 0 
	with open(ref, 'r') as IN: 
		for line in IN: 
			if line.startswith('>'):
				lines = line.strip().split()
				seq_type = lines[1].split(':')[-1]
				seq_len = int(lines[2].split(':')[-1])
				seq_node = lines[3].split(':')[-1]
				data[lines[0][1:]] = [seq_type, seq_len, seq_node]
				genome_size += seq_len

	cmd = ''
	count = total_len = 0 
	blocksize = int(genome_size / float(n) + 1)
	for i in data:
		total_len += data[i][1]
		cmd += i + '\t' + str(count) + '\n'
		if total_len >= blocksize:
			total_len = 0 
			count += 1
	path = ref + '.blc'
	write2file(cmd, path)
	return data

def gather_ctg_cns_output(cfg, subtasks, seq_info):
	ctg_result = []
	ctg_result_iglob = 'nd.asm.f.part*.fasta'
	for subtask in subtasks:
		filenames = list(iglob(os.path.dirname(subtask.path) + '/' + ctg_result_iglob))
		if len(filenames) == 1:
			ctg_result.append(filenames[0])
		else:
			log.error('Failed to find output file pattern for task: ' + subtask.path)
			sys.exit(1)	

	i = 0
	asm = cfg['ctg_graphdir'] + '/nd.asm.fasta'
	while os.path.exists(asm):
		i += 1
		asm = cfg['ctg_graphdir'] + '/nd.asm.v%d.fasta' % (i)
	stat = []
	i = 0
	OUT = open(asm, 'w')
	for ctg_file in ctg_result:
		with open(ctg_file) as IN:
			for line in IN:
				if line.startswith('>'):
					lines = line.strip().split()
					seq_name, seq_len = line.strip().split()
					seq_name = seq_name[1:]
					seq_len = int(seq_len)
					if seq_name in seq_info:
						print(">ctg%06d type:s:%s length:i:%d node:i:%s" % (i, seq_info[seq_name][0], \
							seq_len, seq_info[seq_name][2]), file=OUT)
					else:
						seq_name = seq_name.split('_')[0]
						seq_type = 'linear' if seq_info[seq_name][2] != 'unknown' else 'unknown'
						print(">ctg%06d type:s:%s length:i:%d node:s:unknown" % (i, seq_type, \
							seq_len), file=OUT)
					stat.append(int(seq_len))
					i += 10
				else:
					print(line.strip(), file=OUT)
	OUT.close()
	out = cal_n50_info(stat, asm + '.stat')

	return asm, out

def asm_finished(cfg):
	asm_shell = cfg['ctg_graphdir'] + '/01.ctg_graph.sh'
	if os.path.exists(asm_shell):
		with open(asm_shell) as IN:
			g = re.search(r'nextgraph (.*) -f', IN.read())
			if g and g.group(1) == cfg['nextgraph_options']:
				return False
			else:
				return True
	else:
		return False

def ctg_graph(cfg):
	input_seq = cfg['ctg_graphdir'] + '/01.ctg_graph.input.seqs'
	input_ovl = cfg['ctg_graphdir'] + '/01.ctg_graph.input.ovls'
	cmd = SCRIPT_PATH + '/bin/nextgraph ' + cfg['nextgraph_options'] + ' -f %s %s -o nd.asm.p.%s;\n' \
		% (input_seq, input_ovl, cfg['_nextgraph_out_format'])
	for i in range(int(cfg['_random_round_with_less_accuracy'])):
		if i == 0:
			cmd += SCRIPT_PATH + '/bin/nextgraph -a 0 -A -f %s %s;\n' % (input_seq, input_ovl)
		if i == 1:
			cmd += SCRIPT_PATH + '/bin/nextgraph -a 0 -n 45 -f %s %s;\n' % (input_seq, input_ovl)
		elif i == 2:
			cmd += SCRIPT_PATH + '/bin/nextgraph -a 0 -I 0.5 -f %s %s;\n' % (input_seq, input_ovl)
		elif i == 3:
			cmd += SCRIPT_PATH + '/bin/nextgraph -a 0 -q 5 -f %s %s;\n' % (input_seq, input_ovl)
		elif i == 4:
			cmd += SCRIPT_PATH + '/bin/nextgraph -a 0 -N 1 -f %s %s;\n' % (input_seq, input_ovl)
		elif i == 5:
			cmd += SCRIPT_PATH + '/bin/nextgraph -a 0 -u 1 -f %s %s;\n' % (input_seq, input_ovl)
		elif i == 6:
			cmd += SCRIPT_PATH + '/bin/nextgraph -a 0 -k -f %s %s;\n' % (input_seq, input_ovl)	
		elif i == 7:
			cmd += SCRIPT_PATH + '/bin/nextgraph -a 0 -I 0.1 -f %s %s;\n' % (input_seq, input_ovl)
		elif i == 8:
			cmd += SCRIPT_PATH + '/bin/nextgraph -a 0 -G -f %s %s;\n' % (input_seq, input_ovl)
		elif i < 23:
			cmd += SCRIPT_PATH + '/bin/nextgraph -a 0 -q %d -f %s %s;\n' % (i - 8, input_seq, input_ovl)
		elif i == 23:
			cmd += SCRIPT_PATH + '/bin/nextgraph -a 0 -q %d -f %s %s;\n' % (100, input_seq, input_ovl)
		elif i < 1000:
			cmd += SCRIPT_PATH + '/bin/nextgraph -a 0 -n %d -q %d -I %.2f -S %.2f -N %d -r %.2f -m %.2f -C %d -z %d -B %d' % \
			(randint(20, 2000), randint(0, 15), uniform(0.1, 0.7), uniform(0.1, 0.7), randint(1, 2), \
			uniform(0.1, 0.7), uniform(1.5, 10), randint(10, 50), randint(5, 16), randint(35, 700)) + \
			(' -A' if randint(0, 1) else '') + ' -f %s %s;\n' % (input_seq, input_ovl)
	return cmd

def sort_align(cfg):
	cmd = ''
	d = cfg['raw_aligndir'] + '/'
	for seed_file in iglob(d + '.input.seed.*.idx'):
		cmd += SCRIPT_PATH + '/bin/ovl_sort '  + cfg['sort_options'] +  ' -i ' + seed_file +\
			' -o ' + os.path.basename(seed_file)[1:-3] + 'sorted.ovl '
		if 'usetempdir' in cfg:
			cmd += '-d ' + cfg['usetempdir'] + ' '
		cmd += 'input.fofn;\n'

	return cmd

def cns_align(cfg, cns_files):
	cmd = ''
	for i in range(len(cns_files)):
		for j in range(i, len(cns_files)):
			if i == j:
				cmd += SCRIPT_PATH + '/bin/minimap2-nd -I 6G --step 2 ' + cfg['minimap2_options_cns'] + ' '\
					+ cns_files[i]  + ' ' + cns_files[j] + ' -o ' + 'cns.filt.dovt.ovl;\n'
			else:
				cmd += SCRIPT_PATH + '/bin/minimap2-nd -I 6G --step 2 --dual=yes ' + cfg['minimap2_options_cns'] + ' '\
					+ cns_files[i]  + ' ' + cns_files[j] + ' -o ' + 'cns.filt.dovt.ovl;\n'
	return cmd

def ctg_align(cfg, ref):

	def cal_total_seed_len(infiles):
		total = 0
		for idx_file in infiles:
			with open(idx_file) as IN:
				for line in IN:
					if line != "\n":
						total += int(line.split()[2])
		return total

	def get_seed_files(idx = False):
		if cfg['input_type'] == 'corrected':
			with open(cfg['ctg_graphdir'] + '/01.ctg_graph.input.seqs') as IN:
				data = IN.read().strip()
			if idx:
				return (f + '.idx' for f in data.split("\n"))
			else:
				return data.split("\n")
		else:
			d = cfg['raw_aligndir'] + '/'
			return glob(d + '.input.seed*idx') if idx else glob(d + 'input.seed*2bit')

	cmd = ''
	for seed_file in get_seed_files():
		if 'usetempdir' in cfg:
			seed_file = cfg['usetempdir'] + '/' + os.path.basename(seed_file)
		cmd += SCRIPT_PATH + '/bin/minimap2-nd --step 3 ' + cfg['minimap2_options_map'] + ' -a -t %d %s %s|' % \
			(cfg['_map_threads'], ref, seed_file) + SCRIPT_PATH + '/bin/bam_sort -i -@ %d -o %s.sort.bam\n' % \
			(cfg['_map_threads'], os.path.basename(seed_file))
	
	if cfg['input_type'] != 'corrected':
		mindepth = 35
		gs = calgs(ref)
		total_seed_len = cal_total_seed_len(get_seed_files(idx=True))
		if total_seed_len < gs * (mindepth - 1):
			d = cfg['raw_aligndir'] + '/'
			part_idx_files = glob(d + '.input.part*idx')
			minlen = cal_minlen_from_idx(part_idx_files, len(part_idx_files), gs * mindepth - total_seed_len)
			minlen = 2000 if minlen < 2000 else minlen
			part_files =  glob(d + 'input.part*2bit')
			for part_file in part_files:
				if 'usetempdir' in cfg:
					part_file = cfg['usetempdir'] + '/' + os.path.basename(part_file)
				cmd += SCRIPT_PATH + '/bin/minimap2-nd --step 3 ' + cfg['minimap2_options_map'] + ' -a -t %d --minlen %d %s %s|' % \
					(cfg['_map_threads'], minlen, ref, part_file) + SCRIPT_PATH + '/bin/bam_sort -i -@ %d -o %s.sort.bam\n' % \
					(cfg['_map_threads'], os.path.basename(part_file))
	return cmd

def ctg_cns(cfg, ref):
	cmd = ''
	blc = ref + '.blc'
	for i in range(int(cfg['pa_correction'])):
		cmd += pypath() + ' ' + SCRIPT_PATH + '/lib/ctg_cns.py ' + cfg['ctg_cns_options'] + \
			' -g ' + ref + ' -b ' + blc + ' -i ' + str(i) + ' -r ' + cfg['read_type'] +  ' -l ' + \
			cfg['ctg_graphdir'] + '/03.ctg_cns.input.bams -o nd.asm.f.part%03d.fasta\n' % i
	return cmd

def raw_align(cfg):
	cmd = ''
	d = cfg['raw_aligndir'] + '/'
	k = 0
	batch_seed_size = 6 if cfg['read_type'] == 'hifi' else 3
	seed_files = glob(d + 'input.seed*2bit') #TODO: set global var 
	part_files = glob(d + 'input.part*2bit')
	for i in range(len(seed_files)):
		seed_file = seed_files[i]

		if 'usetempdir' in cfg:
			seed_file = cfg['usetempdir'] + '/' + os.path.basename(seed_file)

		for j in range(len(part_files)):
			part_file = part_files[j]

			if 'usetempdir' in cfg:
				part_file = cfg['usetempdir'] + '/' + os.path.basename(part_file)

			cmd += SCRIPT_PATH + '/bin/minimap2-nd --step 1 --dual=yes '  + cfg['minimap2_options_raw'] + ' '\
				 + seed_file  + ' ' + part_file + ' -o ' + os.path.basename(seed_file) + '.' +  str(k) + '.ovl;\n'
			k += 1

		for t in range(i, len(seed_files)):
			part_file = seed_files[t]
			
			if 'usetempdir' in cfg:
				part_file = cfg['usetempdir'] + '/' + os.path.basename(part_file)

			if part_file != seed_file:
				cmd += SCRIPT_PATH + '/bin/minimap2-nd --step 1 -I ' + str(batch_seed_size) + 'G --dual=yes ' + \
					cfg['minimap2_options_raw'] + ' ' + seed_file  + ' ' + part_file + ' -o ' + \
					os.path.basename(seed_file) + '.' +  str(k) + '.ovl;'
				cmd += 'ln -sf ' + os.path.basename(seed_file) + '.' +  str(k) + '.ovl ' + os.path.basename(part_file) + \
					'.' +  str(k) + '.ovl;'
			else:
				cmd += SCRIPT_PATH + '/bin/minimap2-nd --step 1 -I ' + str(batch_seed_size) + 'G ' + \
					cfg['minimap2_options_raw'] + ' ' + seed_file  + ' ' + part_file + ' -o ' + \
					os.path.basename(seed_file) + '.' +  str(k) + '.ovl;'
			cmd += '\n'
			k += 1
	return cmd

def get_availablenodes(cfg):
	node = ''
	ava_nodes = []
	if cfg['job_type'].lower() == 'sge':
		for line in os.popen("qhost -q").read().split('\n'):
			nodes = line.strip().split()
			nodes_len = len(nodes)
			if nodes_len == 11:
				node = nodes[0] if '-' not in nodes[6:9] else ''
			elif node and nodes_len > 2 and (not cfg['_sge_queue'] or nodes[0] in cfg['_sge_queue']) and node not in ava_nodes:
				ava_nodes.append(node)
	else:
		with open(cfg['nodelist']) as IN:
			for line in IN:
				if line.startswith('#') or line == '\n':
					continue
				lines = line.strip().split()
				if lines[0] not in ava_nodes:
					ava_nodes.append(lines[0])
	return ava_nodes

def db_local(cfg, clean_path):
	cmd = ''
	cmd_clean = ''
	d = cfg['raw_aligndir'] + '/'
	seed_files = glob(d + 'input.seed*2bit')
	part_files = glob(d + 'input.part*2bit')

	ava_nodes = get_availablenodes(cfg)
	log.info('find ' + str(len(ava_nodes)) + ' available nodes')
	for j in ava_nodes:
		cmd_clean += 'ssh -o ConnectTimeout=15 ' + j + \
			' \"if [ -d ' + cfg['usetempdir'] + ' ];' + \
			'then rm -rf ' + cfg['usetempdir']  + ' && echo remove ' + cfg['usetempdir'] + ' in node ' + j + ';fi;\"\n'

		cmd += 'ssh -o ConnectTimeout=15 ' + j + \
			' \"if [ -d ' + cfg['usetempdir'] + ' ];' + \
			'then echo ' + cfg['usetempdir'] + ' existed, exit ....; exit 1;' + \
			'else mkdir -p ' + cfg['usetempdir'] + ';'

		for i in seed_files + part_files:
			idirname = os.path.dirname(i)
			ibasename = os.path.basename(i)
			idx = idirname + '/.' + ibasename[:-4] + 'idx'
			cmd += 'cp -n ' + i + ' ' + cfg['usetempdir']  + '/;' + \
				'cp ' + idx + ' ' + cfg['usetempdir']  + '/;'

		cmd += 'fi;\"\n'

	write2file(cmd_clean, clean_path)
	log.warning('please run shell file: [' +  clean_path  + '] to clean ' + cfg['usetempdir'] + ' in each node.')
	return cmd

def reset_cfg(cfg):
	from config_parser import ConfigParser
	stat_info = ''
	stat_file = cfg['raw_aligndir'] + '/input.reads.stat'
	with open(stat_file) as IN:
		stat_info = IN.read()
	g = re.search(r'Clean\s+\d+\s+(\d+).*real seed depth:\s*(\d+\.\d+).*\s+(\d+)\s*bp', stat_info, re.DOTALL)
	if not g or len(g.groups()) != 3:
		log.error('regular match failed for file: %s' % stat_file)
		sys.exit(1)

	tcfg = ConfigParser()
	tcfg.cfg = cfg
	tcfg.update(int(g.group(1)), int(g.group(3)), float(g.group(2)))
	log.info('updated options:\n' + str(cfg))
	log.info('summary of input data:\nfile:\033[35m %s \033[0m\n' % stat_file + stat_info[stat_info.index('[Read length stat]'):].strip('\n'))
	if float(cfg['seed_depth']) < 10:
		log.error('the input data is insufficient for an assembly.')
		sys.exit(1)

def db_split(cfg):
	cmd = ' '.join([SCRIPT_PATH + '/bin/seq_dump',
		'-f', cfg['read_cutoff'],
		'-s', cfg['seed_cutoff'],
		'-b', str(cfg['blocksize']),
		'-n', cfg['seed_cutfiles'],
		'-d', cfg['raw_aligndir'],
		cfg['input_fofn']
		])
	return cmd

def db_stat(cfg):
	cmd = ' '.join([SCRIPT_PATH + '/bin/seq_stat',
		'-f', cfg['read_cutoff'],
		'-g', cfg['genome_size'],
		'-d', cfg['seed_depth'],
		'-a' if cfg['read_type'] == 'hifi' else '',
		'-o', cfg['raw_aligndir'] + '/input.reads.stat',
		cfg['input_fofn']
		])
	return cmd

def main(args):
	if not args[1]:
		parser.print_help()
		sys.exit(1)

	global log
	log_file = 'pid' + str(os.getpid()) + '.' + args[0].log.strip('pidXXX.')
	log = plog(log_file)
	log.info('NextDenovo start...')
	log.info('version:%s logfile:%s' % (getver(SCRIPT_PATH), log_file))
	cfg = ConfigParser(args[1][0]).cfg
	job_prefix = cfg['job_prefix']

	for d in (cfg['workdir'], cfg['raw_aligndir'], cfg['cns_aligndir'], cfg['ctg_graphdir']):
		if (d != cfg['workdir'] and not cfg['rewrite']) or (d == cfg['ctg_graphdir'] and asm_finished(cfg)):
			if os.path.exists(d):
				for i in range(100):
					e = d + '.backup.v' + str(i)
					if not os.path.exists(e):
						shutil.move(d, e)
						log.warning('backup ' + d + ' to ' + e)
						break

		if not pmkdir(d):
			log.info('skip mkdir: ' + d)
		else:
			log.info('mkdir: ' + d)

	if int(cfg['seed_cutoff']) <= 0:# or cfg['read_type'] == 'hifi':
		task = Task(set_task(cfg, 'db_stat'), dir_prefix='db_stat', job_prefix=job_prefix)
		if not task.is_finished():
			task.set_run(max_parallel_job=1, job_type=cfg['job_type'], use_drmaa=cfg['use_drmaa'], \
				submit=cfg['submit'], kill=cfg['kill'], check_alive=cfg['check_alive'], job_id_regex=cfg['job_id_regex'])
			task.run.start()
			if task.run.is_finished():
				task.set_task_finished()
				log.info('db_stat done')
			else:
				log.error('db_stat failed: please check the following logs:')
				for subtask in task.run.unfinished_jobs:
					log.error(subtask.err)
				sys.exit(1)
		else:
			log.info('skip step: db_stat')
		reset_cfg(cfg)
	else:
		log.info('options: \n' + str(cfg))
	
	if cfg['task'] in ['all', 'correct']:
		task = Task(set_task(cfg, 'db_split'), dir_prefix='db_split',job_prefix=job_prefix)
		if not task.is_finished():
			task.set_run(max_parallel_job=1, job_type=cfg['job_type'], use_drmaa=cfg['use_drmaa'], \
				submit=cfg['submit'], kill=cfg['kill'], check_alive=cfg['check_alive'], job_id_regex=cfg['job_id_regex'])
			task.run.start()
			if task.run.is_finished():
				task.set_task_finished()
				log.info('db_split done')
			else:
				log.error('db_split failed: please check the following logs:')
				for subtask in task.run.unfinished_jobs:
					log.error(subtask.err)
				sys.exit(1)
		else:
			log.info('skip step: db_split')

		if 'usetempdir' in cfg:
			task = Task(set_task(cfg, 'db_local'), dir_prefix='db_local', job_prefix=job_prefix, convert_path=False)
			if not task.is_finished():
				task.set_run(max_parallel_job=30, job_type='local')
				task.run.start()
				if task.run.is_finished():
					task.set_task_finished()
					log.info('db_local done')
				else:
					log.error('db_local failed: please check the following logs:')
					for subtask in task.run.unfinished_jobs:
						log.error(subtask.err)
					sys.exit(1)
			else:
				log.info('skip step: db_local')

		task = Task(set_task(cfg, 'raw_align'), dir_prefix='raw_align', job_prefix=job_prefix, convert_path=False)
		if not task.is_finished():
			task.set_run(max_parallel_job=cfg['parallel_jobs'], job_type = cfg['job_type'], mem='30G', \
				cpu=cfg['_minimap2_threads'][0], use_drmaa=cfg['use_drmaa'], submit=cfg['submit'], \
				kill=cfg['kill'], check_alive=cfg['check_alive'], job_id_regex=cfg['job_id_regex'])
			total_tasks = len(task.run.unfinished_jobs)
			task.run.start()
			while (not task.run.is_finished()):
				if len(task.run.unfinished_jobs) == total_tasks or not cfg['rerun']:
					log.error('raw_align failed: please check the following logs:')
					for subtask in task.run.unfinished_jobs:
						log.error(subtask.err)
					sys.exit(1)
				else:
					log.info(str(len(task.run.unfinished_jobs)) + ' raw_align jobs failed, and rerun for the '+ str(cfg['rerun']) + ' time')
					task.run.rerun()
					cfg['rerun'] -= 1
			else:
				task.set_task_finished()
				log.info('raw_align done')
		else:
			log.info('skip step: raw_align')

		subptasks = task.jobs
		task = Task(set_task(cfg, 'sort_align'), dir_prefix='sort_align', job_prefix=job_prefix, convert_path=False)
		if not task.is_finished():
			task.set_run(max_parallel_job=cfg['parallel_jobs'], job_type=cfg['job_type'], mem=cfg['_sort_mem'], \
				cpu=cfg['_sort_threads'], use_drmaa=cfg['use_drmaa'], submit=cfg['submit'], \
				kill=cfg['kill'], check_alive=cfg['check_alive'], job_id_regex=cfg['job_id_regex'])
			pre_sort_align_input(get_raw_align_output(subptasks), task.run.unfinished_jobs)
			task.run.start()
			if task.run.is_finished():
				task.set_task_finished()
				log.info('sort_align done')
				if cfg['deltmp']:
					clean_raw_align_output(task.jobs)
			else:
				log.error('sort_align failed: please check the following logs:')
				for subtask in task.run.unfinished_jobs:
					log.error(subtask.err)
				sys.exit(1)
		else:
			log.info('skip step: sort_align')

		cns_input = get_sort_align_output(task.jobs)
		task = Task(set_task(cfg, 'seed_cns', cns_input), dir_prefix='seed_cns', job_prefix=job_prefix, convert_path=False)
		if not task.is_finished():
			task.set_run(max_parallel_job=cfg['pa_correction'], job_type=cfg['job_type'], cpu=cfg['_cns_threads'],\
				use_drmaa=cfg['use_drmaa'], submit=cfg['submit'], kill=cfg['kill'], check_alive=cfg['check_alive'], \
				job_id_regex=cfg['job_id_regex'])
			pre_raw_idxs(cfg, '01.seed_cns.input.idxs')
			task.run.start()
			if task.run.is_finished():
				task.set_task_finished()
				log.info('seed_cns done')
			else:
				log.error('seed_cns failed: please check the following logs:')
				for subtask in task.run.unfinished_jobs:
					log.error(subtask.err)
				sys.exit(1)
		else:
			log.info('skip step: seed_cns')

		log.info('seed_cns finished, and final corrected reads file:')
		log.info('\033[35m %s/01.seed_cns.sh.work/seed_cns*/cns.fasta \033[0m' % cfg['cns_aligndir'])

	if cfg['task'] in ['all', 'assemble']:
		if cfg['input_type'] == 'corrected':
			task = Task(set_task(cfg, 'split_seed'), dir_prefix='split_seed', job_prefix=job_prefix)
			if not task.is_finished():
				task.set_run(max_parallel_job=1, job_type=cfg['job_type'], use_drmaa=cfg['use_drmaa'], \
					submit=cfg['submit'], kill=cfg['kill'], check_alive=cfg['check_alive'], job_id_regex=cfg['job_id_regex'])
				total_tasks = len(task.run.unfinished_jobs)
				task.run.start()
				while (not task.run.is_finished()):
					if len(task.run.unfinished_jobs) == total_tasks or not cfg['rerun']:
						log.error('split_seed failed: please check the following logs:')
						for subtask in task.run.unfinished_jobs:
							log.error(subtask.err)
						sys.exit(1)
					else:
						log.info(str(len(task.run.unfinished_jobs)) + ' split_seed jobs failed, and rerun for the '+ str(cfg['rerun']) + ' time')
						task.run.rerun()
						cfg['rerun'] -= 1
				else:
					task.set_task_finished()
					log.info('split_seed done')
			else:
				log.info('skip step: split_seed')
			cns_output = split_cns_output(task.jobs)
		else:
			cns_output = get_cns_output(task.jobs)

		task = Task(set_task(cfg, 'cns_align', cns_output), dir_prefix='cns_align', job_prefix=job_prefix, convert_path=False)
		if not task.is_finished():
			task.set_run(max_parallel_job=cfg['parallel_jobs'], job_type=cfg['job_type'], mem='30G', \
				cpu=cfg['_minimap2_threads'][1], use_drmaa=cfg['use_drmaa'], submit=cfg['submit'],\
				kill=cfg['kill'], check_alive=cfg['check_alive'], job_id_regex=cfg['job_id_regex'])
			total_tasks = len(task.run.unfinished_jobs)
			task.run.start()
			while (not task.run.is_finished()):
				if len(task.run.unfinished_jobs) == total_tasks or not cfg['rerun']:
					log.error('cns_align failed: please check the following logs:')
					for subtask in task.run.unfinished_jobs:
						log.error(subtask.err)
					sys.exit(1)
				else:
					log.info(str(len(task.run.unfinished_jobs)) + ' cns_align jobs failed, and rerun for the '+ str(cfg['rerun']) + ' time')
					task.run.rerun()
					cfg['rerun'] -= 1
			else:
				task.set_task_finished()
				log.info('cns_align done')
		else:
			log.info('skip step: cns_align')

		subptasks = task.jobs
		task = Task(set_task(cfg, 'ctg_graph'), dir_prefix='ctg_graph', job_prefix=job_prefix, convert_path=False)
		if not task.is_finished():
			task.set_run(max_parallel_job=cfg['parallel_jobs'], job_type = cfg['job_type'], mem='3G', cpu=1,\
				use_drmaa=cfg['use_drmaa'], submit=cfg['submit'], kill=cfg['kill'], \
				check_alive=cfg['check_alive'], job_id_regex=cfg['job_id_regex'])
			pre_cns_seqs(cfg, cns_output, '01.ctg_graph.input.seqs')
			pre_cns_ovl(cfg, get_cns_ovl_output(subptasks), '01.ctg_graph.input.ovls')
			task.run.start()
			if task.run.is_finished():
				task.set_task_finished()
				log.info('ctg_graph done')
			else:
				log.error('ctg_graph failed: please check the following logs:')
				for subtask in task.run.unfinished_jobs:
					log.error(subtask.err)
				sys.exit(1)
		else:
			log.info('skip step: ctg_graph')

		if cfg['_nextgraph_out_format'] != "fasta":
			log.info('nextDenovo finished')
			sys.exit(0)

		ctg_graph_output = get_ctg_graph_output(task.jobs)
		task = Task(set_task(cfg, 'ctg_align', ctg_graph_output), dir_prefix='ctg_align', \
			job_prefix=job_prefix, convert_path=False)
		if not task.is_finished():
			task.set_run(max_parallel_job=cfg['pa_correction'], job_type=cfg['job_type'], \
				cpu = cfg['_cns_threads'], use_drmaa=cfg['use_drmaa'], submit=cfg['submit'],\
				kill=cfg['kill'], check_alive=cfg['check_alive'], job_id_regex=cfg['job_id_regex'])
			task.run.start()
			if task.run.is_finished():
				task.set_task_finished()
				log.info('ctg_align done')
			else:
				log.error('ctg_align failed: please check the following logs:')
				for subtask in task.run.unfinished_jobs:
					log.error(subtask.err)
				sys.exit(1)
		else:
			log.info('skip step: ctg_align')

		subptasks = task.jobs
		seq_info = blc_genome(cfg['pa_correction'], ctg_graph_output)
		task = Task(set_task(cfg, 'ctg_cns', ctg_graph_output), dir_prefix='ctg_cns', \
			job_prefix=job_prefix, convert_path=False)
		if not task.is_finished():
			ctg_align_output = get_ctg_align_output(subptasks)
			task.set_run(max_parallel_job=cfg['pa_correction'], job_type=cfg['job_type'], \
				cpu=cfg['_cns_threads'], use_drmaa=cfg['use_drmaa'], submit=cfg['submit'],\
				kill=cfg['kill'], check_alive=cfg['check_alive'], job_id_regex=cfg['job_id_regex'])
			pre_ctg_cns_input(cfg['ctg_graphdir'] + '/03.ctg_cns.input.bams', ctg_align_output)
			task.run.start()
			if task.run.is_finished():
				task.set_task_finished()
				log.info('ctg_cns done')
				if cfg['deltmp']:
					clean_ctg_align_output(ctg_align_output)
			else:
				log.error('ctg_cns failed: please check the following logs:')
				for subtask in task.run.unfinished_jobs:
					log.error(subtask.err)
				sys.exit(1)
		else:
			log.info('skip step: ctg_cns')
		
		asm, stat = gather_ctg_cns_output(cfg, task.jobs, seq_info)
		log.info('nextDenovo finished')
		log.info('final assembly file:')
		log.info('\033[35m %s \033[0m' % asm)
		log.info('final stat file:')
		log.info('\033[35m %s.stat \033[0m' % (asm))
		log.info('asm stat:')
		log.info("\n" + stat)

if __name__ == '__main__':
	parser = argparse.ArgumentParser(
		add_help = False,
		formatter_class = HelpFormatter,
		description = '''
nextDenovo:
	Fast and accurate de novo assembler for long reads
exmples: 
	%(prog)s run.cfg

For more information about NextDenovo, see https://github.com/Nextomics/NextDenovo
'''
	)
	parser.version = '%(prog)s ' + getver(SCRIPT_PATH)
	parser.add_argument ('-l','--log',metavar = 'FILE',type = str, default = 'pidXXX.log.info',
		help = 'log file')
	parser.add_argument('-v', '--version', action='version')
	parser.add_argument('-h', '--help',  action='help',
		help = 'please use the config file to pass parameters')
	args = parser.parse_known_args()
	main(args)
