'''
Implementation of tomopack by Stéphane Roux
'''

import os
import numpy
import radioSphere.projectSphere
import scipy.ndimage
from scipy.spatial import distance

import matplotlib.pyplot as plt
import tifffile

# _kernel = numpy.ones((1,3,3))/9. # square kernel
_kernel = numpy.array([[[1,2,1],[2,4,2],[1,2,1]]])/16. # gaussian kernel

# non-naiive approach
def _PI_projector(f):
    g = f.copy()
    g[f<0.5] = 0.
    # g[g.imag < 1e-14] = numpy.round(g[g.imag < 1e-14])
    # g.real = numpy.round(g.real)
    g = numpy.round(numpy.abs(g))
    return g

# Bottom of page 8 in TomoPack.pdf --- just take peaks in a 3x3 pixel area, and weight by all of the mass in that area
# Why not directly look at maxima in the masses, this avoids the
def _filter_maxima(f, debug=False, removeNegatives=False):
    # Detect unpadded 2D image first:
    if len(f.shape) == 2:
        f = f[numpy.newaxis, ...]
        twoD = True
    else:
        twoD = False

    # This is making our complex result into a scalar result
    f_abs = numpy.abs(f)

    # Remove negatives -- however keeping negatives can help sharpen the result of the convolution
    if removeNegatives: f_abs[f<0] = 0

    masses = scipy.ndimage.convolve(f_abs,_kernel)

    peaks = masses == scipy.ndimage.maximum_filter(masses,size=(1,3,3))

    if debug:
        #plt.imshow(numpy.real(f))
        #plt.title('real part of original estimation')
        #plt.colorbar()
        #plt.show()

        #plt.imshow(f_abs)
        #plt.title('positive real component')
        #plt.colorbar()
        #plt.show()

        #plt.imshow(g)
        #plt.colorbar()
        #plt.title('Local filtered maxima')
        #plt.show()

        #plt.imshow(peaks)
        #plt.colorbar()
        #plt.title('Location of local peaks')
        #plt.show()

        plt.imshow(masses[0])
        plt.colorbar()
        plt.title('Local masses (first slice if 3D)')
        plt.show()

        plt.imshow(masses[0]>0.5)
        plt.colorbar()
        plt.title('Local masses > 0.5 (first slice if 3D)')
        plt.show()

        #plt.imshow(peaks*masses)
        #plt.colorbar()
        #plt.title('Final weighted estimation')
        #plt.show()

    if twoD: return peaks[0]*masses[0]
    else:    return peaks   *masses


def tomopack(radioMM, psiMM, maxIterations=50, l=0.1, epsilon='iterative', kTrustRatio=0.75, GRAPH=False):
    '''
    'tomopack' FFT-based algorithm for sphere identification by Stéphane Roux.
    the FFT of psiMM as a structuring element to pick up projected spheres in radioMM
    using a parallel projection hypothesis.

    Parameters
    ----------
        radioMM : 2D numpy array of floats
            radiography containing "distance" projections in mm
            This is typically the result of mu*log(I/I_0)

        psiMM : 2D numpy array of floats
            A (synthetic) projection of a single sphere in approx the same settings,
            as radioMM -- this is the structuring element used for FFT recognition.
            Should be same size as radioMM

        maxIterations : int (optional)
            Number of iterations to run the detection for.
            Default = 50

        l : float (optional)
            "lambda" parameter which controls relaxation in the iterations.
            Default = 0.1

        epsilon : float (optional), or 'iterative'
            trust cutoff in psi.
            If the radiograph is in mm so is this parameter
            Default = 'iterative'

        GRAPH : bool (optional)
            VERY noisily make graphs?

    Returns
    -------
        f_x : 2D numpy array
            same size as radioMM, high values where we think particles are.
            Consider using indicatorFunctionToDetectorPositions()

    '''
    #assert(psi.sum() < radioMM.sum()), "detectSpheres.tomopack(): sum(psi) > sum(radioMM), doesn't look right"
    assert(len(psiMM.shape) == 2), "detectSpheres.tomopack(): psi and radioMM should be same size"
    assert(len(psiMM.shape) == len(radioMM.shape)), "detectSpheres.tomopack(): psi and radioMM should be same size"
    assert(psiMM.shape[0] == radioMM.shape[0]), "detectSpheres.tomopack(): psi and radioMM should be same size"
    assert(psiMM.shape[1] == radioMM.shape[1]), "detectSpheres.tomopack(): psi and radioMM should be same size"

    radioMM_FFT = numpy.fft.fft2(radioMM)

    # define the psi function and its FFT -- the structuring element for SR
    # Start with the projection of a centred sphere
    #psiMM_FFT = numpy.fft.fft2(psiMM)
    # 2021-10-26 OS: phase shift of psi so that zero phase angle is positioned at the centre of the detector
    # avoids the HACK of shifting f_x later on line 187
    psiMM_FFT = numpy.fft.fft2(numpy.fft.fftshift(psiMM))


    # This is comparable to Figure 7 in TomoPack
    # if GRAPH:
    #     plt.imshow(1./numpy.abs(psiMM_FFT), vmin=0, vmax=1, cmap='hot')
    #     plt.title(r"$1/|\psi_{FFT}|$")
    #     plt.colorbar()
    #     plt.show()

    ## naiive approach
    with numpy.errstate(divide='ignore', invalid='ignore'):
        f_k_naiive = radioMM_FFT/psiMM_FFT
    f_x_naiive = numpy.fft.ifft2(f_k_naiive)


    # Prepare for iterations

    f_x = numpy.zeros_like(radioMM)
    if epsilon == 'iterative':
        epsilon = 1e-5
        k_trust = numpy.abs(psiMM_FFT) > epsilon
        while k_trust.sum()/k_trust.shape[0]/k_trust.shape[1] > kTrustRatio:
            # print(epsilon)
            epsilon *= 1.5
            k_trust = numpy.abs(psiMM_FFT) > epsilon
        f_x_old = epsilon*numpy.ones_like(radioMM)
    else:
        f_x_old = epsilon*numpy.ones_like(radioMM)
        k_trust = numpy.abs(psiMM_FFT) > epsilon

    if GRAPH: print("INFO: k_trust maintains {}% of wavenumubers".format(100*k_trust.sum()/k_trust.shape[0]/k_trust.shape[1]))

    count = 0
    if GRAPH: plt.ion()


    # Start iterations as per Algorithm 1 in TomoPack
    #   Objective: get a good indictor function f_x
    while numpy.linalg.norm(f_x - f_x_old) > 1e-8: # NOTE: Using a different epsilon here
        if GRAPH: plt.clf()

        f_x_old = f_x.copy()
        f_k = numpy.fft.fft2(f_x)
        # if GRAPH: plt.plot(f_k,'k.')
        f_k[k_trust] = f_k_naiive[k_trust]
        # if GRAPH: plt.plot(f_k,'r.')
        f_x = numpy.fft.ifft2(f_k)
        #f_x = numpy.fft.fftshift(numpy.fft.ifft2(f_k)) # importing HACK from 1D version
        f_x = f_x + l*(_PI_projector(f_x) - f_x)
        #print(numpy.amax(numpy.abs(f_x - f_x_old)))

        if GRAPH:
            plt.imshow(numpy.abs(f_x), vmin=0)#, vmax=1)
            plt.colorbar()
            # plt.plot(numpy.real(f_x),numpy.imag(f_x),'.')
            plt.pause(0.00001)
        count += 1

        if count > maxIterations:
            # f_x = numpy.zeros_like(f_x) # NOTE: THIS WILL RETURN NO MATCHES WHEN TOMOPACK DOESN'T CONVERGE
            print('\nKILLED LOOP')
            break
    # plt.show()
    # print('Took ' + str(count) + ' iterations to complete.')

    if GRAPH:
        #not_k_trust_ind = numpy.where(~k_trust)

        plt.figure(figsize=[8,6])

        plt.subplot(331)
        plt.title(r'$p(x)$')
        plt.imshow(radioMM)
        # plt.plot(pPredictedIJ[:,0], pPredictedIJ[:,1], 'rx')
        plt.colorbar()

        plt.subplot(332)
        plt.title(r'$\psi(x)$')
        plt.imshow(psiMM, vmin=radioMM.min(), vmax=radioMM.max())
        plt.colorbar()

        ax3 = plt.subplot(333)
        plt.title(r'True $f(x)$')
        # HACK ky-flat image
        plt.imshow(numpy.zeros_like(radioMM))
        #ax3.set_aspect('equal', 'datalim')

        plt.subplot(334)
        plt.title(r'$|\tilde{p}(k)|$')
        plt.imshow(numpy.abs(radioMM_FFT))#, vmin=0, vmax=10)
        plt.colorbar()

        plt.subplot(335)
        plt.title(r'$|\tilde{\psi}(k)|$')
        plt.imshow(numpy.abs(psiMM_FFT), vmin=0, vmax=10)
        plt.colorbar()

        one_on_psi_plot = 1./numpy.abs(psiMM_FFT)
        one_on_psi_plot[~k_trust] = numpy.nan
        plt.subplot(336)
        plt.title(r'$1/|\psi(k)|$')
        plt.imshow(one_on_psi_plot)
        # plt.semilogy(numpy.arange(len(psi_FFT))[k_trust],1./numpy.abs(psi_FFT)[k_trust],'g.')
        #plt.semilogy(numpy.arange(len(psi_FFT))[~k_trust],1./numpy.abs(psi_FFT)[~k_trust],'r.')
        #plt.semilogy([0,len(psi_FFT)],[1./epsilon,1./epsilon],'k--')
        # plt.plot(not_k_trust_ind[0], not_k_trust_ind[1], 'r.')
        plt.colorbar()
        # plt.ylim(0,1./epsilon)


        plt.subplot(337)
        plt.title(r'Estimated $|f(k)|$')
        plt.imshow(numpy.abs(f_k))
        plt.colorbar()

        plt.subplot(338)
        plt.title(r'Estimated $f(x)$')
        plt.imshow(numpy.real(f_x))
        # plt.plot(pPredictedIJ[:,0], pPredictedIJ[:,1], 'o', mec='r', mfc='None', mew=0.1)
        plt.colorbar()

        #plt.subplot(339)
        #plt.title(r'Residual $p(x) - p(f(x))$')
        #p_f_x = numpy.zeros_like(radioMM, dtype=float)


        ### Convert to XYZ in space and mm
        #positionsXYZmm = numpy.zeros([pPredictedIJ.shape[0], 3])
        ## Y
        #positionsXYZmm[:,1] = -1*(pPredictedIJ[:,0] - radioMM.shape[0]/2.0)*pixelSizeMM
        ## Z
        #positionsXYZmm[:,2] = -1*(pPredictedIJ[:,1] - radioMM.shape[1]/2.0)*pixelSizeMM
        #radii = numpy.ones([pPredictedIJ.shape[0]])
        #radii *= radiusMM

        #print("\n\n\n", pPredictedIJ, "\n\n\n")
        #print("\n\n\n", positionsXYZmm, "\n\n\n")

        #p_f_x = radioSphere.projectSphere.projectSphereMM(positionsXYZmm,
                                                          #radii,
                                                          #sourceDetectorDistMM=numpy.inf,
                                                          #pixelSizeMM=pixelSizeMM,
                                                          #detectorResolution=detectorResolution)

        #residual = p_f_x
        #residual = p_f_x - radioMM
        #vmax = numpy.abs(residual).max()
        #plt.imshow(residual, vmin=-vmax, vmax=vmax, cmap='coolwarm')
        #plt.colorbar()



        plt.subplots_adjust(hspace=0.5)
        plt.show()
        #plt.savefig('tt3D.png',dpi=200)

    #return pPredictedIJ
    return numpy.real(f_x)


def indicatorFunctionToDetectorPositions(f_x, debug=False, particlesPresenceThreshold=0.9):
    '''
    Takes an approximation of the indicator function f_x from tomopack and returns positions on the detector

    Parameters
    ----------
        f_x : 2D numpy array of floats
            The value of the indicator function for each pixel on the detector

        debug : bool, optional
            Show debug graphs (especially in the maximum filtering)
            Default = False

        particlesPresence : float, optional
            Threshold for accepting a particle
            Default = 0.9
    '''
    assert(len(f_x.shape) == 2), "detectSpheres.indicatorFunctionToDetectorPositions(): need 2D array"

    f_x = _filter_maxima(f_x, debug=debug)

    particlesPresence = numpy.round(numpy.real(_PI_projector(f_x)))
    pPredictedIJ = []
    for pos in numpy.array(numpy.where(particlesPresence > particlesPresenceThreshold)).T:
        val = particlesPresence[pos[0], pos[1]]
        # print(f_x[pos[0],pos[1]])
        for i in range(int(val)):
            pPredictedIJ.append([pos[1], pos[0]])
    pPredictedIJ = numpy.array(pPredictedIJ)

    return pPredictedIJ


def psiSeriesScanTo3DPositions(radio,      # not (necessarily) MM
                               psiXseries, # Obviously in the same units as radio above, please
                               radiusMM,   # can be removed?
                               CORxPositions = None,
                               massThreshold=0.12,
                               scanPersistenceThresholdRadii=None,
                               scanFixedNumber=None,
                               scanPersistenceThreshold=7, maxIterations=50,
                               sourceDetectorDistMM=100, pixelSizeMM=0.1, l=0.2, kTrustRatio=0.7, useCache=True,
                               numCores=1,
                               blur=0.0,
                               cacheFile='fXseries.tif',
                               verbose=False):

    # it is our objective to fill in fx series
    fXseries = numpy.zeros_like(psiXseries)

    if CORxPositions is None:
        print("xPositions is not passed, just putting 1, 2, 3...")
        CORxPositions = numpy.arange(psiXseries.shape[0])

    for posN, CORxPos in enumerate(CORxPositions):
        ### "Structuring Element"
        print("\t{}/{} CORxPos = {:0.2f} mm".format(posN+1, len(CORxPositions), CORxPos), end='\r')
        fXseries[posN] = radioSphere.detectSpheres.tomopack(radio,
                                                            psiXseries[posN],
                                                            GRAPH=0,
                                                            maxIterations=maxIterations,
                                                            l=l,
                                                            kTrustRatio=kTrustRatio)
    tifffile.imsave(cacheFile, fXseries.astype('float'))
    print(f"saved {cacheFile}")
    #loadedCache = False

    #if useCache:
        #cachePsiFile = cacheFile[:-4] + '_psi.tif'
        #if os.path.isfile(cacheFile) and os.path.isfile(cachePsiFile):
            #print("Loading previous indicator functions... ", end="")
            #fXseries = tifffile.imread(cacheFile)
            #psiXseries = tifffile.imread(cachePsiFile)
            #if ( fXseries.shape[0] == CORxNumber ) and ( fXseries.shape[1] == radioMM.shape[0] ) and ( fXseries.shape[2] == radioMM.shape[1] ):
                #print("done.")
                #loadedCache = True
            #else:
                #print("cached file had wrong dimensions. Generating new cache file.")
        #else:
            #print('No cached indicator functions found. Generating them now to cache.')
    #if not loadedCache:
        #fXseries = numpy.zeros((len(CORxPositions), radioMM.shape[0], radioMM.shape[1]))
        #psiXseries = numpy.zeros_like(fXseries)

        #psiRefMM = radioSphere.projectSphere.projectSphereMM(numpy.array([[(CORxMax+CORxMin)/2., 0., 0.]]),
                                                          #numpy.array([radiusMM]),
                                                          #detectorResolution=radioMM.shape,
                                                          #pixelSizeMM=pixelSizeMM,
                                                          #sourceDetectorDistMM=sourceDetectorDistMM,
                                                          #blur=blur)

        #for posN, CORxPos in enumerate(CORxPositions):
            #### "Structuring Element"
            #print("\t{}/{} CORxPos = {:0.2f}mm".format(posN+1, len(CORxPositions), CORxPos), end='\r')
            #psiMM = radioSphere.projectSphere.projectSphereMM(numpy.array([[CORxPos, 0., 0.]]),
                                                              #numpy.array([radiusMM]),
                                                              #detectorResolution=radioMM.shape,
                                                              #pixelSizeMM=pixelSizeMM,
                                                              #sourceDetectorDistMM=sourceDetectorDistMM,
                                                              #blur=blur)

            #fXseries[posN] = radioSphere.detectSpheres.tomopack(radioMM, psiMM, GRAPH=0, maxIterations=maxIterations, l=l, kTrustRatio=kTrustRatio)
            #psiXseries[posN] = radioSphere.detectSpheres.tomopack(psiRefMM, psiMM, GRAPH=0, maxIterations=maxIterations, l=l, kTrustRatio=kTrustRatio)

    #if useCache and not loadedCache:
        #print("Saving indicator functions for next time... ", end="")
        #tifffile.imsave(cacheFile, fXseries.astype('<f4'))
        #tifffile.imsave(cachePsiFile, psiXseries.astype('<f4'))
        #print("done.")

    #L_x  = 20 # TODO: SCALING IN X DIRECTION SHOULD BE A FUNCTION OF THE CONE ANGLE
    #L_yz =  2 # TODO: THIS SHOULD BE A FUNCTION OF THE PIXELS PER RADIUS

    #struct = psiXseries[(psiXseries.shape[0])//2 -  L_x:(psiXseries.shape[0])//2 + L_x  + 1,
                        #(psiXseries.shape[1])//2 - L_yz:(psiXseries.shape[1])//2 + L_yz + 1,
                        #(psiXseries.shape[2])//2 - L_yz:(psiXseries.shape[2])//2 + L_yz + 1]

    #fXconvolvedSeries = scipy.ndimage.convolve(fXseries,struct/struct.sum())
    ##if useCache and not loadedCache:
        ##tifffile.imsave(f'{cacheFile[:-4]}_struct.tif', struct.astype('<f4'))
        ##tifffile.imsave(f'{cacheFile[:-4]}_fXconvolvedSeries.tif', fXconvolvedSeries.astype('<f4'))


    #binaryPeaks = fXconvolvedSeries > massThreshold

    zoomLevel = sourceDetectorDistMM/((CORxPositions[0] + CORxPositions[-1])/2)
    CORxDelta = numpy.abs(CORxPositions[0]-CORxPositions[1])
    # Look in a volume of +/- half a radius in all directions for the highest value (+/- 1 radius keeps overlapping and causing issues, half a radius doesn't overlap particles, but still contains one clean peak)
    fXseriesMaximumFiltered = scipy.ndimage.maximum_filter(fXseries,
                                                              size=(3*numpy.int(numpy.floor(radiusMM/CORxDelta)),
                                                                    numpy.int(numpy.floor(radiusMM/pixelSizeMM*zoomLevel)),
                                                                    numpy.int(numpy.floor(radiusMM/pixelSizeMM*zoomLevel)))
                                                              )
    allPeaks = fXseries == fXseriesMaximumFiltered
    masses = allPeaks*fXseries

    if verbose:
        tifffile.imsave(cacheFile[:-4] + '_masses.tif', masses.astype('<f4'))
        tifffile.imsave(cacheFile[:-4] + '_peaks.tif', allPeaks.astype('<f4'))
        tifffile.imsave(cacheFile[:-4] + '_fXseries.tif', fXseries.astype('<f4'))
        tifffile.imsave(cacheFile[:-4] + '_fXseriesMaximumFiltered.tif', fXseriesMaximumFiltered.astype('<f4'))

    if scanFixedNumber:
        # get the indices of all of the peaks, from highest to lowest
        sortedPeakIndices = numpy.argsort(masses, axis=None)[::-1]
        # print(sortedPeakIndices.shape)
        # get just the first scanFixedNumber of those and put them into a scanFixedNumber x 3 array
        peaksCORxPOSnJI = numpy.vstack(numpy.unravel_index(sortedPeakIndices[:scanFixedNumber], masses.shape)).T
        # print(peaksCORxPOSnJI.shape)
    else:
        filteredPeaks = masses > massThreshold
        peaksCORxPOSnJI = numpy.argwhere(filteredPeaks)

        if verbose: tifffile.imsave(cacheFile[:-4] + '_filteredPeaks.tif', filteredPeaks.astype('<f4'))
    # print(peaksCORxPOSnJI)

    ###############################################################
    ### Now we have guesses for all particle according to detector
    ###   (IJ) and position along the X-scanning direction
    ### We're going to convert that to spatial XYZ
    ###############################################################
    print("\nConverting tomopack x-scan to 3D positions\n")
    ## Convert to XYZ in space and mm
    positionsXYZmm = numpy.zeros([peaksCORxPOSnJI.shape[0], 3])

    for i in range(positionsXYZmm.shape[0]):
        # X -- look up which CORx slice the maximum falls in, this could be interpolated instead of rounded
        positionsXYZmm[i,0] = CORxPositions[int(numpy.round(peaksCORxPOSnJI[i,0]))]

        # detector I gives real position Y in mm
        yPosDetMM = -1*(peaksCORxPOSnJI[i,2] - radio.shape[1]/2.0)*pixelSizeMM

        # detector J gives real position Z in mm
        zPosDetMM = -1*(peaksCORxPOSnJI[i,1] - radio.shape[0]/2.0)*pixelSizeMM

        # And now scale down by zoom factor
        # Y
        positionsXYZmm[i,1] = yPosDetMM * ( positionsXYZmm[i,0] / sourceDetectorDistMM )
        # Z
        positionsXYZmm[i,2] = zPosDetMM * ( positionsXYZmm[i,0] / sourceDetectorDistMM )

    print(f"\ntomopackDivergentScanTo3DPositions(): I'm returning {positionsXYZmm.shape[0]} 3D positions.\n")
    return positionsXYZmm


def tomopackDivergentScanTo3DPositions(radioMM, radiusMM,
                                       CORxMin=None, CORxMax=None, CORxNumber=100,
                                       psiXseries = None,
                                       massThreshold=0.12,
                                       scanPersistenceThresholdRadii=None,
                                       scanFixedNumber=None,
                                       scanPersistenceThreshold=7, maxIterations=50,
                                       sourceDetectorDistMM=100, pixelSizeMM=0.1, l=0.2, kTrustRatio=0.7, useCache=True,
                                       numCores=1,
                                       blur=0.0,
                                       cacheFile='fXseries.tif',
                                       verbose=False):
    """
    This function takes in a single divergent projection, and will run tomopack
    generating different psis by varying their position as Centre Of Rotation (COR) in the x-direction,
    from CORxMin to CORxMax in CORxNumber steps.

    The resulting series of indicator functions is analysed and a 3D position guess for
    all identified spheres is returned.

    Parameters
    ----------
        massThreshold : float (optional)
            Threshold for the accepting the result of the convolution of the indicator function.
            Deafult = 0.5

        blur : float, optional
            sigma of blur to pass to scipy.ndimage.gaussian_filter to
            blur the radiograph at the end of everything

        # scanPersistenceThresholdRadii : float (optional)
            # How much +/- in radii in the scanning direction should the particle's indicator function be above massThreshold? If not passed, scanPersistenceThreshold is used

    Returns
    -------
        positionsXYZmm : 2D numpy array
    """
    assert(len(radioMM.shape) == 2), "detectSpheres.tomopackDivergentScanTo3DPositions(): need 2D array"

    if CORxMin is None: CORxMin = CORx-3*radiusMM
    if CORxMax is None: CORxMax = CORx+3*radiusMM

    CORxPositions = numpy.linspace(CORxMin, CORxMax, CORxNumber)
    CORxDelta = CORxPositions[1] - CORxPositions[0]

    loadedCache = False

    if useCache:
        cachePsiFile = cacheFile[:-4] + '_psi.tif'
        if os.path.isfile(cacheFile) and os.path.isfile(cachePsiFile):
            print("Loading previous indicator functions... ", end="")
            fXseries = tifffile.imread(cacheFile)
            psiXseries = tifffile.imread(cachePsiFile)
            if ( fXseries.shape[0] == CORxNumber ) and ( fXseries.shape[1] == radioMM.shape[0] ) and ( fXseries.shape[2] == radioMM.shape[1] ):
                print("done.")
                loadedCache = True
            else:
                print("cached file had wrong dimensions. Generating new cache file.")
        else:
            print('No cached indicator functions found. Generating them now to cache.')
    if not loadedCache:
        fXseries = numpy.zeros((len(CORxPositions), radioMM.shape[0], radioMM.shape[1]))
        psiXseries = numpy.zeros_like(fXseries)
        psiMMseries = numpy.zeros_like(fXseries)

        psiRefMM = radioSphere.projectSphere.projectSphereMM(numpy.array([[(CORxMax+CORxMin)/2., 0., 0.]]),
                                                          numpy.array([radiusMM]),
                                                          detectorResolution=radioMM.shape,
                                                          pixelSizeMM=pixelSizeMM,
                                                          sourceDetectorDistMM=sourceDetectorDistMM,
                                                          blur=blur)

        for posN, CORxPos in enumerate(CORxPositions):
            ### "Structuring Element"
            print("\t{}/{} CORxPos = {:0.2f}mm".format(posN+1, len(CORxPositions), CORxPos), end='\r')
            psiMM = radioSphere.projectSphere.projectSphereMM(numpy.array([[CORxPos, 0., 0.]]),
                                                              numpy.array([radiusMM]),
                                                              detectorResolution=radioMM.shape,
                                                              pixelSizeMM=pixelSizeMM,
                                                              sourceDetectorDistMM=sourceDetectorDistMM,
                                                              blur=blur)

            fXseries[posN] = radioSphere.detectSpheres.tomopack(radioMM, psiMM, GRAPH=0, maxIterations=maxIterations, l=l, kTrustRatio=kTrustRatio)
            psiXseries[posN] = radioSphere.detectSpheres.tomopack(psiRefMM, psiMM, GRAPH=0, maxIterations=maxIterations, l=l, kTrustRatio=kTrustRatio)
            psiMMseries[posN] = psiMM

    if useCache and not loadedCache:
        print("Saving indicator functions for next time... ", end="")
        tifffile.imsave(cacheFile, fXseries.astype('<f4'))
        tifffile.imsave(cachePsiFile, psiXseries.astype('<f4'))
        tifffile.imsave(cachePsiFile[:-4]+'_MM.tif', psiMMseries.astype('<f4'))
        print("done.")


    use_old_method = False
    if use_old_method:
        ### Step 1. Maximum filtering... Trying only in the detector direction, in order not to bridge over the bad
        #           COR x positions where there are weird resonances
        ### Step 2. Lump together objects in the space of IJ and X scan
        binaryPeaks = scipy.ndimage.convolve(fXseries,_kernel) > massThreshold
        structure = numpy.ones([5,3,3])
        binaryPeaks = scipy.ndimage.binary_closing(binaryPeaks,structure=structure) # remove any gaps
        if verbose: tifffile.imsave('closed_peaks.tif', binaryPeaks.astype('<f4'))
    else:
        L_x  = 20 # TODO: SCALING IN X DIRECTION SHOULD BE A FUNCTION OF THE CONE ANGLE
        L_yz =  2 # TODO: THIS SHOULD BE A FUNCTION OF THE PIXELS PER RADIUS

        struct = psiXseries[(psiXseries.shape[0])//2 -  L_x:(psiXseries.shape[0])//2 + L_x  + 1,
                            (psiXseries.shape[1])//2 - L_yz:(psiXseries.shape[1])//2 + L_yz + 1,
                            (psiXseries.shape[2])//2 - L_yz:(psiXseries.shape[2])//2 + L_yz + 1]

        fXconvolvedSeries = scipy.ndimage.convolve(fXseries,struct/struct.sum())
        if useCache and not loadedCache:
            tifffile.imsave(f'{cacheFile[:-4]}_struct.tif', struct.astype('<f4'))
            tifffile.imsave(f'{cacheFile[:-4]}_fXconvolvedSeries.tif', fXconvolvedSeries.astype('<f4'))


        binaryPeaks = fXconvolvedSeries > massThreshold

    get_peaks_from_persistence = False
    if get_peaks_from_persistence:

        labelledPeaks = scipy.ndimage.label(binaryPeaks)[0]
        labelledPeaksBB = scipy.ndimage.find_objects(labelledPeaks)

        print(f"Found {labelledPeaks.max()} peaks across the indicator function series.")
        # Compute centres of mass and persistence in the scanning direction
        peaksCORxPOSnJI = numpy.zeros((labelledPeaks.max(), 4))

        try:
            import spam.label
            #print("Doing it with spam")
            BB = spam.label.boundingBoxes(labelledPeaks)

            # save x-y-z positions where x is the slice number in the scan direction and y, z pixel positions on detector
            peaksCORxPOSnJI[:,0:3] = spam.label.centresOfMass(labelledPeaks, boundingBoxes=BB)[1:]

            # also add in how many slices this is alive for
            peaksCORxPOSnJI[:,-1] = BB[1:,1] - BB[1:,0] + 1
            #print("Done it with spam")

        except:
            print("Not doing peak label analysis with spam (not found) this will be very slow")
            for label in range(1,labelledPeaks.max()+1):
                # save x-y-z positions where x is the slice number in the scan direction and y, z pixel positions on detector
                peaksCORxPOSnJI[label-1][0:3] = scipy.ndimage.center_of_mass(labelledPeaks==label)

                # also add in how many slices this is alive for
                peaksCORxPOSnJI[label-1][-1] = labelledPeaksBB[label-1][0].stop - labelledPeaksBB[label-1][0].start + 1
            print("...it was slow")

        if verbose: print("tomopackDivergentScanTo3DPositions(): object persistence in scan direction:\n", peaksCORxPOSnJI[:, -1])

        # Apply persistence threshold:
        persistenceMask = numpy.zeros(labelledPeaks.max(), dtype=bool)
        CORxSpacing = CORxPositions[1] - CORxPositions[0]
        if scanPersistenceThresholdRadii is not None:
            scanPersistenceThreshold = int(numpy.floor((scanPersistenceThresholdRadii*radiusMM)/CORxSpacing))
            print(f"Calculated scan persistence threshold for {scanPersistenceThresholdRadii}\% particle radii: {scanPersistenceThreshold}")
        if scanFixedNumber:
            # print('Scanning for a fixed number of particles is not implemented!')
            sortedPeakIndices = numpy.argsort(peaksCORxPOSnJI[:, -1])[::-1]
            for i in range(scanFixedNumber):
                persistenceMask[sortedPeakIndices[i]] = True
        else: # use threshold value
            persistenceMask[peaksCORxPOSnJI[:, -1] > scanPersistenceThreshold] = True
        # Apply mask and remove 4th column:
        peaksCORxPOSnJI = peaksCORxPOSnJI[persistenceMask, 0:3]
    else:
        zoomLevel = sourceDetectorDistMM/((CORxMin + CORxMax)/2)
        # Look in a volume of +/- half a radius in all directions for the highest value (+/- 1 radius keeps overlapping and causing issues, half a radius doesn't overlap particles, but still contains one clean peak)
        fXconvolvedMaximumFiltered = scipy.ndimage.maximum_filter(fXconvolvedSeries,
                                                                  size=(numpy.int(numpy.floor(radiusMM/CORxDelta)),
                                                                        numpy.int(numpy.floor(radiusMM/pixelSizeMM*zoomLevel)),
                                                                        numpy.int(numpy.floor(radiusMM/pixelSizeMM*zoomLevel))
                                                                        )
                                                                  )
        allPeaks = fXconvolvedSeries == fXconvolvedMaximumFiltered
        masses = allPeaks*fXconvolvedSeries
        if verbose:
            tifffile.imsave(cacheFile[:-4] + '_masses.tif', masses.astype('<f4'))
            tifffile.imsave(cacheFile[:-4] + '_peaks.tif', allPeaks.astype('<f4'))
            tifffile.imsave(cacheFile[:-4] + '_fXconvolvedSeries.tif', fXconvolvedSeries.astype('<f4'))
            tifffile.imsave(cacheFile[:-4] + '_fXconvolvedMaximumFiltered.tif', fXconvolvedMaximumFiltered.astype('<f4'))

        if scanFixedNumber:
            # get the indices of all of the peaks, from highest to lowest
            sortedPeakIndices = numpy.argsort(masses, axis=None)[::-1]
            # print(sortedPeakIndices.shape)
            # get just the first scanFixedNumber of those and put them into a scanFixedNumber x 3 array
            peaksCORxPOSnJI = numpy.vstack(numpy.unravel_index(sortedPeakIndices[:scanFixedNumber], masses.shape)).T
            # print(peaksCORxPOSnJI.shape)
        else:
            filteredPeaks = masses > massThreshold
            peaksCORxPOSnJI = numpy.argwhere(filteredPeaks)

            if verbose: tifffile.imsave(cacheFile[:-4] + '_filteredPeaks.tif', filteredPeaks.astype('<f4'))
        # print(peaksCORxPOSnJI)




    ###############################################################
    ### Now we have guesses for all particle according to detector
    ###   (IJ) and position along the X-scanning direction
    ### We're going to convert that to spatial XYZ
    ###############################################################
    print("\nConverting tomopack x-scan to 3D positions\n")
    ## Convert to XYZ in space and mm
    positionsXYZmm = numpy.zeros([peaksCORxPOSnJI.shape[0], 3])

    for i in range(positionsXYZmm.shape[0]):
        # X -- look up which CORx slice the maximum falls in, this could be interpolated instead of rounded
        positionsXYZmm[i,0] = CORxPositions[int(numpy.round(peaksCORxPOSnJI[i,0]))]

        # detector I gives real position Y in mm
        yPosDetMM = -1*(peaksCORxPOSnJI[i,2] - radioMM.shape[1]/2.0)*pixelSizeMM

        # detector J gives real position Z in mm
        zPosDetMM = -1*(peaksCORxPOSnJI[i,1] - radioMM.shape[0]/2.0)*pixelSizeMM

        # And now scale down by zoom factor
        # Y
        positionsXYZmm[i,1] = yPosDetMM * ( positionsXYZmm[i,0] / sourceDetectorDistMM )
        # Z
        positionsXYZmm[i,2] = zPosDetMM * ( positionsXYZmm[i,0] / sourceDetectorDistMM )

    print(f"\ntomopackDivergentScanTo3DPositions(): I'm returning {positionsXYZmm.shape[0]} 3D positions.\n")
    return positionsXYZmm

def removeParticle(positionsXYZmm,residual,radioMM,radiiMM,pixelSizeDetectorMM,zoomLevel,sourceObjectDistMM,verbose,GRAPH):
    if verbose: print('Removing a particle')
    # find the location of the highest peak on the detector panel. This is presumably the centroid.
    residualMMPeakIndices = numpy.unravel_index(numpy.argmax(residual, axis=None), residual.shape)
    if verbose: print(f'residualMMPeakIndices: {residualMMPeakIndices}')

    # define unit vector between source and peak location
    yPosDetMM = -1*(residualMMPeakIndices[1] - radioMM.shape[1]/2.0)*pixelSizeDetectorMM
    zPosDetMM = -1*(residualMMPeakIndices[0] - radioMM.shape[0]/2.0)*pixelSizeDetectorMM
    magnitude = numpy.sqrt(zoomLevel**2*sourceObjectDistMM**2 + yPosDetMM**2 + zPosDetMM**2)
    s = numpy.array([
        zoomLevel*sourceObjectDistMM/magnitude,
        yPosDetMM/magnitude,
        zPosDetMM/magnitude])
    if verbose: print(f's: {s}')

    # find distance of every particle from the line defined by this unit vector
    distances = numpy.linalg.norm(numpy.cross(positionsXYZmm,s),axis=1)
    if verbose: print(f'distances: {distances}')

    # remove the particle closest to the line
    closestParticleIndex = numpy.argmin(distances)
    positionsXYZmm = numpy.delete(positionsXYZmm, closestParticleIndex, axis=0)

    p_f_x = radioSphere.projectSphere.projectSphereMM(positionsXYZmm,
                                                      radiiMM[0]*numpy.ones(len(positionsXYZmm)),
                                                      sourceDetectorDistMM=zoomLevel*sourceObjectDistMM,
                                                      pixelSizeMM=pixelSizeDetectorMM,
                                                      detectorResolution=radioMM.shape)
    residual = p_f_x - radioMM

    if GRAPH:
        plt.imshow(residual)
        plt.show()

    return positionsXYZmm, residual

def addParticle(*args, **kwargs):
    # return addParticleRaster(*args, **kwargs)
    return addParticleSensitivity(*args, **kwargs)


def addParticleRaster(positionsXYZmm,residual,radioMM,radiiMM,pixelSizeDetectorMM,zoomLevel,sourceObjectDistMM,CORxMin, CORxMax,CORxNumber,verbose,GRAPH):
    if verbose: print('Adding a particle')
    # find the location of the highest peak on the detector panel. This is presumably the centroid.
    residualMMPeakIndices = numpy.unravel_index(numpy.argmin(residual, axis=None), residual.shape)
    if verbose: print(f'residualMMPeakIndices: {residualMMPeakIndices}')

    # define unit vector between source and peak location
    yPosDetMM = -1*(residualMMPeakIndices[1] - radioMM.shape[1]/2.0)*pixelSizeDetectorMM
    zPosDetMM = -1*(residualMMPeakIndices[0] - radioMM.shape[0]/2.0)*pixelSizeDetectorMM
    magnitude = numpy.sqrt(zoomLevel**2*sourceObjectDistMM**2 + yPosDetMM**2 + zPosDetMM**2)
    s = numpy.array([
        zoomLevel*sourceObjectDistMM/magnitude,
        yPosDetMM/magnitude,
        zPosDetMM/magnitude])
    if verbose: print(f's: {s}')

    # trying to find an optimal solution by doing a raster scan and looking for minimal residual
    x_test = numpy.linspace(CORxMin,CORxMax,CORxNumber)
    best_index = 0
    best_residual = numpy.inf
    ref_projection = radioSphere.projectSphere.projectSphereMM(positionsXYZmm,
                                                      radiiMM[0]*numpy.ones(len(positionsXYZmm)),
                                                      sourceDetectorDistMM=zoomLevel*sourceObjectDistMM,
                                                      pixelSizeMM=pixelSizeDetectorMM,
                                                      detectorResolution=radioMM.shape)

    limits = radioSphere.projectSphere.singleSphereToDetectorPixelRange(s*x_test[0],
                                                                        radiiMM[0],
                                                                        radiusMargin=0.1,
                                                                        sourceDetectorDistMM=zoomLevel*sourceObjectDistMM,
                                                                        pixelSizeMM=pixelSizeDetectorMM,
                                                                        detectorResolution=radioMM.shape)

    ref_projection_crop = ref_projection[limits[0,0]:limits[0,1], limits[1,0]:limits[1,1]]
    radioMM_crop = radioMM[limits[0,0]:limits[0,1], limits[1,0]:limits[1,1]]
    for i,x in enumerate(x_test):
        single_particle_projection = radioSphere.projectSphere.projectSphereMM(numpy.expand_dims(s*x,axis=0),
                                                          numpy.expand_dims(radiiMM[0],axis=0),
                                                          sourceDetectorDistMM=zoomLevel*sourceObjectDistMM,
                                                          pixelSizeMM=pixelSizeDetectorMM,
                                                          detectorResolution=radioMM.shape,
                                                          ROIcentreMM=s*x_test[0],
                                                          ROIradiusMM=radiiMM[0])

        residual = ref_projection_crop + single_particle_projection - radioMM_crop
        # print(i, (residual**2).sum(), best_residual, best_index)

        if (residual**2).sum() < best_residual:
            best_index = i
            best_residual = (residual**2).sum()

        # plt.ion()
        # plt.title(i)
        # plt.imshow(residual)
        # plt.pause(0.001)
    bestPositionXYZmm = s*x_test[best_index]

    print(f'Best location at {best_index}-th x value: {x_test[best_index]}')

    optimise = True
    if optimise:
        import radiosphere.optimsePositions
        bestPositionXYZmm = radioSphere.optimisePositions.optimiseSensitivityFields(radioMM,
                                                                                    numpy.expand_dims(bestPositionXYZmm,axis=0), # try the middle of the sample
                                                                                    numpy.expand_dims(radiiMM[0],axis=0),
                                                                                    perturbationMM=(0.01, 0.01, 0.01),
                                                                                    # perturbationMM=(0.5, 0.25, 0.25),
                                                                                    # perturbationMM=(1, 0.5, 0.5),
                                                                                    # perturbationMM=(3, 1, 1),
                                                                                    minDeltaMM=0.0001,
                                                                                    iterationsMax=500,
                                                                                    sourceDetectorDistMM=zoomLevel*sourceObjectDistMM,
                                                                                    pixelSizeMM=pixelSizeDetectorMM,
                                                                                    detectorResolution=radioMM.shape,
                                                                                    verbose=False,
                                                                                    # GRAPH=True,
                                                                                    # NDDEM_output=True
                                                                                    )

    positionsXYZmm = numpy.vstack([positionsXYZmm, bestPositionXYZmm])
    p_f_x = radioSphere.projectSphere.projectSphereMM(positionsXYZmm,
                                                      radiiMM[0]*numpy.ones(len(positionsXYZmm)),
                                                      sourceDetectorDistMM=zoomLevel*sourceObjectDistMM,
                                                      pixelSizeMM=pixelSizeDetectorMM,
                                                      detectorResolution=radioMM.shape)
    residual = p_f_x - radioMM

    if GRAPH:
        plt.imshow(residual)
        plt.show()

    return positionsXYZmm, residual


def addParticleSensitivity(positionsXYZmm,residual,radioMM,radiiMM,pixelSizeDetectorMM,zoomLevel,sourceObjectDistMM,CORxMin, CORxMax,CORxNumber,verbose,GRAPH):
    if verbose: print('Adding a particle')
    # find the location of the highest peak on the detector panel. This is presumably the centroid.
    residualMMPeakIndices = numpy.unravel_index(numpy.argmin(residual, axis=None), residual.shape)
    if verbose: print(f'residualMMPeakIndices: {residualMMPeakIndices}')

    # define unit vector between source and peak location
    yPosDetMM = -1*(residualMMPeakIndices[1] - radioMM.shape[1]/2.0)*pixelSizeDetectorMM
    zPosDetMM = -1*(residualMMPeakIndices[0] - radioMM.shape[0]/2.0)*pixelSizeDetectorMM
    magnitude = numpy.sqrt(zoomLevel**2*sourceObjectDistMM**2 + yPosDetMM**2 + zPosDetMM**2)
    s = numpy.array([
        zoomLevel*sourceObjectDistMM/magnitude,
        yPosDetMM/magnitude,
        zPosDetMM/magnitude])
    if verbose: print(f's: {s}')

    # trying to find an optimal solution by doing a raster scan and looking for minimal residual
    positionXYZmmOpt = radioSphere.optimisePositions.optimiseSensitivityFields(   radioMM,
                            numpy.expand_dims(s*sourceObjectDistMM,axis=0), # try the middle of the sample
                            numpy.expand_dims(radiiMM[0],axis=0),
                            perturbationMM=(0.001, 0.001, 0.001),
                            minDeltaMM=0.0001,
                            iterationsMax=500,
                            sourceDetectorDistMM=zoomLevel*sourceObjectDistMM,
                            pixelSizeMM=pixelSizeDetectorMM,
                            detectorResolution=radioMM.shape,
                            verbose=False,
                            NDDEM_output=True
                            )



    positionsXYZmm = numpy.vstack([positionsXYZmm, positionXYZmmOpt])
    p_f_x = radioSphere.projectSphere.projectSphereMM(positionsXYZmm,
                                                      radiiMM[0]*numpy.ones(len(positionsXYZmm)),
                                                      sourceDetectorDistMM=zoomLevel*sourceObjectDistMM,
                                                      pixelSizeMM=pixelSizeDetectorMM,
                                                      detectorResolution=radioMM.shape)
    residual = p_f_x - radioMM

    if GRAPH:
        plt.imshow(residual)
        plt.show()

    return positionsXYZmm, residual


def cleanDivergentScan(positionsXYZmm,radioMM,radiiMM,zoomLevel,sourceObjectDistMM,pixelSizeDetectorMM,CORxMin,CORxMax, CORxNumber,verbose=False,GRAPH=False):
    p_f_x = radioSphere.projectSphere.projectSphereMM(positionsXYZmm,
                                                      radiiMM[0]*numpy.ones(len(positionsXYZmm)),
                                                      sourceDetectorDistMM=zoomLevel*sourceObjectDistMM,
                                                      pixelSizeMM=pixelSizeDetectorMM,
                                                      detectorResolution=radioMM.shape)

    residual = p_f_x - radioMM
    print(f'Maximum |residual| {numpy.abs(residual).max()}')
    # Found an extra particle that shouldn't be there
    removed = 0
    added = 0
    # remove any overlaps
    while residual.max() > 0.9:
        positionsXYZmm, residual = removeParticle(positionsXYZmm,residual,radioMM,radiiMM,pixelSizeDetectorMM,zoomLevel,sourceObjectDistMM,verbose,GRAPH)
        removed += 1
    while residual.min() < -0.9:
        positionsXYZmm, residual = addParticle(positionsXYZmm,residual,radioMM,radiiMM,pixelSizeDetectorMM,zoomLevel,sourceObjectDistMM,CORxMin, CORxMax,CORxNumber,verbose,GRAPH)
        added +=1
    # now we clean up the cleaning up to make sure we have the right number of particles
    while added > removed:
        positionsXYZmm, residual = removeParticle(positionsXYZmm,residual,radioMM,radiiMM,pixelSizeDetectorMM,zoomLevel,sourceObjectDistMM,verbose,GRAPH)
        removed += 1
    while removed > added:
        positionsXYZmm, residual = addParticle(positionsXYZmm,residual,radioMM,radiiMM,pixelSizeDetectorMM,zoomLevel,sourceObjectDistMM,CORxMin, CORxMax,CORxNumber,verbose,GRAPH)
        added +=1

    print(f'Removed {removed} particles and added {added}.')
    return positionsXYZmm


def calculateErrors(posXYZa,posXYZb,radiiMM,verbose=True):
    distances = distance.cdist(posXYZa,posXYZb)
    distance_threshold = radiiMM[0]

    # an empty column in the distance matrix is a lost particle
    number_lost = numpy.sum((numpy.sum(distances <= distance_threshold, axis=0) == 0))
    if number_lost > 0: print(f"detectSpheres.calculateErrors(): number_lost = {number_lost}")

    # get errors
    min_errors = numpy.min(distances, axis=0) # best match for each column
    min_valid_errors = min_errors[min_errors < distance_threshold] # just where we found a particle

    err_mean    = numpy.mean(min_valid_errors)
    err_std     = numpy.std( min_valid_errors)

    return err_mean, err_std, number_lost
