/*
 * Decompiled with CFR 0.152.
 */
package fmsim.inference;

import fmsim.inference.SMCScheme;
import fmsim.model.Calculations;
import fmsim.model.InferenceListener;
import fmsim.model.Protocol;
import fmsim.model.ProtocolEvent;
import fmsim.model.RateChangeEvent;
import fmsim.model.Rates;
import fmsim.model.VesicleModel;
import fmsim.model.VesicleModelState;
import fmsim.observations.Observations;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.util.FastMath;

public class InferenceOverallModel {
    private int inferredRateTotal;
    private int currentInferredRate;
    private final List<InferenceListener> listeners = new ArrayList<InferenceListener>();

    public Protocol runOverall(Protocol originalProtocol, Observations observations, int particleCount, int iterations) {
        if (!this.hasInferredRates(originalProtocol)) {
            return originalProtocol;
        }
        this.inferredRateTotal = this.countInferredRateTotal(originalProtocol);
        this.currentInferredRate = 0;
        Protocol protocol = originalProtocol.clone();
        protocol.setExperimentName(String.valueOf(protocol.getExperimentName()) + "-inferred");
        double marginalY = Double.NEGATIVE_INFINITY;
        int i = 1;
        while (i <= iterations) {
            boolean accepted;
            Protocol proposalProtocol = protocol.clone();
            this.setProposalParameters(proposalProtocol);
            List<RateChangeEvent> rateEvents = protocol.getRateChangeEvents();
            VesicleModel model = new VesicleModel(protocol.getConfiguration());
            model.setRates(rateEvents.get((int)0).rates);
            VesicleModelState[] sampleStates = model.getInitialSampleStates(particleCount);
            double simulationEndTime = protocol.getEndTime();
            double newMarginalY = 0.0;
            int rateIndex = 0;
            while (rateIndex < rateEvents.size()) {
                RateChangeEvent rateChangeEvent = rateEvents.get(rateIndex);
                double startTime = rateChangeEvent.time;
                double endTime = simulationEndTime;
                if (rateIndex < rateEvents.size() - 1) {
                    endTime = Math.min(endTime, rateEvents.get((int)(rateIndex + 1)).time);
                }
                model.setRates(rateChangeEvent.rates);
                Observations eventObservations = observations.getSubset(startTime, endTime);
                SMCScheme smcScheme = new SMCScheme(eventObservations, particleCount, startTime, endTime, protocol.getConfiguration());
                smcScheme.run(rateChangeEvent.rates, sampleStates);
                if (Double.isNaN(smcScheme.marginalLikelihood) || Double.isInfinite(smcScheme.marginalLikelihood)) {
                    System.err.println(smcScheme.marginalLikelihood);
                }
                newMarginalY += smcScheme.marginalLikelihood;
                sampleStates = smcScheme.sampleStates;
                ++rateIndex;
            }
            this.fireCurrentIteration(i, newMarginalY);
            double acceptance = FastMath.exp((double)(newMarginalY - marginalY));
            boolean bl = accepted = acceptance >= 1.0 || marginalY == Double.NEGATIVE_INFINITY;
            if (!accepted) {
                double trial = FastMath.random();
                boolean bl2 = accepted = trial < acceptance;
            }
            if (accepted) {
                System.out.println("Accepted: " + newMarginalY + ", previous: " + marginalY);
                marginalY = newMarginalY;
                protocol = proposalProtocol;
                this.fireCandidateFound(i, protocol, marginalY);
                Thread.yield();
            }
            ++i;
        }
        return protocol;
    }

    private int countInferredRateTotal(Protocol protocol) {
        int total = 0;
        for (ProtocolEvent event : protocol.events) {
            int i = 0;
            while (i < 10) {
                Rates.RateValue value = event.getRates().values[i];
                if (value.type == Rates.RateValue.Type.INFERRED) {
                    ++total;
                }
                ++i;
            }
        }
        int i = 0;
        while (i < 10) {
            Rates.RateValue value = protocol.getDefaultRates().values[i];
            if (value.type == Rates.RateValue.Type.INFERRED) {
                ++total;
            }
            ++i;
        }
        return total;
    }

    private void setProposalParameters(Protocol protocol) {
        int rateIndex = -1;
        for (ProtocolEvent event : protocol.events) {
            int i = 0;
            while (i < 10) {
                Rates.RateValue value = event.getRates().values[i];
                if (value.type == Rates.RateValue.Type.INFERRED && ++rateIndex == this.currentInferredRate) {
                    event.getRates().values[i] = Calculations.getProposalParameter(value);
                    break;
                }
                ++i;
            }
            if (rateIndex >= this.currentInferredRate) break;
        }
        int i = 0;
        while (i < 10 && rateIndex < this.currentInferredRate) {
            Rates.RateValue value = protocol.getDefaultRates().values[i];
            if (value.type == Rates.RateValue.Type.INFERRED && ++rateIndex == this.currentInferredRate) {
                protocol.getDefaultRates().values[i] = Calculations.getProposalParameter(value);
                break;
            }
            ++i;
        }
        this.currentInferredRate = (this.currentInferredRate + 1) % this.inferredRateTotal;
    }

    public void addInferenceListener(InferenceListener listener) {
        this.listeners.add(listener);
    }

    public void removeInferenceListener(InferenceListener listener) {
        this.listeners.remove(listener);
    }

    private void fireCurrentIteration(int iteration, double marginal) {
        for (InferenceListener listener : this.listeners) {
            listener.setCurrentIteration(iteration, marginal);
        }
    }

    private void fireCandidateFound(int iteration, Protocol protocol, double marginal) {
        for (InferenceListener listener : this.listeners) {
            listener.candidateFound(iteration, protocol, marginal);
        }
    }

    private boolean hasInferredRates(Protocol protocol) {
        for (ProtocolEvent event : protocol.events) {
            Rates.RateValue[] rateValueArray = event.getRates().values;
            int n = event.getRates().values.length;
            int n2 = 0;
            while (n2 < n) {
                Rates.RateValue value = rateValueArray[n2];
                if (value.type == Rates.RateValue.Type.INFERRED) {
                    return true;
                }
                ++n2;
            }
        }
        Rates.RateValue[] rateValueArray = protocol.getDefaultRates().values;
        int n = protocol.getDefaultRates().values.length;
        int n3 = 0;
        while (n3 < n) {
            Rates.RateValue value = rateValueArray[n3];
            if (value.type == Rates.RateValue.Type.INFERRED) {
                return true;
            }
            ++n3;
        }
        return false;
    }
}

