lib/lhm/migrator.rb in lhm-2.1.0 vs lib/lhm/migrator.rb in lhm-2.2.0

- old
+ new

@@ -11,17 +11,18 @@ # `run` returns a Migration which can be used for the remaining process. class Migrator include Command include SqlHelper - attr_reader :name, :statements, :connection, :conditions + attr_reader :name, :statements, :connection, :conditions, :renames def initialize(table, connection = nil) @connection = connection @origin = table @name = table.destination_name @statements = [] + @renames = {} end # Alter a table with a custom statement # # @example @@ -50,11 +51,11 @@ # end # # @param [String] name Name of the column to add # @param [String] definition Valid SQL column definition def add_column(name, definition) - ddl("alter table `%s` add column `%s` %s" % [@name, name, definition]) + ddl('alter table `%s` add column `%s` %s' % [@name, name, definition]) end # Change an existing column to a new definition # # @example @@ -64,24 +65,45 @@ # end # # @param [String] name Name of the column to change # @param [String] definition Valid SQL column definition def change_column(name, definition) - ddl("alter table `%s` modify column `%s` %s" % [@name, name, definition]) + ddl('alter table `%s` modify column `%s` %s' % [@name, name, definition]) end + # Rename an existing column. + # + # @example + # + # Lhm.change_table(:users) do |m| + # m.rename_column(:login, :username) + # end + # + # @param [String] old Name of the column to change + # @param [String] nu New name to use for the column + def rename_column(old, nu) + col = @origin.columns[old.to_s] + + definition = col[:type] + definition += ' NOT NULL' unless col[:is_nullable] + definition += " DEFAULT #{@connection.quote_value(col[:column_default])}" if col[:column_default] + + ddl('alter table `%s` change column `%s` `%s` %s' % [@name, old, nu, definition]) + @renames[old.to_s] = nu.to_s + end + # Remove a column from a table # # @example # # Lhm.change_table(:users) do |m| # m.remove_column(:comment) # end # # @param [String] name Name of the column to delete def remove_column(name) - ddl("alter table `%s` drop `%s`" % [@name, name]) + ddl('alter table `%s` drop `%s`' % [@name, name]) end # Add an index to a table # # @example @@ -134,14 +156,14 @@ # for compound indexes. # @param [String, Symbol] index_name # Optional name of the index to be removed def remove_index(columns, index_name = nil) columns = [columns].flatten.map(&:to_sym) - from_origin = @origin.indices.find {|name, cols| cols.map(&:to_sym) == columns} + from_origin = @origin.indices.find { |name, cols| cols.map(&:to_sym) == columns } index_name ||= from_origin[0] unless from_origin.nil? index_name ||= idx_name(@origin.name, columns) - ddl("drop index `%s` on `%s`" % [index_name, @name]) + ddl('drop index `%s` on `%s`' % [index_name, @name]) end # Filter the data that is copied into the new table by the provided SQL. # This SQL will be inserted into the copy directly after the "from" # statement - so be sure to use inner/outer join syntax and not cross joins. @@ -164,11 +186,11 @@ unless @connection.table_exists?(@origin.name) error("could not find origin table #{ @origin.name }") end unless @origin.satisfies_primary_key? - error("origin does not satisfy primary key requirements") + error('origin does not satisfy primary key requirements') end dest = @origin.destination_name if @connection.table_exists?(dest) @@ -177,11 +199,11 @@ end def execute destination_create @connection.sql(@statements) - Migration.new(@origin, destination_read, conditions) + Migration.new(@origin, destination_read, conditions, renames) end def destination_create @connection.destination_create(@origin) end @@ -189,12 +211,19 @@ def destination_read Table.parse(@origin.destination_name, connection) end def index_ddl(cols, unique = nil, index_name = nil) - type = unique ? "unique index" : "index" + assert_valid_idx_name(index_name) + type = unique ? 'unique index' : 'index' index_name ||= idx_name(@origin.name, cols) parts = [type, index_name, @name, idx_spec(cols)] - "create %s `%s` on `%s` (%s)" % parts + 'create %s `%s` on `%s` (%s)' % parts + end + + def assert_valid_idx_name(index_name) + if index_name && !(index_name.is_a?(String) || index_name.is_a?(Symbol)) + raise ArgumentError, 'index_name must be a string or symbol' + end end end end