lib/rodauth/features/oauth_base.rb in rodauth-oauth-0.9.2 vs lib/rodauth/features/oauth_base.rb in rodauth-oauth-0.9.3
- old
+ new
@@ -424,36 +424,44 @@
create_params = {
oauth_tokens_expires_in_column => Sequel.date_add(Sequel::CURRENT_TIMESTAMP, seconds: oauth_token_expires_in)
}.merge(params)
rescue_from_uniqueness_error do
- token = oauth_unique_id_generator
+ access_token = _generate_access_token(create_params)
+ refresh_token = _generate_refresh_token(create_params) if should_generate_refresh_token
+ oauth_token = _store_oauth_token(create_params)
+ oauth_token[oauth_tokens_token_column] = access_token
+ oauth_token[oauth_tokens_refresh_token_column] = refresh_token if refresh_token
+ oauth_token
+ end
+ end
- if oauth_tokens_token_hash_column
- create_params[oauth_tokens_token_hash_column] = generate_token_hash(token)
- else
- create_params[oauth_tokens_token_column] = token
- end
+ def _generate_access_token(params = {})
+ token = oauth_unique_id_generator
- refresh_token = nil
- if should_generate_refresh_token
- refresh_token = oauth_unique_id_generator
+ if oauth_tokens_token_hash_column
+ params[oauth_tokens_token_hash_column] = generate_token_hash(token)
+ else
+ params[oauth_tokens_token_column] = token
+ end
- if oauth_tokens_refresh_token_hash_column
- create_params[oauth_tokens_refresh_token_hash_column] = generate_token_hash(refresh_token)
- else
- create_params[oauth_tokens_refresh_token_column] = refresh_token
- end
- end
- oauth_token = _generate_oauth_token(create_params)
- oauth_token[oauth_tokens_token_column] = token
- oauth_token[oauth_tokens_refresh_token_column] = refresh_token if refresh_token
- oauth_token
+ token
+ end
+
+ def _generate_refresh_token(params)
+ token = oauth_unique_id_generator
+
+ if oauth_tokens_refresh_token_hash_column
+ params[oauth_tokens_refresh_token_hash_column] = generate_token_hash(token)
+ else
+ params[oauth_tokens_refresh_token_column] = token
end
+
+ token
end
- def _generate_oauth_token(params = {})
+ def _store_oauth_token(params = {})
ds = db[oauth_tokens_table]
if __one_oauth_token_per_account
token = __insert_or_update_and_return__(
@@ -575,46 +583,27 @@
def create_oauth_token_from_token(oauth_token, update_params)
redirect_response_error("invalid_grant") unless token_from_application?(oauth_token, oauth_application)
rescue_from_uniqueness_error do
oauth_tokens_ds = db[oauth_tokens_table]
- token = oauth_unique_id_generator
+ access_token = _generate_access_token(update_params)
- if oauth_tokens_token_hash_column
- update_params[oauth_tokens_token_hash_column] = generate_token_hash(token)
+ if oauth_refresh_token_protection_policy == "rotation"
+ update_params = {
+ **update_params,
+ oauth_tokens_oauth_token_id_column => oauth_token[oauth_tokens_id_column],
+ oauth_tokens_account_id_column => oauth_token[oauth_tokens_account_id_column],
+ oauth_tokens_scopes_column => oauth_token[oauth_tokens_scopes_column]
+ }
+
+ refresh_token = _generate_refresh_token(update_params)
else
- update_params[oauth_tokens_token_column] = token
+ refresh_token = param("refresh_token")
end
+ oauth_token = __update_and_return__(oauth_tokens_ds, update_params)
- oauth_token = if oauth_refresh_token_protection_policy == "rotation"
- insert_params = {
- **update_params,
- oauth_tokens_oauth_token_id_column => oauth_token[oauth_tokens_id_column],
- oauth_tokens_scopes_column => oauth_token[oauth_tokens_scopes_column]
- }
-
- refresh_token = oauth_unique_id_generator
-
- if oauth_tokens_refresh_token_hash_column
- insert_params[oauth_tokens_refresh_token_hash_column] = generate_token_hash(refresh_token)
- else
- insert_params[oauth_tokens_refresh_token_column] = refresh_token
- end
-
- # revoke the refresh token
- oauth_tokens_ds.where(oauth_tokens_id_column => oauth_token[oauth_tokens_id_column])
- .update(oauth_tokens_revoked_at_column => Sequel::CURRENT_TIMESTAMP)
-
- insert_params[oauth_tokens_oauth_token_id_column] = oauth_token[oauth_tokens_id_column]
- __insert_and_return__(oauth_tokens_ds, oauth_tokens_id_column, insert_params)
- else
- # includes none
- ds = oauth_tokens_ds.where(oauth_tokens_id_column => oauth_token[oauth_tokens_id_column])
- __update_and_return__(ds, update_params)
- end
-
- oauth_token[oauth_tokens_token_column] = token
- oauth_token[oauth_tokens_refresh_token_column] = refresh_token if refresh_token
+ oauth_token[oauth_tokens_token_column] = access_token
+ oauth_token[oauth_tokens_refresh_token_column] = refresh_token
oauth_token
end
end
def supported_grant_type?(grant_type, expected_grant_type = grant_type)