import os, pickle
import numpy as np
from rdkit import Chem
from keras.models import model_from_json
from pkg_resources import resource_filename, Requirement
import openbabel


def loadKerasModel(path,name):    
    """
    function to load a pre-trainned
    keras neural network model 
    
    :param: path: path to ANN folder (in package)
    :param: path: string, name of model
    :return: keras model (the neural network object)
    """ 
    # load model
    json_file = open(path+'/'+name+'_model.json', 'r')
    loaded_model_json = json_file.read()
    json_file.close()
    loaded_model = model_from_json(loaded_model_json)
    #load weights into new model
    loaded_model.load_weights(path+'/'+name+'_model.h5', 'r')
    loaded_model.compile(loss="mse",optimizer='sgd',
              metrics=['mse', 'mae', 'mape'])   
    return(loaded_model)


def loadNormalizationData(path,name):
    """
    function to load the trainning
    data-scaling numbers needed to keep
    consistent with training
        
    :param: path: path to ANN folder (in package)
    :param: path: string, name of model
    :return: np.arrays for trainning mean/var data
    """ 
    train_mean_x = list()
    path_to_file = path + name + '_mean_x.csv'
    if os.path.isfile(path_to_file):
        with open(path_to_file, 'r') as f:
            for lines in f.readlines():
                train_mean_x.append([float(lines.strip().strip('[]'))])
    train_var_x = list()
    path_to_file = path + name + '_var_x.csv'
    with open(path_to_file, 'r') as f:
        for lines in f.readlines():
            train_var_x.append([float(lines.strip().strip('[]'))])
    train_mean_y = list()
    path_to_file = path + name + '_mean_y.csv'
    with open(path_to_file, 'r') as f:
        for lines in f.readlines():
            train_mean_y.append([float(lines.strip().strip('[]'))])
    train_var_y = list()
    path_to_file =path + name + '_var_y.csv'
    with open(path_to_file, 'r') as f:
        for lines in f.readlines():
            train_var_y.append([float(lines.strip().strip('[]'))])
    
    train_mean_x = np.array(train_mean_x)
    train_var_x = np.array(train_var_x)
    train_mean_y = np.array(train_mean_y)
    train_var_y = np.array(train_var_y)

    return train_mean_x, train_mean_y, train_var_x, train_var_y

def loadANNModel():
    """
    function to load the saved
    neural network model and parameters
    this load can be the slowest
    part of the whole operation so
    should only be once off
    
    :return: dictionary, containing loaded data
    """ 
    
    # get path info from package manager
    ofrpath = resource_filename(Requirement.parse("ofrcalc"),"ofrcalc/data/ofrmod")
    
    # load the ANN model and normalzation data
    
    model = loadKerasModel(path=ofrpath, name='single-ANN')
    train_mean_x, train_mean_y, train_var_x, train_var_y = loadNormalizationData(path=ofrpath+'/', name='single-ANN')
    feature_names = pickle.load(open(ofrpath+"/feature_names.p", "rb"))
    NNdict = {"model":model,"feature_names":feature_names,
              "train_mean_x":train_mean_x,"train_mean_y":train_mean_y,
              "train_var_x":train_var_x,"train_var_y":train_var_y}
    return(NNdict)

def loadInputFile(inpath):
    """
    function to read the user
    supplied input file and check
    that it can be understood by the 
    code
    
    :param: inpath: path to input file
    :return: dictionary, containing loaded data
    """ 
    inputs = []
    errors = []
    valid = True
    with open(inpath,'r') as f:
        if valid:
            for i, lines in enumerate(f.readlines()):
                ll = lines.strip().split(' ')
                if not len(ll) == 4:
                    errors.append('error : line number '+str(i) + ' is the wrong length')
                    valid = False
                    thisInput = {}
                else:
                    name, smiles, con1, con2 = ll
                    try: 
                        con1 = int(con1)
                        con2 = int(con2)
                    except:
                        errors.append('error : line number '+str(i) + ' connection atom was not understood')
                        valid = False

                    thisInput = {"name":name,'smiles':smiles,'conIdx': [con1, con2 ]}
                    if valid:
                        valid, valErrors = inputValidator(thisInput)
                if not valid:
                    print(errors)
                    print(valErrors)
                    sys.exit()
                else:
                    inputs.append(thisInput)
                    msg='valid input detected at line number '+str(i) 
                    print(msg)
    return(inputs)

def gradientPicSave(results, picDir):
    """
    function to take results, including gradient info
    from procsmiles.invokeFromSmilles and save the images
    into the indicated directory
    :param: results: dictionary, prepared by procsmiles.invokeFromSmiles
    :param: picDir: path to dir
    """
    
    if not os.path.isdir(picDir):
        os.makedirs(picDir)
    for pic,l in zip(results['gradPics'],results['baseNames']):
        with open(picDir+l+'.png','wb') as f:
            f.write(pic)
    
def resultWriter(resultsList, resDir):
    """
    function to take results
    from procsmiles.invokeFromSmilles and save
    COF inro into the indicated directory
    :param: results: dictionary, prepared by procsmiles.invokeFromSmiles
    :param: resDir: path to dir
    """ 
    if not os.path.isdir(resDir):
        os.mkdir(resDir)   
    ## make results table
    with open(resDir.strip('/') + '/results.csv','w') as wf:
        wf.write(','.join(resultsList[0]["conditionNames"])+',COF\n') 
        for rn in resultsList:
                 for i, cn in enumerate(rn["conditions"]):
                     wf.write(','.join([str(j) for j in cn])+  ',' + str(np.round(rn["predictions"][i][0],4))+'\n')
def resultRead(path):
    """
    function to read results from resultWriter
    and return the average COF and # of results. This
    is intended for use in testing only. 
    :param: path: path to results file
    :return: int, float number of and mean COF in file
    """ 
    # list to hold COFs
    COFs = []
    with open(path, 'r') as f:
        for i, l in enumerate(f.readlines()):
            if i > 0:
                ll = l.strip().split(",")
                COFs.append(float(ll[-1]))
    return(len(COFs),np.mean(COFs))
    
    
def inputValidator(ins):
    """
    function to test and validate
    inputs harvested from infiles and provide 
    useful errors to the user
    
    :param: ins: dictionary, containing info from one line of input
    :return: bool, list of errors the status of the input
    """ 
    errors = []
    valid = True
    
    # test presence of keys:
    for key in ['smiles','name','conIdx']:
        if not key in ins.keys():
            errors.append('error: key '+ key +' missing from input')
            valid = False
    if not valid:
        return(valid, errors)
    else:
        # check smiles
        try:
            rmdol = Chem.MolFromSmiles(ins['smiles'])
        except:
            errors.append('error: SMILEs '+ str(ins['smiles']) +' not able to be understood by rdkit')
            valid = False
        # check conIdx
        conIdx = ins['conIdx']
        if not len(conIdx) == 2:
            errors.append('error: OFRCalc is trained on bidentate binding.'\
                          'Instead '+ str(len(ins['conIdx'])) +' atoms were given. Please choose 2 instead.')
            valid = False
    return(valid, errors)


def velocityValidator(args):
    """
    function to test and validate
    velocity inputs provied at command line
    and provide
    useful errors to the user
    
    :param: args, the argparse object
    :return: np.array, velocity range for ANN eval
    :return: bool, is the input valid
    :return: list of errors the status of the input
    """ 
    vrange = []
    errors = []
    valid = True
    
    # test presence of keys:
    for arg in ['vlow','vhigh','vsteps']:
        if not hasattr(args, arg):
            errors.append('error: arg '+ arg +' missing from input')
            valid = False
    if not valid:
        return(vrange,valid, errors)
    else:
        # check the range
        if args.vhigh - args.vlow < 0:
            errors.append('error: vhigh is smaller than vsmall. Please correct')
            valid = False
        # check they are in within sane limits
        if args.vhigh > 10.5 or args.vhigh <0.5:
            errors.append('error: vhigh out of range! must be <= 10.5 m/s and > 0.5 m/s, training data only extends to 10 m/s.')
            valid = False
        if args.vlow < 0.25 or args.vlow >10.0:
            errors.append('error: vlow out of range! must be >= 0.25 m/s and < 10.0 m/s, training data only extends to 0.5 m/s.')
            valid = False
        # check vsteps
        if args.vsteps > 1000 or args.vsteps  < 1:
            errors.append('error: vsteps out of range! must be >= 1 and <= 1000.')
            valid = False
    if valid:
        vrange = np.linspace(start=args.vlow,stop=args.vhigh,num=args.vsteps)
        
    return(vrange, valid, errors)
   
#count bonds of an atom
def bondCount(aa):
    bCount = 0
    for b in openbabel.OBAtomBondIter(aa):
        bCount = bCount + 1
    return bCount

#find acceptable pairs
def find_pairs(conlist):
    """
    function finds out the possible pairs for
    connection atoms
    if conlist length < 2, monodentate
    if conlist length >= 2 and pair distance > 4, monodentate
    if conlist length >= 2 and pair distance <= 3, list all pairs
    """
    lconlist = len(conlist)
    pairlist = []
    if lconlist < 2:
        pairlist.append([conlist[0],conlist[0]])
    else:
        for i in range(0,lconlist-1):
            if abs(conlist[i]-conlist[i+1]) > 4:
                pairlist.append([conlist[i],conlist[i]])
            if abs(conlist[i]-conlist[i+1]) <= 4:
                pairlist.append([conlist[i],conlist[i+1]])

    return pairlist

#find non-C atoms
def find_nonCAtoms(obmol):
    """
    function takes the SMILES
    return: A list of nonC atoms indeses
    """
    nonC = []
    for obatom in openbabel.OBMolAtomIter(obmol):    
        if obatom.GetAtomicNum() != 6:
            nonC.append([obatom.GetAtomicNum(),obatom.GetIdx(),bondCount(obatom)])
    # remove non possible cases        
    nonCindx = []
    for item in nonC:
        if item[0] == 8 and item[2] > 1:
            pass
        if item[0] == 7 and item[2] > 1:
            pass
        else:
            nonCindx.append(item[1]-1)
    return nonCindx

def filt_out(mystr):
    newstr = ""
    for i in range(0,len(mystr)):
        if mystr[i] not in ['=','-','(',')','[',']','1','2','.','/']:
            newstr = newstr+mystr[i]
    return newstr

def genInputFile(filename):
    hgSMARTS=['[#6](-[#8])=[#8]',
       '[#6](-[#8])-[#8]',
        '[#6](-[#8])-[#7]', # alc-amide
        '[*](-[#8])-[#8]', 
        '[#6](-[#7])-[#7]', # amide-amide
       '[#6](-[#7])=[#8]', # carbonyl-amide
         '[#6](-[#6]-[#8])-[#8]',
         '[#6](-[#6]-[#8])=[#8]',
         '[#6](-[#6]=[#8])=[#8]',
         '[#6](-[#6]-[#7])-[#8]',
         '[#6](-[#6]-[#7])=[#8]',
         '[#6](-[#6]-[#7])-[#7]',
         '[#6](-[#6]-[#8])-[#6]-[#8]',
         '[#6](-[#6]-[#8])-[#6]=[#8]',
         '[#6](-[#6]=[#8])-[#6]=[#8]']

    cdxml = filename
    obconv = openbabel.OBConversion()  # ob Class
    obmol = openbabel.OBMol()  # ob Class
    obconv.SetInAndOutFormats('cdxml','smi')  # ob Method to set cdxml
    obconv.ReadFile(obmol, cdxml)  # ob Method to read cdxml into a OBMol()
    obSMI = obconv.WriteFile(obmol,filename+'.smi')
    obconv.SetInFormat('smi')
    obconv.ReadFile(obmol, filename+'.smi')

    with open(filename+'.smi','r') as f:
        for line in f.readlines():
            a=line.split()
    baseSMI = a[0]

    m1 = Chem.MolFromSmiles(baseSMI)
    
    condict = {}
    counter = 0
    for patt in hgSMARTS:
        ipatt = Chem.MolFromSmarts(patt)
        if m1.HasSubstructMatch(ipatt):
            conlist = []
            for i in m1.GetSubstructMatch(ipatt):
                atom  = m1.GetAtomWithIdx(i)
                if atom.GetSymbol() in ['O','N'] and len(atom.GetBonds()) == 1.0:
                    conlist.append(i)
            if conlist:        
                condict.update({counter:conlist})
            counter = counter + 1
    with open(filename.split('.')[0]+'.txt','w') as wf:
        for i in condict:
            pair = condict[i]
            if len(pair) == 1:
                wf.write('mol'+str(i)+' '+baseSMI+' '+str(pair[0])+' '+str(pair[0])+'\n')
            elif len(pair) == 2:
                wf.write('mol'+str(i)+' '+baseSMI+' '+str(pair[0])+' '+str(pair[1])+'\n')
            else:
                smi=filt_out(baseSMI)
                id1 = len(smi)-2
                id2 = len(smi)-1
                wf.write('mol0 '+baseSMI+' '+str(id1)+' '+str(id2)+'\n')

