lib/redis/connection/memory.rb in fakeredis-0.3.2 vs lib/redis/connection/memory.rb in fakeredis-0.3.3
- old
+ new
@@ -1,71 +1,18 @@
require 'set'
require 'redis/connection/registry'
require 'redis/connection/command_helper'
+require "fakeredis/expiring_hash"
+require "fakeredis/sorted_set_argument_handler"
+require "fakeredis/sorted_set_store"
+require "fakeredis/zset"
class Redis
module Connection
class Memory
- # Represents a normal hash with some additional expiration information
- # associated with each key
- class ExpiringHash < Hash
- attr_reader :expires
-
- def initialize(*)
- super
- @expires = {}
- end
-
- def [](key)
- delete(key) if expired?(key)
- super
- end
-
- def []=(key, val)
- expire(key)
- super
- end
-
- def delete(key)
- expire(key)
- super
- end
-
- def expire(key)
- expires.delete(key)
- end
-
- def expired?(key)
- expires.include?(key) && expires[key] < Time.now
- end
-
- def key?(key)
- delete(key) if expired?(key)
- super
- end
-
- def values_at(*keys)
- keys.each {|key| delete(key) if expired?(key)}
- super
- end
-
- def keys
- super.select do |key|
- if expired?(key)
- delete(key)
- false
- else
- true
- end
- end
- end
- end
-
- class ZSet < Hash
- end
-
include Redis::Connection::CommandHelper
+ include FakeRedis
def initialize
@data = ExpiringHash.new
@connected = false
@replies = []
@@ -91,21 +38,25 @@
def timeout=(usecs)
end
def write(command)
- method = command.shift
- reply = send(method, *command)
+ meffod = command.shift
+ if respond_to?(meffod)
+ reply = send(meffod, *command)
+ else
+ raise RuntimeError, "ERR unknown command '#{meffod}'"
+ end
if reply == true
reply = 1
elsif reply == false
reply = 0
end
@replies << reply
- @buffer << reply if @buffer && method != :multi
+ @buffer << reply if @buffer && meffod != :multi
nil
end
def read
@replies.shift
@@ -115,17 +66,18 @@
# * blpop
# * brpop
# * brpoplpush
# * discard
# * move
+ # * sort
# * subscribe
# * psubscribe
# * publish
- # * zremrangebyrank
- # * zunionstore
+
def flushdb
@data = ExpiringHash.new
+ "OK"
end
def flushall
flushdb
end
@@ -158,10 +110,11 @@
def bgsave ; end
def bgreriteaof ; end
def get(key)
+ data_type_check(key, String)
@data[key]
end
def getbit(key, offset)
return unless @data[key]
@@ -173,17 +126,18 @@
@data[key][start..ending]
end
alias :substr :getrange
def getset(key, value)
- old_value = @data[key]
- @data[key] = value
- return old_value
+ data_type_check(key, String)
+ @data[key].tap do
+ set(key, value)
+ end
end
def mget(*keys)
- raise ArgumentError, "wrong number of arguments for 'mget' command" if keys.empty?
+ raise RuntimeError, "ERR wrong number of arguments for 'mget' command" if keys.empty?
@data.values_at(*keys)
end
def append(key, value)
@data[key] = (@data[key] || "")
@@ -252,11 +206,11 @@
@data[key].size
end
def lrange(key, startidx, endidx)
data_type_check(key, Array)
- @data[key] && @data[key][startidx..endidx]
+ @data[key] && @data[key][startidx..endidx] || []
end
def ltrim(key, start, stop)
data_type_check(key, Array)
return unless @data[key]
@@ -280,11 +234,11 @@
end
def lset(key, index, value)
data_type_check(key, Array)
return unless @data[key]
- raise RuntimeError if index >= @data[key].size
+ raise RuntimeError, "ERR index out of range" if index >= @data[key].size
@data[key][index] = value
end
def lrem(key, count, value)
data_type_check(key, Array)
@@ -305,11 +259,11 @@
end
def rpush(key, value)
data_type_check(key, Array)
@data[key] ||= []
- @data[key].push(value)
+ @data[key].push(value.to_s)
@data[key].size
end
def rpushx(key, value)
data_type_check(key, Array)
@@ -318,11 +272,11 @@
end
def lpush(key, value)
data_type_check(key, Array)
@data[key] ||= []
- @data[key].unshift(value)
+ @data[key].unshift(value.to_s)
@data[key].size
end
def lpushx(key, value)
data_type_check(key, Array)
@@ -336,12 +290,13 @@
@data[key].pop
end
def rpoplpush(key1, key2)
data_type_check(key1, Array)
- elem = rpop(key1)
- lpush(key2, elem)
+ rpop(key1).tap do |elem|
+ lpush(key2, elem)
+ end
end
def lpop(key)
data_type_check(key, Array)
return unless @data[key]
@@ -449,11 +404,11 @@
def del(*keys)
old_count = @data.keys.size
keys.flatten.each do |key|
@data.delete(key)
end
- deleted_count = old_count - @data.keys.size
+ old_count - @data.keys.size
end
def setnx(key, value)
if exists(key)
false
@@ -521,22 +476,25 @@
return false if @data[key] && @data[key][field]
hset(key, field, value)
end
def hmset(key, *fields)
- raise ArgumentError, "wrong number of arguments for 'hmset' command" if fields.empty? || fields.size.odd?
+ # mapped_hmset gives us [[:k1, "v1", :k2, "v2"]] for `fields`. Fix that.
+ fields = fields[0] if fields.size == 1 && fields[0].is_a?(Array)
+ fields = fields[0] if mapped_param?(fields)
+ raise RuntimeError, "ERR wrong number of arguments for HMSET" if fields.size > 2 && fields.size.odd?
+ raise RuntimeError, "ERR wrong number of arguments for 'hmset' command" if fields.empty? || fields.size.odd?
data_type_check(key, Hash)
@data[key] ||= {}
fields.each_slice(2) do |field|
@data[key][field[0].to_s] = field[1].to_s
end
end
def hmget(key, *fields)
- raise ArgumentError, "wrong number of arguments for 'hmget' command" if fields.empty?
+ raise RuntimeError, "ERR wrong number of arguments for 'hmget' command" if fields.empty?
data_type_check(key, Hash)
- values = []
fields.map do |field|
field = field.to_s
if @data[key]
@data[key][field]
else
@@ -610,54 +568,54 @@
s = @data[key][offset,value.size]
@data[key][s] = value
end
def mset(*pairs)
+ # Handle pairs for mapped_mset command
+ pairs = pairs[0] if mapped_param?(pairs)
pairs.each_slice(2) do |pair|
@data[pair[0].to_s] = pair[1].to_s
end
"OK"
end
def msetnx(*pairs)
+ # Handle pairs for mapped_mset command
+ pairs = pairs[0] if mapped_param?(pairs)
keys = []
pairs.each_with_index{|item, index| keys << item.to_s if index % 2 == 0}
- return if keys.any?{|key| @data.key?(key) }
+ return false if keys.any? {|key| @data.key?(key) }
mset(*pairs)
true
end
def sort(key)
# TODO: Implement
end
def incr(key)
- @data[key] = (@data[key] || "0")
- @data[key] = (@data[key].to_i + 1).to_s
+ @data.merge!({ key => (@data[key].to_i + 1).to_s || "1"})
@data[key].to_i
end
def incrby(key, by)
- @data[key] = (@data[key] || "0")
- @data[key] = (@data[key].to_i + by.to_i).to_s
+ @data.merge!({ key => (@data[key].to_i + by.to_i).to_s || by })
@data[key].to_i
end
def decr(key)
- @data[key] = (@data[key] || "0")
- @data[key] = (@data[key].to_i - 1).to_s
+ @data.merge!({ key => (@data[key].to_i - 1).to_s || "-1"})
@data[key].to_i
end
def decrby(key, by)
- @data[key] = (@data[key] || "0")
- @data[key] = (@data[key].to_i - by.to_i).to_s
+ @data.merge!({ key => ((@data[key].to_i - by.to_i) || (by.to_i * -1)).to_s })
@data[key].to_i
end
def type(key)
- case value = @data[key]
+ case @data[key]
when nil then "none"
when String then "string"
when Hash then "hash"
when Array then "list"
when ::Set then "set"
@@ -692,10 +650,11 @@
def zadd(key, score, value)
data_type_check(key, ZSet)
@data[key] ||= ZSet.new
exists = @data[key].key?(value.to_s)
+ score = "inf" if score == "+inf"
@data[key][value.to_s] = score
!exists
end
def zrem(key, value)
@@ -711,11 +670,12 @@
@data[key] ? @data[key].size : 0
end
def zscore(key, value)
data_type_check(key, ZSet)
- @data[key] && @data[key][value.to_s].to_s
+ result = @data[key] && @data[key][value.to_s]
+ result.to_s if result
end
def zcount(key, min, max)
data_type_check(key, ZSet)
return 0 unless @data[key]
@@ -724,11 +684,16 @@
def zincrby(key, num, value)
data_type_check(key, ZSet)
@data[key] ||= ZSet.new
@data[key][value.to_s] ||= 0
- @data[key][value.to_s] += num
+ if %w(+inf -inf).include?(num)
+ num = "inf" if num == "+inf"
+ @data[key][value.to_s] = num
+ elsif ! %w(+inf -inf).include?(@data[key][value.to_s])
+ @data[key][value.to_s] += num
+ end
@data[key][value.to_s].to_s
end
def zrank(key, value)
data_type_check(key, ZSet)
@@ -742,15 +707,21 @@
def zrange(key, start, stop, with_scores = nil)
data_type_check(key, ZSet)
return [] unless @data[key]
- if with_scores
- @data[key].sort_by {|_,v| v }
- else
- @data[key].keys.sort_by {|k| @data[key][k] }
- end[start..stop].flatten.map(&:to_s)
+ # Sort by score, or if scores are equal, key alphanum
+ results = @data[key].sort do |(k1, v1), (k2, v2)|
+ if v1 == v2
+ k1 <=> k2
+ else
+ v1 <=> v2
+ end
+ end
+ # Select just the keys unless we want scores
+ results = results.map(&:first) unless with_scores
+ results[start..stop].flatten.map(&:to_s)
end
def zrevrange(key, start, stop, with_scores = nil)
data_type_check(key, ZSet)
return [] unless @data[key]
@@ -803,35 +774,27 @@
range = zrange_select_by_score(key, min, max)
range.each {|k,_| @data[key].delete(k) }
range.size
end
- def zinterstore(out, _, *keys)
+ def zinterstore(out, *args)
data_type_check(out, ZSet)
+ args_handler = SortedSetArgumentHandler.new(args)
+ @data[out] = SortedSetIntersectStore.new(args_handler, @data).call
+ @data[out].size
+ end
- hashes = keys.map do |src|
- case @data[src]
- when ::Set
- Hash[@data[src].zip([0] * @data[src].size)]
- when Hash
- @data[src]
- else
- {}
- end
- end
-
- @data[out] = ZSet.new
- values = hashes.inject([]) {|r, h| r.empty? ? h.keys : r & h.keys }
- values.each do |value|
- @data[out][value] = hashes.inject(0) {|n, h| n + h[value].to_i }
- end
-
+ def zunionstore(out, *args)
+ data_type_check(out, ZSet)
+ args_handler = SortedSetArgumentHandler.new(args)
+ @data[out] = SortedSetUnionStore.new(args_handler, @data).call
@data[out].size
end
def zremrangebyrank(key, start, stop)
- sorted_elements = @data[key].sort { |(v_a, r_a), (v_b, r_b)| r_a <=> r_b }
+ sorted_elements = @data[key].sort { |(_, r_a), (_, r_b)| r_a <=> r_b }
+ start = sorted_elements.length if start > sorted_elements.length
elements_to_delete = sorted_elements[start..stop]
elements_to_delete.each { |elem, rank| @data[key].delete(elem) }
elements_to_delete.size
end
@@ -845,11 +808,12 @@
del(key) if @data[key] && @data[key].empty?
end
def data_type_check(key, klass)
if @data[key] && !@data[key].is_a?(klass)
- fail "Operation against a key holding the wrong kind of value: Expected #{klass} at #{key}."
+ warn "Operation against a key holding the wrong kind of value: Expected #{klass} at #{key}."
+ raise RuntimeError.new("ERR Operation against a key holding the wrong kind of value")
end
end
def get_limit(opts, vals)
index = opts.index('LIMIT')
@@ -860,9 +824,13 @@
count = opts[index + 2]
count = vals.size if count < 0
[offset, count]
end
+ end
+
+ def mapped_param? param
+ param.size == 1 && param[0].is_a?(Array)
end
end
end
end