/*
 * Decompiled with CFR 0.152.
 */
package com.wcohen.ss;

import com.wcohen.ss.AbstractStatisticalTokenDistance;
import com.wcohen.ss.BagOfTokens;
import com.wcohen.ss.PrintfFormat;
import com.wcohen.ss.api.StringWrapper;
import com.wcohen.ss.api.Token;
import com.wcohen.ss.api.Tokenizer;
import java.util.Iterator;

public class TFIDF
extends AbstractStatisticalTokenDistance {
    private UnitVector lastVector = null;

    public TFIDF(Tokenizer tokenizer) {
        super(tokenizer);
    }

    public TFIDF() {
    }

    @Override
    public double score(StringWrapper s, StringWrapper t) {
        this.checkTrainingHasHappened(s, t);
        UnitVector sBag = this.asUnitVector(s);
        UnitVector tBag = this.asUnitVector(t);
        double sim = 0.0;
        Iterator<Token> i = sBag.tokenIterator();
        while (i.hasNext()) {
            Token tok = i.next();
            if (!tBag.contains(tok)) continue;
            sim += sBag.getWeight(tok) * tBag.getWeight(tok);
        }
        return sim;
    }

    protected UnitVector asUnitVector(StringWrapper w) {
        if (w instanceof UnitVector) {
            return (UnitVector)w;
        }
        if (w instanceof BagOfTokens) {
            return new UnitVector((BagOfTokens)w);
        }
        return new UnitVector(w.unwrap(), this.tokenizer.tokenize(w.unwrap()));
    }

    @Override
    public StringWrapper prepare(String s) {
        this.lastVector = new UnitVector(s, this.tokenizer.tokenize(s));
        return this.lastVector;
    }

    public Token[] getTokens() {
        return this.lastVector.getTokens();
    }

    public double getWeight(Token token) {
        return this.lastVector.getWeight(token);
    }

    @Override
    public int getDocumentFrequency(Token token) {
        Integer df = (Integer)this.documentFrequency.get(token);
        return df;
    }

    public void setDocumentFrequency(Token token, int df) {
        this.documentFrequency.put(token, new Integer(df));
    }

    public int getCollectionSize() {
        return this.collectionSize;
    }

    public void setCollectionSize(int n) {
        this.collectionSize = n;
    }

    @Override
    public String explainScore(StringWrapper s, StringWrapper t) {
        BagOfTokens sBag = (BagOfTokens)s;
        BagOfTokens tBag = (BagOfTokens)t;
        StringBuilder buf = new StringBuilder("");
        PrintfFormat fmt = new PrintfFormat("%.3f");
        buf.append("Common tokens: ");
        Iterator<Token> i = sBag.tokenIterator();
        while (i.hasNext()) {
            Token tok = i.next();
            if (!tBag.contains(tok)) continue;
            buf.append(" " + tok.getValue() + ": ");
            buf.append(fmt.sprintf(sBag.getWeight(tok)));
            buf.append("*");
            buf.append(fmt.sprintf(tBag.getWeight(tok)));
        }
        buf.append("\nscore = " + this.score(s, t));
        return buf.toString();
    }

    public String toString() {
        return "[TFIDF]";
    }

    public static void main(String[] argv) {
        TFIDF.doMain(new TFIDF(), argv);
    }

    protected class UnitVector
    extends BagOfTokens {
        public UnitVector(String s, Token[] tokens) {
            super(s, tokens);
            this.termFreq2TFIDF();
        }

        public UnitVector(BagOfTokens bag) {
            this(bag.unwrap(), bag.getTokens());
            this.termFreq2TFIDF();
        }

        private void termFreq2TFIDF() {
            Token tok;
            double normalizer = 0.0;
            Iterator<Token> i = this.tokenIterator();
            while (i.hasNext()) {
                tok = i.next();
                if (TFIDF.this.collectionSize > 0) {
                    Integer dfInteger = (Integer)TFIDF.this.documentFrequency.get(tok);
                    double df = dfInteger == null ? 1.0 : (double)dfInteger.intValue();
                    double w = Math.log(this.getWeight(tok) + 1.0) * Math.log((double)TFIDF.this.collectionSize / df);
                    this.setWeight(tok, w);
                    normalizer += w * w;
                    continue;
                }
                this.setWeight(tok, 1.0);
                normalizer += 1.0;
            }
            normalizer = Math.sqrt(normalizer);
            i = this.tokenIterator();
            while (i.hasNext()) {
                tok = i.next();
                this.setWeight(tok, this.getWeight(tok) / normalizer);
            }
        }
    }
}

