/*
 * Decompiled with CFR 0.152.
 */
package com.floragunn.searchguard.transport;

import com.floragunn.searchguard.GuiceDependencies;
import com.floragunn.searchguard.auditlog.AuditLog;
import com.floragunn.searchguard.configuration.AdminDNs;
import com.floragunn.searchguard.configuration.ClusterInfoHolder;
import com.floragunn.searchguard.ssl.SslExceptionHandler;
import com.floragunn.searchguard.ssl.transport.PrincipalExtractor;
import com.floragunn.searchguard.support.Base64Helper;
import com.floragunn.searchguard.transport.InterClusterRequestEvaluator;
import com.floragunn.searchguard.transport.SearchGuardRequestHandler;
import com.floragunn.searchguard.user.User;
import com.floragunn.searchsupport.diag.DiagnosticContext;
import com.google.common.collect.Maps;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.regex.PatternSyntaxException;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse;
import org.elasticsearch.action.get.GetRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.RemoteClusterService;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportInterceptor;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportResponseHandler;

public class SearchGuardInterceptor {
    protected final Logger actionTrace = LogManager.getLogger((String)"sg_action_trace");
    protected static final Logger log = LogManager.getLogger(SearchGuardInterceptor.class);
    private AuditLog auditLog;
    private final ThreadPool threadPool;
    private final PrincipalExtractor principalExtractor;
    private final InterClusterRequestEvaluator requestEvalProvider;
    private final ClusterService cs;
    private final SslExceptionHandler sslExceptionHandler;
    private final ClusterInfoHolder clusterInfoHolder;
    private final List<Pattern> customAllowedHeaderPatterns;
    private final DiagnosticContext diagnosticContext;
    private final GuiceDependencies guiceDependencies;
    private final AdminDNs adminDns;

    public SearchGuardInterceptor(Settings settings, ThreadPool threadPool, AuditLog auditLog, PrincipalExtractor principalExtractor, InterClusterRequestEvaluator requestEvalProvider, ClusterService cs, SslExceptionHandler sslExceptionHandler, ClusterInfoHolder clusterInfoHolder, GuiceDependencies guiceDependencies, DiagnosticContext diagnosticContext, AdminDNs adminDns) {
        this.auditLog = auditLog;
        this.threadPool = threadPool;
        this.principalExtractor = principalExtractor;
        this.requestEvalProvider = requestEvalProvider;
        this.cs = cs;
        this.sslExceptionHandler = sslExceptionHandler;
        this.clusterInfoHolder = clusterInfoHolder;
        this.customAllowedHeaderPatterns = SearchGuardInterceptor.getCustomAllowedHeaderPatterns(settings);
        this.diagnosticContext = diagnosticContext;
        this.guiceDependencies = guiceDependencies;
        this.adminDns = adminDns;
    }

    public <T extends TransportRequest> SearchGuardRequestHandler<T> getHandler(String action, TransportRequestHandler<T> actualHandler) {
        return new SearchGuardRequestHandler<T>(action, actualHandler, this.threadPool, this.auditLog, this.principalExtractor, this.requestEvalProvider, this.cs, this.sslExceptionHandler, this.adminDns);
    }

    public <T extends TransportResponse> void sendRequestDecorate(TransportInterceptor.AsyncSender sender, Transport.Connection connection, String action, TransportRequest request, TransportRequestOptions options, TransportResponseHandler<T> handler) {
        Map origHeaders0 = this.getThreadContext().getHeaders();
        User user0 = (User)this.getThreadContext().getTransient("_sg_user");
        String origin0 = (String)this.getThreadContext().getTransient("_sg_origin");
        Object remoteAdress0 = this.getThreadContext().getTransient("_sg_remote_address");
        String origCCSTransientDls = (String)this.getThreadContext().getTransient("_sg_dls_query_ccs");
        String origCCSTransientFls = (String)this.getThreadContext().getTransient("_sg_fls_fields_ccs");
        String origCCSTransientMf = (String)this.getThreadContext().getTransient("_sg_masked_fields_ccs");
        String actionStack = this.diagnosticContext.getActionStack();
        try (ThreadContext.StoredContext stashedContext = this.getThreadContext().stashContext();){
            RestoringTransportResponseHandler restoringHandler = new RestoringTransportResponseHandler(handler, stashedContext);
            this.getThreadContext().putHeader("_sg_remotecn", this.cs.getClusterName().value());
            HashMap<String, String> headerMap = new HashMap<String, String>(Maps.filterKeys((Map)origHeaders0, k -> k != null && (k.equals("_sg_conf_request") || k.equals("_sg_origin_header") || k.equals("_sg_remote_address_header") || k.equals("_sg_user_header") || k.equals("_sg_dls_query") || k.equals("_sg_fls_fields") || k.equals("_sg_masked_fields") || k.equals("_sg_doc_whitelist") || k.equals("_sg_filter_level_dls_done") || k.equals("_sg_dls_mode") || k.equals("_sg_dls_filter_level_query") || k.equals("_sg_source_field_context") && !(request instanceof SearchRequest) && !(request instanceof GetRequest) || k.startsWith("_sg_trace") || k.startsWith("_sg_initial_action_class_header") || this.checkCustomAllowedHeader((String)k))));
            RemoteClusterService remoteClusterService = this.guiceDependencies.getTransportService().getRemoteClusterService();
            if (remoteClusterService.isCrossClusterSearchEnabled() && this.clusterInfoHolder.isInitialized() && (action.equals("indices:admin/shards/search_shards") || action.equals("indices:data/read/search")) && !this.clusterInfoHolder.hasNode(connection.getNode()).booleanValue()) {
                if (log.isDebugEnabled()) {
                    log.debug("remove dls/fls/mf because we sent a ccs request to a remote cluster");
                }
                headerMap.remove("_sg_dls_query");
                headerMap.remove("_sg_dls_mode");
                headerMap.remove("_sg_masked_fields");
                headerMap.remove("_sg_fls_fields");
                headerMap.remove("_sg_filter_level_dls_done");
                headerMap.remove("_sg_dls_filter_level_query");
                headerMap.remove("_sg_doc_whitelist");
            }
            if (remoteClusterService.isCrossClusterSearchEnabled() && this.clusterInfoHolder.isInitialized() && !action.startsWith("internal:") && !action.equals("indices:admin/shards/search_shards") && !this.clusterInfoHolder.hasNode(connection.getNode()).booleanValue()) {
                if (log.isDebugEnabled()) {
                    log.debug("add dls/fls/mf from transient");
                }
                if (origCCSTransientDls != null && !origCCSTransientDls.isEmpty()) {
                    headerMap.put("_sg_dls_query", origCCSTransientDls);
                }
                if (origCCSTransientMf != null && !origCCSTransientMf.isEmpty()) {
                    headerMap.put("_sg_masked_fields", origCCSTransientMf);
                }
                if (origCCSTransientFls != null && !origCCSTransientFls.isEmpty()) {
                    headerMap.put("_sg_fls_fields", origCCSTransientFls);
                }
            }
            if (actionStack != null) {
                this.getThreadContext().putHeader("x_action_stack", actionStack);
            }
            this.getThreadContext().putHeader(headerMap);
            this.ensureCorrectHeaders(remoteAdress0, user0, origin0);
            if (this.actionTrace.isTraceEnabled()) {
                this.getThreadContext().putHeader("_sg_trace" + System.currentTimeMillis() + "#" + UUID.randomUUID().toString(), Thread.currentThread().getName() + " IC -> " + action + " " + this.getThreadContext().getHeaders().entrySet().stream().filter(p -> !((String)p.getKey()).startsWith("_sg_trace")).collect(Collectors.toMap(p -> (String)p.getKey(), p -> (String)p.getValue())));
            }
            sender.sendRequest(connection, action, request, options, restoringHandler);
        }
    }

    private void ensureCorrectHeaders(Object remoteAdr, User origUser, String origin) {
        String userHeader;
        String remoteAddressHeader;
        if (origin != null && !origin.isEmpty() && this.getThreadContext().getHeader("_sg_origin_header") == null) {
            this.getThreadContext().putHeader("_sg_origin_header", origin);
        }
        if (origin == null && this.getThreadContext().getHeader("_sg_origin_header") == null) {
            this.getThreadContext().putHeader("_sg_origin_header", AuditLog.Origin.LOCAL.toString());
        }
        if (remoteAdr != null && remoteAdr instanceof TransportAddress && (remoteAddressHeader = this.getThreadContext().getHeader("_sg_remote_address_header")) == null) {
            this.getThreadContext().putHeader("_sg_remote_address_header", Base64Helper.serializeObject(((TransportAddress)remoteAdr).address()));
        }
        if (origUser != null && (userHeader = this.getThreadContext().getHeader("_sg_user_header")) == null) {
            this.getThreadContext().putHeader("_sg_user_header", Base64Helper.serializeObject(origUser));
        }
    }

    private ThreadContext getThreadContext() {
        return this.threadPool.getThreadContext();
    }

    private boolean checkCustomAllowedHeader(String headerKey) {
        if (headerKey.startsWith("_sg_")) {
            return false;
        }
        if (headerKey.equals("X-Opaque-Id")) {
            return false;
        }
        if (this.customAllowedHeaderPatterns.size() == 0) {
            return false;
        }
        for (Pattern pattern : this.customAllowedHeaderPatterns) {
            Matcher matcher = pattern.matcher(headerKey);
            if (!matcher.matches()) continue;
            return true;
        }
        return false;
    }

    private static List<Pattern> getCustomAllowedHeaderPatterns(Settings settings) {
        List patternStrings = settings.getAsList("searchguard.allow_custom_headers", Collections.emptyList());
        ArrayList<Pattern> result = new ArrayList<Pattern>(patternStrings.size());
        for (String patternString : patternStrings) {
            try {
                result.add(Pattern.compile(patternString));
            }
            catch (PatternSyntaxException e) {
                log.error("Invalid pattern configured in searchguard.allow_custom_headers: " + patternString, (Throwable)e);
            }
        }
        return Collections.unmodifiableList(result);
    }

    private class RestoringTransportResponseHandler<T extends TransportResponse>
    implements TransportResponseHandler<T> {
        private final ThreadContext.StoredContext contextToRestore;
        private final TransportResponseHandler<T> innerHandler;

        private RestoringTransportResponseHandler(TransportResponseHandler<T> innerHandler, ThreadContext.StoredContext contextToRestore) {
            this.contextToRestore = contextToRestore;
            this.innerHandler = innerHandler;
        }

        public T read(StreamInput in) throws IOException {
            return (T)((TransportResponse)this.innerHandler.read(in));
        }

        public void handleResponse(T response) {
            ThreadContext threadContext = SearchGuardInterceptor.this.getThreadContext();
            Map responseHeaders = threadContext.getResponseHeaders();
            List flsResponseHeader = (List)responseHeaders.get("_sg_fls_fields");
            List dlsResponseHeader = (List)responseHeaders.get("_sg_dls_query");
            List maskedFieldsResponseHeader = (List)responseHeaders.get("_sg_masked_fields");
            this.contextToRestore.restore();
            if (response instanceof ClusterSearchShardsResponse) {
                if (flsResponseHeader != null && !flsResponseHeader.isEmpty()) {
                    threadContext.putTransient("_sg_fls_fields_ccs", flsResponseHeader.get(0));
                }
                if (dlsResponseHeader != null && !dlsResponseHeader.isEmpty()) {
                    threadContext.putTransient("_sg_dls_query_ccs", dlsResponseHeader.get(0));
                }
                if (maskedFieldsResponseHeader != null && !maskedFieldsResponseHeader.isEmpty()) {
                    threadContext.putTransient("_sg_masked_fields_ccs", maskedFieldsResponseHeader.get(0));
                }
            }
            this.innerHandler.handleResponse(response);
        }

        public void handleException(TransportException e) {
            this.contextToRestore.restore();
            this.innerHandler.handleException(e);
        }

        public String executor() {
            return this.innerHandler.executor();
        }
    }
}

