/*
 * Decompiled with CFR 0.152.
 */
package ghidra.feature.vt.api.correlator.program;

import generic.DominantPair;
import generic.lsh.vector.LSHCosineVectorAccum;
import generic.lsh.vector.LSHVector;
import generic.lsh.vector.VectorCompare;
import ghidra.feature.vt.api.main.VTAssociation;
import ghidra.feature.vt.api.main.VTAssociationStatus;
import ghidra.feature.vt.api.main.VTAssociationType;
import ghidra.feature.vt.api.main.VTMatch;
import ghidra.feature.vt.api.main.VTMatchInfo;
import ghidra.feature.vt.api.main.VTMatchSet;
import ghidra.feature.vt.api.main.VTScore;
import ghidra.feature.vt.api.main.VTSession;
import ghidra.feature.vt.api.util.VTAbstractProgramCorrelator;
import ghidra.framework.options.ToolOptions;
import ghidra.framework.plugintool.ServiceProvider;
import ghidra.program.model.address.Address;
import ghidra.program.model.address.AddressSetView;
import ghidra.program.model.listing.CodeUnit;
import ghidra.program.model.listing.CodeUnitIterator;
import ghidra.program.model.listing.Data;
import ghidra.program.model.listing.Function;
import ghidra.program.model.listing.FunctionManager;
import ghidra.program.model.listing.Instruction;
import ghidra.program.model.listing.Listing;
import ghidra.program.model.listing.Program;
import ghidra.program.model.symbol.Reference;
import ghidra.program.model.symbol.ReferenceIterator;
import ghidra.program.model.symbol.ReferenceManager;
import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskMonitor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

public abstract class VTAbstractReferenceProgramCorrelator
extends VTAbstractProgramCorrelator {
    private static final int MAX_DEPTH = 30;
    private static final int TOP_N = 5;
    private static final double DIFFERENTIAL = 0.2;
    private static final double EQUALS_EPSILON = 1.0E-5;
    private String correlatorName;
    private HashMap<Address, LSHCosineVectorAccum> srcFuncAddresstoVectorMap;
    private HashMap<Address, LSHCosineVectorAccum> destFuncAddresstoVectorMap;
    private Program sourceProgram;
    private Program destinationProgram;
    private Listing sourceListing;
    private Listing destinationListing;
    private static final Comparator<VTMatchInfo> SCORE_COMPARATOR = new Comparator<VTMatchInfo>(){

        @Override
        public int compare(VTMatchInfo o1, VTMatchInfo o2) {
            return o2.getSimilarityScore().compareTo(o1.getSimilarityScore());
        }
    };

    VTAbstractReferenceProgramCorrelator(ServiceProvider serviceProvider, Program sourceProgram, AddressSetView sourceAddressSet, Program destinationProgram, AddressSetView destinationAddressSet, String correlatorName, ToolOptions options) {
        super(serviceProvider, sourceProgram, sourceAddressSet, destinationProgram, destinationAddressSet, options);
        this.correlatorName = correlatorName;
        this.sourceProgram = sourceProgram;
        this.destinationProgram = destinationProgram;
        this.sourceListing = sourceProgram.getListing();
        this.destinationListing = destinationProgram.getListing();
    }

    @Override
    public String getName() {
        return this.correlatorName;
    }

    @Override
    protected void doCorrelate(VTMatchSet matchSet, TaskMonitor monitor) throws CancelledException {
        double minbits = this.getOptions().getDouble("Confidence threshold (info content)", 1.0);
        double similarity_threshold = this.getOptions().getDouble("Minimum similarity threshold (score)", 0.5);
        monitor.setMessage("Finding reference features");
        this.extractReferenceFeatures(matchSet, monitor);
        monitor.setMessage("Finding destination functions");
        try {
            this.findDestinations(matchSet, similarity_threshold, minbits, monitor);
        }
        catch (Exception e) {
            throw new RuntimeException("problem with parallel decompiler", e);
        }
    }

    protected void findDestinations(VTMatchSet matchSet, double similarityThreshold, double minbits, TaskMonitor monitor) {
        monitor.initialize((long)this.destFuncAddresstoVectorMap.size());
        for (Map.Entry<Address, LSHCosineVectorAccum> destEntry : this.destFuncAddresstoVectorMap.entrySet()) {
            if (monitor.isCancelled()) {
                return;
            }
            monitor.incrementProgress(1L);
            Function destFunc = this.destinationListing.getFunctionAt(destEntry.getKey());
            LSHCosineVectorAccum dstVector = destEntry.getValue();
            HashMap<Address, DominantPair<Double, VectorCompare>> srcNeighbors = new HashMap<Address, DominantPair<Double, VectorCompare>>();
            for (Map.Entry<Address, LSHCosineVectorAccum> srcEntry : this.srcFuncAddresstoVectorMap.entrySet()) {
                Address srcAddr = srcEntry.getKey();
                LSHCosineVectorAccum srcVector = srcEntry.getValue();
                VectorCompare veccompare = new VectorCompare();
                Double similarity = dstVector.compare((LSHVector)srcVector, veccompare);
                DominantPair compareOut = new DominantPair((Object)similarity, (Object)veccompare);
                if (!(dstVector.compare((LSHVector)srcVector, veccompare) > 0.0)) continue;
                srcNeighbors.put(srcAddr, (DominantPair<Double, VectorCompare>)compareOut);
            }
            List<VTMatchInfo> members = this.transform(matchSet, destFunc, dstVector, srcNeighbors, similarityThreshold, minbits, monitor);
            for (VTMatchInfo member : members) {
                if (member == null) continue;
                matchSet.addMatch(member);
            }
        }
    }

    private List<VTMatchInfo> transform(VTMatchSet matchSet, Function destinationFunction, LSHCosineVectorAccum destinationVector, HashMap<Address, DominantPair<Double, VectorCompare>> neighbors, double similarityThreshold, double minbits, TaskMonitor monitor) {
        boolean refineResult = this.getOptions().getBoolean("Refine Results", true);
        Address destinationAddress = destinationFunction.getEntryPoint();
        int destinationLength = (int)destinationFunction.getBody().getNumAddresses();
        List<VTMatchInfo> result = new ArrayList<VTMatchInfo>();
        for (Map.Entry<Address, DominantPair<Double, VectorCompare>> neighbor : neighbors.entrySet()) {
            if (monitor.isCancelled()) break;
            Address sourceAddr = neighbor.getKey();
            double similarity = (Double)neighbor.getValue().first;
            VectorCompare veccompare = (VectorCompare)neighbor.getValue().second;
            veccompare.fillOut();
            double confidence = veccompare.dotproduct;
            if (similarity < similarityThreshold || Double.isNaN(similarity) || confidence < minbits) continue;
            confidence *= 10.0;
            VTMatchInfo match = new VTMatchInfo(matchSet);
            Function sourceFunction = this.sourceListing.getFunctionAt(sourceAddr);
            Address sourceAddress = sourceFunction.getEntryPoint();
            int sourceLength = (int)sourceFunction.getBody().getNumAddresses();
            match.setSimilarityScore(new VTScore(similarity));
            match.setConfidenceScore(new VTScore(confidence));
            match.setSourceLength(sourceLength);
            match.setDestinationLength(destinationLength);
            match.setSourceAddress(sourceAddress);
            match.setDestinationAddress(destinationAddress);
            match.setTag(null);
            match.setAssociationType(VTAssociationType.FUNCTION);
            result.add(match);
        }
        if (refineResult) {
            result = this.refine(result);
        }
        return result;
    }

    private List<VTMatchInfo> refine(List<VTMatchInfo> list) {
        int ii;
        int cutoffIndex;
        Collections.sort(list, SCORE_COMPARATOR);
        int topN = Math.min(6, list.size());
        list = list.subList(0, topN);
        if (list.size() > 1) {
            double previousScore = list.get(0).getSimilarityScore().getScore();
            cutoffIndex = 1;
            for (ii = 1; ii < list.size(); ++ii) {
                double currentScore = list.get(ii).getSimilarityScore().getScore();
                if (currentScore > previousScore - 1.0E-5) {
                    --cutoffIndex;
                    break;
                }
                ++cutoffIndex;
                previousScore = currentScore;
            }
            list = list.subList(0, cutoffIndex);
        }
        if ((list = list.subList(0, topN = Math.min(5, list.size()))).size() > 1) {
            double bestScore = list.get(0).getSimilarityScore().getScore();
            cutoffIndex = list.size();
            for (ii = 1; ii < list.size(); ++ii) {
                if (!(list.get(ii).getSimilarityScore().getScore() < bestScore - 0.2)) continue;
                cutoffIndex = ii;
                break;
            }
            list = list.subList(0, cutoffIndex);
        }
        return list;
    }

    private void accumulateFunctionReferences(int depth, List<Function> list, ReferenceManager refManager, FunctionManager funManager, Listing listing, Address address) {
        Address[] thunkAddresses;
        if (depth >= 30) {
            return;
        }
        Function addressFunction = funManager.getFunctionAt(address);
        if (addressFunction != null && (thunkAddresses = addressFunction.getFunctionThunkAddresses()) != null) {
            for (Address thunkAddress : thunkAddresses) {
                if (depth >= 30) continue;
                this.accumulateFunctionReferences(depth + 1, list, refManager, funManager, listing, thunkAddress);
            }
        }
        ReferenceIterator ii = refManager.getReferencesTo(address);
        while (ii.hasNext()) {
            Reference reference = ii.next();
            Address fromAddress = reference.getFromAddress();
            CodeUnit codeUnit = listing.getCodeUnitAt(fromAddress);
            if (codeUnit instanceof Instruction) {
                Function function = funManager.getFunctionContaining(fromAddress);
                if (function == null) continue;
                if (!function.isThunk()) {
                    list.add(function);
                    continue;
                }
                this.accumulateFunctionReferences(depth + 1, list, refManager, funManager, listing, function.getEntryPoint());
                continue;
            }
            if (!(codeUnit instanceof Data) || depth >= 30) continue;
            this.accumulateFunctionReferences(depth + 1, list, refManager, funManager, listing, fromAddress);
        }
    }

    protected abstract boolean isExpectedRefType(VTAssociationType var1);

    protected abstract boolean isExpectedRefType(Reference var1);

    protected void extractReferenceFeatures(VTMatchSet matchSet, TaskMonitor monitor) {
        int i;
        CodeUnitIterator iter;
        this.srcFuncAddresstoVectorMap = new HashMap();
        this.destFuncAddresstoVectorMap = new HashMap();
        FunctionManager srcFuncManager = this.sourceProgram.getFunctionManager();
        FunctionManager destFuncManager = this.destinationProgram.getFunctionManager();
        int srcFunctionCount = srcFuncManager.getFunctionCount();
        int destFunctionCount = destFuncManager.getFunctionCount();
        VTSession session = matchSet.getSession();
        int total = 0;
        HashMap<String, VTMatchSet> dedupedMatchSets = new HashMap<String, VTMatchSet>();
        for (VTMatchSet ms : session.getMatchSets()) {
            String name = ms.getProgramCorrelatorInfo().getName();
            if (name.equals(this.correlatorName) || dedupedMatchSets.containsKey(name) && ms.getID() < ((VTMatchSet)dedupedMatchSets.get(name)).getID()) continue;
            dedupedMatchSets.put(name, ms);
            total += ms.getMatchCount();
        }
        Collection matchSets = dedupedMatchSets.values();
        monitor.initialize((long)total);
        HashMap<VTMatch, ArrayList<Function>> sourceRefMap = new HashMap<VTMatch, ArrayList<Function>>();
        HashMap destinationRefMap = new HashMap();
        for (Object ms : matchSets) {
            Collection<VTMatch> matches = ms.getMatches();
            for (VTMatch match : matches) {
                if (monitor.isCancelled()) {
                    return;
                }
                monitor.incrementProgress(1L);
                VTAssociation association = match.getAssociation();
                Address sourceAddress = association.getSourceAddress();
                Address destinationAddress = association.getDestinationAddress();
                if (!this.isExpectedRefType(association.getType()) || association.getStatus() != VTAssociationStatus.ACCEPTED) continue;
                ArrayList<Function> sourceReferences = new ArrayList<Function>();
                this.accumulateFunctionReferences(0, sourceReferences, this.sourceProgram.getReferenceManager(), srcFuncManager, this.sourceListing, sourceAddress);
                Iterator destinationReferences = new ArrayList<Function>();
                this.accumulateFunctionReferences(0, (List<Function>)((Object)destinationReferences), this.destinationProgram.getReferenceManager(), destFuncManager, this.destinationListing, destinationAddress);
                int sourceReferenceCountTo = sourceReferences.size();
                int destinationReferenceCountTo = ((ArrayList)((Object)destinationReferences)).size();
                if (sourceReferenceCountTo == 0 || destinationReferenceCountTo == 0) continue;
                sourceRefMap.put(match, sourceReferences);
                destinationRefMap.put(match, destinationReferences);
            }
        }
        monitor.setMessage("Adding ACCEPTED matches to feature vectors.");
        int featureID = 1;
        for (VTMatch match : sourceRefMap.keySet()) {
            LSHCosineVectorAccum vector;
            if (monitor.isCancelled()) {
                return;
            }
            monitor.incrementProgress(1L);
            if (sourceRefMap.get(match) == null) continue;
            HashSet srcRefFuncs = new HashSet((Collection)sourceRefMap.get(match));
            HashSet destRefFuncs = new HashSet((Collection)destinationRefMap.get(match));
            double altPraw = (double)(srcRefFuncs.size() + destRefFuncs.size()) / (double)(srcFunctionCount + destFunctionCount);
            double weight = Math.sqrt(-Math.log(altPraw));
            for (Function function : (ArrayList)sourceRefMap.get(match)) {
                vector = this.srcFuncAddresstoVectorMap.get(function.getEntryPoint());
                if (vector == null) {
                    vector = new LSHCosineVectorAccum();
                    this.srcFuncAddresstoVectorMap.put(function.getEntryPoint(), vector);
                }
                vector.addHash(featureID, weight);
            }
            for (Function function : (ArrayList)destinationRefMap.get(match)) {
                vector = this.destFuncAddresstoVectorMap.get(function.getEntryPoint());
                if (vector == null) {
                    vector = new LSHCosineVectorAccum();
                    this.destFuncAddresstoVectorMap.put(function.getEntryPoint(), vector);
                }
                vector.addHash(featureID, weight);
            }
            ++featureID;
        }
        monitor.setMessage("Adding unmatched references to feature vectors.");
        double pSwitch = 0.5;
        double uniqueWeight = Math.sqrt(-Math.log(pSwitch));
        for (Address addr : this.srcFuncAddresstoVectorMap.keySet()) {
            Function func = srcFuncManager.getFunctionAt(addr);
            iter = this.sourceProgram.getListing().getCodeUnits(func.getBody(), true);
            int totalRefs = 0;
            while (iter.hasNext()) {
                Reference[] memRefs;
                CodeUnit cu = iter.next();
                for (Reference memRef : memRefs = cu.getReferencesFrom()) {
                    if (!this.isExpectedRefType(memRef)) continue;
                    ++totalRefs;
                }
            }
            LSHCosineVectorAccum srcVector = this.srcFuncAddresstoVectorMap.get(addr);
            int numEntries = srcVector.numEntries();
            for (i = 0; i < totalRefs - numEntries; ++i) {
                srcVector.addHash(featureID, uniqueWeight);
                ++featureID;
            }
        }
        for (Address addr : this.destFuncAddresstoVectorMap.keySet()) {
            Function func = destFuncManager.getFunctionAt(addr);
            iter = this.destinationListing.getCodeUnits(func.getBody(), true);
            int totalRefs = 0;
            while (iter.hasNext()) {
                Reference[] memRefs;
                CodeUnit cu = iter.next();
                for (Reference memRef : memRefs = cu.getReferencesFrom()) {
                    if (!this.isExpectedRefType(memRef)) continue;
                    ++totalRefs;
                }
            }
            LSHCosineVectorAccum dstVector = this.destFuncAddresstoVectorMap.get(addr);
            int numEntries = dstVector.numEntries();
            for (i = 0; i < totalRefs - numEntries; ++i) {
                dstVector.addHash(featureID, uniqueWeight);
                ++featureID;
            }
        }
    }
}

