module JSON class JWK < ActiveSupport::HashWithIndifferentAccess class UnknownAlgorithm < JWT::Exception; end def initialize(params = {}, ex_params = {}) case params when OpenSSL::PKey::RSA, OpenSSL::PKey::EC super params.to_jwk(ex_params) when OpenSSL::PKey::PKey raise UnknownAlgorithm.new('Unknown Key Type') when String super( k: params, kty: :oct ) merge! ex_params else super params merge! ex_params end calculate_default_kid if self[:kid].blank? end def content_type 'application/jwk+json' end def thumbprint(digest = OpenSSL::Digest::SHA256.new) digest = case digest when OpenSSL::Digest digest when String, Symbol OpenSSL::Digest.new digest.to_s else raise UnknownAlgorithm.new('Unknown Digest Algorithm') end UrlSafeBase64.encode64 digest.digest(normalize.to_json) end def to_key case when rsa? to_rsa_key when ec? to_ec_key when oct? self[:k] else raise UnknownAlgorithm.new('Unknown Key Type') end end private def rsa? self[:kty].try(:to_sym) == :RSA end def ec? self[:kty].try(:to_sym) == :EC end def oct? self[:kty].try(:to_sym) == :oct end def calculate_default_kid self[:kid] = thumbprint rescue # ignore end def normalize case when rsa? { e: self[:e], kty: self[:kty], n: self[:n] } when ec? { crv: self[:crv], kty: self[:kty], x: self[:x], y: self[:y] } when oct? { k: self[:k], kty: self[:kty] } else raise UnknownAlgorithm.new('Unknown Key Type') end end def to_rsa_key e, n, d, p, q, dp, dq, qi = [:e, :n, :d, :p, :q, :dp, :dq, :qi].collect do |key| if self[key] OpenSSL::BN.new UrlSafeBase64.decode64(self[key]), 2 end end key = OpenSSL::PKey::RSA.new if key.respond_to? :set_key key.set_key n, e, d key.set_factors p, q if p && q key.set_crt_params dp, dq, qi if dp && dq && qi else key.e = e key.n = n key.d = d if d key.p = p if p key.q = q if q key.dmp1 = dp if dp key.dmq1 = dq if dq key.iqmp = qi if qi end key end def to_ec_key curve_name = case self[:crv].try(:to_sym) when :'P-256' 'prime256v1' when :'P-384' 'secp384r1' when :'P-521' 'secp521r1' else raise UnknownAlgorithm.new('Unknown EC Curve') end x, y, d = [:x, :y, :d].collect do |key| if self[key] UrlSafeBase64.decode64(self[key]) end end key = OpenSSL::PKey::EC.new curve_name key.private_key = OpenSSL::BN.new(d, 2) if d key.public_key = OpenSSL::PKey::EC::Point.new( OpenSSL::PKey::EC::Group.new(curve_name), OpenSSL::BN.new(['04' + x.unpack('H*').first + y.unpack('H*').first].pack('H*'), 2) ) key end end end