lib/mini_sql/connection.rb in mini_sql-0.1.2 vs lib/mini_sql/connection.rb in mini_sql-0.1.3

- old
+ new

@@ -1,7 +1,10 @@ +# frozen_string_literal: true + module MiniSql class Connection + attr_reader :raw_connection def self.default_deserializer_cache @deserializer_cache ||= DeserializerCache.new end @@ -12,79 +15,81 @@ map.add_coder(MiniSql::Coders::NumericCoder.new(name: "numeric", oid: 1700, format: 0)) map.add_coder(MiniSql::Coders::IPAddrCoder.new(name: "inet", oid: 869, format: 0)) end end - def initialize(conn, deserializer_cache = nil, type_map = nil) + # Initialize a new MiniSql::Connection object + # + # @param raw_connection [PG::Connection] an active connection to PG + # @param deserializer_cache [MiniSql::DeserializerCache] a cache of field names to deserializer, can be nil + # @param type_map [PG::TypeMap] a type mapper for all results returned, can be nil + def initialize(raw_connection, deserializer_cache: nil, type_map: nil, param_encoder: nil) # TODO adapter to support other databases - @conn = conn + @raw_connection = raw_connection @deserializer_cache = deserializer_cache || Connection.default_deserializer_cache - @type_map = type_map || Connection.type_map(conn) + @type_map = type_map || Connection.type_map(raw_connection) + @param_encoder = param_encoder || InlineParamEncoder.new(self) end - def query_single(sql, params=nil) + # Returns a flat array containing all results. + # Note, if selecting multiple columns array will be flattened + # + # @param sql [String] the query to run + # @param params [Array or Hash], params to apply to query + # @return [Object] a flat array containing all results + def query_single(sql, *params) result = run(sql, params) result.type_map = @type_map - result.column_values(0) + if result.nfields == 1 + result.column_values(0) + else + array = [] + f = 0 + row = 0 + while row < result.ntuples + while f < result.nfields + array << result.getvalue(row, f) + f += 1 + end + f = 0 + row += 1 + end + array + end ensure result.clear if result end - def query(sql, params=nil) + def query(sql, *params) result = run(sql, params) result.type_map = @type_map @deserializer_cache.materialize(result) ensure result.clear if result end - def exec(sql, params=nil) + def exec(sql, *params) result = run(sql, params) result.cmd_tuples ensure result.clear if result end def build(sql) Builder.new(self, sql) end + def escape_string(str) + raw_connection.escape_string(str) + end + private def run(sql, params) - if params - @conn.async_exec(*process_params(sql, params)) - else - @conn.async_exec(sql) + if params && params.length > 0 + sql = @param_encoder.encode(sql, *params) end - end - - def process_params(sql, params) - sql = sql.dup - param_array = nil - - if Hash === params - param_array = [] - params.each do |k, v| - sql.gsub!(":#{k.to_s}", "$#{param_array.length + 1}") - param_array << v - end - elsif Array === params - i = 0 - sql.gsub!("?") do - i += 1 - case params[i-1] - when Integer then "$#{i}::bigint" - when Float then "$#{i}::float8" - when String then "$#{i}::text" - else "$#{i}::unknown" - end - end - param_array = params - end - - [sql, param_array] - + raw_connection.async_exec(sql) end end end