# -*- coding: utf-8 -*-
"""
YAKE
----
"""
from __future__ import absolute_import, division, print_function, unicode_literals

import collections
import math
import operator

from cytoolz import itertoolz

from . import utils as ke_utils
from .. import compat, utils


def yake(
    doc,
    normalize="lemma",
    ngrams=(1, 2, 3),
    include_pos=("NOUN", "PROPN", "ADJ"),
    window_size=2,
    topn=10,
):
    """
    Extract key terms from a document using the YAKE algorithm.

    Args:
        doc (:class:`spacy.tokens.Doc`): spaCy ``Doc`` from which to extract keyterms.
            Must be sentence-segmented; optionally POS-tagged.
        normalize (str): If "lemma", lemmatize terms; if "lower", lowercase terms;
            if None, use the form of terms as they appeared in ``doc``.

            .. note:: Unlike the other keyterm extraction functions, this one
               doesn't accept a callable for ``normalize``.

        ngrams (int or Set[int]): n of which n-grams to consider as keyterm candidates.
            For example, `(1, 2, 3)`` includes all unigrams, bigrams, and trigrams,
            while ``2`` includes bigrams only.
        include_pos (str or Set[str]): One or more POS tags with which to filter
            for good candidate keyterms. If None, include tokens of all POS tags
            (which also allows keyterm extraction from docs without POS-tagging.)
        window_size (int): Number of words to the right and left of a given word
            to use as context when computing the "relatedness to context"
            component of its score. Note that the resulting sliding window's
            full width is ``1 + (2 * window_size)``.
        topn (int or float): Number of top-ranked terms to return as key terms.
            If an integer, represents the absolute number; if a float, value
            must be in the interval (0.0, 1.0], which is converted to an int by
            ``int(round(len(candidates) * topn))``

    Returns:
        List[Tuple[str, float]]: Sorted list of top ``topn`` key terms and
        their corresponding scores.

    References:
        Campos, Mangaravite, Pasquali, Jorge, Nunes, and Jatowt. (2018).
        A Text Feature Based Automatic Keyword Extraction Method for Single Documents.
        Advances in Information Retrieval. ECIR 2018.
        Lecture Notes in Computer Science, vol 10772, pp. 684-691.
    """
    # validate / transform args
    ngrams = utils.to_collection(ngrams, int, tuple)
    include_pos = utils.to_collection(include_pos, compat.unicode_, set)
    if isinstance(topn, float):
        if not 0.0 < topn <= 1.0:
            raise ValueError(
                "topn={} is invalid; "
                "must be an int, or a float between 0.0 and 1.0".format(topn)
            )

    # bail out on empty docs
    if not doc:
        return []

    stop_words = set()
    seen_candidates = set()
    # compute key values on a per-word basis
    word_occ_vals = _get_per_word_occurrence_values(doc, normalize, stop_words, window_size)
    # doc doesn't have any words...
    if not word_occ_vals:
        return []

    word_freqs = {w_id: len(vals["is_uc"]) for w_id, vals in word_occ_vals.items()}
    word_scores = _compute_word_scores(doc, word_occ_vals, word_freqs, stop_words)
    # compute scores for candidate terms based on scores of constituent words
    term_scores = {}
    # do single-word candidates separately; it's faster and simpler
    if 1 in ngrams:
        candidates = _get_unigram_candidates(doc, include_pos)
        _score_unigram_candidates(
            candidates,
            word_freqs, word_scores, term_scores,
            stop_words, seen_candidates,
            normalize,
        )
    # now compute combined scores for higher-n ngram and candidates
    candidates = list(
        ke_utils.get_ngram_candidates(
            doc, [n for n in ngrams if n > 1], include_pos=include_pos,
        )
    )
    attr_name = _get_attr_name(normalize, True)
    ngram_freqs = itertoolz.frequencies(
        " ".join(getattr(word, attr_name) for word in ngram)
        for ngram in candidates)
    _score_ngram_candidates(
        candidates,
        ngram_freqs, word_scores, term_scores,
        seen_candidates,
        normalize,
    )
    # build up a list of key terms in order of increasing score
    if isinstance(topn, float):
        topn = int(round(len(seen_candidates) * topn))
    sorted_term_scores = sorted(
        term_scores.items(),
        key=operator.itemgetter(1),
        reverse=False,
    )
    return ke_utils.get_filtered_topn_terms(
        sorted_term_scores, topn, match_threshold=0.8)


def _get_attr_name(normalize, as_strings):
    """
    Args:
        normalize (str)
        as_strings (bool)

    Returns:
        str
    """
    if normalize is None:
        attr_name = "norm"
    elif normalize in ("lemma", "lower"):
        attr_name = normalize
    else:
        raise ValueError(
            "normalize='{}' is invalid; "
            "must be None or one of {}".format({"lemma", "lower"})
        )
    if as_strings is True:
        attr_name = attr_name + "_"
    return attr_name


def _get_per_word_occurrence_values(doc, normalize, stop_words, window_size):
    """
    Get base values for each individual occurrence of a word, to be aggregated
    and combined into a per-word score.

    Args:
        doc (:class:`spacy.tokens.Doc`)
        normalize (str)
        stop_words (Set[str])
        window_size (int)

    Returns:
        Dict[int, Dict[str, list]]
    """
    word_occ_vals = collections.defaultdict(lambda: collections.defaultdict(list))

    def _is_upper_cased(tok):
        return tok.is_upper or (tok.is_title and not tok.is_sent_start)

    attr_name = _get_attr_name(normalize, False)
    padding = [None] * window_size
    for sent_idx, sent in enumerate(doc.sents):
        sent_padded = itertoolz.concatv(padding, sent, padding)
        for window in itertoolz.sliding_window(1 + (2 * window_size), sent_padded):
            lwords, word, rwords = window[:window_size], window[window_size], window[window_size + 1:]
            w_id = getattr(word, attr_name)
            if word.is_stop:
                stop_words.add(w_id)
            word_occ_vals[w_id]["is_uc"].append(_is_upper_cased(word))
            word_occ_vals[w_id]["sent_idx"].append(sent_idx)
            word_occ_vals[w_id]["l_context"].extend(
                getattr(w, attr_name) for w in lwords
                if not (w is None or w.is_punct or w.is_space)
            )
            word_occ_vals[w_id]["r_context"].extend(
                getattr(w, attr_name) for w in rwords
                if not (w is None or w.is_punct or w.is_space)
            )
    return word_occ_vals


def _compute_word_scores(doc, word_occ_vals, word_freqs, stop_words):
    """
    Aggregate values from per-word occurrence values, compute per-word weights
    of several components, then combine components into per-word scores.

    Args:
        doc (:class:`spacy.tokens.Doc`)
        word_occ_vals (Dict[int, Dict[str, list]])
        word_freqs (Dict[int, int])
        stop_words (Set[str])

    Returns:
        Dict[int, float]
    """
    word_weights = collections.defaultdict(dict)
    # compute summary stats for word frequencies
    freqs_nsw = [freq for w_id, freq in word_freqs.items() if w_id not in stop_words]
    freq_max = max(word_freqs.values())
    freq_baseline = compat.mean_(freqs_nsw) + compat.stdev_(freqs_nsw)
    n_sents = itertoolz.count(doc.sents)
    for w_id, vals in word_occ_vals.items():
        freq = word_freqs[w_id]
        word_weights[w_id]["case"] = sum(vals["is_uc"]) / compat.log2_(1 + freq)
        word_weights[w_id]["pos"] = compat.log2_(compat.log2_(3 + compat.median_(vals["sent_idx"])))
        word_weights[w_id]["freq"] = freq / freq_baseline
        word_weights[w_id]["disp"] = len(set(vals["sent_idx"])) / n_sents
        n_unique_lc = len(set(vals["l_context"]))
        n_unique_rc = len(set(vals["r_context"]))
        try:
            wl = n_unique_lc / len(vals["l_context"])
        except ZeroDivisionError:
            wl = 0.0
        try:
            wr = n_unique_rc / len(vals["r_context"])
        except ZeroDivisionError:
            wr = 0.0
        pl = n_unique_lc / freq_max
        pr = n_unique_rc / freq_max
        word_weights[w_id]["rel"] = 1.0 + (wl + wr) * (freq / freq_max) + pl + pr

    # combine individual weights into per-word scores
    word_scores = {
        w_id: (wts["rel"] * wts["pos"]) / (wts["case"] + (wts["freq"] / wts["rel"]) + (wts["disp"] / wts["rel"]))
        for w_id, wts in word_weights.items()
    }
    return word_scores


def _get_unigram_candidates(doc, include_pos):
    """
    Args:
        doc (:class:`spacy.tokens.Doc`)
        include_pos (Set[str])

    Returns:
        List[:class:`spacy.tokens.Token`]
    """
    candidates = (
        word for word in doc
        if not (word.is_stop or word.is_punct or word.is_space)
    )
    if include_pos:
        candidates = (
            word for word in candidates
            if word.pos_ in include_pos
        )
    return candidates


def _score_unigram_candidates(
    candidates,
    word_freqs,
    word_scores,
    term_scores,
    stop_words,
    seen_candidates,
    normalize,
):
    """
    Args:
        candidates (List[:class:`spacy.tokens.Token`])
        word_freqs (Dict[int, float])
        word_scores (Dict[int, float])
        term_scores (Dict[str, float])
        stop_words (Set[str])
        seen_candidates (Set[str])
        normalize (str)
    """
    attr_name = _get_attr_name(normalize, False)
    attr_name_str = _get_attr_name(normalize, True)
    for word in candidates:
        w_id = getattr(word, attr_name)
        if w_id in stop_words or w_id in seen_candidates:
            continue
        else:
            seen_candidates.add(w_id)
        # NOTE: here i've modified the YAKE algorithm to put less emphasis on term freq
        # term_scores[word.lower_] = word_scores[w_id] / (word_freqs[w_id] * (1 + word_scores[w_id]))
        term_scores[getattr(word, attr_name_str)] = (
            word_scores[w_id] / (compat.log2_(1 + word_freqs[w_id]) * (1 + word_scores[w_id]))
        )


def _score_ngram_candidates(
    candidates,
    ngram_freqs, word_scores, term_scores,
    seen_candidates,
    normalize,
):
    """
    Args:
        candidates (List[Tuple[:class:`spacy.tokens.Token`]])
        ngram_freqs (Dict[str, int])
        word_scores (Dict[int, float])
        term_scores (Dict[str, float])
        seen_candidates (Set[str])
        normalize (str)
    """
    attr_name = _get_attr_name(normalize, False)
    attr_name_str = _get_attr_name(normalize, True)
    for ngram in candidates:
        ngtxt = " ".join(getattr(word, attr_name_str) for word in ngram)
        if ngtxt in seen_candidates:
            continue
        else:
            seen_candidates.add(ngtxt)
        ngram_word_scores = [word_scores[getattr(word, attr_name)] for word in ngram]
        # multiply individual word scores together in the numerator
        numerator = compat.reduce_(operator.mul, ngram_word_scores, 1.0)
        # NOTE: here i've modified the YAKE algorithm to put less emphasis on term freq
        # denominator = ngram_freqs[ngtxt] * (1.0 + sum(ngram_word_scores))
        denominator = compat.log2_(1 + ngram_freqs[ngtxt]) * (1.0 + sum(ngram_word_scores))
        term_scores[ngtxt] = numerator / denominator
