/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.parser.metrics;

import edu.stanford.nlp.ling.Label;
import edu.stanford.nlp.parser.metrics.AbstractEval;
import edu.stanford.nlp.parser.metrics.Evalb;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.trees.Constituent;
import edu.stanford.nlp.trees.Tree;
import java.io.PrintWriter;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;
import java.util.regex.Pattern;

public class EvalbByCat
extends AbstractEval {
    private final Evalb evalb;
    private Pattern pLabelFilter = null;
    private final Counter<Label> precisions;
    private final Counter<Label> recalls;
    private final Counter<Label> f1s;
    private final Counter<Label> precisions2;
    private final Counter<Label> recalls2;
    private final Counter<Label> pnums2;
    private final Counter<Label> rnums2;

    public EvalbByCat(String str, boolean runningAverages) {
        super(str, runningAverages);
        this.evalb = new Evalb(str, false);
        this.precisions = new ClassicCounter<Label>();
        this.recalls = new ClassicCounter<Label>();
        this.f1s = new ClassicCounter<Label>();
        this.precisions2 = new ClassicCounter<Label>();
        this.recalls2 = new ClassicCounter<Label>();
        this.pnums2 = new ClassicCounter<Label>();
        this.rnums2 = new ClassicCounter<Label>();
    }

    public EvalbByCat(String str, boolean runningAverages, String labelRegex) {
        this(str, runningAverages);
        if (labelRegex != null) {
            this.pLabelFilter = Pattern.compile(labelRegex.trim());
        }
    }

    protected Set<Constituent> makeObjects(Tree tree) {
        return this.evalb.makeObjects(tree);
    }

    private Map<Label, Set<Constituent>> makeObjectsByCat(Tree t) {
        HashMap<Label, Set<Constituent>> objMap = new HashMap<Label, Set<Constituent>>();
        Set<Constituent> objSet = this.makeObjects(t);
        for (Constituent lc : objSet) {
            Label l = lc.label();
            if (!objMap.keySet().contains(l)) {
                objMap.put(l, new HashSet());
            }
            ((Set)objMap.get(l)).add(lc);
        }
        return objMap;
    }

    @Override
    public void evaluate(Tree guess, Tree gold, PrintWriter pw) {
        if (gold == null || guess == null) {
            System.err.printf("%s: Cannot compare against a null gold or guess tree!%n", this.getClass().getName());
            return;
        }
        Map<Label, Set<Constituent>> guessDeps = this.makeObjectsByCat(guess);
        Map<Label, Set<Constituent>> goldDeps = this.makeObjectsByCat(gold);
        HashSet<Label> cats = new HashSet<Label>(guessDeps.keySet());
        cats.addAll(goldDeps.keySet());
        if (pw != null && this.runningAverages) {
            pw.println("========================================");
            pw.println("Labeled Bracketed Evaluation by Category");
            pw.println("========================================");
        }
        this.num += 1.0;
        for (Label cat : cats) {
            Set<Object> thisGuessDeps = guessDeps.containsKey(cat) ? guessDeps.get(cat) : new HashSet();
            Set<Object> thisGoldDeps = goldDeps.containsKey(cat) ? goldDeps.get(cat) : new HashSet();
            double currentPrecision = EvalbByCat.precision(thisGuessDeps, thisGoldDeps);
            double currentRecall = EvalbByCat.precision(thisGoldDeps, thisGuessDeps);
            double currentF1 = currentPrecision > 0.0 && currentRecall > 0.0 ? 2.0 / (1.0 / currentPrecision + 1.0 / currentRecall) : 0.0;
            this.precisions.incrementCount(cat, currentPrecision);
            this.recalls.incrementCount(cat, currentRecall);
            this.f1s.incrementCount(cat, currentF1);
            this.precisions2.incrementCount(cat, (double)thisGuessDeps.size() * currentPrecision);
            this.pnums2.incrementCount(cat, thisGuessDeps.size());
            this.recalls2.incrementCount(cat, (double)thisGoldDeps.size() * currentRecall);
            this.rnums2.incrementCount(cat, thisGoldDeps.size());
            if (pw == null || !this.runningAverages) continue;
            pw.println(cat + "\tP: " + (double)((int)(currentPrecision * 10000.0)) / 100.0 + " (sent ave " + (double)((int)(this.precisions.getCount(cat) * 10000.0 / this.num)) / 100.0 + ") (evalb " + (double)((int)(this.precisions2.getCount(cat) * 10000.0 / this.pnums2.getCount(cat))) / 100.0 + ")");
            pw.println("\tR: " + (double)((int)(currentRecall * 10000.0)) / 100.0 + " (sent ave " + (double)((int)(this.recalls.getCount(cat) * 10000.0 / this.num)) / 100.0 + ") (evalb " + (double)((int)(this.recalls2.getCount(cat) * 10000.0 / this.rnums2.getCount(cat))) / 100.0 + ")");
            double cF1 = 2.0 / (this.rnums2.getCount(cat) / this.recalls2.getCount(cat) + this.pnums2.getCount(cat) / this.precisions2.getCount(cat));
            String emit = this.str + " F1: " + (double)((int)(currentF1 * 10000.0)) / 100.0 + " (sent ave " + (double)((int)(10000.0 * this.f1s.getCount(cat) / this.num)) / 100.0 + ", evalb " + (double)((int)(10000.0 * cF1)) / 100.0 + ")";
            pw.println(emit);
        }
        if (pw != null && this.runningAverages) {
            pw.println("========================================");
        }
    }

    private Set<Label> getEvalLabelSet(Set<Label> labelSet) {
        if (this.pLabelFilter == null) {
            return new HashSet<Label>(this.precisions.keySet());
        }
        HashSet<Label> evalSet = new HashSet<Label>(this.precisions.keySet().size());
        for (Label label : labelSet) {
            if (!this.pLabelFilter.matcher(label.value()).matches()) continue;
            evalSet.add(label);
        }
        return evalSet;
    }

    @Override
    public void display(boolean verbose, PrintWriter pw) {
        if (this.precisions.keySet().size() != this.recalls.keySet().size()) {
            System.err.println("ERROR: Different counts for precisions and recalls!");
            return;
        }
        Set<Label> cats = this.getEvalLabelSet(this.precisions.keySet());
        Random rand = new Random();
        TreeMap<Double, Label> f1Map = new TreeMap<Double, Label>();
        for (Label cat : cats) {
            double rec;
            double pnum2 = this.pnums2.getCount(cat);
            double rnum2 = this.rnums2.getCount(cat);
            double prec = this.precisions2.getCount(cat) / pnum2;
            double f1 = 2.0 / (1.0 / prec + 1.0 / (rec = this.recalls2.getCount(cat) / rnum2));
            if (new Double(f1).equals(Double.NaN)) {
                f1 = -1.0;
            }
            if (f1Map.containsKey(f1)) {
                f1Map.put(f1 + rand.nextDouble() / 1000.0, cat);
                continue;
            }
            f1Map.put(f1, cat);
        }
        pw.println("============================================================");
        pw.println("Labeled Bracketed Evaluation by Category -- final statistics");
        pw.println("============================================================");
        double catPrecisions = 0.0;
        double catPrecisionNums = 0.0;
        double catRecalls = 0.0;
        double catRecallNums = 0.0;
        for (Label cat : f1Map.values()) {
            double pnum2 = this.pnums2.getCount(cat);
            double rnum2 = this.rnums2.getCount(cat);
            double prec = this.precisions2.getCount(cat) / pnum2;
            double rec = this.recalls2.getCount(cat) / rnum2;
            double f1 = 2.0 / (1.0 / (prec *= 100.0) + 1.0 / (rec *= 100.0));
            catPrecisions += this.precisions2.getCount(cat);
            catPrecisionNums += pnum2;
            catRecalls += this.recalls2.getCount(cat);
            catRecallNums += rnum2;
            String LP = pnum2 == 0.0 ? "N/A" : String.format("%.2f", prec);
            String LR = rnum2 == 0.0 ? "N/A" : String.format("%.2f", rec);
            String F1 = pnum2 == 0.0 || rnum2 == 0.0 ? "N/A" : String.format("%.2f", f1);
            pw.printf("%s\tLP: %s\tguessed: %d\tLR: %s\tgold: %d\t F1: %s%n", cat.value(), LP, (int)pnum2, LR, (int)rnum2, F1);
        }
        pw.println("============================================================");
        double prec = catPrecisions / catPrecisionNums;
        double rec = catRecalls / catRecallNums;
        double f1 = 2.0 * prec * rec / (prec + rec);
        pw.printf("Total\tLP: %.2f\tguessed: %d\tLR: %.2f\tgold: %d\t F1: %.2f%n", prec * 100.0, (int)catPrecisionNums, rec * 100.0, (int)catRecallNums, f1 * 100.0);
        pw.println("============================================================");
    }
}

