lib/rodauth/features/oauth_jwt.rb in rodauth-oauth-0.1.0 vs lib/rodauth/features/oauth_jwt.rb in rodauth-oauth-0.2.0
- old
+ new
@@ -20,10 +20,14 @@
auth_value_method :oauth_jwt_jwe_key, nil
auth_value_method :oauth_jwt_jwe_public_key, nil
auth_value_method :oauth_jwt_jwe_algorithm, nil
auth_value_method :oauth_jwt_jwe_encryption_method, nil
+ # values used for rotating keys
+ auth_value_method :oauth_jwt_legacy_public_key, nil
+ auth_value_method :oauth_jwt_legacy_algorithm, nil
+
auth_value_method :oauth_jwt_jwe_copyright, nil
auth_value_method :oauth_jwt_audience, nil
auth_value_method :request_uri_not_supported_message, "request uri is unsupported"
auth_value_method :invalid_request_object_message, "request object is invalid"
@@ -86,13 +90,11 @@
return super unless request_object && oauth_application
jws_jwk = if oauth_application[oauth_application_jws_jwk_column]
jwk = oauth_application[oauth_application_jws_jwk_column]
- if jwk
- jwk = JSON.parse(jwk, symbolize_names: true) if jwk.is_a?(String)
- end
+ jwk = JSON.parse(jwk, symbolize_names: true) if jwk && jwk.is_a?(String)
else
redirect_response_error("invalid_request_object")
end
claims = jwt_decode(request_object, jws_key: jwk_import(jws_jwk), jws_algorithm: jwk[:alg])
@@ -103,12 +105,12 @@
# Object SHOULD contain the Claims "iss" (issuer) and "aud" (audience)
# as members, with their semantics being the same as defined in the JWT
# [RFC7519] specification. The value of "aud" should be the value of
# the Authorization Server (AS) "issuer" as defined in RFC8414
# [RFC8414].
- claims.delete(:iss)
- audience = claims.delete(:aud)
+ claims.delete("iss")
+ audience = claims.delete("aud")
redirect_response_error("invalid_request_object") if audience && audience != authorization_server_url
claims.each do |k, v|
request.params[k.to_s] = v
@@ -117,15 +119,21 @@
super
end
# /token
- def before_token
+ def require_oauth_application
# requset authentication optional for assertions
- return if param("grant_type") == "urn:ietf:params:oauth:grant-type:jwt-bearer"
+ return super unless param("grant_type") == "urn:ietf:params:oauth:grant-type:jwt-bearer"
- super
+ claims = jwt_decode(param("assertion"))
+
+ redirect_response_error("invalid_grant") unless claims
+
+ @oauth_application = db[oauth_applications_table].where(oauth_applications_client_id_column => claims["client_id"]).first
+
+ authorization_required unless @oauth_application
end
def validate_oauth_token_params
if param("grant_type") == "urn:ietf:params:oauth:grant-type:jwt-bearer"
redirect_response_error("invalid_client") unless param_or_nil("assertion")
@@ -143,14 +151,10 @@
end
def create_oauth_token_from_assertion
claims = jwt_decode(param("assertion"))
- redirect_response_error("invalid_grant") unless claims
-
- @oauth_application = db[oauth_applications_table].where(oauth_applications_client_id_column => claims["client_id"]).first
-
account = account_ds(claims["sub"]).first
redirect_response_error("invalid_client") unless oauth_application && account
create_params = {
@@ -165,22 +169,24 @@
def generate_oauth_token(params = {}, should_generate_refresh_token = true)
create_params = {
oauth_grants_expires_in_column => Time.now + oauth_token_expires_in
}.merge(params)
- if should_generate_refresh_token
- refresh_token = oauth_unique_id_generator
+ oauth_token = rescue_from_uniqueness_error do
+ if should_generate_refresh_token
+ refresh_token = oauth_unique_id_generator
- if oauth_tokens_refresh_token_hash_column
- create_params[oauth_tokens_refresh_token_hash_column] = generate_token_hash(refresh_token)
- else
- create_params[oauth_tokens_refresh_token_column] = refresh_token
+ if oauth_tokens_refresh_token_hash_column
+ create_params[oauth_tokens_refresh_token_hash_column] = generate_token_hash(refresh_token)
+ else
+ create_params[oauth_tokens_refresh_token_column] = refresh_token
+ end
end
+
+ _generate_oauth_token(create_params)
end
- oauth_token = _generate_oauth_token(create_params)
-
claims = jwt_claims(oauth_token)
# one of the points of using jwt is avoiding database lookups, so we put here all relevant
# token data.
claims[:scope] = oauth_token[oauth_tokens_scopes_column]
@@ -291,22 +297,21 @@
response = http.request(request)
authorization_required unless response.code.to_i == 200
# time-to-live
ttl = if response.key?("cache-control")
- cache_control = response["cache_control"]
+ cache_control = response["cache-control"]
cache_control[/max-age=(\d+)/, 1]
elsif response.key?("expires")
- Time.httpdate(response["expires"]).utc.to_i - Time.now.utc.to_i
+ DateTime.httpdate(response["expires"]).utc.to_i - Time.now.utc.to_i
end
[JSON.parse(response.body, symbolize_names: true), ttl]
end
end
if defined?(JSON::JWT)
- # :nocov:
def jwk_import(data)
JSON::JWK.new(data)
end
@@ -328,27 +333,31 @@
end
def jwt_decode(token, jws_key: oauth_jwt_public_key || _jwt_key, **)
token = JSON::JWT.decode(token, oauth_jwt_jwe_key).plain_text if oauth_jwt_jwe_key
- @jwt_token = if jws_key
- JSON::JWT.decode(token, jws_key)
- elsif !is_authorization_server? && auth_server_jwks_set
- JSON::JWT.decode(token, JSON::JWK::Set.new(auth_server_jwks_set))
- end
+ if is_authorization_server?
+ if oauth_jwt_legacy_public_key
+ JSON::JWT.decode(token, JSON::JWK::Set.new({ keys: jwks_set }))
+ elsif jws_key
+ JSON::JWT.decode(token, jws_key)
+ end
+ elsif (jwks = auth_server_jwks_set)
+ JSON::JWT.decode(token, JSON::JWK::Set.new(jwks))
+ end
rescue JSON::JWT::Exception
nil
end
def jwks_set
- [
+ @jwks_set ||= [
(JSON::JWK.new(oauth_jwt_public_key).merge(use: "sig", alg: oauth_jwt_algorithm) if oauth_jwt_public_key),
+ (JSON::JWK.new(oauth_jwt_legacy_public_key).merge(use: "sig", alg: oauth_jwt_legacy_algorithm) if oauth_jwt_legacy_public_key),
(JSON::JWK.new(oauth_jwt_jwe_public_key).merge(use: "enc", alg: oauth_jwt_jwe_algorithm) if oauth_jwt_jwe_public_key)
].compact
end
- # :nocov:
elsif defined?(JWT)
# ruby-jwt
def jwk_import(data)
@@ -389,24 +398,33 @@
end
def jwt_decode(token, jws_key: oauth_jwt_public_key || _jwt_key, jws_algorithm: oauth_jwt_algorithm)
# decrypt jwe
token = JWE.decrypt(token, oauth_jwt_jwe_key) if oauth_jwt_jwe_key
-
# decode jwt
- @jwt_token = if jws_key
- JWT.decode(token, jws_key, true, algorithms: [jws_algorithm]).first
- elsif !is_authorization_server? && auth_server_jwks_set
- algorithms = auth_server_jwks_set[:keys].select { |k| k[:use] == "sig" }.map { |k| k[:alg] }
- JWT.decode(token, nil, true, jwks: auth_server_jwks_set, algorithms: algorithms).first
- end
+ if is_authorization_server?
+ if oauth_jwt_legacy_public_key
+ algorithms = jwks_set.select { |k| k[:use] == "sig" }.map { |k| k[:alg] }
+ JWT.decode(token, nil, true, jwks: { keys: jwks_set }, algorithms: algorithms).first
+ elsif jws_key
+ JWT.decode(token, jws_key, true, algorithms: [jws_algorithm]).first
+ end
+ elsif (jwks = auth_server_jwks_set)
+ algorithms = jwks[:keys].select { |k| k[:use] == "sig" }.map { |k| k[:alg] }
+ JWT.decode(token, nil, true, jwks: jwks, algorithms: algorithms).first
+ end
rescue JWT::DecodeError, JWT::JWKError
nil
end
def jwks_set
- [
+ @jwks_set ||= [
(JWT::JWK.new(oauth_jwt_public_key).export.merge(use: "sig", alg: oauth_jwt_algorithm) if oauth_jwt_public_key),
+ (
+ if oauth_jwt_legacy_public_key
+ JWT::JWK.new(oauth_jwt_legacy_public_key).export.merge(use: "sig", alg: oauth_jwt_legacy_algorithm)
+ end
+ ),
(JWT::JWK.new(oauth_jwt_jwe_public_key).export.merge(use: "enc", alg: oauth_jwt_jwe_algorithm) if oauth_jwt_jwe_public_key)
].compact
end
else
# :nocov: