/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.mcpserver;

import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.bulk.BulkItemResponse;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.ml.action.mcpserver.McpToolsHelper;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.MLIndex;
import org.opensearch.ml.common.settings.MLCommonsSettings;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.mcpserver.action.MLMcpToolsUpdateOnNodesAction;
import org.opensearch.ml.common.transport.mcpserver.requests.McpToolBaseInput;
import org.opensearch.ml.common.transport.mcpserver.requests.register.McpToolRegisterInput;
import org.opensearch.ml.common.transport.mcpserver.requests.update.MLMcpToolsUpdateNodesRequest;
import org.opensearch.ml.common.transport.mcpserver.requests.update.McpToolUpdateInput;
import org.opensearch.ml.common.transport.mcpserver.responses.update.MLMcpToolsUpdateNodesResponse;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.client.Client;

public class TransportMcpToolsUpdateAction
extends HandledTransportAction<ActionRequest, MLMcpToolsUpdateNodesResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportMcpToolsUpdateAction.class);
    TransportService transportService;
    ClusterService clusterService;
    ThreadPool threadPool;
    Client client;
    NamedXContentRegistry xContentRegistry;
    DiscoveryNodeHelper nodeFilter;
    private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
    private final McpToolsHelper mcpToolsHelper;

    @Inject
    public TransportMcpToolsUpdateAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService, ThreadPool threadPool, Client client, NamedXContentRegistry xContentRegistry, DiscoveryNodeHelper nodeFilter, McpToolsHelper mcpToolsHelper, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
        super("cluster:admin/opensearch/ml/mcp_tools/update", transportService, actionFilters, MLMcpToolsUpdateNodesRequest::new);
        this.transportService = transportService;
        this.clusterService = clusterService;
        this.threadPool = threadPool;
        this.client = client;
        this.xContentRegistry = xContentRegistry;
        this.nodeFilter = nodeFilter;
        this.mcpToolsHelper = mcpToolsHelper;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<MLMcpToolsUpdateNodesResponse> listener) {
        if (!this.mlFeatureEnabledSetting.isMcpServerEnabled()) {
            listener.onFailure((Exception)new OpenSearchException(MLCommonsSettings.ML_COMMONS_MCP_SERVER_DISABLED_MESSAGE, new Object[0]));
            return;
        }
        if (!this.clusterService.state().metadata().hasIndex(MLIndex.MCP_TOOLS.getIndexName())) {
            listener.onFailure((Exception)new OpenSearchException("MCP tools index doesn't exist", new Object[0]));
            return;
        }
        MLMcpToolsUpdateNodesRequest updateNodesRequest = (MLMcpToolsUpdateNodesRequest)request;
        HashSet updateToolSet = new HashSet();
        updateNodesRequest.getMcpTools().forEach(x -> updateToolSet.add(x.getName()));
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            ActionListener restoreListener = ActionListener.runBefore(listener, () -> ((ThreadContext.StoredContext)context).restore());
            ActionListener searchResultListener = ActionListener.wrap(searchResult -> {
                if (Objects.requireNonNull(searchResult.getHits().getHits()).length > 0) {
                    ArrayList<SearchedMcpToolWrapper> searchedMcpToolWrappers = new ArrayList<SearchedMcpToolWrapper>();
                    Arrays.stream(Objects.requireNonNull(searchResult.getHits().getHits())).forEach(x -> {
                        try (XContentParser parser = JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY, (DeprecationHandler)LoggingDeprecationHandler.INSTANCE, x.getSourceAsString());){
                            XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.nextToken(), (XContentParser)parser);
                            McpToolRegisterInput registerMcpTool = McpToolRegisterInput.parse((XContentParser)parser);
                            updateToolSet.remove(registerMcpTool.getName());
                            SearchedMcpToolWrapper updateMcpToolWrapper = new SearchedMcpToolWrapper.SearchedMcpToolWrapperBuilder().seqNo(x.getSeqNo()).primaryTerm(x.getPrimaryTerm()).mcpTool(registerMcpTool).build();
                            searchedMcpToolWrappers.add(updateMcpToolWrapper);
                        }
                        catch (IOException e) {
                            log.error("Failed to parse mcp tools configuration");
                            restoreListener.onFailure((Exception)e);
                        }
                    });
                    if (!updateToolSet.isEmpty()) {
                        String errMsg = String.format("Failed to find tools: %s in system index", updateToolSet);
                        log.warn(errMsg);
                        restoreListener.onFailure((Exception)new OpenSearchException(errMsg, new Object[0]));
                    } else {
                        this.updateMcpTools(updateNodesRequest, searchedMcpToolWrappers, (ActionListener<MLMcpToolsUpdateNodesResponse>)restoreListener);
                    }
                } else {
                    restoreListener.onFailure((Exception)new OpenSearchException("Failed to update tools as none of them is found in index", new Object[0]));
                }
            }, e -> {
                log.error("Failed to search mcp tools index", (Throwable)e);
                restoreListener.onFailure(e);
            });
            this.mcpToolsHelper.searchToolsWithPrimaryTermAndSeqNo(updateNodesRequest.getMcpTools().stream().map(McpToolBaseInput::getName).toList(), (ActionListener<SearchResponse>)searchResultListener);
        }
        catch (Exception e2) {
            log.error("Failed to update mcp tools", (Throwable)e2);
            listener.onFailure(e2);
        }
    }

    private void updateMcpTools(MLMcpToolsUpdateNodesRequest updateNodesRequest, List<SearchedMcpToolWrapper> searchedMcpToolWrappers, ActionListener<MLMcpToolsUpdateNodesResponse> listener) {
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            ActionListener restoreListener = ActionListener.runBefore(listener, () -> ((ThreadContext.StoredContext)context).restore());
            ActionListener updateResultListener = ActionListener.wrap(bulkResponse -> {
                if (!bulkResponse.hasFailures()) {
                    this.updateMcpToolsOnNodes(new StringBuilder(), this.mergeDocFields(updateNodesRequest, searchedMcpToolWrappers, (BulkResponse)bulkResponse), updateNodesRequest.getMcpTools().stream().map(McpToolBaseInput::getName).collect(Collectors.toUnmodifiableSet()), (ActionListener<MLMcpToolsUpdateNodesResponse>)restoreListener);
                } else {
                    AtomicReference updateSucceedTools = new AtomicReference();
                    updateSucceedTools.set(new HashSet());
                    AtomicReference updateFailedTools = new AtomicReference();
                    updateFailedTools.set(new HashMap());
                    Arrays.stream(bulkResponse.getItems()).forEach(y -> {
                        if (y.isFailed()) {
                            ((Map)updateFailedTools.get()).put(y.getId(), y.getFailure().getMessage());
                            updateNodesRequest.getMcpTools().removeIf(x -> x.getName().equals(y.getId()));
                            searchedMcpToolWrappers.removeIf(x -> x.getMcpTool().getName().equals(y.getId()));
                        } else {
                            ((Set)updateSucceedTools.get()).add(y.getId());
                        }
                    });
                    StringBuilder errMsgBuilder = new StringBuilder();
                    for (Map.Entry indexFailedTool : ((Map)updateFailedTools.get()).entrySet()) {
                        errMsgBuilder.append(String.format(Locale.ROOT, "Failed to update mcp tool: %s in system index with error: %s", indexFailedTool.getKey(), indexFailedTool.getValue()));
                        errMsgBuilder.append("\n");
                    }
                    log.error(errMsgBuilder.toString());
                    if (!((Set)updateSucceedTools.get()).isEmpty()) {
                        this.updateMcpToolsOnNodes(errMsgBuilder, this.mergeDocFields(updateNodesRequest, searchedMcpToolWrappers, (BulkResponse)bulkResponse), (Set)updateSucceedTools.get(), (ActionListener<MLMcpToolsUpdateNodesResponse>)restoreListener);
                    } else {
                        restoreListener.onFailure((Exception)new OpenSearchException(errMsgBuilder.deleteCharAt(errMsgBuilder.length() - 1).toString(), new Object[0]));
                    }
                }
            }, e -> {
                log.error("Failed to update mcp tools in system index because exception: {}", (Object)e.getMessage());
                restoreListener.onFailure(e);
            });
            Map<String, SearchedMcpToolWrapper> searchedMcpToolWrapperMap = searchedMcpToolWrappers.stream().collect(Collectors.toMap(x -> x.getMcpTool().getName(), x -> x));
            BulkRequest bulkRequest = new BulkRequest();
            for (McpToolUpdateInput mcpTool : updateNodesRequest.getMcpTools()) {
                UpdateRequest updateRequest = new UpdateRequest(MLIndex.MCP_TOOLS.getIndexName(), mcpTool.getName());
                updateRequest.setIfSeqNo(searchedMcpToolWrapperMap.get(mcpTool.getName()).getSeqNo().longValue());
                updateRequest.setIfPrimaryTerm(searchedMcpToolWrapperMap.get(mcpTool.getName()).getPrimaryTerm().longValue());
                HashMap<String, Object> source = new HashMap<String, Object>();
                if (mcpTool.getDescription() != null) {
                    source.put("description", mcpTool.getDescription());
                }
                if (mcpTool.getParameters() != null) {
                    source.put("parameters", mcpTool.getParameters());
                }
                if (mcpTool.getAttributes() != null) {
                    source.put("attributes", mcpTool.getAttributes());
                }
                source.put("last_update_time", Instant.now().toEpochMilli());
                updateRequest.doc(source);
                bulkRequest.add(updateRequest);
            }
            bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
            this.client.bulk(bulkRequest, updateResultListener);
        }
        catch (Exception e2) {
            log.error("Failed to update mcp tools", (Throwable)e2);
            listener.onFailure(e2);
        }
    }

    private MLMcpToolsUpdateNodesRequest mergeDocFields(MLMcpToolsUpdateNodesRequest updateNodesRequest, List<SearchedMcpToolWrapper> updateMcpToolWrappers, BulkResponse bulkResponse) {
        Map<String, McpToolRegisterInput> mcpToolsMap = updateMcpToolWrappers.stream().collect(Collectors.toMap(x -> x.getMcpTool().getName(), SearchedMcpToolWrapper::getMcpTool));
        Map<String, Long> versions = Arrays.stream(bulkResponse.getItems()).filter(x -> !x.isFailed()).collect(Collectors.toMap(BulkItemResponse::getId, x -> x.getResponse().getVersion()));
        updateNodesRequest.getMcpTools().forEach(x -> {
            McpToolRegisterInput registerMcpTool = (McpToolRegisterInput)mcpToolsMap.get(x.getName());
            x.setType(registerMcpTool.getType());
            if (x.getAttributes() == null) {
                x.setAttributes(registerMcpTool.getAttributes());
            }
            if (x.getParameters() == null) {
                x.setParameters(registerMcpTool.getParameters());
            }
            if (x.getDescription() == null) {
                x.setDescription(registerMcpTool.getDescription());
            }
            x.setVersion((Long)versions.get(x.getName()));
        });
        return updateNodesRequest;
    }

    private void updateMcpToolsOnNodes(StringBuilder errMsgBuilder, MLMcpToolsUpdateNodesRequest toolsUpdateNodesRequest, Set<String> indexSucceedTools, ActionListener<MLMcpToolsUpdateNodesResponse> listener) {
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            ActionListener restoreListener = ActionListener.runBefore(listener, () -> ((ThreadContext.StoredContext)context).restore());
            ActionListener addToMemoryResultListener = ActionListener.wrap(r -> {
                if (r.failures() != null && !r.failures().isEmpty()) {
                    r.failures().forEach(x -> {
                        errMsgBuilder.append(String.format(Locale.ROOT, "Tools: %s are updated successfully but failed to update to mcp server memory with error: %s", indexSucceedTools, x.getRootCause().getMessage()));
                        errMsgBuilder.append("\n");
                    });
                    errMsgBuilder.deleteCharAt(errMsgBuilder.length() - 1);
                    log.error(errMsgBuilder.toString());
                    restoreListener.onFailure((Exception)new OpenSearchException(errMsgBuilder.toString(), new Object[0]));
                } else if (errMsgBuilder.isEmpty()) {
                    restoreListener.onResponse(r);
                } else {
                    restoreListener.onFailure((Exception)new OpenSearchException(errMsgBuilder.deleteCharAt(errMsgBuilder.length() - 1).toString(), new Object[0]));
                }
            }, e -> {
                errMsgBuilder.append(String.format(Locale.ROOT, "Tools are updated successfully but failed to update to mcp server memory with error: %s", e.getMessage()));
                log.error(errMsgBuilder.toString(), (Throwable)e);
                restoreListener.onFailure((Exception)new OpenSearchException(errMsgBuilder.toString(), new Object[0]));
            });
            this.client.execute((ActionType)MLMcpToolsUpdateOnNodesAction.INSTANCE, (ActionRequest)toolsUpdateNodesRequest, addToMemoryResultListener);
        }
        catch (Exception e2) {
            log.error("Failed to update mcp tools on nodes", (Throwable)e2);
            listener.onFailure(e2);
        }
    }

    private static class SearchedMcpToolWrapper {
        private McpToolRegisterInput mcpTool;
        private Long primaryTerm;
        private Long seqNo;
        private Long version;

        @Generated
        SearchedMcpToolWrapper(McpToolRegisterInput mcpTool, Long primaryTerm, Long seqNo, Long version) {
            this.mcpTool = mcpTool;
            this.primaryTerm = primaryTerm;
            this.seqNo = seqNo;
            this.version = version;
        }

        @Generated
        public static SearchedMcpToolWrapperBuilder builder() {
            return new SearchedMcpToolWrapperBuilder();
        }

        @Generated
        public McpToolRegisterInput getMcpTool() {
            return this.mcpTool;
        }

        @Generated
        public Long getPrimaryTerm() {
            return this.primaryTerm;
        }

        @Generated
        public Long getSeqNo() {
            return this.seqNo;
        }

        @Generated
        public Long getVersion() {
            return this.version;
        }

        @Generated
        public void setMcpTool(McpToolRegisterInput mcpTool) {
            this.mcpTool = mcpTool;
        }

        @Generated
        public void setPrimaryTerm(Long primaryTerm) {
            this.primaryTerm = primaryTerm;
        }

        @Generated
        public void setSeqNo(Long seqNo) {
            this.seqNo = seqNo;
        }

        @Generated
        public void setVersion(Long version) {
            this.version = version;
        }

        @Generated
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof SearchedMcpToolWrapper)) {
                return false;
            }
            SearchedMcpToolWrapper other = (SearchedMcpToolWrapper)o;
            if (!other.canEqual(this)) {
                return false;
            }
            Long this$primaryTerm = this.getPrimaryTerm();
            Long other$primaryTerm = other.getPrimaryTerm();
            if (this$primaryTerm == null ? other$primaryTerm != null : !((Object)this$primaryTerm).equals(other$primaryTerm)) {
                return false;
            }
            Long this$seqNo = this.getSeqNo();
            Long other$seqNo = other.getSeqNo();
            if (this$seqNo == null ? other$seqNo != null : !((Object)this$seqNo).equals(other$seqNo)) {
                return false;
            }
            Long this$version = this.getVersion();
            Long other$version = other.getVersion();
            if (this$version == null ? other$version != null : !((Object)this$version).equals(other$version)) {
                return false;
            }
            McpToolRegisterInput this$mcpTool = this.getMcpTool();
            McpToolRegisterInput other$mcpTool = other.getMcpTool();
            return !(this$mcpTool == null ? other$mcpTool != null : !this$mcpTool.equals(other$mcpTool));
        }

        @Generated
        protected boolean canEqual(Object other) {
            return other instanceof SearchedMcpToolWrapper;
        }

        @Generated
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            Long $primaryTerm = this.getPrimaryTerm();
            result = result * 59 + ($primaryTerm == null ? 43 : ((Object)$primaryTerm).hashCode());
            Long $seqNo = this.getSeqNo();
            result = result * 59 + ($seqNo == null ? 43 : ((Object)$seqNo).hashCode());
            Long $version = this.getVersion();
            result = result * 59 + ($version == null ? 43 : ((Object)$version).hashCode());
            McpToolRegisterInput $mcpTool = this.getMcpTool();
            result = result * 59 + ($mcpTool == null ? 43 : $mcpTool.hashCode());
            return result;
        }

        @Generated
        public String toString() {
            return "TransportMcpToolsUpdateAction.SearchedMcpToolWrapper(mcpTool=" + String.valueOf(this.getMcpTool()) + ", primaryTerm=" + this.getPrimaryTerm() + ", seqNo=" + this.getSeqNo() + ", version=" + this.getVersion() + ")";
        }

        @Generated
        public static class SearchedMcpToolWrapperBuilder {
            @Generated
            private McpToolRegisterInput mcpTool;
            @Generated
            private Long primaryTerm;
            @Generated
            private Long seqNo;
            @Generated
            private Long version;

            @Generated
            SearchedMcpToolWrapperBuilder() {
            }

            @Generated
            public SearchedMcpToolWrapperBuilder mcpTool(McpToolRegisterInput mcpTool) {
                this.mcpTool = mcpTool;
                return this;
            }

            @Generated
            public SearchedMcpToolWrapperBuilder primaryTerm(Long primaryTerm) {
                this.primaryTerm = primaryTerm;
                return this;
            }

            @Generated
            public SearchedMcpToolWrapperBuilder seqNo(Long seqNo) {
                this.seqNo = seqNo;
                return this;
            }

            @Generated
            public SearchedMcpToolWrapperBuilder version(Long version) {
                this.version = version;
                return this;
            }

            @Generated
            public SearchedMcpToolWrapper build() {
                return new SearchedMcpToolWrapper(this.mcpTool, this.primaryTerm, this.seqNo, this.version);
            }

            @Generated
            public String toString() {
                return "TransportMcpToolsUpdateAction.SearchedMcpToolWrapper.SearchedMcpToolWrapperBuilder(mcpTool=" + String.valueOf(this.mcpTool) + ", primaryTerm=" + this.primaryTerm + ", seqNo=" + this.seqNo + ", version=" + this.version + ")";
            }
        }
    }
}

