package org.jruby.puma;

import org.jruby.Ruby;
import org.jruby.RubyArray;
import org.jruby.RubyClass;
import org.jruby.RubyModule;
import org.jruby.RubyObject;
import org.jruby.RubyString;
import org.jruby.anno.JRubyMethod;
import org.jruby.exceptions.RaiseException;
import org.jruby.javasupport.JavaEmbedUtils;
import org.jruby.runtime.Block;
import org.jruby.runtime.ObjectAllocator;
import org.jruby.runtime.ThreadContext;
import org.jruby.runtime.builtin.IRubyObject;
import org.jruby.util.ByteList;

import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.X509TrustManager;
import java.io.FileInputStream;
import java.io.InputStream;
import java.io.IOException;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.security.KeyManagementException;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.UnrecoverableKeyException;
import java.security.cert.Certificate;
import java.security.cert.CertificateEncodingException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.concurrent.ConcurrentHashMap;
import java.util.Map;
import java.util.function.Supplier;

import static javax.net.ssl.SSLEngineResult.Status;
import static javax.net.ssl.SSLEngineResult.HandshakeStatus;

public class MiniSSL extends RubyObject { // MiniSSL::Engine
  private static ObjectAllocator ALLOCATOR = new ObjectAllocator() {
    public IRubyObject allocate(Ruby runtime, RubyClass klass) {
      return new MiniSSL(runtime, klass);
    }
  };

  public static void createMiniSSL(Ruby runtime) {
    RubyModule mPuma = runtime.defineModule("Puma");
    RubyModule ssl = mPuma.defineModuleUnder("MiniSSL");

    // Puma::MiniSSL::SSLError
    ssl.defineClassUnder("SSLError", runtime.getStandardError(), runtime.getStandardError().getAllocator());

    RubyClass eng = ssl.defineClassUnder("Engine", runtime.getObject(), ALLOCATOR);
    eng.defineAnnotatedMethods(MiniSSL.class);
  }

  /**
   * Fairly transparent wrapper around {@link java.nio.ByteBuffer} which adds the enhancements we need
   */
  private static class MiniSSLBuffer {
    ByteBuffer buffer;

    private MiniSSLBuffer(int capacity) { buffer = ByteBuffer.allocate(capacity); }
    private MiniSSLBuffer(byte[] initialContents) { buffer = ByteBuffer.wrap(initialContents); }

    public void clear() { buffer.clear(); }
    public void compact() { buffer.compact(); }
    public void flip() { ((Buffer) buffer).flip(); }
    public boolean hasRemaining() { return buffer.hasRemaining(); }
    public int position() { return buffer.position(); }

    public ByteBuffer getRawBuffer() {
      return buffer;
    }

    /**
     * Writes bytes to the buffer after ensuring there's room
     */
    private void put(byte[] bytes, final int offset, final int length) {
      if (buffer.remaining() < length) {
        resize(buffer.limit() + length);
      }
      buffer.put(bytes, offset, length);
    }

    /**
     * Ensures that newCapacity bytes can be written to this buffer, only re-allocating if necessary
     */
    public void resize(int newCapacity) {
      if (newCapacity > buffer.capacity()) {
        ByteBuffer dstTmp = ByteBuffer.allocate(newCapacity);
        flip();
        dstTmp.put(buffer);
        buffer = dstTmp;
      } else {
        buffer.limit(newCapacity);
      }
    }

    /**
     * Drains the buffer to a ByteList, or returns null for an empty buffer
     */
    public ByteList asByteList() {
      flip();
      if (!buffer.hasRemaining()) {
        buffer.clear();
        return null;
      }

      byte[] bss = new byte[buffer.limit()];

      buffer.get(bss);
      buffer.clear();
      return new ByteList(bss, false);
    }

    @Override
    public String toString() { return buffer.toString(); }
  }

  private SSLEngine engine;
  private boolean closed;
  private boolean handshake;
  private MiniSSLBuffer inboundNetData;
  private MiniSSLBuffer outboundAppData;
  private MiniSSLBuffer outboundNetData;

  public MiniSSL(Ruby runtime, RubyClass klass) {
    super(runtime, klass);
  }

  private static Map<String, KeyManagerFactory> keyManagerFactoryMap = new ConcurrentHashMap<String, KeyManagerFactory>();
  private static Map<String, TrustManagerFactory> trustManagerFactoryMap = new ConcurrentHashMap<String, TrustManagerFactory>();

  @JRubyMethod(meta = true) // Engine.server
  public static synchronized IRubyObject server(ThreadContext context, IRubyObject recv, IRubyObject miniSSLContext)
      throws KeyStoreException, IOException, CertificateException, NoSuchAlgorithmException, UnrecoverableKeyException {
    // Create the KeyManagerFactory and TrustManagerFactory for this server
    String keystoreFile = asStringValue(miniSSLContext.callMethod(context, "keystore"), null);
    char[] keystorePass = asStringValue(miniSSLContext.callMethod(context, "keystore_pass"), null).toCharArray();
    String keystoreType = asStringValue(miniSSLContext.callMethod(context, "keystore_type"), KeyStore::getDefaultType);

    String truststoreFile;
    char[] truststorePass;
    String truststoreType;
    IRubyObject truststore = miniSSLContext.callMethod(context, "truststore");
    if (truststore.isNil()) {
      truststoreFile = keystoreFile;
      truststorePass = keystorePass;
      truststoreType = keystoreType;
    } else if (!isDefaultSymbol(context, truststore)) {
      truststoreFile = truststore.convertToString().asJavaString();
      IRubyObject pass = miniSSLContext.callMethod(context, "truststore_pass");
      if (pass.isNil()) {
        truststorePass = null;
      } else {
        truststorePass = asStringValue(pass, null).toCharArray();
      }
      truststoreType = asStringValue(miniSSLContext.callMethod(context, "truststore_type"), KeyStore::getDefaultType);
    } else { // self.truststore = :default
      truststoreFile = null;
      truststorePass = null;
      truststoreType = null;
    }

    KeyStore ks = KeyStore.getInstance(keystoreType);
    InputStream is = new FileInputStream(keystoreFile);
    try {
      ks.load(is, keystorePass);
    } finally {
      is.close();
    }
    KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
    kmf.init(ks, keystorePass);
    keyManagerFactoryMap.put(keystoreFile, kmf);

    if (truststoreFile != null) {
      KeyStore ts = KeyStore.getInstance(truststoreType);
      is = new FileInputStream(truststoreFile);
      try {
        ts.load(is, truststorePass);
      } finally {
        is.close();
      }
      TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
      tmf.init(ts);
      trustManagerFactoryMap.put(truststoreFile, tmf);
    }

    RubyClass klass = (RubyClass) recv;
    return klass.newInstance(context, miniSSLContext, Block.NULL_BLOCK);
  }

  private static String asStringValue(IRubyObject value, Supplier<String> defaultValue) {
    if (defaultValue != null && value.isNil()) return defaultValue.get();
    return value.convertToString().asJavaString();
  }

  private static boolean isDefaultSymbol(ThreadContext context, IRubyObject truststore) {
    return context.runtime.newSymbol("default").equals(truststore);
  }

  @JRubyMethod
  public IRubyObject initialize(ThreadContext context, IRubyObject miniSSLContext)
      throws KeyStoreException, NoSuchAlgorithmException, KeyManagementException {

    String keystoreFile = miniSSLContext.callMethod(context, "keystore").convertToString().asJavaString();
    KeyManagerFactory kmf = keyManagerFactoryMap.get(keystoreFile);
    IRubyObject truststore = miniSSLContext.callMethod(context, "truststore");
    String truststoreFile = isDefaultSymbol(context, truststore) ? "" : asStringValue(truststore, () -> keystoreFile);
    TrustManagerFactory tmf = trustManagerFactoryMap.get(truststoreFile); // null if self.truststore = :default
    if (kmf == null) {
      throw new KeyStoreException("Could not find KeyManagerFactory for keystore: " + keystoreFile + " truststore: " + truststoreFile);
    }

    SSLContext sslCtx = SSLContext.getInstance("TLS");

    sslCtx.init(kmf.getKeyManagers(), getTrustManagers(tmf), null);
    closed = false;
    handshake = false;
    engine = sslCtx.createSSLEngine();

    String[] enabledProtocols;
    IRubyObject protocols = miniSSLContext.callMethod(context, "protocols");
    if (protocols.isNil()) {
      if (miniSSLContext.callMethod(context, "no_tlsv1").isTrue()) {
        enabledProtocols = new String[] { "TLSv1.1", "TLSv1.2", "TLSv1.3" };
      } else {
        enabledProtocols = new String[] { "TLSv1", "TLSv1.1", "TLSv1.2", "TLSv1.3" };
      }

      if (miniSSLContext.callMethod(context, "no_tlsv1_1").isTrue()) {
        enabledProtocols = new String[] { "TLSv1.2", "TLSv1.3" };
      }
    } else if (protocols instanceof RubyArray) {
      enabledProtocols = (String[]) ((RubyArray) protocols).toArray(new String[0]);
    } else {
      throw context.runtime.newTypeError(protocols, context.runtime.getArray());
    }
    engine.setEnabledProtocols(enabledProtocols);

    engine.setUseClientMode(false);

    long verify_mode = miniSSLContext.callMethod(context, "verify_mode").convertToInteger("to_i").getLongValue();
    if ((verify_mode & 0x1) != 0) { // 'peer'
        engine.setWantClientAuth(true);
    }
    if ((verify_mode & 0x2) != 0) { // 'force_peer'
        engine.setNeedClientAuth(true);
    }

    IRubyObject cipher_suites = miniSSLContext.callMethod(context, "cipher_suites");
    if (cipher_suites instanceof RubyArray) {
      engine.setEnabledCipherSuites((String[]) ((RubyArray) cipher_suites).toArray(new String[0]));
    } else if (!cipher_suites.isNil()) {
      throw context.runtime.newTypeError(cipher_suites, context.runtime.getArray());
    }

    SSLSession session = engine.getSession();
    inboundNetData = new MiniSSLBuffer(session.getPacketBufferSize());
    outboundAppData = new MiniSSLBuffer(session.getApplicationBufferSize());
    outboundAppData.flip();
    outboundNetData = new MiniSSLBuffer(session.getPacketBufferSize());

    return this;
  }

  private TrustManager[] getTrustManagers(TrustManagerFactory factory) {
    if (factory == null) return null; // use JDK trust defaults
    final TrustManager[] tms = factory.getTrustManagers();
    if (tms != null) {
      for (int i=0; i<tms.length; i++) {
        final TrustManager tm = tms[i];
        if (tm instanceof X509TrustManager) {
          tms[i] = new TrustManagerWrapper((X509TrustManager) tm);
        }
      }
    }
    return tms;
  }

  private volatile transient X509Certificate lastCheckedCert0;

  private class TrustManagerWrapper implements X509TrustManager {

    private final X509TrustManager delegate;

    TrustManagerWrapper(X509TrustManager delegate) {
      this.delegate = delegate;
    }

    @Override
    public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException {
      lastCheckedCert0 = chain.length > 0 ? chain[0] : null;
      delegate.checkClientTrusted(chain, authType);
    }

    @Override
    public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
      delegate.checkServerTrusted(chain, authType);
    }

    @Override
    public X509Certificate[] getAcceptedIssuers() {
      return delegate.getAcceptedIssuers();
    }

  }

  @JRubyMethod
  public IRubyObject inject(IRubyObject arg) {
    ByteList bytes = arg.convertToString().getByteList();
    inboundNetData.put(bytes.unsafeBytes(), bytes.getBegin(), bytes.getRealSize());
    return this;
  }

  private enum SSLOperation {
    WRAP,
    UNWRAP
  }

  private SSLEngineResult doOp(SSLOperation sslOp, MiniSSLBuffer src, MiniSSLBuffer dst) throws SSLException {
    SSLEngineResult res = null;
    boolean retryOp = true;
    while (retryOp) {
      switch (sslOp) {
        case WRAP:
          res = engine.wrap(src.getRawBuffer(), dst.getRawBuffer());
          break;
        case UNWRAP:
          res = engine.unwrap(src.getRawBuffer(), dst.getRawBuffer());
          break;
        default:
          throw new AssertionError("Unknown SSLOperation: " + sslOp);
      }

      switch (res.getStatus()) {
        case BUFFER_OVERFLOW:
          // increase the buffer size to accommodate the overflowing data
          int newSize = Math.max(engine.getSession().getPacketBufferSize(), engine.getSession().getApplicationBufferSize());
          dst.resize(newSize + dst.position());
          // retry the operation
          retryOp = true;
          break;
        case BUFFER_UNDERFLOW:
          // need to wait for more data to come in before we retry
          retryOp = false;
          break;
        case CLOSED:
          closed = true;
          retryOp = false;
          break;
        default:
          // other case is OK.  We're done here.
          retryOp = false;
      }
      if (res.getHandshakeStatus() == HandshakeStatus.FINISHED) {
        handshake = true;
      }
    }

    return res;
  }

  @JRubyMethod
  public IRubyObject read() {
    try {
      inboundNetData.flip();

      if(!inboundNetData.hasRemaining()) {
        return getRuntime().getNil();
      }

      MiniSSLBuffer inboundAppData = new MiniSSLBuffer(engine.getSession().getApplicationBufferSize());
      doOp(SSLOperation.UNWRAP, inboundNetData, inboundAppData);

      HandshakeStatus handshakeStatus = engine.getHandshakeStatus();
      boolean done = false;
      while (!done) {
        SSLEngineResult res;
        switch (handshakeStatus) {
          case NEED_WRAP:
            res = doOp(SSLOperation.WRAP, inboundAppData, outboundNetData);
            handshakeStatus = res.getHandshakeStatus();
            break;
          case NEED_UNWRAP:
            res = doOp(SSLOperation.UNWRAP, inboundNetData, inboundAppData);
            if (res.getStatus() == Status.BUFFER_UNDERFLOW) {
              // need more data before we can shake more hands
              done = true;
            }
            handshakeStatus = res.getHandshakeStatus();
            break;
          case NEED_TASK:
            Runnable runnable;
            while ((runnable = engine.getDelegatedTask()) != null) {
              runnable.run();
            }
            handshakeStatus = engine.getHandshakeStatus();
            break;
          default:
            done = true;
        }
      }

      if (inboundNetData.hasRemaining()) {
        inboundNetData.compact();
      } else {
        inboundNetData.clear();
      }

      ByteList appDataByteList = inboundAppData.asByteList();
      if (appDataByteList == null) {
        return getRuntime().getNil();
      }

      return RubyString.newString(getRuntime(), appDataByteList);
    } catch (SSLException e) {
      throw newSSLError(getRuntime(), e);
    }
  }

  @JRubyMethod
  public IRubyObject write(IRubyObject arg) {
    byte[] bls = arg.convertToString().getBytes();
    outboundAppData = new MiniSSLBuffer(bls);

    return getRuntime().newFixnum(bls.length);
  }

  @JRubyMethod
  public IRubyObject extract(ThreadContext context) {
    try {
      ByteList dataByteList = outboundNetData.asByteList();
      if (dataByteList != null) {
        return RubyString.newString(context.runtime, dataByteList);
      }

      if (!outboundAppData.hasRemaining()) {
        return context.nil;
      }

      outboundNetData.clear();
      doOp(SSLOperation.WRAP, outboundAppData, outboundNetData);
      dataByteList = outboundNetData.asByteList();
      if (dataByteList == null) {
        return context.nil;
      }

      return RubyString.newString(context.runtime, dataByteList);
    } catch (SSLException e) {
      throw newSSLError(getRuntime(), e);
    }
  }

  @JRubyMethod
  public IRubyObject peercert(ThreadContext context) throws CertificateEncodingException {
    Certificate peerCert;
    try {
      peerCert = engine.getSession().getPeerCertificates()[0];
    } catch (SSLPeerUnverifiedException e) {
      peerCert = lastCheckedCert0; // null if trust check did not happen
    }
    return peerCert == null ? context.nil : JavaEmbedUtils.javaToRuby(context.runtime, peerCert.getEncoded());
  }

  @JRubyMethod(name = "init?")
  public IRubyObject isInit(ThreadContext context) {
    return handshake ? getRuntime().getFalse() : getRuntime().getTrue();
  }

  @JRubyMethod
  public IRubyObject shutdown() {
    if (closed || engine.isInboundDone() && engine.isOutboundDone()) {
      if (engine.isOutboundDone()) {
        engine.closeOutbound();
      }
      return getRuntime().getTrue();
    } else {
      return getRuntime().getFalse();
    }
  }

  private static RubyClass getSSLError(Ruby runtime) {
    return (RubyClass) ((RubyModule) runtime.getModule("Puma").getConstantAt("MiniSSL")).getConstantAt("SSLError");
  }

  private static RaiseException newSSLError(Ruby runtime, SSLException cause) {
    return newError(runtime, getSSLError(runtime), cause.toString(), cause);
  }

  private static RaiseException newError(Ruby runtime, RubyClass errorClass, String message, Throwable cause) {
    RaiseException ex = new RaiseException(runtime, errorClass, message, true);
    ex.initCause(cause);
    return ex;
  }

}