package com.hbaspecto.pecas.sd.estimation;

import com.hbaspecto.pecas.FormatLogger;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import no.uib.cipr.matrix.BandMatrix;
import no.uib.cipr.matrix.DenseCholesky;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Matrices;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.Vector;
import org.apache.log4j.Logger;

/* loaded from: input_file:com/hbaspecto/pecas/sd/estimation/MarquardtMinimizer.class */
public class MarquardtMinimizer {
    private static Logger logger = Logger.getLogger(MarquardtMinimizer.class);
    private static FormatLogger loggerf = new FormatLogger(logger);
    private static final int CONVERGED = 0;
    private static final int MAX_ITERATIONS = 1;
    private static final int INVALID_STEP = 2;
    private static final int NOT_RUN = -1;
    private static final double MINIMUM_LAMBDA = 1.0E-7d;
    private static final double MAXIMUM_LAMBDA = 1000000.0d;
    private ObjectiveFunction obj;
    private double[] maxParamChange;
    private Vector parameters;
    private double currentObjective;
    private int iteration;
    private int termination;
    private int penaltyIteration;
    private double alpha;
    private boolean penaltyConverged;
    private Collection<Constraint> constraints = new ArrayList();
    private double lambda = 600.0d;
    private double initialLambda = this.lambda;
    private double lambdaIncrement = 2.0d;
    private double lambdaDecrement = 0.9d;

    /* loaded from: input_file:com/hbaspecto/pecas/sd/estimation/MarquardtMinimizer$AfterIterationCallback.class */
    public interface AfterIterationCallback {
        void finishedIteration(int i);

        void finishedFailedIteration(int i, int i2);
    }

    /* loaded from: input_file:com/hbaspecto/pecas/sd/estimation/MarquardtMinimizer$BeforeIterationCallback.class */
    public interface BeforeIterationCallback {
        void startIteration(int i);

        void startFailedIteration(int i, int i2);
    }

    public MarquardtMinimizer(ObjectiveFunction objectiveFunction, Vector vector) throws OptimizationException {
        this.obj = objectiveFunction;
        this.parameters = vector.copy();
        initMaxParamChange();
        this.obj.setParameterValues(this.parameters);
        try {
            this.currentObjective = this.obj.getValue() + getPenaltyFunction(this.parameters);
            this.iteration = CONVERGED;
            this.penaltyIteration = CONVERGED;
            this.alpha = 0.001d * this.currentObjective;
            this.termination = NOT_RUN;
        } catch (OptimizationException e) {
            throw new OptimizationException("Objective function could not be evaluated at guess", e);
        }
    }

    public void reset(Vector vector) throws OptimizationException {
        this.parameters = vector.copy();
        initMaxParamChange();
        this.obj.setParameterValues(this.parameters);
        try {
            this.currentObjective = this.obj.getValue() + getPenaltyFunction(this.parameters);
            this.lambda = this.initialLambda;
            this.iteration = CONVERGED;
        } catch (OptimizationException e) {
            throw new OptimizationException("Objective function could not be evaluated at guess", e);
        }
    }

    private void initMaxParamChange() {
        if (this.maxParamChange == null || this.parameters.size() != this.maxParamChange.length) {
            this.maxParamChange = new double[this.parameters.size()];
            for (int i = CONVERGED; i < this.maxParamChange.length; i += MAX_ITERATIONS) {
                this.maxParamChange[i] = Double.POSITIVE_INFINITY;
            }
        }
    }

    public void resetPenalty(Vector vector) throws OptimizationException {
        reset(vector);
        this.penaltyIteration = CONVERGED;
        this.alpha = 0.001d * this.currentObjective;
    }

    public void setInitialMarquardtFactor(double d) {
        this.lambda = d;
        this.initialLambda = this.lambda;
    }

    public void setMarquardtFactorAdjustments(double d, double d2) {
        this.lambdaIncrement = d;
        this.lambdaDecrement = d2;
    }

    public void addConstraint(Constraint constraint) {
        this.constraints.add(constraint);
    }

    public Collection<Constraint> getConstraints() {
        return new ArrayList(this.constraints);
    }

    public void clearConstraints() {
        this.constraints.clear();
    }

    public void setMaxParameterChange(Vector vector) {
        this.maxParamChange = Matrices.getArray(vector);
        for (int i = CONVERGED; i < this.maxParamChange.length; i += MAX_ITERATIONS) {
            this.maxParamChange[i] = Math.abs(this.maxParamChange[i]);
        }
    }

    public Vector doOneIteration() throws OptimizationException {
        return doOneIteration(null, null, 100);
    }

    private Vector doOneIteration(BeforeIterationCallback beforeIterationCallback, AfterIterationCallback afterIterationCallback, int i) throws OptimizationException {
        Vector add = this.obj.getGradient(this.parameters).add(getPenaltyFunctionGradient(this.parameters));
        Matrix add2 = this.obj.getHessian(this.parameters).add(getPenaltyFunctionHessian(this.parameters));
        if (beforeIterationCallback != null) {
            beforeIterationCallback.startIteration(this.iteration);
        }
        Matrix marquardtCorrection = getMarquardtCorrection(add2);
        boolean z = CONVERGED;
        int i2 = CONVERGED;
        while (!z) {
            Matrix copy = add2.copy();
            copy.add(this.lambda, marquardtCorrection);
            logger.info("Solving for the new parameter values...");
            Vector scale = convertColumnMatrixToVector(DenseCholesky.factorize(copy).solve(new DenseMatrix(add))).scale(-1.0d);
            for (int i3 = CONVERGED; i3 < scale.size(); i3 += MAX_ITERATIONS) {
                if (scale.get(i3) > this.maxParamChange[i3]) {
                    scale.set(i3, this.maxParamChange[i3]);
                } else if (scale.get(i3) < (-this.maxParamChange[i3])) {
                    scale.set(i3, -this.maxParamChange[i3]);
                }
            }
            Vector denseVector = new DenseVector(this.parameters);
            denseVector.add(scale);
            logger.info("Found new parameter values");
            this.obj.setParameterValues(denseVector);
            try {
                double value = this.obj.getValue() + getPenaltyFunction(denseVector);
                this.obj.logObjective(logger);
                if (value < this.currentObjective) {
                    logger.info("Potentially good step, checking...");
                    Vector vector = this.parameters;
                    this.parameters = denseVector;
                    if (isHessianOkay(this.obj.getHessian(this.parameters).add(getPenaltyFunctionHessian(this.parameters)))) {
                        this.currentObjective = value;
                        this.lambda = Math.max(this.lambdaDecrement * this.lambda, MINIMUM_LAMBDA);
                        logger.info("Found valid step, setting lambda to " + this.lambda);
                        z = MAX_ITERATIONS;
                    } else {
                        this.parameters = vector;
                        this.lambda = this.lambdaIncrement * this.lambda;
                        logger.info("***Step looked good but Hessian had infinity or NaN in it so we have to back up, lambda set to " + this.lambda);
                    }
                } else {
                    logger.info("Step leads to worse goodness of fit, backing up");
                    this.lambda = this.lambdaIncrement * this.lambda;
                    logger.info("Interpolated step, now setting lambda to " + this.lambda);
                }
            } catch (OptimizationException e) {
                logger.info("Overflow error, backing up");
                this.lambda = this.lambdaIncrement * this.lambda;
                logger.info("Setting lambda to " + this.lambda);
            }
            if (!z) {
                i2 += MAX_ITERATIONS;
                if (afterIterationCallback != null) {
                    afterIterationCallback.finishedFailedIteration(this.iteration, i2);
                }
                if (beforeIterationCallback != null) {
                    beforeIterationCallback.startFailedIteration(this.iteration, i2);
                }
            }
            if (this.lambda > MAXIMUM_LAMBDA) {
                throw new OptimizationException("Cannot find a valid step");
            }
        }
        logger.info("Finished iteration " + this.iteration + ".");
        this.iteration += MAX_ITERATIONS;
        if (afterIterationCallback != null) {
            afterIterationCallback.finishedIteration(this.iteration);
        }
        return this.parameters.copy();
    }

    private double getPenaltyFunction(Vector vector) {
        double d = 0.0d;
        Iterator<Constraint> it = this.constraints.iterator();
        while (it.hasNext()) {
            d += it.next().getPenaltyFunction(vector, this.alpha);
        }
        return d;
    }

    private Vector getPenaltyFunctionGradient(Vector vector) {
        DenseVector denseVector = new DenseVector(this.parameters.size());
        Iterator<Constraint> it = this.constraints.iterator();
        while (it.hasNext()) {
            denseVector.add(it.next().getPenaltyFunctionGradient(vector, this.alpha));
        }
        return denseVector;
    }

    private Matrix getPenaltyFunctionHessian(Vector vector) {
        DenseMatrix denseMatrix = new DenseMatrix(this.parameters.size(), this.parameters.size());
        Iterator<Constraint> it = this.constraints.iterator();
        while (it.hasNext()) {
            denseMatrix.add(it.next().getPenaltyFunctionHessian(vector, this.alpha));
        }
        return denseMatrix;
    }

    private Vector convertColumnMatrixToVector(Matrix matrix) {
        int numRows = matrix.numRows();
        DenseVector denseVector = new DenseVector(numRows);
        for (int i = CONVERGED; i < numRows; i += MAX_ITERATIONS) {
            denseVector.set(i, matrix.get(i, CONVERGED));
        }
        return denseVector;
    }

    private Matrix getMarquardtCorrection(Matrix matrix) {
        int numRows = matrix.numRows();
        BandMatrix bandMatrix = new BandMatrix(numRows, CONVERGED, CONVERGED);
        for (int i = CONVERGED; i < numRows; i += MAX_ITERATIONS) {
            double d = matrix.get(i, i);
            if (d == 0.0d) {
                d = 1.0d;
            }
            bandMatrix.set(i, i, d);
        }
        return bandMatrix;
    }

    private boolean isHessianOkay(Matrix matrix) {
        Double valueOf = Double.valueOf(matrix.norm(Matrix.Norm.One));
        return (valueOf.isInfinite() || valueOf.isNaN()) ? false : true;
    }

    private void logInvalidGradient(Vector vector) {
        loggerf.info("The following gradient elements are invalid:", new Object[CONVERGED]);
        for (int i = CONVERGED; i < vector.size(); i += MAX_ITERATIONS) {
            double d = vector.get(i);
            if (Double.isInfinite(d) || Double.isNaN(d)) {
                loggerf.info("Element %d: %f", Integer.valueOf(i), Double.valueOf(d));
            }
        }
    }

    public Vector iterateToConvergence(Vector vector, int i) {
        return iterateToConvergence(vector, i, null, null);
    }

    public Vector iterateToConvergence(Vector vector, int i, BeforeIterationCallback beforeIterationCallback, AfterIterationCallback afterIterationCallback) {
        this.iteration = CONVERGED;
        Vector vector2 = this.parameters;
        boolean z = CONVERGED;
        boolean z2 = CONVERGED;
        logger.info("Starting optimization - max iterations = " + i);
        this.obj.logObjective(logger);
        this.obj.storePreviousValues();
        try {
        } catch (OptimizationException e) {
            this.termination = INVALID_STEP;
        }
        if (!isHessianOkay(this.obj.getHessian(this.parameters).add(getPenaltyFunctionHessian(this.parameters)))) {
            logger.info("Initial guess led to infinity or NaN in Hessian");
            logInvalidGradient(this.obj.getGradient(this.parameters));
            if (beforeIterationCallback != null) {
                beforeIterationCallback.startIteration(this.iteration);
            }
            this.termination = INVALID_STEP;
            return this.parameters;
        }
        while (!z2 && this.iteration < i) {
            doOneIteration(beforeIterationCallback, afterIterationCallback, i);
            boolean checkConvergence = checkConvergence(this.parameters, vector2, vector);
            this.obj.storePreviousValues();
            vector2 = this.parameters;
            z2 = checkConvergence && z;
            z = checkConvergence;
        }
        if (this.iteration < i) {
            this.termination = CONVERGED;
        } else {
            this.termination = MAX_ITERATIONS;
        }
        this.penaltyIteration += MAX_ITERATIONS;
        this.alpha /= 10.0d;
        return this.parameters.copy();
    }

    private boolean checkConvergence(Vector vector, Vector vector2, Vector vector3) {
        for (int i = CONVERGED; i < vector.size(); i += MAX_ITERATIONS) {
            if (Math.abs(vector.get(i) - vector2.get(i)) >= vector3.get(i)) {
                return false;
            }
        }
        return true;
    }

    public Vector minimize(Vector vector, int i, Vector vector2, int i2) throws OptimizationException {
        if (this.constraints.size() == 0) {
            return iterateToConvergence(vector, i);
        }
        Vector vector3 = this.parameters;
        boolean z = CONVERGED;
        boolean z2 = CONVERGED;
        while (!z2 && this.penaltyIteration < i) {
            reset(vector3);
            iterateToConvergence(vector, i);
            boolean checkConvergence = checkConvergence(this.parameters, vector3, vector2);
            vector3 = this.parameters;
            z2 = checkConvergence && z;
            z = checkConvergence;
        }
        if (this.iteration < i) {
            this.penaltyConverged = true;
        } else {
            this.penaltyConverged = false;
        }
        return this.parameters.copy();
    }

    public double getCurrentObjectiveValue() {
        return this.currentObjective;
    }

    public int getNumberOfIterations() {
        return this.iteration;
    }

    public int getNumberOfPenaltyIterations() {
        return this.penaltyIteration;
    }

    public boolean lastRunConverged() {
        return this.termination == 0;
    }

    public boolean lastRunMaxIterations() {
        return this.termination == MAX_ITERATIONS;
    }

    public boolean lastRunPenaltyConverged() {
        return this.penaltyConverged;
    }

    public boolean lastRunPenaltyMaxIterations() {
        return !this.penaltyConverged;
    }

    public boolean lastRunInvalidStep() {
        return this.termination == INVALID_STEP;
    }
}
