/*
 * Decompiled with CFR 0.152.
 */
package org.hipparchus.linear;

import org.hipparchus.complex.Complex;
import org.hipparchus.exception.LocalizedCoreFormats;
import org.hipparchus.exception.MathIllegalArgumentException;
import org.hipparchus.exception.MathRuntimeException;
import org.hipparchus.linear.Array2DRowRealMatrix;
import org.hipparchus.linear.DecompositionSolver;
import org.hipparchus.linear.FieldDecompositionSolver;
import org.hipparchus.linear.FieldLUDecomposition;
import org.hipparchus.linear.FieldMatrix;
import org.hipparchus.linear.LUDecomposition;
import org.hipparchus.linear.MatrixUtils;
import org.hipparchus.linear.OrderedComplexEigenDecomposition;
import org.hipparchus.linear.RealMatrix;
import org.hipparchus.linear.RiccatiEquationSolver;
import org.hipparchus.linear.SingularValueDecomposition;
import org.hipparchus.util.FastMath;

public class RiccatiEquationSolverImpl
implements RiccatiEquationSolver {
    private static final int MAX_ITERATIONS = 100;
    private static final double EPSILON = 1.0E-8;
    private final RealMatrix P;
    private final RealMatrix K;

    public RiccatiEquationSolverImpl(RealMatrix A, RealMatrix B, RealMatrix Q, RealMatrix R) {
        if (!A.isSquare()) {
            throw new MathIllegalArgumentException(LocalizedCoreFormats.NON_SQUARE_MATRIX, A.getRowDimension(), A.getColumnDimension());
        }
        if (A.getColumnDimension() != B.getRowDimension()) {
            throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH, A.getRowDimension(), B.getRowDimension());
        }
        MatrixUtils.checkMultiplicationCompatible(B, R);
        MatrixUtils.checkMultiplicationCompatible(A, Q);
        SingularValueDecomposition svd = new SingularValueDecomposition(R);
        if (!svd.getSolver().isNonSingular()) {
            throw new MathIllegalArgumentException(LocalizedCoreFormats.SINGULAR_MATRIX, new Object[0]);
        }
        RealMatrix R_inv = svd.getSolver().getInverse();
        this.P = this.computeP(A, B, Q, R, R_inv, 100, 1.0E-8);
        this.K = R_inv.multiplyTransposed(B).multiply(this.P);
    }

    private RealMatrix computeP(RealMatrix A, RealMatrix B, RealMatrix Q, RealMatrix R, RealMatrix R_inv, int maxIterations, double epsilon) {
        RealMatrix P_ = this.computeInitialP(A, B, Q, R_inv);
        return this.approximateP(A, B, Q, R, R_inv, P_, maxIterations, epsilon);
    }

    private RealMatrix computeInitialP(RealMatrix A, RealMatrix B, RealMatrix Q, RealMatrix R_inv) {
        RealMatrix B_tran = B.transpose();
        RealMatrix m11 = A;
        RealMatrix m12 = B.multiply(R_inv).multiply(B_tran).scalarMultiply(-1.0).scalarAdd(0.0);
        RealMatrix m21 = Q.scalarMultiply(-1.0).scalarAdd(0.0);
        RealMatrix m22 = A.transpose().scalarMultiply(-1.0).scalarAdd(0.0);
        if (m11.getRowDimension() != m12.getRowDimension()) {
            throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH, m11.getRowDimension(), m12.getRowDimension());
        }
        if (m21.getRowDimension() != m22.getRowDimension()) {
            throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH, m21.getRowDimension(), m22.getRowDimension());
        }
        if (m11.getColumnDimension() != m21.getColumnDimension()) {
            throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH, m11.getColumnDimension(), m21.getColumnDimension());
        }
        if (m21.getColumnDimension() != m22.getColumnDimension()) {
            throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH, m21.getColumnDimension(), m22.getColumnDimension());
        }
        RealMatrix m = MatrixUtils.createRealMatrix(m11.getRowDimension() + m21.getRowDimension(), m11.getColumnDimension() + m12.getColumnDimension());
        m.setSubMatrix(m11.getData(), 0, 0);
        m.setSubMatrix(m12.getData(), 0, m11.getColumnDimension());
        m.setSubMatrix(m21.getData(), m11.getRowDimension(), 0);
        m.setSubMatrix(m22.getData(), m11.getRowDimension(), m11.getColumnDimension());
        OrderedComplexEigenDecomposition eigenDecomposition = new OrderedComplexEigenDecomposition(m);
        FieldMatrix<Complex> u = eigenDecomposition.getV();
        FieldMatrix<Complex> u11 = u.getSubMatrix(0, m11.getRowDimension() - 1, 0, m11.getColumnDimension() - 1);
        FieldMatrix<Complex> u21 = u.getSubMatrix(m11.getRowDimension(), 2 * m11.getRowDimension() - 1, 0, m11.getColumnDimension() - 1);
        FieldDecompositionSolver<Complex> solver = new FieldLUDecomposition<Complex>(u11).getSolver();
        if (!solver.isNonSingular()) {
            throw new MathRuntimeException(LocalizedCoreFormats.SINGULAR_MATRIX, new Object[0]);
        }
        FieldMatrix<Complex> u11_inv = solver.getInverse();
        FieldMatrix<Complex> p = u21.multiply(u11_inv);
        return this.convertToRealMatrix(p, Double.MAX_VALUE);
    }

    private RealMatrix approximateP(RealMatrix A, RealMatrix B, RealMatrix Q, RealMatrix R, RealMatrix R_inv, RealMatrix initialP, int maxIterations, double epsilon) {
        RealMatrix K_ = null;
        RealMatrix P_ = initialP;
        double error = 1.0;
        int i = 1;
        while (error > epsilon) {
            K_ = P_.multiply(B).multiply(R_inv).scalarMultiply(-1.0);
            RealMatrix X = A.add(B.multiplyTransposed(K_));
            RealMatrix Y = K_.multiply(R).multiplyTransposed(K_).scalarMultiply(-1.0).subtract(Q);
            Array2DRowRealMatrix X_ = (Array2DRowRealMatrix)X.transpose();
            Array2DRowRealMatrix Y_ = (Array2DRowRealMatrix)Y;
            Array2DRowRealMatrix eyeX = (Array2DRowRealMatrix)MatrixUtils.createRealIdentityMatrix(X_.getRowDimension());
            RealMatrix X__ = X_.kroneckerProduct(eyeX).add(eyeX.kroneckerProduct(X_));
            RealMatrix Y__ = Y_.stack();
            DecompositionSolver solver = new LUDecomposition(X__).getSolver();
            if (!solver.isNonSingular()) {
                throw new MathRuntimeException(LocalizedCoreFormats.SINGULAR_MATRIX, new Object[0]);
            }
            RealMatrix PX = solver.solve(Y__);
            RealMatrix P__ = ((Array2DRowRealMatrix)PX).unstackSquare();
            RealMatrix diff = P__.subtract(P_);
            SingularValueDecomposition svd = new SingularValueDecomposition(diff);
            error = svd.getNorm();
            P_ = P__;
            if (++i <= maxIterations) continue;
            throw new MathRuntimeException(LocalizedCoreFormats.CONVERGENCE_FAILED, new Object[0]);
        }
        return P_;
    }

    @Override
    public RealMatrix getP() {
        return this.P;
    }

    @Override
    public RealMatrix getK() {
        return this.K;
    }

    private RealMatrix convertToRealMatrix(FieldMatrix<Complex> matrix, double tolerance) {
        RealMatrix toRet = MatrixUtils.createRealMatrix(matrix.getRowDimension(), matrix.getRowDimension());
        for (int i = 0; i < toRet.getRowDimension(); ++i) {
            for (int j = 0; j < toRet.getColumnDimension(); ++j) {
                Complex c = matrix.getEntry(i, j);
                if (c.getImaginary() != 0.0 && FastMath.abs(c.getImaginary()) > tolerance) {
                    throw new MathRuntimeException(LocalizedCoreFormats.COMPLEX_CANNOT_BE_CONSIDERED_A_REAL_NUMBER, c.getReal(), c.getImaginary());
                }
                toRet.setEntry(i, j, c.getReal());
            }
        }
        return toRet;
    }
}

