lib/rodauth/features/oauth_jwt.rb in rodauth-oauth-0.0.4 vs lib/rodauth/features/oauth_jwt.rb in rodauth-oauth-0.0.5

- old
+ new

@@ -1,24 +1,19 @@ # frozen-string-literal: true +require "rodauth/oauth/ttl_store" + module Rodauth Feature.define(:oauth_jwt) do depends :oauth - auth_value_method :grant_type_param, "grant_type" - auth_value_method :assertion_param, "assertion" - auth_value_method :oauth_jwt_token_issuer, "Example" auth_value_method :oauth_jwt_key, nil auth_value_method :oauth_jwt_public_key, nil auth_value_method :oauth_jwt_algorithm, "HS256" - auth_value_method :oauth_jwt_jwk_key, nil - auth_value_method :oauth_jwt_jwk_public_key, nil - auth_value_method :oauth_jwt_jwk_algorithm, "RS256" - auth_value_method :oauth_jwt_jwe_key, nil auth_value_method :oauth_jwt_jwe_public_key, nil auth_value_method :oauth_jwt_jwe_algorithm, nil auth_value_method :oauth_jwt_jwe_encryption_method, nil @@ -29,10 +24,12 @@ :jwt_encode, :jwt_decode, :jwks_set ) + JWKS = OAuth::TtlStore.new + def require_oauth_authorization(*scopes) authorization_required unless authorization_token scopes << oauth_application_default_scope if scopes.empty? @@ -44,40 +41,54 @@ private def authorization_token return @authorization_token if defined?(@authorization_token) - @authorization_token = jwt_decode(fetch_access_token) + @authorization_token = begin + bearer_token = fetch_access_token + + return unless bearer_token + + jwt_token = jwt_decode(bearer_token) + + return unless jwt_token + + return if jwt_token["iss"] != oauth_jwt_token_issuer || + jwt_token["aud"] != oauth_jwt_audience || + !jwt_token["sub"] + + jwt_token + end end # /token def before_token # requset authentication optional for assertions - return if param(grant_type_param) == "urn:ietf:params:oauth:grant-type:jwt-bearer" + return if param("grant_type") == "urn:ietf:params:oauth:grant-type:jwt-bearer" super end def validate_oauth_token_params - if param(grant_type_param) == "urn:ietf:params:oauth:grant-type:jwt-bearer" - redirect_response_error("invalid_client") unless param_or_nil(assertion_param) + if param("grant_type") == "urn:ietf:params:oauth:grant-type:jwt-bearer" + redirect_response_error("invalid_client") unless param_or_nil("assertion") else super end end def create_oauth_token - if param(grant_type_param) == "urn:ietf:params:oauth:grant-type:jwt-bearer" + if param("grant_type") == "urn:ietf:params:oauth:grant-type:jwt-bearer" create_oauth_token_from_assertion else super end end def create_oauth_token_from_assertion - claims = jwt_decode(param(assertion_param)) + claims = jwt_decode(param("assertion")) redirect_response_error("invalid_grant") unless claims @oauth_application = db[oauth_applications_table].where(oauth_applications_client_id_column => claims["client_id"]).first @@ -109,11 +120,11 @@ end end oauth_token = _generate_oauth_token(create_params) - issued_at = Time.current.utc.to_i + issued_at = Time.now.utc.to_i payload = { sub: oauth_token[oauth_tokens_account_id_column], iss: oauth_jwt_token_issuer, # issuer iat: issued_at, # issued at @@ -131,11 +142,11 @@ exp: issued_at + oauth_token_expires_in, aud: oauth_jwt_audience, # one of the points of using jwt is avoiding database lookups, so we put here all relevant # token data. - scope: oauth_token[oauth_tokens_scopes_column].gsub(",", " ") + scope: oauth_token[oauth_tokens_scopes_column] } token = jwt_encode(payload) oauth_token[oauth_tokens_token_column] = token @@ -180,27 +191,59 @@ oauth_token["client_id"] == oauth_application[oauth_applications_client_id_column] end def _jwt_key - @_jwt_key ||= oauth_jwt_key || oauth_application[oauth_applications_client_secret_column] + @_jwt_key ||= oauth_jwt_key || (oauth_application[oauth_applications_client_secret_column] if oauth_application) end + # Resource Server only! + # + # returns the jwks set from the authorization server. + def auth_server_jwks_set + metadata = authorization_server_metadata + + return unless metadata && (jwks_uri = metadata[:jwks_uri]) + + jwks_uri = URI(jwks_uri) + + jwks = JWKS[jwks_uri] + + return jwks if jwks + + JWKS.set(jwks_uri) do + http = Net::HTTP.new(jwks_uri.host, jwks_uri.port) + http.use_ssl = jwks_uri.scheme == "https" + + request = Net::HTTP::Get.new(jwks_uri.request_uri) + request["accept"] = json_response_content_type + response = http.request(request) + authorization_required unless response.code.to_i == 200 + + # time-to-live + ttl = if response.key?("cache-control") + cache_control = response["cache_control"] + cache_control[/max-age=(\d+)/, 1] + elsif response.key?("expires") + Time.httpdate(response["expires"]).utc.to_i - Time.now.utc.to_i + end + + [JSON.parse(response.body, symbolize_names: true), ttl] + end + end + if defined?(JSON::JWT) # :nocov: # json-jwt def jwt_encode(payload) jwt = JSON::JWT.new(payload) + jwk = JSON::JWK.new(_jwt_key) - jwt = if oauth_jwt_jwk_key - jwk = JSON::JWK.new(oauth_jwt_jwk_key) - jwt.kid = jwk.thumbprint - jwt.sign(oauth_jwt_jwk_key, oauth_jwt_jwk_algorithm) - else - jwt.sign(_jwt_key, oauth_jwt_algorithm) - end + jwt = jwt.sign(jwk, oauth_jwt_algorithm) + jwt.kid = jwk.thumbprint + if oauth_jwt_jwe_key algorithm = oauth_jwt_jwe_algorithm.to_sym if oauth_jwt_jwe_algorithm jwt = jwt.encrypt(oauth_jwt_jwe_public_key || oauth_jwt_jwe_key, algorithm, oauth_jwt_jwe_encryption_method.to_sym) @@ -211,54 +254,52 @@ def jwt_decode(token) return @jwt_token if defined?(@jwt_token) token = JSON::JWT.decode(token, oauth_jwt_jwe_key).plain_text if oauth_jwt_jwe_key - @jwt_token = if oauth_jwt_jwk_key - jwk = JSON::JWK.new(oauth_jwt_jwk_public_key || oauth_jwt_jwk_key) + jwk = oauth_jwt_public_key || _jwt_key + + @jwt_token = if jwk JSON::JWT.decode(token, jwk) - else - JSON::JWT.decode(token, oauth_jwt_public_key || _jwt_key) + elsif !is_authorization_server? && auth_server_jwks_set + JSON::JWT.decode(token, JSON::JWK::Set.new(auth_server_jwks_set)) end rescue JSON::JWT::Exception nil end def jwks_set [ - (JSON::JWK.new(oauth_jwt_jwk_public_key).merge(use: "sig", alg: oauth_jwt_jwk_algorithm) if oauth_jwt_jwk_public_key), + (JSON::JWK.new(oauth_jwt_public_key).merge(use: "sig", alg: oauth_jwt_algorithm) if oauth_jwt_public_key), (JSON::JWK.new(oauth_jwt_jwe_public_key).merge(use: "enc", alg: oauth_jwt_jwe_algorithm) if oauth_jwt_jwe_public_key) ].compact end + # :nocov: elsif defined?(JWT) # ruby-jwt def jwt_encode(payload) headers = {} - key, algorithm = if oauth_jwt_jwk_key - jwk_key = JWT::JWK.new(oauth_jwt_jwk_key) - # JWK - # Currently only supports RSA public keys. - headers[:kid] = jwk_key.kid + key = _jwt_key - [jwk_key.keypair, oauth_jwt_jwk_algorithm] - else - # JWS + if key.is_a?(OpenSSL::PKey::RSA) + jwk = JWT::JWK.new(_jwt_key) + headers[:kid] = jwk.kid - [_jwt_key, oauth_jwt_algorithm] - end + key = jwk.keypair + end # Use the key and iat to create a unique key per request to prevent replay attacks jti_raw = [key, payload[:iat]].join(":").to_s jti = Digest::SHA256.hexdigest(jti_raw) # @see JWT reserved claims - https://tools.ietf.org/html/draft-jones-json-web-token-07#page-7 payload[:jti] = jti - token = JWT.encode(payload, key, algorithm, headers) + token = JWT.encode(payload, key, oauth_jwt_algorithm, headers) if oauth_jwt_jwe_key params = { zip: "DEF", copyright: oauth_jwt_jwe_copyright @@ -276,39 +317,25 @@ # decrypt jwe token = JWE.decrypt(token, oauth_jwt_jwe_key) if oauth_jwt_jwe_key # decode jwt - headers = { algorithms: [oauth_jwt_algorithm] } + key = oauth_jwt_public_key || _jwt_key - key = if oauth_jwt_jwk_key - jwk_key = JWT::JWK.new(oauth_jwt_jwk_public_key || oauth_jwt_jwk_key) - # JWK - # The jwk loader would fetch the set of JWKs from a trusted source - jwk_loader = lambda do |options| - @cached_keys = nil if options[:invalidate] # need to reload the keys - @cached_keys ||= { keys: [jwk_key.export] } - end - - headers[:algorithms] = [oauth_jwt_jwk_algorithm] - headers[:jwks] = jwk_loader - - nil - else - # JWS - # worst case scenario, the key is the application key - oauth_jwt_public_key || _jwt_key - end - @jwt_token, = JWT.decode(token, key, true, headers) - @jwt_token - rescue JWT::DecodeError + @jwt_token = if key + JWT.decode(token, key, true, algorithms: [oauth_jwt_algorithm]).first + elsif !is_authorization_server? && auth_server_jwks_set + algorithms = auth_server_jwks_set[:keys].select { |k| k[:use] == "sig" }.map { |k| k[:alg] } + JWT.decode(token, nil, true, jwks: auth_server_jwks_set, algorithms: algorithms).first + end + rescue JWT::DecodeError, JWT::JWKError nil end def jwks_set [ - (JWT::JWK.new(oauth_jwt_jwk_public_key).export.merge(use: "sig", alg: oauth_jwt_jwk_algorithm) if oauth_jwt_jwk_public_key), + (JWT::JWK.new(oauth_jwt_public_key).export.merge(use: "sig", alg: oauth_jwt_algorithm) if oauth_jwt_public_key), (JWT::JWK.new(oauth_jwt_jwe_public_key).export.merge(use: "enc", alg: oauth_jwt_jwe_algorithm) if oauth_jwt_jwe_public_key) ].compact end else # :nocov: @@ -326,10 +353,10 @@ # :nocov: end route(:oauth_jwks) do |r| r.get do - json_response_success(jwks_set) + json_response_success({ keys: jwks_set }) end end end end