/*
 * Decompiled with CFR 0.152.
 */
package jitk.spline;

import jitk.spline.ThinPlateR2LogRSplineKernelTransform;
import org.apache.log4j.LogManager;
import org.apache.log4j.Logger;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;
import org.ejml.ops.NormOps;

public class TransformInverseGradientDescent {
    int ndims;
    ThinPlateR2LogRSplineKernelTransform xfm;
    DenseMatrix64F jacobian;
    DenseMatrix64F directionalDeriv;
    DenseMatrix64F descentDirectionMag;
    DenseMatrix64F dir;
    DenseMatrix64F errorV;
    DenseMatrix64F estimate;
    DenseMatrix64F estimateXfm;
    DenseMatrix64F target;
    double error = 9999.0;
    double stepSz = 1.0;
    int maxIters = 20;
    double eps = 1.0E-6;
    double beta = 0.7;
    protected static Logger logger = LogManager.getLogger(TransformInverseGradientDescent.class.getName());

    public TransformInverseGradientDescent(int ndims, ThinPlateR2LogRSplineKernelTransform xfm) {
        this.ndims = ndims;
        this.xfm = xfm;
        this.dir = new DenseMatrix64F(ndims, 1);
        this.errorV = new DenseMatrix64F(ndims, 1);
        this.directionalDeriv = new DenseMatrix64F(ndims, 1);
        this.descentDirectionMag = new DenseMatrix64F(1, 1);
    }

    public void setEps(double eps) {
        this.eps = eps;
    }

    public void setStepSize(double stepSize) {
        this.stepSz = stepSize;
    }

    public void setJacobian(double[][] mtx) {
        this.jacobian = new DenseMatrix64F(mtx);
        logger.trace("setJacobian:\n" + this.jacobian);
    }

    public void setTarget(double[] tgt) {
        this.target = new DenseMatrix64F(this.ndims, 1);
        this.target.setData(tgt);
    }

    public DenseMatrix64F getErrorVector() {
        return this.errorV;
    }

    public DenseMatrix64F getDirection() {
        return this.dir;
    }

    public DenseMatrix64F getJacobian() {
        return this.jacobian;
    }

    public void setEstimate(double[] est) {
        this.estimate = new DenseMatrix64F(this.ndims, 1);
        this.estimate.setData(est);
    }

    public void setEstimateXfm(double[] est) {
        this.estimateXfm = new DenseMatrix64F(this.ndims, 1);
        this.estimateXfm.setData(est);
        this.updateError();
    }

    public DenseMatrix64F getEstimate() {
        return this.estimate;
    }

    public double getError() {
        return this.error;
    }

    public void oneIteration() {
        this.oneIteration(true);
    }

    public void oneIteration(boolean updateError) {
        this.computeDirection();
        this.updateEstimate(this.stepSz);
        if (updateError) {
            this.updateError();
        }
    }

    public void computeDirectionSteepest() {
        DenseMatrix64F tmp = new DenseMatrix64F(this.ndims, 1);
        logger.trace("\nerrorV:\n" + this.errorV);
        CommonOps.mult(this.jacobian, this.estimate, tmp);
        CommonOps.subEquals(tmp, this.errorV);
        CommonOps.multTransA(2.0, this.jacobian, tmp, this.dir);
        double norm = NormOps.normP2(this.dir);
        CommonOps.divide(norm, this.dir);
        CommonOps.mult(this.jacobian, this.dir, this.directionalDeriv);
        CommonOps.scale(-1.0, this.dir);
    }

    public void computeDirection() {
        CommonOps.solve(this.jacobian, this.errorV, this.dir);
        double norm = NormOps.normP2(this.dir);
        CommonOps.divide(norm, this.dir);
        CommonOps.mult(this.jacobian, this.dir, this.directionalDeriv);
        CommonOps.multTransA(this.dir, this.directionalDeriv, this.descentDirectionMag);
        logger.debug("descentDirectionMag: " + this.descentDirectionMag.get(0));
    }

    public double backtrackingLineSearch(double c, double beta, int maxtries, double t0) {
        int k;
        double t = t0;
        for (k = 0; k < maxtries && !this.armijoCondition(c, t); ++k) {
            t *= beta;
        }
        logger.trace("selected step size after " + k + " tries");
        return t;
    }

    public boolean armijoCondition(double c, double t) {
        double[] d = this.dir.data;
        double[] x = this.estimate.data;
        double[] x_ap = new double[this.ndims];
        for (int i = 0; i < this.ndims; ++i) {
            x_ap[i] = x[i] + t * d[i];
        }
        double[] phix = this.estimateXfm.data;
        double[] phix_ap = this.xfm.apply(x_ap);
        double fx = this.squaredError(phix);
        double fx_ap = this.squaredError(phix_ap);
        double m = this.sumSquaredErrorsDeriv(this.target.data, phix) * this.descentDirectionMag.get(0);
        logger.trace("   f( x )     : " + fx);
        logger.trace("   f( x + ap ): " + fx_ap);
        logger.trace("   f( x ) + c * m * t: " + (fx + c * t * m));
        return fx_ap < fx + c * t * m;
    }

    public double squaredError(double[] x) {
        double error = 0.0;
        for (int i = 0; i < this.ndims; ++i) {
            error += (x[i] - this.target.get(i)) * (x[i] - this.target.get(i));
        }
        return error;
    }

    public void updateEstimate(double stepSize) {
        logger.trace("step size: " + stepSize);
        logger.trace("estimate:\n" + this.estimate);
        CommonOps.addEquals(this.estimate, stepSize, this.dir);
        logger.trace("new estimate:\n" + this.estimate);
    }

    public void updateEstimateNormBased(double stepSize) {
        logger.debug("step size: " + stepSize);
        logger.trace("estimate:\n" + this.estimate);
        double norm = NormOps.normP2(this.dir);
        logger.debug("norm: " + norm);
        if (norm > stepSize) {
            CommonOps.scale(-stepSize / norm, this.dir);
        }
        CommonOps.addEquals(this.estimate, this.dir);
        logger.trace("new estimate:\n" + this.estimate);
    }

    public void updateError() {
        if (this.estimate == null || this.target == null) {
            System.err.println("WARNING: Call to updateError with null target or estimate");
            return;
        }
        CommonOps.sub(this.target, this.estimateXfm, this.errorV);
        logger.trace("#########################");
        logger.trace("updateError, estimate   :\n" + this.estimate);
        logger.trace("updateError, estimateXfm:\n" + this.estimateXfm);
        logger.trace("updateError, target     :\n" + this.target);
        logger.trace("updateError, error      :\n" + this.errorV);
        logger.trace("#########################");
        this.error = Math.abs(this.errorV.get(0));
        for (int i = 1; i < this.ndims; ++i) {
            if (!(Math.abs(this.errorV.get(i)) > this.error)) continue;
            this.error = Math.abs(this.errorV.get(i));
        }
    }

    private double sumSquaredErrorsDeriv(double[] y, double[] x) {
        double errDeriv = 0.0;
        for (int i = 0; i < this.ndims; ++i) {
            errDeriv += (y[i] - x[i]) * (y[i] - x[i]);
        }
        return 2.0 * errDeriv;
    }

    public static double sumSquaredErrors(double[] y, double[] x) {
        int ndims = y.length;
        double err = 0.0;
        for (int i = 0; i < ndims; ++i) {
            err += (y[i] - x[i]) * (y[i] - x[i]);
        }
        return err;
    }

    public static void copyVectorIntoArray(DenseMatrix64F vec, double[] array) {
        System.arraycopy(vec.data, 0, array, 0, vec.getNumElements());
    }
}

