/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.prediction;

import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.common.inject.Inject;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.task.MLPredictTaskRunner;
import org.opensearch.ml.task.MLTaskRunner;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

public class TransportPredictionTaskAction
extends HandledTransportAction<ActionRequest, MLTaskResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportPredictionTaskAction.class);
    private final MLTaskRunner<MLPredictionTaskRequest, MLTaskResponse> mlPredictTaskRunner;
    private final TransportService transportService;
    private final MLModelCacheHelper modelCacheHelper;

    @Inject
    public TransportPredictionTaskAction(TransportService transportService, ActionFilters actionFilters, MLPredictTaskRunner mlPredictTaskRunner, MLModelCacheHelper modelCacheHelper) {
        super("cluster:admin/opensearch/ml/predict", transportService, actionFilters, MLPredictionTaskRequest::new);
        this.mlPredictTaskRunner = mlPredictTaskRunner;
        this.transportService = transportService;
        this.modelCacheHelper = modelCacheHelper;
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<MLTaskResponse> listener) {
        MLPredictionTaskRequest mlPredictionTaskRequest = MLPredictionTaskRequest.fromActionRequest((ActionRequest)request);
        String modelId = mlPredictionTaskRequest.getModelId();
        String requestId = mlPredictionTaskRequest.getRequestID();
        log.debug("receive predict request " + requestId + " for model " + mlPredictionTaskRequest.getModelId());
        long startTime = System.nanoTime();
        this.mlPredictTaskRunner.run(mlPredictionTaskRequest, this.transportService, (ActionListener<MLTaskResponse>)ActionListener.runAfter(listener, () -> {
            long endTime = System.nanoTime();
            double durationInMs = (double)(endTime - startTime) / 1000000.0;
            this.modelCacheHelper.addPredictRequestDuration(modelId, durationInMs);
            log.debug("completed predict request " + requestId + " for model " + modelId);
        }));
    }
}

