"""
This module gathers tree-based methods, including decision, regression and
randomized trees. Single and multi-output problems are both handled.
"""

# Authors: Gilles Louppe <g.louppe@gmail.com>
#          Peter Prettenhofer <peter.prettenhofer@gmail.com>
#          Brian Holt <bdholt1@gmail.com>
#          Noel Dawe <noel@dawe.me>
#          Satrajit Gosh <satrajit.ghosh@gmail.com>
#          Joly Arnaud <arnaud.v.joly@gmail.com>
#          Fares Hedayati <fares.hedayati@gmail.com>
#          Nelson Liu <nelson@nelsonliu.me>
#
# License: BSD 3 clause

from __future__ import division


import numbers
from abc import ABCMeta
from abc import abstractmethod
from math import ceil

import numpy as np
from scipy.sparse import issparse

from sklearn.base import BaseEstimator
from sklearn.base import ClassifierMixin
from sklearn.base import RegressorMixin
from sklearn.externals import six
from sklearn.preprocessing import LabelEncoder
from sklearn.utils import check_random_state
from sklearn.utils import compute_sample_weight
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import check_array
from sklearn.utils.validation import check_is_fitted
from sklearn.utils.validation import check_X_y
from sklearn.exceptions import NotFittedError

from ._criterion import Criterion
from ._splitter import Splitter
from ._tree import DepthFirstTreeBuilder
from ._tree import PartialFitTreeBuilder
from ._tree import Tree
from . import _tree, _splitter, _criterion

__all__ = ["DecisionTreeClassifier",
           "DecisionTreeRegressor",
           "ExtraTreeClassifier",
           "ExtraTreeRegressor"]


# =============================================================================
# Types and constants
# =============================================================================

DTYPE = _tree.DTYPE
DOUBLE = _tree.DOUBLE

CRITERIA_CLF = {"classification": _criterion.ClassificationCriterion}
CRITERIA_REG = {"mse": _criterion.MSE}

SPLITTERS = {"mondrian": _splitter.MondrianSplitter}

# =============================================================================
# Base decision tree
# =============================================================================


class BaseDecisionTree(six.with_metaclass(ABCMeta, BaseEstimator)):
    """Base class for decision trees.

    Warning: This class should not be used directly.
    Use derived classes instead.
    """

    @abstractmethod
    def __init__(self,
                 criterion,
                 splitter,
                 max_depth,
                 min_samples_split,
                 random_state,
                 class_weight=None):
        self.criterion = criterion
        self.splitter = splitter
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.random_state = random_state
        self.class_weight = class_weight

    def fit(self, X, y, sample_weight=None, check_input=True,
            X_idx_sorted=None):
        random_state = check_random_state(self.random_state)
        if check_input:
            X, y = check_X_y(X, y, dtype=DTYPE, multi_output=False)

        # Determine output settings
        n_samples, self.n_features_ = X.shape
        is_classification = isinstance(self, ClassifierMixin)

        y = np.atleast_1d(y)
        expanded_class_weight = None

        if y.ndim == 1:
            # reshape is necessary to preserve the data contiguity against vs
            # [:, np.newaxis] that does not.
            y = np.reshape(y, (-1, 1))

        self.n_outputs_ = y.shape[1]

        if is_classification:
            check_classification_targets(y)
            y = np.copy(y)

            self.classes_ = []
            self.n_classes_ = []

            if self.class_weight is not None:
                y_original = np.copy(y)

            y_encoded = np.zeros(y.shape, dtype=np.int)
            for k in range(self.n_outputs_):
                classes_k, y_encoded[:, k] = np.unique(y[:, k],
                                                       return_inverse=True)
                self.classes_.append(classes_k)
                self.n_classes_.append(classes_k.shape[0])
            y = y_encoded

            if self.class_weight is not None:
                expanded_class_weight = compute_sample_weight(
                    self.class_weight, y_original)

        else:
            self.classes_ = [None] * self.n_outputs_
            self.n_classes_ = [1] * self.n_outputs_

        self.n_classes_ = np.array(self.n_classes_, dtype=np.intp)

        if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
            y = np.ascontiguousarray(y, dtype=DOUBLE)

        # Check parameters
        max_depth = ((2 ** 31) - 1 if self.max_depth is None
                     else self.max_depth)

        if isinstance(self.min_samples_split, (numbers.Integral, np.integer)):
            if not 2 <= self.min_samples_split:
                raise ValueError("min_samples_split must be an integer "
                                 "greater than 1 or a float in (0.0, 1.0]; "
                                 "got the integer %s"
                                 % self.min_samples_split)
            min_samples_split = self.min_samples_split
        else:  # float
            if not 0. < self.min_samples_split <= 1.:
                raise ValueError("min_samples_split must be an integer "
                                 "greater than 1 or a float in (0.0, 1.0]; "
                                 "got the float %s"
                                 % self.min_samples_split)
            min_samples_split = int(ceil(self.min_samples_split * n_samples))
            min_samples_split = max(2, min_samples_split)

        if len(y) != n_samples:
            raise ValueError("Number of labels=%d does not match "
                             "number of samples=%d" % (len(y), n_samples))
        if max_depth <= 0:
            raise ValueError("max_depth must be greater than zero. ")

        if sample_weight is not None:
            if (getattr(sample_weight, "dtype", None) != DOUBLE or
                    not sample_weight.flags.contiguous):
                sample_weight = np.ascontiguousarray(
                    sample_weight, dtype=DOUBLE)
            if len(sample_weight.shape) > 1:
                raise ValueError("Sample weights array has more "
                                 "than one dimension: %d" %
                                 len(sample_weight.shape))
            if len(sample_weight) != n_samples:
                raise ValueError("Number of weights=%d does not match "
                                 "number of samples=%d" %
                                 (len(sample_weight), n_samples))

        if expanded_class_weight is not None:
            if sample_weight is not None:
                sample_weight = sample_weight * expanded_class_weight
            else:
                sample_weight = expanded_class_weight

        # Build tree
        criterion = self.criterion
        if not isinstance(criterion, Criterion):
            if is_classification:
                criterion = CRITERIA_CLF[self.criterion](self.n_outputs_,
                                                         self.n_classes_)
            else:
                criterion = CRITERIA_REG[self.criterion](self.n_outputs_,
                                                         n_samples)
        splitter = self.splitter
        if not isinstance(self.splitter, Splitter):
            splitter = SPLITTERS[self.splitter](criterion,
                                                random_state)
        self.tree_ = Tree(self.n_features_, self.n_classes_, self.n_outputs_)
        builder = DepthFirstTreeBuilder(splitter, min_samples_split,
                                        max_depth)
        builder.build(self.tree_, X, y, sample_weight, X_idx_sorted)

        if self.n_outputs_ == 1:
            self.n_classes_ = self.n_classes_[0]
            self.classes_ = self.classes_[0]

        return self

    def _validate_X_predict(self, X, check_input):
        """Validate X whenever one tries to predict, apply, predict_proba"""
        if self.tree_ is None:
            raise NotFittedError("Estimator not fitted, "
                                 "call `fit` before exploiting the model.")

        if check_input:
            X = check_array(X, dtype=DTYPE, accept_sparse="csr")
            if issparse(X) and (X.indices.dtype != np.intc or
                                X.indptr.dtype != np.intc):
                raise ValueError("No support for np.int64 index based "
                                 "sparse matrices")

        n_features = X.shape[1]
        if self.n_features_ != n_features:
            raise ValueError("Number of features of the model must "
                             "match the input. Model n_features is %s and "
                             "input n_features is %s "
                             % (self.n_features_, n_features))

        return X

    def predict(self, X, check_input=True, return_std=False):
        """Predict class or regression value for X.

        For a classification model, the predicted class for each sample in X is
        returned. For a regression model, the predicted value based on X is
        returned.

        Parameters
        ----------
        X : array-like or sparse matrix of shape = [n_samples, n_features]
            The input samples. Internally, it will be converted to
            ``dtype=np.float32`` and if a sparse matrix is provided
            to a sparse ``csr_matrix``.

        check_input : boolean, (default=True)
            Allow to bypass several input checking.
            Don't use this parameter unless you know what you do.

        return_std : boolean, (default=True)
            Whether or not to return the standard deviation.

        Returns
        -------
        y : array of shape = [n_samples] or [n_samples, n_outputs]
            The predicted classes, or the predict values.
        """
        check_is_fitted(self, 'tree_')
        X = self._validate_X_predict(X, check_input)

        # Classification
        if isinstance(self, ClassifierMixin):
            return self.classes_[self.predict_proba(X).argmax(axis=1)]

        # Regression
        else:
            mean_and_std = self.tree_.predict(
                X, return_std=return_std, is_regression=True)
            if return_std:
                return mean_and_std
            return mean_and_std[0]

    def apply(self, X, check_input=True):
        """
        Returns the index of the leaf that each sample is predicted as.

        .. versionadded:: 0.17

        Parameters
        ----------
        X : array_like or sparse matrix, shape = [n_samples, n_features]
            The input samples. Internally, it will be converted to
            ``dtype=np.float32`` and if a sparse matrix is provided
            to a sparse ``csr_matrix``.

        check_input : boolean, (default=True)
            Allow to bypass several input checking.
            Don't use this parameter unless you know what you do.

        Returns
        -------
        X_leaves : array_like, shape = [n_samples,]
            For each datapoint x in X, return the index of the leaf x
            ends up in. Leaves are numbered within
            ``[0; self.tree_.node_count)``, possibly with gaps in the
            numbering.
        """
        check_is_fitted(self, 'tree_')
        X = self._validate_X_predict(X, check_input)
        return self.tree_.apply(X)

    def decision_path(self, X, check_input=True):
        """Return the decision path in the tree

        .. versionadded:: 0.18

        Parameters
        ----------
        X : array_like or sparse matrix, shape = [n_samples, n_features]
            The input samples. Internally, it will be converted to
            ``dtype=np.float32`` and if a sparse matrix is provided
            to a sparse ``csr_matrix``.

        check_input : boolean, (default=True)
            Allow to bypass several input checking.
            Don't use this parameter unless you know what you do.

        Returns
        -------
        indicator : sparse csr array, shape = [n_samples, n_nodes]
            Return a node indicator matrix where non zero elements
            indicates that the samples goes through the nodes.
        """
        X = self._validate_X_predict(X, check_input)
        return self.tree_.decision_path(X)


class BaseMondrianTree(BaseDecisionTree):
    """A Mondrian tree.

    The splits in a mondrian tree regressor differ from the standard regression
    tree in the following ways.

    At fit time:
        - Splits are done independently of the labels.
        - The candidate feature is drawn with a probability proportional to the
          feature range.
        - The candidate threshold is drawn from a uniform distribution
          with the bounds equal to the bounds of the candidate feature.
        - The time of split is also stored which is proportional to the
          inverse of the size of the bounding-box.

    At prediction time:
        - Every node in the path from the root to the leaf is given a weight
          while making predictions.
        - At each node, the probability of an unseen sample splitting from that
          node is calculated. The farther the sample is away from the bounding
          box, the more probable that it will split away.
        - For every node, the probability that an unseen sample has not split
          before reaching that node and the probability that it will split away
          at that particular node are multiplied to give a weight.

    Parameters
    ----------
    max_depth : int or None, optional (default=None)
        The maximum depth of the tree. If None, then nodes are expanded until
        all leaves are pure or until all leaves contain less than
        min_samples_split samples.

    min_samples_split : int, float, optional (default=2)
        The minimum number of samples required to split an internal node:

        - If int, then consider `min_samples_split` as the minimum number.
        - If float, then `min_samples_split` is a percentage and
          `ceil(min_samples_split * n_samples)` are the minimum
          number of samples for each split.

    random_state : int, RandomState instance or None, optional (default=None)
        If int, random_state is the seed used by the random number generator;
        If RandomState instance, random_state is the random number generator;
        If None, the random number generator is the RandomState instance used
        by `np.random`.
    """
    def partial_fit(self, X, y, classes=None):
        """
        Incremental building of Mondrian Trees.

        Parameters
        ----------
        X : array_like, shape = [n_samples, n_features]
            The input samples. Internally, it will be converted to
            ``dtype=np.float32``

        y: array_like, shape = [n_samples]
            Input targets.

        classes: array_like, shape = [n_classes]
            Ignored for a regression problem. For a classification
            problem, if not provided this is inferred from y.
            This is taken into account for only the first call to
            partial_fit and ignored for subsequent calls.

        Returns
        -------
        self: instance of MondrianTree
        """
        random_state = check_random_state(self.random_state)
        X, y = check_X_y(X, y, dtype=DTYPE, multi_output=False, order="C")
        is_classifier = isinstance(self, ClassifierMixin)
        random_state = check_random_state(self.random_state)
        max_depth = ((2 ** 31) - 1 if self.max_depth is None
                     else self.max_depth)

        # This is necessary to rebuild the tree if partial_fit is called
        # after fit.
        first_call = not hasattr(self, "first_")
        if not hasattr(self, "first_"):
            self.first_ = True

        if is_classifier:
            check_classification_targets(y)

            # First call to partial_fit
            if first_call:
                if len(y) == 1 and classes is None:
                    raise ValueError("Unable to infer classes. Should be "
                                     "provided at the first call to partial_fit.")
                self.le_ = LabelEncoder()
                if classes is not None:
                    self.le_.fit(classes)
                else:
                    self.le_.fit(y)
                self.classes_ = self.le_.classes_
            y = self.le_.transform(y)
            n_classes = [len(self.le_.classes_)]
        else:
            n_classes = [1]

        # To be consistent with sklearns tree architecture, we reshape.
        y = np.array(y, dtype=np.float64)
        y = np.reshape(y, (-1, 1))

        # First call to partial_fit, initalize tree
        if first_call:
            self.n_features_ = X.shape[1]
            self.n_classes_ = np.array(n_classes, dtype=np.intp)
            self.n_outputs_ = 1
            self.tree_ = Tree(self.n_features_, self.n_classes_, self.n_outputs_)

        builder = PartialFitTreeBuilder(
            self.min_samples_split, max_depth, random_state)
        builder.build(self.tree_, X, y)
        return self

    def weighted_decision_path(self, X, check_input=True):
        """
        Returns the weighted decision path in the tree.

        Each non-zero value in the decision path determines the weight
        of that particular node in making predictions.

        Parameters
        ----------
        X : array_like, shape = [n_samples, n_features]
            The input samples. Internally, it will be converted to
            ``dtype=np.float32`` and if a sparse matrix is provided
            to a sparse ``csr_matrix``.

        check_input : boolean, (default=True)
            Allow to bypass several input checking.
            Don't use this parameter unless you know what you do.

        Returns
        -------
        indicator : sparse csr array, shape = [n_samples, n_nodes]
            Return a node indicator matrix where non zero elements
            indicate the weight of that particular node in making predictions.
        """
        X = self._validate_X_predict(X, check_input)
        return self.tree_.weighted_decision_path(X)


class MondrianTreeRegressor(BaseMondrianTree, RegressorMixin):
    def __init__(self,
                 max_depth=None,
                 min_samples_split=2,
                 random_state=None):
        super(MondrianTreeRegressor, self).__init__(
            criterion="mse",
            splitter="mondrian",
            max_depth=max_depth,
            min_samples_split=min_samples_split,
            random_state=random_state)

    def partial_fit(self, X, y):
        """
        Incremental building of Mondrian Tree Regressors.

        Parameters
        ----------
        X : array_like, shape = [n_samples, n_features]
            The input samples. Internally, it will be converted to
            ``dtype=np.float32``

        y: array_like, shape = [n_samples]
            Input targets.

        Returns
        -------
        self: instance of MondrianTree
        """
        return super(MondrianTreeRegressor, self).partial_fit(X, y)

class MondrianTreeClassifier(BaseMondrianTree, ClassifierMixin):
    def __init__(self,
                 max_depth=None,
                 min_samples_split=2,
                 random_state=None):
        super(MondrianTreeClassifier, self).__init__(
            criterion="classification",
            splitter="mondrian",
            max_depth=max_depth,
            min_samples_split=min_samples_split,
            random_state=random_state)

    def predict_proba(self, X, check_input=True):
        """
        Predicts the probability of each class label given X.

        Parameters
        ----------
        X : array-like, shape = [n_samples, n_features]
            The input samples. Internally, it will be converted to
            ``dtype=np.float32``.

        check_input : boolean, (default=True)
            Allow to bypass several input checking.
            Don't use this parameter unless you know what you do.

        Returns
        -------
        y_prob : array of shape = [n_samples, n_classes]
            Prediceted probabilities for each class.
        """
        check_is_fitted(self, 'tree_')
        X = self._validate_X_predict(X, check_input)

        return self.tree_.predict(X, return_std=False, is_regression=False)[0]

    def partial_fit(self, X, y, classes=None):
        """
        Incremental building of Mondrian Tree Classifiers.

        Parameters
        ----------
        X : array_like, shape = [n_samples, n_features]
            The input samples. Internally, it will be converted to
            ``dtype=np.float32``

        y: array_like, shape = [n_samples]
            Input targets.

        classes: array_like, shape = [n_classes]
            Ignored for a regression problem. For a classification
            problem, if not provided this is inferred from y.
            This is taken into account for only the first call to
            partial_fit and ignored for subsequent calls.

        Returns
        -------
        self: instance of MondrianTree
        """
        return super(MondrianTreeClassifier, self).partial_fit(
            X, y, classes=classes)
