# frozen_string_literal: true
require 'openssl'
begin
  require 'rbnacl'
rescue LoadError
end

# JWT::Signature module
module JWT
  # Signature logic for JWT
  module Signature
    extend self

    HMAC_ALGORITHMS = %w(HS256 HS512256 HS384 HS512).freeze
    RSA_ALGORITHMS = %w(RS256 RS384 RS512).freeze
    ECDSA_ALGORITHMS = %w(ES256 ES384 ES512).freeze

    NAMED_CURVES = {
      'prime256v1' => 'ES256',
      'secp384r1' => 'ES384',
      'secp521r1' => 'ES512'
    }.freeze

    def sign(algorithm, msg, key)
      if HMAC_ALGORITHMS.include?(algorithm)
        sign_hmac(algorithm, msg, key)
      elsif RSA_ALGORITHMS.include?(algorithm)
        sign_rsa(algorithm, msg, key)
      elsif ECDSA_ALGORITHMS.include?(algorithm)
        sign_ecdsa(algorithm, msg, key)
      else
        raise NotImplementedError, 'Unsupported signing method'
      end
    end

    def verify(algo, key, signing_input, signature)
      verified = if HMAC_ALGORITHMS.include?(algo)
                   verify_hmac(algo, key, signing_input, signature)
                 elsif RSA_ALGORITHMS.include?(algo)
                   verify_rsa(algo, key, signing_input, signature)
                 elsif ECDSA_ALGORITHMS.include?(algo)
                   verify_ecdsa(algo, key, signing_input, signature)
                 else
                   raise JWT::VerificationError, 'Algorithm not supported'
                 end

      raise(JWT::VerificationError, 'Signature verification raised') unless verified
    rescue OpenSSL::PKey::PKeyError
      raise JWT::VerificationError, 'Signature verification raised'
    ensure
      OpenSSL.errors.clear
    end

    private

    def sign_rsa(algorithm, msg, private_key)
      raise EncodeError, "The given key is a #{private_key.class}. It has to be an OpenSSL::PKey::RSA instance." if private_key.class == String
      private_key.sign(OpenSSL::Digest.new(algorithm.sub('RS', 'sha')), msg)
    end

    def sign_ecdsa(algorithm, msg, private_key)
      key_algorithm = NAMED_CURVES[private_key.group.curve_name]
      if algorithm != key_algorithm
        raise IncorrectAlgorithm, "payload algorithm is #{algorithm} but #{key_algorithm} signing key was provided"
      end

      digest = OpenSSL::Digest.new(algorithm.sub('ES', 'sha'))
      asn1_to_raw(private_key.dsa_sign_asn1(digest.digest(msg)), private_key)
    end

    def sign_hmac(algorithm, msg, key)
      authenticator, padded_key = rbnacl_fixup(algorithm, key)
      if authenticator && padded_key
        authenticator.auth(padded_key, msg.encode('binary'))
      else
        OpenSSL::HMAC.digest(OpenSSL::Digest.new(algorithm.sub('HS', 'sha')), key, msg)
      end
    end

    def verify_rsa(algorithm, public_key, signing_input, signature)
      public_key.verify(OpenSSL::Digest.new(algorithm.sub('RS', 'sha')), signature, signing_input)
    end

    def verify_ecdsa(algorithm, public_key, signing_input, signature)
      key_algorithm = NAMED_CURVES[public_key.group.curve_name]
      if algorithm != key_algorithm
        raise IncorrectAlgorithm, "payload algorithm is #{algorithm} but #{key_algorithm} verification key was provided"
      end

      digest = OpenSSL::Digest.new(algorithm.sub('ES', 'sha'))
      public_key.dsa_verify_asn1(digest.digest(signing_input), raw_to_asn1(signature, public_key))
    end

    def verify_hmac(algorithm, public_key, signing_input, signature)
      authenticator, padded_key = rbnacl_fixup(algorithm, public_key)
      if authenticator && padded_key
        begin
          authenticator.verify(padded_key, signature.encode('binary'), signing_input.encode('binary'))
        rescue RbNaCl::BadAuthenticatorError
          false
        end
      else
        secure_compare(signature, sign_hmac(algorithm, signing_input, public_key))
      end
    end

    def asn1_to_raw(signature, public_key)
      byte_size = (public_key.group.degree + 7) / 8
      OpenSSL::ASN1.decode(signature).value.map { |value| value.value.to_s(2).rjust(byte_size, "\x00") }.join
    end

    def raw_to_asn1(signature, private_key)
      byte_size = (private_key.group.degree + 7) / 8
      r = signature[0..(byte_size - 1)]
      s = signature[byte_size..-1] || ''
      OpenSSL::ASN1::Sequence.new([r, s].map { |int| OpenSSL::ASN1::Integer.new(OpenSSL::BN.new(int, 2)) }).to_der
    end

    def rbnacl_fixup(algorithm, key)
      algorithm = algorithm.sub('HS', 'SHA').to_sym

      return [] unless defined?(RbNaCl) && RbNaCl::HMAC.constants(false).include?(algorithm)

      authenticator = RbNaCl::HMAC.const_get(algorithm)

      # Fall back to OpenSSL for keys larger than 32 bytes.
      return [] if key.bytesize > authenticator.key_bytes

      [
        authenticator,
        key.bytes.fill(0, key.bytesize...authenticator.key_bytes).pack('C*')
      ]
    end

    # From devise
    # constant-time comparison algorithm to prevent timing attacks
    def secure_compare(a, b)
      return false if a.nil? || b.nil? || a.empty? || b.empty? || a.bytesize != b.bytesize
      l = a.unpack "C#{a.bytesize}"
      res = 0
      b.each_byte { |byte| res |= byte ^ l.shift }
      res.zero?
    end
  end
end