/*
 * Decompiled with CFR 0.152.
 */
package com.floragunn.encryption.at.rest.index;

import com.floragunn.encryption.at.rest.key_management.AESKey;
import com.floragunn.encryption.at.rest.key_management.WrappedAESKey;
import com.floragunn.encryption.at.rest.key_management.WrappedAESKeyContainer;
import com.floragunn.encryption.at.rest.plugin.KeyStore;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.time.Duration;
import java.util.Base64;
import java.util.Map;
import java.util.Objects;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.index.IndexService;
import org.elasticsearch.index.engine.Engine;
import org.elasticsearch.index.mapper.ParsedDocument;
import org.elasticsearch.index.mapper.SourceToParse;
import org.elasticsearch.index.shard.IndexingOperationListener;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.index.shard.ShardPath;
import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentType;

public class CryptoTranslogIndexingOperationListener
implements IndexingOperationListener {
    private static final Logger logger = LogManager.getLogger(CryptoTranslogIndexingOperationListener.class);
    private static final String FILENAME = "_encrypted_translog_key";
    private final Cache<ShardId, AESKey> perShardKekWrappedKeyCache = CacheBuilder.newBuilder().maximumSize(1000L).expireAfterAccess(Duration.ofMinutes(60L)).removalListener(e -> logger.trace("Invalidate perShardKekWrappedKeyCache entry because of {}", (Object)e.getCause())).build();
    private final Cache<ShardId, WrappedAESKeyContainer> wrappedKeyCache = CacheBuilder.newBuilder().maximumSize(1000L).expireAfterWrite(Duration.ofMinutes(60L)).removalListener(e -> logger.trace("Invalidate wrappedKeyCache entry because of {}", (Object)e.getCause())).build();
    private final KeyStore keyStore;
    private final IndicesService indicesService;
    private final boolean onlyDecrypt;

    public CryptoTranslogIndexingOperationListener(KeyStore keyStore, IndicesService indicesService, boolean onlyDecrypt) {
        this.keyStore = Objects.requireNonNull(keyStore, "keyStore must not be null");
        this.indicesService = Objects.requireNonNull(indicesService, "indicesService must not be null");
        this.onlyDecrypt = onlyDecrypt;
    }

    public Engine.Index preIndex(ShardId shardId, Engine.Index _operation) {
        try {
            this.preIndex0(shardId, _operation);
        }
        catch (Exception e) {
            logger.error("preIndex Error {}", (Object)e, (Object)e);
            _operation.parsedDoc().docs().clear();
            _operation.parsedDoc().setSource(null, null);
            throw new RuntimeException(e);
        }
        return null;
    }

    private void preIndex0(ShardId shardId, Engine.Index _operation) throws Exception {
        if (_operation.operationType() != Engine.Operation.TYPE.INDEX) {
            return;
        }
        if (_operation.id() == null || _operation.id().isEmpty()) {
            logger.error("No id set for in index {}", (Object)shardId.getIndexName());
            throw new RuntimeException("no id set");
        }
        IndexService indexService = this.indicesService.indexService(shardId.getIndex());
        if (indexService == null) {
            throw new RuntimeException("indexService must not be null");
        }
        if (this.keyStore.getClusterKeK() == null) {
            throw new Exception("Cluster key must not be null here");
        }
        ShardPath shardPath = indexService.getShard(shardId.id()).shardPath();
        Path storedFile = shardPath.getDataPath().resolve(FILENAME);
        AESKey key = (AESKey)this.perShardKekWrappedKeyCache.get((Object)shardId, () -> {
            if (Files.exists(storedFile, new LinkOption[0])) {
                if (logger.isTraceEnabled()) {
                    logger.trace("Read perShardKekWrappedKeyCache from {}", (Object)storedFile.toAbsolutePath().toString());
                }
                byte[] encryptedKey = Files.readAllBytes(storedFile);
                return this.keyStore.getClusterKeK().unwrapAESKey(new WrappedAESKey(encryptedKey));
            }
            if (logger.isTraceEnabled()) {
                logger.trace("Read perShardKekWrappedKeyCache from cache");
            }
            WrappedAESKeyContainer wkey = this.keyStore.getClusterKeK().newRandomWrapped();
            Files.write(storedFile, wkey.wrappedAESKey().bytes(), StandardOpenOption.WRITE, StandardOpenOption.SYNC, StandardOpenOption.DSYNC, StandardOpenOption.CREATE_NEW);
            return wkey.aesKey();
        });
        if (logger.isTraceEnabled()) {
            logger.trace("preIndex on {}/{} Id: {} of type {} from {} with routing {} isRetry {}, seqno: {} , primary term: {}, version {}, version type: {}", new Object[]{shardId.getIndexName(), shardId.id(), _operation.id(), _operation.operationType(), _operation.origin(), _operation.routing(), _operation.isRetry(), _operation.seqNo(), _operation.primaryTerm(), _operation.version(), _operation.versionType()});
            try {
                logger.trace(_operation.source().utf8ToString());
            }
            catch (Exception e) {
                logger.trace("Unable to print source content because of {}", (Object)e.toString());
            }
        }
        if (_operation.origin() == Engine.Operation.Origin.PEER_RECOVERY || _operation.origin() == Engine.Operation.Origin.PRIMARY || _operation.origin() == Engine.Operation.Origin.REPLICA) {
            this.encryptSourceFieldForTranslogAndStoredField(_operation, key, shardId);
            return;
        }
        if (_operation.origin() == Engine.Operation.Origin.LOCAL_TRANSLOG_RECOVERY || _operation.origin() == Engine.Operation.Origin.LOCAL_RESET) {
            BytesReference decryptedSource = this.decryptSourceFieldForTranslog(_operation, key, shardId);
            if (decryptedSource != null) {
                ParsedDocument decryptedParsedDocument = indexService.mapperService().documentMapper().parse(new SourceToParse(_operation.id(), decryptedSource, XContentType.JSON, _operation.routing()));
                _operation.parsedDoc().docs().clear();
                _operation.parsedDoc().docs().addAll(decryptedParsedDocument.docs());
            }
            return;
        }
        throw new Exception("unreachable code");
    }

    private void encryptSourceFieldForTranslogAndStoredField(Engine.Index operation, AESKey key, ShardId shardId) throws Exception {
        if (this.onlyDecrypt) {
            return;
        }
        if (logger.isTraceEnabled()) {
            logger.trace("Encrypt id {} on shard {}", (Object)operation.id(), (Object)shardId.id());
        }
        BytesReference source = operation.parsedDoc().source();
        byte[] sb = source.array();
        if (source.length() > 20 && sb[2 + source.arrayOffset()] == 95 && sb[13 + source.arrayOffset()] == 116 && sb[14 + source.arrayOffset()] == 108 && source.utf8ToString().startsWith("{\"_encrypted_tl_content\":")) {
            logger.error("source for id {} is already encrypted", (Object)operation.id());
            assert (false) : "source for id " + operation.id() + " is already encrypted";
            return;
        }
        WrappedAESKeyContainer wrappedKey = (WrappedAESKeyContainer)this.wrappedKeyCache.get((Object)shardId, () -> {
            if (logger.isTraceEnabled()) {
                logger.trace("Rotate key for shard {}", (Object)shardId.id());
            }
            return key.newRandomWrapped();
        });
        byte[] encryptedSource = wrappedKey.encryptWithSiv(source, operation.id().getBytes(StandardCharsets.UTF_8));
        operation.parsedDoc().setSource((BytesReference)new BytesArray("{\"_encrypted_tl_content\":\"" + Base64.getEncoder().encodeToString(encryptedSource) + "\"}"), operation.parsedDoc().getXContentType());
    }

    private BytesReference decryptSourceFieldForTranslog(Engine.Index operation, AESKey key, ShardId shardId) throws Exception {
        if (logger.isTraceEnabled()) {
            logger.trace("Decrypt id {} on shard {}", (Object)operation.id(), (Object)shardId.id());
        }
        try {
            Map jsonMap = XContentHelper.convertToMap((XContent)operation.parsedDoc().getXContentType().xContent(), (InputStream)operation.parsedDoc().source().streamInput(), (boolean)false);
            Object encryptedTlField = jsonMap.get("_encrypted_tl_content");
            if (encryptedTlField == null) {
                logger.trace("skip decryption of Id: {}", (Object)operation.id());
                return null;
            }
            byte[] decodedTlField = Base64.getDecoder().decode(encryptedTlField.toString());
            logger.trace("decrypt Id: {}", (Object)operation.id());
            BytesReference decryptedSource = AESKey.decryptWithSiv(decodedTlField, key, operation.id().getBytes(StandardCharsets.UTF_8));
            operation.parsedDoc().setSource(decryptedSource, operation.parsedDoc().getXContentType());
            return decryptedSource;
        }
        catch (Exception e) {
            throw new RuntimeException("decryption failed", e);
        }
    }

    void invalidateCache() {
        this.perShardKekWrappedKeyCache.invalidateAll();
        this.wrappedKeyCache.invalidateAll();
    }
}

