lib/redis/connection/memory.rb in fakeredis-0.5.0 vs lib/redis/connection/memory.rb in fakeredis-0.6.0

- old
+ new

@@ -6,18 +6,20 @@ require "fakeredis/sort_method" require "fakeredis/sorted_set_argument_handler" require "fakeredis/sorted_set_store" require "fakeredis/transaction_commands" require "fakeredis/zset" +require "fakeredis/bitop_command" class Redis module Connection class Memory include Redis::Connection::CommandHelper include FakeRedis include SortMethod include TransactionCommands + include BitopCommand include CommandExecutor attr_accessor :options # Tracks all databases for all instances across the current process. @@ -33,10 +35,18 @@ # Used for resetting everything in specs def self.reset_all_databases @databases = nil end + def self.channels + @channels ||= Hash.new {|h,k| h[k] = [] } + end + + def self.reset_all_channels + @channels = nil + end + def self.connect(options = {}) new(options) end def initialize(options = {}) @@ -84,18 +94,10 @@ def read replies.shift end - # NOT IMPLEMENTED: - # * blpop - # * brpop - # * brpoplpush - # * subscribe - # * psubscribe - # * publish - def flushdb databases.delete_at(database_id) "OK" end @@ -144,10 +146,49 @@ return false if destination.has_key?(key) destination[key] = data.delete(key) true end + def dump(key) + return nil unless exists(key) + + value = data[key] + + Marshal.dump( + value: value, + version: FakeRedis::VERSION, # Redis includes the version, so we might as well + ) + end + + def restore(key, ttl, serialized_value) + raise Redis::CommandError, "ERR Target key name is busy." if exists(key) + + raise Redis::CommandError, "ERR DUMP payload version or checksum are wrong" if serialized_value.nil? + + parsed_value = begin + Marshal.load(serialized_value) + rescue TypeError + raise Redis::CommandError, "ERR DUMP payload version or checksum are wrong" + end + + if parsed_value[:version] != FakeRedis::VERSION + raise Redis::CommandError, "ERR DUMP payload version or checksum are wrong" + end + + # We could figure out what type the key was and set it with the public API here, + # or we could just assign the value. If we presume the serialized_value is only ever + # a return value from `dump` then we've only been given something that was in + # the internal data structure anyway. + data[key] = parsed_value[:value] + + # Set a TTL if one has been passed + ttl = ttl.to_i # Makes nil into 0 + expire(key, ttl / 1000) unless ttl.zero? + + "OK" + end + def get(key) data_type_check(key, String) data[key] end @@ -200,23 +241,72 @@ data_type_check(key, Hash) data[key] && data[key][field.to_s] end def hdel(key, field) - field = field.to_s data_type_check(key, Hash) - deleted = data[key] && data[key].delete(field) + return 0 unless data[key] + + if field.is_a?(Array) + old_keys_count = data[key].size + fields = field.map(&:to_s) + + data[key].delete_if { |k, v| fields.include? k } + deleted = old_keys_count - data[key].size + else + field = field.to_s + deleted = data[key].delete(field) ? 1 : 0 + end + remove_key_for_empty_collection(key) - deleted ? 1 : 0 + deleted end def hkeys(key) data_type_check(key, Hash) return [] if data[key].nil? data[key].keys end + def hscan(key, start_cursor, *args) + data_type_check(key, Hash) + return ["0", []] unless data[key] + + match = "*" + count = 10 + + if args.size.odd? + raise_argument_error('hscan') + end + + if idx = args.index("MATCH") + match = args[idx + 1] + end + + if idx = args.index("COUNT") + count = args[idx + 1] + end + + start_cursor = start_cursor.to_i + + cursor = start_cursor + next_keys = [] + + if start_cursor + count >= data[key].length + next_keys = (data[key].to_a)[start_cursor..-1] + cursor = 0 + else + cursor = start_cursor + count + next_keys = (data[key].to_a)[start_cursor..cursor-1] + end + + filtered_next_keys = next_keys.select{|k,v| File.fnmatch(match, k)} + result = filtered_next_keys.flatten.map(&:to_s) + + return ["#{cursor}", result] + end + def keys(pattern = "*") data.keys.select { |key| File.fnmatch(pattern, key) } end def randomkey @@ -254,11 +344,18 @@ data[key].size end def lrange(key, startidx, endidx) data_type_check(key, Array) - (data[key] && data[key][startidx..endidx]) || [] + if data[key] + # In Ruby when negative start index is out of range Array#slice returns + # nil which is not the case for lrange in Redis. + startidx = 0 if startidx < 0 && startidx.abs > data[key].size + data[key][startidx..endidx] || [] + else + [] + end end def ltrim(key, start, stop) data_type_check(key, Array) return unless data[key] @@ -282,11 +379,15 @@ end def linsert(key, where, pivot, value) data_type_check(key, Array) return unless data[key] - index = data[key].index(pivot) + + value = value.to_s + index = data[key].index(pivot.to_s) + return -1 if index.nil? + case where when :before then data[key].insert(index, value) when :after then data[key].insert(index + 1, value) else raise_syntax_error end @@ -294,16 +395,18 @@ def lset(key, index, value) data_type_check(key, Array) return unless data[key] raise Redis::CommandError, "ERR index out of range" if index >= data[key].size - data[key][index] = value + data[key][index] = value.to_s end def lrem(key, count, value) data_type_check(key, Array) - return unless data[key] + return 0 unless data[key] + + value = value.to_s old_size = data[key].size diff = if count == 0 data[key].delete(value) old_size - data[key].size @@ -316,34 +419,38 @@ remove_key_for_empty_collection(key) diff end def rpush(key, value) + raise_argument_error('rpush') if value.respond_to?(:each) && value.empty? data_type_check(key, Array) data[key] ||= [] [value].flatten.each do |val| data[key].push(val.to_s) end data[key].size end def rpushx(key, value) + raise_argument_error('rpushx') if value.respond_to?(:each) && value.empty? data_type_check(key, Array) return unless data[key] rpush(key, value) end def lpush(key, value) + raise_argument_error('lpush') if value.respond_to?(:each) && value.empty? data_type_check(key, Array) data[key] ||= [] [value].flatten.each do |val| data[key].unshift(val.to_s) end data[key].size end def lpushx(key, value) + raise_argument_error('lpushx') if value.respond_to?(:each) && value.empty? data_type_check(key, Array) return unless data[key] lpush(key, value) end @@ -351,23 +458,54 @@ data_type_check(key, Array) return unless data[key] data[key].pop end + def brpop(keys, timeout=0) + #todo threaded mode + keys = Array(keys) + keys.each do |key| + if data[key] && data[key].size > 0 + return [key, data[key].pop] + end + end + sleep(timeout.to_f) + nil + end + def rpoplpush(key1, key2) data_type_check(key1, Array) rpop(key1).tap do |elem| lpush(key2, elem) unless elem.nil? end end + def brpoplpush(key1, key2, opts={}) + data_type_check(key1, Array) + brpop(key1).tap do |elem| + lpush(key2, elem) unless elem.nil? + end + end + def lpop(key) data_type_check(key, Array) return unless data[key] data[key].shift end + def blpop(keys, timeout=0) + #todo threaded mode + keys = Array(keys) + keys.each do |key| + if data[key] && data[key].size > 0 + return [key, data[key].shift] + end + end + sleep(timeout.to_f) + nil + end + def smembers(key) data_type_check(key, ::Set) return [] unless data[key] data[key].to_a.reverse end @@ -402,11 +540,11 @@ return false unless data[key] if value.is_a?(Array) old_size = data[key].size values = value.map(&:to_s) - values.each { |value| data[key].delete(value) } + values.each { |v| data[key].delete(v) } deleted = old_size - data[key].size else deleted = !!data[key].delete?(value.to_s) end @@ -433,10 +571,11 @@ return 0 unless data[key] data[key].size end def sinter(*keys) + keys = keys[0] if flatten?(keys) raise_argument_error('sinter') if keys.empty? keys.each { |k| data_type_check(k, ::Set) } return ::Set.new if keys.any? { |k| data[k].nil? } keys = keys.map { |k| data[k] || ::Set.new } @@ -450,10 +589,13 @@ result = sinter(*keys) data[destination] = ::Set.new(result) end def sunion(*keys) + keys = keys[0] if flatten?(keys) + raise_argument_error('sunion') if keys.empty? + keys.each { |k| data_type_check(k, ::Set) } keys = keys.map { |k| data[k] || ::Set.new } keys.inject(::Set.new) do |set, key| set | key end.to_a @@ -464,10 +606,11 @@ result = sunion(*keys) data[destination] = ::Set.new(result) end def sdiff(key1, *keys) + keys = keys[0] if flatten?(keys) [key1, *keys].each { |k| data_type_check(k, ::Set) } keys = keys.map { |k| data[k] || ::Set.new } keys.inject(data[key1] || Set.new) do |memo, set| memo - set end.to_a @@ -481,10 +624,48 @@ def srandmember(key, number=nil) number.nil? ? srandmember_single(key) : srandmember_multiple(key, number) end + def sscan(key, start_cursor, *args) + data_type_check(key, ::Set) + return ["0", []] unless data[key] + + match = "*" + count = 10 + + if args.size.odd? + raise_argument_error('sscan') + end + + if idx = args.index("MATCH") + match = args[idx + 1] + end + + if idx = args.index("COUNT") + count = args[idx + 1] + end + + start_cursor = start_cursor.to_i + + cursor = start_cursor + next_keys = [] + + if start_cursor + count >= data[key].length + next_keys = (data[key].to_a)[start_cursor..-1] + cursor = 0 + else + cursor = start_cursor + count + next_keys = (data[key].to_a)[start_cursor..cursor-1] + end + + filtered_next_keys = next_keys.select{ |k,v| File.fnmatch(match, k)} + result = filtered_next_keys.flatten.map(&:to_s) + + return ["#{cursor}", result] + end + def del(*keys) keys = keys.flatten(1) raise_argument_error('del') if keys.empty? old_count = data.keys.size @@ -494,14 +675,14 @@ old_count - data.keys.size end def setnx(key, value) if exists(key) - false + 0 else set(key, value) - true + 1 end end def rename(key, new_key) return unless data[key] @@ -523,18 +704,32 @@ return 0 unless data[key] data.expires[key] = Time.now + ttl 1 end + def pexpire(key, ttl) + return 0 unless data[key] + data.expires[key] = Time.now + (ttl / 1000.0) + 1 + end + def ttl(key) if data.expires.include?(key) && (ttl = data.expires[key].to_i - Time.now.to_i) > 0 ttl else exists(key) ? -1 : -2 end end + def pttl(key) + if data.expires.include?(key) && (ttl = data.expires[key].to_f - Time.now.to_f) > 0 + ttl * 1000 + else + exists(key) ? -1 : -2 + end + end + def expireat(key, timestamp) data.expires[key] = Time.at(timestamp) true end @@ -546,14 +741,14 @@ data_type_check(key, Hash) field = field.to_s if data[key] result = !data[key].include?(field) data[key][field] = value.to_s - result + result ? 1 : 0 else data[key] = { field => value.to_s } - true + 1 end end def hsetnx(key, field, value) data_type_check(key, Hash) @@ -725,10 +920,15 @@ def incrby(key, by) data.merge!({ key => (data[key].to_i + by.to_i).to_s || by }) data[key].to_i end + def incrbyfloat(key, by) + data.merge!({ key => (data[key].to_f + by.to_f).to_s || by }) + data[key] + end + def decr(key) data.merge!({ key => (data[key].to_i - 1).to_s || "-1"}) data[key].to_i end @@ -756,14 +956,10 @@ def scan(start_cursor, *args) match = "*" count = 10 - if args.size.odd? - raise_argument_error('scan') - end - if idx = args.index("MATCH") match = args[idx + 1] end if idx = args.index("COUNT") @@ -772,21 +968,26 @@ start_cursor = start_cursor.to_i data_type_check(start_cursor, Fixnum) cursor = start_cursor - next_keys = [] + returned_keys = [] + final_page = start_cursor + count >= keys(match).length - if start_cursor + count >= data.length - next_keys = keys(match)[start_cursor..-1] + if final_page + previous_keys_been_deleted = (count >= keys(match).length) + start_index = previous_keys_been_deleted ? 0 : cursor + + returned_keys = keys(match)[start_index..-1] cursor = 0 else - cursor = start_cursor + 10 - next_keys = keys(match)[start_cursor..cursor] + end_index = start_cursor + (count - 1) + returned_keys = keys(match)[start_cursor..end_index] + cursor = start_cursor + count end - return "#{cursor}", next_keys + return "#{cursor}", returned_keys end def zadd(key, *args) if !args.first.is_a?(Array) if args.size < 2 @@ -871,23 +1072,48 @@ def zrange(key, start, stop, with_scores = nil) data_type_check(key, ZSet) return [] unless data[key] - # Sort by score, or if scores are equal, key alphanum - results = data[key].sort do |(k1, v1), (k2, v2)| - if v1 == v2 - k1 <=> k2 - else - v1 <=> v2 - end - end + results = sort_keys(data[key]) # Select just the keys unless we want scores results = results.map(&:first) unless with_scores results[start..stop].flatten.map(&:to_s) end + def zrangebylex(key, start, stop, *opts) + data_type_check(key, ZSet) + return [] unless data[key] + zset = data[key] + + sorted = if zset.identical_scores? + zset.keys.sort { |x, y| x.to_s <=> y.to_s } + else + zset.keys + end + + range = get_range start, stop, sorted.first, sorted.last + + filtered = [] + sorted.each do |element| + filtered << element if (range[0][:value]..range[1][:value]).cover?(element) + end + filtered.shift if filtered[0] == range[0][:value] && !range[0][:inclusive] + filtered.pop if filtered.last == range[1][:value] && !range[1][:inclusive] + + limit = get_limit(opts, filtered) + if limit + filtered = filtered[limit[0]..-1].take(limit[1]) + end + + filtered + end + + def zrevrangebylex(key, start, stop, *args) + zrangebylex(key, stop, start, args).reverse + end + def zrevrange(key, start, stop, with_scores = nil) data_type_check(key, ZSet) return [] unless data[key] if with_scores @@ -964,10 +1190,135 @@ args_handler = SortedSetArgumentHandler.new(args) data[out] = SortedSetUnionStore.new(args_handler, data).call data[out].size end + def subscribe(*channels) + raise_argument_error('subscribe') if channels.empty?() + + #Create messages for all data from the channels + channel_replies = channels.map do |channel| + self.class.channels[channel].slice!(0..-1).map!{|v| ["message", channel, v]} + end + channel_replies.flatten!(1) + channel_replies.compact!() + + #Put messages into the replies for the future + channels.each_with_index do |channel,index| + replies << ["subscribe", channel, index+1] + end + replies.push(*channel_replies) + + #Add unsubscribe message to stop blocking (see https://github.com/redis/redis-rb/blob/v3.2.1/lib/redis/subscribe.rb#L38) + replies.push(self.unsubscribe()) + + replies.pop() #Last reply will be pushed back on + end + + def psubscribe(*patterns) + raise_argument_error('psubscribe') if patterns.empty?() + + #Create messages for all data from the channels + channel_replies = self.class.channels.keys.map do |channel| + pattern = patterns.find{|p| File.fnmatch(p, channel) } + unless pattern.nil?() + self.class.channels[channel].slice!(0..-1).map!{|v| ["pmessage", pattern, channel, v]} + end + end + channel_replies.flatten!(1) + channel_replies.compact!() + + #Put messages into the replies for the future + patterns.each_with_index do |pattern,index| + replies << ["psubscribe", pattern, index+1] + end + replies.push(*channel_replies) + + #Add unsubscribe to stop blocking + replies.push(self.punsubscribe()) + + replies.pop() #Last reply will be pushed back on + end + + def publish(channel, message) + self.class.channels[channel] << message + 0 #Just fake number of subscribers + end + + def unsubscribe(*channels) + if channels.empty?() + replies << ["unsubscribe", nil, 0] + else + channels.each do |channel| + replies << ["unsubscribe", channel, 0] + end + end + replies.pop() #Last reply will be pushed back on + end + + def punsubscribe(*patterns) + if patterns.empty?() + replies << ["punsubscribe", nil, 0] + else + patterns.each do |pattern| + replies << ["punsubscribe", pattern, 0] + end + end + replies.pop() #Last reply will be pushed back on + end + + def zscan(key, start_cursor, *args) + data_type_check(key, ZSet) + return [] unless data[key] + + match = "*" + count = 10 + + if args.size.odd? + raise_argument_error('zscan') + end + + if idx = args.index("MATCH") + match = args[idx + 1] + end + + if idx = args.index("COUNT") + count = args[idx + 1] + end + + start_cursor = start_cursor.to_i + data_type_check(start_cursor, Fixnum) + + cursor = start_cursor + next_keys = [] + + sorted_keys = sort_keys(data[key]) + + if start_cursor + count >= sorted_keys.length + next_keys = sorted_keys.to_a.select { |k| File.fnmatch(match, k[0]) } [start_cursor..-1] + cursor = 0 + else + cursor = start_cursor + count + next_keys = sorted_keys.to_a.select { |k| File.fnmatch(match, k[0]) } [start_cursor..cursor-1] + end + return "#{cursor}", next_keys.flatten.map(&:to_s) + end + + # Originally from redis-rb + def zscan_each(key, *args, &block) + data_type_check(key, ZSet) + return [] unless data[key] + + return to_enum(:zscan_each, key, options) unless block_given? + cursor = 0 + loop do + cursor, values = zscan(key, cursor, options) + values.each(&block) + break if cursor == "0" + end + end + private def raise_argument_error(command, match_string=command) error_message = if %w(hmset mset_odd).include?(match_string.downcase) "ERR wrong number of arguments for #{command.upcase}" else @@ -990,10 +1341,31 @@ warn "Operation against a key holding the wrong kind of value: Expected #{klass} at #{key}." raise Redis::CommandError.new("WRONGTYPE Operation against a key holding the wrong kind of value") end end + def get_range(start, stop, min = -Float::INFINITY, max = Float::INFINITY) + range_options = [] + + [start, stop].each do |value| + case value[0] + when "-" + range_options << { value: min, inclusive: true } + when "+" + range_options << { value: max, inclusive: true } + when "[" + range_options << { value: value[1..-1], inclusive: true } + when "(" + range_options << { value: value[1..-1], inclusive: false } + else + raise Redis::CommandError, "ERR min or max not valid string range item" + end + end + + range_options + end + def get_limit(opts, vals) index = opts.index('LIMIT') if index offset = opts[index + 1] @@ -1006,10 +1378,13 @@ end def mapped_param? param param.size == 1 && param[0].is_a?(Array) end + # NOTE : Redis-rb 3.x will flatten *args, so method(["a", "b", "c"]) + # should be handled the same way as method("a", "b", "c") + alias_method :flatten?, :mapped_param? def srandmember_single(key) data_type_check(key, ::Set) return nil unless data[key] data[key].to_a[rand(data[key].size)] @@ -1023,9 +1398,20 @@ available_elements = data[key].to_a - selected selected << available_elements[rand(available_elements.size)] end.compact else (1..-number).map { data[key].to_a[rand(data[key].size)] }.flatten + end + end + + def sort_keys(arr) + # Sort by score, or if scores are equal, key alphanum + sorted_keys = arr.sort do |(k1, v1), (k2, v2)| + if v1 == v2 + k1 <=> k2 + else + v1 <=> v2 + end end end end end end