/*
 * Decompiled with CFR 0.152.
 */
package org.hipparchus.stat.inference;

import java.util.Map;
import java.util.TreeMap;
import java.util.stream.LongStream;
import org.hipparchus.distribution.continuous.NormalDistribution;
import org.hipparchus.exception.Localizable;
import org.hipparchus.exception.LocalizedCoreFormats;
import org.hipparchus.exception.MathIllegalArgumentException;
import org.hipparchus.exception.MathIllegalStateException;
import org.hipparchus.exception.NullArgumentException;
import org.hipparchus.stat.LocalizedStatFormats;
import org.hipparchus.stat.ranking.NaNStrategy;
import org.hipparchus.stat.ranking.NaturalRanking;
import org.hipparchus.stat.ranking.TiesStrategy;
import org.hipparchus.util.FastMath;
import org.hipparchus.util.Precision;

public class MannWhitneyUTest {
    private static final int SMALL_SAMPLE_SIZE = 50;
    private final NaturalRanking naturalRanking;
    private final NormalDistribution standardNormal;

    public MannWhitneyUTest() {
        this.naturalRanking = new NaturalRanking(NaNStrategy.FIXED, TiesStrategy.AVERAGE);
        this.standardNormal = new NormalDistribution(0.0, 1.0);
    }

    public MannWhitneyUTest(NaNStrategy nanStrategy, TiesStrategy tiesStrategy) {
        this.naturalRanking = new NaturalRanking(nanStrategy, tiesStrategy);
        this.standardNormal = new NormalDistribution(0.0, 1.0);
    }

    public double mannWhitneyU(double[] x, double[] y) throws MathIllegalArgumentException, NullArgumentException {
        this.ensureDataConformance(x, y);
        double[] z = this.concatenateSamples(x, y);
        double[] ranks = this.naturalRanking.rank(z);
        double sumRankX = 0.0;
        for (int i = 0; i < x.length; ++i) {
            sumRankX += ranks[i];
        }
        double U1 = sumRankX - (double)((long)x.length * (long)(x.length + 1) / 2L);
        double U2 = (double)((long)x.length * (long)y.length) - U1;
        return FastMath.min((double)U1, (double)U2);
    }

    private double[] concatenateSamples(double[] x, double[] y) {
        double[] z = new double[x.length + y.length];
        System.arraycopy(x, 0, z, 0, x.length);
        System.arraycopy(y, 0, z, x.length, y.length);
        return z;
    }

    public double mannWhitneyUTest(double[] x, double[] y) throws MathIllegalArgumentException, NullArgumentException {
        this.ensureDataConformance(x, y);
        if (x.length + y.length <= 50 && this.tiesMap(x, y).isEmpty()) {
            return this.mannWhitneyUTest(x, y, true);
        }
        return this.mannWhitneyUTest(x, y, false);
    }

    public double mannWhitneyUTest(double[] x, double[] y, boolean exact) throws MathIllegalArgumentException, NullArgumentException {
        this.ensureDataConformance(x, y);
        Map<Double, Integer> tiesMap = this.tiesMap(x, y);
        double u = this.mannWhitneyU(x, y);
        if (exact) {
            if (!tiesMap.isEmpty()) {
                throw new MathIllegalArgumentException((Localizable)LocalizedStatFormats.TIES_ARE_NOT_ALLOWED, new Object[0]);
            }
            return this.exactP(x.length, y.length, u);
        }
        return this.approximateP(u, x.length, y.length, this.varU(x.length, y.length, tiesMap));
    }

    private void ensureDataConformance(double[] x, double[] y) throws MathIllegalArgumentException, NullArgumentException {
        if (x == null || y == null) {
            throw new NullArgumentException();
        }
        if (x.length == 0 || y.length == 0) {
            throw new MathIllegalArgumentException((Localizable)LocalizedCoreFormats.NO_DATA, new Object[0]);
        }
    }

    private double approximateP(double u, int n1, int n2, double varU) throws MathIllegalStateException {
        double mu = (double)((long)n1 * (long)n2) / 2.0;
        if (Precision.equals((double)mu, (double)u)) {
            return 1.0;
        }
        double z = -Math.abs(u - mu + 0.5) / FastMath.sqrt((double)varU);
        return 2.0 * this.standardNormal.cumulativeProbability(z);
    }

    private double exactP(int n, int m, double u) {
        double nm = m * n;
        if (u > nm) {
            return 1.0;
        }
        double crit = u < nm / 2.0 ? u : nm / 2.0 - u;
        double cum = 0.0;
        int ct = 0;
        while ((double)ct <= crit) {
            cum += this.uDensity(n, m, ct);
            ++ct;
        }
        return 2.0 * cum;
    }

    private double uDensity(int n, int m, double u) {
        if (u < 0.0 || u > (double)(m * n)) {
            return 0.0;
        }
        long[] freq = this.uFrequencies(n, m);
        return (double)freq[(int)FastMath.round((double)(u + 1.0))] / (double)LongStream.of(freq).sum();
    }

    private long[] uFrequencies(int n, int m) {
        int max = FastMath.max((int)m, (int)n);
        if (max > 100) {
            throw new MathIllegalArgumentException((Localizable)LocalizedCoreFormats.NUMBER_TOO_LARGE, new Object[]{max, 100});
        }
        int min = FastMath.min((int)m, (int)n);
        long[] out = new long[n * m + 2];
        long[] work = new long[n * m + 2];
        for (int i = 1; i < out.length; ++i) {
            out[i] = i <= max + 1 ? 1L : 0L;
        }
        work[1] = 0L;
        int in = max;
        for (int i = 2; i <= min; ++i) {
            work[i] = 0L;
            int n1 = (in += max) + 2;
            long l = 1 + in / 2;
            int k = i;
            int j = 1;
            while ((long)j <= l) {
                long sum;
                out[j] = sum = out[j] + work[j];
                work[++k] = sum - out[--n1];
                out[n1] = sum;
                ++j;
            }
        }
        return out;
    }

    private double varU(int n, int m, Map<Double, Integer> tiesMap) {
        double nm = (long)n * (long)m;
        if (tiesMap.isEmpty()) {
            return nm * (double)(n + m + 1) / 12.0;
        }
        long tSum = tiesMap.entrySet().stream().mapToLong(e -> (Integer)e.getValue() * (Integer)e.getValue() * (Integer)e.getValue() - (Integer)e.getValue()).sum();
        double totalN = n + m;
        return nm / 12.0 * (totalN + 1.0 - (double)tSum / (totalN * (totalN - 1.0)));
    }

    private Map<Double, Integer> tiesMap(double[] x, double[] y) {
        int i;
        TreeMap<Double, Integer> tiesMap = new TreeMap<Double, Integer>();
        for (i = 0; i < x.length; ++i) {
            tiesMap.merge(x[i], 1, Integer::sum);
        }
        for (i = 0; i < y.length; ++i) {
            tiesMap.merge(y[i], 1, Integer::sum);
        }
        tiesMap.entrySet().removeIf(e -> (Integer)e.getValue() == 1);
        return tiesMap;
    }
}

