/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.model;

import java.time.Duration;
import java.time.Instant;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.commons.lang3.BooleanUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.TokenBucket;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLLimitExceededException;
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.engine.MLExecutable;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.model.MLModelCache;
import org.opensearch.ml.profile.MLModelProfile;
import org.opensearch.ml.settings.MLCommonsSettings;

public class MLModelCacheHelper {
    @Generated
    private static final Logger log = LogManager.getLogger(MLModelCacheHelper.class);
    private final Map<String, MLModelCache> modelCaches = new ConcurrentHashMap<String, MLModelCache>();
    private final Map<String, MLModel> autoDeployModels = new ConcurrentHashMap<String, MLModel>();
    private volatile Long maxRequestCount;

    public MLModelCacheHelper(ClusterService clusterService, Settings settings) {
        this.maxRequestCount = (Long)MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT, it -> {
            this.maxRequestCount = it;
        });
    }

    public synchronized void initModelState(String modelId, MLModelState state, FunctionName functionName, List<String> targetWorkerNodes, boolean deployToAllNodes) {
        if (this.isModelRunningOnNode(modelId) && !this.isAutoDeploying(modelId)) {
            throw new MLLimitExceededException("Duplicate deploy model task");
        }
        log.debug("init model state for model {}, state: {}", (Object)modelId, (Object)state);
        MLModelCache modelCache = new MLModelCache();
        modelCache.setModelState(state);
        modelCache.setFunctionName(functionName);
        modelCache.setTargetWorkerNodes(targetWorkerNodes);
        modelCache.setDeployToAllNodes(deployToAllNodes);
        modelCache.setLastAccessTime(Instant.now());
        this.modelCaches.put(modelId, modelCache);
    }

    public synchronized void initModelStateAutoDeploy(String modelId, MLModelState state, FunctionName functionName, List<String> targetWorkerNodes) {
        log.debug("init local model deployment state for model {}, state: {}", (Object)modelId, (Object)state);
        if (this.isModelRunningOnNode(modelId)) {
            return;
        }
        MLModelCache modelCache = new MLModelCache();
        modelCache.setModelState(state);
        modelCache.setFunctionName(functionName);
        modelCache.setTargetWorkerNodes(targetWorkerNodes);
        modelCache.setDeployToAllNodes(false);
        modelCache.setLastAccessTime(Instant.now());
        this.modelCaches.put(modelId, modelCache);
        this.setIsAutoDeploying(modelId, true);
    }

    public synchronized void setModelState(String modelId, MLModelState state) {
        log.debug("Updating State of Model {}  to state {}", (Object)modelId, (Object)state);
        this.getExistingModelCache(modelId).setModelState(state);
    }

    public synchronized void setRateLimiter(String modelId, TokenBucket rateLimiter) {
        log.debug("Setting the rate limiter for Model {}", (Object)modelId);
        this.getExistingModelCache(modelId).setRateLimiter(rateLimiter);
    }

    public TokenBucket getRateLimiter(String modelId) {
        MLModelCache modelCache = this.modelCaches.get(modelId);
        if (modelCache == null) {
            return null;
        }
        return modelCache.getRateLimiter();
    }

    public synchronized void removeRateLimiter(String modelId) {
        log.debug("Removing the rate limiter for Model {}", (Object)modelId);
        this.getExistingModelCache(modelId).setRateLimiter(null);
    }

    public synchronized void setUserRateLimiterMap(String modelId, Map<String, TokenBucket> userRateLimiterMap) {
        log.debug("Setting the user level rate limiter for Model {}", (Object)modelId);
        this.getExistingModelCache(modelId).setUserRateLimiterMap(userRateLimiterMap);
    }

    public synchronized void removeUserRateLimiterMap(String modelId) {
        log.debug("Removing the user level rate limiter for Model {}", (Object)modelId);
        this.getExistingModelCache(modelId).setUserRateLimiterMap(null);
    }

    public Map<String, TokenBucket> getUserRateLimiterMap(String modelId) {
        MLModelCache modelCache = this.modelCaches.get(modelId);
        if (modelCache == null) {
            return null;
        }
        return modelCache.getUserRateLimiterMap();
    }

    public TokenBucket getUserRateLimiter(String modelId, String user) {
        Map<String, TokenBucket> userRateLimiterMap = this.getUserRateLimiterMap(modelId);
        if (userRateLimiterMap == null) {
            return null;
        }
        return userRateLimiterMap.get(user);
    }

    public synchronized void setModelInterface(String modelId, Map<String, String> modelInterface) {
        log.debug("Setting ML Interface {} for Model {}", modelInterface, (Object)modelId);
        this.getExistingModelCache(modelId).setModelInterface(modelInterface);
    }

    public Map<String, String> getModelInterface(String modelId) {
        MLModelCache modelCache = this.modelCaches.get(modelId);
        if (modelCache == null) {
            return null;
        }
        return modelCache.getModelInterface();
    }

    public synchronized void removeModelInterface(String modelId) {
        log.debug("Removing the ML Interface from Model {}", (Object)modelId);
        this.getExistingModelCache(modelId).setModelInterface(null);
    }

    public synchronized void setMLGuard(String modelId, MLGuard mlGuard) {
        log.debug("Setting ML guard {} for Model {}", (Object)mlGuard, (Object)modelId);
        this.getExistingModelCache(modelId).setMlGuard(mlGuard);
    }

    public MLGuard getMLGuard(String modelId) {
        MLModelCache modelCache = this.modelCaches.get(modelId);
        if (modelCache == null) {
            return null;
        }
        return modelCache.getMlGuard();
    }

    public synchronized void removeMLGuard(String modelId) {
        log.debug("Removing the ML guard from Model {}", (Object)modelId);
        this.getExistingModelCache(modelId).setMlGuard(null);
    }

    public synchronized void setIsModelEnabled(String modelId, Boolean isModelEnabled) {
        log.debug("Setting the quota flag for Model {}", (Object)modelId);
        this.getExistingModelCache(modelId).setIsModelEnabled(isModelEnabled);
    }

    public Boolean getIsModelEnabled(String modelId) {
        MLModelCache modelCache = this.modelCaches.get(modelId);
        if (modelCache == null) {
            return null;
        }
        return modelCache.getIsModelEnabled();
    }

    public synchronized void setIsAutoDeploying(String modelId, Boolean isModelAutoDeploying) {
        log.debug("Setting the auto deploying flag for Model {}", (Object)modelId);
        this.getExistingModelCache(modelId).setIsAutoDeploying(isModelAutoDeploying);
    }

    public boolean isAutoDeploying(String modelId) {
        MLModelCache modelCache = this.modelCaches.get(modelId);
        return modelCache != null && BooleanUtils.isTrue((Boolean)modelCache.getIsAutoDeploying());
    }

    public synchronized void setMemSizeEstimation(String modelId, MLModelFormat format, Long size) {
        Long memSize = this.getMemSizeEstimation(format, size);
        log.debug("Updating memSizeEstimation of Model {}  to {}", (Object)modelId, (Object)memSize);
        this.getExistingModelCache(modelId).setMemSizeEstimationCPU(memSize);
        this.getExistingModelCache(modelId).setMemSizeEstimationGPU(memSize);
    }

    private Long getMemSizeEstimation(MLModelFormat format, Long size) {
        Double scale = 1.0;
        switch (format) {
            case ONNX: {
                scale = 1.5;
                break;
            }
            case TORCH_SCRIPT: {
                scale = 1.2;
            }
        }
        Long memSize = Double.valueOf(scale * (double)size.longValue()).longValue();
        return memSize;
    }

    public Long getMemEstCPU(String modelId) {
        MLModelCache modelCache = this.modelCaches.get(modelId);
        if (modelCache == null) {
            return null;
        }
        return modelCache.getMemSizeEstimationCPU();
    }

    public Long getMemEstGPU(String modelId) {
        MLModelCache modelCache = this.modelCaches.get(modelId);
        if (modelCache == null) {
            return null;
        }
        return modelCache.getMemSizeEstimationGPU();
    }

    public synchronized boolean isModelDeployed(String modelId) {
        MLModelCache modelCache = this.modelCaches.get(modelId);
        return modelCache != null && modelCache.getModelState() == MLModelState.DEPLOYED;
    }

    public String[] getDeployedModels() {
        return this.modelCaches.entrySet().stream().filter(entry -> ((MLModelCache)entry.getValue()).getModelState() == MLModelState.DEPLOYED).map(entry -> (String)entry.getKey()).collect(Collectors.toList()).toArray(new String[0]);
    }

    public String[] getLocalDeployedModels() {
        return this.modelCaches.entrySet().stream().filter(entry -> ((MLModelCache)entry.getValue()).getModelState() == MLModelState.DEPLOYED && ((MLModelCache)entry.getValue()).getFunctionName() != FunctionName.REMOTE).map(entry -> (String)entry.getKey()).collect(Collectors.toList()).toArray(new String[0]);
    }

    public String[] getExpiredModels() {
        return this.modelCaches.entrySet().stream().filter(entry -> {
            MLModelCache modelCache = (MLModelCache)entry.getValue();
            MLModel mlModel = modelCache.getCachedModelInfo();
            MLModelState modelState = modelCache.getModelState();
            if (mlModel == null || mlModel.getDeploySetting() == null) {
                return false;
            }
            Duration liveDuration = Duration.between(((MLModelCache)entry.getValue()).getLastAccessTime(), Instant.now());
            Long ttlInMinutes = mlModel.getDeploySetting().getModelTTLInMinutes();
            if (ttlInMinutes < 0L) {
                return false;
            }
            Duration ttl = Duration.ofMinutes(ttlInMinutes);
            boolean isModelExpired = liveDuration.getSeconds() >= ttl.getSeconds();
            return isModelExpired && (modelState == MLModelState.DEPLOYED || modelState == MLModelState.PARTIALLY_DEPLOYED);
        }).map(entry -> (String)entry.getKey()).collect(Collectors.toList()).toArray(new String[0]);
    }

    public boolean isModelRunningOnNode(String modelId) {
        MLModelCache modelCache = this.modelCaches.get(modelId);
        return modelCache != null && modelCache.getModelState() != null;
    }

    public synchronized void setPredictor(String modelId, Predictable predictor) {
        MLModelCache modelCache = this.getExistingModelCache(modelId);
        modelCache.setPredictor(predictor);
    }

    public synchronized void setMLExecutor(String modelId, MLExecutable mlExecutor) {
        MLModelCache modelCache = this.getExistingModelCache(modelId);
        modelCache.setExecutor(mlExecutor);
    }

    public MLExecutable getMLExecutor(String modelId) {
        MLModelCache modelCache = this.modelCaches.get(modelId);
        if (modelCache == null) {
            return null;
        }
        return modelCache.getExecutor();
    }

    public Predictable getPredictor(String modelId) {
        MLModelCache modelCache = this.modelCaches.get(modelId);
        if (modelCache == null) {
            return null;
        }
        return modelCache.getPredictor();
    }

    public void setTargetWorkerNodes(String modelId, List<String> targetWorkerNodes) {
        MLModelCache modelCache = this.modelCaches.get(modelId);
        if (modelCache != null) {
            modelCache.setTargetWorkerNodes(targetWorkerNodes);
        }
    }

    public void refreshLastAccessTime(String modelId) {
        MLModelCache modelCache = this.modelCaches.get(modelId);
        modelCache.setLastAccessTime(Instant.now());
    }

    public void removeModel(String modelId) {
        MLModelCache modelCache = this.modelCaches.get(modelId);
        if (modelCache != null) {
            log.debug("removing model {} from cache", (Object)modelId);
            modelCache.clear();
            this.modelCaches.remove(modelId);
        }
        this.autoDeployModels.remove(modelId);
    }

    public String[] getAllModels() {
        return this.modelCaches.keySet().toArray(new String[0]);
    }

    public String[] getWorkerNodes(String modelId) {
        MLModelCache modelCache = this.modelCaches.get(modelId);
        if (modelCache == null) {
            return null;
        }
        return modelCache.getWorkerNodes();
    }

    public synchronized void addWorkerNode(String modelId, String nodeId) {
        log.debug("add node {} to model routing table for model: {}", (Object)nodeId, (Object)modelId);
        MLModelCache modelCache = this.getOrCreateModelCache(modelId);
        modelCache.addWorkerNode(nodeId);
    }

    public void removeWorkerNodes(Set<String> removedNodes, boolean isFromUndeploy) {
        Set<String> modelIds = this.modelCaches.keySet();
        for (String modelId : modelIds) {
            MLModelCache modelCache = this.modelCaches.get(modelId);
            log.debug("remove worker nodes of model {} : {}", (Object)modelId, (Object)removedNodes.toArray(new String[0]));
            modelCache.removeWorkerNodes(removedNodes, isFromUndeploy);
            if (modelCache.isValidCache()) continue;
            log.debug("remove model cache {}", (Object)modelId);
            this.modelCaches.remove(modelId);
        }
    }

    public void removeWorkerNode(String modelId, String nodeId, boolean isFromUndeploy) {
        MLModelCache modelCache = this.modelCaches.get(modelId);
        if (modelCache != null) {
            log.debug("remove worker node {} of model {} from cache", (Object)nodeId, (Object)modelId);
            modelCache.removeWorkerNode(nodeId, isFromUndeploy);
            if (!modelCache.isValidCache()) {
                log.debug("remove model {} from cache as no node running it", (Object)modelId);
                this.modelCaches.remove(modelId);
            }
        }
    }

    public void syncWorkerNodes(Map<String, Set<String>> modelWorkerNodes) {
        log.debug("sync model worker nodes");
        HashSet<String> currentModels = new HashSet<String>(this.modelCaches.keySet());
        currentModels.removeAll(modelWorkerNodes.keySet());
        if (currentModels.size() > 0) {
            currentModels.forEach(modelId -> this.clearWorkerNodes((String)modelId));
        }
        modelWorkerNodes.entrySet().forEach(entry -> {
            MLModelCache modelCache = this.getOrCreateModelCache((String)entry.getKey());
            modelCache.syncWorkerNode((Set)entry.getValue());
        });
    }

    public void clearWorkerNodes() {
        log.debug("clear all model worker nodes");
        this.modelCaches.entrySet().forEach(entry -> this.clearWorkerNodes((String)entry.getKey()));
    }

    public void clearWorkerNodes(String modelId) {
        MLModelCache modelCache = this.modelCaches.get(modelId);
        if (modelCache != null) {
            log.debug("clear worker nodes of model {}", (Object)modelId);
            modelCache.clearWorkerNodes();
            if (!modelCache.isValidCache()) {
                this.modelCaches.remove(modelId);
            }
        }
    }

    public MLModelProfile getModelProfile(String modelId) {
        String[] workerNodes;
        String[] targetWorkerNodes;
        MLModelCache modelCache = this.modelCaches.get(modelId);
        if (modelCache == null) {
            return null;
        }
        MLModelProfile.MLModelProfileBuilder builder = MLModelProfile.builder();
        builder.modelState(modelCache.getModelState());
        if (modelCache.getPredictor() != null) {
            builder.predictor(modelCache.getPredictor().toString());
        }
        if ((targetWorkerNodes = modelCache.getTargetWorkerNodes()).length > 0) {
            builder.targetWorkerNodes(targetWorkerNodes);
        }
        if ((workerNodes = modelCache.getWorkerNodes()).length > 0) {
            builder.workerNodes(workerNodes);
        }
        builder.modelInferenceStats(modelCache.getInferenceStats(true));
        builder.predictRequestStats(modelCache.getInferenceStats(false));
        builder.memSizeEstimationCPU(modelCache.getMemSizeEstimationCPU());
        builder.memSizeEstimationGPU(modelCache.getMemSizeEstimationGPU());
        return builder.build();
    }

    public void addModelInferenceDuration(String modelId, double duration) {
        MLModelCache modelCache = this.getOrCreateModelCache(modelId);
        modelCache.addModelInferenceDuration(duration, this.maxRequestCount);
    }

    public void addPredictRequestDuration(String modelId, double duration) {
        MLModelCache modelCache = this.getOrCreateModelCache(modelId);
        modelCache.addPredictRequestDuration(duration, this.maxRequestCount);
    }

    public void resizeMonitoringQueue(long monitoringReqCount) {
        for (Map.Entry<String, MLModelCache> entry : this.modelCaches.entrySet()) {
            entry.getValue().resizeMonitoringQueue(monitoringReqCount);
        }
    }

    public FunctionName getFunctionName(String modelId) {
        MLModelCache modelCache = this.getExistingModelCache(modelId);
        return modelCache.getFunctionName();
    }

    public Optional<FunctionName> getOptionalFunctionName(String modelId) {
        MLModelCache modelCache = this.modelCaches.get(modelId);
        FunctionName functionName = modelCache == null ? null : modelCache.getFunctionName();
        return Optional.ofNullable(functionName);
    }

    public void setDeployToAllNodes(String modelId, Boolean deployToAllNodes) {
        MLModelCache modelCache = this.modelCaches.get(modelId);
        if (modelCache != null) {
            log.info("Starting to set deployToAllNodes flag to modelId: {}, value to: {}", (Object)modelId, (Object)deployToAllNodes);
            modelCache.setDeployToAllNodes(deployToAllNodes);
        }
    }

    public boolean getDeployToAllNodes(String modelId) {
        MLModelCache mlModelCache = this.getExistingModelCache(modelId);
        return mlModelCache.isDeployToAllNodes();
    }

    public void setModelInfo(String modelId, MLModel mlModel) {
        MLModelCache mlModelCache = this.modelCaches.get(modelId);
        if (mlModelCache != null) {
            mlModelCache.setModelInfo(mlModel);
        }
    }

    public MLModel getModelInfo(String modelId) {
        MLModelCache mlModelCache = this.modelCaches.get(modelId);
        if (mlModelCache == null) {
            return null;
        }
        return mlModelCache.getCachedModelInfo();
    }

    private MLModelCache getExistingModelCache(String modelId) {
        MLModelCache modelCache = this.modelCaches.get(modelId);
        if (modelCache == null) {
            throw new IllegalArgumentException("Model not found in cache");
        }
        return modelCache;
    }

    private MLModelCache getOrCreateModelCache(String modelId) {
        return this.modelCaches.computeIfAbsent(modelId, it -> new MLModelCache());
    }

    public MLModel addModelToAutoDeployCache(String modelId, MLModel model) {
        MLModel addedModel = this.autoDeployModels.computeIfAbsent(modelId, key -> model);
        if (addedModel == model) {
            log.info("Add model {} to auto deploy cache", (Object)modelId);
        }
        return addedModel;
    }

    public void removeAutoDeployModel(String modelId) {
        MLModel removedModel = this.autoDeployModels.remove(modelId);
        if (removedModel != null) {
            log.info("Remove model {} from auto deploy cache", (Object)modelId);
        }
    }
}

