lib/tasks/switchman.rake in switchman-2.2.3 vs lib/tasks/switchman.rake in switchman-3.0.0

- old
+ new

@@ -1,22 +1,24 @@ +# frozen_string_literal: true + module Switchman module Rake def self.filter_database_servers(&block) chain = filter_database_servers_chain # use a local variable so that the current chain is closed over in the following lambda - @filter_database_servers_chain = lambda { |servers| block.call(servers, chain) } + @filter_database_servers_chain = ->(servers) { block.call(servers, chain) } end def self.scope(base_scope = Shard, - database_server: ENV['DATABASE_SERVER'], - shard: ENV['SHARD']) + database_server: ENV['DATABASE_SERVER'], + shard: ENV['SHARD']) servers = DatabaseServer.all if database_server servers = database_server if servers.first == '-' negative = true - servers = servers[1..-1] + servers = servers[1..] end servers = servers.split(',') open = servers.delete('open') servers = servers.map { |server| DatabaseServer.find(server) }.compact @@ -29,34 +31,31 @@ servers = DatabaseServer.all - servers if negative end servers = filter_database_servers_chain.call(servers) - scope = base_scope.order(::Arel.sql("database_server_id IS NOT NULL, database_server_id, id")) + scope = base_scope.order(::Arel.sql('database_server_id IS NOT NULL, database_server_id, id')) if servers != DatabaseServer.all - conditions = ["database_server_id IN (?)", servers.map(&:id)] - conditions.first << " OR database_server_id IS NULL" if servers.include?(Shard.default.database_server) + conditions = ['database_server_id IN (?)', servers.map(&:id)] + conditions.first << ' OR database_server_id IS NULL' if servers.include?(Shard.default.database_server) scope = scope.where(conditions) end - if shard - scope = shard_scope(scope, shard) - end + scope = shard_scope(scope, shard) if shard scope end def self.options - # we still pass through both of these options for back-compat purposes - { parallel: ENV['PARALLEL']&.to_i, max_procs: ENV['MAX_PARALLEL_PROCS']&.to_i } + { parallel: ENV['PARALLEL'].to_i, max_procs: ENV['MAX_PARALLEL_PROCS'] } end - # categories - an array or proc, to activate as the current shard during the + # classes - an array or proc, to activate as the current shard during the # task. tasks which modify the schema may want to pass all categories in # so that schema updates for non-default tables happen against all shards. # this is handled automatically for the default migration tasks, below. - def self.shardify_task(task_name, categories: [:primary]) + def self.shardify_task(task_name, classes: [::ActiveRecord::Base]) old_task = ::Rake::Task[task_name] old_actions = old_task.actions.dup old_task.actions.clear old_task.enhance do |*task_args| @@ -65,47 +64,34 @@ TestHelper.recreate_persistent_test_shards(dont_create: true) end ::GuardRail.activate(:deploy) do Shard.default.database_server.unguard do - begin - categories = categories.call if categories.respond_to?(:call) - Shard.with_each_shard(scope, categories, options) do - shard = Shard.current - puts "#{shard.id}: #{shard.description}" - ::ActiveRecord::Base.connection_pool.spec.config[:shard_name] = Shard.current.name - if ::Rails.version < '6.0' - ::ActiveRecord::Base.configurations[::Rails.env] = ::ActiveRecord::Base.connection_pool.spec.config.stringify_keys - else - # Adopted from the deprecated code that currently lives in rails proper - remaining_configs = ::ActiveRecord::Base.configurations.configurations.reject { |db_config| db_config.env_name == ::Rails.env } - new_config = ::ActiveRecord::DatabaseConfigurations.new(::Rails.env => - ::ActiveRecord::Base.connection_pool.spec.config.stringify_keys).configurations - new_configs = remaining_configs + new_config - - ::ActiveRecord::Base.configurations = new_configs - end - shard.database_server.unguard do - old_actions.each { |action| action.call(*task_args) } - end - nil + classes = classes.call if classes.respond_to?(:call) + Shard.with_each_shard(scope, classes, **options) do + shard = Shard.current + puts "#{shard.id}: #{shard.description}" + + shard.database_server.unguard do + old_actions.each { |action| action.call(*task_args) } end - rescue => e - puts "Exception from #{e.current_shard.id}: #{e.current_shard.description}" if options[:parallel].to_i != 0 - raise + nil end + rescue => e + puts "Exception from #{e.current_shard.id}: #{e.current_shard.description}" if options[:parallel] != 0 + raise + + #::ActiveRecord::Base.configurations = old_configurations end end end end - %w{db:migrate db:migrate:up db:migrate:down db:rollback}.each do |task_name| - shardify_task(task_name, categories: ->{ Shard.categories }) + %w[db:migrate db:migrate:up db:migrate:down db:rollback].each do |task_name| + shardify_task(task_name, classes: -> { Shard.sharded_models }) end - private - def self.shard_scope(scope, raw_shard_ids) raw_shard_ids = raw_shard_ids.split(',') shard_ids = [] negative_shard_ids = [] @@ -124,10 +110,11 @@ when '-primary' negative_shard_ids.concat(Shard.primary.pluck(:id)) when /^(-?)(\d+)?\.\.(\.)?(\d+)?$/ negative, start, open, finish = $1.present?, $2, $3.present?, $4 raise "Invalid shard id or range: #{id}" unless start || finish + range = [] range << "id>=#{start}" if start range << "id<#{'=' unless open}#{finish}" if finish (negative ? negative_ranges : ranges) << "(#{range.join(' AND ')})" when /^-(\d+)$/ @@ -135,16 +122,16 @@ when /^\d+$/ shard_ids << id.to_i when %r{^(-?\d+)/(\d+)$} numerator = $1.to_i denominator = $2.to_i - if numerator == 0 || numerator.abs > denominator - raise "Invalid fractional chunk: #{id}" - end + raise "Invalid fractional chunk: #{id}" if numerator.zero? || numerator.abs > denominator + # one chunk means everything if denominator == 1 next if numerator == 1 + return scope.none end total_shard_count ||= scope.count per_chunk = (total_shard_count / denominator.to_f).ceil @@ -155,78 +142,75 @@ subscope = Shard.select(:id).order(:id) select = [] if index != 1 subscope = subscope.offset(per_chunk * (index - 1)) - select << "MIN(id) AS min_id" + select << 'MIN(id) AS min_id' end if index != denominator subscope = subscope.limit(per_chunk) - select << "MAX(id) AS max_id" + select << 'MAX(id) AS max_id' end - result = Shard.from(subscope).select(select.join(", ")).to_a.first - if index == 1 - range = "id<=#{result['max_id']}" - elsif index == denominator - range = "id>=#{result['min_id']}" - else - range = "(id>=#{result['min_id']} AND id<=#{result['max_id']})" - end + result = Shard.from(subscope).select(select.join(', ')).to_a.first + range = case index + when 1 + "id<=#{result['max_id']}" + when denominator + "id>=#{result['min_id']}" + else + "(id>=#{result['min_id']} AND id<=#{result['max_id']})" + end - (numerator < 0 ? negative_ranges : ranges) << range - else + (numerator.negative? ? negative_ranges : ranges) << range + else raise "Invalid shard id or range: #{id}" end end shard_ids.uniq! negative_shard_ids.uniq! unless shard_ids.empty? shard_ids -= negative_shard_ids - if shard_ids.empty? && ranges.empty? - return scope.none - end + return scope.none if shard_ids.empty? && ranges.empty? + # we already trimmed them all out; no need to make the server do it as well negative_shard_ids = [] if ranges.empty? end conditions = [] positive_queries = [] - unless ranges.empty? - positive_queries << ranges.join(" OR ") - end + positive_queries << ranges.join(' OR ') unless ranges.empty? unless shard_ids.empty? - positive_queries << "id IN (?)" + positive_queries << 'id IN (?)' conditions << shard_ids end - positive_query = positive_queries.join(" OR ") + positive_query = positive_queries.join(' OR ') scope = scope.where(positive_query, *conditions) unless positive_queries.empty? - scope = scope.where("NOT (#{negative_ranges.join(" OR")})") unless negative_ranges.empty? - scope = scope.where("id NOT IN (?)", negative_shard_ids) unless negative_shard_ids.empty? + scope = scope.where("NOT (#{negative_ranges.join(' OR')})") unless negative_ranges.empty? + scope = scope.where('id NOT IN (?)', negative_shard_ids) unless negative_shard_ids.empty? scope end def self.filter_database_servers_chain @filter_database_servers_chain ||= ->(servers) { servers } end end module ActiveRecord module PostgreSQLDatabaseTasks - def structure_dump(filename, extra_flags=nil) + def structure_dump(filename, extra_flags = nil) set_psql_env args = ['-s', '-x', '-O', '-f', filename] args.concat(Array(extra_flags)) if extra_flags - search_path = configuration['schema_search_path'] shard = Shard.current.name serialized_search_path = shard args << "--schema=#{Shellwords.escape(shard)}" args << configuration['database'] run_cmd('pg_dump', args, 'dumping') - File.open(filename, "a") { |f| f << "SET search_path TO #{serialized_search_path};\n\n" } + File.open(filename, 'a') { |f| f << "SET search_path TO #{serialized_search_path};\n\n" } end end end end