
import sys
import os
import numpy as np
import math
import random 
from ofrcalc.classes.atom3D import *
from ofrcalc.classes.mol3D import *
from ofrcalc.classes.globalvars import globalvars



########### UNIT CONVERSION
HF_to_Kcal_mol = 627.503


def autocorrelation(mol, prop_vec, orig, d, oct=True, catoms=None):
    ## this function returns the autocorrelation
    ## for one atom
    # Inputs:
    #	mol - mol3D class
    #	prop_vec - vector, property of atoms in mol in order of index
    #	orig -  int, zero-indexed starting atom
    #	d - int, number of hops to travel
    #	oct - bool, if complex is octahedral, will use better bond checks
    result_vector = np.zeros(d + 1)
    hopped = 0
    active_set = set([orig])
    historical_set = set()
    result_vector[hopped] = prop_vec[orig] * prop_vec[orig]
    while hopped < (d):

        hopped += 1
        new_active_set = set()
        for this_atom in active_set:
            ## prepare all atoms attached to this connection
            # print('called in AC')
            this_atoms_neighbors = mol.getBondedAtomsSmart(this_atom, oct=oct)
            for bound_atoms in this_atoms_neighbors:
                if (bound_atoms not in historical_set) and (bound_atoms not in active_set):
                    new_active_set.add(bound_atoms)
        # print('new active set at hop = ' +str(hopped) + ' is ' +str(new_active_set))
        for inds in new_active_set:
            result_vector[hopped] += prop_vec[orig] * prop_vec[inds]
            historical_set.update(active_set)
        active_set = new_active_set
    return (result_vector)

def autocorrelation_derivative(mol, prop_vec, orig, d, oct=True, catoms=None):
    ## this function returns the derivative vector
    ## of the scalar autocorrelation 
    ## starting at orig with depth d,
    ## with respect to the atomic properties
    ## in prop_vec, for all atoms.
    ## The return type is np.array
    ## Be sure to read this carefully!
    ## for one atom
    # Inputs:
    #	mol - mol3D class
    #	prop_vec - vector, property of atoms in mol in order of index
    #	orig -  int, zero-indexed starting atom
    #	d - int, number of hops to travel
    #	oct - bool, if complex is octahedral, will use better bond checks
    derivative_mat = np.zeros((d + 1,len(prop_vec)))
    
    # loop for each atom

    hopped = 0
    active_set = set([orig])
    historical_set = set()
    for derivate_ind in range(0,len(prop_vec)):
        if derivate_ind == orig:
            derivative_mat[hopped,derivate_ind] = 2*prop_vec[orig]
        else:
            derivative_mat[hopped,derivate_ind] = 0
    while hopped < (d):

        hopped += 1
        new_active_set = set()
        for this_atom in active_set:
            ## prepare all atoms attached to this connection
            # print('called in AC')
            this_atoms_neighbors = mol.getBondedAtomsSmart(this_atom, oct=oct)
            for bound_atoms in this_atoms_neighbors:
                if (bound_atoms not in historical_set) and (bound_atoms not in active_set):
                    new_active_set.add(bound_atoms)
        # print('new active set at hop = ' +str(hopped) + ' is ' +str(new_active_set))
        for inds in new_active_set:
            for derivate_ind in range(0,len(prop_vec)):
                if derivate_ind == orig:
                    derivative_mat[hopped,derivate_ind] += prop_vec[inds]
                elif derivate_ind == inds:
                    derivative_mat[hopped,derivate_ind] += prop_vec[orig]
            historical_set.update(active_set)
        active_set = new_active_set
    return (derivative_mat)
	





def full_autocorrelation(mol, prop, d, oct=oct, modifier= False):
    w = construct_property_vector(mol, prop, oct=oct, modifier=modifier)
    index_set = range(0, mol.natoms)
    autocorrelation_vector = np.zeros(d + 1)
    for centers in index_set:
        autocorrelation_vector += autocorrelation(mol, w, centers, d, oct=oct)
    return (autocorrelation_vector)

def full_autocorrelation_derivative(mol, prop, d, oct=oct, modifier= False):
    w = construct_property_vector(mol, prop, oct=oct, modifier=modifier)
    index_set = range(0, mol.natoms)
    autocorrelation_derivative_mat = np.zeros((d + 1,mol.natoms))
    for centers in index_set:
        autocorrelation_derivative_mat += autocorrelation_derivative(mol, w, centers, d, oct=oct)
    return (autocorrelation_derivative_mat)

def atom_only_autocorrelation(mol, prop, d, atomIdx, oct=True):
    ## atomIdx must b either a list of indcies
    ## or a single index
    w = construct_property_vector(mol, prop, oct)
    autocorrelation_vector = np.zeros(d + 1)
    if hasattr(atomIdx, "__len__"):
        for elements in atomIdx:
            autocorrelation_vector += autocorrelation(mol, w, elements, d, oct=oct)
        autocorrelation_vector = np.divide(autocorrelation_vector, len(atomIdx))
    else:
        autocorrelation_vector += autocorrelation(mol, w, atomIdx, d, oct=oct)
    return (autocorrelation_vector)

def atom_only_autocorrelation_derivative(mol, prop, d, atomIdx, oct=True):
    ## atomIdx must b either a list of indcies
    ## or a single index
    w = construct_property_vector(mol, prop, oct)
    autocorrelation_derivative_mat = np.zeros((d + 1,mol.natoms))
    if hasattr(atomIdx, "__len__"):
        for elements in atomIdx:
            autocorrelation_derivative_mat += autocorrelation_derivative(mol, w, elements, d, oct=oct)
        autocorrelation_derivative_mat = np.divide(autocorrelation_derivative_mat, len(atomIdx))
    else:
        autocorrelation_derivative_mat += autocorrelation_derivative(mol, w, atomIdx, d, oct=oct)
    return (autocorrelation_derivative_mat)
    

def atom_only_deltametric(mol, prop, d, atomIdx, oct=True,modifier=False):
    ## atomIdx must b either a list of indcies
    ## or a single index
    w = construct_property_vector(mol, prop, oct=oct,modifier=modifier)
    
    deltametric_vector = np.zeros(d + 1)
    if hasattr(atomIdx, "__len__"):
        for elements in atomIdx:
            deltametric_vector += deltametric(mol, w, elements, d, oct=oct)
        deltametric_vector = np.divide(deltametric_vector, len(atomIdx))
    else:
        deltametric_vector += deltametric(mol, w, atomIdx, d, oct=oct)
    return (deltametric_vector)

def atom_only_deltametric_derivative(mol, prop, d, atomIdx, oct=True,modifier=False):
    ## atomIdx must b either a list of indcies
    ## or a single index
    w = construct_property_vector(mol, prop, oct=oct,modifier=modifier)
    
    deltametric_derivative_mat = np.zeros((d + 1,mol.natoms))
    if hasattr(atomIdx, "__len__"):
        for elements in atomIdx:
            deltametric_derivative_mat += deltametric_derivative(mol, w, elements, d, oct=oct)
        deltametric_derivative_mat = np.divide(deltametric_derivative_mat, len(atomIdx))
    else:

        deltametric_derivative_mat += deltametric_derivative(mol, w, atomIdx, d, oct=oct)
    return (deltametric_derivative_mat)



def construct_property_vector(mol, prop, oct=True,modifier = False):
    ## assigns the value of property
    ## for atom i (zero index) in mol
    ## to position i in returned vector
    ## can be used to create weighted
    ## graph representations
    ## oct - bool, if complex is octahedral, will use better bond checks
    ## modifier - dict, used to modify prop vector (e.g. for adding 
    ##             ONLY used with  ox_nuclear_charge    ox or charge)
    ##              {"Fe":2, "Co": 3} etc
    allowed_strings = ['electronegativity', 'nuclear_charge', 'ident', 'topology',
                        'ox_nuclear_charge', 'size', 'vdwrad', 'effective_nuclear_charge']
    ## note that ident just codes every atom as one, this gives
    ## a purely toplogical index. coord gives the number of
    ## connecting atom to attom i (similar to Randic index)
    # if not oct:
    #     print('NOT using octahedral bonding pattern')
    globs = globalvars()
    prop_dict = dict()
    w = np.zeros(mol.natoms)
    done = False
    if not prop in allowed_strings:
        print('error, property  ' + str(prop) + ' is not a vaild choice')
        print(' options are  ' + str(allowed_strings))
        return False
    if prop == 'electronegativity':
        prop_dict = globs.endict()
    elif prop == 'size':
        at_keys = globs.amass().keys()
        for keys in at_keys:
            values = globs.amass()[keys][2]
            prop_dict.update({keys: values})
    elif prop == 'nuclear_charge':
        at_keys = globs.amass().keys()
        for keys in at_keys:
            values = globs.amass()[keys][1]
            prop_dict.update({keys: values})
    elif prop == 'effective_nuclear_charge': #Uses number of valence electrons
        if not modifier:
            at_keys = globs.amass().keys()
            for keys in at_keys:
                values = globs.amass()[keys][3]
                prop_dict.update({keys: values})
        else:
            at_keys = globs.amass().keys()
            for keys in at_keys:
                values = globs.amass()[keys][3]
                if keys in modifier.keys():
                    values += float(modifier[keys])
                prop_dict.update({keys: values})
    elif prop == 'ox_nuclear_charge':
        if not modifier:
            print('Error, must give modifier with ox_nuclear_charge')
            return False
        else:
            at_keys = globs.amass().keys()
            for keys in at_keys:
                values = globs.amass()[keys][1]
                if keys in modifier.keys():
                    values += float(modifier[keys])
                prop_dict.update({keys: values})
    elif prop == 'ident':
        at_keys = globs.amass().keys()
        for keys in at_keys:
            prop_dict.update({keys: 1})
    elif prop == 'topology':
        for i, atoms in enumerate(mol.getAtoms()):
            # print('atom # ' + str(i) + " symbol =  " + str(atoms.symbol()))
            w[i] = len(mol.getBondedAtomsSmart(i, oct=oct))
        done = True
    elif prop == 'vdwrad':
        prop_dict = globs.vdwrad()
        for i, atoms in enumerate(mol.getAtoms()):
            atom_type = atoms.symbol()
            if atom_type in globs.metalslist():
                w[i] = globs.amass()[atoms.symbol()][2]
            else:
                w[i] = prop_dict[atoms.symbol()]
        done = True
        # for keys in at_keys:
        #     prop_dict.update({keys: 1})
    if not done:
        for i, atoms in enumerate(mol.getAtoms()):
            # print('atom # ' + str(i) + " symbol =  " + str(atoms.symbol()))
            w[i] = prop_dict[atoms.symbol()]
    return (w)





def generate_full_complex_autocorrelations(mol, loud, depth=4, oct=True, flag_name=False, modifier=False):
    result = list()
    colnames = []
    allowed_strings = ['electronegativity', 'nuclear_charge', 'ident', 'topology', 'size', 'effective_nuclear_charge']
    labels_strings = ['chi', 'Z', 'I', 'T', 'S', 'Zeff']
    for ii, properties in enumerate(allowed_strings):
        metal_ac = full_autocorrelation(mol, properties, depth, oct=oct, modifier=modifier)
        this_colnames = []
        for i in range(0, depth + 1):
            this_colnames.append(labels_strings[ii] + '-' + str(i))
        colnames.append(this_colnames)
        result.append(metal_ac)
    if flag_name:
        results_dictionary = {'colnames': colnames, 'results_f_all': result}
    else:
        results_dictionary = {'colnames': colnames, 'results': result}
    return results_dictionary

def generate_full_complex_autocorrelation_derivatives(mol, loud, depth=4, oct=True, flag_name=False, modifier=False):
    result = None
    colnames = []
    allowed_strings = ['electronegativity', 'nuclear_charge', 'ident', 'topology', 'size', 'effective_nuclear_charge']
    labels_strings = ['chi', 'Z', 'I', 'T', 'S', 'Zeff']
    for ii, properties in enumerate(allowed_strings):
        f_ac_der = full_autocorrelation_derivative(mol, properties, depth, oct=oct, modifier=modifier)
        this_colnames = []
        for i in range(0, depth + 1):
            colnames.append(['d'+labels_strings[ii] + '-' + str(i)+ '/d' + labels_strings[ii] + str(j) for j in range(0, mol.natoms)])
        #colnames.append(this_colnames)
        if result is None:
            result = f_ac_der
        else:
            result = np.row_stack([result,f_ac_der])
    if flag_name:
        results_dictionary = {'colnames': colnames, 'results_f_all': result}
    else:
        results_dictionary = {'colnames': colnames, 'results': result}
    return results_dictionary


def generate_atomonly_autocorrelations(mol, atomIdx, loud, depth=4, oct=True):
    ## this function gets autocorrelations for a molecule starting
    ## in one single atom only
    # Inputs:
    #       mol - mol3D class
    #       atomIdx - int, index of atom3D class
    #       loud - bool, print output
    result = list()
    colnames = []
    allowed_strings = ['electronegativity', 'nuclear_charge', 'ident', 'topology', 'size', 'effective_nuclear_charge']
    labels_strings = ['chi', 'Z', 'I', 'T', 'S', 'Zeff']
    #print('The selected connection type is ' + str(mol.getAtom(atomIdx).symbol()))
    for ii, properties in enumerate(allowed_strings):
        atom_only_ac = atom_only_autocorrelation(mol, properties, depth, atomIdx, oct=oct)
        this_colnames = []
        for i in range(0, depth + 1):
            this_colnames.append(labels_strings[ii] + '-' + str(i))
        colnames.append(this_colnames)
        result.append(atom_only_ac)
    results_dictionary = {'colnames': colnames, 'results': result}
    return results_dictionary
    
def generate_atomonly_autocorrelation_derivatives(mol, atomIdx, loud, depth=4, oct=True):
    ## this function gets the d/dx for autocorrelations for a molecule starting
    ## in one single atom only
    # Inputs:
    #       mol - mol3D class
    #       atomIdx - int, index of atom3D class
    #       loud - bool, print output
    result = None
    colnames = []
    allowed_strings = ['electronegativity', 'nuclear_charge', 'ident', 'topology', 'size', 'effective_nuclear_charge']
    labels_strings = ['chi', 'Z', 'I', 'T', 'S', 'Zeff']
    #print('The selected connection type is ' + str(mol.getAtom(atomIdx).symbol()))
    for ii, properties in enumerate(allowed_strings):
        atom_only_ac = atom_only_autocorrelation_derivative(mol, properties, depth, atomIdx, oct=oct)
        this_colnames = []
        for i in range(0, depth + 1):
            colnames.append(['d'+labels_strings[ii] + '-' + str(i)+ '/d' + labels_strings[ii] + str(j) for j in range(0, mol.natoms)])
        if result is None:
            result = atom_only_ac
        else:
            result = np.row_stack([result,atom_only_ac])
    results_dictionary = {'colnames': colnames, 'results': result}
    return results_dictionary

def generate_atomonly_deltametrics(mol, atomIdx, loud, depth=4, oct=True):
    ## this function gets deltametrics for a molecule starting
    ## in one single atom only
    # Inputs:
    #       mol - mol3D class
    #       atomIdx - int, index of atom3D class
    #       loud - bool, print output
    result = list()
    colnames = []
    allowed_strings = ['electronegativity', 'nuclear_charge', 'ident', 'topology', 'size', 'effective_nuclear_charge']
    labels_strings = ['chi', 'Z', 'I', 'T', 'S', 'Zeff']
    # print('The selected connection type is ' + str(mol.getAtom(atomIdx).symbol()))
    for ii, properties in enumerate(allowed_strings):
        atom_only_ac = atom_only_deltametric(mol, properties, depth, atomIdx, oct=oct)
        this_colnames = []
        for i in range(0, depth + 1):
            this_colnames.append(labels_strings[ii] + '-' + str(i))
        colnames.append(this_colnames)
        result.append(atom_only_ac)
    results_dictionary = {'colnames': colnames, 'results': result}
    return results_dictionary
    
def generate_atomonly_deltametric_derivatives(mol, atomIdx, loud, depth=4, oct=True):
    ## this function gets deltametrics for a molecule starting
    ## in one single atom only
    # Inputs:
    #       mol - mol3D class
    #       atomIdx - int, index of atom3D class
    #       loud - bool, print output
    result = None
    colnames = []
    allowed_strings = ['electronegativity', 'nuclear_charge', 'ident', 'topology', 'size', 'effective_nuclear_charge']
    labels_strings = ['chi', 'Z', 'I', 'T', 'S', 'Zeff']
    # print('The selected connection type is ' + str(mol.getAtom(atomIdx).symbol()))
    for ii, properties in enumerate(allowed_strings):
        atom_only_ac_der = atom_only_deltametric_derivative(mol, properties, depth, atomIdx, oct=oct)
        this_colnames = []
        for i in range(0, depth + 1):
            colnames.append(['d'+labels_strings[ii] + '-' + str(i)+ '/d' + labels_strings[ii] + str(j) for j in range(0, mol.natoms)])
        if result is None:
            result = atom_only_ac_der
        else:
            result = np.row_stack([result,atom_only_ac_der])
    results_dictionary = {'colnames': colnames, 'results': result}
    return results_dictionary
