/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.rest;

import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.common.xcontent.XContentParserUtils;
import org.opensearch.ml.action.stats.MLStatsNodesAction;
import org.opensearch.ml.action.stats.MLStatsNodesRequest;
import org.opensearch.ml.stats.MLClusterLevelStat;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStat;
import org.opensearch.ml.stats.MLStatLevel;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.stats.MLStatsInput;
import org.opensearch.ml.utils.IndexUtils;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestResponse;
import org.opensearch.rest.RestStatus;

public class RestMLStatsAction
extends BaseRestHandler {
    @Generated
    private static final Logger log = LogManager.getLogger(RestMLStatsAction.class);
    private static final String STATS_ML_ACTION = "stats_ml";
    private MLStats mlStats;
    private ClusterService clusterService;
    private IndexUtils indexUtils;

    public RestMLStatsAction(MLStats mlStats, ClusterService clusterService, IndexUtils indexUtils) {
        this.mlStats = mlStats;
        this.clusterService = clusterService;
        this.indexUtils = indexUtils;
    }

    public String getName() {
        return STATS_ML_ACTION;
    }

    public List<RestHandler.Route> routes() {
        return ImmutableList.of((Object)new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_ml/{nodeId}/stats/"), (Object)new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_ml/{nodeId}/stats/{stat}"), (Object)new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_ml/stats/"), (Object)new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_ml/stats/{stat}"));
    }

    protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
        MLStatsInput mlStatsInput;
        boolean hasContent = request.hasContent();
        if (hasContent) {
            XContentParser parser = request.contentParser();
            XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.nextToken(), (XContentParser)parser);
            mlStatsInput = MLStatsInput.parse(parser);
        } else {
            mlStatsInput = this.createMlStatsInputFromRequestParams(request);
        }
        String[] nodeIds = mlStatsInput.retrieveStatsOnAllNodes() ? this.getAllNodes() : mlStatsInput.getNodeIds().toArray(new String[0]);
        MLStatsNodesRequest mlStatsNodesRequest = new MLStatsNodesRequest(nodeIds, mlStatsInput);
        HashMap<MLClusterLevelStat, Object> clusterStatsMap = new HashMap<MLClusterLevelStat, Object>();
        if (mlStatsInput.getTargetStatLevels().contains((Object)MLStatLevel.CLUSTER)) {
            clusterStatsMap.putAll(this.getClusterStatsMap(mlStatsInput));
        }
        MLStatsInput finalMlStatsInput = mlStatsInput;
        return channel -> {
            if (finalMlStatsInput.getTargetStatLevels().contains((Object)MLStatLevel.CLUSTER) && (finalMlStatsInput.retrieveAllClusterLevelStats() || finalMlStatsInput.getClusterLevelStats().contains((Object)MLClusterLevelStat.ML_MODEL_COUNT))) {
                this.indexUtils.getNumberOfDocumentsInIndex(".plugins-ml-model", (ActionListener<Long>)ActionListener.wrap(count -> {
                    clusterStatsMap.put(MLClusterLevelStat.ML_MODEL_COUNT, count);
                    this.getNodeStats(finalMlStatsInput, (Map<MLClusterLevelStat, Object>)clusterStatsMap, client, mlStatsNodesRequest, (RestChannel)channel);
                }, e -> {
                    String errorMessage = "Failed to get ML model count";
                    log.error(errorMessage, (Throwable)e);
                    this.onFailure((RestChannel)channel, RestStatus.INTERNAL_SERVER_ERROR, errorMessage, (Exception)e);
                }));
            } else {
                this.getNodeStats(finalMlStatsInput, (Map<MLClusterLevelStat, Object>)clusterStatsMap, client, mlStatsNodesRequest, (RestChannel)channel);
            }
        };
    }

    MLStatsInput createMlStatsInputFromRequestParams(RestRequest request) {
        Optional<String[]> stats;
        MLStatsInput mlStatsInput = new MLStatsInput();
        Optional<String[]> nodeIds = RestActionUtils.splitCommaSeparatedParam(request, "nodeId");
        if (nodeIds.isPresent()) {
            mlStatsInput.getNodeIds().addAll(Arrays.asList(nodeIds.get()));
        }
        if ((stats = RestActionUtils.splitCommaSeparatedParam(request, "stat")).isPresent()) {
            for (String state : stats.get()) {
                if ((state = state.toUpperCase(Locale.ROOT)).startsWith("ML_NODE")) {
                    mlStatsInput.getNodeLevelStats().add(MLNodeLevelStat.from(state));
                    continue;
                }
                mlStatsInput.getClusterLevelStats().add(MLClusterLevelStat.from(state));
            }
            if (mlStatsInput.getClusterLevelStats().size() > 0) {
                mlStatsInput.getTargetStatLevels().add(MLStatLevel.CLUSTER);
            }
            if (mlStatsInput.getNodeLevelStats().size() > 0) {
                mlStatsInput.getTargetStatLevels().add(MLStatLevel.NODE);
            }
        } else {
            mlStatsInput.getTargetStatLevels().addAll(EnumSet.allOf(MLStatLevel.class));
        }
        return mlStatsInput;
    }

    void getNodeStats(MLStatsInput mlStatsInput, Map<MLClusterLevelStat, Object> clusterStatsMap, NodeClient client, MLStatsNodesRequest mlStatsNodesRequest, RestChannel channel) throws IOException {
        XContentBuilder builder = channel.newBuilder();
        if (mlStatsInput.onlyRetrieveClusterLevelStats()) {
            builder.startObject();
            if (clusterStatsMap != null && clusterStatsMap.size() > 0) {
                for (Map.Entry<MLClusterLevelStat, Object> entry : clusterStatsMap.entrySet()) {
                    builder.field(entry.getKey().name().toLowerCase(Locale.ROOT), entry.getValue());
                }
            }
            builder.endObject();
            channel.sendResponse((RestResponse)new BytesRestResponse(RestStatus.OK, builder));
        } else {
            client.execute((ActionType)MLStatsNodesAction.INSTANCE, (ActionRequest)mlStatsNodesRequest, ActionListener.wrap(r -> {
                List nodeStats;
                builder.startObject();
                if (clusterStatsMap != null && clusterStatsMap.size() > 0) {
                    for (Map.Entry entry : clusterStatsMap.entrySet()) {
                        builder.field(((MLClusterLevelStat)((Object)((Object)entry.getKey()))).name().toLowerCase(Locale.ROOT), entry.getValue());
                    }
                }
                if ((nodeStats = r.getNodes().stream().filter(s -> !s.isEmpty()).collect(Collectors.toList())) != null && nodeStats.size() > 0) {
                    r.toXContent(builder, ToXContent.EMPTY_PARAMS);
                }
                builder.endObject();
                channel.sendResponse((RestResponse)new BytesRestResponse(RestStatus.OK, builder));
            }, e -> {
                String errorMessage = "Failed to get ML node level stats";
                log.error(errorMessage, (Throwable)e);
                this.onFailure(channel, RestStatus.INTERNAL_SERVER_ERROR, errorMessage, (Exception)e);
            }));
        }
    }

    private String[] getAllNodes() {
        Iterator iterator = this.clusterService.state().nodes().iterator();
        ArrayList<String> nodeIds = new ArrayList<String>();
        while (iterator.hasNext()) {
            nodeIds.add(((DiscoveryNode)iterator.next()).getId());
        }
        return nodeIds.toArray(new String[0]);
    }

    private void onFailure(RestChannel channel, RestStatus status, String errorMessage, Exception exception) {
        BytesRestResponse bytesRestResponse;
        try {
            bytesRestResponse = new BytesRestResponse(channel, exception);
        }
        catch (Exception e) {
            bytesRestResponse = new BytesRestResponse(status, errorMessage);
        }
        channel.sendResponse((RestResponse)bytesRestResponse);
    }

    private Map<MLClusterLevelStat, Object> getClusterStatsMap(MLStatsInput mlStatsInput) {
        HashMap<MLClusterLevelStat, Object> clusterStats = new HashMap<MLClusterLevelStat, Object>();
        this.mlStats.getClusterStats().entrySet().stream().filter(s -> mlStatsInput.retrieveStat((Enum)s.getKey())).forEach(s -> clusterStats.put((MLClusterLevelStat)((Object)((Object)s.getKey())), ((MLStat)s.getValue()).getValue()));
        return clusterStats;
    }
}

