lib/rodauth/features/oauth.rb in rodauth-oauth-0.7.1 vs lib/rodauth/features/oauth.rb in rodauth-oauth-0.7.2
- old
+ new
@@ -64,10 +64,11 @@
error_flash "Please authorize to continue", "require_authorization"
error_flash "There was an error registering your oauth application", "create_oauth_application"
notice_flash "Your oauth application has been registered", "create_oauth_application"
notice_flash "The oauth token has been revoked", "revoke_oauth_token"
+ error_flash "You are not authorized to revoke this token", "revoke_unauthorized_account"
view "authorize", "Authorize", "authorize"
view "oauth_applications", "Oauth Applications", "oauth_applications"
view "oauth_application", "Oauth Application", "oauth_application"
view "new_oauth_application", "New Oauth Application", "new_oauth_application"
@@ -277,12 +278,18 @@
# /revoke
route(:revoke) do |r|
next unless is_authorization_server?
before_revoke_route
- require_oauth_application
+ if logged_in?
+ require_account
+ require_oauth_application_from_account
+ else
+ require_oauth_application
+ end
+
r.post do
catch_error do
validate_oauth_revoke_params
oauth_token = nil
@@ -384,11 +391,14 @@
request.get "new" do
new_oauth_application_view
end
request.on(oauth_applications_id_pattern) do |id|
- oauth_application = db[oauth_applications_table].where(oauth_applications_id_column => id).first
+ oauth_application = db[oauth_applications_table]
+ .where(oauth_applications_id_column => id)
+ .where(oauth_applications_account_id_column => account_id)
+ .first
next unless oauth_application
scope.instance_variable_set(:@oauth_application, oauth_application)
request.is do
@@ -405,11 +415,12 @@
end
end
end
request.get do
- scope.instance_variable_set(:@oauth_applications, db[oauth_applications_table])
+ scope.instance_variable_set(:@oauth_applications, db[oauth_applications_table]
+ .where(oauth_applications_account_id_column => account_id))
oauth_applications_view
end
request.post do
catch_error do
@@ -472,11 +483,11 @@
when Array
scope
when String
scope.split(" ")
when nil
- [oauth_application_default_scope]
+ Array(oauth_application_default_scope)
end
end
def redirect_uri
param_or_nil("redirect_uri") || begin
@@ -682,10 +693,24 @@
return if @oauth_application && use_oauth_pkce? && param_or_nil("code_verifier")
authorization_required unless @oauth_application && secret_matches?(@oauth_application, client_secret)
end
+ def require_oauth_application_from_account
+ ds = db[oauth_applications_table]
+ .join(oauth_tokens_table, Sequel[oauth_tokens_table][oauth_tokens_oauth_application_id_column] =>
+ Sequel[oauth_applications_table][oauth_applications_id_column])
+ .where(oauth_token_by_token_ds(param("token")).opts.fetch(:where, true))
+ .where(Sequel[oauth_applications_table][oauth_applications_account_id_column] => account_id)
+
+ @oauth_application = ds.qualify.first
+ return if @oauth_application
+
+ set_redirect_error_flash revoke_unauthorized_account_error_flash
+ redirect request.referer || "/"
+ end
+
def secret_matches?(oauth_application, secret)
BCrypt::Password.new(oauth_application[oauth_applications_client_secret_column]) == secret
end
def secret_hash(secret)
@@ -772,20 +797,24 @@
end
__insert_and_return__(ds, oauth_tokens_id_column, params)
end
end
- def oauth_token_by_token(token)
+ def oauth_token_by_token_ds(token)
ds = db[oauth_tokens_table]
ds = if oauth_tokens_token_hash_column
- ds.where(oauth_tokens_token_hash_column => generate_token_hash(token))
+ ds.where(Sequel[oauth_tokens_table][oauth_tokens_token_hash_column] => generate_token_hash(token))
else
- ds.where(oauth_tokens_token_column => token)
+ ds.where(Sequel[oauth_tokens_table][oauth_tokens_token_column] => token)
end
- ds.where(Sequel[oauth_tokens_expires_in_column] >= Sequel::CURRENT_TIMESTAMP)
- .where(oauth_tokens_revoked_at_column => nil).first
+ ds.where(Sequel[oauth_tokens_table][oauth_tokens_expires_in_column] >= Sequel::CURRENT_TIMESTAMP)
+ .where(Sequel[oauth_tokens_table][oauth_tokens_revoked_at_column] => nil)
+ end
+
+ def oauth_token_by_token(token)
+ oauth_token_by_token_ds(token).first
end
def oauth_token_by_refresh_token(token, revoked: false)
ds = db[oauth_tokens_table]
#