require 'active_record/connection_adapters/abstract_adapter' module ActiveRecord class Base # Establishes a connection to the database that's used by all Active Record objects def self.postgresql_connection(config) # :nodoc: require_library_or_gem 'postgres' unless self.class.const_defined?(:PGconn) config = config.symbolize_keys host = config[:host] port = config[:port] || 5432 unless host.nil? username = config[:username].to_s password = config[:password].to_s min_messages = config[:min_messages] if config.has_key?(:database) database = config[:database] else raise ArgumentError, "No database specified. Missing argument: database." end pga = ConnectionAdapters::PostgreSQLAdapter.new( PGconn.connect(host, port, "", "", database, username, password), logger, config ) PGconn.translate_results = false if PGconn.respond_to? :translate_results= pga.schema_search_path = config[:schema_search_path] || config[:schema_order] pga end end module ConnectionAdapters # The PostgreSQL adapter works both with the C-based (http://www.postgresql.jp/interfaces/ruby/) and the Ruby-base # (available both as gem and from http://rubyforge.org/frs/?group_id=234&release_id=1145) drivers. # # Options: # # * :host -- Defaults to localhost # * :port -- Defaults to 5432 # * :username -- Defaults to nothing # * :password -- Defaults to nothing # * :database -- The name of the database. No default, must be provided. # * :schema_search_path -- An optional schema search path for the connection given as a string of comma-separated schema names. This is backward-compatible with the :schema_order option. # * :encoding -- An optional client encoding that is using in a SET client_encoding TO call on connection. # * :min_messages -- An optional client min messages that is using in a SET client_min_messages TO call on connection. class PostgreSQLAdapter < AbstractAdapter def adapter_name 'PostgreSQL' end def initialize(connection, logger, config = {}) super(connection, logger) @config = config configure_connection end # Is this connection alive and ready for queries? def active? if @connection.respond_to?(:status) @connection.status == PGconn::CONNECTION_OK else @connection.query 'SELECT 1' true end # postgres-pr raises a NoMethodError when querying if no conn is available rescue PGError, NoMethodError false end # Close then reopen the connection. def reconnect! # TODO: postgres-pr doesn't have PGconn#reset. if @connection.respond_to?(:reset) @connection.reset configure_connection end end def disconnect! # Both postgres and postgres-pr respond to :close @connection.close rescue nil end def native_database_types { :primary_key => "serial primary key", :string => { :name => "character varying", :limit => 255 }, :text => { :name => "text" }, :integer => { :name => "integer" }, :float => { :name => "float" }, :datetime => { :name => "timestamp" }, :timestamp => { :name => "timestamp" }, :time => { :name => "time" }, :date => { :name => "date" }, :binary => { :name => "bytea" }, :boolean => { :name => "boolean" } } end def supports_migrations? true end def table_alias_length 63 end # QUOTING ================================================== def quote(value, column = nil) if value.kind_of?(String) && column && column.type == :binary "'#{escape_bytea(value)}'" else super end end def quote_column_name(name) %("#{name}") end # DATABASE STATEMENTS ====================================== def select_all(sql, name = nil) #:nodoc: select(sql, name) end def select_one(sql, name = nil) #:nodoc: result = select(sql, name) result.first if result end def insert(sql, name = nil, pk = nil, id_value = nil, sequence_name = nil) #:nodoc: execute(sql, name) table = sql.split(" ", 4)[2] id_value || last_insert_id(table, sequence_name || default_sequence_name(table, pk)) end def query(sql, name = nil) #:nodoc: log(sql, name) { @connection.query(sql) } end def execute(sql, name = nil) #:nodoc: log(sql, name) { @connection.exec(sql) } end def update(sql, name = nil) #:nodoc: execute(sql, name).cmdtuples end alias_method :delete, :update #:nodoc: def begin_db_transaction #:nodoc: execute "BEGIN" end def commit_db_transaction #:nodoc: execute "COMMIT" end def rollback_db_transaction #:nodoc: execute "ROLLBACK" end # SCHEMA STATEMENTS ======================================== # Return the list of all tables in the schema search path. def tables(name = nil) #:nodoc: schemas = schema_search_path.split(/,/).map { |p| quote(p) }.join(',') query(<<-SQL, name).map { |row| row[0] } SELECT tablename FROM pg_tables WHERE schemaname IN (#{schemas}) SQL end def indexes(table_name, name = nil) #:nodoc: result = query(<<-SQL, name) SELECT i.relname, d.indisunique, a.attname FROM pg_class t, pg_class i, pg_index d, pg_attribute a WHERE i.relkind = 'i' AND d.indexrelid = i.oid AND d.indisprimary = 'f' AND t.oid = d.indrelid AND t.relname = '#{table_name}' AND a.attrelid = t.oid AND ( d.indkey[0]=a.attnum OR d.indkey[1]=a.attnum OR d.indkey[2]=a.attnum OR d.indkey[3]=a.attnum OR d.indkey[4]=a.attnum OR d.indkey[5]=a.attnum OR d.indkey[6]=a.attnum OR d.indkey[7]=a.attnum OR d.indkey[8]=a.attnum OR d.indkey[9]=a.attnum ) ORDER BY i.relname SQL current_index = nil indexes = [] result.each do |row| if current_index != row[0] indexes << IndexDefinition.new(table_name, row[0], row[1] == "t", []) current_index = row[0] end indexes.last.columns << row[2] end indexes end def columns(table_name, name = nil) #:nodoc: column_definitions(table_name).collect do |name, type, default, notnull| Column.new(name, default_value(default), translate_field_type(type), notnull == "f") end end # Set the schema search path to a string of comma-separated schema names. # Names beginning with $ are quoted (e.g. $user => '$user') # See http://www.postgresql.org/docs/8.0/interactive/ddl-schemas.html def schema_search_path=(schema_csv) #:nodoc: if schema_csv execute "SET search_path TO #{schema_csv}" @schema_search_path = nil end end def schema_search_path #:nodoc: @schema_search_path ||= query('SHOW search_path')[0][0] end def default_sequence_name(table_name, pk = nil) default_pk, default_seq = pk_and_sequence_for(table_name) default_seq || "#{table_name}_#{pk || default_pk || 'id'}_seq" end # Resets sequence to the max value of the table's pk if present. def reset_pk_sequence!(table, pk = nil, sequence = nil) unless pk and sequence default_pk, default_sequence = pk_and_sequence_for(table) pk ||= default_pk sequence ||= default_sequence end if pk if sequence select_value <<-end_sql, 'Reset sequence' SELECT setval('#{sequence}', (SELECT COALESCE(MAX(#{pk})+(SELECT increment_by FROM #{sequence}), (SELECT min_value FROM #{sequence})) FROM #{table}), false) end_sql else @logger.warn "#{table} has primary key #{pk} with no default sequence" if @logger end end end # Find a table's primary key and sequence. def pk_and_sequence_for(table) # First try looking for a sequence with a dependency on the # given table's primary key. result = execute(<<-end_sql, 'PK and serial sequence')[0] SELECT attr.attname, name.nspname, seq.relname FROM pg_class seq, pg_attribute attr, pg_depend dep, pg_namespace name, pg_constraint cons WHERE seq.oid = dep.objid AND seq.relnamespace = name.oid AND seq.relkind = 'S' AND attr.attrelid = dep.refobjid AND attr.attnum = dep.refobjsubid AND attr.attrelid = cons.conrelid AND attr.attnum = cons.conkey[1] AND cons.contype = 'p' AND dep.refobjid = '#{table}'::regclass end_sql if result.nil? or result.empty? # If that fails, try parsing the primary key's default value. # Support the 7.x and 8.0 nextval('foo'::text) as well as # the 8.1+ nextval('foo'::regclass). # TODO: assumes sequence is in same schema as table. result = execute(<<-end_sql, 'PK and custom sequence')[0] SELECT attr.attname, name.nspname, split_part(def.adsrc, '\\\'', 2) FROM pg_class t JOIN pg_namespace name ON (t.relnamespace = name.oid) JOIN pg_attribute attr ON (t.oid = attrelid) JOIN pg_attrdef def ON (adrelid = attrelid AND adnum = attnum) JOIN pg_constraint cons ON (conrelid = adrelid AND adnum = conkey[1]) WHERE t.oid = '#{table}'::regclass AND cons.contype = 'p' AND def.adsrc ~* 'nextval' end_sql end # check for existence of . in sequence name as in public.foo_sequence. if it does not exist, join the current namespace result.last['.'] ? [result.first, result.last] : [result.first, "#{result[1]}.#{result[2]}"] rescue nil end def rename_table(name, new_name) execute "ALTER TABLE #{name} RENAME TO #{new_name}" end def add_column(table_name, column_name, type, options = {}) execute("ALTER TABLE #{table_name} ADD #{column_name} #{type_to_sql(type, options[:limit])}") execute("ALTER TABLE #{table_name} ALTER #{column_name} SET NOT NULL") if options[:null] == false change_column_default(table_name, column_name, options[:default]) unless options[:default].nil? end def change_column(table_name, column_name, type, options = {}) #:nodoc: begin execute "ALTER TABLE #{table_name} ALTER #{column_name} TYPE #{type_to_sql(type, options[:limit])}" rescue ActiveRecord::StatementInvalid # This is PG7, so we use a more arcane way of doing it. begin_db_transaction add_column(table_name, "#{column_name}_ar_tmp", type, options) execute "UPDATE #{table_name} SET #{column_name}_ar_tmp = CAST(#{column_name} AS #{type_to_sql(type, options[:limit])})" remove_column(table_name, column_name) rename_column(table_name, "#{column_name}_ar_tmp", column_name) commit_db_transaction end change_column_default(table_name, column_name, options[:default]) unless options[:default].nil? end def change_column_default(table_name, column_name, default) #:nodoc: execute "ALTER TABLE #{table_name} ALTER COLUMN #{column_name} SET DEFAULT '#{default}'" end def rename_column(table_name, column_name, new_column_name) #:nodoc: execute "ALTER TABLE #{table_name} RENAME COLUMN #{column_name} TO #{new_column_name}" end def remove_index(table_name, options) #:nodoc: execute "DROP INDEX #{index_name(table_name, options)}" end private BYTEA_COLUMN_TYPE_OID = 17 TIMESTAMPOID = 1114 TIMESTAMPTZOID = 1184 def configure_connection if @config[:encoding] execute("SET client_encoding TO '#{@config[:encoding]}'") end if @config[:min_messages] execute("SET client_min_messages TO '#{@config[:min_messages]}'") end end def last_insert_id(table, sequence_name) Integer(select_value("SELECT currval('#{sequence_name}')")) end def select(sql, name = nil) res = execute(sql, name) results = res.result rows = [] if results.length > 0 fields = res.fields results.each do |row| hashed_row = {} row.each_index do |cel_index| column = row[cel_index] case res.type(cel_index) when BYTEA_COLUMN_TYPE_OID column = unescape_bytea(column) when TIMESTAMPTZOID, TIMESTAMPOID column = cast_to_time(column) end hashed_row[fields[cel_index]] = column end rows << hashed_row end end return rows end def escape_bytea(s) if PGconn.respond_to? :escape_bytea self.class.send(:define_method, :escape_bytea) do |s| PGconn.escape_bytea(s) if s end else self.class.send(:define_method, :escape_bytea) do |s| if s result = '' s.each_byte { |c| result << sprintf('\\\\%03o', c) } result end end end escape_bytea(s) end def unescape_bytea(s) if PGconn.respond_to? :unescape_bytea self.class.send(:define_method, :unescape_bytea) do |s| PGconn.unescape_bytea(s) if s end else self.class.send(:define_method, :unescape_bytea) do |s| if s result = '' i, max = 0, s.size while i < max char = s[i] if char == ?\\ if s[i+1] == ?\\ char = ?\\ i += 1 else char = s[i+1..i+3].oct i += 3 end end result << char i += 1 end result end end end unescape_bytea(s) end # Query a table's column names, default values, and types. # # The underlying query is roughly: # SELECT column.name, column.type, default.value # FROM column LEFT JOIN default # ON column.table_id = default.table_id # AND column.num = default.column_num # WHERE column.table_id = get_table_id('table_name') # AND column.num > 0 # AND NOT column.is_dropped # ORDER BY column.num # # If the table name is not prefixed with a schema, the database will # take the first match from the schema search path. # # Query implementation notes: # - format_type includes the column size constraint, e.g. varchar(50) # - ::regclass is a function that gives the id for a table name def column_definitions(table_name) query <<-end_sql SELECT a.attname, format_type(a.atttypid, a.atttypmod), d.adsrc, a.attnotnull FROM pg_attribute a LEFT JOIN pg_attrdef d ON a.attrelid = d.adrelid AND a.attnum = d.adnum WHERE a.attrelid = '#{table_name}'::regclass AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum end_sql end # Translate PostgreSQL-specific types into simplified SQL types. # These are special cases; standard types are handled by # ConnectionAdapters::Column#simplified_type. def translate_field_type(field_type) # Match the beginning of field_type since it may have a size constraint on the end. case field_type when /^timestamp/i then 'datetime' when /^real|^money/i then 'float' when /^interval/i then 'string' # geometric types (the line type is currently not implemented in postgresql) when /^(?:point|lseg|box|"?path"?|polygon|circle)/i then 'string' when /^bytea/i then 'binary' else field_type # Pass through standard types. end end def default_value(value) # Boolean types return "t" if value =~ /true/i return "f" if value =~ /false/i # Char/String/Bytea type values return $1 if value =~ /^'(.*)'::(bpchar|text|character varying|bytea)$/ # Numeric values return value if value =~ /^-?[0-9]+(\.[0-9]*)?/ # Fixed dates / times return $1 if value =~ /^'(.+)'::(date|timestamp)/ # Anything else is blank, some user type, or some function # and we can't know the value of that, so return nil. return nil end # Only needed for DateTime instances def cast_to_time(value) return value unless value.class == DateTime v = value time_array = [v.year, v.month, v.day, v.hour, v.min, v.sec] Time.send(Base.default_timezone, *time_array) rescue nil end end end end