lib/redis/connection/memory.rb in fakeredis-0.3.2 vs lib/redis/connection/memory.rb in fakeredis-0.3.3

- old
+ new

@@ -1,71 +1,18 @@ require 'set' require 'redis/connection/registry' require 'redis/connection/command_helper' +require "fakeredis/expiring_hash" +require "fakeredis/sorted_set_argument_handler" +require "fakeredis/sorted_set_store" +require "fakeredis/zset" class Redis module Connection class Memory - # Represents a normal hash with some additional expiration information - # associated with each key - class ExpiringHash < Hash - attr_reader :expires - - def initialize(*) - super - @expires = {} - end - - def [](key) - delete(key) if expired?(key) - super - end - - def []=(key, val) - expire(key) - super - end - - def delete(key) - expire(key) - super - end - - def expire(key) - expires.delete(key) - end - - def expired?(key) - expires.include?(key) && expires[key] < Time.now - end - - def key?(key) - delete(key) if expired?(key) - super - end - - def values_at(*keys) - keys.each {|key| delete(key) if expired?(key)} - super - end - - def keys - super.select do |key| - if expired?(key) - delete(key) - false - else - true - end - end - end - end - - class ZSet < Hash - end - include Redis::Connection::CommandHelper + include FakeRedis def initialize @data = ExpiringHash.new @connected = false @replies = [] @@ -91,21 +38,25 @@ def timeout=(usecs) end def write(command) - method = command.shift - reply = send(method, *command) + meffod = command.shift + if respond_to?(meffod) + reply = send(meffod, *command) + else + raise RuntimeError, "ERR unknown command '#{meffod}'" + end if reply == true reply = 1 elsif reply == false reply = 0 end @replies << reply - @buffer << reply if @buffer && method != :multi + @buffer << reply if @buffer && meffod != :multi nil end def read @replies.shift @@ -115,17 +66,18 @@ # * blpop # * brpop # * brpoplpush # * discard # * move + # * sort # * subscribe # * psubscribe # * publish - # * zremrangebyrank - # * zunionstore + def flushdb @data = ExpiringHash.new + "OK" end def flushall flushdb end @@ -158,10 +110,11 @@ def bgsave ; end def bgreriteaof ; end def get(key) + data_type_check(key, String) @data[key] end def getbit(key, offset) return unless @data[key] @@ -173,17 +126,18 @@ @data[key][start..ending] end alias :substr :getrange def getset(key, value) - old_value = @data[key] - @data[key] = value - return old_value + data_type_check(key, String) + @data[key].tap do + set(key, value) + end end def mget(*keys) - raise ArgumentError, "wrong number of arguments for 'mget' command" if keys.empty? + raise RuntimeError, "ERR wrong number of arguments for 'mget' command" if keys.empty? @data.values_at(*keys) end def append(key, value) @data[key] = (@data[key] || "") @@ -252,11 +206,11 @@ @data[key].size end def lrange(key, startidx, endidx) data_type_check(key, Array) - @data[key] && @data[key][startidx..endidx] + @data[key] && @data[key][startidx..endidx] || [] end def ltrim(key, start, stop) data_type_check(key, Array) return unless @data[key] @@ -280,11 +234,11 @@ end def lset(key, index, value) data_type_check(key, Array) return unless @data[key] - raise RuntimeError if index >= @data[key].size + raise RuntimeError, "ERR index out of range" if index >= @data[key].size @data[key][index] = value end def lrem(key, count, value) data_type_check(key, Array) @@ -305,11 +259,11 @@ end def rpush(key, value) data_type_check(key, Array) @data[key] ||= [] - @data[key].push(value) + @data[key].push(value.to_s) @data[key].size end def rpushx(key, value) data_type_check(key, Array) @@ -318,11 +272,11 @@ end def lpush(key, value) data_type_check(key, Array) @data[key] ||= [] - @data[key].unshift(value) + @data[key].unshift(value.to_s) @data[key].size end def lpushx(key, value) data_type_check(key, Array) @@ -336,12 +290,13 @@ @data[key].pop end def rpoplpush(key1, key2) data_type_check(key1, Array) - elem = rpop(key1) - lpush(key2, elem) + rpop(key1).tap do |elem| + lpush(key2, elem) + end end def lpop(key) data_type_check(key, Array) return unless @data[key] @@ -449,11 +404,11 @@ def del(*keys) old_count = @data.keys.size keys.flatten.each do |key| @data.delete(key) end - deleted_count = old_count - @data.keys.size + old_count - @data.keys.size end def setnx(key, value) if exists(key) false @@ -521,22 +476,25 @@ return false if @data[key] && @data[key][field] hset(key, field, value) end def hmset(key, *fields) - raise ArgumentError, "wrong number of arguments for 'hmset' command" if fields.empty? || fields.size.odd? + # mapped_hmset gives us [[:k1, "v1", :k2, "v2"]] for `fields`. Fix that. + fields = fields[0] if fields.size == 1 && fields[0].is_a?(Array) + fields = fields[0] if mapped_param?(fields) + raise RuntimeError, "ERR wrong number of arguments for HMSET" if fields.size > 2 && fields.size.odd? + raise RuntimeError, "ERR wrong number of arguments for 'hmset' command" if fields.empty? || fields.size.odd? data_type_check(key, Hash) @data[key] ||= {} fields.each_slice(2) do |field| @data[key][field[0].to_s] = field[1].to_s end end def hmget(key, *fields) - raise ArgumentError, "wrong number of arguments for 'hmget' command" if fields.empty? + raise RuntimeError, "ERR wrong number of arguments for 'hmget' command" if fields.empty? data_type_check(key, Hash) - values = [] fields.map do |field| field = field.to_s if @data[key] @data[key][field] else @@ -610,54 +568,54 @@ s = @data[key][offset,value.size] @data[key][s] = value end def mset(*pairs) + # Handle pairs for mapped_mset command + pairs = pairs[0] if mapped_param?(pairs) pairs.each_slice(2) do |pair| @data[pair[0].to_s] = pair[1].to_s end "OK" end def msetnx(*pairs) + # Handle pairs for mapped_mset command + pairs = pairs[0] if mapped_param?(pairs) keys = [] pairs.each_with_index{|item, index| keys << item.to_s if index % 2 == 0} - return if keys.any?{|key| @data.key?(key) } + return false if keys.any? {|key| @data.key?(key) } mset(*pairs) true end def sort(key) # TODO: Implement end def incr(key) - @data[key] = (@data[key] || "0") - @data[key] = (@data[key].to_i + 1).to_s + @data.merge!({ key => (@data[key].to_i + 1).to_s || "1"}) @data[key].to_i end def incrby(key, by) - @data[key] = (@data[key] || "0") - @data[key] = (@data[key].to_i + by.to_i).to_s + @data.merge!({ key => (@data[key].to_i + by.to_i).to_s || by }) @data[key].to_i end def decr(key) - @data[key] = (@data[key] || "0") - @data[key] = (@data[key].to_i - 1).to_s + @data.merge!({ key => (@data[key].to_i - 1).to_s || "-1"}) @data[key].to_i end def decrby(key, by) - @data[key] = (@data[key] || "0") - @data[key] = (@data[key].to_i - by.to_i).to_s + @data.merge!({ key => ((@data[key].to_i - by.to_i) || (by.to_i * -1)).to_s }) @data[key].to_i end def type(key) - case value = @data[key] + case @data[key] when nil then "none" when String then "string" when Hash then "hash" when Array then "list" when ::Set then "set" @@ -692,10 +650,11 @@ def zadd(key, score, value) data_type_check(key, ZSet) @data[key] ||= ZSet.new exists = @data[key].key?(value.to_s) + score = "inf" if score == "+inf" @data[key][value.to_s] = score !exists end def zrem(key, value) @@ -711,11 +670,12 @@ @data[key] ? @data[key].size : 0 end def zscore(key, value) data_type_check(key, ZSet) - @data[key] && @data[key][value.to_s].to_s + result = @data[key] && @data[key][value.to_s] + result.to_s if result end def zcount(key, min, max) data_type_check(key, ZSet) return 0 unless @data[key] @@ -724,11 +684,16 @@ def zincrby(key, num, value) data_type_check(key, ZSet) @data[key] ||= ZSet.new @data[key][value.to_s] ||= 0 - @data[key][value.to_s] += num + if %w(+inf -inf).include?(num) + num = "inf" if num == "+inf" + @data[key][value.to_s] = num + elsif ! %w(+inf -inf).include?(@data[key][value.to_s]) + @data[key][value.to_s] += num + end @data[key][value.to_s].to_s end def zrank(key, value) data_type_check(key, ZSet) @@ -742,15 +707,21 @@ def zrange(key, start, stop, with_scores = nil) data_type_check(key, ZSet) return [] unless @data[key] - if with_scores - @data[key].sort_by {|_,v| v } - else - @data[key].keys.sort_by {|k| @data[key][k] } - end[start..stop].flatten.map(&:to_s) + # Sort by score, or if scores are equal, key alphanum + results = @data[key].sort do |(k1, v1), (k2, v2)| + if v1 == v2 + k1 <=> k2 + else + v1 <=> v2 + end + end + # Select just the keys unless we want scores + results = results.map(&:first) unless with_scores + results[start..stop].flatten.map(&:to_s) end def zrevrange(key, start, stop, with_scores = nil) data_type_check(key, ZSet) return [] unless @data[key] @@ -803,35 +774,27 @@ range = zrange_select_by_score(key, min, max) range.each {|k,_| @data[key].delete(k) } range.size end - def zinterstore(out, _, *keys) + def zinterstore(out, *args) data_type_check(out, ZSet) + args_handler = SortedSetArgumentHandler.new(args) + @data[out] = SortedSetIntersectStore.new(args_handler, @data).call + @data[out].size + end - hashes = keys.map do |src| - case @data[src] - when ::Set - Hash[@data[src].zip([0] * @data[src].size)] - when Hash - @data[src] - else - {} - end - end - - @data[out] = ZSet.new - values = hashes.inject([]) {|r, h| r.empty? ? h.keys : r & h.keys } - values.each do |value| - @data[out][value] = hashes.inject(0) {|n, h| n + h[value].to_i } - end - + def zunionstore(out, *args) + data_type_check(out, ZSet) + args_handler = SortedSetArgumentHandler.new(args) + @data[out] = SortedSetUnionStore.new(args_handler, @data).call @data[out].size end def zremrangebyrank(key, start, stop) - sorted_elements = @data[key].sort { |(v_a, r_a), (v_b, r_b)| r_a <=> r_b } + sorted_elements = @data[key].sort { |(_, r_a), (_, r_b)| r_a <=> r_b } + start = sorted_elements.length if start > sorted_elements.length elements_to_delete = sorted_elements[start..stop] elements_to_delete.each { |elem, rank| @data[key].delete(elem) } elements_to_delete.size end @@ -845,11 +808,12 @@ del(key) if @data[key] && @data[key].empty? end def data_type_check(key, klass) if @data[key] && !@data[key].is_a?(klass) - fail "Operation against a key holding the wrong kind of value: Expected #{klass} at #{key}." + warn "Operation against a key holding the wrong kind of value: Expected #{klass} at #{key}." + raise RuntimeError.new("ERR Operation against a key holding the wrong kind of value") end end def get_limit(opts, vals) index = opts.index('LIMIT') @@ -860,9 +824,13 @@ count = opts[index + 2] count = vals.size if count < 0 [offset, count] end + end + + def mapped_param? param + param.size == 1 && param[0].is_a?(Array) end end end end