package cc.mallet.grmm.inference.gbp;

import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.UndirectedGrid;
import cc.mallet.grmm.types.Variable;
import cc.mallet.util.CollectionUtils;
import cc.mallet.util.MalletLogger;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.logging.Logger;

/* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/grmm/inference/gbp/ClusterVariationalRegionGenerator.class */
public class ClusterVariationalRegionGenerator implements RegionGraphGenerator {
    private static final Logger logger = MalletLogger.getLogger(ClusterVariationalRegionGenerator.class.getName());
    private static final boolean debug = false;
    private BaseRegionComputer regionComputer;

    /* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/grmm/inference/gbp/ClusterVariationalRegionGenerator$BaseRegionComputer.class */
    public interface BaseRegionComputer {
        List computeBaseRegions(FactorGraph factorGraph);
    }

    /* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/grmm/inference/gbp/ClusterVariationalRegionGenerator$ByFactorRegionComputer.class */
    public static class ByFactorRegionComputer implements BaseRegionComputer {
        @Override // cc.mallet.grmm.inference.gbp.ClusterVariationalRegionGenerator.BaseRegionComputer
        public List computeBaseRegions(FactorGraph factorGraph) {
            ArrayList arrayList = new ArrayList(factorGraph.factors().size());
            Iterator factorsIterator = factorGraph.factorsIterator();
            while (factorsIterator.hasNext()) {
                arrayList.add(new Region((Factor) factorsIterator.next()));
            }
            ClusterVariationalRegionGenerator.removeSubsumedRegions(arrayList);
            ClusterVariationalRegionGenerator.addAllFactors(factorGraph, arrayList);
            return arrayList;
        }
    }

    /* loaded from: input_file:WEB-INF/lib/mallet-2.0.7.jar:cc/mallet/grmm/inference/gbp/ClusterVariationalRegionGenerator$Grid2x2RegionComputer.class */
    public static class Grid2x2RegionComputer implements BaseRegionComputer {
        @Override // cc.mallet.grmm.inference.gbp.ClusterVariationalRegionGenerator.BaseRegionComputer
        public List computeBaseRegions(FactorGraph factorGraph) {
            ArrayList arrayList = new ArrayList();
            UndirectedGrid undirectedGrid = (UndirectedGrid) factorGraph;
            for (int i = 0; i < undirectedGrid.getWidth() - 1; i++) {
                for (int i2 = 0; i2 < undirectedGrid.getHeight() - 1; i2++) {
                    arrayList.add(new Region(new Variable[]{undirectedGrid.get(i, i2), undirectedGrid.get(i, i2 + 1), undirectedGrid.get(i + 1, i2 + 1), undirectedGrid.get(i + 1, i2)}, new Factor[0]));
                }
            }
            ClusterVariationalRegionGenerator.addAllFactors(factorGraph, arrayList);
            return arrayList;
        }
    }

    public ClusterVariationalRegionGenerator() {
        this(new ByFactorRegionComputer());
    }

    public ClusterVariationalRegionGenerator(BaseRegionComputer baseRegionComputer) {
        this.regionComputer = baseRegionComputer;
    }

    @Override // cc.mallet.grmm.inference.gbp.RegionGraphGenerator
    public RegionGraph constructRegionGraph(FactorGraph factorGraph) {
        RegionGraph regionGraph = new RegionGraph();
        int i = 0;
        List computeBaseRegions = this.regionComputer.computeBaseRegions(factorGraph);
        while (!computeBaseRegions.isEmpty()) {
            List computeOverlaps = computeOverlaps(computeBaseRegions);
            addEdgesForOverlaps(regionGraph, computeBaseRegions, computeOverlaps);
            computeBaseRegions = computeOverlaps;
            i++;
        }
        regionGraph.computeInferenceCaches();
        logger.info("ClusterVariationalRegionGenerator: Number of regions " + regionGraph.size() + " Number of edges:" + regionGraph.numEdges());
        return regionGraph;
    }

    private List computeOverlaps(List list) {
        ArrayList arrayList = new ArrayList();
        Iterator it = list.iterator();
        while (it.hasNext()) {
            Region region = (Region) it.next();
            Iterator it2 = list.iterator();
            while (it2.hasNext()) {
                Region region2 = (Region) it2.next();
                if (region != region2) {
                    Collection intersection = CollectionUtils.intersection(region.vars, region2.vars);
                    if (!intersection.isEmpty() && !anySubsumes(arrayList, intersection)) {
                        Collection intersection2 = CollectionUtils.intersection(region.factors, region2.factors);
                        arrayList.add(new Region((Variable[]) intersection.toArray(new Variable[intersection.size()]), (Factor[]) intersection2.toArray(new Factor[intersection2.size()])));
                    }
                }
            }
        }
        ListIterator listIterator = arrayList.listIterator();
        while (listIterator.hasNext()) {
            if (anySubsumes(arrayList.subList(listIterator.nextIndex(), arrayList.size()), ((Region) listIterator.next()).vars)) {
                listIterator.remove();
            }
        }
        return arrayList;
    }

    private boolean anySubsumes(List list, Collection collection) {
        Iterator it = list.iterator();
        while (it.hasNext()) {
            if (((Region) it.next()).vars.containsAll(collection)) {
                return true;
            }
        }
        return false;
    }

    private void addEdgesForOverlaps(RegionGraph regionGraph, List list, List list2) {
        Iterator it = list.iterator();
        while (it.hasNext()) {
            Region region = (Region) it.next();
            Iterator it2 = list2.iterator();
            while (it2.hasNext()) {
                Region region2 = (Region) it2.next();
                if (region.vars.containsAll(region2.vars)) {
                    regionGraph.add(region, region2);
                }
            }
        }
    }

    public static void removeSubsumedRegions(List list) {
        ListIterator listIterator = list.listIterator();
        while (listIterator.hasNext()) {
            Region region = (Region) listIterator.next();
            Iterator it = list.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                Region region2 = (Region) it.next();
                if (region2 != region && region2.vars.size() >= region.vars.size() && region2.vars.containsAll(region.vars)) {
                    listIterator.remove();
                    break;
                }
            }
        }
    }

    public static void addAllFactors(FactorGraph factorGraph, List list) {
        Iterator it = list.iterator();
        while (it.hasNext()) {
            Region region = (Region) it.next();
            Iterator factorsIterator = factorGraph.factorsIterator();
            while (factorsIterator.hasNext()) {
                Factor factor = (Factor) factorsIterator.next();
                if (region.vars.containsAll(factor.varSet())) {
                    region.factors.add(factor);
                }
            }
        }
    }
}
