package fmsim.inference;

import fmsim.model.Calculations;
import fmsim.model.InferenceListener;
import fmsim.model.Protocol;
import fmsim.model.ProtocolEvent;
import fmsim.model.RateChangeEvent;
import fmsim.model.Rates;
import fmsim.model.VesicleModel;
import fmsim.model.VesicleModelState;
import fmsim.observations.Observations;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.math3.util.FastMath;

/* loaded from: input_file:fmsim/inference/InferenceOverallModel.class */
public class InferenceOverallModel {
    private int inferredRateTotal;
    private int currentInferredRate;
    private final List<InferenceListener> listeners = new ArrayList();

    public Protocol runOverall(Protocol protocol, Observations observations, int i, int i2) {
        if (!hasInferredRates(protocol)) {
            return protocol;
        }
        this.inferredRateTotal = countInferredRateTotal(protocol);
        this.currentInferredRate = 0;
        Protocol m12clone = protocol.m12clone();
        m12clone.setExperimentName(String.valueOf(m12clone.getExperimentName()) + "-inferred");
        double d = Double.NEGATIVE_INFINITY;
        for (int i3 = 1; i3 <= i2; i3++) {
            Protocol m12clone2 = m12clone.m12clone();
            setProposalParameters(m12clone2);
            List<RateChangeEvent> rateChangeEvents = m12clone.getRateChangeEvents();
            VesicleModel vesicleModel = new VesicleModel(m12clone.getConfiguration());
            vesicleModel.setRates(rateChangeEvents.get(0).rates);
            VesicleModelState[] initialSampleStates = vesicleModel.getInitialSampleStates(i);
            double endTime = m12clone.getEndTime();
            double d2 = 0.0d;
            for (int i4 = 0; i4 < rateChangeEvents.size(); i4++) {
                RateChangeEvent rateChangeEvent = rateChangeEvents.get(i4);
                double d3 = rateChangeEvent.time;
                double d4 = endTime;
                if (i4 < rateChangeEvents.size() - 1) {
                    d4 = Math.min(d4, rateChangeEvents.get(i4 + 1).time);
                }
                vesicleModel.setRates(rateChangeEvent.rates);
                SMCScheme sMCScheme = new SMCScheme(observations.getSubset(d3, d4), i, d3, d4, m12clone.getConfiguration());
                sMCScheme.run(rateChangeEvent.rates, initialSampleStates);
                if (Double.isNaN(sMCScheme.marginalLikelihood) || Double.isInfinite(sMCScheme.marginalLikelihood)) {
                    System.err.println(sMCScheme.marginalLikelihood);
                }
                d2 += sMCScheme.marginalLikelihood;
                initialSampleStates = sMCScheme.sampleStates;
            }
            fireCurrentIteration(i3, d2);
            double exp = FastMath.exp(d2 - d);
            boolean z = exp >= 1.0d || d == Double.NEGATIVE_INFINITY;
            if (!z) {
                z = FastMath.random() < exp;
            }
            if (z) {
                System.out.println("Accepted: " + d2 + ", previous: " + d);
                d = d2;
                m12clone = m12clone2;
                fireCandidateFound(i3, m12clone, d);
                Thread.yield();
            }
        }
        return m12clone;
    }

    private int countInferredRateTotal(Protocol protocol) {
        int i = 0;
        for (ProtocolEvent protocolEvent : protocol.events) {
            for (int i2 = 0; i2 < 10; i2++) {
                if (protocolEvent.getRates().values[i2].type == Rates.RateValue.Type.INFERRED) {
                    i++;
                }
            }
        }
        for (int i3 = 0; i3 < 10; i3++) {
            if (protocol.getDefaultRates().values[i3].type == Rates.RateValue.Type.INFERRED) {
                i++;
            }
        }
        return i;
    }

    private void setProposalParameters(Protocol protocol) {
        int i = -1;
        for (ProtocolEvent protocolEvent : protocol.events) {
            int i2 = 0;
            while (true) {
                if (i2 >= 10) {
                    break;
                }
                Rates.RateValue rateValue = protocolEvent.getRates().values[i2];
                if (rateValue.type == Rates.RateValue.Type.INFERRED) {
                    i++;
                    if (i == this.currentInferredRate) {
                        protocolEvent.getRates().values[i2] = Calculations.getProposalParameter(rateValue);
                        break;
                    }
                }
                i2++;
            }
            if (i >= this.currentInferredRate) {
                break;
            }
        }
        int i3 = 0;
        while (true) {
            if (i3 >= 10 || i >= this.currentInferredRate) {
                break;
            }
            Rates.RateValue rateValue2 = protocol.getDefaultRates().values[i3];
            if (rateValue2.type == Rates.RateValue.Type.INFERRED) {
                i++;
                if (i == this.currentInferredRate) {
                    protocol.getDefaultRates().values[i3] = Calculations.getProposalParameter(rateValue2);
                    break;
                }
            }
            i3++;
        }
        this.currentInferredRate = (this.currentInferredRate + 1) % this.inferredRateTotal;
    }

    public void addInferenceListener(InferenceListener inferenceListener) {
        this.listeners.add(inferenceListener);
    }

    public void removeInferenceListener(InferenceListener inferenceListener) {
        this.listeners.remove(inferenceListener);
    }

    private void fireCurrentIteration(int i, double d) {
        Iterator<InferenceListener> it = this.listeners.iterator();
        while (it.hasNext()) {
            it.next().setCurrentIteration(i, d);
        }
    }

    private void fireCandidateFound(int i, Protocol protocol, double d) {
        Iterator<InferenceListener> it = this.listeners.iterator();
        while (it.hasNext()) {
            it.next().candidateFound(i, protocol, d);
        }
    }

    private boolean hasInferredRates(Protocol protocol) {
        Iterator<ProtocolEvent> it = protocol.events.iterator();
        while (it.hasNext()) {
            for (Rates.RateValue rateValue : it.next().getRates().values) {
                if (rateValue.type == Rates.RateValue.Type.INFERRED) {
                    return true;
                }
            }
        }
        for (Rates.RateValue rateValue2 : protocol.getDefaultRates().values) {
            if (rateValue2.type == Rates.RateValue.Type.INFERRED) {
                return true;
            }
        }
        return false;
    }
}
