# -*- coding: utf-8 -*-
# copyright: sktime developers, BSD-3-Clause License (see LICENSE file)
"""Implements hierarchical reconciliation transformers.

These reconcilers only depend on the structure of the hierarchy.
"""

__author__ = ["ciaran-g", "eenticott-shell", "k1m190r"]

from warnings import warn

import numpy as np
import pandas as pd
from numpy.linalg import inv

from sktime.transformations.base import BaseTransformer
from sktime.transformations.hierarchical.aggregate import _check_index_no_total

# TODO: failing test which are escaped


class Reconciler(BaseTransformer):
    """Hierarchical reconcilation transformer.

    Hierarchical reconciliation is a transfromation which is used to make the
    predictions in a hierarchy of time-series sum together appropriately.

    The methods implemented in this class only require the structure of the
    hierarchy to reconcile the forecasts.

    Please refer to [1]_ for further information

    Parameters
    ----------
    method : {"bu", "ols", "wls_str"}, default="bu"
        The reconciliation approach applied to the forecasts
            "bu" - bottom-up
            "ols" - ordinary least squares
            "wls_str" - weighted least squares (structural)

    References
    ----------
    .. [1] https://otexts.com/fpp3/hierarchical.html
    """

    _tags = {
        "scitype:transform-input": "Series",
        "scitype:transform-output": "Series",
        "scitype:transform-labels": "None",
        "scitype:instancewise": False,  # is this an instance-wise transform?
        "X_inner_mtype": [
            "pd.Series",
            "pd.DataFrame",
            "pd-multiindex",
            "pd_multiindex_hier",
        ],
        "y_inner_mtype": "None",  # which mtypes do _fit/_predict support for y?
        "capability:inverse_transform": False,
        "skip-inverse-transform": True,  # is inverse-transform skipped when called?
        "univariate-only": True,  # can the transformer handle multivariate X?
        "handles-missing-data": False,  # can estimator handle missing data?
        "X-y-must-have-same-index": False,  # can estimator handle different X/y index?
        "fit_is_empty": False,  # is fit empty and can be skipped? Yes = True
        "transform-returns-same-time-index": True,
    }

    METHOD_LIST = ["bu", "ols", "wls_str"]

    def __init__(self, method="bu"):

        self.method = method

        super(Reconciler, self).__init__()

    def _add_totals(self, X):
        """Add total levels to X, using Aggregate."""
        from sktime.transformations.hierarchical.aggregate import Aggregator

        return Aggregator().fit_transform(X)

    def _fit(self, X, y=None):
        """Fit transformer to X and y.

        private _fit containing the core logic, called from fit

        Parameters
        ----------
        X : Panel of mtype pd_multiindex_hier
            Data to fit transform to
        y :  Ignored argument for interface compatibility.

        Returns
        -------
        self: reference to self
        """
        self._check_method()

        # check the length of index
        if X.index.nlevels < 2:
            return self

        # check index for no "__total", if not add totals to X
        if _check_index_no_total(X):
            X = self._add_totals(X)

        if self.method == "bu":
            self.g_matrix = _get_g_matrix_bu(X)
        elif self.method == "ols":
            self.g_matrix = _get_g_matrix_ols(X)
        elif self.method == "wls_str":
            self.g_matrix = _get_g_matrix_wls_str(X)
        else:
            raise RuntimeError("unreachable condition, error in _check_method")

        self.s_matrix = _get_s_matrix(X)

        return self

    def _transform(self, X, y=None):
        """Transform X and return a transformed version.

        private _transform containing core logic, called from transform

        Parameters
        ----------
        X : Panel of mtype pd_multiindex_hier
            Data to be transformed
        y : Ignored argument for interface compatibility.

        Returns
        -------
        recon_preds : multi-indexed pd.DataFrame of Panel mtype pd_multiindex
        """
        # check the length of index
        if X.index.nlevels < 2:
            warn(
                "Reconciler is intended for use with X.index.nlevels > 1. "
                "Returning X unchanged."
            )
            return X

        # check index for no "__total", if not add totals to X
        if _check_index_no_total(X):
            warn(
                "No elements of the index of X named '__total' found. Adding "
                "aggregate levels using the default Aggregator transformer "
                "before reconciliation."
            )
            X = self._add_totals(X)

        # check here that index of X matches the self.s_matrix
        al_inds = X.droplevel(level=-1).index.unique()
        chk_newindx = np.all(self.s_matrix.index == al_inds)
        if not chk_newindx:
            raise ValueError(
                "Check unique indexes of X.droplevel(level=-1) matches "
                "the data used in Reconciler().fit(X)."
            )

        # include index between matrices here as in df.dot()?
        X = X.groupby(level=-1)
        recon_preds = X.transform(
            lambda y: np.dot(self.s_matrix, np.dot(self.g_matrix, y))
        )

        return recon_preds

    def _check_method(self):
        """Raise warning if method is not defined correctly."""
        if not np.isin(self.method, self.METHOD_LIST):
            raise ValueError(f"""method must be one of {self.METHOD_LIST}.""")
        else:
            pass

    @classmethod
    def get_test_params(cls):
        """Return testing parameter settings for the estimator.

        Returns
        -------
        params : dict, default = {}
            Parameters to create testing instances of the class
            Each dict are parameters to construct an "interesting" test instance, i.e.,
            `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
            `create_test_instance` uses the first (or only) dictionary in `params`
        """
        return [{"method": x} for x in cls.METHOD_LIST]


def _get_s_matrix(X):
    """Determine the summation "S" matrix.

    Reconciliation methods require the S matrix, which is defined by the
    structure of the hierarchy only. The S matrix is inferred from the input
    multi-index of the forecasts and is used to sum bottom-level forecasts
    appropriately.

    Please refer to [1]_ for further information.

    Parameters
    ----------
    X :  Panel of mtype pd_multiindex_hier

    Returns
    -------
    s_matrix : pd.DataFrame with rows equal to the number of unique nodes in
        the hierarchy, and columns equal to the number of bottom level nodes only,
        i.e. with no aggregate nodes. The matrix indexes is inherited from the
        input data, with the time level removed.

    References
    ----------
    .. [1] https://otexts.com/fpp3/hierarchical.html
    """
    # get bottom level indexes
    bl_inds = (
        X.loc[~(X.index.get_level_values(level=-2).isin(["__total"]))]
        .index.droplevel(level=-1)
        .unique()
    )
    # get all level indexes
    al_inds = X.droplevel(level=-1).index.unique()

    # set up matrix
    s_matrix = pd.DataFrame(
        [[0.0 for i in range(len(bl_inds))] for i in range(len(al_inds))],
        index=al_inds,
    )
    s_matrix.columns = bl_inds

    # now insert indicator for bottom level
    for i in s_matrix.columns:
        s_matrix.loc[s_matrix.index == i, i] = 1.0

    # now for each unique column add aggregate indicator
    for i in s_matrix.columns:
        if s_matrix.index.nlevels > 1:
            # replace index with totals -> ("nodeA", "__total")
            agg_ind = list(i)[::-1]
            for j in range(len(agg_ind)):
                agg_ind[j] = "__total"
                # insert indicator
                s_matrix.loc[tuple(agg_ind[::-1]), i] = 1.0
        else:
            s_matrix.loc["__total", i] = 1.0

    # drop new levels not present in orginal matrix
    s_matrix.dropna(inplace=True)

    return s_matrix


def _get_g_matrix_bu(X):
    """Determine the reconciliation "G" matrix for the bottom up method.

    Reconciliation methods require the G matrix. The G matrix is used to redefine
    base forecasts for the entire hierarchy to the bottom-level only before
    summation using the S matrix.

    Please refer to [1]_ for further information.

    Parameters
    ----------
    X :  Panel of mtype pd_multiindex_hier

    Returns
    -------
    g_matrix : pd.DataFrame with rows equal to the number of bottom level nodes
        only, i.e. with no aggregate nodes, and columns equal to the number of
        unique nodes in the hierarchy. The matrix indexes is inherited from the
        input data, with the time level removed.

    References
    ----------
    .. [1] https://otexts.com/fpp3/hierarchical.html
    """
    # get bottom level indexes
    bl_inds = (
        X.loc[~(X.index.get_level_values(level=-2).isin(["__total"]))]
        .index.droplevel(level=-1)
        .unique()
    )

    # get all level indexes
    al_inds = X.droplevel(level=-1).index.unique()

    g_matrix = pd.DataFrame(
        [[0.0 for i in range(len(bl_inds))] for i in range(len(al_inds))],
        index=al_inds,
    )
    g_matrix.columns = bl_inds

    # now insert indicator for bottom level
    for i in g_matrix.columns:
        g_matrix.loc[g_matrix.index == i, i] = 1.0

    return g_matrix.transpose()


def _get_g_matrix_ols(X):
    """Determine the reconciliation "G" matrix for the ordinary least squares method.

    Reconciliation methods require the G matrix. The G matrix is used to redefine
    base forecasts for the entire hierarchy to the bottom-level only before
    summation using the S matrix.

    Please refer to [1]_ for further information.

    Parameters
    ----------
    X :  Panel of mtype pd_multiindex_hier

    Returns
    -------
    g_ols : pd.DataFrame with rows equal to the number of bottom level nodes
        only, i.e. with no aggregate nodes, and columns equal to the number of
        unique nodes in the hierarchy. The matrix indexes is inherited from the
        summation matrix.

    References
    ----------
    .. [1] https://otexts.com/fpp3/hierarchical.html
    """
    # get s matrix
    smat = _get_s_matrix(X)
    # get g
    g_ols = pd.DataFrame(
        np.dot(inv(np.dot(np.transpose(smat), smat)), np.transpose(smat))
    )
    # set indexes of matrix
    g_ols = g_ols.transpose()
    g_ols = g_ols.set_index(smat.index)
    g_ols.columns = smat.columns
    g_ols = g_ols.transpose()

    return g_ols


def _get_g_matrix_wls_str(X):
    """Reconciliation "G" matrix for the weighted least squares (structural) method.

    Reconciliation methods require the G matrix. The G matrix is used to re-define
    base forecasts for the entire hierarchy to the bottom-level only before
    summation using the S matrix.

    Please refer to [1]_ for further information.

    Parameters
    ----------
    X :  Panel of mtype pd_multiindex_hier

    Returns
    -------
    g_wls_str : pd.DataFrame with rows equal to the number of bottom level nodes
        only, i.e. with no aggregate nodes, and columns equal to the number of
        unique nodes in the hierarchy. The matrix indexes is inherited from the
        summation matrix.

    References
    ----------
    .. [1] https://otexts.com/fpp3/hierarchical.html
    """
    # this is similar to the ols except we have a new matrix W
    smat = _get_s_matrix(X)

    diag_data = np.diag(smat.sum(axis=1).values)
    w_mat = pd.DataFrame(diag_data, index=smat.index, columns=smat.index)

    g_wls_str = pd.DataFrame(
        np.dot(
            inv(np.dot(np.transpose(smat), np.dot(w_mat, smat))),
            np.dot(np.transpose(smat), w_mat),
        )
    )
    # set indexes of matrix
    g_wls_str = g_wls_str.transpose()
    g_wls_str = g_wls_str.set_index(smat.index)
    g_wls_str.columns = smat.columns
    g_wls_str = g_wls_str.transpose()

    return g_wls_str
