# This file is part of the pyMOR project (http://www.pymor.org).
# Copyright 2013-2019 pyMOR developers and contributors. All rights reserved.
# License: BSD 2-Clause License (http://opensource.org/licenses/BSD-2-Clause)

import numpy as np
import scipy.linalg as spla

from pymor.algorithms.gram_schmidt import gram_schmidt, gram_schmidt_biorth
from pymor.algorithms.riccati import solve_ricc_lrcf, solve_pos_ricc_lrcf
from pymor.core.interfaces import BasicInterface
from pymor.models.iosys import LTIModel
from pymor.operators.constructions import IdentityOperator
from pymor.reductors.basic import LTIPGReductor


class GenericBTReductor(BasicInterface):
    """Generic Balanced Truncation reductor.

    Parameters
    ----------
    fom
        The full-order |LTIModel| to reduce.
    mu
        |Parameter|.
    """
    def __init__(self, fom, mu=None):
        assert isinstance(fom, LTIModel)
        self.fom = fom
        self.mu = fom.parse_parameter(mu)
        self.V = None
        self.W = None
        self._pg_reductor = None
        self._sv_U_V_cache = None

    def _gramians(self):
        """Return low-rank Cholesky factors of Gramians."""
        raise NotImplementedError

    def _sv_U_V(self):
        """Return singular values and vectors."""
        if self._sv_U_V_cache is None:
            cf, of = self._gramians()
            U, sv, Vh = spla.svd(self.fom.E.apply2(of, cf, mu=self.mu), lapack_driver='gesvd')
            self._sv_U_V_cache = (sv, U.T, Vh)
        return self._sv_U_V_cache

    def error_bounds(self):
        """Returns error bounds for all possible reduced orders."""
        raise NotImplementedError

    def reduce(self, r=None, tol=None, projection='bfsr'):
        """Generic Balanced Truncation.

        Parameters
        ----------
        r
            Order of the reduced model if `tol` is `None`, maximum order if `tol` is specified.
        tol
            Tolerance for the error bound if `r` is `None`.
        projection
            Projection method used:

            - `'sr'`: square root method
            - `'bfsr'`: balancing-free square root method (default, since it avoids scaling by
              singular values and orthogonalizes the projection matrices, which might make it more
              accurate than the square root method)
            - `'biorth'`: like the balancing-free square root method, except it biorthogonalizes the
              projection matrices (using :func:`~pymor.algorithms.gram_schmidt.gram_schmidt_biorth`)

        Returns
        -------
        rom
            Reduced-order model.
        """
        assert r is not None or tol is not None
        assert r is None or 0 < r < self.fom.order
        assert projection in ('sr', 'bfsr', 'biorth')

        cf, of = self._gramians()
        sv, sU, sV = self._sv_U_V()

        # find reduced order if tol is specified
        if tol is not None:
            error_bounds = self.error_bounds()
            r_tol = np.argmax(error_bounds <= tol) + 1
            r = r_tol if r is None else min(r, r_tol)
        if r > min(len(cf), len(of)):
            raise ValueError('r needs to be smaller than the sizes of Gramian factors.')

        # compute projection matrices
        self.V = cf.lincomb(sV[:r])
        self.W = of.lincomb(sU[:r])
        if projection == 'sr':
            alpha = 1 / np.sqrt(sv[:r])
            self.V.scal(alpha)
            self.W.scal(alpha)
        elif projection == 'bfsr':
            gram_schmidt(self.V, atol=0, rtol=0, copy=False)
            gram_schmidt(self.W, atol=0, rtol=0, copy=False)
        elif projection == 'biorth':
            gram_schmidt_biorth(self.V, self.W, product=self.fom.E, copy=False)

        # find reduced-order model
        if self.fom.parametric:
            fom_mu = self.fom.with_(**{op: getattr(self.fom, op).assemble(mu=self.mu)
                                       for op in ['A', 'B', 'C', 'D', 'E']},
                                    parameter_space=None)
        else:
            fom_mu = self.fom
        self._pg_reductor = LTIPGReductor(fom_mu, self.W, self.V, projection in ('sr', 'biorth'))
        rom = self._pg_reductor.reduce()
        return rom

    def reconstruct(self, u):
        """Reconstruct high-dimensional vector from reduced vector `u`."""
        return self._pg_reductor.reconstruct(u)


class BTReductor(GenericBTReductor):
    """Standard (Lyapunov) Balanced Truncation reductor.

    See Section 7.3 in [A05]_.

    Parameters
    ----------
    fom
        The full-order |LTIModel| to reduce.
    mu
        |Parameter|.
    """
    def _gramians(self):
        return self.fom.gramian('c_lrcf', mu=self.mu), self.fom.gramian('o_lrcf', mu=self.mu)

    def error_bounds(self):
        sv = self._sv_U_V()[0]
        return 2 * sv[:0:-1].cumsum()[::-1]


class LQGBTReductor(GenericBTReductor):
    r"""Linear Quadratic Gaussian (LQG) Balanced Truncation reductor.

    See Section 3 in [MG91]_.

    Parameters
    ----------
    fom
        The full-order |LTIModel| to reduce.
    mu
        |Parameter|.
    solver_options
        The solver options to use to solve the Riccati equations.
    """
    def __init__(self, fom, mu=None, solver_options=None):
        super().__init__(fom, mu=mu)
        self.solver_options = solver_options

    def _gramians(self):
        A, B, C, E = (getattr(self.fom, op).assemble(mu=self.mu)
                      for op in ['A', 'B', 'C', 'E'])
        if isinstance(E, IdentityOperator):
            E = None
        options = self.solver_options

        cf = solve_ricc_lrcf(A, E, B.as_range_array(), C.as_source_array(),
                             trans=False, options=options)
        of = solve_ricc_lrcf(A, E, B.as_range_array(), C.as_source_array(),
                             trans=True, options=options)
        return cf, of

    def error_bounds(self):
        sv = self._sv_U_V()[0]
        return 2 * (sv[:0:-1] / np.sqrt(1 + sv[:0:-1]**2)).cumsum()[::-1]


class BRBTReductor(GenericBTReductor):
    r"""Bounded Real (BR) Balanced Truncation reductor.

    See [A05]_ (Section 7.5.3) and [OJ88]_.

    Parameters
    ----------
    fom
        The full-order |LTIModel| to reduce.
    gamma
        Upper bound for the :math:`\mathcal{H}_\infty`-norm.
    mu
        |Parameter|.
    solver_options
        The solver options to use to solve the positive Riccati equations.
    """
    def __init__(self, fom, gamma=1, mu=None, solver_options=None):
        super().__init__(fom, mu=mu)
        self.gamma = gamma
        self.solver_options = solver_options

    def _gramians(self):
        A, B, C, E = (getattr(self.fom, op).assemble(mu=self.mu)
                      for op in ['A', 'B', 'C', 'E'])
        if isinstance(E, IdentityOperator):
            E = None
        options = self.solver_options

        cf = solve_pos_ricc_lrcf(A, E, B.as_range_array(), C.as_source_array(),
                                 R=self.gamma**2 * np.eye(self.fom.output_dim) if self.gamma != 1 else None,
                                 trans=False, options=options)
        of = solve_pos_ricc_lrcf(A, E, B.as_range_array(), C.as_source_array(),
                                 R=self.gamma**2 * np.eye(self.fom.input_dim) if self.gamma != 1 else None,
                                 trans=True, options=options)
        return cf, of

    def error_bounds(self):
        sv = self._sv_U_V()[0]
        return 2 * sv[:0:-1].cumsum()[::-1]
