/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.common.tree;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.Set;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.SparseModel;
import org.tribuo.common.tree.LeafNode;
import org.tribuo.common.tree.Node;
import org.tribuo.common.tree.SplitNode;
import org.tribuo.math.la.SparseVector;
import org.tribuo.provenance.ModelProvenance;

public class TreeModel<T extends Output<T>>
extends SparseModel<T> {
    private static final long serialVersionUID = 3L;
    private final Node<T> root;

    TreeModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, boolean generatesProbabilities, Node<T> root) {
        super(name, description, featureIDMap, outputIDInfo, generatesProbabilities, TreeModel.gatherActiveFeatures(featureIDMap, root));
        this.root = root;
    }

    protected TreeModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, boolean generatesProbabilities, Map<String, List<String>> activeFeatures) {
        super(name, description, featureIDMap, outputIDInfo, generatesProbabilities, activeFeatures);
        this.root = null;
    }

    private static <T extends Output<T>> Map<String, List<String>> gatherActiveFeatures(ImmutableFeatureMap fMap, Node<T> root) {
        LinkedHashSet<String> activeFeatures = new LinkedHashSet<String>();
        LinkedList nodeQueue = new LinkedList();
        nodeQueue.offer(root);
        while (!nodeQueue.isEmpty()) {
            Node node = (Node)nodeQueue.poll();
            if (node == null || node.isLeaf()) continue;
            SplitNode splitNode = (SplitNode)node;
            String featureName = fMap.get(splitNode.getFeatureID()).getName();
            activeFeatures.add(featureName);
            nodeQueue.offer(splitNode.getGreaterThan());
            nodeQueue.offer(splitNode.getLessThanOrEqual());
        }
        return Collections.singletonMap("ALL_OUTPUTS", new ArrayList(activeFeatures));
    }

    public int getDepth() {
        return TreeModel.computeDepth(0, this.root);
    }

    protected static <T extends Output<T>> int computeDepth(int initialDepth, Node<T> root) {
        int maxDepth = initialDepth;
        LinkedList<Pair> nodeQueue = new LinkedList<Pair>();
        nodeQueue.offer(new Pair((Object)initialDepth, root));
        while (!nodeQueue.isEmpty()) {
            Pair nodePair = (Pair)nodeQueue.poll();
            int curDepth = (Integer)nodePair.getA() + 1;
            Node node = (Node)nodePair.getB();
            if (node == null || node.isLeaf()) continue;
            SplitNode splitNode = (SplitNode)node;
            Node greaterThan = splitNode.getGreaterThan();
            Node lessThan = splitNode.getLessThanOrEqual();
            if (greaterThan instanceof LeafNode) {
                if (maxDepth < curDepth) {
                    maxDepth = curDepth;
                }
            } else {
                nodeQueue.offer(new Pair((Object)curDepth, greaterThan));
            }
            if (lessThan instanceof LeafNode) {
                if (maxDepth >= curDepth) continue;
                maxDepth = curDepth;
                continue;
            }
            nodeQueue.offer(new Pair((Object)curDepth, lessThan));
        }
        return maxDepth;
    }

    public Prediction<T> predict(Example<T> example) {
        SparseVector vec = SparseVector.createSparseVector(example, (ImmutableFeatureMap)this.featureIDMap, (boolean)false);
        if (vec.numActiveElements() == 0) {
            throw new IllegalArgumentException("No features found in Example " + example.toString());
        }
        Node<T> oldNode = this.root;
        Node<T> curNode = this.root;
        while (curNode != null) {
            oldNode = curNode;
            curNode = oldNode.getNextNode(vec);
        }
        return ((LeafNode)oldNode).getPrediction(vec.numActiveElements(), example);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
        int maxFeatures = n < 0 ? this.featureIDMap.size() : n;
        HashMap<Object, Integer> featureCounts = new HashMap<Object, Integer>();
        LinkedList nodeQueue = new LinkedList();
        nodeQueue.offer(this.root);
        while (!nodeQueue.isEmpty()) {
            Node node = (Node)nodeQueue.poll();
            if (node == null || node.isLeaf()) continue;
            SplitNode splitNode = (SplitNode)node;
            String featureName = this.featureIDMap.get(splitNode.getFeatureID()).getName();
            featureCounts.put(featureName, featureCounts.getOrDefault(featureName, 0) + 1);
            nodeQueue.offer(splitNode.getGreaterThan());
            nodeQueue.offer(splitNode.getLessThanOrEqual());
        }
        Comparator<Pair> comparator = Comparator.comparingDouble(p -> Math.abs((Double)p.getB()));
        PriorityQueue<Pair> q = new PriorityQueue<Pair>(maxFeatures, comparator);
        for (Map.Entry e : featureCounts.entrySet()) {
            Pair cur = new Pair(e.getKey(), (Object)((Integer)e.getValue()));
            if (q.size() < maxFeatures) {
                q.offer(cur);
                continue;
            }
            if (comparator.compare(cur, q.peek()) <= 0) continue;
            q.poll();
            q.offer(cur);
        }
        ArrayList<Pair> list = new ArrayList<Pair>();
        while (q.size() > 0) {
            list.add(q.poll());
        }
        Collections.reverse(list);
        HashMap<String, List<Pair<String, Double>>> map = new HashMap<String, List<Pair<String, Double>>>();
        map.put("ALL_OUTPUTS", list);
        return map;
    }

    public Optional<Excuse<T>> getExcuse(Example<T> example) {
        ArrayList<String> list = new ArrayList<String>();
        SparseVector vec = SparseVector.createSparseVector(example, (ImmutableFeatureMap)this.featureIDMap, (boolean)false);
        Node<T> oldNode = this.root;
        Node<T> curNode = this.root;
        while (curNode != null) {
            oldNode = curNode;
            if (oldNode instanceof SplitNode) {
                SplitNode node = (SplitNode)curNode;
                list.add(this.featureIDMap.get(node.getFeatureID()).getName());
            }
            curNode = oldNode.getNextNode(vec);
        }
        Prediction<T> pred = ((LeafNode)oldNode).getPrediction(vec.numActiveElements(), example);
        ArrayList<Pair> pairs = new ArrayList<Pair>();
        int i = list.size() + 1;
        for (String s : list) {
            pairs.add(new Pair((Object)s, (Object)((double)i + 0.0)));
            --i;
        }
        HashMap<String, ArrayList<Pair>> map = new HashMap<String, ArrayList<Pair>>();
        map.put("ALL_OUTPUTS", pairs);
        return Optional.of(new Excuse(example, pred, map));
    }

    protected TreeModel<T> copy(String newName, ModelProvenance newProvenance) {
        return new TreeModel<T>(newName, newProvenance, this.featureIDMap, this.outputIDInfo, this.generatesProbabilities, this.root.copy());
    }

    public Set<String> getFeatures() {
        HashSet<String> features = new HashSet<String>();
        LinkedList nodeQueue = new LinkedList();
        nodeQueue.offer(this.root);
        while (!nodeQueue.isEmpty()) {
            Node node = (Node)nodeQueue.poll();
            if (node == null || node.isLeaf()) continue;
            SplitNode splitNode = (SplitNode)node;
            features.add(this.featureIDMap.get(splitNode.getFeatureID()).getName());
            nodeQueue.offer(splitNode.getGreaterThan());
            nodeQueue.offer(splitNode.getLessThanOrEqual());
        }
        return features;
    }

    public String toString() {
        return "TreeModel(description=" + this.provenance.toString() + ",\n\t\ttree=" + this.root.toString() + ")";
    }

    public Node<T> getRoot() {
        return this.root;
    }
}

