# -*- coding: utf-8 -*-
# Copyright (c) 2015, PyRETIS Development Team.
# Distributed under the LGPLv2.1+ License. See LICENSE for more info.
"""Classes and functions for paths.

The classes and functions defined in this module are useful for
representing paths.


Important classes defined here
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

PathBase (:py:class:`.PathBase`)
    A base class for paths.

Path (:py:class:`.Path`)
    Class for a generic path that stores all possible information.

PathExt (:py:class:`.PathExt`)
    Class for a external paths. In external paths, the trajectories
    are stored in external files and the object will only contain the
    file names so that the external snapshots can be accessed.

Important methods defined here
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

paste_paths
    Function for joining two paths, one is in a backward time
    direction and the other is in the forward time direction.
"""
from abc import abstractmethod
import logging
import numpy as np
logger = logging.getLogger(__name__)  # pylint: disable=C0103
logger.addHandler(logging.NullHandler())


__all__ = ['PathBase', 'Path', 'PathExt', 'paste_paths']

# the following defines a human-readable form of the possible path status:
_STATUS = {'ACC': 'The path has been accepted',
           'MCR': 'Momenta change rejection',
           'BWI': 'Backward trajectory end at wrong interface',
           'BTL': 'Backward trajectory too long (detailed balance condition)',
           'BTX': 'Backward trajectory too long (max-path exceeded)',
           'BTS': 'Backward trajectory too short',
           'KOB': 'Kicked outside of boundaries',
           'FTL': 'Forward trajectory too long (detailed balance condition)',
           'FTX': 'Forward trajectory too long (max-path exceeded)',
           'FTS': 'Forward trajectory too short',
           'NCR': 'No crossing with middle interface'}

_GENERATED = {'sh': 'Path was generated with a shooting move',
              'tr': 'Path was generated with a time-reversal move',
              'ki': 'Path was generated by integration after kicking',
              're': 'Path was loaded from an external file',
              's+': 'Path was generated by a swapping move from +',
              's-': 'Path was generated by a Swapping move from -',
              '00': 'Path was generated by a null move'}


def paste_paths(path_back, path_forw, overlap=True, maxlen=None):
    """Merge a backward with a forward path into a new path.

    The resulting path is equal to the two paths stacked, in correct
    time. Note that the ordering is important here so that:
    ``paste_paths(path1, path2) != paste_paths(path2, path1)``.

    There are two things we need to take care of here:

    - `path_back` must be iterated in reverse (it is assumed to be a
      backward trajectory).
    - we may have to remove one point in `path2` (if the paths overlap).

    Parameters
    ----------
    path_back : object like :py:class:`.PathBase`
        This is the backward trajectory.
    path_forw : object like :py:class:`.PathBase`
        This is the forward trajectory.
    overlap : boolean, default is True
        If true, `path_back` and `path_forw` have a common
        starting-point, that is, the first point in `path_forw` is
        identical to the first point in `path_back`. In time-space this
        means that the *first* point in `path_forw` is identical to the
        *last* point in `path_back` (the backward and forward path
        started at the same location in space).
    maxlen : float, optional
        This is the maximum length for the new path. If it's not given,
        it will just be set to the largest of the `maxlen` of the two
        given paths.

    Note
    ----
    Some information about the path will not be set here. This must be
    set elsewhere. This includes how the path was generated
    (`path.generated`) and the status of the path (`path.status`).
    """
    if maxlen is None:
        if path_back.maxlen == path_forw.maxlen:
            maxlen = path_back.maxlen
        else:
            # They are unequal and both is not None, just pick the largest.
            # In case one is None, the other will be picked.
            # Note that now there is a chance of truncating the path while
            # pasting!
            maxlen = max(path_back.maxlen, path_forw.maxlen)
            msg = 'Unequal length: Using {} for the new path!'.format(maxlen)
            logger.warning(msg)
    time_origin = path_back.time_origin - path_back.length + 1
    new_path = path_back.empty_path(maxlen=maxlen, time_origin=time_origin)
    for phasepoint in path_back.trajectory(reverse=True):
        app = new_path.append(phasepoint)
        if not app:
            msg = 'Truncated while pasting backwards at: {}'
            msg = msg.format(new_path.length)
            logger.warning(msg)
            return new_path
    first = True
    for phasepoint in path_forw.trajectory():
        if first and overlap:
            first = False
            continue
        app = new_path.append(phasepoint)
        if not app:
            msg = 'Truncated path at: {}'.format(new_path.length)
            logger.warning(msg)
            return new_path
    return new_path


def check_crossing(cycle, orderp, interfaces, leftside_prev):
    """Check if we have crossed an interface during the last step.

    This function is useful for checking if an interface was crossed
    from the previous step till the current one. This is for instance
    used in the MD simulations for the initial flux.
    If will use a variable to store the previous positions with respect
    to the interfaces and check if interfaces were crossed here.

    Parameters
    ----------
    cycle : int
        This is the current simulation cycle number.
    orderp : float
        The current order parameter.
    interfaces : list of floats
        These are the interfaces to check.
    leftside_prev : list of booleans
        These are used to store the previous positions with respect
        to the interfaces.

    Returns
    -------
    leftside_curr : list of booleans
        These are the updated positions with respect to the interfaces.
    cross : list of tuples
        If a certain interface is crossed, a tuple will be added to this
        list. The tuple is of form
        (cycle number, interface number, direction)
        where direction is '-' for a crossing in the negative direction
        and '+' for a crossing in the positive direction.
    """
    cross = []
    if leftside_prev is None:
        leftside_curr = [orderp < interf for interf in interfaces]
    else:
        leftside_curr = [i for i in leftside_prev]
        for i, (left, interf) in enumerate(zip(leftside_prev, interfaces)):
            if left and orderp > interf:
                leftside_curr[i] = False
                cross.append((cycle, i, '+'))
            elif not left and orderp < interf:
                leftside_curr[i] = True
                cross.append((cycle, i, '-'))
    return leftside_curr, cross


class PathBase():
    """Base class for representation of paths.

    This class represents a path. A path consist of a series of
    consecutive snapshots (the trajectory) with the corresponding order
    parameter. We are going to assume that we always store the order
    parameter as a function of the time. For the other properties, the
    different sub-classes might not store all the intformation.

    Attributes
    ----------
    generated : tuple
        This contains information on how the path was generated.
        `generated[0]` : string, as defined in the variable `_GENERATED`
        `generated[1:]` : additional information:
        For ``generated[0] == 'sh'`` the additional information is the
        index of the shooting point on the old path, the new path and
        the corresponding order parameter.
    maxlen : int
        This is the maximum path length. Some algorithms requires this
        to be set. Others don't, which is indicated by setting `maxlen`
        equal to None.
    order : list of floats
        The order parameters as function of time.
    ordermin : tuple
        This is the (current) minimum order parameter for the path.
        `ordermin[0]` is the value, `ordermin[1]` is the index in
        `self.path`.
    ordermax : tuple
        This is the (current) maximum order parameter for the path.
        `ordermax[0]` is the value, `ordermax[1]` is the index in
        `self.path`.
    rgen : object like :py:class:`.RandomGenerator`
        This is the random generator that will be used for the
        paths that required random numbers.
    time_origin : int
        This is the location of the phase point `path[0]` relative to
        its parent. This might be useful for plotting.
    status : str or None
        The status of the path. The possibilities are defined
        in the variable `_STATUS`
    vpot : list of floats
        The potential energy as function of time.
    ekin : list of floats
        The kinetic energy as function of time.
    """

    def __init__(self, rgen, maxlen=None, time_origin=0):
        """Initialize the Path object.

        Parameters
        ----------
        rgen : object like :py:class:`.RandomGenerator`
            This is the random generator that will be used.
        maxlen : int, optional
            This is the max-length of the path. The default value,
            None, is just a path of arbitrary length.
        time_origin : int, optional
            This can be used to store the shooting point of a parent
            trajectory.
        """
        self.order = []
        self.vpot = []
        self.ekin = []
        self.maxlen = maxlen
        self.length = 0
        self.ordermin = None
        self.ordermax = None
        self.time_origin = time_origin
        self.status = None
        self.generated = None
        self.rgen = rgen

    def _update_orderp(self, orderp, idx):
        """Update current min/max order parameter.

        Update the min/max order parameter given a new order parameter.
        It will just check if the given order parameter is larger or
        smaller than the current ones.

        Parameters
        -----------
        orderp : float
            This is the new order parameter.
        idx : int
            This is the index of the new order parameter in `self.path`.
        """
        if self.ordermax is None or orderp > self.ordermax[0]:
            self.ordermax = (orderp, idx)
        if self.ordermin is None or orderp < self.ordermin[0]:
            self.ordermin = (orderp, idx)

    def get_min_max_orderp(self):
        """Get the minimum and maximum order parameter on the path.

        Update the minimum and maximum order parameter on the path and
        return them. This function will explicitly loop over the path,
        check all phase-space points and find the minimum and maximum
        order parameter. This is useful if the path was read directly
        without calling `append`.

        Returns
        -------
        out[0] : list
            This is the minimum order parameter, tuple with
            (value, index)
        out[1] : list
            This is the maximum order parameter, tuple with
            (value, index)
        """
        ordermin = None
        ordermax = None
        for i in range(self.length):
            orderp = self.order[i][0]
            if ordermin is None or ordermax is None:
                ordermin = (orderp, i)
                ordermax = (orderp, i)
            else:
                if orderp > ordermax[0]:
                    ordermax = (orderp, i)
                if orderp < ordermin[0]:
                    ordermin = (orderp, i)
        self.ordermin = ordermin
        self.ordermax = ordermax
        return ordermin, ordermax

    def check_interfaces(self, interfaces):
        """Check current status of the path.

        Get the current status of the path with respect to the
        `interfaces`. This is intended to determine if we have crossed
        certain interfaces or not.

        Parameters
        ----------
        interfaces : list of floats
            This list is assumed to contain the three interface values
            left, middle and right

        Returns
        -------
        out[0] : str, 'L' or 'R' or None
            Start condition: did the trajectory start at the left ('L')
            or right ('R') interface.
        out[1] : str, 'L' or 'R' or None
            Ending condition: did the trajectory end at the left ('L')
            or right ('R') interface or None of them.
        out[2] str, 'M' or '*'
            'M' if middle interface is crossed, '*' otherwise.
        out[3] : list of boolean
            out[2][i] = True if ordermin < interfaces[i] <= ordermax
        """
        if self.length < 1:
            logger.warning('Path is empty!')
            return None, None, None, None
        ordermax, ordermin = self.ordermax[0], self.ordermin[0]
        cross = [ordermin < interpos <= ordermax for interpos in interfaces]
        left, right = min(interfaces), max(interfaces)
        # check end & start:
        end = self.get_end_point(left, right)
        start = self.get_start_point(left, right)
        middle = 'M' if cross[1] else '*'
        return start, end, middle, cross

    def get_end_point(self, left, right):
        """Return the end point of the path as a string.

        The end point is either to the left of the `left` interface or
        to the right of the `right` interface, or somewhere in between.

        Parameters
        ----------
        left : float
            The left interface
        right : float
            The right interface

        Returns
        -------
        out : string
            String representing where the end point is ('L' - left,
            'R' - right or None).
        """
        if self.order[-1][0] <= left:
            end = 'L'
        elif self.order[-1][0] >= right:
            end = 'R'
        else:
            end = None
            logger.debug('Undefined end point.')
        return end

    def get_start_point(self, left, right):
        """Return the start point of the path as a string.

        The start point is either to the left of the `left` interface or
        to the right of the `right` interface.

        Parameters
        ----------
        left : float
            The left interface
        right : float
            The right interface

        Returns
        -------
        out : string
            String representing where the start point is ('L' - left,
            'R' - right or None).
        """
        if self.order[0][0] <= left:
            start = 'L'
        elif self.order[0][0] >= right:
            start = 'R'
        else:
            start = None
            logger.debug('Undefined starting point.')
        return start

    @abstractmethod
    def get_shooting_point(self):
        """Return a shooting point from the path.

        Parameters
        ----------
        rgen : object like :py:class:`.RandomGenerator`
            This object is used to draw a random integer.

        Returns
        -------
        phasepoint : tuple
            `phasepoint[0]` is the order parameter (as a tuple) and the
            two next items are the positions and velocities.
        idx : int
            The shooting point index.
        """
        pass

    def trajectory(self, reverse=False):
        """Iterate over the phase-space points in the path.

        Parameters
        ----------
        reverse : boolean
            If this is True, we iterate in the reverse direction.

        Yields
        ------
        out : tuple
            The phase-space points in the path.
        """
        if reverse:
            for i in range(self.length - 1, -1, -1):
                yield self.phasepoint(i)
        else:
            for i in range(self.length):
                yield self.phasepoint(i)

    @abstractmethod
    def phasepoint(self, idx):
        """Return a specific phase point.

        Parameters
        ----------
        idx : int
            Index for phase-space point to return.

        Returns
        -------
        out : tuple
            A phase-space point in the path.
        """
        pass

    @abstractmethod
    def _append_posvel(self, pos, vel):
        """Method to append positions and velocities."""
        pass

    def append(self, phasepoint):
        """Append a new phase point to the path.

        We will here append a new phase-space point to the path.
        The phase point is assumed to be given by positions and
        velocities with a corresponding order parameter and energy.

        Parameters
        ----------
        phasepoint : dict
            A dictionary with the things to add to the path.
            We assume that it contains the following keys:

            * 'order': list of floats representing the order parameter(s).

            * pos: the representation of the positions.

            * vel: the representation of velocities.

            * vpot: the potential energy.

            * ekin: the kinetic energy.
        """
        if self.maxlen is None or self.length < self.maxlen:
            orderp = phasepoint['order']
            self.order.append(orderp)
            self._update_orderp(orderp[0], self.length)
            self._append_posvel(phasepoint['pos'], phasepoint['vel'])
            self.vpot.append(phasepoint['vpot'])
            self.ekin.append(phasepoint['ekin'])
            self.length += 1
            return True
        msg = 'Max length exceeded! Could not append to path!'
        logger.debug(msg)
        return False

    def get_path_data(self, status, interfaces):
        """Return information about the Path.

        This information can (and is typically) stored in a
        `PathEnsemble`.

        Parameters
        ----------
        status : string
            This represents the current status of the path.
        interfaces : list
            These are just the interfaces we are currently considering.
        """
        path_info = {'generated': self.generated,
                     'status': status,
                     'length': self.length}

        path_info['ordermax'] = tuple(self.ordermax)
        path_info['ordermin'] = tuple(self.ordermin)

        start, end, middle, _ = self.check_interfaces(interfaces)
        path_info['interface'] = (start, middle, end)
        return path_info

    def set_move(self, move):
        """Update the path move.

        The path move is a short string that represent how the path
        was generated. It should preferably match one of the moves
        defined in `_GENERATED`

        Parameters
        ----------
        move : string
            A short description of the move
        """
        if self.generated is None:
            self.generated = (move, 0, 0, 0)
        else:
            self.generated = (move, self.generated[1], self.generated[2],
                              self.generated[3])

    def success(self, detect):
        """Check if the path is successful.

        The check is based on the maximum order parameter and the value
        of `detect`. It is successful if the maximum order parameter is
        greater than `detect`.

        Parameters
        ----------
        detect : float
            The value for which the path is successful, i.e. the
            "detect" interface.
        """
        return self.ordermax[0] > detect

    def __iadd__(self, other):
        """Add path data to a path from another path, i.e. ``self += other``.

        This will simply append the phase points from `other`.

        Parameters
        ----------
        other : object of type `Path`
            The object to add path data from.

        Returns
        -------
        self : object of type `Path`
            The updated path object.
        """
        for phasepoint in other.trajectory():
            app = self.append(phasepoint)
            if not app:
                msg = 'Truncated path while +=: {}'.format(self.length)
                logger.warning(msg)
                return self
        return self

    def copy_path(self):
        """Return a copy of the path."""
        new_path = self.empty_path()
        for phasepoint in self.trajectory():
            new_path.append(phasepoint)
        new_path.status = self.status
        new_path.time_origin = self.time_origin
        new_path.generated = self.generated
        return new_path

    @staticmethod
    @abstractmethod
    def reverse_velocities(vel):
        """Method that handles reversing of velocities."""
        pass

    def reverse(self):
        """Helper method for reversing the path, indented to be extended."""
        return self.reverse_trajectory()

    def reverse_trajectory(self):
        """Reverse a path and return the reverse path as a new path.

        This will simply reverse a path and return the reversed path as
        a new `Path` object. Note that currently, recalculating
        order parameters have not been implemented!  Typically, reversing
        will not change the order parameter, but it might change the
        velocity for the order parameter and so on.

        Returns
        -------
        new_path : object like :py:class:`.PathBase`
            This is basically a copy of `self`, just reversed.
        """
        new_path = self.empty_path()
        for phasepoint in self.trajectory(reverse=True):
            new_point = {key: val for key, val in phasepoint.items()}
            new_point['vel'] = self.reverse_velocities(new_point['vel'])
            app = new_path.append(new_point)
            if not app:  # pragma: no cover
                msg = 'Could not reverse path'
                logger.error(msg)
                return None
        return new_path

    def __str__(self):
        """Return a simple string representation of the Path."""
        msg = ['Path with length {} (max: {})'.format(self.length,
                                                      self.maxlen)]
        msg += ['Order parameter max: {}'.format(self.ordermax)]
        msg += ['Order parameter min: {}'.format(self.ordermin)]
        if self.length > 0:
            msg += ['Start {}'.format(self.order[0][0])]
            msg += ['End {}'.format(self.order[-1][0])]
        if self.status:
            msg += ['Status: {}'.format(_STATUS[self.status])]
        if self.generated:
            move = self.generated[0]
            txtmove = _GENERATED.get(move, 'unknown move')
            msg += ['Generated: {}'.format(txtmove)]
        return '\n'.join(msg)

    @abstractmethod
    def restart_info(self):
        """Return a dictionary with restart information."""
        return

    @abstractmethod
    def empty_path(self, **kwargs):
        """Return an empty path of same class as the current one.

        This function is intended to spawn child paths that share some
        propertis and also some characteristics of the current path.
        The idea here is that a path of a certain class should only be
        able to create paths of the same class.

        Returns
        -------
        out : object like :py:class:`.PathBase`
            A new empty path.
        """
        return


class Path(PathBase):
    """A path where the full trajectory is stored in memory.

    This class represents a path. A path consist of a series of
    consecutive snapshots (the trajectory) with the corresponding
    order parameter. Here we store all information for all phase points
    on the path.

    Attributes
    ----------
    pos : list of numpy.arrays
        Positions as function of time
    vel : list of numpy.arrays
        Velocities as function of time.
    """

    def __init__(self, rgen, maxlen=None, time_origin=0):
        """Initialize the Path object.

        Parameters
        ----------
        rgen : object like :py:class:`.RandomGenerator`
            This is the random generator that will be used.
        maxlen : int, optional
            This is the max-length of the path. The default value,
            None, is just a path of arbitrary length.
        time_origin : int, optional
            This can be used to store the shooting point of a parent
            trajectory.
        """
        super().__init__(rgen, maxlen=maxlen,
                         time_origin=time_origin)
        self.pos = []
        self.vel = []

    def phasepoint(self, idx):
        """Return a specific phase point.

        Parameters
        ----------
        idx : int
            Index for phase-space point to return.

        Returns
        -------
        out : tuple
            A phase-space point in the path.
        """
        phasepoint = {'order': self.order[idx], 'pos': self.pos[idx],
                      'vel': self.vel[idx], 'vpot': self.vpot[idx],
                      'ekin': self.ekin[idx]}
        return phasepoint

    def _append_posvel(self, pos, vel):
        """Append positions and velocities to the path."""
        self.pos.append(np.copy(pos))
        self.vel.append(np.copy(vel))

    def get_shooting_point(self):
        """Return a shooting point from the path.

        This will simply draw a shooting point from the path at
        random. All points can be selected with equal probability with
        the exception of the end points which are not considered.

        Returns
        -------
        out[0] : tuple
            The phase point we selected. The first item are the
            order parameter(s).
        out[1] : int
            The shooting point index.
        """
        idx = self.rgen.random_integers(1, self.length - 2)
        return self.phasepoint(idx), idx

    def empty_path(self, **kwargs):
        """Return an empty path of same class as the current one.

        Returns
        -------
        out : object like :py:class:`.PathBase`
            A new empty path.
        """
        maxlen = kwargs.get('maxlen', None)
        time_origin = kwargs.get('time_origin', 0)
        return self.__class__(self.rgen, maxlen=maxlen,
                              time_origin=time_origin)

    @staticmethod
    def reverse_velocities(vel):
        """Reverse velocities.

        Parameters
        ----------
        vel : np.array or None
            Velocities to reverse.

        Returns
        -------
        out : np.array or None
            The reversed velocities."""
        if vel is not None:
            return vel * -1
        return None

    def restart_info(self):
        """Return a dictionary with restart information."""
        info = {
            'generated': self.generated,
            'maxlen': self.maxlen,
            'order': self.order,
            'ordermin': self.ordermin,
            'ordermax': self.ordermax,
            'time_origin': self.time_origin,
            'status': self.status,
            'vpot': self.vpot,
            'ekin': self.ekin,
            'pos': self.pos,
            'vel': self.vel,
            'length': self.length,
        }
        return info

    def load_restart_info(self, info):
        """Set up the path using restart information."""
        for key in info:
            if hasattr(self, key):
                setattr(self, key, info[key])


class PathExt(Path):
    """A path where snapshots are not stored in memory.

    This class represents a path where the snapshots are stored
    on disk and not in memory. This is useful when we are using
    external engines and do not have to read entire trajectories
    into memory.

    Attributes
    ----------
    pos : list of strings
        Positions as function of time. Here, the positions are
        actually files which contain the positions AND velocities
    vel : list of booleans
        If an item in this list is True, the the corresponding
        velocities in the snapshot file in ``pos`` should be
        reversed.
    """

    def __init__(self, rgen, maxlen=None, time_origin=0):
        """Initialize the Path object.

        Parameters
        ----------
        rgen : object like :py:class:`.RandomGenerator`
            This is the random generator that will be used.
        maxlen : int, optional
            This is the max-length of the path. The default value,
            None, is just a path of arbitrary length.
        time_origin : int, optional
            This can be used to store the shooting point of a parent
            trajectory.
        """
        super().__init__(rgen, maxlen=maxlen,
                         time_origin=time_origin)

    def _append_posvel(self, pos, vel):
        """Add positions and velocities to the path."""
        self.pos.append(pos)
        self.vel.append(vel)

    @staticmethod
    def reverse_velocities(vel):
        """Reverse the velocities.

        Parameters
        ----------
        vel : boolean
            The velocities to reverse.

        Returns
        -------
        out : boolean
            The reversed velocities.
        """
        return not vel
