lib/rodauth/features/oauth_base.rb in rodauth-oauth-0.9.2 vs lib/rodauth/features/oauth_base.rb in rodauth-oauth-0.9.3

- old
+ new

@@ -424,36 +424,44 @@ create_params = { oauth_tokens_expires_in_column => Sequel.date_add(Sequel::CURRENT_TIMESTAMP, seconds: oauth_token_expires_in) }.merge(params) rescue_from_uniqueness_error do - token = oauth_unique_id_generator + access_token = _generate_access_token(create_params) + refresh_token = _generate_refresh_token(create_params) if should_generate_refresh_token + oauth_token = _store_oauth_token(create_params) + oauth_token[oauth_tokens_token_column] = access_token + oauth_token[oauth_tokens_refresh_token_column] = refresh_token if refresh_token + oauth_token + end + end - if oauth_tokens_token_hash_column - create_params[oauth_tokens_token_hash_column] = generate_token_hash(token) - else - create_params[oauth_tokens_token_column] = token - end + def _generate_access_token(params = {}) + token = oauth_unique_id_generator - refresh_token = nil - if should_generate_refresh_token - refresh_token = oauth_unique_id_generator + if oauth_tokens_token_hash_column + params[oauth_tokens_token_hash_column] = generate_token_hash(token) + else + params[oauth_tokens_token_column] = token + end - if oauth_tokens_refresh_token_hash_column - create_params[oauth_tokens_refresh_token_hash_column] = generate_token_hash(refresh_token) - else - create_params[oauth_tokens_refresh_token_column] = refresh_token - end - end - oauth_token = _generate_oauth_token(create_params) - oauth_token[oauth_tokens_token_column] = token - oauth_token[oauth_tokens_refresh_token_column] = refresh_token if refresh_token - oauth_token + token + end + + def _generate_refresh_token(params) + token = oauth_unique_id_generator + + if oauth_tokens_refresh_token_hash_column + params[oauth_tokens_refresh_token_hash_column] = generate_token_hash(token) + else + params[oauth_tokens_refresh_token_column] = token end + + token end - def _generate_oauth_token(params = {}) + def _store_oauth_token(params = {}) ds = db[oauth_tokens_table] if __one_oauth_token_per_account token = __insert_or_update_and_return__( @@ -575,46 +583,27 @@ def create_oauth_token_from_token(oauth_token, update_params) redirect_response_error("invalid_grant") unless token_from_application?(oauth_token, oauth_application) rescue_from_uniqueness_error do oauth_tokens_ds = db[oauth_tokens_table] - token = oauth_unique_id_generator + access_token = _generate_access_token(update_params) - if oauth_tokens_token_hash_column - update_params[oauth_tokens_token_hash_column] = generate_token_hash(token) + if oauth_refresh_token_protection_policy == "rotation" + update_params = { + **update_params, + oauth_tokens_oauth_token_id_column => oauth_token[oauth_tokens_id_column], + oauth_tokens_account_id_column => oauth_token[oauth_tokens_account_id_column], + oauth_tokens_scopes_column => oauth_token[oauth_tokens_scopes_column] + } + + refresh_token = _generate_refresh_token(update_params) else - update_params[oauth_tokens_token_column] = token + refresh_token = param("refresh_token") end + oauth_token = __update_and_return__(oauth_tokens_ds, update_params) - oauth_token = if oauth_refresh_token_protection_policy == "rotation" - insert_params = { - **update_params, - oauth_tokens_oauth_token_id_column => oauth_token[oauth_tokens_id_column], - oauth_tokens_scopes_column => oauth_token[oauth_tokens_scopes_column] - } - - refresh_token = oauth_unique_id_generator - - if oauth_tokens_refresh_token_hash_column - insert_params[oauth_tokens_refresh_token_hash_column] = generate_token_hash(refresh_token) - else - insert_params[oauth_tokens_refresh_token_column] = refresh_token - end - - # revoke the refresh token - oauth_tokens_ds.where(oauth_tokens_id_column => oauth_token[oauth_tokens_id_column]) - .update(oauth_tokens_revoked_at_column => Sequel::CURRENT_TIMESTAMP) - - insert_params[oauth_tokens_oauth_token_id_column] = oauth_token[oauth_tokens_id_column] - __insert_and_return__(oauth_tokens_ds, oauth_tokens_id_column, insert_params) - else - # includes none - ds = oauth_tokens_ds.where(oauth_tokens_id_column => oauth_token[oauth_tokens_id_column]) - __update_and_return__(ds, update_params) - end - - oauth_token[oauth_tokens_token_column] = token - oauth_token[oauth_tokens_refresh_token_column] = refresh_token if refresh_token + oauth_token[oauth_tokens_token_column] = access_token + oauth_token[oauth_tokens_refresh_token_column] = refresh_token oauth_token end end def supported_grant_type?(grant_type, expected_grant_type = grant_type)