from __future__ import division

import logging

from numpy import abs as npabs
from numpy import max as npmax
from numpy import asarray, concatenate

from ..exception import OptimixError


def minimize(function, verbose=True, factr=1e5, pgtol=1e-7):
    r"""Minimize a function using L-BFGS-B.

    Parameters
    ----------
    function : object
        Objective function. It has to implement the
        :class:`optimix.function.Function` interface.
    verbose : bool, optional
        ``True`` for verbose output; ``False`` otherwise.
    factr : float, optional
        The iteration stops when
        ``(f^k - f^{k+1})/max{|f^k|,|f^{k+1}|,1} <= factr * eps``,
        where ``eps`` is the machine precision, which is automatically
        generated by the code. Typical values for `factr` are: 1e12 for
        low accuracy; 1e7 for moderate accuracy; 10.0 for extremely
        high accuracy. See Notes for relationship to `ftol`, which is exposed
        (instead of `factr`) by the `scipy.optimize.minimize` interface to
        L-BFGS-B.
    pgtol : float, optional
        The iteration will stop when
        ``max{|proj g_i | i = 1, ..., n} <= pgtol``
        where ``pg_i`` is the i-th component of the projected gradient.
    """
    _minimize(
        ProxyFunction(function, verbose, False), factr=factr, pgtol=pgtol)


def maximize(function, verbose=True, factr=1e5, pgtol=1e-7):
    r"""Maximize a function using L-BFGS-B.

    Parameters
    ----------
    function : object
        Objective function. It has to implement the
        :class:`optimix.function.Function` interface.
    verbose : bool
        ``True`` for verbose output; ``False`` otherwise.
    factr : float, optional
        The iteration stops when
        ``(f^k - f^{k+1})/max{|f^k|,|f^{k+1}|,1} <= factr * eps``,
        where ``eps`` is the machine precision, which is automatically
        generated by the code. Typical values for `factr` are: 1e12 for
        low accuracy; 1e7 for moderate accuracy; 10.0 for extremely
        high accuracy. See Notes for relationship to `ftol`, which is exposed
        (instead of `factr`) by the `scipy.optimize.minimize` interface to
        L-BFGS-B.
    pgtol : float, optional
        The iteration will stop when
        ``max{|proj g_i | i = 1, ..., n} <= pgtol``
        where ``pg_i`` is the i-th component of the projected gradient.
    """
    _minimize(ProxyFunction(function, verbose, True), factr=factr, pgtol=pgtol)


def _do_flatten(x):
    if isinstance(x, (list, tuple)):
        return concatenate([asarray(xi).ravel() for xi in x])
    return concatenate(x)


class ProxyFunction(object):
    def __init__(self, function, verbose, negative):
        self._function = function
        self._signal = -1 if negative else +1
        self.verbose = verbose
        self._solutions = []
        self._logger = logging.getLogger(__name__)

    @property
    def solutions(self):
        return self._solutions

    def names(self):
        return sorted(self._function.variables().select(fixed=False).names())

    def value(self):
        return self._signal * self._function.value()

    def gradient(self):
        g = self._function.gradient()
        grad = {name: self._signal * g[name] for name in self.names()}

        if self._logger.getEffectiveLevel() <= logging.DEBUG:
            self._logger.debug("Gradient: %s", str(grad))

        return grad

    def unflatten(self, x):
        variables = self._function.variables().select(fixed=False)
        d = dict()
        offset = 0
        for name in self.names():
            size = variables.get(name).size
            d[name] = x[offset:offset + size]
            offset += size
        return d

    def flatten(self, d):
        names = self.names()
        return _do_flatten([d[name] for name in names])

    def __call__(self, x):

        x = asarray(x).ravel()
        self._solutions.append(x.copy())
        self._function.variables().set(self.unflatten(x))

        if self._logger.getEffectiveLevel() <= logging.DEBUG:
            var = self._function.variables().select(fixed=False)
            for name in self.names():
                self._logger.debug("Setting %s to %s", name, var[name])

        v = self.value()
        g = self.flatten(self.gradient())

        if self._logger.getEffectiveLevel() <= logging.DEBUG:
            self._logger.debug("Function evaluation is %.10f", v)

        return v, g

    def set_solution(self, x):
        self._function.variables().set(self.unflatten(x))

    def get_solution(self):
        v = self._function.variables().select(fixed=False)
        return concatenate([v.get(n).asarray().ravel() for n in self.names()])


def _try_minimize(proxy_function, n, factr, pgtol):
    from scipy.optimize import fmin_l_bfgs_b

    disp = 1 if proxy_function.verbose else 0
    logger = logging.getLogger()

    if n == 0:
        raise OptimixError("Too many bad solutions")

    warn = False
    try:
        x0 = proxy_function.get_solution()

        bounds = []

        var = proxy_function._function.variables().select(fixed=False)
        for name in proxy_function.names():
            if len(var[name].shape) == 0:
                bounds.append(var[name].bounds)
            else:
                bounds += var[name].bounds

        res = fmin_l_bfgs_b(
            proxy_function,
            x0,
            bounds=bounds,
            factr=factr,
            pgtol=pgtol,
            disp=disp)

    except OptimixError:
        warn = True
    else:
        warn = res[2]['warnflag'] > 0

    if warn:
        xs = proxy_function.solutions
        if len(xs) < 2:
            raise OptimixError("Bad solution at the first iteration.")

        proxy_function.set_solution(xs[-2] / 5 + xs[-1] / 5)

        logger.info("Optimix: Restarting L-BFGS-B due to bad solution.")
        res = _try_minimize(proxy_function, n - 1)

    return res


def _minimize(proxy_function, factr, pgtol):

    g = proxy_function.flatten(proxy_function.gradient())
    if npmax(npabs(g)) <= pgtol:
        return

    r = _try_minimize(proxy_function, 5, factr=factr, pgtol=pgtol)

    if r[2]['warnflag'] == 1:
        raise OptimixError("L-BFGS-B: too many function evaluations" +
                           " or too many iterations")
    elif r[2]['warnflag'] == 2:
        raise OptimixError("L-BFGS-B: %s" % r[2]['task'])

    proxy_function.set_solution(r[0])
