
from ofrcalc.routines.autocorrelation import*
import time
from rdkit import Chem


class quickOFR:
    def __init__(self,name,mol,conIdx):
        """
        constructor for OFR class
        :param: name str name of ofr
        :param: mol  mol3D object for ofr
        :param: conIdx  list of int, connect atom indices
        """
        self.name =  name
        self.mol = mol
        self.conIdx = conIdx
        self.smiles = False
        ## descriptors
        self.set_desc = False
        self.descriptors =  list()
        self.descriptorNames =  list()
        self.descriptorDerivatives =  None
        self.descriptorDerivativeNames =  list()
        self.rdmol =  False
        
        
    def setRdMolFromSmiles(self, smiles):
        """
        creates RDKit mol object from smiles
        :param: smiles str SMILES string for OFR
        """
        # save the SMILEs to the rdmol objet
        self.smiles = smiles
        # save this rdmol for drawing later
        self.rdmol =  Chem.MolFromSmiles(self.smiles) 
    
    def getDescriptorVector(self,loud=False,name=False):
        """
        computes molecular AC fingergrpint derivatives
        :param: loud bool print debug info
        """
        ## full ACs
        resultsDictionary = generate_full_complex_autocorrelations(self.mol,depth=4,loud=loud,oct=False)
        self.appendDescriptors(resultsDictionary['colnames'],resultsDictionary['results'],'f','all')
        ## connennection atom acs 
        resultsDictionary = generate_atomonly_autocorrelations(self.mol,self.conIdx,depth=5,loud=loud,oct=False)
        self.appendDescriptors(resultsDictionary['colnames'],resultsDictionary['results'],'con','all')
        
    def appendDescriptors(self,list_of_names,list_of_props,prefix,suffix):
        """
        adds the AC molecular fingerprints to the object
        :param: list_of_names list of names of features
        :param: list_of_props list of descriptors
        :param: prefix str prefix for AC names
        :param: suffix str suffix for AC names
        """  
        for names in list_of_names:
            if hasattr(names, '__iter__'):
                names = [".".join([prefix,str(i).replace('-','.'),suffix]) for i in names]
                self.descriptorNames += names
            else:
                names = ".".join([prefix,str(names).replace('-','.'),suffix])
                self.descriptorNames.append(names)
        for values in list_of_props:
            if hasattr(values, '__iter__'):
                self.descriptors.extend(values)
            else:
                self.descriptors.append(values)

    def getDescriptorVectorDerivatives(self,loud=False):
        """
        computes molecular AC fingergrpint derivatives
        :param: loud bool print debug info
        """
        ## full ACs
        resultsDictionary = generate_full_complex_autocorrelation_derivatives(self.mol,depth=4,loud=loud,oct=False)
        self.appendDescriptorDerivatives(resultsDictionary['colnames'],resultsDictionary['results'],'f','all')
        ## connennection atom acs 
        resultsDictionary = generate_atomonly_autocorrelation_derivatives(self.mol,self.conIdx,depth=5,loud=loud,oct=False)
        self.appendDescriptorDerivatives(resultsDictionary['colnames'],resultsDictionary['results'],'con','all')
    
    def appendDescriptorDerivatives(self,mat_of_names,dmat,prefix,suffix):
        """
        adds the AC molecular fingerprint gradients to the object
        :param: mat_of_names list of list of names of features
        :param: dmat np array of descriptors
        :param: prefix str prefix for AC names
        :param: suffix str suffix for AC names
        """  
        for names in mat_of_names:
            jnames = [".".join([prefix,str(i).replace('-','.'),suffix]) for i in names]
            self.descriptorDerivativeNames.append(jnames)
        if self.descriptorDerivatives is None:
            self.descriptorDerivatives = dmat
        else:
            self.descriptorDerivatives =  np.row_stack([self.descriptorDerivatives,dmat]) 
    
    def getBondCounts(self):
        """
        Find the count of  C=C double bonds and 
        CO groups in the SMILES string using rdkit
        """  
        dbBondProbe = Chem.MolFromSmarts('C=C')
        coBondProbe = Chem.MolFromSmarts('C=O')       
        
        # use rd subtructure to find out if there are double bonds/CO groups
        dbCount  = len(self.rdmol.GetSubstructMatches(dbBondProbe))
        coCount = len(self.rdmol.GetSubstructMatches(coBondProbe))
        # add to vector
        self.descriptors.extend([dbCount,coCount])
        self.descriptorNames.extend(['db_count','co_count'])
        
    def filt_out(self,mystr):
        """
        Filter out the special charectors for recognizing specific
        patterns
        """
        newstr = ""
        for i in range(0,len(mystr)):
            if mystr[i] not in ['=','-','(',')','[',']','1','2','.','/']:
                newstr = newstr+mystr[i]
        return newstr
        
    def getSidechainCounts(self):
        """
        Find the count of side chains in the SMILES string
        using rdkit
        """
                       
        n_chain = 0
        list_sidechain = []
        newstr = ""
        j = 0

        if self.rdmol.GetRingInfo().NumRings() > 0:
            n_chain = 1
        else:
            for i in range(0,len(self.smiles)):
                if self.smiles[i] == '(':
                    j = i + 1
                    while self.smiles[j] != ')':
                        newstr = newstr+self.smiles[j]
                        j = j + 1
                    list_sidechain.append(newstr)
                    i = j
            for i,sub in enumerate(tuple(list_sidechain)):

                newsub = self.filt_out(sub)
                if 'O' in sub:
                        print('removing')
                        list_sidechain.remove(sub)
            n_chain = len(list_sidechain)
        
        self.descriptors.extend([n_chain])
        self.descriptorNames.extend(['Number.of.tail.side.chains'])
