lib/rodauth/features/oauth_jwt.rb in rodauth-oauth-0.1.0 vs lib/rodauth/features/oauth_jwt.rb in rodauth-oauth-0.2.0

- old
+ new

@@ -20,10 +20,14 @@ auth_value_method :oauth_jwt_jwe_key, nil auth_value_method :oauth_jwt_jwe_public_key, nil auth_value_method :oauth_jwt_jwe_algorithm, nil auth_value_method :oauth_jwt_jwe_encryption_method, nil + # values used for rotating keys + auth_value_method :oauth_jwt_legacy_public_key, nil + auth_value_method :oauth_jwt_legacy_algorithm, nil + auth_value_method :oauth_jwt_jwe_copyright, nil auth_value_method :oauth_jwt_audience, nil auth_value_method :request_uri_not_supported_message, "request uri is unsupported" auth_value_method :invalid_request_object_message, "request object is invalid" @@ -86,13 +90,11 @@ return super unless request_object && oauth_application jws_jwk = if oauth_application[oauth_application_jws_jwk_column] jwk = oauth_application[oauth_application_jws_jwk_column] - if jwk - jwk = JSON.parse(jwk, symbolize_names: true) if jwk.is_a?(String) - end + jwk = JSON.parse(jwk, symbolize_names: true) if jwk && jwk.is_a?(String) else redirect_response_error("invalid_request_object") end claims = jwt_decode(request_object, jws_key: jwk_import(jws_jwk), jws_algorithm: jwk[:alg]) @@ -103,12 +105,12 @@ # Object SHOULD contain the Claims "iss" (issuer) and "aud" (audience) # as members, with their semantics being the same as defined in the JWT # [RFC7519] specification. The value of "aud" should be the value of # the Authorization Server (AS) "issuer" as defined in RFC8414 # [RFC8414]. - claims.delete(:iss) - audience = claims.delete(:aud) + claims.delete("iss") + audience = claims.delete("aud") redirect_response_error("invalid_request_object") if audience && audience != authorization_server_url claims.each do |k, v| request.params[k.to_s] = v @@ -117,15 +119,21 @@ super end # /token - def before_token + def require_oauth_application # requset authentication optional for assertions - return if param("grant_type") == "urn:ietf:params:oauth:grant-type:jwt-bearer" + return super unless param("grant_type") == "urn:ietf:params:oauth:grant-type:jwt-bearer" - super + claims = jwt_decode(param("assertion")) + + redirect_response_error("invalid_grant") unless claims + + @oauth_application = db[oauth_applications_table].where(oauth_applications_client_id_column => claims["client_id"]).first + + authorization_required unless @oauth_application end def validate_oauth_token_params if param("grant_type") == "urn:ietf:params:oauth:grant-type:jwt-bearer" redirect_response_error("invalid_client") unless param_or_nil("assertion") @@ -143,14 +151,10 @@ end def create_oauth_token_from_assertion claims = jwt_decode(param("assertion")) - redirect_response_error("invalid_grant") unless claims - - @oauth_application = db[oauth_applications_table].where(oauth_applications_client_id_column => claims["client_id"]).first - account = account_ds(claims["sub"]).first redirect_response_error("invalid_client") unless oauth_application && account create_params = { @@ -165,22 +169,24 @@ def generate_oauth_token(params = {}, should_generate_refresh_token = true) create_params = { oauth_grants_expires_in_column => Time.now + oauth_token_expires_in }.merge(params) - if should_generate_refresh_token - refresh_token = oauth_unique_id_generator + oauth_token = rescue_from_uniqueness_error do + if should_generate_refresh_token + refresh_token = oauth_unique_id_generator - if oauth_tokens_refresh_token_hash_column - create_params[oauth_tokens_refresh_token_hash_column] = generate_token_hash(refresh_token) - else - create_params[oauth_tokens_refresh_token_column] = refresh_token + if oauth_tokens_refresh_token_hash_column + create_params[oauth_tokens_refresh_token_hash_column] = generate_token_hash(refresh_token) + else + create_params[oauth_tokens_refresh_token_column] = refresh_token + end end + + _generate_oauth_token(create_params) end - oauth_token = _generate_oauth_token(create_params) - claims = jwt_claims(oauth_token) # one of the points of using jwt is avoiding database lookups, so we put here all relevant # token data. claims[:scope] = oauth_token[oauth_tokens_scopes_column] @@ -291,22 +297,21 @@ response = http.request(request) authorization_required unless response.code.to_i == 200 # time-to-live ttl = if response.key?("cache-control") - cache_control = response["cache_control"] + cache_control = response["cache-control"] cache_control[/max-age=(\d+)/, 1] elsif response.key?("expires") - Time.httpdate(response["expires"]).utc.to_i - Time.now.utc.to_i + DateTime.httpdate(response["expires"]).utc.to_i - Time.now.utc.to_i end [JSON.parse(response.body, symbolize_names: true), ttl] end end if defined?(JSON::JWT) - # :nocov: def jwk_import(data) JSON::JWK.new(data) end @@ -328,27 +333,31 @@ end def jwt_decode(token, jws_key: oauth_jwt_public_key || _jwt_key, **) token = JSON::JWT.decode(token, oauth_jwt_jwe_key).plain_text if oauth_jwt_jwe_key - @jwt_token = if jws_key - JSON::JWT.decode(token, jws_key) - elsif !is_authorization_server? && auth_server_jwks_set - JSON::JWT.decode(token, JSON::JWK::Set.new(auth_server_jwks_set)) - end + if is_authorization_server? + if oauth_jwt_legacy_public_key + JSON::JWT.decode(token, JSON::JWK::Set.new({ keys: jwks_set })) + elsif jws_key + JSON::JWT.decode(token, jws_key) + end + elsif (jwks = auth_server_jwks_set) + JSON::JWT.decode(token, JSON::JWK::Set.new(jwks)) + end rescue JSON::JWT::Exception nil end def jwks_set - [ + @jwks_set ||= [ (JSON::JWK.new(oauth_jwt_public_key).merge(use: "sig", alg: oauth_jwt_algorithm) if oauth_jwt_public_key), + (JSON::JWK.new(oauth_jwt_legacy_public_key).merge(use: "sig", alg: oauth_jwt_legacy_algorithm) if oauth_jwt_legacy_public_key), (JSON::JWK.new(oauth_jwt_jwe_public_key).merge(use: "enc", alg: oauth_jwt_jwe_algorithm) if oauth_jwt_jwe_public_key) ].compact end - # :nocov: elsif defined?(JWT) # ruby-jwt def jwk_import(data) @@ -389,24 +398,33 @@ end def jwt_decode(token, jws_key: oauth_jwt_public_key || _jwt_key, jws_algorithm: oauth_jwt_algorithm) # decrypt jwe token = JWE.decrypt(token, oauth_jwt_jwe_key) if oauth_jwt_jwe_key - # decode jwt - @jwt_token = if jws_key - JWT.decode(token, jws_key, true, algorithms: [jws_algorithm]).first - elsif !is_authorization_server? && auth_server_jwks_set - algorithms = auth_server_jwks_set[:keys].select { |k| k[:use] == "sig" }.map { |k| k[:alg] } - JWT.decode(token, nil, true, jwks: auth_server_jwks_set, algorithms: algorithms).first - end + if is_authorization_server? + if oauth_jwt_legacy_public_key + algorithms = jwks_set.select { |k| k[:use] == "sig" }.map { |k| k[:alg] } + JWT.decode(token, nil, true, jwks: { keys: jwks_set }, algorithms: algorithms).first + elsif jws_key + JWT.decode(token, jws_key, true, algorithms: [jws_algorithm]).first + end + elsif (jwks = auth_server_jwks_set) + algorithms = jwks[:keys].select { |k| k[:use] == "sig" }.map { |k| k[:alg] } + JWT.decode(token, nil, true, jwks: jwks, algorithms: algorithms).first + end rescue JWT::DecodeError, JWT::JWKError nil end def jwks_set - [ + @jwks_set ||= [ (JWT::JWK.new(oauth_jwt_public_key).export.merge(use: "sig", alg: oauth_jwt_algorithm) if oauth_jwt_public_key), + ( + if oauth_jwt_legacy_public_key + JWT::JWK.new(oauth_jwt_legacy_public_key).export.merge(use: "sig", alg: oauth_jwt_legacy_algorithm) + end + ), (JWT::JWK.new(oauth_jwt_jwe_public_key).export.merge(use: "enc", alg: oauth_jwt_jwe_algorithm) if oauth_jwt_jwe_public_key) ].compact end else # :nocov: