package com.hbaspecto.pecas.sd.estimation;

import com.hbaspecto.discreteChoiceModelling.Coefficient;
import java.io.BufferedWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import no.uib.cipr.matrix.DenseCholesky;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Matrices;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.Vector;
import org.apache.log4j.Logger;

/* loaded from: input_file:com/hbaspecto/pecas/sd/estimation/GaussBayesianObjective.class */
public class GaussBayesianObjective implements ObjectiveFunction {
    private DifferentiableModel myModel;
    private List<EstimationTarget> myTargets;
    private List<Coefficient> myCoeffs;
    private Matrix myInverseTargetVariance;
    private Vector myMean;
    private Matrix myInversePriorVariance;
    private int numParams;
    private double currentObjectiveValue;
    private double currentParameterError;
    private double currentTargetError;
    private Vector currentParameterValues;
    private double[] currentModelledValues;
    private Vector currentGradient;
    private Matrix currentHessian;
    private boolean previousValuesExist = false;
    private double previousObjectiveValue;
    private Vector previousParameterValues;
    private double[] previousModelledValues;

    public GaussBayesianObjective(DifferentiableModel differentiableModel, List<Coefficient> list, List<EstimationTarget> list2, Matrix matrix, Vector vector, Matrix matrix2) {
        this.numParams = vector.size();
        this.myModel = differentiableModel;
        this.myTargets = new ArrayList(list2);
        this.myCoeffs = new ArrayList(list);
        this.myMean = vector.copy();
        this.myInverseTargetVariance = DenseCholesky.factorize(matrix).solve(Matrices.identity(list2.size()));
        this.myInversePriorVariance = DenseCholesky.factorize(matrix2).solve(Matrices.identity(this.numParams));
    }

    @Override // com.hbaspecto.pecas.sd.estimation.ObjectiveFunction
    public void setParameterValues(Vector vector) {
        this.currentParameterValues = vector;
    }

    @Override // com.hbaspecto.pecas.sd.estimation.ObjectiveFunction
    public double getValue() throws OptimizationException {
        Vector targetValues = this.myModel.getTargetValues(this.myTargets, this.currentParameterValues);
        this.currentModelledValues = Matrices.getArray(targetValues);
        Vector copy = targetValues.copy();
        copy.add(-1.0d, getMyTargetValues());
        Vector copy2 = this.currentParameterValues.copy();
        copy2.add(-1.0d, this.myMean);
        this.currentTargetError = quadraticForm(copy, this.myInverseTargetVariance);
        this.currentParameterError = quadraticForm(copy2, this.myInversePriorVariance);
        this.currentObjectiveValue = this.currentTargetError + this.currentParameterError;
        return this.currentObjectiveValue;
    }

    @Override // com.hbaspecto.pecas.sd.estimation.ObjectiveFunction
    public Vector getGradient(Vector vector) throws OptimizationException {
        Vector copy = this.myModel.getTargetValues(this.myTargets, vector).copy();
        copy.add(-1.0d, getMyTargetValues());
        Vector mult = this.myInverseTargetVariance.mult(copy, copy.copy());
        Vector copy2 = vector.copy();
        copy2.add(-1.0d, this.myMean);
        Matrix jacobian = this.myModel.getJacobian(this.myTargets, vector);
        this.currentGradient = new DenseVector(this.numParams);
        this.currentGradient = jacobian.transMult(mult, this.currentGradient);
        this.currentGradient = this.myInversePriorVariance.multAdd(copy2, this.currentGradient);
        this.currentGradient.scale(2.0d);
        return this.currentGradient;
    }

    @Override // com.hbaspecto.pecas.sd.estimation.ObjectiveFunction
    public Matrix getHessian(Vector vector) throws OptimizationException {
        Matrix jacobian = this.myModel.getJacobian(this.myTargets, vector);
        Matrix mult = this.myInverseTargetVariance.mult(jacobian, new DenseMatrix(jacobian.numRows(), jacobian.numColumns()));
        this.currentHessian = new DenseMatrix(this.numParams, this.numParams);
        this.currentHessian = jacobian.transAmult(mult, this.currentHessian);
        this.currentHessian.add(this.myInversePriorVariance);
        this.currentHessian.scale(2.0d);
        return this.currentHessian;
    }

    private double quadraticForm(Vector vector, Matrix matrix) {
        DenseMatrix denseMatrix = new DenseMatrix(vector);
        Matrix transpose = denseMatrix.transpose(new DenseMatrix(1, vector.size()));
        return transpose.mult(matrix, transpose.copy()).mult(denseMatrix, new DenseMatrix(1, 1)).get(0, 0);
    }

    private Vector getMyTargetValues() {
        DenseVector denseVector = new DenseVector(this.myTargets.size());
        int i = 0;
        Iterator<EstimationTarget> it = this.myTargets.iterator();
        while (it.hasNext()) {
            denseVector.set(i, it.next().getTargetValue());
            i++;
        }
        return denseVector;
    }

    public void logParameters(Logger logger) {
        logger.info("Parameter values:");
        int i = 0;
        Iterator<Coefficient> it = this.myCoeffs.iterator();
        while (it.hasNext()) {
            String str = "Parameter " + it.next().getName() + ": prior mean = " + this.myMean.get(i) + ", current value = " + this.currentParameterValues.get(i);
            if (this.previousValuesExist) {
                str = String.valueOf(str) + ", previous value = " + this.previousParameterValues.get(i);
            }
            logger.info(str);
            i++;
        }
    }

    public void logCurrentValues(Logger logger) {
        logParameters(logger);
        logTargetAndObjective(logger);
    }

    public void logTargetAndObjective(Logger logger) {
        logObjective(logger);
        logger.info("Target values:");
        int i = 0;
        for (EstimationTarget estimationTarget : this.myTargets) {
            String str = "Target " + estimationTarget.getName() + ": target value = " + estimationTarget.getTargetValue() + ", modelled value = " + this.currentModelledValues[i];
            if (this.previousValuesExist) {
                str = String.valueOf(str) + ", previous value = " + this.previousModelledValues[i];
            }
            logger.info(str);
            i++;
        }
    }

    @Override // com.hbaspecto.pecas.sd.estimation.ObjectiveFunction
    public void logObjective(Logger logger) {
        logger.info("Current value of objective function = " + this.currentObjectiveValue);
        if (this.previousValuesExist) {
            logger.info("Previous value of objective function = " + this.previousObjectiveValue);
        }
        logger.info("Contribution from parameters = " + this.currentParameterError);
        logger.info("Contribution from targets = " + this.currentTargetError);
    }

    @Override // com.hbaspecto.pecas.sd.estimation.ObjectiveFunction
    public void storePreviousValues() {
        this.previousParameterValues = this.currentParameterValues;
        this.previousObjectiveValue = this.currentObjectiveValue;
        this.previousModelledValues = this.currentModelledValues;
        this.previousValuesExist = true;
    }

    public void printParameters(BufferedWriter bufferedWriter, ParameterPrinter parameterPrinter) throws IOException {
        List<Field> commonFields = parameterPrinter.getCommonFields(this.myCoeffs);
        bufferedWriter.write((String) commonFields.stream().map((v0) -> {
            return v0.getName();
        }).collect(Collectors.joining(",")));
        bufferedWriter.write(",PriorMean,CurValue");
        if (this.previousValuesExist) {
            bufferedWriter.write(",PrevValue");
        }
        bufferedWriter.newLine();
        int i = 0;
        Iterator<Coefficient> it = this.myCoeffs.iterator();
        while (it.hasNext()) {
            bufferedWriter.write(String.join(",", parameterPrinter.adaptToFields(it.next(), commonFields)));
            bufferedWriter.write("," + this.myMean.get(i));
            bufferedWriter.write("," + this.currentParameterValues.get(i));
            if (this.previousValuesExist) {
                bufferedWriter.write("," + this.previousParameterValues.get(i));
            }
            bufferedWriter.newLine();
            i++;
        }
    }

    public void printTargetAndObjective(BufferedWriter bufferedWriter, TargetPrinter targetPrinter) throws IOException {
        bufferedWriter.write("CurObj," + this.currentObjectiveValue);
        bufferedWriter.newLine();
        if (this.previousValuesExist) {
            bufferedWriter.write("PrevObj," + this.previousObjectiveValue);
            bufferedWriter.newLine();
        }
        bufferedWriter.write("ParamError," + this.currentParameterError);
        bufferedWriter.newLine();
        bufferedWriter.write("TargetError," + this.currentTargetError);
        bufferedWriter.newLine();
        List<Field> commonFields = targetPrinter.getCommonFields(this.myTargets);
        bufferedWriter.write((String) commonFields.stream().map((v0) -> {
            return v0.getName();
        }).collect(Collectors.joining(",")));
        bufferedWriter.write(",TargetValue,CurValue");
        if (this.previousValuesExist) {
            bufferedWriter.write(",PrevValue");
        }
        bufferedWriter.newLine();
        int i = 0;
        for (EstimationTarget estimationTarget : this.myTargets) {
            bufferedWriter.write(String.join(",", targetPrinter.adaptToFields(estimationTarget, commonFields)));
            bufferedWriter.write("," + estimationTarget.getTargetValue());
            bufferedWriter.write("," + this.currentModelledValues[i]);
            if (this.previousValuesExist) {
                bufferedWriter.write("," + this.previousModelledValues[i]);
            }
            bufferedWriter.newLine();
            i++;
        }
    }

    public void printGradient(BufferedWriter bufferedWriter, ParameterPrinter parameterPrinter) throws IOException {
        List<Field> commonFields = parameterPrinter.getCommonFields(this.myCoeffs);
        bufferedWriter.write((String) commonFields.stream().map((v0) -> {
            return v0.getName();
        }).collect(Collectors.joining(",")));
        bufferedWriter.write(",Derivative");
        bufferedWriter.newLine();
        int i = 0;
        Iterator<Coefficient> it = this.myCoeffs.iterator();
        while (it.hasNext()) {
            bufferedWriter.write(String.join(",", parameterPrinter.adaptToFields(it.next(), commonFields)));
            bufferedWriter.write("," + this.currentGradient.get(i));
            bufferedWriter.newLine();
            i++;
        }
    }

    public void printHessian(BufferedWriter bufferedWriter, ParameterPrinter parameterPrinter) throws IOException {
        for (int i = 0; i < this.myCoeffs.size(); i++) {
            bufferedWriter.write("," + parameterPrinter.asString(this.myCoeffs.get(i)));
        }
        bufferedWriter.newLine();
        for (int i2 = 0; i2 < this.currentHessian.numRows(); i2++) {
            bufferedWriter.write(parameterPrinter.asString(this.myCoeffs.get(i2)));
            for (int i3 = 0; i3 < this.myCoeffs.size(); i3++) {
                bufferedWriter.write("," + this.currentHessian.get(i2, i3));
            }
            bufferedWriter.newLine();
        }
    }

    public void printStdError(BufferedWriter bufferedWriter, ParameterPrinter parameterPrinter) throws IOException {
        Iterator<Coefficient> it = this.myCoeffs.iterator();
        while (it.hasNext()) {
            bufferedWriter.write("," + parameterPrinter.asString(it.next()));
        }
        bufferedWriter.newLine();
        int size = this.myCoeffs.size();
        Matrix solve = this.currentHessian.copy().scale(0.5d).solve(Matrices.identity(size), new DenseMatrix(size, size));
        for (int i = 0; i < size; i++) {
            bufferedWriter.write(parameterPrinter.asString(this.myCoeffs.get(i)));
            for (int i2 = 0; i2 < size; i2++) {
                if (i == i2) {
                    bufferedWriter.write("," + Math.sqrt(solve.get(i, i)));
                } else {
                    bufferedWriter.write("," + (solve.get(i, i2) / Math.sqrt(solve.get(i, i) * solve.get(i2, i2))));
                }
            }
            bufferedWriter.newLine();
        }
    }
}
