require "pg" require 'ostruct' require "pg_conn/version" require "pg_conn/role_methods" require "pg_conn/schema_methods" require "pg_conn/rdbms_methods" module PgConn class Error < StandardError; end # Can be raised in #transaction blocks to rollback changes class Rollback < Error; end # Return a PgConn::Connection object. TODO: A block argument def self.new(*args, &block) Connection.new(*args, &block) end # Make the PgConn module pretend it has PgConn instances def self.===(element) element.is_a?(PgConn::Connection) or super end # Returns a PgConn::Connection object (aka. a PgConn object). It's arguments # can be an existing connection that will just be returned or a set of # PgConn::Connection#initialize arguments that will be used to create a new # PgConn::Connection object def self.ensure(*args) if args.size == 1 && args.first.is_a?(PgConn::Connection) args.first else PgConn::Connection.new(*args) end end # All results from the database are converted into native Ruby types class Connection # Make PgConn::Connection pretend to be an instance of the PgConn module def is_a?(klass) klass == PgConn or super end # The PG::Connection object attr_reader :pg_connection # The class of column names (Symbol or String). Default is Symbol attr_reader :field_name_class # Name of user def user() @pg_connection.user end alias_method :username, :user # Obsolete # Name of database def name() @pg_connection.db end alias_method :database, :name # Obsolete # Database manipulation methods: #exist?, #create, #drop, #list attr_reader :rdbms # Role manipulation methods: #exist?, #create, #drop, #list attr_reader :role # Schema manipulation methods: #exist?, #create, #drop, #list, and # #exist?/#list for relations/tables/views/columns attr_reader :schema # The transaction timestamp of the most recent SQL statement executed by # #exec or #transaction block attr_reader :timestamp # PG::Error object if the last statement failed; otherwise nil attr_reader :err # Last error message. The error message is the first line of the PG error # message that may contain additional info. It doesn't contain a # terminating newline def errmsg = err&.message =~ /^ERROR:\s*(.*?)\n/m && $1.capitalize # The one-based line number of the last PG::Error or nil if absent in the # Postgres error message def errline = err&.message =~ /\n\s*LINE\s+(\d+):/m && $1.to_i # The one-based character number of the error in the last PG::Error or nil # if absent in the Postgres error message def errchar = err&.message =~ /\n(\s*LINE\s+\d+: ).*?\n(\s+)\^\n/m && ($2.size - $1.size + 1) # :call-seq: # initialize(dbname = nil, user = nil, field_name_class: Symbol) # initialize(connection_hash, field_name_class: Symbol) # initialize(connection_string, field_name_class: Symbol) # initialize(host, port, dbname, user, password, field_name_class: Symbol) # initialize(array, field_name_class: Symbol) # initialize(pg_connection_object) # # Initialize a connection object and connect to the database # # The possible keys of the connection hash are :host, :port, :dbname, :user, # and :password. The connection string can either be a space-separated list # of = pairs with the same keys as the hash, or a URI with the # format 'postgres[ql]://[user[:password]@][host][:port][/name] # # If given an array argument, PgConn will not connect to the database and # instead write its commands to the array. In this case, methods extracting # values from the database (eg. #value) will return nil or raise an # exception # # The last variant is used to establish a PgConn from an existing # connection. It doesn't change the connection settings and is not # recommended except in cases where you want to piggyback on an existing # connection (eg. a Rails connection) # # The :field_name_class option controls the Ruby type of column names. It can be # Symbol (the default) or String. The :timestamp option is used # internally to set the timestamp for transactions # # Note that the connection hash and the connection string may support more # parameters than documented here. Consult # https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING # for the full list # # TODO: Change to 'initialize(*args, **opts)' def initialize(*args) if args.last.is_a?(Hash) @field_name_class = args.last.delete(:field_name_class) || Symbol @timestamp = args.last.delete(:timestamp) args.pop if args.last.empty? else @field_name_class = Symbol end # else # We assume that the current user is a postgres superuser # @db = PgConn.new("template0") using_existing_connection = false @pg_connection = if args.size == 0 make_connection elsif args.size == 1 case arg = args.first when PG::Connection using_existing_connection = true arg when String if arg =~ /=/ make_connection arg elsif arg =~ /\// make_connection arg else make_connection dbname: arg end when Hash make_connection **arg when Array @pg_commands = arg nil else raise Error, "Illegal argument type: #{arg.class}" end elsif args.size == 2 make_connection dbname: args.first, user: args.last elsif args.size == 5 make_connection args[0], args[1], nil, nil, args[2], args[3], args[4] else raise Error, "Illegal number of parameters: #{args.size}" end if @pg_connection && !using_existing_connection # Auto-convert to ruby types type_map = PG::BasicTypeMapForResults.new(@pg_connection) # Use String as default type. Kills 'Warning: no type cast defined for # type "uuid" with oid 2950..' warnings type_map.default_type_map = PG::TypeMapAllStrings.new # Timestamp decoder type_map.add_coder PG::TextDecoder::Timestamp.new( # Timestamp without time zone oid: 1114, flags: PG::Coder::TIMESTAMP_DB_UTC | PG::Coder::TIMESTAMP_APP_UTC) # Decode anonymous records but note that this is only useful to convert the # outermost structure into an array, the elements are not decoded and are # returned as strings. It is best to avoid anonymous records if possible type_map.add_coder PG::TextDecoder::Record.new( oid: 2249 ) @pg_connection.type_map_for_results = type_map @pg_connection.field_name_type = @field_name_class.to_s.downcase.to_sym # Use symbol field names @pg_connection.exec "set client_min_messages to warning;" # Silence warnings end @schema = SchemaMethods.new(self) @role = RoleMethods.new(self) @rdbms = RdbmsMethods.new(self) @timestamp = nil @savepoints = nil # Stack of savepoint names. Nil if no transaction in progress end # Close the database connection def terminate() @pg_connection.close if @pg_connection && !@pg_connection.finished? end def self.new(*args, **opts, &block) if block_given? begin object = Connection.allocate object.send(:initialize, *args, **opts) yield(object) # if object.pg_connection ensure object.terminate if object.pg_connection end else super(*args, **opts) end end # :call-seq: # exist?(query) # exist?(table, id) # eists?(table, where_clause) # # Return true iff the query returns exactly one value def exist?(*args) arg1, arg2 = *args query = case arg2 when Integer; "select from #{arg1} where id = #{arg2}" when String; "select from #{arg1} where #{arg2}" when NilClass; arg1 end count(query) == 1 end # :call-seq: # count(query) # count(table, where_clause = nil) # # Return true if the table or the result of the query is empty def empty?(arg, where_clause = nil) if arg =~ /\s/ value "select count(*) from (#{arg} limit 1) as inner_query" elsif where_clause value "select count(*) from (select 1 from #{arg} where #{where_clause} limit 1) as inner_query" else value "select count(*) from (select 1 from #{arg} limit 1) as inner_query" end == 0 end # :call-seq: # count(query) # count(table_name, where_clause = nil) # # The number of records in the table or in the query def count(arg, where_clause = nil) if arg =~ /\s/ value("select count(*) from (#{arg}) as inner_query") else value("select count(*) from #{arg}" + (where_clause ? " where #{where_clause}" : "")) end end # Return a single value. It is an error if the query doesn't return a # single record with a single column. If :transaction is true, the query # will be executed in a transaction and be committed it :commit is true # (the default). This can be used in 'insert ... returning ...' statements def value(query) #, transaction: false, commit: true) r = pg_exec(query) check_1c(r) check_1r(r) r.values[0][0] end # Like #value but returns nil if no record was found. It is still an error # if the query returns more than one column def value?(query) #, transaction: false, commit: true) r = pg_exec(query) check_1c(r) return nil if r.ntuples == 0 check_1r(r) r.values[0][0] end # Return an array of values. It is an error if the query returns records # with more than one column. If :transaction is true, the query will be # executed in a transaction and be committed it :commit is true (the # default). This can be used in 'insert ... returning ...' statements def values(query) r = pg_exec(query) check_1c(r) r.column_values(0) end # Return an array of column values. It is an error if the query returns # more than one record. If :transaction is true, the query will be executed # in a transaction and be committed it :commit is true (the default). This # can be used in 'insert ... returning ...' statements def tuple(query) r = pg_exec(query) check_1r(r) r.values[0] end # Like #tuple but returns nil if no record was found def tuple?(query) r = pg_exec(query) return nil if r.ntuples == 0 check_1r(r) r.values[0] end # Return an array of tuples. If :transaction is true, the query will be # executed in a transaction and be committed it :commit is true (the # default). This can be used in 'insert ... returning ...' statements def tuples(query) pg_exec(query).values end # Return a single-element hash from column name to value. It is an error # if the query returns more than one record or more than one column. Note # that you will probably prefer to use #value instead when you expect only # a single field def field(query) r = pg_exec(query) check_1c(r) check_1r(r) r.tuple(0).to_h end # Like #field but returns nil if no record was found def field?(query) r = pg_exec(query) check_1c(r) return nil if r.ntuples == 0 check_1r(r) r.tuple(0).to_h end # Return an array of single-element hashes from column name to value. It # is an error if the query returns records with more than one column. Note # that you will probably prefer to use #values instead when you expect only # single-column records def fields(query) r = pg_exec(query) check_1c(r) r.each.to_a.map(&:to_h) end # Return a hash from column name (a Symbol) to field value. It is an error if # the query returns more than one record. It blows up if a column name is # not a valid ruby symbol def record(query) r = pg_exec(query) check_1r(r) r.tuple(0).to_h end # Like #record but returns nil if no record was found def record?(query) r = pg_exec(query) return nil if r.ntuples == 0 check_1r(r) r.tuple(0).to_h end # Return an array of hashes from column name to field value def records(query) r = pg_exec(query) r.each.to_a.map(&:to_h) end # Return a record as a OpenStruct object. It is an error if the query # returns more than one record. It blows up if a column name is not a valid # ruby symbol def struct(query) OpenStruct.new(**record(query)) end # Like #struct but returns nil if no record was found def struct?(query) args = record?(query) return nil if args.nil? OpenStruct.new(**args) end # Return an array of OpenStruct objects def structs(query) records(query).map { |record| OpenStruct.new(**record) } end # Return a hash from the record id column to record (hash from column name # to field value) If the :key_column option is defined it will be used # instead of id as the key It is an error if the id field value is not # unique def table(query, key_column: :id) [String, Symbol].include?(key_column.class) or raise "Illegal key_column" key_column = (field_name_class == Symbol ? key_column.to_sym : key_column.to_s) r = pg_exec(query) begin r.fnumber(key_column.to_s) # FIXME: What is this? rescue ArgumentError raise Error, "Can't find column #{key_column}" end h = {} r.each { |record| key = record[key_column] !h.key?(key) or raise Error, "Duplicate key: #{key.inspect}" h[record[key_column]] = record.to_h } h end # Return a hash from the record id column to an OpenStruct representation # of the record. If the :key_column option is defined it will be used # instead of id as the key. It is an error if the id field value is not # unique def set(query, key_column: :id) key_column = key_column.to_sym keys = {} r = pg_exec(query) begin r.fnumber(key_column.to_s) # Check that key column exists rescue ArgumentError raise Error, "Can't find column #{key_column}" end h = {} for i in 0...r.ntuples struct = OpenStruct.new(**r[i]) key = struct.send(key_column) !h.key?(key) or raise Error, "Duplicate key: #{key.inspect}" h[key] = struct end h end # Returns a hash from the first field to a tuple of the remaining fields. # If there is only one remaining field then that value is used instead of a # tuple of that value. The optional +key+ argument sets the mapping field def map(query, key = nil) r = pg_exec(query) begin key = (key || r.fname(0)).to_s key_index = r.fnumber(key.to_s) one = (r.nfields == 2) rescue ArgumentError raise Error, "Can't find column #{key}" end h = {} r.each_row { |row| key_value = row.delete_at(key_index) !h.key?(key_value) or raise Error, "Duplicate key: #{key_value}" h[key_value] = (one ? row.first : row) } h end # Return the value of calling the given function (which can be a String or # a Symbol and can contain the schema of the function). It dynamically # detects the structure of the result and return a value or an array of # values if the result contained only one column (like #value or #values), # a tuple if the record has multiple columns (like #tuple), and an array of # of tuples if the result contained more than one record with multiple # columns (like #tuples) # def call(name, *args, proc: false) # :proc may interfere with hashes args_sql = args.map { |arg| # TODO: Use pg's encoder case arg when NilClass; "null" when String; "'#{arg}'" when Integer; arg when TrueClass, FalseClass; arg when Array; "Array['#{arg.join("', '")}']" # Quick and dirty # FIXME when Hash; raise NotImplementedError else raise ArgumentError, "Unrecognized value: #{arg.inspect}" end }.join(", ") if proc pg_exec "call #{name}(#{args_sql})" return nil else r = pg_exec "select * from #{name}(#{args_sql})" if r.ntuples == 0 raise Error, "No records returned" elsif r.ntuples == 1 if r.nfields == 1 r.values[0][0] else r.values[0] end elsif r.nfields == 1 r.column_values(0) else r&.values end end end # Execute SQL statement(s) in a transaction and return the number of # affected records (if any). Also sets #timestamp unless a transaction is # already in progress. The +sql+ argument can be a command (String) or an # arbitrarily nested array of commands. Note that you can't have commands # that span multiple lines. The empty array is a NOP but the empty string # is not. # # #exec pass Postgres exceptions to the caller unless :fail is false. If # fail is false #exec instead return nil but note that postgres doesn't # ignore it so that if you're inside a transaction, the transaction will be # in an error state and if you're also using subtransactions the whole # transaction stack has collapsed # # TODO: Make sure the transaction stack is emptied on postgres errors def exec(sql, commit: true, fail: true, silent: false) transaction(commit: commit) { execute(sql, fail: fail, silent: silent) } end # Execute SQL statement(s) without a transaction block and return the # number of affected records (if any). This used to call procedures that # may manipulate transactions. The +sql+ argument can be a SQL command or # an arbitrarily nested array of commands. The empty array is a NOP but the # empty string is not. #execute pass Postgres exceptions to the caller # unless :fail is false in which case it returns nil # # TODO: Handle postgres exceptions wrt transaction state and stack def execute(sql, fail: true, silent: false) if @pg_connection pg_exec(sql, fail: fail, silent: silent)&.cmd_tuples else pg_exec(sql, fail: fail, silent: silent) end end # Switch user to the given user and execute the statement before swithcing # back to the original user # # FIXME: The out-commented transaction block makes postspec fail for some reason def su(username, &block) raise Error, "Missing block in call to PgConn::Connection#su" if !block_given? realuser = self.value "select current_user" result = nil # transaction(commit: false) { execute "set session authorization #{username}" result = yield execute "set session authorization #{realuser}" # } result end # TODO: Move to TransactionMethods def commit() if transaction? pop_transaction else pg_exec("commit") end end def rollback() raise Rollback end # True if a transaction is in progress def transaction?() !@savepoints.nil? end # Returns number of transaction or savepoint levels def transactions() @savepoints ? 1 + @savepoints.size : 0 end def push_transaction if transaction? savepoint = "savepoint_#{@savepoints.size + 1}" @savepoints.push savepoint pg_exec("savepoint #{savepoint}") else @savepoints = [] pg_exec("begin") @timestamp = pg_exec("select current_timestamp").values[0][0] if @pg_connection end end def pop_transaction(commit: true) transaction? or raise Error, "No transaction in progress" if savepoint = @savepoints.pop if !commit pg_exec("rollback to savepoint #{savepoint}") pg_exec("release savepoint #{savepoint}") else pg_exec("release savepoint #{savepoint}") end else @savepoints = nil pg_exec(commit ? "commit" : "rollback") end end # Does a rollback and empties the stack. This should be called in response # to PG::Error exceptions because then the whole transaction stack is # invalid def cancel_transaction pg_exec("rollback") @savepoints = nil end # Execute block within a transaction and return the result of the block. # The transaction can be rolled back by raising a PgConn::Rollback # exception in which case #transaction returns nil. Note that the # transaction timestamp is set to the start of the first transaction even # if transactions are nested def transaction(commit: true, &block) result = nil begin push_transaction result = yield rescue PgConn::Rollback pop_transaction(commit: false) return nil # FIXME: Rescue other postgres errors and wipe-out stack end pop_transaction(commit: commit) result end private # Wrapper around PG::Connection.new that switches to the postgres user # before connecting if the current user is the root user # def make_connection(*args, **opts) if Process.euid == 0 begin postgres_uid = Process::UID.from_name "postgres" rescue ArgumentError raise Error, "Can't find 'postgres' user" end begin postgres_gid = Process::GID.from_name "postgres" rescue ArgumentError raise Error, "Can't find 'postgres' group" end begin Process::Sys.seteuid postgres_uid Process::Sys.setegid postgres_gid PG::Connection.new *args, **opts ensure Process::Sys.seteuid 0 Process::Sys.setguid 0 end else PG::Connection.new *args, **opts end end # :call-seq: # pg_exec(string) # pg_exec(array) # # Execute statement(s) on the server. If the argument is an array of # commands, the commands are concatenated with ';' before being sent to the # server. #pg_exec returns a PG::Result object or nil if +arg+ was empty. # #exec pass Postgres exceptions to the caller unless :fail is false # # FIXME: Error message prints the last statement but what if another # statement failed? # # TODO: Connsider executing statements one-by-one so we're able to # pin-point Postgres errors without a line number. This may be expensive, # though # # TODO: Fix silent by not handling exceptions def pg_exec(arg, fail: true, silent: false) if @pg_connection @err = nil begin last_stmt = nil # To make the current SQL statement visible to the rescue clause. FIXME Not used? if arg.is_a?(String) return nil if arg == "" last_stmt = arg @pg_connection.exec(last_stmt) else stmts = arg.flatten.compact return nil if stmts.empty? last_stmt = stmts.last @pg_connection.exec(stmts.join(";\n")) end rescue PG::Error => ex @err = ex if fail if !silent # FIXME Why do we handle this? $stderr.puts arg $stderr.puts $stderr.puts ex.message $stderr.flush end raise else return nil end end else # @pg_commands is defined if arg.is_a?(String) @pg_commands << arg if arg != "" else @pg_commands.concat(arg) end nil end end def check_1c(r) case r.nfields when 0; raise Error, "No columns returned" when 1; else raise Error, "More than one column returned" end end def check_1r(r) if r.ntuples == 0 raise Error, "No records returned" elsif r.ntuples > 1 raise Error, "More than one record returned" end end end def self.sql_values(values) "'" + values.join("', '") + "'" end def self.sql_idents(values) '"' + values.join('", "') + '"' end end