lib/lhm/migrator.rb in lhm-2.1.0 vs lib/lhm/migrator.rb in lhm-2.2.0
- old
+ new
@@ -11,17 +11,18 @@
# `run` returns a Migration which can be used for the remaining process.
class Migrator
include Command
include SqlHelper
- attr_reader :name, :statements, :connection, :conditions
+ attr_reader :name, :statements, :connection, :conditions, :renames
def initialize(table, connection = nil)
@connection = connection
@origin = table
@name = table.destination_name
@statements = []
+ @renames = {}
end
# Alter a table with a custom statement
#
# @example
@@ -50,11 +51,11 @@
# end
#
# @param [String] name Name of the column to add
# @param [String] definition Valid SQL column definition
def add_column(name, definition)
- ddl("alter table `%s` add column `%s` %s" % [@name, name, definition])
+ ddl('alter table `%s` add column `%s` %s' % [@name, name, definition])
end
# Change an existing column to a new definition
#
# @example
@@ -64,24 +65,45 @@
# end
#
# @param [String] name Name of the column to change
# @param [String] definition Valid SQL column definition
def change_column(name, definition)
- ddl("alter table `%s` modify column `%s` %s" % [@name, name, definition])
+ ddl('alter table `%s` modify column `%s` %s' % [@name, name, definition])
end
+ # Rename an existing column.
+ #
+ # @example
+ #
+ # Lhm.change_table(:users) do |m|
+ # m.rename_column(:login, :username)
+ # end
+ #
+ # @param [String] old Name of the column to change
+ # @param [String] nu New name to use for the column
+ def rename_column(old, nu)
+ col = @origin.columns[old.to_s]
+
+ definition = col[:type]
+ definition += ' NOT NULL' unless col[:is_nullable]
+ definition += " DEFAULT #{@connection.quote_value(col[:column_default])}" if col[:column_default]
+
+ ddl('alter table `%s` change column `%s` `%s` %s' % [@name, old, nu, definition])
+ @renames[old.to_s] = nu.to_s
+ end
+
# Remove a column from a table
#
# @example
#
# Lhm.change_table(:users) do |m|
# m.remove_column(:comment)
# end
#
# @param [String] name Name of the column to delete
def remove_column(name)
- ddl("alter table `%s` drop `%s`" % [@name, name])
+ ddl('alter table `%s` drop `%s`' % [@name, name])
end
# Add an index to a table
#
# @example
@@ -134,14 +156,14 @@
# for compound indexes.
# @param [String, Symbol] index_name
# Optional name of the index to be removed
def remove_index(columns, index_name = nil)
columns = [columns].flatten.map(&:to_sym)
- from_origin = @origin.indices.find {|name, cols| cols.map(&:to_sym) == columns}
+ from_origin = @origin.indices.find { |name, cols| cols.map(&:to_sym) == columns }
index_name ||= from_origin[0] unless from_origin.nil?
index_name ||= idx_name(@origin.name, columns)
- ddl("drop index `%s` on `%s`" % [index_name, @name])
+ ddl('drop index `%s` on `%s`' % [index_name, @name])
end
# Filter the data that is copied into the new table by the provided SQL.
# This SQL will be inserted into the copy directly after the "from"
# statement - so be sure to use inner/outer join syntax and not cross joins.
@@ -164,11 +186,11 @@
unless @connection.table_exists?(@origin.name)
error("could not find origin table #{ @origin.name }")
end
unless @origin.satisfies_primary_key?
- error("origin does not satisfy primary key requirements")
+ error('origin does not satisfy primary key requirements')
end
dest = @origin.destination_name
if @connection.table_exists?(dest)
@@ -177,11 +199,11 @@
end
def execute
destination_create
@connection.sql(@statements)
- Migration.new(@origin, destination_read, conditions)
+ Migration.new(@origin, destination_read, conditions, renames)
end
def destination_create
@connection.destination_create(@origin)
end
@@ -189,12 +211,19 @@
def destination_read
Table.parse(@origin.destination_name, connection)
end
def index_ddl(cols, unique = nil, index_name = nil)
- type = unique ? "unique index" : "index"
+ assert_valid_idx_name(index_name)
+ type = unique ? 'unique index' : 'index'
index_name ||= idx_name(@origin.name, cols)
parts = [type, index_name, @name, idx_spec(cols)]
- "create %s `%s` on `%s` (%s)" % parts
+ 'create %s `%s` on `%s` (%s)' % parts
+ end
+
+ def assert_valid_idx_name(index_name)
+ if index_name && !(index_name.is_a?(String) || index_name.is_a?(Symbol))
+ raise ArgumentError, 'index_name must be a string or symbol'
+ end
end
end
end