import os
import numpy as np
import tensorflow as tf
from keras import backend as k

from rdkit import Chem

from ofrcalc.routines.visupfuns import *
from ofrcalc.classes.ofrclass import quickOFR
from ofrcalc.classes.mol3D import mol3D


def data_rescale(scaled_dat,train_mean,train_var):
    d = np.shape(train_mean)[0]
    dat = (np.multiply(scaled_dat,np.sqrt(train_var))) + train_mean.T
    return(dat)
def data_normalize(data,train_mean,train_var):
    d = np.shape(train_mean)[0]
    scaled_dat = np.divide((data - train_mean.T),np.sqrt(train_var),)
    return(scaled_dat)


def gradientEvaluationDraw(model, inputs, OFR, train_var_x, absGradients):
    # calculate gradient tensor allocator
    outputTensor = model.output 
    inputTensor = model.input 
    gradients = k.gradients(outputTensor, inputTensor)
    
    # calculate gradient 
    
    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())
    evaluatedGradients = sess.run(gradients,feed_dict={model.input:inputs})[0]
    sess.close()
    
    pics = []
    for j, egrad in enumerate(evaluatedGradients):  
        thisGrad = projectGradient(egrad, OFR, train_var_x)
        pics.append(makeCairoFromRdMol(OFR.rdmol, thisGrad,absGradients))
    
    return(pics)

def invokeFromSmiles(name, smiles, conIdx, vrange, NNdict,\
                    getGradients = False, absGradients = True):
    """
    main function to create an OFR class object from SIMLES
    and connection atoms, run model evaluation and return the
    results
    
    """
    # load the OFR
    thisMol = mol3D()
    thisMol.read_smiles(smiles,ff='uff',steps=200)
    thisOFR = quickOFR(name,thisMol,conIdx)
    thisOFR.setRdMolFromSmiles(smiles)
    
    # check connection set right
    conAtoms = [thisOFR.mol.getAtom(i).symbol() for i in thisOFR.conIdx]
    msg = name + " connection atom symbols: " + "/".join(conAtoms) 
    print(msg)
    
    # check correct species are given:
    for ca in conAtoms:
        if not ca in ['O','N']:
            msg =" Error for " + name + ": connection atom type must be O or N, found "+ "/".join(conAtoms) 
            msg += 'aborting'
            print(msg)
            sys.exit()
      
    # get descriptors
    thisOFR.getDescriptorVector()
    thisOFR.getDescriptorVectorDerivatives()
    thisOFR.getBondCounts()
    thisOFR.getSidechainCounts()
    
    
    # keep track of column inds we create
    colNames = ['speed','temp','fact_data.low.coverage','fact_data.medium.coverage']+thisOFR.descriptorNames
    #print(thisOFR.descriptorNames)
    #print(NNdict['feature_names'])
    # get the setup to pad
    baseExpDesign = [vrange,[45,120,150],[0,1]]
    # length of rows to generate:
    baseRows = np.array(np.meshgrid(*baseExpDesign)).T.reshape(-1,3)
    baseRows = np.column_stack([baseRows,1-baseRows[:,-1]])
    # create full matrix
    inputMat = np.column_stack([baseRows,np.tile(np.array(thisOFR.descriptors),(baseRows.shape[0],1))])
    
    # perform column swap to get correct ordering
    correctInds = [colNames.index(fn) for fn in NNdict['feature_names']]
    #print(correctInds)
    inputMat = inputMat[:,correctInds]
    colNames = NNdict['feature_names']
    
    # normalize data
    scaled_input = data_normalize(inputMat,NNdict['train_mean_x'].squeeze(),
                                    NNdict['train_var_x'].squeeze())
    
    # call the model
    predictions = data_rescale(NNdict['model'].predict(scaled_input),
                    train_mean=NNdict['train_mean_y'],
                    train_var=NNdict['train_var_y'])
    
    # check for negative values
    warnNeg = False
    for i,p in enumerate(predictions):
        if p < 0.0:
            warnNeg = True
            predictions[i] = 0.00
    if warnNeg:
        msg = 'Warning: prediction results are out of bounds (<0.0) '
        msg += 'for  candidate' + name + '. Setting to 0. Be cautious,'
        msg += 'results may be unreliable'
        print(msg)
        
    # make rowname labels
    baseLabels = ['OFR','speed','temperature','coverage']
    conditionNames = ['name','smiles','conIdx','velocity[m/s]','temperature[C]','coverage']
    baseNames =  []
    conditions = []
    for r in baseRows:
        vel = str(r[0])
        T = str(int(r[1]))
        if r[2] == 1:
            cov = 'low'
        else:
            cov = 'medium'
            
        conditions.append([name,smiles," ".join([str(i) for i in conIdx]),vel,T,cov])
        thisName = "_".join([name,'speed',vel,'temp',T,'cov',cov])
        baseNames.append(thisName)
    
    # get gradients if requested
    if getGradients:
        gradPics = gradientEvaluationDraw(model = NNdict['model'],
                                        inputs = scaled_input,
                                        OFR =  thisOFR,    
                                        train_var_x = NNdict['train_var_x'],
                                        absGradients = absGradients)
    else:
        gradPics =  False
    msg = name + ': average COF across all ' + str(len(predictions))+  ' cases is ' + str(round(np.mean(predictions),3))
    print(msg)        
    # construct results object
    results = {"conditions":conditions, "predictions":predictions,\
                "OFR":thisOFR, "gradPics":gradPics,\
                "conditionNames": conditionNames, "baseNames":baseNames} 
          
    return(results)
    
    
