import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import DrawingOptions
from rdkit.Chem.Draw import rdMolDraw2D
from ofrcalc.classes.mol3D import *

def normalizeGrad(grad):
    ming = min(grad)
    maxg = max(grad)
    actr = max(grad) -  min(grad)
    gradn = 2.0/actr
    c = 1-gradn*maxg
    newgrad = [i*gradn + c for i in grad]
    aming = min(np.abs(grad))
    amaxg = max(grad)
    aactr = max(grad) -  min(grad)
    scaled_steps = [(abs(i)-aming)/aactr for i in grad]
    base1 = np.array([0,0,1.0])
    base2 = np.array([1.0,0,1.0])
    interf = [np.array([1,1,1])*(1-scaled_steps[j])+ \
              abs(scaled_steps[j])*(base1 + (0.5*vf+0.5)*base2) for j,vf in enumerate(newgrad)]
    return([tuple(i) for i in interf])

def normalizeAbsGrad(grad):
    aming = min(np.abs(grad))
    amaxg = max(np.abs(grad))
    aactr = amaxg -  aming
    scaled_steps = [(i-aming)/aactr for i in np.abs(grad)]
    base1 = np.array([0,0,1.0])
    base2 = np.array([1.0,-1.0,-1.0])
    interf = [np.array([1,1,1])*(1-vf)+ \
              vf*(base2) for j,vf in enumerate(scaled_steps)]
    return([tuple(i) for i in interf])



def gradientPrune(OFR, ignoreNames):
    """
    remove parts of the gradients for 
    the components of the features not used 
    in the model
    :param: OFR: OFR class to extract gradient
    :param: ignoreList: list of str, terms to discard
    :return: truncated gradient (np array)
    """

    # remove inds for rows we have dropped in the model
    firstr = [l[0] for l in OFR.descriptorDerivativeNames]
    badInd = []
    for ig in ignoreNames:
        badInd.append(firstr.index(ig))
        
    # remove rows of gradient
    selectedGrad = np.delete(OFR.descriptorDerivatives, badInd, 0)
    return(selectedGrad)
    
def projectGradient(targetGrad, OFR, train_var_x):
    """call to project model gradient 
    onto atom-based descriptors defined by
    OFR.
    :param: targetGrad: np.array model gradient
    :param: OFR: OFR class, to project onto
    :param: train_var_x: numpy array, gradient of model (for one call)
    :return: numpy array, projected gradient (num atoms x 1)
    """
    # derivatives we don't use
    ignoreNames = ['f.dI.0/dI0.all'] 
    # truncate model grad to RACs only
    targetGrad = targetGrad[5:-2] # NOT softcoded
    # rescalte
    targetGrad =targetGrad/np.sqrt(train_var_x[5:-2]).squeeze()
    # truncate OFR gradient
    racGrad = gradientPrune(OFR, ignoreNames)
    # inner product
    resultGrad = np.dot(np.transpose(racGrad),np.reshape(targetGrad,(targetGrad.shape[0],1)))
    return(resultGrad)    
    
    
    
def makeSVGFromRdMol(rdmol,grad, makeAbs = True):
    nats  = rdmol.GetNumAtoms()
    highlight= range(0,nats)
    if makeAbs:
        colors = dict(zip(highlight,normalizeGrad(grad)))
    else:
        colors = dict(zip(highlight,normalizeGrad(grad)))

    rdmol.Compute2DCoords()
    drawer = rdMolDraw2D.MolDraw2DSVG(800,400)
    drawer.drawOptions().useBWAtomPalette()
    drawer.DrawMolecule(rdmol,highlightAtoms=highlight,highlightAtomColors=colors,  highlightBonds=[])
    drawer.FinishDrawing()
    svg = drawer.GetDrawingText().replace('svg:','')
    return(svg)
    
def makeCairoFromRdMol(rdmol,grad,makeAbs = True):
    
    nats  = rdmol.GetNumAtoms()
    highlight= range(0,nats)
    if makeAbs:
        colors = dict(zip(highlight,normalizeAbsGrad(grad)))
    else:
        colors = dict(zip(highlight,normalizeGrad(grad)))
    rdmol.Compute2DCoords()
    drawer = rdMolDraw2D.MolDraw2DCairo(800,400)
    drawer.drawOptions().useBWAtomPalette()
    drawer.DrawMolecule(rdmol,highlightAtoms=highlight,highlightAtomColors=colors,  highlightBonds=[])
    drawer.FinishDrawing()
    return(drawer.GetDrawingText())
