package com.floragunn.searchguard.enterprise.dlsfls.legacy;

import com.floragunn.searchguard.test.TestData;
import com.floragunn.searchguard.test.TestSgConfig;
import com.floragunn.searchguard.test.helper.cluster.JavaSecurityTestSetup;
import com.floragunn.searchguard.test.helper.cluster.LocalCluster;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.function.Function;
import java.util.regex.Pattern;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.bouncycastle.crypto.digests.Blake2bDigest;
import org.bouncycastle.util.encoders.Hex;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.client.RequestOptions;
import org.opensearch.client.RestHighLevelClient;
import org.opensearch.common.Strings;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.search.aggregations.AggregationBuilders;
import org.opensearch.search.aggregations.BucketOrder;
import org.opensearch.search.aggregations.bucket.terms.ParsedStringTerms;
import org.opensearch.search.aggregations.bucket.terms.Terms;
import org.opensearch.search.builder.SearchSourceBuilder;

/* loaded from: input_file:com/floragunn/searchguard/enterprise/dlsfls/legacy/FieldMaskingAggregationTest.class */
public class FieldMaskingAggregationTest {
    private static final int DOC_COUNT = 1000;
    private static final TestSgConfig.User MASKED_TEST_USER = new TestSgConfig.User("masked_test").roles(new TestSgConfig.Role[]{new TestSgConfig.Role("mask").indexPermissions(new String[]{"*"}).maskedFields(new String[]{"*ip::/[0-9]{1,3}$/::XXX", "source_loc"}).on(new String[]{"ip"}).clusterPermissions(new String[]{"*"})});
    private static final TestSgConfig.User UNMASKED_TEST_USER = new TestSgConfig.User("unmasked_test").roles(new TestSgConfig.Role[]{new TestSgConfig.Role("allaccess").indexPermissions(new String[]{"*"}).on(new String[]{"ip"}).clusterPermissions(new String[]{"*"})});
    private static final byte[] salt = "e1ukloTsQlOgPquJ".getBytes(StandardCharsets.UTF_8);

    @ClassRule
    public static JavaSecurityTestSetup javaSecurity = new JavaSecurityTestSetup();

    @ClassRule
    public static LocalCluster cluster = new LocalCluster.Builder().sslEnabled().enterpriseModulesEnabled().users(new TestSgConfig.User[]{MASKED_TEST_USER, UNMASKED_TEST_USER}).resources("dlsfls_legacy").build();
    private static ReferenceAggregationTable referenceAggregationTable = new ReferenceAggregationTable().maskingFunction("source_loc", "hash", Masks::blake2bHash).maskingFunction("source_ip", "masked", str -> {
        return Masks.regexReplace(str, Pattern.compile("[0-9]{1,3}$"), "XXX");
    });
    private static final Logger log = LogManager.getLogger(FieldMaskingAggregationTest.class);

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/floragunn/searchguard/enterprise/dlsfls/legacy/FieldMaskingAggregationTest$Masks.class */
    public static class Masks {
        Masks() {
        }

        static String blake2bHash(String str) {
            return new String(blake2bHash(str.getBytes()));
        }

        static byte[] blake2bHash(byte[] bArr) {
            Blake2bDigest blake2bDigest = new Blake2bDigest((byte[]) null, 32, (byte[]) null, FieldMaskingAggregationTest.salt);
            blake2bDigest.update(bArr, 0, bArr.length);
            byte[] bArr2 = new byte[blake2bDigest.getDigestSize()];
            blake2bDigest.doFinal(bArr2, 0);
            return Hex.encode(bArr2);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public static String regexReplace(String str, Pattern pattern, String str2) {
            return pattern.matcher(str).replaceAll(str2);
        }
    }

    /* loaded from: input_file:com/floragunn/searchguard/enterprise/dlsfls/legacy/FieldMaskingAggregationTest$ReferenceAggregationTable.class */
    static class ReferenceAggregationTable {
        private Map<String, Map<String, Integer>> aggregatedAttributeCounts = new HashMap();
        private Map<String, Map<String, Function<String, String>>> maskingFunctions = new HashMap();

        ReferenceAggregationTable() {
        }

        ReferenceAggregationTable maskingFunction(String str, String str2, Function<String, String> function) {
            this.maskingFunctions.computeIfAbsent(str, str3 -> {
                return new HashMap();
            }).put(str2, function);
            return this;
        }

        void add(Collection<Map<String, ?>> collection) {
            Iterator<Map<String, ?>> it = collection.iterator();
            while (it.hasNext()) {
                add(it.next());
            }
        }

        void add(Map<String, ?> map) {
            for (Map.Entry<String, ?> entry : map.entrySet()) {
                this.aggregatedAttributeCounts.computeIfAbsent(entry.getKey(), str -> {
                    return new HashMap();
                }).compute(String.valueOf(entry.getValue()), (str2, num) -> {
                    return Integer.valueOf(num == null ? 1 : num.intValue() + 1);
                });
                addMaskedValues(entry.getKey(), String.valueOf(entry.getValue()));
            }
        }

        int getCount(String str, String str2) {
            Map<String, Integer> map = this.aggregatedAttributeCounts.get(str);
            if (map == null) {
                throw new IllegalArgumentException("Unknown attribute " + str + "; available: " + this.aggregatedAttributeCounts.keySet());
            }
            Integer num = map.get(str2);
            if (num != null) {
                return num.intValue();
            }
            return 0;
        }

        private void addMaskedValues(String str, String str2) {
            Map<String, Function<String, String>> map = this.maskingFunctions.get(str);
            if (map != null) {
                for (Map.Entry<String, Function<String, String>> entry : map.entrySet()) {
                    this.aggregatedAttributeCounts.computeIfAbsent(str + ":" + entry.getKey(), str3 -> {
                        return new HashMap();
                    }).compute(entry.getValue().apply(str2), (str4, num) -> {
                        return Integer.valueOf(num == null ? 1 : num.intValue() + 1);
                    });
                }
            }
        }
    }

    @BeforeClass
    public static void setupTestData() {
        Client internalNodeClient = cluster.getInternalNodeClient();
        try {
            TestData testData = TestData.documentCount(DOC_COUNT).get();
            testData.createIndex(internalNodeClient, "ip", Settings.builder().put("index.number_of_shards", 5).build());
            referenceAggregationTable.add(testData.getRetainedDocuments().values());
            if (internalNodeClient != null) {
                internalNodeClient.close();
            }
        } catch (Throwable th) {
            if (internalNodeClient != null) {
                try {
                    internalNodeClient.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testPartiallyMaskedField() throws Exception {
        RestHighLevelClient restHighLevelClient = cluster.getRestHighLevelClient(MASKED_TEST_USER);
        try {
            SearchResponse search = restHighLevelClient.search(new SearchRequest(new String[]{"ip"}).source(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery()).size(10).aggregation(AggregationBuilders.terms("source_ip_terms").field("source_ip.keyword").size(100).shardSize(DOC_COUNT))), RequestOptions.DEFAULT);
            log.info(Strings.toString(search, true, true));
            ParsedStringTerms parsedStringTerms = (ParsedStringTerms) search.getAggregations().asList().get(0);
            for (int i = 0; i < parsedStringTerms.getBuckets().size(); i++) {
                Terms.Bucket bucket = (Terms.Bucket) parsedStringTerms.getBuckets().get(i);
                Assert.assertEquals("Bucket " + i + ":\n" + toxToString(bucket), referenceAggregationTable.getCount("source_ip:masked", bucket.getKeyAsString()), bucket.getDocCount());
            }
            if (restHighLevelClient != null) {
                restHighLevelClient.close();
            }
        } catch (Throwable th) {
            if (restHighLevelClient != null) {
                try {
                    restHighLevelClient.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testHashMaskedField() throws Exception {
        SearchRequest source = new SearchRequest(new String[]{"ip"}).source(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery()).size(10).aggregation(AggregationBuilders.terms("source_loc_terms").field("source_loc.keyword").size(DOC_COUNT)));
        RestHighLevelClient restHighLevelClient = cluster.getRestHighLevelClient(UNMASKED_TEST_USER);
        try {
            SearchResponse search = restHighLevelClient.search(source, RequestOptions.DEFAULT);
            log.info(Strings.toString(search, true, true));
            if (restHighLevelClient != null) {
                restHighLevelClient.close();
            }
            restHighLevelClient = cluster.getRestHighLevelClient(MASKED_TEST_USER);
            try {
                SearchResponse search2 = restHighLevelClient.search(source, RequestOptions.DEFAULT);
                log.info(Strings.toString(search2, true, true));
                if (restHighLevelClient != null) {
                    restHighLevelClient.close();
                }
                compareHashedBuckets(search2, search);
            } finally {
            }
        } finally {
        }
    }

    @Test
    public void testHashMaskedFieldWithShardSizeParam() throws Exception {
        SearchRequest source = new SearchRequest(new String[]{"ip"}).source(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery()).size(10).aggregation(AggregationBuilders.terms("source_loc_terms").field("source_loc.keyword").size(100).shardSize(DOC_COUNT)));
        RestHighLevelClient restHighLevelClient = cluster.getRestHighLevelClient(UNMASKED_TEST_USER);
        try {
            SearchResponse search = restHighLevelClient.search(source, RequestOptions.DEFAULT);
            log.info(Strings.toString(search, true, true));
            if (restHighLevelClient != null) {
                restHighLevelClient.close();
            }
            restHighLevelClient = cluster.getRestHighLevelClient(MASKED_TEST_USER);
            try {
                SearchResponse search2 = restHighLevelClient.search(source, RequestOptions.DEFAULT);
                log.info(Strings.toString(search2, true, true));
                if (restHighLevelClient != null) {
                    restHighLevelClient.close();
                }
                compareHashedBuckets(search2, search);
            } finally {
            }
        } finally {
        }
    }

    @Test
    public void testHashMaskedFieldOrderedByKey() throws Exception {
        SearchRequest source = new SearchRequest(new String[]{"ip"}).source(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery()).size(10).aggregation(AggregationBuilders.terms("source_loc_terms").field("source_loc.keyword").order(BucketOrder.key(true)).size(100).shardSize(DOC_COUNT).showTermDocCountError(true)));
        RestHighLevelClient restHighLevelClient = cluster.getRestHighLevelClient(UNMASKED_TEST_USER);
        try {
            log.info(Strings.toString(restHighLevelClient.search(source, RequestOptions.DEFAULT), true, true));
            if (restHighLevelClient != null) {
                restHighLevelClient.close();
            }
            restHighLevelClient = cluster.getRestHighLevelClient(MASKED_TEST_USER);
            try {
                SearchResponse search = restHighLevelClient.search(source, RequestOptions.DEFAULT);
                log.info(Strings.toString(search, true, true));
                if (restHighLevelClient != null) {
                    restHighLevelClient.close();
                }
                ParsedStringTerms parsedStringTerms = (ParsedStringTerms) search.getAggregations().asList().get(0);
                for (int i = 0; i < parsedStringTerms.getBuckets().size(); i++) {
                    Terms.Bucket bucket = (Terms.Bucket) parsedStringTerms.getBuckets().get(i);
                    Assert.assertEquals("Bucket " + i + ":\n" + toxToString(bucket), referenceAggregationTable.getCount("source_loc:hash", bucket.getKeyAsString()), bucket.getDocCount());
                }
            } finally {
            }
        } finally {
        }
    }

    private void compareHashedBuckets(SearchResponse searchResponse, SearchResponse searchResponse2) {
        ParsedStringTerms parsedStringTerms = (ParsedStringTerms) searchResponse.getAggregations().asList().get(0);
        ParsedStringTerms parsedStringTerms2 = (ParsedStringTerms) searchResponse2.getAggregations().asList().get(0);
        Assert.assertEquals(parsedStringTerms2.getBuckets().size(), parsedStringTerms.getBuckets().size());
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        Terms.Bucket bucket = null;
        int i = 0;
        for (int i2 = 0; i2 < parsedStringTerms2.getBuckets().size(); i2++) {
            Terms.Bucket bucket2 = (Terms.Bucket) parsedStringTerms2.getBuckets().get(i2);
            Terms.Bucket bucket3 = (Terms.Bucket) parsedStringTerms.getBuckets().get(i2);
            if (bucket != null && bucket.getDocCount() != bucket2.getDocCount()) {
                Assert.assertEquals("Buckets at " + i + " to " + (i2 - 1) + ":\n" + toxToString(bucket2) + "\n" + toxToString(bucket3), hashSet, hashSet2);
                hashSet.clear();
                hashSet2.clear();
                i = 1;
            }
            Assert.assertEquals("Bucket " + i2 + ":\n" + toxToString(bucket2) + "\n" + toxToString(bucket3), bucket2.getDocCount(), bucket2.getDocCount());
            hashSet.add(Masks.blake2bHash(bucket2.getKeyAsString()));
            hashSet2.add(bucket3.getKeyAsString());
            bucket = bucket2;
        }
    }

    private static String toxToString(ToXContent toXContent) {
        try {
            XContentBuilder humanReadable = JsonXContent.contentBuilder().prettyPrint().humanReadable(true);
            toXContent.toXContent(humanReadable, ToXContent.EMPTY_PARAMS);
            return BytesReference.bytes(humanReadable).utf8ToString();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}
