import numpy
import scipy.ndimage
import sys, os


def pointToDetectorPixelRange(posMM, sourceDetectorDistMM=100, pixelSizeMM=0.1, detectorResolution=[512,512]):
    """
    This function gives the detector pixel that a single point will affect.
    The main idea is that it will be used with `ROIaroundSphere` option for projectSphereMM() and in turn singleSphereToDetectorPixelRange()
    in order to only project the needed pixels.

    Parameters
    ----------
        posMM : 1x3 2D numpy array of floats
            xyz position of sphere in mm, with the origin being the middle of the detector

        sourceDetectorDistMM : float, optional
            Distance between x-ray source and middle of detector
            Set as numpy.inf for parallel projection
            Default = 100

        pixelSizeMM : float, optional
            Pixel size on detector in mm
            Default = 0.1

        detectorResolution : 2-component list of ints, optional
            Number of pixels rows, columns of detector
            Default = [512,512]

    Returns
    -------
        detectorPixel : tuple
            row, column (j,i) coordinate on detector as per figures/projectedCoords_v2.pdf
    """
    assert(len(posMM.ravel())==3)

    posMM = posMM.ravel()

    if sourceDetectorDistMM == numpy.inf:
        projectedPixelSize = 1.0
    else:
        zoomLevel = sourceDetectorDistMM/posMM[0]
        projectedPixelSize = pixelSizeMM/zoomLevel

    # This is the pixel position wrt to the middle of the detector
    YpositionProjectedPX = posMM[1] / projectedPixelSize
    ZpositionProjectedPX = posMM[2] / projectedPixelSize

    # Detector is rows, columns, so Z, Y
    detectorPX = numpy.array(detectorResolution)//2 - [ ZpositionProjectedPX, YpositionProjectedPX ]

    return numpy.round(detectorPX).astype(int)


def singleSphereToDetectorPixelRange(spherePositionMM, radiusMM, radiusMargin=0.1, sourceDetectorDistMM=100, pixelSizeMM=0.1, detectorResolution=[512,512], transformationCentreMM=None, transformationMatrix=None):
    """
    This function gives the detector pixel range that a single sphere will affect.
    The main idea is that it will be used with `ROIaroundSphere` option for projectSphereMM()
    in order to only project the needed pixels.

    Parameters
    ----------
        spherePositionMM : 1x3 2D numpy array of floats
            xyz position of sphere in mm, with the origin being the middle of the detector

        radiusMM : float
            Particle radius for projection

        radiusMargin : float
            Multiplicative margin on radius
            Default = 0.1

        ROIaroundSphere : bool, optional
            If there is only one sphere, only compute a region-of-interest radiography?
            Default = False

        sourceDetectorDistMM : float, optional
            Distance between x-ray source and middle of detector
            Set as numpy.inf for parallel projection
            Default = 100

        pixelSizeMM : float, optional
            Pixel size on detector in mm
            Default = 0.1

        detectorResolution : 2-component list of ints, optional
            Number of pixels rows, columns of detector
            Default = [512,512]

        transformationCentreMM : 3-component vector
            XYZ centre for a transformation
            Default = None

        transformationMatrix : 3x3 numpy array
            XYZ transformation matrix to apply to coordinates
            Default = None

    Returns
    -------
        JIrange : range in rows, colums of detector concerned by this grain
    """
    assert((transformationCentreMM is None) == (transformationMatrix is None)), "projectSphere.singleSphereToDetectorPixelRange(): transformationCentreMM and transformationMatrix must both be set or unset"

    # Transform coordinates if so asked
    if transformationCentreMM is not None:
        spherePositionMM = numpy.dot(transformationMatrix, numpy.array(spherePositionMM).ravel()  - numpy.array(transformationCentreMM).ravel()) + transformationCentreMM
        #spherePositionMM = numpy.array([numpy.dot(transformationMatrix, spherePositionMM[0] - transformationCentreMM) + transformationCentreMM])

    x = spherePositionMM.ravel()[0]
    y = spherePositionMM.ravel()[1]
    z = spherePositionMM.ravel()[2]

    # compute bounding square from the corners of the XYZ-aligned cube closest to the detector
    # scrap that, do all corners for safety
    corners = numpy.zeros((8,2), dtype=int)
    n = 0
    for dx in [-1, 1]:
        for dy in [-1, 1]:
            for dz in [-1, 1]:
                corners[n] = pointToDetectorPixelRange(numpy.array([x+radiusMM*dx+radiusMM*radiusMargin*dx,
                                                                    y+radiusMM*dy+radiusMM*radiusMargin*dy,
                                                                    z+radiusMM*dz+radiusMM*radiusMargin*dz]),
                                                        sourceDetectorDistMM=sourceDetectorDistMM,
                                                        pixelSizeMM=pixelSizeMM,
                                                        detectorResolution=detectorResolution)
                n += 1

    #print("singleSphereToDetectorPixelRange():", numpy.max(corners, axis=0))
    #print("singleSphereToDetectorPixelRange():", numpy.min(corners, axis=0))
    maxRC = numpy.max(corners, axis=0)
    minRC = numpy.min(corners, axis=0)

    maxRC = numpy.minimum(maxRC,detectorResolution) # clip bounding box if past limits
    minRC = numpy.maximum(minRC,[0,0]) # clip bounding box if past limits
    return numpy.array([[minRC[0], maxRC[0]], [minRC[1], maxRC[1]]])


def projectSphereMM(spheresPositionMM, radiiMM, ROIcentreMM=None, ROIradiusMM=None, sourceDetectorDistMM=100, pixelSizeMM=0.1, detectorResolution=[512,512], projector='C', transformationCentreMM=None, transformationMatrix=None, blur=None):
    """
    This is the python wrapping function for the C++ projector, it gets projection geometry,
    list of particle positions and radii, projected in the X direction.
    The output is the crossed distance for each sphere in mm.

    Please refer to the figures/projectedCoords_v2 for geometry

    In order to allow projections from diffferent angles, an XYZ centre and a transformation matrix can be provided,
    which will be applied to the particle positions.

    Parameters
    ----------
        spheresPositionMM : Nx3 2D numpy array of floats
            xyz positions of spheres in mm, with the origin being the middle of the detector

        radiiMM : 1D numpy array of floats
            Particle radii for projection

        ROIcentreMM : 3-component vector of floats, optional
            Particle position for ROI
            Default = Disactivated (None)

        ROIradiusMM : float, optional
            Particle radius for ROI
            Default = Disactivated (None)

        sourceDetectorDistMM : float, optional
            Distance between x-ray source and middle of detector.
            Set as numpy.inf for parallel projection
            Default = 100

        pixelSizeMM : float, optional
            Pixel size on detector in mm
            Default = 0.1

        detectorResolution : 2-component list of ints, optional
            Number of pixels rows, columns of detector
            Default = [512,512]

        projector : string, optional
            Algorithm for the projector (leave this alone for now)
            Default = 'C'

        transformationCentreMM : 3-component vector
            XYZ centre for a transformation
            Default = None

        transformationMatrix : 3x3 numpy array
            XYZ transformation matrix to apply to coordinates
            Default = None

        blur : float, optional
            sigma of blur to pass to scipy.ndimage.gaussian_filter to
            blur the radiograph at the end of everything

    Returns
    -------
        projectionMM : 2D numpy array of floats
            Radiography containing the total crossed distance through the spheres distance in mm for each beam path.
            To turn this into a radiography, the distances should be put into a calibrated Beer-Lambert law
    """

    assert(len(spheresPositionMM.shape) == 2),                                  "projectSphere.projectSphereMM(): spheresPositionMM is not 2D array"
    assert(len(radiiMM.shape) == 1),                                            "projectSphere.projectSphereMM(): radiiMM is not 1D array"
    assert(spheresPositionMM.shape[0] == radiiMM.shape[0]),                     "projectSphere.projectSphereMM(): number of radii and number of sphere positions not the same"
    assert((transformationCentreMM is None) == (transformationMatrix is None)), "projectSphere.projectSphereMM(): transformationCentreMM and transformationMatrix must both be set or unset"


    # Transform coordinates if so asked
    if transformationCentreMM is not None:
        tmp = spheresPositionMM - transformationCentreMM
        for n, t in enumerate(tmp):
            tmp[n] = numpy.dot(transformationMatrix, t)
        tmp += transformationCentreMM
        spheresPositionMM = tmp

        # On the fly let's also move ROIcentreMM, no this is moved into transformation options in singleSphereToDetectorPixelRange
        #if ROIcentreMM is not None:
            #ROIcentreMM = numpy.dot(transformationMatrix, numpy.array(ROIcentreMM).ravel()  - numpy.array(transformationCentreMM).ravel()) + transformationCentreMM

    # Special case: use this to indicate a parallel projection in the X-direction, so x-positions are ignored.
    if sourceDetectorDistMM == numpy.inf:
        if ROIcentreMM is None or ROIradiusMM is None:
            # Again refer to projectedCoords_v2.pdf
            # -z in space goes with j on the detector and
            # -y in space goes with i on the detector

            # use algorithm from tomopack
            iDetector = numpy.linspace( pixelSizeMM*detectorResolution[0]/2.,
                                       -pixelSizeMM*detectorResolution[0]/2.,
                                        detectorResolution[0]).astype('<f4')
            jDetector = numpy.linspace( pixelSizeMM*detectorResolution[1]/2.,
                                       -pixelSizeMM*detectorResolution[1]/2.,
                                        detectorResolution[1]).astype('<f4')

            iDetector2D, jDetector2D = numpy.meshgrid(jDetector, iDetector)

            projectionXmm = numpy.zeros(detectorResolution, dtype=('<f4'))

            for spherePositionMM, radiusMM in zip(spheresPositionMM, radiiMM):
                # This function returns the parallel projection (dims of x_detector) of particles positioned at x and y
                #print("Adding sphere at: ", spherePositionMM, 'r: ', radiusMM)
                tmp = radiusMM**2 - (spherePositionMM[1] - iDetector2D)**2    - (spherePositionMM[2] - jDetector2D)**2
                mask = tmp > 0
                projectionXmm[mask] += 2*numpy.sqrt(tmp[mask])
            return projectionXmm
        else:
            print("projectSphere.projectSphereMM(): ROI mode in parallel not yet implemented (but shoudn't be to hard)")
            return



    # Flip axes for the C++ projector after applying transformation
    spheresPositionMM = spheresPositionMM * numpy.array([1, -1, -1])

    if projector == 'C':
        #sys.path.append(os.path.join(os.path.dirname(__file__), "/"))
        import projectSphereC3

        if ROIcentreMM is None or ROIradiusMM is None:
            projectionXmm = numpy.zeros(detectorResolution, dtype=('<f4'))

            # This C++ function fill in the passed projectionXmm array
            projectSphereC3.project_func(numpy.array([sourceDetectorDistMM], dtype='<f4'),
                                         radiiMM.astype('<f4'),
                                         numpy.linspace(-pixelSizeMM*detectorResolution[0]/2.,
                                                         pixelSizeMM*detectorResolution[0]/2.,
                                                         detectorResolution[0]).astype('<f4'),
                                         numpy.linspace(-pixelSizeMM*detectorResolution[1]/2.,
                                                         pixelSizeMM*detectorResolution[1]/2.,
                                                         detectorResolution[1]).astype('<f4'),
                                         spheresPositionMM.astype('<f4'),
                                         projectionXmm)
        elif ROIcentreMM is not None and ROIradiusMM is not None:
            # Make sure there's only one sphere:
            assert(len(spheresPositionMM.ravel())==3), "projectSphere.projectSphereMM(): in ROI mode I want only one sphere"

            # Get limits
            limits = singleSphereToDetectorPixelRange(ROIcentreMM.ravel(),
                                                      ROIradiusMM,
                                                      radiusMargin=0.1,
                                                      sourceDetectorDistMM=sourceDetectorDistMM,
                                                      pixelSizeMM=pixelSizeMM,
                                                      detectorResolution=detectorResolution,
                                                      transformationCentreMM=transformationCentreMM,
                                                      transformationMatrix=transformationMatrix)

            # Define (smaller) projection array to fill in
            projectionXmm = numpy.zeros((limits[0,1]-limits[0,0], limits[1,1]-limits[1,0]), dtype=('<f4'))

            # This C++ function fill in the passed projectionXmm array -- not limits after linspace
            projectSphereC3.project_func(numpy.array([sourceDetectorDistMM], dtype='<f4'),
                                         radiiMM.astype('<f4'),
                                         numpy.linspace(-pixelSizeMM*detectorResolution[0]/2.,
                                                         pixelSizeMM*detectorResolution[0]/2.,
                                                         detectorResolution[0]).astype('<f4')[limits[0,0]:limits[0,1]],
                                         numpy.linspace(-pixelSizeMM*detectorResolution[1]/2.,
                                                         pixelSizeMM*detectorResolution[1]/2.,
                                                         detectorResolution[1]).astype('<f4')[limits[1,0]:limits[1,1]],
                                         spheresPositionMM.astype('<f4'),
                                         projectionXmm)
        else:
            print("projectSphere.projectSphereMM(): If you set ROIcentreMM you must also set ROIradiusMM")

    else:
        print("projectSphere.projectSphereMM(): This projection mode not implemented")
        return

    if blur is not None:
        projectionXmm = scipy.ndimage.gaussian_filter(projectionXmm, sigma=blur)

    return projectionXmm


def computeLinearBackground(radioMM, mask=None):
    """
    This function computes a plane-fit for background greylevels to help with the correction of the background


    Parameters
    ----------
        radioMM : 2D numpy array of floats
            The image

        mask : 2D numpy array of bools, optional
            Masked zone to fit?

    Returns
    -------
        background : 2D numpy array of floats
            Same size as radioMM
    """
    x,y = numpy.meshgrid(numpy.arange(radioMM.shape[0]), numpy.arange(radioMM.shape[1]), indexing='ij')
    def plane(params,z,validPoints):
        output = params[0]*x[validPoints] + params[1]*y[validPoints] + z[validPoints] - params[2]
        return numpy.ravel(output)

    if mask is None:
        mask = numpy.abs(radioMM) < 0.05*radiiMM[0] # just where it is reasonable valued
    #isBackgroundMask *= numpy.abs(residualMM) > 0.1*radiiMM[0] # just where it is reasonable valued
    #plt.imshow(isBackgroundMask); plt.show()
    LSQret = scipy.optimize.least_squares(plane,
                                            [0,0,0],
                                            args=[radioMM,mask])

    backgroundPlane = - LSQret['x'][0]*x - LSQret['x'][1]*y + LSQret['x'][2]

    return backgroundPlane


def gl2mm(radio, calib=None):
    """
    This function takes a greylevel radiograph (I/I0) and returns
    a radiograph in mm (L), representing the path length encountered.

    Parameters
    ----------
        radio : a 2D numpy array of floats

        calib : dictionary (optional)
            This contains a calibration of greylevels to mm
            If not passed \mu is assumed to be 1.
    """
    if calib is not None:
        print("projectSphere.gl2mm() I need to be implemented")
    return -numpy.log(radio)


def mm2gl(radioMM, calib=None):
    """
    This function takes a radiograph in mm (L) and returns
    a radiograph in greylevels (I/I0)

    Parameters
    ----------
        radioMM : a 2D numpy array of floats

        calib : dictionary (optional)
            This contains a calibration of greylevels to mm
            If not passed \mu is assumed to be 1.
    """
    if calib is not None:
        print("projectSphere.mm2gl() I need to be implemented")
    return numpy.exp(-radioMM)
