lib/rodauth/features/oauth_jwt.rb in rodauth-oauth-0.9.3 vs lib/rodauth/features/oauth_jwt.rb in rodauth-oauth-0.10.0

- old
+ new

@@ -42,14 +42,17 @@ auth_value_method :oauth_application_jwt_public_key_param, "jwt_public_key" auth_value_method :oauth_application_jwks_param, "jwks" auth_value_method :oauth_jwt_keys, {} auth_value_method :oauth_jwt_key, nil + auth_value_method :oauth_jwt_public_keys, {} auth_value_method :oauth_jwt_public_key, nil auth_value_method :oauth_jwt_algorithm, "RS256" + auth_value_method :oauth_jwt_jwe_keys, {} auth_value_method :oauth_jwt_jwe_key, nil + auth_value_method :oauth_jwt_jwe_public_keys, {} auth_value_method :oauth_jwt_jwe_public_key, nil auth_value_method :oauth_jwt_jwe_algorithm, nil auth_value_method :oauth_jwt_jwe_encryption_method, nil # values used for rotating keys @@ -405,14 +408,15 @@ JSON::JWK.new(data) end def jwt_encode(payload, jwks: nil, - jwe_key: oauth_jwt_jwe_public_key || oauth_jwt_jwe_key, - signing_algorithm: oauth_jwt_algorithm, encryption_algorithm: oauth_jwt_jwe_algorithm, - encryption_method: oauth_jwt_jwe_encryption_method) + encryption_method: oauth_jwt_jwe_encryption_method, + jwe_key: oauth_jwt_jwe_keys[[encryption_algorithm, + encryption_method]] || oauth_jwt_jwe_public_key || oauth_jwt_jwe_key, + signing_algorithm: oauth_jwt_algorithm || oauth_jwt_keys.keys.first) payload[:jti] = generate_jti(payload) jwt = JSON::JWT.new(payload) key = oauth_jwt_keys[signing_algorithm] || _jwt_key key = key.first if key.is_a?(Array) @@ -425,10 +429,11 @@ if jwks && (jwk = jwks.find { |k| k[:use] == "enc" && k[:alg] == encryption_algorithm && k[:enc] == encryption_method }) jwk = JSON::JWK.new(jwk) jwe = jwt.encrypt(jwk, encryption_algorithm.to_sym, encryption_method.to_sym) jwe.to_s elsif jwe_key + jwe_key = jwe_key.first if jwe_key.is_a?(Array) algorithm = encryption_algorithm.to_sym if encryption_algorithm meth = encryption_method.to_sym if encryption_method jwt.encrypt(jwe_key, algorithm, meth) else jwt.to_s @@ -436,23 +441,28 @@ end def jwt_decode( token, jwks: nil, - jws_key: oauth_jwt_public_key || _jwt_key, - jws_algorithm: oauth_jwt_algorithm, - jwe_key: oauth_jwt_jwe_key, + jws_algorithm: oauth_jwt_algorithm || oauth_jwt_public_key.keys.first || oauth_jwt_keys.keys.first, + jws_key: oauth_jwt_public_key || oauth_jwt_keys[jws_algorithm] || _jwt_key, jws_encryption_algorithm: oauth_jwt_jwe_algorithm, jws_encryption_method: oauth_jwt_jwe_encryption_method, + jwe_key: oauth_jwt_jwe_keys[[jws_encryption_algorithm, jws_encryption_method]] || oauth_jwt_jwe_key, verify_claims: true, verify_jti: true, verify_iss: true, verify_aud: false, ** ) - token = JSON::JWT.decode(token, oauth_jwt_jwe_key).plain_text if jwe_key + jws_key = jws_key.first if jws_key.is_a?(Array) + if jwe_key + jwe_key = jwe_key.first if jwe_key.is_a?(Array) + token = JSON::JWT.decode(token, jwe_key).plain_text + end + claims = if is_authorization_server? if oauth_jwt_legacy_public_key JSON::JWT.decode(token, JSON::JWK::Set.new({ keys: jwks_set })) elsif jwks enc_algs = [jws_encryption_algorithm].compact @@ -485,10 +495,25 @@ nil end def jwks_set @jwks_set ||= [ + *( + unless oauth_jwt_public_keys.empty? + oauth_jwt_public_keys.flat_map { |algo, pkeys| pkeys.map { |pkey| JSON::JWK.new(pkey).merge(use: "sig", alg: algo) } } + end + ), + *( + unless oauth_jwt_jwe_public_keys.empty? + oauth_jwt_jwe_public_keys.flat_map do |(algo, _enc), pkeys| + pkeys.map do |pkey| + JSON::JWK.new(pkey).merge(use: "enc", alg: algo) + end + end + end + ), + # legacy (JSON::JWK.new(oauth_jwt_public_key).merge(use: "sig", alg: oauth_jwt_algorithm) if oauth_jwt_public_key), (JSON::JWK.new(oauth_jwt_legacy_public_key).merge(use: "sig", alg: oauth_jwt_legacy_algorithm) if oauth_jwt_legacy_public_key), (JSON::JWK.new(oauth_jwt_jwe_public_key).merge(use: "enc", alg: oauth_jwt_jwe_algorithm) if oauth_jwt_jwe_public_key) ].compact end @@ -520,11 +545,12 @@ def jwk_import(data) JWT::JWK.import(data).keypair end - def jwt_encode(payload, signing_algorithm: oauth_jwt_algorithm) + def jwt_encode(payload, + signing_algorithm: oauth_jwt_algorithm || oauth_jwt_keys.keys.first) headers = {} key = oauth_jwt_keys[signing_algorithm] || _jwt_key key = key.first if key.is_a?(Array) @@ -543,22 +569,23 @@ if defined?(JWE) def jwt_encode_with_jwe( payload, jwks: nil, - jwe_key: oauth_jwt_jwe_public_key || oauth_jwt_jwe_key, encryption_algorithm: oauth_jwt_jwe_algorithm, - encryption_method: oauth_jwt_jwe_encryption_method, **args + encryption_method: oauth_jwt_jwe_encryption_method, + jwe_key: oauth_jwt_jwe_public_key || oauth_jwt_jwe_keys[[encryption_algorithm, encryption_method]] || oauth_jwt_jwe_key, + **args ) - token = jwt_encode_without_jwe(payload, **args) return token unless encryption_algorithm && encryption_method if jwks && jwks.any? { |k| k[:use] == "enc" } JWE.__rodauth_oauth_encrypt_from_jwks(token, jwks, alg: encryption_algorithm, enc: encryption_method) elsif jwe_key + jwe_key = jwe_key.first if jwe_key.is_a?(Array) params = { zip: "DEF", copyright: oauth_jwt_jwe_copyright } params[:enc] = encryption_method if encryption_method @@ -574,17 +601,19 @@ end def jwt_decode( token, jwks: nil, - jws_key: oauth_jwt_public_key || _jwt_key, - jws_algorithm: oauth_jwt_algorithm, + jws_algorithm: oauth_jwt_algorithm || oauth_jwt_public_key.keys.first || oauth_jwt_keys.keys.first, + jws_key: oauth_jwt_public_key || oauth_jwt_keys[jws_algorithm] || _jwt_key, verify_claims: true, verify_jti: true, verify_iss: true, verify_aud: false ) + jws_key = jws_key.first if jws_key.is_a?(Array) + # verifying the JWT implies verifying: # # issuer: check that server generated the token # aud: check the audience field (client is who he says he is) # iat: check that the token didn't expire @@ -629,19 +658,20 @@ if defined?(JWE) def jwt_decode_with_jwe( token, jwks: nil, - jwe_key: oauth_jwt_jwe_key, jws_encryption_algorithm: oauth_jwt_jwe_algorithm, jws_encryption_method: oauth_jwt_jwe_encryption_method, + jwe_key: oauth_jwt_jwe_keys[[jws_encryption_algorithm, jws_encryption_method]] || oauth_jwt_jwe_key, **args ) token = if jwks && jwks.any? { |k| k[:use] == "enc" } JWE.__rodauth_oauth_decrypt_from_jwks(token, jwks, alg: jws_encryption_algorithm, enc: jws_encryption_method) elsif jwe_key + jwe_key = jwe_key.first if jwe_key.is_a?(Array) JWE.decrypt(token, jwe_key) else token end @@ -654,9 +684,24 @@ alias_method :jwt_decode, :jwt_decode_with_jwe end def jwks_set @jwks_set ||= [ + *( + unless oauth_jwt_public_keys.empty? + oauth_jwt_public_keys.flat_map { |algo, pkeys| pkeys.map { |pkey| JWT::JWK.new(pkey).export.merge(use: "sig", alg: algo) } } + end + ), + *( + unless oauth_jwt_jwe_public_keys.empty? + oauth_jwt_jwe_public_keys.flat_map do |(algo, _enc), pkeys| + pkeys.map do |pkey| + JWT::JWK.new(pkey).export.merge(use: "enc", alg: algo) + end + end + end + ), + # legacy (JWT::JWK.new(oauth_jwt_public_key).export.merge(use: "sig", alg: oauth_jwt_algorithm) if oauth_jwt_public_key), ( if oauth_jwt_legacy_public_key JWT::JWK.new(oauth_jwt_legacy_public_key).export.merge(use: "sig", alg: oauth_jwt_legacy_algorithm) end