import numpy as np
from scipy.spatial.distance import cdist

def DEM_step(locMMin, radiiMM, k=0.5):
    """
    Lightweight DEM with assumption of same size spheres as mechanical regularisation

    Parameters
    ----------
        locMMin : Nx3 2D numpy array of floats
            xyz positions of spheres in mm, with the origin being the middle of the detector

        radiiMM : 1D numpy array of floats
            Particle radii for projection

        k : float, optional
            Stiffness and timestep wrapped into one
            Default = 0.1

    Returns
    -------
        locMMin : output positions
    """
    locMM = locMMin.copy()

    if radiiMM.min() != radiiMM.max():
        print("DEM.DEM_step(): WARINING I assume all radii are the same, taking first one")
    #k = 0.1 # stiffness and timestep wrapped into one
    np = len(locMM)
    delta = cdist(locMM,locMM) - 2*radiiMM[0] # assuming all radii the same
    # print(delta)
    for i in range(0,np):
        for j in range(i+1,np):
            if delta[i,j] < 0:
                # print(i,j,i+j)
                branch_vector = locMM[i] - locMM[j]
                F = -k*delta[i,j]*branch_vector
                locMM[i] += F
                locMM[j] -= F
    return locMM


if __name__ == '__main__':
    import matplotlib.pyplot as plt

    # locMM = np.array([[0,0,0],
    #                   [0,1.8,0],
    #                   [0,0,3],
    #                   [3,0,0]])
    locMM = np.random.rand(50,3)*20
    print(len(locMM))
    radiiMM = 1.*np.ones(len(locMM))
    for t in range(100):
        locMM = DEM_step(locMM,radiiMM)
        plt.ion()
        plt.plot(locMM[:,0],locMM[:,1],'.')
        # plt.show()
        plt.pause(0.01)
    #print(locMM)
