/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.security.configuration;

import com.google.common.collect.ImmutableList;
import java.io.Serializable;
import java.lang.reflect.Field;
import java.security.AccessController;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.StreamSupport;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.opensearch.OpenSearchException;
import org.opensearch.OpenSearchSecurityException;
import org.opensearch.SpecialPermission;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.DocWriteRequest;
import org.opensearch.action.RealtimeRequest;
import org.opensearch.action.admin.cluster.shards.ClusterSearchShardsRequest;
import org.opensearch.action.admin.indices.shrink.ResizeRequest;
import org.opensearch.action.bulk.BulkItemRequest;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.bulk.BulkShardRequest;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.Strings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.index.query.ParsedQuery;
import org.opensearch.rest.RestStatus;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.aggregations.AggregationBuilder;
import org.opensearch.search.aggregations.AggregatorFactories;
import org.opensearch.search.aggregations.BucketOrder;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation;
import org.opensearch.search.aggregations.bucket.sampler.DiversifiedAggregationBuilder;
import org.opensearch.search.aggregations.bucket.terms.InternalTerms;
import org.opensearch.search.aggregations.bucket.terms.SignificantTermsAggregationBuilder;
import org.opensearch.search.aggregations.bucket.terms.StringTerms;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.security.OpenSearchSecurityPlugin;
import org.opensearch.security.configuration.DlsFilterLevelActionHandler;
import org.opensearch.security.configuration.DlsFlsRequestValve;
import org.opensearch.security.configuration.DlsQueryParser;
import org.opensearch.security.resolver.IndexResolverReplacer;
import org.opensearch.security.securityconf.EvaluatedDlsFlsConfig;
import org.opensearch.security.support.Base64Helper;
import org.opensearch.security.support.HeaderHelper;
import org.opensearch.security.support.SecurityUtils;
import org.opensearch.threadpool.ThreadPool;

public class DlsFlsValveImpl
implements DlsFlsRequestValve {
    private static final String MAP_EXECUTION_HINT = "map";
    private static final Logger log = LogManager.getLogger(DlsFlsValveImpl.class);
    private final Client nodeClient;
    private final ClusterService clusterService;
    private final ThreadContext threadContext;
    private final Mode mode;
    private final DlsQueryParser dlsQueryParser;
    private final IndexNameExpressionResolver resolver;

    public DlsFlsValveImpl(Settings settings, Client nodeClient, ClusterService clusterService, IndexNameExpressionResolver resolver, NamedXContentRegistry namedXContentRegistry, ThreadContext threadContext) {
        this.nodeClient = nodeClient;
        this.clusterService = clusterService;
        this.resolver = resolver;
        this.threadContext = threadContext;
        this.mode = Mode.get(settings);
        this.dlsQueryParser = new DlsQueryParser(namedXContentRegistry);
    }

    @Override
    public boolean invoke(String action, ActionRequest request, ActionListener<?> listener, EvaluatedDlsFlsConfig evaluatedDlsFlsConfig, IndexResolverReplacer.Resolved resolved) {
        SearchSourceBuilder source;
        boolean doFilterLevelDls;
        if (log.isDebugEnabled()) {
            log.debug("DlsFlsValveImpl.invoke()\nrequest: " + request + "\nevaluatedDlsFlsConfig: " + evaluatedDlsFlsConfig + "\nresolved: " + resolved + "\nmode: " + this.mode);
        }
        if (evaluatedDlsFlsConfig == null || evaluatedDlsFlsConfig.isEmpty()) {
            return true;
        }
        if (this.threadContext.getHeader("_opendistro_security_filter_level_dls_done") != null) {
            if (log.isDebugEnabled()) {
                log.debug("DLS is already done for: " + this.threadContext.getHeader("_opendistro_security_filter_level_dls_done"));
            }
            return true;
        }
        EvaluatedDlsFlsConfig filteredDlsFlsConfig = evaluatedDlsFlsConfig.filter(resolved);
        if (this.mode == Mode.FILTER_LEVEL) {
            doFilterLevelDls = true;
        } else if (this.mode == Mode.LUCENE_LEVEL) {
            doFilterLevelDls = false;
        } else {
            Mode modeByHeader = this.getDlsModeHeader();
            if (modeByHeader == Mode.FILTER_LEVEL) {
                doFilterLevelDls = true;
                log.debug("Doing filter-level DLS due to header");
            } else {
                doFilterLevelDls = this.dlsQueryParser.containsTermLookupQuery(filteredDlsFlsConfig.getAllQueries());
                if (doFilterLevelDls) {
                    this.setDlsModeHeader(Mode.FILTER_LEVEL);
                    log.debug("Doing filter-level DLS because the query contains a TLQ");
                } else {
                    log.debug("Doing lucene-level DLS because the query does not contain a TLQ");
                }
            }
        }
        if (!doFilterLevelDls) {
            this.setDlsHeaders(evaluatedDlsFlsConfig, request);
        }
        this.setFlsHeaders(evaluatedDlsFlsConfig, request);
        if (filteredDlsFlsConfig.isEmpty()) {
            return true;
        }
        if (request instanceof RealtimeRequest) {
            ((RealtimeRequest)request).realtime(Boolean.FALSE.booleanValue());
        }
        if (request instanceof SearchRequest) {
            SearchRequest searchRequest = (SearchRequest)request;
            if (evaluatedDlsFlsConfig.hasFieldMasking() && searchRequest.source() != null && searchRequest.source().aggregations() != null) {
                for (AggregationBuilder aggregationBuilder : searchRequest.source().aggregations().getAggregatorFactories()) {
                    if (aggregationBuilder instanceof TermsAggregationBuilder) {
                        ((TermsAggregationBuilder)aggregationBuilder).executionHint(MAP_EXECUTION_HINT);
                    }
                    if (aggregationBuilder instanceof SignificantTermsAggregationBuilder) {
                        ((SignificantTermsAggregationBuilder)aggregationBuilder).executionHint(MAP_EXECUTION_HINT);
                    }
                    if (!(aggregationBuilder instanceof DiversifiedAggregationBuilder)) continue;
                    ((DiversifiedAggregationBuilder)aggregationBuilder).executionHint(MAP_EXECUTION_HINT);
                }
            }
            if (!evaluatedDlsFlsConfig.hasFls() && !evaluatedDlsFlsConfig.hasDls() && searchRequest.source().aggregations() != null) {
                boolean cacheable = true;
                for (AggregationBuilder af : searchRequest.source().aggregations().getAggregatorFactories()) {
                    if (!af.getType().equals("cardinality") && !af.getType().equals("count")) {
                        cacheable = false;
                        continue;
                    }
                    StringBuilder sb = new StringBuilder();
                    if (searchRequest.source() != null) {
                        sb.append(Strings.toString((ToXContent)searchRequest.source()) + System.lineSeparator());
                    }
                    sb.append(Strings.toString((ToXContent)af) + System.lineSeparator());
                    LogManager.getLogger((String)"debuglogger").error(sb.toString());
                }
                if (!cacheable) {
                    searchRequest.requestCache(Boolean.FALSE);
                } else {
                    LogManager.getLogger((String)"debuglogger").error("Shard requestcache enabled for " + (searchRequest.source() == null ? "<NULL>" : Strings.toString((ToXContent)searchRequest.source())));
                }
            } else {
                searchRequest.requestCache(Boolean.FALSE);
            }
        }
        if (request instanceof UpdateRequest) {
            listener.onFailure((Exception)new OpenSearchSecurityException("Update is not supported when FLS or DLS or Fieldmasking is activated", new Object[0]));
            return false;
        }
        if (request instanceof BulkRequest) {
            for (DocWriteRequest inner : ((BulkRequest)request).requests()) {
                if (!(inner instanceof UpdateRequest)) continue;
                listener.onFailure((Exception)new OpenSearchSecurityException("Update is not supported when FLS or DLS or Fieldmasking is activated", new Object[0]));
                return false;
            }
        }
        if (request instanceof BulkShardRequest) {
            for (BulkItemRequest inner : ((BulkShardRequest)request).items()) {
                if (!(inner.request() instanceof UpdateRequest)) continue;
                listener.onFailure((Exception)new OpenSearchSecurityException("Update is not supported when FLS or DLS or Fieldmasking is activated", new Object[0]));
                return false;
            }
        }
        if (request instanceof ResizeRequest) {
            listener.onFailure((Exception)new OpenSearchSecurityException("Resize is not supported when FLS or DLS or Fieldmasking is activated", new Object[0]));
            return false;
        }
        if (action.contains("plugins/replication")) {
            listener.onFailure((Exception)new OpenSearchSecurityException("Cross Cluster Replication is not supported when FLS or DLS or Fieldmasking is activated", RestStatus.FORBIDDEN, new Object[0]));
            return false;
        }
        if (evaluatedDlsFlsConfig.hasDls() && request instanceof SearchRequest && (source = ((SearchRequest)request).source()) != null) {
            AggregatorFactories.Builder aggregations = source.aggregations();
            if (aggregations != null) {
                for (AggregationBuilder factory : aggregations.getAggregatorFactories()) {
                    if (!(factory instanceof TermsAggregationBuilder) || ((TermsAggregationBuilder)factory).minDocCount() != 0L) continue;
                    listener.onFailure((Exception)new OpenSearchException("min_doc_count 0 is not supported when DLS is activated", new Object[0]));
                    return false;
                }
            }
            if (source.profile()) {
                listener.onFailure((Exception)new OpenSearchSecurityException("Profiling is not supported when DLS is activated", new Object[0]));
                return false;
            }
        }
        if (doFilterLevelDls && filteredDlsFlsConfig.hasDls()) {
            return DlsFilterLevelActionHandler.handle(action, request, listener, evaluatedDlsFlsConfig, resolved, this.nodeClient, this.clusterService, OpenSearchSecurityPlugin.GuiceHolder.getIndicesService(), this.resolver, this.dlsQueryParser, this.threadContext);
        }
        return true;
    }

    @Override
    public void handleSearchContext(SearchContext context, ThreadPool threadPool, NamedXContentRegistry namedXContentRegistry) {
        try {
            Map queries = (Map)((Object)HeaderHelper.deserializeSafeFromHeader(threadPool.getThreadContext(), "_opendistro_security_dls_query"));
            String dlsEval = SecurityUtils.evalMap(queries, context.indexShard().indexSettings().getIndex().getName());
            if (dlsEval != null) {
                if (context.suggest() != null) {
                    return;
                }
                assert (context.parsedQuery() != null);
                Set unparsedDlsQueries = (Set)queries.get(dlsEval);
                if (unparsedDlsQueries != null && !unparsedDlsQueries.isEmpty()) {
                    BooleanQuery.Builder queryBuilder = this.dlsQueryParser.parse(unparsedDlsQueries, context.getQueryShardContext(), q -> new ConstantScoreQuery(q));
                    queryBuilder.add(context.parsedQuery().query(), BooleanClause.Occur.MUST);
                    ParsedQuery dlsQuery = new ParsedQuery((Query)queryBuilder.build());
                    if (dlsQuery != null) {
                        context.parsedQuery(dlsQuery);
                        context.preProcess(true);
                    }
                }
            }
        }
        catch (Exception e) {
            throw new RuntimeException("Error evaluating dls for a search query: " + e, e);
        }
    }

    @Override
    public void onQueryPhase(QuerySearchResult queryResult) {
        InternalAggregations aggregations = (InternalAggregations)queryResult.aggregations().expand();
        assert (aggregations != null);
        queryResult.aggregations(InternalAggregations.from((List)((List)StreamSupport.stream(aggregations.spliterator(), false).map(aggregation -> DlsFlsValveImpl.aggregateBuckets((InternalAggregation)aggregation)).collect(ImmutableList.toImmutableList()))));
    }

    private static InternalAggregation aggregateBuckets(InternalAggregation aggregation) {
        StringTerms stringTerms;
        List<StringTerms.Bucket> buckets;
        if (aggregation instanceof StringTerms && (buckets = (stringTerms = (StringTerms)aggregation).getBuckets()).size() > 1) {
            buckets = DlsFlsValveImpl.mergeBuckets(buckets, StringTermsGetter.getReduceOrder(stringTerms).comparator());
            aggregation = stringTerms.create(buckets);
        }
        return aggregation;
    }

    private static List<StringTerms.Bucket> mergeBuckets(List<StringTerms.Bucket> buckets, Comparator<MultiBucketsAggregation.Bucket> comparator) {
        if (log.isDebugEnabled()) {
            log.debug("Merging buckets: {}", buckets.stream().map(b -> b.getKeyAsString()).collect(ImmutableList.toImmutableList()));
        }
        buckets.sort(comparator);
        BucketMerger merger = new BucketMerger(comparator, buckets.size());
        buckets.stream().forEach(merger);
        buckets = merger.getBuckets();
        if (log.isDebugEnabled()) {
            log.debug("New buckets: {}", buckets.stream().map(b -> b.getKeyAsString()).collect(ImmutableList.toImmutableList()));
        }
        return buckets;
    }

    private void setDlsHeaders(EvaluatedDlsFlsConfig dlsFls, ActionRequest request) {
        if (!dlsFls.getDlsQueriesByIndex().isEmpty()) {
            Map<String, Set<String>> dlsQueries = dlsFls.getDlsQueriesByIndex();
            if (request instanceof ClusterSearchShardsRequest && HeaderHelper.isTrustedClusterRequest(this.threadContext)) {
                this.threadContext.addResponseHeader("_opendistro_security_dls_query", Base64Helper.serializeObject((Serializable)((Object)dlsQueries)));
                if (log.isDebugEnabled()) {
                    log.debug("added response header for DLS info: {}", dlsQueries);
                }
            } else if (this.threadContext.getHeader("_opendistro_security_dls_query") != null) {
                Serializable deserializedDlsQueries = Base64Helper.deserializeObject(this.threadContext.getHeader("_opendistro_security_dls_query"));
                if (!dlsQueries.equals(deserializedDlsQueries)) {
                    throw new OpenSearchSecurityException("_opendistro_security_dls_query does not match (SG 900D)", new Object[0]);
                }
            } else {
                this.threadContext.putHeader("_opendistro_security_dls_query", Base64Helper.serializeObject((Serializable)((Object)dlsQueries)));
                if (log.isDebugEnabled()) {
                    log.debug("attach DLS info: {}", dlsQueries);
                }
            }
        }
    }

    private void setDlsModeHeader(Mode mode) {
        String modeString = mode.name();
        if (this.threadContext.getHeader("_opendistro_security_dls_mode") != null) {
            if (!modeString.equals(this.threadContext.getHeader("_opendistro_security_dls_mode"))) {
                log.warn("Cannot update DLS mode to " + mode + "; current: " + this.threadContext.getHeader("_opendistro_security_dls_mode"));
            }
        } else {
            this.threadContext.putHeader("_opendistro_security_dls_mode", modeString);
        }
    }

    private Mode getDlsModeHeader() {
        String modeString = this.threadContext.getHeader("_opendistro_security_dls_mode");
        if (modeString != null) {
            return Mode.valueOf(modeString);
        }
        return null;
    }

    private void setFlsHeaders(EvaluatedDlsFlsConfig dlsFls, ActionRequest request) {
        if (!dlsFls.getFieldMaskingByIndex().isEmpty()) {
            Map<String, Set<String>> maskedFieldsMap = dlsFls.getFieldMaskingByIndex();
            if (request instanceof ClusterSearchShardsRequest && HeaderHelper.isTrustedClusterRequest(this.threadContext)) {
                this.threadContext.addResponseHeader("_opendistro_security_masked_fields", Base64Helper.serializeObject((Serializable)((Object)maskedFieldsMap)));
                if (log.isDebugEnabled()) {
                    log.debug("added response header for masked fields info: {}", maskedFieldsMap);
                }
            } else if (this.threadContext.getHeader("_opendistro_security_masked_fields") != null) {
                if (!maskedFieldsMap.equals(Base64Helper.deserializeObject(this.threadContext.getHeader("_opendistro_security_masked_fields")))) {
                    throw new OpenSearchSecurityException("_opendistro_security_masked_fields does not match (SG 901D)", new Object[0]);
                }
                if (log.isDebugEnabled()) {
                    log.debug("_opendistro_security_masked_fields already set");
                }
            } else {
                this.threadContext.putHeader("_opendistro_security_masked_fields", Base64Helper.serializeObject((Serializable)((Object)maskedFieldsMap)));
                if (log.isDebugEnabled()) {
                    log.debug("attach masked fields info: {}", maskedFieldsMap);
                }
            }
        }
        if (!dlsFls.getFlsByIndex().isEmpty()) {
            Map<String, Set<String>> flsFields = dlsFls.getFlsByIndex();
            if (request instanceof ClusterSearchShardsRequest && HeaderHelper.isTrustedClusterRequest(this.threadContext)) {
                this.threadContext.addResponseHeader("_opendistro_security_fls_fields", Base64Helper.serializeObject((Serializable)((Object)flsFields)));
                if (log.isDebugEnabled()) {
                    log.debug("added response header for FLS info: {}", flsFields);
                }
            } else if (this.threadContext.getHeader("_opendistro_security_fls_fields") != null) {
                if (!flsFields.equals(Base64Helper.deserializeObject(this.threadContext.getHeader("_opendistro_security_fls_fields")))) {
                    throw new OpenSearchSecurityException("_opendistro_security_fls_fields does not match (SG 901D) " + flsFields + "---" + Base64Helper.deserializeObject(this.threadContext.getHeader("_opendistro_security_fls_fields")), new Object[0]);
                }
                if (log.isDebugEnabled()) {
                    log.debug("_opendistro_security_fls_fields already set");
                }
            } else {
                this.threadContext.putHeader("_opendistro_security_fls_fields", Base64Helper.serializeObject((Serializable)((Object)flsFields)));
                if (log.isDebugEnabled()) {
                    log.debug("attach FLS info: {}", flsFields);
                }
            }
        }
    }

    public static enum Mode {
        ADAPTIVE,
        LUCENE_LEVEL,
        FILTER_LEVEL;


        static Mode get(Settings settings) {
            String modeString = settings.get("plugins.security.dls.mode");
            if ("adaptive".equalsIgnoreCase(modeString)) {
                return ADAPTIVE;
            }
            if ("lucene_level".equalsIgnoreCase(modeString)) {
                return LUCENE_LEVEL;
            }
            if ("filter_level".equalsIgnoreCase(modeString)) {
                return FILTER_LEVEL;
            }
            return ADAPTIVE;
        }
    }

    private static class StringTermsGetter {
        private static final Field REDUCE_ORDER = StringTermsGetter.getField(InternalTerms.class, "reduceOrder");
        private static final Field TERM_BYTES = StringTermsGetter.getField(StringTerms.Bucket.class, "termBytes");
        private static final Field FORMAT = StringTermsGetter.getField(InternalTerms.Bucket.class, "format");

        private StringTermsGetter() {
        }

        private static <T> Field getFieldPrivileged(Class<T> cls, String name) {
            try {
                Field field = cls.getDeclaredField(name);
                field.setAccessible(true);
                return field;
            }
            catch (NoSuchFieldException | SecurityException e) {
                log.error("Failed to get class {} declared field {}", (Object)cls.getSimpleName(), (Object)name, (Object)e);
                if (e instanceof RuntimeException) {
                    throw (RuntimeException)e;
                }
                throw new RuntimeException(e);
            }
        }

        private static <T> Field getField(Class<T> cls, String name) {
            SpecialPermission.check();
            return AccessController.doPrivileged(() -> StringTermsGetter.getFieldPrivileged(cls, name));
        }

        private static <T, C> T getFieldValue(Field field, C c) {
            try {
                return (T)field.get(c);
            }
            catch (IllegalAccessException | IllegalArgumentException e) {
                log.error("Exception while getting value {} of class {}", (Object)field.getName(), (Object)c.getClass().getSimpleName(), (Object)e);
                if (e instanceof RuntimeException) {
                    throw (RuntimeException)e;
                }
                throw new RuntimeException(e);
            }
        }

        public static BucketOrder getReduceOrder(StringTerms stringTerms) {
            return (BucketOrder)StringTermsGetter.getFieldValue(REDUCE_ORDER, stringTerms);
        }

        public static BytesRef getTerm(StringTerms.Bucket bucket) {
            return (BytesRef)StringTermsGetter.getFieldValue(TERM_BYTES, bucket);
        }

        public static DocValueFormat getDocValueFormat(StringTerms.Bucket bucket) {
            return (DocValueFormat)StringTermsGetter.getFieldValue(FORMAT, bucket);
        }
    }

    private static class BucketMerger
    implements Consumer<StringTerms.Bucket> {
        private Comparator<MultiBucketsAggregation.Bucket> comparator;
        private StringTerms.Bucket bucket = null;
        private int mergeCount;
        private long mergedDocCount;
        private long mergedDocCountError;
        private boolean showDocCountError = true;
        private final ImmutableList.Builder<StringTerms.Bucket> builder;

        BucketMerger(Comparator<MultiBucketsAggregation.Bucket> comparator, int size) {
            this.comparator = Objects.requireNonNull(comparator);
            this.builder = ImmutableList.builderWithExpectedSize((int)size);
        }

        private void finalizeBucket() {
            if (this.mergeCount == 1) {
                this.builder.add((Object)this.bucket);
            } else {
                this.builder.add((Object)new StringTerms.Bucket(StringTermsGetter.getTerm(this.bucket), this.mergedDocCount, (InternalAggregations)this.bucket.getAggregations(), this.showDocCountError, this.mergedDocCountError, StringTermsGetter.getDocValueFormat(this.bucket)));
            }
        }

        private void merge(StringTerms.Bucket bucket) {
            if (this.bucket != null && (bucket == null || this.comparator.compare((MultiBucketsAggregation.Bucket)this.bucket, (MultiBucketsAggregation.Bucket)bucket) != 0)) {
                this.finalizeBucket();
                this.bucket = null;
                this.mergeCount = 0;
                this.mergedDocCount = 0L;
                this.mergedDocCountError = 0L;
                this.showDocCountError = true;
            }
        }

        public List<StringTerms.Bucket> getBuckets() {
            this.merge(null);
            return this.builder.build();
        }

        @Override
        public void accept(StringTerms.Bucket bucket) {
            this.merge(bucket);
            ++this.mergeCount;
            this.mergedDocCount += bucket.getDocCount();
            if (this.showDocCountError) {
                try {
                    this.mergedDocCountError += bucket.getDocCountError();
                }
                catch (IllegalStateException e) {
                    this.showDocCountError = false;
                }
            }
            this.bucket = bucket;
        }
    }
}

