lib/redis/connection/memory.rb in fakeredis-0.3.3 vs lib/redis/connection/memory.rb in fakeredis-0.4.0

- old
+ new

@@ -1,32 +1,85 @@ require 'set' require 'redis/connection/registry' require 'redis/connection/command_helper' -require "fakeredis/expiring_hash" -require "fakeredis/sorted_set_argument_handler" -require "fakeredis/sorted_set_store" -require "fakeredis/zset" class Redis module Connection class Memory + # Represents a normal hash with some additional expiration information + # associated with each key + class ExpiringHash < Hash + attr_reader :expires + + def initialize(*) + super + @expires = {} + end + + def [](key) + delete(key) if expired?(key) + super + end + + def []=(key, val) + expire(key) + super + end + + def delete(key) + expire(key) + super + end + + def expire(key) + expires.delete(key) + end + + def expired?(key) + expires.include?(key) && expires[key] < Time.now + end + + def key?(key) + delete(key) if expired?(key) + super + end + + def values_at(*keys) + keys.each {|key| delete(key) if expired?(key)} + super + end + + def keys + super.select do |key| + if expired?(key) + delete(key) + false + else + true + end + end + end + end + + class ZSet < Hash + end + include Redis::Connection::CommandHelper - include FakeRedis - def initialize + def initialize(connected = false) @data = ExpiringHash.new - @connected = false + @connected = connected @replies = [] @buffer = nil end def connected? @connected end - def connect(host, port, timeout) - @connected = true + def self.connect(options = {}) + self.new(true) end def connect_unix(path, timeout) @connected = true end @@ -38,25 +91,21 @@ def timeout=(usecs) end def write(command) - meffod = command.shift - if respond_to?(meffod) - reply = send(meffod, *command) - else - raise RuntimeError, "ERR unknown command '#{meffod}'" - end + method = command.shift + reply = send(method, *command) if reply == true reply = 1 elsif reply == false reply = 0 end @replies << reply - @buffer << reply if @buffer && meffod != :multi + @buffer << reply if @buffer && method != :multi nil end def read @replies.shift @@ -66,18 +115,17 @@ # * blpop # * brpop # * brpoplpush # * discard # * move - # * sort # * subscribe # * psubscribe # * publish - + # * zremrangebyrank + # * zunionstore def flushdb @data = ExpiringHash.new - "OK" end def flushall flushdb end @@ -110,11 +158,10 @@ def bgsave ; end def bgreriteaof ; end def get(key) - data_type_check(key, String) @data[key] end def getbit(key, offset) return unless @data[key] @@ -126,18 +173,17 @@ @data[key][start..ending] end alias :substr :getrange def getset(key, value) - data_type_check(key, String) - @data[key].tap do - set(key, value) - end + old_value = @data[key] + @data[key] = value + return old_value end def mget(*keys) - raise RuntimeError, "ERR wrong number of arguments for 'mget' command" if keys.empty? + raise ArgumentError, "wrong number of arguments for 'mget' command" if keys.empty? @data.values_at(*keys) end def append(key, value) @data[key] = (@data[key] || "") @@ -149,11 +195,11 @@ @data[key].size end def hgetall(key) data_type_check(key, Hash) - @data[key] || {} + @data[key].to_a.flatten || {} end def hget(key, field) data_type_check(key, Hash) @data[key] && @data[key][field.to_s] @@ -206,11 +252,11 @@ @data[key].size end def lrange(key, startidx, endidx) data_type_check(key, Array) - @data[key] && @data[key][startidx..endidx] || [] + @data[key] && @data[key][startidx..endidx] end def ltrim(key, start, stop) data_type_check(key, Array) return unless @data[key] @@ -234,11 +280,11 @@ end def lset(key, index, value) data_type_check(key, Array) return unless @data[key] - raise RuntimeError, "ERR index out of range" if index >= @data[key].size + raise RuntimeError if index >= @data[key].size @data[key][index] = value end def lrem(key, count, value) data_type_check(key, Array) @@ -259,11 +305,11 @@ end def rpush(key, value) data_type_check(key, Array) @data[key] ||= [] - @data[key].push(value.to_s) + @data[key].push(value) @data[key].size end def rpushx(key, value) data_type_check(key, Array) @@ -272,11 +318,11 @@ end def lpush(key, value) data_type_check(key, Array) @data[key] ||= [] - @data[key].unshift(value.to_s) + @data[key].unshift(value) @data[key].size end def lpushx(key, value) data_type_check(key, Array) @@ -290,13 +336,12 @@ @data[key].pop end def rpoplpush(key1, key2) data_type_check(key1, Array) - rpop(key1).tap do |elem| - lpush(key2, elem) - end + elem = rpop(key1) + lpush(key2, elem) end def lpop(key) data_type_check(key, Array) return unless @data[key] @@ -404,11 +449,11 @@ def del(*keys) old_count = @data.keys.size keys.flatten.each do |key| @data.delete(key) end - old_count - @data.keys.size + deleted_count = old_count - @data.keys.size end def setnx(key, value) if exists(key) false @@ -476,25 +521,22 @@ return false if @data[key] && @data[key][field] hset(key, field, value) end def hmset(key, *fields) - # mapped_hmset gives us [[:k1, "v1", :k2, "v2"]] for `fields`. Fix that. - fields = fields[0] if fields.size == 1 && fields[0].is_a?(Array) - fields = fields[0] if mapped_param?(fields) - raise RuntimeError, "ERR wrong number of arguments for HMSET" if fields.size > 2 && fields.size.odd? - raise RuntimeError, "ERR wrong number of arguments for 'hmset' command" if fields.empty? || fields.size.odd? + raise ArgumentError, "wrong number of arguments for 'hmset' command" if fields.empty? || fields.size.odd? data_type_check(key, Hash) @data[key] ||= {} fields.each_slice(2) do |field| @data[key][field[0].to_s] = field[1].to_s end end def hmget(key, *fields) - raise RuntimeError, "ERR wrong number of arguments for 'hmget' command" if fields.empty? + raise ArgumentError, "wrong number of arguments for 'hmget' command" if fields.empty? data_type_check(key, Hash) + values = [] fields.map do |field| field = field.to_s if @data[key] @data[key][field] else @@ -568,54 +610,54 @@ s = @data[key][offset,value.size] @data[key][s] = value end def mset(*pairs) - # Handle pairs for mapped_mset command - pairs = pairs[0] if mapped_param?(pairs) pairs.each_slice(2) do |pair| @data[pair[0].to_s] = pair[1].to_s end "OK" end def msetnx(*pairs) - # Handle pairs for mapped_mset command - pairs = pairs[0] if mapped_param?(pairs) keys = [] pairs.each_with_index{|item, index| keys << item.to_s if index % 2 == 0} - return false if keys.any? {|key| @data.key?(key) } + return if keys.any?{|key| @data.key?(key) } mset(*pairs) true end def sort(key) # TODO: Implement end def incr(key) - @data.merge!({ key => (@data[key].to_i + 1).to_s || "1"}) + @data[key] = (@data[key] || "0") + @data[key] = (@data[key].to_i + 1).to_s @data[key].to_i end def incrby(key, by) - @data.merge!({ key => (@data[key].to_i + by.to_i).to_s || by }) + @data[key] = (@data[key] || "0") + @data[key] = (@data[key].to_i + by.to_i).to_s @data[key].to_i end def decr(key) - @data.merge!({ key => (@data[key].to_i - 1).to_s || "-1"}) + @data[key] = (@data[key] || "0") + @data[key] = (@data[key].to_i - 1).to_s @data[key].to_i end def decrby(key, by) - @data.merge!({ key => ((@data[key].to_i - by.to_i) || (by.to_i * -1)).to_s }) + @data[key] = (@data[key] || "0") + @data[key] = (@data[key].to_i - by.to_i).to_s @data[key].to_i end def type(key) - case @data[key] + case value = @data[key] when nil then "none" when String then "string" when Hash then "hash" when Array then "list" when ::Set then "set" @@ -650,11 +692,10 @@ def zadd(key, score, value) data_type_check(key, ZSet) @data[key] ||= ZSet.new exists = @data[key].key?(value.to_s) - score = "inf" if score == "+inf" @data[key][value.to_s] = score !exists end def zrem(key, value) @@ -670,12 +711,11 @@ @data[key] ? @data[key].size : 0 end def zscore(key, value) data_type_check(key, ZSet) - result = @data[key] && @data[key][value.to_s] - result.to_s if result + @data[key] && @data[key][value.to_s].to_s end def zcount(key, min, max) data_type_check(key, ZSet) return 0 unless @data[key] @@ -684,16 +724,11 @@ def zincrby(key, num, value) data_type_check(key, ZSet) @data[key] ||= ZSet.new @data[key][value.to_s] ||= 0 - if %w(+inf -inf).include?(num) - num = "inf" if num == "+inf" - @data[key][value.to_s] = num - elsif ! %w(+inf -inf).include?(@data[key][value.to_s]) - @data[key][value.to_s] += num - end + @data[key][value.to_s] += num @data[key][value.to_s].to_s end def zrank(key, value) data_type_check(key, ZSet) @@ -707,21 +742,15 @@ def zrange(key, start, stop, with_scores = nil) data_type_check(key, ZSet) return [] unless @data[key] - # Sort by score, or if scores are equal, key alphanum - results = @data[key].sort do |(k1, v1), (k2, v2)| - if v1 == v2 - k1 <=> k2 - else - v1 <=> v2 - end - end - # Select just the keys unless we want scores - results = results.map(&:first) unless with_scores - results[start..stop].flatten.map(&:to_s) + if with_scores + @data[key].sort_by {|_,v| v } + else + @data[key].keys.sort_by {|k| @data[key][k] } + end[start..stop].flatten.map(&:to_s) end def zrevrange(key, start, stop, with_scores = nil) data_type_check(key, ZSet) return [] unless @data[key] @@ -774,27 +803,35 @@ range = zrange_select_by_score(key, min, max) range.each {|k,_| @data[key].delete(k) } range.size end - def zinterstore(out, *args) + def zinterstore(out, _, *keys) data_type_check(out, ZSet) - args_handler = SortedSetArgumentHandler.new(args) - @data[out] = SortedSetIntersectStore.new(args_handler, @data).call - @data[out].size - end - def zunionstore(out, *args) - data_type_check(out, ZSet) - args_handler = SortedSetArgumentHandler.new(args) - @data[out] = SortedSetUnionStore.new(args_handler, @data).call + hashes = keys.map do |src| + case @data[src] + when ::Set + Hash[@data[src].zip([0] * @data[src].size)] + when Hash + @data[src] + else + {} + end + end + + @data[out] = ZSet.new + values = hashes.inject([]) {|r, h| r.empty? ? h.keys : r & h.keys } + values.each do |value| + @data[out][value] = hashes.inject(0) {|n, h| n + h[value].to_i } + end + @data[out].size end def zremrangebyrank(key, start, stop) - sorted_elements = @data[key].sort { |(_, r_a), (_, r_b)| r_a <=> r_b } - start = sorted_elements.length if start > sorted_elements.length + sorted_elements = @data[key].sort { |(v_a, r_a), (v_b, r_b)| r_a <=> r_b } elements_to_delete = sorted_elements[start..stop] elements_to_delete.each { |elem, rank| @data[key].delete(elem) } elements_to_delete.size end @@ -808,12 +845,11 @@ del(key) if @data[key] && @data[key].empty? end def data_type_check(key, klass) if @data[key] && !@data[key].is_a?(klass) - warn "Operation against a key holding the wrong kind of value: Expected #{klass} at #{key}." - raise RuntimeError.new("ERR Operation against a key holding the wrong kind of value") + fail "Operation against a key holding the wrong kind of value: Expected #{klass} at #{key}." end end def get_limit(opts, vals) index = opts.index('LIMIT') @@ -824,13 +860,9 @@ count = opts[index + 2] count = vals.size if count < 0 [offset, count] end - end - - def mapped_param? param - param.size == 1 && param[0].is_a?(Array) end end end end