lib/redis/connection/memory.rb in fakeredis-0.5.0 vs lib/redis/connection/memory.rb in fakeredis-0.6.0
- old
+ new
@@ -6,18 +6,20 @@
require "fakeredis/sort_method"
require "fakeredis/sorted_set_argument_handler"
require "fakeredis/sorted_set_store"
require "fakeredis/transaction_commands"
require "fakeredis/zset"
+require "fakeredis/bitop_command"
class Redis
module Connection
class Memory
include Redis::Connection::CommandHelper
include FakeRedis
include SortMethod
include TransactionCommands
+ include BitopCommand
include CommandExecutor
attr_accessor :options
# Tracks all databases for all instances across the current process.
@@ -33,10 +35,18 @@
# Used for resetting everything in specs
def self.reset_all_databases
@databases = nil
end
+ def self.channels
+ @channels ||= Hash.new {|h,k| h[k] = [] }
+ end
+
+ def self.reset_all_channels
+ @channels = nil
+ end
+
def self.connect(options = {})
new(options)
end
def initialize(options = {})
@@ -84,18 +94,10 @@
def read
replies.shift
end
- # NOT IMPLEMENTED:
- # * blpop
- # * brpop
- # * brpoplpush
- # * subscribe
- # * psubscribe
- # * publish
-
def flushdb
databases.delete_at(database_id)
"OK"
end
@@ -144,10 +146,49 @@
return false if destination.has_key?(key)
destination[key] = data.delete(key)
true
end
+ def dump(key)
+ return nil unless exists(key)
+
+ value = data[key]
+
+ Marshal.dump(
+ value: value,
+ version: FakeRedis::VERSION, # Redis includes the version, so we might as well
+ )
+ end
+
+ def restore(key, ttl, serialized_value)
+ raise Redis::CommandError, "ERR Target key name is busy." if exists(key)
+
+ raise Redis::CommandError, "ERR DUMP payload version or checksum are wrong" if serialized_value.nil?
+
+ parsed_value = begin
+ Marshal.load(serialized_value)
+ rescue TypeError
+ raise Redis::CommandError, "ERR DUMP payload version or checksum are wrong"
+ end
+
+ if parsed_value[:version] != FakeRedis::VERSION
+ raise Redis::CommandError, "ERR DUMP payload version or checksum are wrong"
+ end
+
+ # We could figure out what type the key was and set it with the public API here,
+ # or we could just assign the value. If we presume the serialized_value is only ever
+ # a return value from `dump` then we've only been given something that was in
+ # the internal data structure anyway.
+ data[key] = parsed_value[:value]
+
+ # Set a TTL if one has been passed
+ ttl = ttl.to_i # Makes nil into 0
+ expire(key, ttl / 1000) unless ttl.zero?
+
+ "OK"
+ end
+
def get(key)
data_type_check(key, String)
data[key]
end
@@ -200,23 +241,72 @@
data_type_check(key, Hash)
data[key] && data[key][field.to_s]
end
def hdel(key, field)
- field = field.to_s
data_type_check(key, Hash)
- deleted = data[key] && data[key].delete(field)
+ return 0 unless data[key]
+
+ if field.is_a?(Array)
+ old_keys_count = data[key].size
+ fields = field.map(&:to_s)
+
+ data[key].delete_if { |k, v| fields.include? k }
+ deleted = old_keys_count - data[key].size
+ else
+ field = field.to_s
+ deleted = data[key].delete(field) ? 1 : 0
+ end
+
remove_key_for_empty_collection(key)
- deleted ? 1 : 0
+ deleted
end
def hkeys(key)
data_type_check(key, Hash)
return [] if data[key].nil?
data[key].keys
end
+ def hscan(key, start_cursor, *args)
+ data_type_check(key, Hash)
+ return ["0", []] unless data[key]
+
+ match = "*"
+ count = 10
+
+ if args.size.odd?
+ raise_argument_error('hscan')
+ end
+
+ if idx = args.index("MATCH")
+ match = args[idx + 1]
+ end
+
+ if idx = args.index("COUNT")
+ count = args[idx + 1]
+ end
+
+ start_cursor = start_cursor.to_i
+
+ cursor = start_cursor
+ next_keys = []
+
+ if start_cursor + count >= data[key].length
+ next_keys = (data[key].to_a)[start_cursor..-1]
+ cursor = 0
+ else
+ cursor = start_cursor + count
+ next_keys = (data[key].to_a)[start_cursor..cursor-1]
+ end
+
+ filtered_next_keys = next_keys.select{|k,v| File.fnmatch(match, k)}
+ result = filtered_next_keys.flatten.map(&:to_s)
+
+ return ["#{cursor}", result]
+ end
+
def keys(pattern = "*")
data.keys.select { |key| File.fnmatch(pattern, key) }
end
def randomkey
@@ -254,11 +344,18 @@
data[key].size
end
def lrange(key, startidx, endidx)
data_type_check(key, Array)
- (data[key] && data[key][startidx..endidx]) || []
+ if data[key]
+ # In Ruby when negative start index is out of range Array#slice returns
+ # nil which is not the case for lrange in Redis.
+ startidx = 0 if startidx < 0 && startidx.abs > data[key].size
+ data[key][startidx..endidx] || []
+ else
+ []
+ end
end
def ltrim(key, start, stop)
data_type_check(key, Array)
return unless data[key]
@@ -282,11 +379,15 @@
end
def linsert(key, where, pivot, value)
data_type_check(key, Array)
return unless data[key]
- index = data[key].index(pivot)
+
+ value = value.to_s
+ index = data[key].index(pivot.to_s)
+ return -1 if index.nil?
+
case where
when :before then data[key].insert(index, value)
when :after then data[key].insert(index + 1, value)
else raise_syntax_error
end
@@ -294,16 +395,18 @@
def lset(key, index, value)
data_type_check(key, Array)
return unless data[key]
raise Redis::CommandError, "ERR index out of range" if index >= data[key].size
- data[key][index] = value
+ data[key][index] = value.to_s
end
def lrem(key, count, value)
data_type_check(key, Array)
- return unless data[key]
+ return 0 unless data[key]
+
+ value = value.to_s
old_size = data[key].size
diff =
if count == 0
data[key].delete(value)
old_size - data[key].size
@@ -316,34 +419,38 @@
remove_key_for_empty_collection(key)
diff
end
def rpush(key, value)
+ raise_argument_error('rpush') if value.respond_to?(:each) && value.empty?
data_type_check(key, Array)
data[key] ||= []
[value].flatten.each do |val|
data[key].push(val.to_s)
end
data[key].size
end
def rpushx(key, value)
+ raise_argument_error('rpushx') if value.respond_to?(:each) && value.empty?
data_type_check(key, Array)
return unless data[key]
rpush(key, value)
end
def lpush(key, value)
+ raise_argument_error('lpush') if value.respond_to?(:each) && value.empty?
data_type_check(key, Array)
data[key] ||= []
[value].flatten.each do |val|
data[key].unshift(val.to_s)
end
data[key].size
end
def lpushx(key, value)
+ raise_argument_error('lpushx') if value.respond_to?(:each) && value.empty?
data_type_check(key, Array)
return unless data[key]
lpush(key, value)
end
@@ -351,23 +458,54 @@
data_type_check(key, Array)
return unless data[key]
data[key].pop
end
+ def brpop(keys, timeout=0)
+ #todo threaded mode
+ keys = Array(keys)
+ keys.each do |key|
+ if data[key] && data[key].size > 0
+ return [key, data[key].pop]
+ end
+ end
+ sleep(timeout.to_f)
+ nil
+ end
+
def rpoplpush(key1, key2)
data_type_check(key1, Array)
rpop(key1).tap do |elem|
lpush(key2, elem) unless elem.nil?
end
end
+ def brpoplpush(key1, key2, opts={})
+ data_type_check(key1, Array)
+ brpop(key1).tap do |elem|
+ lpush(key2, elem) unless elem.nil?
+ end
+ end
+
def lpop(key)
data_type_check(key, Array)
return unless data[key]
data[key].shift
end
+ def blpop(keys, timeout=0)
+ #todo threaded mode
+ keys = Array(keys)
+ keys.each do |key|
+ if data[key] && data[key].size > 0
+ return [key, data[key].shift]
+ end
+ end
+ sleep(timeout.to_f)
+ nil
+ end
+
def smembers(key)
data_type_check(key, ::Set)
return [] unless data[key]
data[key].to_a.reverse
end
@@ -402,11 +540,11 @@
return false unless data[key]
if value.is_a?(Array)
old_size = data[key].size
values = value.map(&:to_s)
- values.each { |value| data[key].delete(value) }
+ values.each { |v| data[key].delete(v) }
deleted = old_size - data[key].size
else
deleted = !!data[key].delete?(value.to_s)
end
@@ -433,10 +571,11 @@
return 0 unless data[key]
data[key].size
end
def sinter(*keys)
+ keys = keys[0] if flatten?(keys)
raise_argument_error('sinter') if keys.empty?
keys.each { |k| data_type_check(k, ::Set) }
return ::Set.new if keys.any? { |k| data[k].nil? }
keys = keys.map { |k| data[k] || ::Set.new }
@@ -450,10 +589,13 @@
result = sinter(*keys)
data[destination] = ::Set.new(result)
end
def sunion(*keys)
+ keys = keys[0] if flatten?(keys)
+ raise_argument_error('sunion') if keys.empty?
+
keys.each { |k| data_type_check(k, ::Set) }
keys = keys.map { |k| data[k] || ::Set.new }
keys.inject(::Set.new) do |set, key|
set | key
end.to_a
@@ -464,10 +606,11 @@
result = sunion(*keys)
data[destination] = ::Set.new(result)
end
def sdiff(key1, *keys)
+ keys = keys[0] if flatten?(keys)
[key1, *keys].each { |k| data_type_check(k, ::Set) }
keys = keys.map { |k| data[k] || ::Set.new }
keys.inject(data[key1] || Set.new) do |memo, set|
memo - set
end.to_a
@@ -481,10 +624,48 @@
def srandmember(key, number=nil)
number.nil? ? srandmember_single(key) : srandmember_multiple(key, number)
end
+ def sscan(key, start_cursor, *args)
+ data_type_check(key, ::Set)
+ return ["0", []] unless data[key]
+
+ match = "*"
+ count = 10
+
+ if args.size.odd?
+ raise_argument_error('sscan')
+ end
+
+ if idx = args.index("MATCH")
+ match = args[idx + 1]
+ end
+
+ if idx = args.index("COUNT")
+ count = args[idx + 1]
+ end
+
+ start_cursor = start_cursor.to_i
+
+ cursor = start_cursor
+ next_keys = []
+
+ if start_cursor + count >= data[key].length
+ next_keys = (data[key].to_a)[start_cursor..-1]
+ cursor = 0
+ else
+ cursor = start_cursor + count
+ next_keys = (data[key].to_a)[start_cursor..cursor-1]
+ end
+
+ filtered_next_keys = next_keys.select{ |k,v| File.fnmatch(match, k)}
+ result = filtered_next_keys.flatten.map(&:to_s)
+
+ return ["#{cursor}", result]
+ end
+
def del(*keys)
keys = keys.flatten(1)
raise_argument_error('del') if keys.empty?
old_count = data.keys.size
@@ -494,14 +675,14 @@
old_count - data.keys.size
end
def setnx(key, value)
if exists(key)
- false
+ 0
else
set(key, value)
- true
+ 1
end
end
def rename(key, new_key)
return unless data[key]
@@ -523,18 +704,32 @@
return 0 unless data[key]
data.expires[key] = Time.now + ttl
1
end
+ def pexpire(key, ttl)
+ return 0 unless data[key]
+ data.expires[key] = Time.now + (ttl / 1000.0)
+ 1
+ end
+
def ttl(key)
if data.expires.include?(key) && (ttl = data.expires[key].to_i - Time.now.to_i) > 0
ttl
else
exists(key) ? -1 : -2
end
end
+ def pttl(key)
+ if data.expires.include?(key) && (ttl = data.expires[key].to_f - Time.now.to_f) > 0
+ ttl * 1000
+ else
+ exists(key) ? -1 : -2
+ end
+ end
+
def expireat(key, timestamp)
data.expires[key] = Time.at(timestamp)
true
end
@@ -546,14 +741,14 @@
data_type_check(key, Hash)
field = field.to_s
if data[key]
result = !data[key].include?(field)
data[key][field] = value.to_s
- result
+ result ? 1 : 0
else
data[key] = { field => value.to_s }
- true
+ 1
end
end
def hsetnx(key, field, value)
data_type_check(key, Hash)
@@ -725,10 +920,15 @@
def incrby(key, by)
data.merge!({ key => (data[key].to_i + by.to_i).to_s || by })
data[key].to_i
end
+ def incrbyfloat(key, by)
+ data.merge!({ key => (data[key].to_f + by.to_f).to_s || by })
+ data[key]
+ end
+
def decr(key)
data.merge!({ key => (data[key].to_i - 1).to_s || "-1"})
data[key].to_i
end
@@ -756,14 +956,10 @@
def scan(start_cursor, *args)
match = "*"
count = 10
- if args.size.odd?
- raise_argument_error('scan')
- end
-
if idx = args.index("MATCH")
match = args[idx + 1]
end
if idx = args.index("COUNT")
@@ -772,21 +968,26 @@
start_cursor = start_cursor.to_i
data_type_check(start_cursor, Fixnum)
cursor = start_cursor
- next_keys = []
+ returned_keys = []
+ final_page = start_cursor + count >= keys(match).length
- if start_cursor + count >= data.length
- next_keys = keys(match)[start_cursor..-1]
+ if final_page
+ previous_keys_been_deleted = (count >= keys(match).length)
+ start_index = previous_keys_been_deleted ? 0 : cursor
+
+ returned_keys = keys(match)[start_index..-1]
cursor = 0
else
- cursor = start_cursor + 10
- next_keys = keys(match)[start_cursor..cursor]
+ end_index = start_cursor + (count - 1)
+ returned_keys = keys(match)[start_cursor..end_index]
+ cursor = start_cursor + count
end
- return "#{cursor}", next_keys
+ return "#{cursor}", returned_keys
end
def zadd(key, *args)
if !args.first.is_a?(Array)
if args.size < 2
@@ -871,23 +1072,48 @@
def zrange(key, start, stop, with_scores = nil)
data_type_check(key, ZSet)
return [] unless data[key]
- # Sort by score, or if scores are equal, key alphanum
- results = data[key].sort do |(k1, v1), (k2, v2)|
- if v1 == v2
- k1 <=> k2
- else
- v1 <=> v2
- end
- end
+ results = sort_keys(data[key])
# Select just the keys unless we want scores
results = results.map(&:first) unless with_scores
results[start..stop].flatten.map(&:to_s)
end
+ def zrangebylex(key, start, stop, *opts)
+ data_type_check(key, ZSet)
+ return [] unless data[key]
+ zset = data[key]
+
+ sorted = if zset.identical_scores?
+ zset.keys.sort { |x, y| x.to_s <=> y.to_s }
+ else
+ zset.keys
+ end
+
+ range = get_range start, stop, sorted.first, sorted.last
+
+ filtered = []
+ sorted.each do |element|
+ filtered << element if (range[0][:value]..range[1][:value]).cover?(element)
+ end
+ filtered.shift if filtered[0] == range[0][:value] && !range[0][:inclusive]
+ filtered.pop if filtered.last == range[1][:value] && !range[1][:inclusive]
+
+ limit = get_limit(opts, filtered)
+ if limit
+ filtered = filtered[limit[0]..-1].take(limit[1])
+ end
+
+ filtered
+ end
+
+ def zrevrangebylex(key, start, stop, *args)
+ zrangebylex(key, stop, start, args).reverse
+ end
+
def zrevrange(key, start, stop, with_scores = nil)
data_type_check(key, ZSet)
return [] unless data[key]
if with_scores
@@ -964,10 +1190,135 @@
args_handler = SortedSetArgumentHandler.new(args)
data[out] = SortedSetUnionStore.new(args_handler, data).call
data[out].size
end
+ def subscribe(*channels)
+ raise_argument_error('subscribe') if channels.empty?()
+
+ #Create messages for all data from the channels
+ channel_replies = channels.map do |channel|
+ self.class.channels[channel].slice!(0..-1).map!{|v| ["message", channel, v]}
+ end
+ channel_replies.flatten!(1)
+ channel_replies.compact!()
+
+ #Put messages into the replies for the future
+ channels.each_with_index do |channel,index|
+ replies << ["subscribe", channel, index+1]
+ end
+ replies.push(*channel_replies)
+
+ #Add unsubscribe message to stop blocking (see https://github.com/redis/redis-rb/blob/v3.2.1/lib/redis/subscribe.rb#L38)
+ replies.push(self.unsubscribe())
+
+ replies.pop() #Last reply will be pushed back on
+ end
+
+ def psubscribe(*patterns)
+ raise_argument_error('psubscribe') if patterns.empty?()
+
+ #Create messages for all data from the channels
+ channel_replies = self.class.channels.keys.map do |channel|
+ pattern = patterns.find{|p| File.fnmatch(p, channel) }
+ unless pattern.nil?()
+ self.class.channels[channel].slice!(0..-1).map!{|v| ["pmessage", pattern, channel, v]}
+ end
+ end
+ channel_replies.flatten!(1)
+ channel_replies.compact!()
+
+ #Put messages into the replies for the future
+ patterns.each_with_index do |pattern,index|
+ replies << ["psubscribe", pattern, index+1]
+ end
+ replies.push(*channel_replies)
+
+ #Add unsubscribe to stop blocking
+ replies.push(self.punsubscribe())
+
+ replies.pop() #Last reply will be pushed back on
+ end
+
+ def publish(channel, message)
+ self.class.channels[channel] << message
+ 0 #Just fake number of subscribers
+ end
+
+ def unsubscribe(*channels)
+ if channels.empty?()
+ replies << ["unsubscribe", nil, 0]
+ else
+ channels.each do |channel|
+ replies << ["unsubscribe", channel, 0]
+ end
+ end
+ replies.pop() #Last reply will be pushed back on
+ end
+
+ def punsubscribe(*patterns)
+ if patterns.empty?()
+ replies << ["punsubscribe", nil, 0]
+ else
+ patterns.each do |pattern|
+ replies << ["punsubscribe", pattern, 0]
+ end
+ end
+ replies.pop() #Last reply will be pushed back on
+ end
+
+ def zscan(key, start_cursor, *args)
+ data_type_check(key, ZSet)
+ return [] unless data[key]
+
+ match = "*"
+ count = 10
+
+ if args.size.odd?
+ raise_argument_error('zscan')
+ end
+
+ if idx = args.index("MATCH")
+ match = args[idx + 1]
+ end
+
+ if idx = args.index("COUNT")
+ count = args[idx + 1]
+ end
+
+ start_cursor = start_cursor.to_i
+ data_type_check(start_cursor, Fixnum)
+
+ cursor = start_cursor
+ next_keys = []
+
+ sorted_keys = sort_keys(data[key])
+
+ if start_cursor + count >= sorted_keys.length
+ next_keys = sorted_keys.to_a.select { |k| File.fnmatch(match, k[0]) } [start_cursor..-1]
+ cursor = 0
+ else
+ cursor = start_cursor + count
+ next_keys = sorted_keys.to_a.select { |k| File.fnmatch(match, k[0]) } [start_cursor..cursor-1]
+ end
+ return "#{cursor}", next_keys.flatten.map(&:to_s)
+ end
+
+ # Originally from redis-rb
+ def zscan_each(key, *args, &block)
+ data_type_check(key, ZSet)
+ return [] unless data[key]
+
+ return to_enum(:zscan_each, key, options) unless block_given?
+ cursor = 0
+ loop do
+ cursor, values = zscan(key, cursor, options)
+ values.each(&block)
+ break if cursor == "0"
+ end
+ end
+
private
def raise_argument_error(command, match_string=command)
error_message = if %w(hmset mset_odd).include?(match_string.downcase)
"ERR wrong number of arguments for #{command.upcase}"
else
@@ -990,10 +1341,31 @@
warn "Operation against a key holding the wrong kind of value: Expected #{klass} at #{key}."
raise Redis::CommandError.new("WRONGTYPE Operation against a key holding the wrong kind of value")
end
end
+ def get_range(start, stop, min = -Float::INFINITY, max = Float::INFINITY)
+ range_options = []
+
+ [start, stop].each do |value|
+ case value[0]
+ when "-"
+ range_options << { value: min, inclusive: true }
+ when "+"
+ range_options << { value: max, inclusive: true }
+ when "["
+ range_options << { value: value[1..-1], inclusive: true }
+ when "("
+ range_options << { value: value[1..-1], inclusive: false }
+ else
+ raise Redis::CommandError, "ERR min or max not valid string range item"
+ end
+ end
+
+ range_options
+ end
+
def get_limit(opts, vals)
index = opts.index('LIMIT')
if index
offset = opts[index + 1]
@@ -1006,10 +1378,13 @@
end
def mapped_param? param
param.size == 1 && param[0].is_a?(Array)
end
+ # NOTE : Redis-rb 3.x will flatten *args, so method(["a", "b", "c"])
+ # should be handled the same way as method("a", "b", "c")
+ alias_method :flatten?, :mapped_param?
def srandmember_single(key)
data_type_check(key, ::Set)
return nil unless data[key]
data[key].to_a[rand(data[key].size)]
@@ -1023,9 +1398,20 @@
available_elements = data[key].to_a - selected
selected << available_elements[rand(available_elements.size)]
end.compact
else
(1..-number).map { data[key].to_a[rand(data[key].size)] }.flatten
+ end
+ end
+
+ def sort_keys(arr)
+ # Sort by score, or if scores are equal, key alphanum
+ sorted_keys = arr.sort do |(k1, v1), (k2, v2)|
+ if v1 == v2
+ k1 <=> k2
+ else
+ v1 <=> v2
+ end
end
end
end
end
end