lib/mini_sql/connection.rb in mini_sql-0.1.2 vs lib/mini_sql/connection.rb in mini_sql-0.1.3
- old
+ new
@@ -1,7 +1,10 @@
+# frozen_string_literal: true
+
module MiniSql
class Connection
+ attr_reader :raw_connection
def self.default_deserializer_cache
@deserializer_cache ||= DeserializerCache.new
end
@@ -12,79 +15,81 @@
map.add_coder(MiniSql::Coders::NumericCoder.new(name: "numeric", oid: 1700, format: 0))
map.add_coder(MiniSql::Coders::IPAddrCoder.new(name: "inet", oid: 869, format: 0))
end
end
- def initialize(conn, deserializer_cache = nil, type_map = nil)
+ # Initialize a new MiniSql::Connection object
+ #
+ # @param raw_connection [PG::Connection] an active connection to PG
+ # @param deserializer_cache [MiniSql::DeserializerCache] a cache of field names to deserializer, can be nil
+ # @param type_map [PG::TypeMap] a type mapper for all results returned, can be nil
+ def initialize(raw_connection, deserializer_cache: nil, type_map: nil, param_encoder: nil)
# TODO adapter to support other databases
- @conn = conn
+ @raw_connection = raw_connection
@deserializer_cache = deserializer_cache || Connection.default_deserializer_cache
- @type_map = type_map || Connection.type_map(conn)
+ @type_map = type_map || Connection.type_map(raw_connection)
+ @param_encoder = param_encoder || InlineParamEncoder.new(self)
end
- def query_single(sql, params=nil)
+ # Returns a flat array containing all results.
+ # Note, if selecting multiple columns array will be flattened
+ #
+ # @param sql [String] the query to run
+ # @param params [Array or Hash], params to apply to query
+ # @return [Object] a flat array containing all results
+ def query_single(sql, *params)
result = run(sql, params)
result.type_map = @type_map
- result.column_values(0)
+ if result.nfields == 1
+ result.column_values(0)
+ else
+ array = []
+ f = 0
+ row = 0
+ while row < result.ntuples
+ while f < result.nfields
+ array << result.getvalue(row, f)
+ f += 1
+ end
+ f = 0
+ row += 1
+ end
+ array
+ end
ensure
result.clear if result
end
- def query(sql, params=nil)
+ def query(sql, *params)
result = run(sql, params)
result.type_map = @type_map
@deserializer_cache.materialize(result)
ensure
result.clear if result
end
- def exec(sql, params=nil)
+ def exec(sql, *params)
result = run(sql, params)
result.cmd_tuples
ensure
result.clear if result
end
def build(sql)
Builder.new(self, sql)
end
+ def escape_string(str)
+ raw_connection.escape_string(str)
+ end
+
private
def run(sql, params)
- if params
- @conn.async_exec(*process_params(sql, params))
- else
- @conn.async_exec(sql)
+ if params && params.length > 0
+ sql = @param_encoder.encode(sql, *params)
end
- end
-
- def process_params(sql, params)
- sql = sql.dup
- param_array = nil
-
- if Hash === params
- param_array = []
- params.each do |k, v|
- sql.gsub!(":#{k.to_s}", "$#{param_array.length + 1}")
- param_array << v
- end
- elsif Array === params
- i = 0
- sql.gsub!("?") do
- i += 1
- case params[i-1]
- when Integer then "$#{i}::bigint"
- when Float then "$#{i}::float8"
- when String then "$#{i}::text"
- else "$#{i}::unknown"
- end
- end
- param_array = params
- end
-
- [sql, param_array]
-
+ raw_connection.async_exec(sql)
end
end
end