/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.remote;

import java.net.URL;
import java.security.AccessController;
import java.time.Duration;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import lombok.Generated;
import org.apache.commons.text.StringEscapeUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.util.TokenBucket;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.HttpConnector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.httpclient.MLHttpClientFactory;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.engine.algorithms.remote.AbstractConnectorExecutor;
import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils;
import org.opensearch.ml.engine.algorithms.remote.ExecutionContext;
import org.opensearch.ml.engine.algorithms.remote.MLSdkAsyncHttpResponseHandler;
import org.opensearch.ml.engine.algorithms.remote.streaming.StreamPredictActionListener;
import org.opensearch.ml.engine.algorithms.remote.streaming.StreamingHandler;
import org.opensearch.ml.engine.algorithms.remote.streaming.StreamingHandlerFactory;
import org.opensearch.ml.engine.annotation.ConnectorExecutor;
import org.opensearch.script.ScriptService;
import org.opensearch.transport.StreamTransportService;
import org.opensearch.transport.client.Client;
import software.amazon.awssdk.core.internal.http.async.SimpleHttpContentPublisher;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.http.async.AsyncExecuteRequest;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler;
import software.amazon.awssdk.http.async.SdkHttpContentPublisher;

@ConnectorExecutor(value="http")
public class HttpJsonConnectorExecutor
extends AbstractConnectorExecutor {
    @Generated
    private static final Logger log = LogManager.getLogger(HttpJsonConnectorExecutor.class);
    private HttpConnector connector;
    private ScriptService scriptService;
    private TokenBucket rateLimiter;
    private Map<String, TokenBucket> userRateLimiterMap;
    private Client client;
    private MLGuard mlGuard;
    private volatile AtomicBoolean connectorPrivateIpEnabled;
    private SdkAsyncHttpClient httpClient;
    private StreamTransportService streamTransportService;

    public HttpJsonConnectorExecutor(Connector connector) {
        super.initialize(connector);
        this.connector = (HttpConnector)connector;
        Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout().intValue());
        Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout().intValue());
        Integer maxConnection = super.getConnectorClientConfig().getMaxConnections();
        this.httpClient = MLHttpClientFactory.getAsyncHttpClient((Duration)connectionTimeout, (Duration)readTimeout, (int)maxConnection);
    }

    @Override
    public Logger getLogger() {
        return log;
    }

    @Override
    public void invokeRemoteService(String action, MLInput mlInput, Map<String, String> parameters, String payload, ExecutionContext executionContext, ActionListener<Tuple<Integer, ModelTensors>> actionListener) {
        try {
            SdkHttpFullRequest request = switch (this.connector.getActionHttpMethod(action).toUpperCase(Locale.ROOT)) {
                case "POST" -> {
                    log.debug("original payload to remote model: " + payload);
                    this.validateHttpClientParameters(action, parameters);
                    yield ConnectorUtils.buildSdkRequest(action, (Connector)this.connector, parameters, payload, SdkHttpMethod.POST);
                }
                case "GET" -> {
                    this.validateHttpClientParameters(action, parameters);
                    yield ConnectorUtils.buildSdkRequest(action, (Connector)this.connector, parameters, null, SdkHttpMethod.GET);
                }
                default -> throw new IllegalArgumentException("unsupported http method");
            };
            AsyncExecuteRequest executeRequest = AsyncExecuteRequest.builder().request((SdkHttpRequest)request).requestContentPublisher((SdkHttpContentPublisher)new SimpleHttpContentPublisher(request)).responseHandler((SdkAsyncHttpResponseHandler)new MLSdkAsyncHttpResponseHandler(executionContext, actionListener, parameters, (Connector)this.connector, this.scriptService, this.mlGuard, action)).build();
            AccessController.doPrivileged(() -> this.httpClient.execute(executeRequest));
        }
        catch (RuntimeException e) {
            log.error("Fail to execute http connector", (Throwable)e);
            actionListener.onFailure((Exception)e);
        }
        catch (Throwable e) {
            log.error("Fail to execute http connector", e);
            actionListener.onFailure((Exception)new MLException("Fail to execute http connector", e));
        }
    }

    @Override
    public void invokeRemoteServiceStream(String action, MLInput mlInput, Map<String, String> parameters, String payload, ExecutionContext executionContext, StreamPredictActionListener<MLTaskResponse, ?> actionListener) {
        try {
            String llmInterface = parameters.get("_llm_interface");
            llmInterface = llmInterface.trim().toLowerCase(Locale.ROOT);
            llmInterface = StringEscapeUtils.unescapeJava((String)llmInterface);
            this.validateLLMInterface(llmInterface);
            StreamingHandler handler = StreamingHandlerFactory.createHandler(llmInterface, (Connector)this.connector, null, super.getConnectorClientConfig());
            handler.startStream(action, parameters, payload, actionListener);
        }
        catch (Exception e) {
            log.error("Failed to execute streaming", (Throwable)e);
            actionListener.onFailure((Exception)new MLException("Fail to execute streaming", (Throwable)e));
        }
    }

    private void validateHttpClientParameters(String action, Map<String, String> parameters) throws Exception {
        String endpoint = this.connector.getActionEndpoint(action, parameters);
        URL url = new URL(endpoint);
        String protocol = url.getProtocol();
        String host = url.getHost();
        int port = url.getPort();
        MLHttpClientFactory.validate((String)protocol, (String)host, (int)port, (AtomicBoolean)this.connectorPrivateIpEnabled);
    }

    private void validateLLMInterface(String llmInterface) {
        switch (llmInterface) {
            case "openai/v1/chat/completions": {
                break;
            }
            default: {
                throw new IllegalArgumentException(String.format("Unsupported llm interface: %s", llmInterface));
            }
        }
    }

    @Generated
    public HttpConnector getConnector() {
        return this.connector;
    }

    @Override
    @Generated
    public void setScriptService(ScriptService scriptService) {
        this.scriptService = scriptService;
    }

    @Override
    @Generated
    public ScriptService getScriptService() {
        return this.scriptService;
    }

    @Override
    @Generated
    public void setRateLimiter(TokenBucket rateLimiter) {
        this.rateLimiter = rateLimiter;
    }

    @Override
    @Generated
    public TokenBucket getRateLimiter() {
        return this.rateLimiter;
    }

    @Override
    @Generated
    public void setUserRateLimiterMap(Map<String, TokenBucket> userRateLimiterMap) {
        this.userRateLimiterMap = userRateLimiterMap;
    }

    @Override
    @Generated
    public Map<String, TokenBucket> getUserRateLimiterMap() {
        return this.userRateLimiterMap;
    }

    @Override
    @Generated
    public void setClient(Client client) {
        this.client = client;
    }

    @Override
    @Generated
    public Client getClient() {
        return this.client;
    }

    @Override
    @Generated
    public void setMlGuard(MLGuard mlGuard) {
        this.mlGuard = mlGuard;
    }

    @Override
    @Generated
    public MLGuard getMlGuard() {
        return this.mlGuard;
    }

    @Override
    @Generated
    public void setConnectorPrivateIpEnabled(AtomicBoolean connectorPrivateIpEnabled) {
        this.connectorPrivateIpEnabled = connectorPrivateIpEnabled;
    }

    @Generated
    public void setStreamTransportService(StreamTransportService streamTransportService) {
        this.streamTransportService = streamTransportService;
    }

    @Generated
    public StreamTransportService getStreamTransportService() {
        return this.streamTransportService;
    }
}

