import math
import numpy
import matplotlib.pyplot as plt
import tifffile
from scipy.optimize import curve_fit


def cubicFit( x, l, n, m, c ): return l*x**3 + n*x**2 + m*x + c
def quadraticFit( x, n, m, c ): return n*x**2 + m*x + c
def linearFit( x, m, c ): return m*x + c
def linearFitNoIntercept( x, m ): return m*x

def getPathLengthThroughSphere(p, r, sdd, sod):
    '''
    Get the path length of a ray passing through a sphere. Schematic diagram is provided in the supplementary methods section of the first (only?) radiosphere paper.

    Parameters
    ----------
        p : float
            The distance in mm along the detector panel from the centre.

        r : float
            The particle radius in mm.

        sdd : float
            The source-detector distance in mm.

        sod : float
            The source-object distance in mm.

    Returns
    -------
        L : float
            The path length through the sphere in mm. If the path lies outside the sphere, returns 0.
    '''
    try:
        p = float(p)
        #print p,
        alpha = math.atan( p/sdd )
        #print alpha
        beta = math.asin( sod*math.sin(alpha) / r  )
        L = 2.0*r*math.sin(math.pi/2.0-beta)
        #print numpy.rad2deg(alpha), numpy.rad2deg(beta), L
        return L
    except:
        return 0.0

def generateFitParameters(calibSphereLog,pixelSizeMM,sourceDetectorDistMM,sourceObjectDistanceMM,radiusMM,centreYXpx,fitFunction=linearFit,outputPath=False,verbose=False):
    '''
    Fit an attenuation law to a logged normalised radiograph of a single particle.

    Parameters
    ----------
        calibSphereLog : 2D numpy array of floats
            A radiograph of a single particle divided by the background intensity and then logged.

        pixelSizeMM : float
            The size of a single pixel in mm on the detector panel.

        sourceDetectorDistMM: float
            The source-detector distance in mm.

        sourceObjectDistanceMM: float
            The source-object distance in mm.

        radiusMM : float
            The particle radius in mm.

        centreYXpx : 1D numpy array of floats
             The y and x location of the centre of the particle in pixels.

        fitFunction : function handle (optional)
            The fitting function to use. By default fit a linear fitting law. Options are `cubicFit`, `quadraticFit`, `linearFit` and `linearFitNoIntercept`

        outputPath : string (optional)
            A path to save the fit as an `npy` file.

        verbose : bool (optional)
            Show the fit on a graph. Default is False.

    Returns
    -------
        L: float
            The path length through the sphere in mm. If the path lies outside the sphere, returns 0.
    '''
    alphaMax = math.asin( radiusMM/sourceObjectDistanceMM )
    pmax = math.tan( alphaMax ) * sourceDetectorDistMM

    # Pixel positions going -x from the middle of the sphere in MM
    pixelPoints   = numpy.array(range(int(round(centreYXpx[0])))[::-1])
    pixelPointsMMdetector = pixelPoints*pixelSizeMM


    points = []
    #for pn, pMMdet in enumerate(pixelPointsMMdetector[0:130]):
    for pn, pMMdet in enumerate(pixelPointsMMdetector):
        #if pMMdet < pmax:
        L = getPathLengthThroughSphere( pMMdet, radiusMM, sourceDetectorDistMM, sourceObjectDistanceMM )
        if L > 0:
            points.append( [ L, calibSphereLog[pixelPoints[-pn], int(centreYXpx[1])], pixelPoints[-pn] ] )
    points = numpy.array( points )

    poptN,pcov  = curve_fit(fitFunction, points[:,1], points[:,0])

    if outputPath: numpy.save(outputPath, poptN)

    if verbose:
        D = 150
        plt.subplot(121)
        plt.imshow(calibSphereLog)
        plt.colorbar()
        plt.plot(int(centreYXpx[1])*numpy.ones_like(points[:,2]),points[:,2],'w--')
        plt.plot(int(centreYXpx[1]),points[0,2],'wx')
        plt.plot(int(centreYXpx[1]),points[-1,2],'wx')
        plt.xlim(centreYXpx[1]-D,centreYXpx[1]+D)
        plt.ylim(centreYXpx[0]-D,centreYXpx[0]+D)
        plt.subplot(122)
        plt.plot( points[:,1], points[:,0], 'k.', label='Measured value' ) # Measured calib sphere with 130kV\nand 1.00mm Cu filter
        plt.plot( points[:,1], fitFunction( points[:,1], *poptN ), 'k-', alpha=0.5, label='Fit' )
        plt.ylabel("Path length inside\nsphere (mm)")
        # plt.ylim([0, max(points[:,0])])
        plt.xlabel(r"Log Attenuation $\ln(I/I_0)$")
        plt.legend(loc=0)

        # plt.subplots_adjust(bottom=0.21,top=0.99,right=0.99,left=0.16)
        plt.show()
        # plt.savefig('./figures/experimental-attenuationCalibration.pdf')


    return poptN

if __name__ == '__main__':
    plt.style.use('./tools/radioSphere.mplstyle')

    calibSpherePath = './data/2021-02-09-EddyBillesNanoBis/7mm-AVG64.tif'
    backgroundPath = './data/2021-02-09-EddyBillesNanoBis/I0-AVG64.tif'
    outputPath = "./cache/fit-log-linear.npy"

    # Load images
    calibSphere = tifffile.imread(calibSpherePath).astype(float) / tifffile.imread(backgroundPath).astype(float)
    calibSphereLog = numpy.log(calibSphere)

    # Projection geometry stuff
    binning=4
    pixelSizeMM = 0.127*float(binning)
    sourceDetectorDistMM = 242.597 # from XAct
    radiusMM = 7/2
    sourceObjectDistanceMM  = sourceDetectorDistMM * ( radiusMM/71 / pixelSizeMM)     # 71 pixels across diameter, so 51um/px, pixels 0.508 mm
    # uncertain parameter, this wiggles it
    sourceObjectDistanceMM += 0.5
    # print("SOD = ", sourceObjectDistanceMM)
    centreYXpx = numpy.array([ 229, 183 ])

    poptN = generateFitParameters(calibSphereLog,pixelSizeMM,sourceDetectorDistMM,sourceObjectDistanceMM,radiusMM,centreYXpx,outputPath=outputPath,verbose=True)
