/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor;

import com.google.common.annotations.VisibleForTesting;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import lombok.Generated;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.ingest.AbstractProcessor;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;

public class TextImageEmbeddingProcessor
extends AbstractProcessor {
    @Generated
    private static final Logger log = LogManager.getLogger(TextImageEmbeddingProcessor.class);
    public static final String TYPE = "text_image_embedding";
    public static final String MODEL_ID_FIELD = "model_id";
    public static final String EMBEDDING_FIELD = "embedding";
    public static final String FIELD_MAP_FIELD = "field_map";
    public static final String TEXT_FIELD_NAME = "text";
    public static final String IMAGE_FIELD_NAME = "image";
    public static final String INPUT_TEXT = "inputText";
    public static final String INPUT_IMAGE = "inputImage";
    private static final Set<String> VALID_FIELD_NAMES = Set.of("text", "image");
    private final String modelId;
    private final String embedding;
    private final Map<String, String> fieldMap;
    private final MLCommonsClientAccessor mlCommonsClientAccessor;
    private final Environment environment;
    private final ClusterService clusterService;

    public TextImageEmbeddingProcessor(String tag, String description, String modelId, String embedding, Map<String, String> fieldMap, MLCommonsClientAccessor clientAccessor, Environment environment, ClusterService clusterService) {
        super(tag, description);
        if (StringUtils.isBlank((CharSequence)modelId)) {
            throw new IllegalArgumentException("model_id is null or empty, can not process it");
        }
        this.validateEmbeddingConfiguration(fieldMap);
        this.modelId = modelId;
        this.embedding = embedding;
        this.fieldMap = fieldMap;
        this.mlCommonsClientAccessor = clientAccessor;
        this.environment = environment;
        this.clusterService = clusterService;
    }

    private void validateEmbeddingConfiguration(Map<String, String> fieldMap) {
        if (fieldMap == null || fieldMap.isEmpty() || fieldMap.entrySet().stream().anyMatch(x -> StringUtils.isBlank((CharSequence)((CharSequence)x.getKey())) || Objects.isNull(x.getValue()))) {
            throw new IllegalArgumentException("Unable to create the TextImageEmbedding processor as field_map has invalid key or value");
        }
        if (fieldMap.entrySet().stream().anyMatch(entry -> !VALID_FIELD_NAMES.contains(entry.getKey()))) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Unable to create the TextImageEmbedding processor with provided field name(s). Following names are supported [%s]", String.join((CharSequence)",", VALID_FIELD_NAMES)));
        }
    }

    public IngestDocument execute(IngestDocument ingestDocument) {
        return ingestDocument;
    }

    public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
        try {
            this.validateEmbeddingFieldsValue(ingestDocument);
            Map<String, String> knnMap = this.buildMapWithKnnKeyAndOriginalValue(ingestDocument);
            Map<String, String> inferenceMap = this.createInferences(knnMap);
            if (inferenceMap.isEmpty()) {
                handler.accept(ingestDocument, null);
            } else {
                this.mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceMap, (ActionListener<List<Float>>)ActionListener.wrap(vectors -> {
                    this.setVectorFieldsToDocument(ingestDocument, (List<Float>)vectors);
                    handler.accept(ingestDocument, null);
                }, e -> handler.accept((IngestDocument)null, (Exception)e)));
            }
        }
        catch (Exception e2) {
            handler.accept(null, e2);
        }
    }

    private void setVectorFieldsToDocument(IngestDocument ingestDocument, List<Float> vectors) {
        Objects.requireNonNull(vectors, "embedding failed, inference returns null result!");
        log.debug("Text embedding result fetched, starting build vector output!");
        Map<String, Object> textEmbeddingResult = this.buildTextEmbeddingResult(this.embedding, vectors);
        textEmbeddingResult.forEach((arg_0, arg_1) -> ((IngestDocument)ingestDocument).setFieldValue(arg_0, arg_1));
    }

    private Map<String, String> createInferences(Map<String, String> knnKeyMap) {
        HashMap<String, String> texts = new HashMap<String, String>();
        if (this.fieldMap.containsKey(TEXT_FIELD_NAME) && knnKeyMap.containsKey(this.fieldMap.get(TEXT_FIELD_NAME))) {
            texts.put(INPUT_TEXT, knnKeyMap.get(this.fieldMap.get(TEXT_FIELD_NAME)));
        }
        if (this.fieldMap.containsKey(IMAGE_FIELD_NAME) && knnKeyMap.containsKey(this.fieldMap.get(IMAGE_FIELD_NAME))) {
            texts.put(INPUT_IMAGE, knnKeyMap.get(this.fieldMap.get(IMAGE_FIELD_NAME)));
        }
        return texts;
    }

    @VisibleForTesting
    Map<String, String> buildMapWithKnnKeyAndOriginalValue(IngestDocument ingestDocument) {
        Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
        LinkedHashMap<String, String> mapWithKnnKeys = new LinkedHashMap<String, String>();
        for (Map.Entry<String, String> fieldMapEntry : this.fieldMap.entrySet()) {
            String originalKey = fieldMapEntry.getValue();
            if (!sourceAndMetadataMap.containsKey(originalKey)) continue;
            if (!(sourceAndMetadataMap.get(originalKey) instanceof String)) {
                throw new IllegalArgumentException("Unsupported format of the field in the document, value must be a string");
            }
            mapWithKnnKeys.put(originalKey, (String)sourceAndMetadataMap.get(originalKey));
        }
        return mapWithKnnKeys;
    }

    @VisibleForTesting
    Map<String, Object> buildTextEmbeddingResult(String knnKey, List<Float> modelTensorList) {
        LinkedHashMap<String, Object> result = new LinkedHashMap<String, Object>();
        result.put(knnKey, modelTensorList);
        return result;
    }

    private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
        Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
        for (Map.Entry<String, String> embeddingFieldsEntry : this.fieldMap.entrySet()) {
            String mappedSourceKey = embeddingFieldsEntry.getValue();
            Object sourceValue = sourceAndMetadataMap.get(mappedSourceKey);
            if (Objects.isNull(sourceValue)) continue;
            Class<?> sourceValueClass = sourceValue.getClass();
            if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) {
                String indexName = sourceAndMetadataMap.get("_index").toString();
                this.validateNestedTypeValue(mappedSourceKey, sourceValue, () -> 1, indexName);
                continue;
            }
            if (!String.class.isAssignableFrom(sourceValueClass)) {
                throw new IllegalArgumentException("field [" + mappedSourceKey + "] is neither string nor nested type, can not process it");
            }
            if (!StringUtils.isBlank((CharSequence)sourceValue.toString())) continue;
            throw new IllegalArgumentException("field [" + mappedSourceKey + "] has empty string value, can not process it");
        }
    }

    private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier<Integer> maxDepthSupplier, String indexName) {
        Settings indexSettings;
        int maxDepth = maxDepthSupplier.get();
        if ((long)maxDepth > (Long)MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings = this.clusterService.state().metadata().index(indexName).getSettings())) {
            throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it");
        }
        if (List.class.isAssignableFrom(sourceValue.getClass())) {
            TextImageEmbeddingProcessor.validateListTypeValue(sourceKey, (List)sourceValue);
        } else if (Map.class.isAssignableFrom(sourceValue.getClass())) {
            ((Map)sourceValue).values().stream().filter(Objects::nonNull).forEach(x -> this.validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1, indexName));
        } else {
            if (!String.class.isAssignableFrom(sourceValue.getClass())) {
                throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, can not process it");
            }
            if (StringUtils.isBlank((CharSequence)sourceValue.toString())) {
                throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, can not process it");
            }
        }
    }

    private static void validateListTypeValue(String sourceKey, List<Object> sourceValue) {
        for (Object value : sourceValue) {
            if (value == null) {
                throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, can not process it");
            }
            if (!(value instanceof String)) {
                throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, can not process it");
            }
            if (!StringUtils.isBlank((CharSequence)value.toString())) continue;
            throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, can not process it");
        }
    }

    public String getType() {
        return TYPE;
    }
}

