lib/rodauth/features/oauth.rb in rodauth-oauth-0.7.1 vs lib/rodauth/features/oauth.rb in rodauth-oauth-0.7.2

- old
+ new

@@ -64,10 +64,11 @@ error_flash "Please authorize to continue", "require_authorization" error_flash "There was an error registering your oauth application", "create_oauth_application" notice_flash "Your oauth application has been registered", "create_oauth_application" notice_flash "The oauth token has been revoked", "revoke_oauth_token" + error_flash "You are not authorized to revoke this token", "revoke_unauthorized_account" view "authorize", "Authorize", "authorize" view "oauth_applications", "Oauth Applications", "oauth_applications" view "oauth_application", "Oauth Application", "oauth_application" view "new_oauth_application", "New Oauth Application", "new_oauth_application" @@ -277,12 +278,18 @@ # /revoke route(:revoke) do |r| next unless is_authorization_server? before_revoke_route - require_oauth_application + if logged_in? + require_account + require_oauth_application_from_account + else + require_oauth_application + end + r.post do catch_error do validate_oauth_revoke_params oauth_token = nil @@ -384,11 +391,14 @@ request.get "new" do new_oauth_application_view end request.on(oauth_applications_id_pattern) do |id| - oauth_application = db[oauth_applications_table].where(oauth_applications_id_column => id).first + oauth_application = db[oauth_applications_table] + .where(oauth_applications_id_column => id) + .where(oauth_applications_account_id_column => account_id) + .first next unless oauth_application scope.instance_variable_set(:@oauth_application, oauth_application) request.is do @@ -405,11 +415,12 @@ end end end request.get do - scope.instance_variable_set(:@oauth_applications, db[oauth_applications_table]) + scope.instance_variable_set(:@oauth_applications, db[oauth_applications_table] + .where(oauth_applications_account_id_column => account_id)) oauth_applications_view end request.post do catch_error do @@ -472,11 +483,11 @@ when Array scope when String scope.split(" ") when nil - [oauth_application_default_scope] + Array(oauth_application_default_scope) end end def redirect_uri param_or_nil("redirect_uri") || begin @@ -682,10 +693,24 @@ return if @oauth_application && use_oauth_pkce? && param_or_nil("code_verifier") authorization_required unless @oauth_application && secret_matches?(@oauth_application, client_secret) end + def require_oauth_application_from_account + ds = db[oauth_applications_table] + .join(oauth_tokens_table, Sequel[oauth_tokens_table][oauth_tokens_oauth_application_id_column] => + Sequel[oauth_applications_table][oauth_applications_id_column]) + .where(oauth_token_by_token_ds(param("token")).opts.fetch(:where, true)) + .where(Sequel[oauth_applications_table][oauth_applications_account_id_column] => account_id) + + @oauth_application = ds.qualify.first + return if @oauth_application + + set_redirect_error_flash revoke_unauthorized_account_error_flash + redirect request.referer || "/" + end + def secret_matches?(oauth_application, secret) BCrypt::Password.new(oauth_application[oauth_applications_client_secret_column]) == secret end def secret_hash(secret) @@ -772,20 +797,24 @@ end __insert_and_return__(ds, oauth_tokens_id_column, params) end end - def oauth_token_by_token(token) + def oauth_token_by_token_ds(token) ds = db[oauth_tokens_table] ds = if oauth_tokens_token_hash_column - ds.where(oauth_tokens_token_hash_column => generate_token_hash(token)) + ds.where(Sequel[oauth_tokens_table][oauth_tokens_token_hash_column] => generate_token_hash(token)) else - ds.where(oauth_tokens_token_column => token) + ds.where(Sequel[oauth_tokens_table][oauth_tokens_token_column] => token) end - ds.where(Sequel[oauth_tokens_expires_in_column] >= Sequel::CURRENT_TIMESTAMP) - .where(oauth_tokens_revoked_at_column => nil).first + ds.where(Sequel[oauth_tokens_table][oauth_tokens_expires_in_column] >= Sequel::CURRENT_TIMESTAMP) + .where(Sequel[oauth_tokens_table][oauth_tokens_revoked_at_column] => nil) + end + + def oauth_token_by_token(token) + oauth_token_by_token_ds(token).first end def oauth_token_by_refresh_token(token, revoked: false) ds = db[oauth_tokens_table] #