/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ad.transport;

import java.time.Instant;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.action.ActionListener;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.support.master.AcknowledgedResponse;
import org.opensearch.ad.NodeStateManager;
import org.opensearch.ad.breaker.ADCircuitBreakerService;
import org.opensearch.ad.caching.CacheProvider;
import org.opensearch.ad.common.exception.EndRunException;
import org.opensearch.ad.common.exception.LimitExceededException;
import org.opensearch.ad.indices.ADIndex;
import org.opensearch.ad.indices.AnomalyDetectionIndices;
import org.opensearch.ad.ml.EntityModel;
import org.opensearch.ad.ml.ModelManager;
import org.opensearch.ad.ml.ModelState;
import org.opensearch.ad.ml.ThresholdingResult;
import org.opensearch.ad.model.AnomalyDetector;
import org.opensearch.ad.model.AnomalyResult;
import org.opensearch.ad.model.Entity;
import org.opensearch.ad.ratelimit.CheckpointReadWorker;
import org.opensearch.ad.ratelimit.ColdEntityWorker;
import org.opensearch.ad.ratelimit.EntityColdStartWorker;
import org.opensearch.ad.ratelimit.EntityFeatureRequest;
import org.opensearch.ad.ratelimit.RequestPriority;
import org.opensearch.ad.ratelimit.ResultWriteRequest;
import org.opensearch.ad.ratelimit.ResultWriteWorker;
import org.opensearch.ad.stats.ADStats;
import org.opensearch.ad.stats.StatNames;
import org.opensearch.ad.transport.EntityResultAction;
import org.opensearch.ad.transport.EntityResultRequest;
import org.opensearch.ad.util.ExceptionUtil;
import org.opensearch.ad.util.ParseUtils;
import org.opensearch.common.inject.Inject;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

public class EntityResultTransportAction
extends HandledTransportAction<EntityResultRequest, AcknowledgedResponse> {
    private static final Logger LOG = LogManager.getLogger(EntityResultTransportAction.class);
    private ModelManager modelManager;
    private ADCircuitBreakerService adCircuitBreakerService;
    private CacheProvider cache;
    private final NodeStateManager stateManager;
    private AnomalyDetectionIndices indexUtil;
    private ResultWriteWorker resultWriteQueue;
    private CheckpointReadWorker checkpointReadQueue;
    private ColdEntityWorker coldEntityQueue;
    private ThreadPool threadPool;
    private EntityColdStartWorker entityColdStartWorker;
    private ADStats adStats;

    @Inject
    public EntityResultTransportAction(ActionFilters actionFilters, TransportService transportService, ModelManager manager, ADCircuitBreakerService adCircuitBreakerService, CacheProvider entityCache, NodeStateManager stateManager, AnomalyDetectionIndices indexUtil, ResultWriteWorker resultWriteQueue, CheckpointReadWorker checkpointReadQueue, ColdEntityWorker coldEntityQueue, ThreadPool threadPool, EntityColdStartWorker entityColdStartWorker, ADStats adStats) {
        super(EntityResultAction.NAME, transportService, actionFilters, EntityResultRequest::new);
        this.modelManager = manager;
        this.adCircuitBreakerService = adCircuitBreakerService;
        this.cache = entityCache;
        this.stateManager = stateManager;
        this.indexUtil = indexUtil;
        this.resultWriteQueue = resultWriteQueue;
        this.checkpointReadQueue = checkpointReadQueue;
        this.coldEntityQueue = coldEntityQueue;
        this.threadPool = threadPool;
        this.entityColdStartWorker = entityColdStartWorker;
        this.adStats = adStats;
    }

    protected void doExecute(Task task, EntityResultRequest request, ActionListener<AcknowledgedResponse> listener) {
        if (this.adCircuitBreakerService.isOpen().booleanValue()) {
            this.threadPool.executor("ad-threadpool").execute(() -> this.cache.get().releaseMemoryForOpenCircuitBreaker());
            listener.onFailure((Exception)new LimitExceededException(request.getDetectorId(), "AD memory circuit is broken.", false));
            return;
        }
        try {
            String detectorId = request.getDetectorId();
            Optional<Exception> previousException = this.stateManager.fetchExceptionAndClear(detectorId);
            if (previousException.isPresent()) {
                EndRunException endRunException;
                Exception exception = previousException.get();
                LOG.error("Previous exception of {}: {}", (Object)detectorId, (Object)exception);
                if (exception instanceof EndRunException && (endRunException = (EndRunException)exception).isEndNow()) {
                    listener.onFailure(exception);
                    return;
                }
                listener = ExceptionUtil.wrapListener(listener, exception, detectorId);
            }
            this.stateManager.getAnomalyDetector(detectorId, this.onGetDetector(listener, detectorId, request, previousException));
        }
        catch (Exception exception) {
            LOG.error("fail to get entity's anomaly grade", (Throwable)exception);
            listener.onFailure(exception);
        }
    }

    private ActionListener<Optional<AnomalyDetector>> onGetDetector(ActionListener<AcknowledgedResponse> listener, String detectorId, EntityResultRequest request, Optional<Exception> prevException) {
        return ActionListener.wrap(detectorOptional -> {
            if (!detectorOptional.isPresent()) {
                listener.onFailure((Exception)new EndRunException(detectorId, "AnomalyDetector is not available.", false));
                return;
            }
            AnomalyDetector detector = (AnomalyDetector)detectorOptional.get();
            if (request.getEntities() == null) {
                listener.onFailure((Exception)new EndRunException(detectorId, "Fail to get any entities from request.", false));
                return;
            }
            Instant executionStartTime = Instant.now();
            HashMap<Entity, double[]> cacheMissEntities = new HashMap<Entity, double[]>();
            for (Map.Entry<Entity, double[]> entityEntry : request.getEntities().entrySet()) {
                Optional<String> modelIdOptional;
                Entity categoricalValues = entityEntry.getKey();
                if (this.isEntityFromOldNodeMsg(categoricalValues) && detector.getCategoryField() != null && detector.getCategoryField().size() == 1) {
                    Map<String, String> attrValues = categoricalValues.getAttributes();
                    categoricalValues = Entity.createSingleAttributeEntity(detector.getCategoryField().get(0), attrValues.get(""));
                }
                if (!(modelIdOptional = categoricalValues.getModelId(detectorId)).isPresent()) continue;
                String modelId = modelIdOptional.get();
                double[] datapoint = entityEntry.getValue();
                ModelState<EntityModel> entityModel = this.cache.get().get(modelId, detector);
                if (entityModel == null) {
                    cacheMissEntities.put(categoricalValues, datapoint);
                    continue;
                }
                try {
                    ThresholdingResult result = this.modelManager.getAnomalyResultForEntity(datapoint, entityModel, modelId, categoricalValues, detector.getShingleSize());
                    if (!(result.getRcfScore() > 0.0)) continue;
                    AnomalyResult resultToSave = result.toAnomalyResult(detector, Instant.ofEpochMilli(request.getStart()), Instant.ofEpochMilli(request.getEnd()), executionStartTime, Instant.now(), ParseUtils.getFeatureData(datapoint, detector), categoricalValues, this.indexUtil.getSchemaVersion(ADIndex.RESULT), modelId, null, null);
                    this.resultWriteQueue.put(new ResultWriteRequest(System.currentTimeMillis() + detector.getDetectorIntervalInMilliseconds(), detectorId, result.getGrade() > 0.0 ? RequestPriority.HIGH : RequestPriority.MEDIUM, resultToSave, detector.getResultIndex()));
                }
                catch (IllegalArgumentException e) {
                    LOG.error((Message)new ParameterizedMessage("Likely model corruption for [{}]", (Object)modelId), (Throwable)e);
                    this.adStats.getStat(StatNames.MODEL_CORRUTPION_COUNT.getName()).increment();
                    this.cache.get().removeEntityModel(detectorId, modelId);
                    this.entityColdStartWorker.put(new EntityFeatureRequest(System.currentTimeMillis() + detector.getDetectorIntervalInMilliseconds(), detectorId, RequestPriority.MEDIUM, categoricalValues, datapoint, request.getStart()));
                }
            }
            Pair<List<Entity>, List<Entity>> hotColdEntities = this.cache.get().selectUpdateCandidate(cacheMissEntities.keySet(), detectorId, detector);
            ArrayList<EntityFeatureRequest> hotEntityRequests = new ArrayList<EntityFeatureRequest>();
            ArrayList<EntityFeatureRequest> coldEntityRequests = new ArrayList<EntityFeatureRequest>();
            for (Entity hotEntity : (List)hotColdEntities.getLeft()) {
                double[] hotEntityValue = (double[])cacheMissEntities.get(hotEntity);
                if (hotEntityValue == null) {
                    LOG.error((Message)new ParameterizedMessage("feature value should not be null: [{}]", (Object)hotEntity));
                    continue;
                }
                hotEntityRequests.add(new EntityFeatureRequest(System.currentTimeMillis() + detector.getDetectorIntervalInMilliseconds(), detectorId, RequestPriority.MEDIUM, hotEntity, hotEntityValue, request.getStart()));
            }
            for (Entity coldEntity : (List)hotColdEntities.getRight()) {
                double[] coldEntityValue = (double[])cacheMissEntities.get(coldEntity);
                if (coldEntityValue == null) {
                    LOG.error((Message)new ParameterizedMessage("feature value should not be null: [{}]", (Object)coldEntity));
                    continue;
                }
                coldEntityRequests.add(new EntityFeatureRequest(System.currentTimeMillis() + detector.getDetectorIntervalInMilliseconds(), detectorId, RequestPriority.LOW, coldEntity, coldEntityValue, request.getStart()));
            }
            this.checkpointReadQueue.putAll(hotEntityRequests);
            this.coldEntityQueue.putAll(coldEntityRequests);
            if (prevException.isPresent()) {
                listener.onFailure((Exception)prevException.get());
            } else {
                listener.onResponse((Object)new AcknowledgedResponse(true));
            }
        }, exception -> {
            LOG.error((Message)new ParameterizedMessage("fail to get entity's anomaly grade for detector [{}]: start: [{}], end: [{}]", new Object[]{detectorId, request.getStart(), request.getEnd()}), (Throwable)exception);
            listener.onFailure(exception);
        });
    }

    private boolean isEntityFromOldNodeMsg(Entity categoricalValues) {
        Map<String, String> attrValues = categoricalValues.getAttributes();
        return attrValues != null && attrValues.containsKey("");
    }
}

