/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.network.crypto;

import com.google.crypto.tink.subtle.AesGcmHkdfStreaming;
import com.google.crypto.tink.subtle.StreamSegmentDecrypter;
import com.google.crypto.tink.subtle.StreamSegmentEncrypter;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.channel.FileRegion;
import io.netty.util.ReferenceCounted;
import java.io.IOException;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.channels.WritableByteChannel;
import java.security.GeneralSecurityException;
import java.security.InvalidAlgorithmParameterException;
import javax.crypto.spec.SecretKeySpec;
import org.apache.spark.network.crypto.TransportCipher;
import org.apache.spark.network.crypto.TransportCipherUtil;
import org.apache.spark.network.util.AbstractFileRegion;
import org.apache.spark.network.util.ByteBufferWriteableChannel;
import org.sparkproject.guava.annotations.VisibleForTesting;
import org.sparkproject.guava.base.Preconditions;
import org.sparkproject.guava.primitives.Longs;

public class GcmTransportCipher
implements TransportCipher {
    private static final String HKDF_ALG = "HmacSha256";
    private static final int LENGTH_HEADER_BYTES = 8;
    @VisibleForTesting
    static final int CIPHERTEXT_BUFFER_SIZE = 32768;
    private final SecretKeySpec aesKey;

    public GcmTransportCipher(SecretKeySpec aesKey) {
        this.aesKey = aesKey;
    }

    AesGcmHkdfStreaming getAesGcmHkdfStreaming() throws InvalidAlgorithmParameterException {
        return new AesGcmHkdfStreaming(this.aesKey.getEncoded(), HKDF_ALG, this.aesKey.getEncoded().length, 32768, 0);
    }

    @Override
    @VisibleForTesting
    public String getKeyId() throws GeneralSecurityException {
        return TransportCipherUtil.getKeyId(this.aesKey);
    }

    @VisibleForTesting
    EncryptionHandler getEncryptionHandler() throws GeneralSecurityException {
        return new EncryptionHandler();
    }

    @VisibleForTesting
    DecryptionHandler getDecryptionHandler() throws GeneralSecurityException {
        return new DecryptionHandler();
    }

    @Override
    public void addToChannel(Channel ch) throws GeneralSecurityException {
        ch.pipeline().addFirst("GcmTransportEncryption", (ChannelHandler)this.getEncryptionHandler()).addFirst("GcmTransportDecryption", (ChannelHandler)this.getDecryptionHandler());
    }

    @VisibleForTesting
    class DecryptionHandler
    extends ChannelInboundHandlerAdapter {
        private final ByteBuffer expectedLengthBuffer;
        private final ByteBuffer headerBuffer;
        private final ByteBuffer ciphertextBuffer;
        private final AesGcmHkdfStreaming aesGcmHkdfStreaming;
        private final StreamSegmentDecrypter decrypter;
        private final int plaintextSegmentSize;
        private boolean decrypterInit = false;
        private boolean completed = false;
        private int segmentNumber = 0;
        private long expectedLength = -1L;
        private long ciphertextRead = 0L;

        DecryptionHandler() throws GeneralSecurityException {
            this.aesGcmHkdfStreaming = GcmTransportCipher.this.getAesGcmHkdfStreaming();
            this.expectedLengthBuffer = ByteBuffer.allocate(8);
            this.headerBuffer = ByteBuffer.allocate(this.aesGcmHkdfStreaming.getHeaderLength());
            this.ciphertextBuffer = ByteBuffer.allocate(this.aesGcmHkdfStreaming.getCiphertextSegmentSize());
            this.decrypter = this.aesGcmHkdfStreaming.newStreamSegmentDecrypter();
            this.plaintextSegmentSize = this.aesGcmHkdfStreaming.getPlaintextSegmentSize();
        }

        private boolean initalizeExpectedLength(ByteBuf ciphertextNettyBuf) {
            if (this.expectedLength < 0L) {
                ciphertextNettyBuf.readBytes(this.expectedLengthBuffer);
                if (this.expectedLengthBuffer.hasRemaining()) {
                    return false;
                }
                ((Buffer)this.expectedLengthBuffer).flip();
                this.expectedLength = this.expectedLengthBuffer.getLong();
                if (this.expectedLength < 0L) {
                    throw new IllegalStateException("Invalid expected ciphertext length.");
                }
                this.ciphertextRead += 8L;
            }
            return true;
        }

        private boolean initalizeDecrypter(ByteBuf ciphertextNettyBuf) throws GeneralSecurityException {
            if (!this.decrypterInit) {
                ciphertextNettyBuf.readBytes(this.headerBuffer);
                if (this.headerBuffer.hasRemaining()) {
                    return false;
                }
                ((Buffer)this.headerBuffer).flip();
                byte[] lengthAad = Longs.toByteArray(this.expectedLength);
                this.decrypter.init(this.headerBuffer, lengthAad);
                this.decrypterInit = true;
                this.ciphertextRead += (long)this.aesGcmHkdfStreaming.getHeaderLength();
                if (this.expectedLength == this.ciphertextRead) {
                    this.completed = true;
                }
            }
            return true;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage) throws GeneralSecurityException {
            Preconditions.checkArgument(ciphertextMessage instanceof ByteBuf, "Unrecognized message type: %s", ciphertextMessage.getClass().getName());
            ByteBuf ciphertextNettyBuf = (ByteBuf)ciphertextMessage;
            try {
                if (!this.initalizeExpectedLength(ciphertextNettyBuf)) {
                    return;
                }
                if (!this.initalizeDecrypter(ciphertextNettyBuf)) {
                    return;
                }
                int nettyBufReadableBytes = ciphertextNettyBuf.readableBytes();
                while (nettyBufReadableBytes > 0 && !this.completed) {
                    int readableBytes = Integer.min(nettyBufReadableBytes, this.ciphertextBuffer.remaining());
                    int expectedRemaining = (int)(this.expectedLength - this.ciphertextRead);
                    int bytesToRead = Integer.min(readableBytes, expectedRemaining);
                    ((Buffer)this.ciphertextBuffer).limit(this.ciphertextBuffer.position() + bytesToRead);
                    ciphertextNettyBuf.readBytes(this.ciphertextBuffer);
                    this.ciphertextRead += (long)bytesToRead;
                    if (this.ciphertextRead == this.expectedLength) {
                        this.completed = true;
                    } else if (this.ciphertextRead > this.expectedLength) {
                        throw new IllegalStateException("Read more ciphertext than expected.");
                    }
                    if (this.ciphertextBuffer.limit() == this.ciphertextBuffer.capacity() || this.completed) {
                        ByteBuffer plaintextBuffer = ByteBuffer.allocate(this.plaintextSegmentSize);
                        ((Buffer)this.ciphertextBuffer).flip();
                        this.decrypter.decryptSegment(this.ciphertextBuffer, this.segmentNumber, this.completed, plaintextBuffer);
                        ++this.segmentNumber;
                        ((Buffer)this.ciphertextBuffer).clear();
                        ((Buffer)plaintextBuffer).flip();
                        ctx.fireChannelRead((Object)Unpooled.wrappedBuffer((ByteBuffer)plaintextBuffer));
                    } else {
                        ((Buffer)this.ciphertextBuffer).limit(this.ciphertextBuffer.capacity());
                    }
                    nettyBufReadableBytes = ciphertextNettyBuf.readableBytes();
                }
            }
            finally {
                ciphertextNettyBuf.release();
            }
        }
    }

    static class GcmEncryptedMessage
    extends AbstractFileRegion {
        private final Object plaintextMessage;
        private final ByteBuffer plaintextBuffer;
        private final ByteBuffer ciphertextBuffer;
        private final ByteBuffer headerByteBuffer;
        private final long bytesToRead;
        private long bytesRead = 0L;
        private final StreamSegmentEncrypter encrypter;
        private long transferred = 0L;
        private final long encryptedCount;

        GcmEncryptedMessage(AesGcmHkdfStreaming aesGcmHkdfStreaming, Object plaintextMessage, ByteBuffer plaintextBuffer, ByteBuffer ciphertextBuffer) throws GeneralSecurityException {
            Preconditions.checkArgument(plaintextMessage instanceof ByteBuf || plaintextMessage instanceof FileRegion, "Unrecognized message type: %s", plaintextMessage.getClass().getName());
            this.plaintextMessage = plaintextMessage;
            this.plaintextBuffer = plaintextBuffer;
            this.ciphertextBuffer = ciphertextBuffer;
            ((Buffer)this.ciphertextBuffer).limit(0);
            this.bytesToRead = this.getReadableBytes();
            this.encryptedCount = 8L + aesGcmHkdfStreaming.expectedCiphertextSize(this.bytesToRead);
            byte[] lengthAad = Longs.toByteArray(this.encryptedCount);
            this.encrypter = aesGcmHkdfStreaming.newStreamSegmentEncrypter(lengthAad);
            this.headerByteBuffer = this.createHeaderByteBuffer();
        }

        private ByteBuffer createHeaderByteBuffer() {
            ByteBuffer encrypterHeader = this.encrypter.getHeader();
            ByteBuffer output = ByteBuffer.allocate(encrypterHeader.remaining() + 8).putLong(this.encryptedCount).put(encrypterHeader);
            ((Buffer)output).flip();
            return output;
        }

        public long position() {
            return 0L;
        }

        public long transferred() {
            return this.transferred;
        }

        public long count() {
            return this.encryptedCount;
        }

        @Override
        public GcmEncryptedMessage touch(Object o) {
            super.touch(o);
            if (this.plaintextMessage instanceof ByteBuf) {
                ByteBuf byteBuf = (ByteBuf)this.plaintextMessage;
                byteBuf.touch(o);
            } else if (this.plaintextMessage instanceof FileRegion) {
                FileRegion fileRegion = (FileRegion)this.plaintextMessage;
                fileRegion.touch(o);
            }
            return this;
        }

        @Override
        public GcmEncryptedMessage retain(int increment) {
            super.retain(increment);
            if (this.plaintextMessage instanceof ByteBuf) {
                ByteBuf byteBuf = (ByteBuf)this.plaintextMessage;
                byteBuf.retain(increment);
            } else if (this.plaintextMessage instanceof FileRegion) {
                FileRegion fileRegion = (FileRegion)this.plaintextMessage;
                fileRegion.retain(increment);
            }
            return this;
        }

        public boolean release(int decrement) {
            if (this.plaintextMessage instanceof ByteBuf) {
                ByteBuf byteBuf = (ByteBuf)this.plaintextMessage;
                byteBuf.release(decrement);
            } else if (this.plaintextMessage instanceof FileRegion) {
                FileRegion fileRegion = (FileRegion)this.plaintextMessage;
                fileRegion.release(decrement);
            }
            return super.release(decrement);
        }

        public long transferTo(WritableByteChannel target, long position) throws IOException {
            int written;
            int transferredThisCall = 0;
            if (this.headerByteBuffer.hasRemaining()) {
                written = target.write(this.headerByteBuffer);
                transferredThisCall += written;
                this.transferred += (long)written;
                if (this.headerByteBuffer.hasRemaining()) {
                    return written;
                }
            }
            if (this.ciphertextBuffer.hasRemaining()) {
                written = target.write(this.ciphertextBuffer);
                transferredThisCall += written;
                this.transferred += (long)written;
                if (this.ciphertextBuffer.hasRemaining()) {
                    return transferredThisCall;
                }
            }
            while (this.bytesRead < this.bytesToRead) {
                ByteBufferWriteableChannel plaintextChannel;
                FileRegion fileRegion;
                long plaintextRead;
                long readableBytes = this.getReadableBytes();
                int readLimit = (int)Math.min(readableBytes, (long)this.plaintextBuffer.remaining());
                if (this.plaintextMessage instanceof ByteBuf) {
                    ByteBuf byteBuf = (ByteBuf)this.plaintextMessage;
                    Preconditions.checkState(0 == this.plaintextBuffer.position());
                    ((Buffer)this.plaintextBuffer).limit(readLimit);
                    byteBuf.readBytes(this.plaintextBuffer);
                    Preconditions.checkState(readLimit == this.plaintextBuffer.position());
                } else if (this.plaintextMessage instanceof FileRegion && (plaintextRead = (fileRegion = (FileRegion)this.plaintextMessage).transferTo((WritableByteChannel)(plaintextChannel = new ByteBufferWriteableChannel(this.plaintextBuffer)), fileRegion.transferred())) < (long)readLimit) {
                    return transferredThisCall;
                }
                boolean lastSegment = this.getReadableBytes() == 0L;
                ((Buffer)this.plaintextBuffer).flip();
                this.bytesRead += (long)this.plaintextBuffer.remaining();
                ((Buffer)this.ciphertextBuffer).clear();
                try {
                    this.encrypter.encryptSegment(this.plaintextBuffer, lastSegment, this.ciphertextBuffer);
                }
                catch (GeneralSecurityException e) {
                    throw new IllegalStateException("GeneralSecurityException from encrypter", e);
                }
                ((Buffer)this.plaintextBuffer).clear();
                ((Buffer)this.ciphertextBuffer).flip();
                int written2 = target.write(this.ciphertextBuffer);
                transferredThisCall += written2;
                this.transferred += (long)written2;
                if (!this.ciphertextBuffer.hasRemaining()) continue;
                return transferredThisCall;
            }
            return transferredThisCall;
        }

        private long getReadableBytes() {
            if (this.plaintextMessage instanceof ByteBuf) {
                ByteBuf byteBuf = (ByteBuf)this.plaintextMessage;
                return byteBuf.readableBytes();
            }
            if (this.plaintextMessage instanceof FileRegion) {
                FileRegion fileRegion = (FileRegion)this.plaintextMessage;
                return fileRegion.count() - fileRegion.transferred();
            }
            throw new IllegalArgumentException("Unsupported message type: " + this.plaintextMessage.getClass().getName());
        }

        protected void deallocate() {
            if (this.plaintextMessage instanceof ReferenceCounted) {
                ((ReferenceCounted)this.plaintextMessage).release();
            }
            this.plaintextBuffer.clear();
            this.ciphertextBuffer.clear();
        }
    }

    @VisibleForTesting
    class EncryptionHandler
    extends ChannelOutboundHandlerAdapter {
        private final ByteBuffer plaintextBuffer;
        private final ByteBuffer ciphertextBuffer;
        private final AesGcmHkdfStreaming aesGcmHkdfStreaming;

        EncryptionHandler() throws InvalidAlgorithmParameterException {
            this.aesGcmHkdfStreaming = GcmTransportCipher.this.getAesGcmHkdfStreaming();
            this.plaintextBuffer = ByteBuffer.allocate(this.aesGcmHkdfStreaming.getPlaintextSegmentSize());
            this.ciphertextBuffer = ByteBuffer.allocate(this.aesGcmHkdfStreaming.getCiphertextSegmentSize());
        }

        public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
            GcmEncryptedMessage encryptedMessage = new GcmEncryptedMessage(this.aesGcmHkdfStreaming, msg, this.plaintextBuffer, this.ciphertextBuffer);
            ctx.write((Object)encryptedMessage, promise);
        }
    }
}

