package com.pb.common.matrix;

import org.apache.log4j.Logger;

/* loaded from: input_file:com/pb/common/matrix/MatrixBalancer.class */
public class MatrixBalancer {
    private Matrix seed;
    private RowVector columnTargets;
    private ColumnVector rowTargets;
    private double relativeError;
    private double absoluteError;
    private int iteration;
    private Logger logger = Logger.getLogger(MatrixBalancer.class);
    private double maxRelativeError = 0.0d;
    private boolean maxRelativeErrorSet = false;
    private double maxAbsoluteError = 0.0d;
    private boolean maxAbsoluteErrorSet = false;
    private int maxIterations = 20;
    private boolean maxIterationsSet = false;

    public void setMaximumRelativeError(double d) {
        this.maxRelativeError = d;
        this.maxRelativeErrorSet = true;
    }

    public void setMaximumAbsoluteError(double d) {
        this.maxAbsoluteError = d;
        this.maxAbsoluteErrorSet = true;
    }

    public void setMaximumIterations(int i) {
        this.maxIterations = i;
        this.maxIterationsSet = true;
    }

    public void setSeed(Matrix matrix) {
        this.seed = (Matrix) matrix.clone();
    }

    public void setTargets(ColumnVector columnVector, RowVector rowVector) {
        this.rowTargets = columnVector;
        this.columnTargets = rowVector;
    }

    public void balance() {
        this.seed = balance(this.seed, this.rowTargets, this.columnTargets);
    }

    public Matrix balance(Matrix matrix, ColumnVector columnVector, RowVector rowVector) {
        checkTargetTotals(columnVector, rowVector);
        checkClosureSet();
        this.iteration = 1;
        this.logger.debug("Beginning matrix balancing using iterative proportional fitting.");
        while (!isClosed()) {
            this.logger.debug("Iteration " + this.iteration);
            matrix = balance(balance(matrix, columnVector), rowVector);
            computeErrors(matrix, columnVector);
            this.logger.debug("Balancing relative error: " + this.relativeError + ", absolute error: " + this.absoluteError);
            this.iteration++;
        }
        return matrix;
    }

    private void checkClosureSet() {
        if (!this.maxIterationsSet && !this.maxRelativeErrorSet && !this.maxAbsoluteErrorSet) {
            this.logger.error("No closure criteria set.");
            throw new MatrixException("No closure criteria set.");
        }
        if (this.maxIterationsSet) {
            this.logger.info("Maximum number of iterations set to " + this.maxIterations);
        } else {
            this.logger.warn("No maximum number of iterations set.");
        }
        if (this.maxRelativeErrorSet) {
            this.logger.info("Minimum relative error set to " + this.maxRelativeError);
        } else {
            this.logger.warn("Minimum relative error not set.");
        }
        if (this.maxAbsoluteErrorSet) {
            this.logger.info("Minimum absolute error set to " + this.maxAbsoluteError);
        } else {
            this.logger.warn("Minimum absolute error not set.");
        }
    }

    private boolean isClosed() {
        boolean z = false;
        if (this.maxIterationsSet && this.iteration > this.maxIterations) {
            this.logger.info("Reached iteration maximum.");
            z = true;
        }
        if (this.iteration > 1) {
            if (this.maxRelativeErrorSet && this.relativeError < this.maxRelativeError) {
                this.logger.info("Reached minimum relative error.");
                z = true;
            }
            if (this.maxAbsoluteErrorSet && this.absoluteError < this.maxAbsoluteError) {
                this.logger.info("Reached minimum absolute error.");
                z = true;
            }
        }
        if (z) {
            this.logger.info("Closed in " + this.iteration + " iterations.");
            this.logger.info("Final relative error: " + this.relativeError);
            this.logger.info("Final absoute error: " + this.absoluteError);
        }
        return z;
    }

    private void checkTargetTotals(ColumnVector columnVector, RowVector rowVector) {
        double sum = columnVector.getSum();
        double sum2 = rowVector.getSum();
        if (relativeDifference(sum, sum2) > this.maxRelativeError) {
            String str = "Row targets sum (" + sum + ") does not match column target sum (" + sum2 + ")";
            this.logger.error(str);
            throw new MatrixException(str);
        }
    }

    private double relativeDifference(double d, double d2) {
        return Math.abs(d - d2) / Math.min(d, d2);
    }

    private Matrix balance(Matrix matrix, ColumnVector columnVector) {
        float valueAt;
        int[] externalRowNumbers = matrix.getExternalRowNumbers();
        for (int i = 1; i < externalRowNumbers.length; i++) {
            int i2 = externalRowNumbers[i];
            float rowSum = matrix.getRowSum(i2);
            if (rowSum != 0.0f) {
                valueAt = columnVector.getValueAt(i2) / rowSum;
            } else {
                if (columnVector.getValueAt(i2) != 0.0f) {
                    throw new RuntimeException("Row " + i2 + ":  Seed row adds to 0 but target is NOT zero.");
                }
                valueAt = 0.0f;
            }
            float f = valueAt;
            this.logger.debug("Scaling factor for row " + i2 + ": " + f);
            scaleRow(matrix, i2, f);
        }
        return matrix;
    }

    private void scaleRow(Matrix matrix, int i, float f) {
        int[] externalColumnNumbers = matrix.getExternalColumnNumbers();
        for (int i2 = 1; i2 < externalColumnNumbers.length; i2++) {
            int i3 = externalColumnNumbers[i2];
            matrix.setValueAt(i, i3, matrix.getValueAt(i, i3) * f);
        }
    }

    private Matrix balance(Matrix matrix, RowVector rowVector) {
        float valueAt;
        int[] externalColumnNumbers = matrix.getExternalColumnNumbers();
        for (int i = 1; i < externalColumnNumbers.length; i++) {
            int i2 = externalColumnNumbers[i];
            float columnSum = matrix.getColumnSum(i2);
            if (columnSum != 0.0f) {
                valueAt = rowVector.getValueAt(i2) / columnSum;
            } else {
                if (rowVector.getValueAt(i2) != 0.0f) {
                    throw new RuntimeException("Column " + i2 + ":  Seed column adds to 0 but target is NOT zero.");
                }
                valueAt = 0.0f;
            }
            float f = valueAt;
            this.logger.debug("Scaling factor for column " + i2 + ": " + f);
            scaleColumn(matrix, i2, f);
        }
        return matrix;
    }

    private void scaleColumn(Matrix matrix, int i, float f) {
        int[] externalRowNumbers = matrix.getExternalRowNumbers();
        for (int i2 = 1; i2 < externalRowNumbers.length; i2++) {
            int i3 = externalRowNumbers[i2];
            matrix.setValueAt(i3, i, matrix.getValueAt(i3, i) * f);
        }
    }

    private void computeErrors(Matrix matrix, ColumnVector columnVector) {
        int[] externalRowNumbers = matrix.getExternalRowNumbers();
        this.relativeError = 0.0d;
        this.absoluteError = 0.0d;
        for (int i = 1; i < externalRowNumbers.length; i++) {
            int i2 = externalRowNumbers[i];
            float valueAt = columnVector.getValueAt(i2);
            float rowSum = matrix.getRowSum(i2);
            double relativeDifference = (valueAt == 0.0f && rowSum == 0.0f) ? 0.0d : relativeDifference(valueAt, rowSum);
            double abs = Math.abs(valueAt - rowSum);
            this.relativeError = Math.max(relativeDifference, this.relativeError);
            this.absoluteError = Math.max(abs, this.absoluteError);
            this.logger.debug("Relative error on row " + i2 + ": " + relativeDifference + ", absolute error: " + this.absoluteError);
        }
    }

    public double getRelativeError() {
        return this.relativeError;
    }

    public double getAbsoluteError() {
        return this.absoluteError;
    }

    public int getIteration() {
        return this.iteration;
    }

    public Matrix getBalancedMatrix() {
        return this.seed;
    }

    public static void scaleRowTargets(ColumnVector columnVector, RowVector rowVector) {
        columnVector.scale((float) (rowVector.getSum() / columnVector.getSum()));
    }

    public static void scaleColumnTargets(ColumnVector columnVector, RowVector rowVector) {
        rowVector.scale((float) (columnVector.getSum() / rowVector.getSum()));
    }

    public static void scaleTargetsToAvg(ColumnVector columnVector, RowVector rowVector) {
        float sum = (float) ((columnVector.getSum() + rowVector.getSum()) / 2.0d);
        rowVector.scale((float) (sum / rowVector.getSum()));
        columnVector.scale((float) (sum / columnVector.getSum()));
    }
}
