From 04b19266cf92712c820c1ee3a79c580220fdfa52 Mon Sep 17 00:00:00 2001 From: Bruce Merry Date: Mon, 15 Nov 2021 10:20:05 +0200 Subject: [PATCH 01/20] Prepare a 1.7.0 release --- README.rst | 9 ++++++++- fakeredis/__init__.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 9932745..a95bbb0 100644 --- a/README.rst +++ b/README.rst @@ -95,7 +95,7 @@ your error handling. Simply set the connected attribute of the server to Fakeredis implements the same interface as `redis-py`_, the popular redis client for python, and models the responses -of redis 6.0 (although most new feature in 6.0 are not supported). +of redis 6.2 (although most new features are not supported). Support for aioredis ==================== @@ -452,6 +452,13 @@ they have all been tagged as 'slow' so you can skip them by running:: Revision history ================ +1.7.0 +----- +- `#310 `_ Fix DeprecationWarning for sampling from a set +- `#315 `_ Improved support for constructor arguments +- `#316 `_ Support redis-py 4, + and change some corner-case behaviours to match Redis 6.2.6. + 1.6.1 ----- - `#305 `_ Some packaging modernisation diff --git a/fakeredis/__init__.py b/fakeredis/__init__.py index 180d732..cff7eef 100644 --- a/fakeredis/__init__.py +++ b/fakeredis/__init__.py @@ -1,4 +1,4 @@ from ._server import FakeServer, FakeRedis, FakeStrictRedis, FakeConnection # noqa: F401 -__version__ = '1.6.1' +__version__ = '1.7.0' From da9a71d64f02d950630b7a74c6cb23b8df5e2625 Mon Sep 17 00:00:00 2001 From: Bruce Merry Date: Mon, 15 Nov 2021 14:48:31 +0200 Subject: [PATCH 02/20] Fix from_url for redis-py 4.0.0 Between 4.0.0rc2 and the final release, something changed that caused one of the from_url tests to attempt to use the connection pool before it was fully created. This works around that by avoiding calling into Redis.from_url. --- fakeredis/_server.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/fakeredis/_server.py b/fakeredis/_server.py index 73a6120..4cfb729 100644 --- a/fakeredis/_server.py +++ b/fakeredis/_server.py @@ -2812,15 +2812,14 @@ def from_url(https://github.com/cls,%20*args,%20**kwargs): server = kwargs.pop('server', None) if server is None: server = FakeServer() - self = super().from_url(https://github.com/*args,%20**kwargs) + pool = redis.ConnectionPool.from_url(https://github.com/*args,%20**kwargs) # Now override how it creates connections - pool = self.connection_pool pool.connection_class = FakeConnection pool.connection_kwargs['server'] = server # FakeConnection cannot handle the path kwarg (present when from_url # is called with a unix socket) pool.connection_kwargs.pop('path', None) - return self + return cls(connection_pool=pool) class FakeStrictRedis(FakeRedisMixin, redis.StrictRedis): From fcfb5e882daace231de875beee1b97ec54d43e69 Mon Sep 17 00:00:00 2001 From: Bruce Merry Date: Mon, 15 Nov 2021 14:50:06 +0200 Subject: [PATCH 03/20] Run tests against redis-py 4.0.0 --- .github/workflows/test.yml | 4 ++-- requirements.in | 2 +- requirements.txt | 6 +++++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d3c650b..3531db2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,7 +10,7 @@ jobs: fail-fast: false matrix: python-version: ["3.6", "3.7", "3.8", "3.9", "pypy-3.7"] - redis-py: ["3.5.3"] + redis-py: ["4.0.0"] aioredis: ["2.0.0"] include: - python-version: "3.9" @@ -35,7 +35,7 @@ jobs: redis-py: "3.5.3" aioredis: "1.3.1" - python-version: "3.9" - redis-py: "3.5.*" + redis-py: "4.0.*" aioredis: "2.0.0" coverage: yes services: diff --git a/requirements.in b/requirements.in index adc64f1..ae95967 100644 --- a/requirements.in +++ b/requirements.in @@ -7,7 +7,7 @@ pytest pytest-asyncio pytest-cov pytest-mock -redis==3.5.3 # Latest at time of writing +redis==4.0.0 # Latest at time of writing six sortedcontainers diff --git a/requirements.txt b/requirements.txt index ac96be3..cfb96f6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,6 +16,8 @@ coverage==5.3 # via # -r requirements.in # pytest-cov +deprecated==1.2.13 + # via redis flake8==3.8.4 # via -r requirements.in hiredis==1.1.0 @@ -52,7 +54,7 @@ pytest-cov==2.10.1 # via -r requirements.in pytest-mock==3.3.1 # via -r requirements.in -redis==3.5.3 +redis==4.0.0 # via -r requirements.in six==1.15.0 # via -r requirements.in @@ -62,5 +64,7 @@ sortedcontainers==2.3.0 # hypothesis toml==0.10.2 # via pytest +wrapt==1.13.3 + # via deprecated zipp==1.2.0 # via -r requirements.in From 173ee2513c7faec1088c6b6909aafe5787168182 Mon Sep 17 00:00:00 2001 From: Bruce Merry Date: Mon, 15 Nov 2021 19:43:33 +0200 Subject: [PATCH 04/20] Add #319 to README --- README.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/README.rst b/README.rst index a95bbb0..2451b03 100644 --- a/README.rst +++ b/README.rst @@ -458,6 +458,7 @@ Revision history - `#315 `_ Improved support for constructor arguments - `#316 `_ Support redis-py 4, and change some corner-case behaviours to match Redis 6.2.6. +- `#319 `_ Add support for GET option to SET 1.6.1 ----- From 964acc7f9053678a480e69c7bf9da3990690ee3d Mon Sep 17 00:00:00 2001 From: Bruce Merry Date: Tue, 16 Nov 2021 14:46:35 +0200 Subject: [PATCH 05/20] Make zrange work with multiple WITHSCORES I don't know why this only just started showing up as a hypothesis failure - it seems like it was already the behaviour in 6.0.10 to allow multiple WITHSCORES modifiers. --- fakeredis/_server.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/fakeredis/_server.py b/fakeredis/_server.py index 3728da0..fdeb3af 100644 --- a/fakeredis/_server.py +++ b/fakeredis/_server.py @@ -2164,14 +2164,17 @@ def zlexcount(self, key, min, max): def _zrange(self, key, start, stop, reverse, *args): zset = key.value - # TODO: does redis allow multiple WITHSCORES? - if len(args) > 1 or (args and not casematch(args[0], b'withscores')): - raise SimpleError(SYNTAX_ERROR_MSG) + withscores = False + for arg in args: + if casematch(arg, b'withscores'): + withscores = True + else: + raise SimpleError(SYNTAX_ERROR_MSG) start, stop = self._fix_range(start, stop, len(zset)) if reverse: start, stop = len(zset) - stop, len(zset) - start items = zset.islice_score(start, stop, reverse) - items = self._apply_withscores(items, bool(args)) + items = self._apply_withscores(items, withscores) return items @command((Key(ZSet), Int, Int), (bytes,)) From 681f5abd326ac9facdbac9f23123c110d2e83628 Mon Sep 17 00:00:00 2001 From: Bruce Merry Date: Mon, 22 Nov 2021 13:33:31 +0200 Subject: [PATCH 06/20] Check for overflow in SET expiry times redis 6.2 introduced checks for expiry times overflowing once converted to UNIX epoch times in milliseconds. Implement the same to allow hypothesis testing to pass. I haven't implemented it for EXPIRE/PEXPIRE because in Redis 6.2.6 they have inconsistent behaviour - see https://github.com/redis/redis/issues/9825. --- fakeredis/_server.py | 8 ++++---- test/test_fakeredis.py | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/fakeredis/_server.py b/fakeredis/_server.py index fdeb3af..2839d8d 100644 --- a/fakeredis/_server.py +++ b/fakeredis/_server.py @@ -1507,12 +1507,12 @@ def set_(self, key, value, *args): i += 1 elif casematch(args[i], b'ex') and i + 1 < len(args): ex = Int.decode(args[i + 1]) - if ex <= 0: + if ex <= 0 or (self._db.time + ex) * 1000 >= 2**63: raise SimpleError(INVALID_EXPIRE_MSG.format('set')) i += 2 elif casematch(args[i], b'px') and i + 1 < len(args): px = Int.decode(args[i + 1]) - if px <= 0: + if px <= 0 or self._db.time * 1000 + px >= 2**63: raise SimpleError(INVALID_EXPIRE_MSG.format('set')) i += 2 elif casematch(args[i], b'keepttl'): @@ -1551,7 +1551,7 @@ def set_(self, key, value, *args): @command((Key(), Int, bytes)) def setex(self, key, seconds, value): - if seconds <= 0: + if seconds <= 0 or (self._db.time + seconds) * 1000 >= 2**63: raise SimpleError(INVALID_EXPIRE_MSG.format('setex')) key.value = value key.expireat = self._db.time + seconds @@ -1559,7 +1559,7 @@ def setex(self, key, seconds, value): @command((Key(), Int, bytes)) def psetex(self, key, ms, value): - if ms <= 0: + if ms <= 0 or self._db.time * 1000 + ms >= 2**63: raise SimpleError(INVALID_EXPIRE_MSG.format('psetex')) key.value = value key.expireat = self._db.time + ms / 1000.0 diff --git a/test/test_fakeredis.py b/test/test_fakeredis.py index 240552a..72cd15b 100644 --- a/test/test_fakeredis.py +++ b/test/test_fakeredis.py @@ -709,6 +709,12 @@ def test_setex_using_float(r): r.setex('foo', 1.2, 'bar') +@pytest.mark.min_server('6.2') +def test_setex_overflow(r): + with pytest.raises(ResponseError): + r.setex('foo', 18446744073709561, 'bar') # Overflows long long in ms + + def test_set_ex(r): assert r.set('foo', 'bar', ex=100) is True assert r.get('foo') == b'bar' @@ -719,6 +725,16 @@ def test_set_ex_using_timedelta(r): assert r.get('foo') == b'bar' +def test_set_ex_overflow(r): + with pytest.raises(ResponseError): + r.set('foo', 'bar', ex=18446744073709561) + + +def test_set_px_overflow(r): + with pytest.raises(ResponseError): + r.set('foo', 'bar', px=2**63 - 2) + + def test_set_px(r): assert r.set('foo', 'bar', px=100) is True assert r.get('foo') == b'bar' From 0d73c3b29a0bd0f535bd024da07fb7bd3dbc1ae4 Mon Sep 17 00:00:00 2001 From: Bruce Merry Date: Mon, 22 Nov 2021 13:49:42 +0200 Subject: [PATCH 07/20] Update README for 1.7.0 Indicate that handling corner cases is not limited to a single PR. --- README.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 2451b03..f91de03 100644 --- a/README.rst +++ b/README.rst @@ -454,10 +454,10 @@ Revision history 1.7.0 ----- +- Change a number of corner-case behaviours to match Redis 6.2.6. - `#310 `_ Fix DeprecationWarning for sampling from a set - `#315 `_ Improved support for constructor arguments -- `#316 `_ Support redis-py 4, - and change some corner-case behaviours to match Redis 6.2.6. +- `#316 `_ Support redis-py 4 - `#319 `_ Add support for GET option to SET 1.6.1 From 2a51f0f15a7d734b840ee784e43b0fe5c50b0232 Mon Sep 17 00:00:00 2001 From: Bruce Merry Date: Mon, 22 Nov 2021 17:05:41 +0200 Subject: [PATCH 08/20] Add some comments to overflow tests --- test/test_fakeredis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fakeredis.py b/test/test_fakeredis.py index 72cd15b..bd354f0 100644 --- a/test/test_fakeredis.py +++ b/test/test_fakeredis.py @@ -727,12 +727,12 @@ def test_set_ex_using_timedelta(r): def test_set_ex_overflow(r): with pytest.raises(ResponseError): - r.set('foo', 'bar', ex=18446744073709561) + r.set('foo', 'bar', ex=18446744073709561) # Overflows long long in ms def test_set_px_overflow(r): with pytest.raises(ResponseError): - r.set('foo', 'bar', px=2**63 - 2) + r.set('foo', 'bar', px=2**63 - 2) # Overflows after adding current time def test_set_px(r): From 5b0fb1787d86492384b7d725e50db8c06c8e5fdc Mon Sep 17 00:00:00 2001 From: Bruce Merry Date: Tue, 23 Nov 2021 15:12:46 +0200 Subject: [PATCH 09/20] PERSIST and EXPIRE should invalidate watches Since Redis 6.0.7, PERSIST has invalidated watches, and EXPIRE already did so. Have all changes to the expiry time mark a key as changed. --- fakeredis/_server.py | 1 + test/test_fakeredis.py | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/fakeredis/_server.py b/fakeredis/_server.py index 2839d8d..873043a 100644 --- a/fakeredis/_server.py +++ b/fakeredis/_server.py @@ -243,6 +243,7 @@ def expireat(self): def expireat(self, value): self._expireat = value self._expireat_modified = True + self._modified = True # Since redis 6.0.7 def get(self, default): return self._value if self else default diff --git a/test/test_fakeredis.py b/test/test_fakeredis.py index bd354f0..81dc123 100644 --- a/test/test_fakeredis.py +++ b/test/test_fakeredis.py @@ -4136,6 +4136,18 @@ def test_expire_should_expire_immediately_with_millisecond_timedelta(r): assert r.expire('bar', 1) is False +def test_watch_expire(r): + """EXPIRE should mark a key as changed for WATCH.""" + r.set('foo', 'bar') + with r.pipeline() as p: + p.watch('foo') + r.expire('foo', 10000) + p.multi() + p.get('foo') + with pytest.raises(redis.exceptions.WatchError): + p.execute() + + @pytest.mark.slow def test_pexpire_should_expire_key(r): r.set('foo', 'bar') @@ -4263,6 +4275,18 @@ def test_persist(r): assert r.persist('foo') == 0 +def test_watch_persist(r): + """PERSIST should mark a variable as changed.""" + r.set('foo', 'bar', ex=10000) + with r.pipeline() as p: + p.watch('foo') + r.persist('foo') + p.multi() + p.get('foo') + with pytest.raises(redis.exceptions.WatchError): + p.execute() + + def test_set_existing_key_persists(r): r.set('foo', 'bar', ex=20) r.set('foo', 'foo') From 69d275d7b94d1fa6e9ff125dbef5f4bcde7cf91f Mon Sep 17 00:00:00 2001 From: Bruce Merry Date: Wed, 24 Nov 2021 07:54:25 +0200 Subject: [PATCH 10/20] Add #323 to changelog --- README.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/README.rst b/README.rst index f91de03..7246661 100644 --- a/README.rst +++ b/README.rst @@ -459,6 +459,7 @@ Revision history - `#315 `_ Improved support for constructor arguments - `#316 `_ Support redis-py 4 - `#319 `_ Add support for GET option to SET +- `#323 `_ PERSIST and EXPIRE should invalidate watches 1.6.1 ----- From 8ef8dc6dacc9baf571d66a25ffbf0fadd7c70f78 Mon Sep 17 00:00:00 2001 From: rotten Date: Tue, 4 Jan 2022 15:51:33 -0500 Subject: [PATCH 11/20] add simple support for Redis 4.1.0 --- build/lib/fakeredis/__init__.py | 4 + build/lib/fakeredis/_aioredis1.py | 181 ++ build/lib/fakeredis/_aioredis2.py | 170 ++ build/lib/fakeredis/_async.py | 51 + build/lib/fakeredis/_server.py | 2850 +++++++++++++++++++++++++++++ build/lib/fakeredis/_zset.py | 87 + build/lib/fakeredis/aioredis.py | 10 + fakeredis/_server.py | 7 +- requirements.in | 2 +- requirements.txt | 4 +- setup.cfg | 2 +- 11 files changed, 3362 insertions(+), 6 deletions(-) create mode 100644 build/lib/fakeredis/__init__.py create mode 100644 build/lib/fakeredis/_aioredis1.py create mode 100644 build/lib/fakeredis/_aioredis2.py create mode 100644 build/lib/fakeredis/_async.py create mode 100644 build/lib/fakeredis/_server.py create mode 100644 build/lib/fakeredis/_zset.py create mode 100644 build/lib/fakeredis/aioredis.py diff --git a/build/lib/fakeredis/__init__.py b/build/lib/fakeredis/__init__.py new file mode 100644 index 0000000..cff7eef --- /dev/null +++ b/build/lib/fakeredis/__init__.py @@ -0,0 +1,4 @@ +from ._server import FakeServer, FakeRedis, FakeStrictRedis, FakeConnection # noqa: F401 + + +__version__ = '1.7.0' diff --git a/build/lib/fakeredis/_aioredis1.py b/build/lib/fakeredis/_aioredis1.py new file mode 100644 index 0000000..7679f2e --- /dev/null +++ b/build/lib/fakeredis/_aioredis1.py @@ -0,0 +1,181 @@ +import asyncio +import sys +import warnings + +import aioredis + +from . import _async, _server + + +class FakeSocket(_async.AsyncFakeSocket): + def _decode_error(self, error): + return aioredis.ReplyError(error.value) + + +class FakeReader: + """Re-implementation of aioredis.stream.StreamReader. + + It does not use a socket, but instead provides a queue that feeds + `readobj`. + """ + + def __init__(self, socket): + self._socket = socket + + def set_parser(self, parser): + pass # No parser needed, we get already-parsed data + + async def readobj(self): + if self._socket.responses is None: + raise asyncio.CancelledError + result = await self._socket.responses.get() + return result + + def at_eof(self): + return self._socket.responses is None + + def feed_obj(self, obj): + self._queue.put_nowait(obj) + + +class FakeWriter: + """Replaces a StreamWriter for an aioredis connection.""" + + def __init__(self, socket): + self.transport = socket # So that aioredis can call writer.transport.close() + + def write(self, data): + self.transport.sendall(data) + + +class FakeConnectionsPool(aioredis.ConnectionsPool): + def __init__(self, server=None, db=None, password=None, encoding=None, + *, minsize, maxsize, ssl=None, parser=None, + create_connection_timeout=None, + connection_cls=None, + loop=None): + super().__init__('fakeredis', + db=db, + password=password, + encoding=encoding, + minsize=minsize, + maxsize=maxsize, + ssl=ssl, + parser=parser, + create_connection_timeout=create_connection_timeout, + connection_cls=connection_cls, + loop=loop) + if server is None: + server = _server.FakeServer() + self._server = server + + def _create_new_connection(self, address): + # TODO: what does address do here? Might just be for sentinel? + return create_connection(self._server, + db=self._db, + password=self._password, + ssl=self._ssl, + encoding=self._encoding, + parser=self._parser_class, + timeout=self._create_connection_timeout, + connection_cls=self._connection_cls, + ) + + +async def create_connection(server=None, *, db=None, password=None, ssl=None, + encoding=None, parser=None, loop=None, + timeout=None, connection_cls=None): + # This is mostly copied from aioredis.connection.create_connection + if timeout is not None and timeout <= 0: + raise ValueError("Timeout has to be None or a number greater than 0") + + if connection_cls: + assert issubclass(connection_cls, aioredis.abc.AbcConnection),\ + "connection_class does not meet the AbcConnection contract" + cls = connection_cls + else: + cls = aioredis.connection.RedisConnection + + if loop is not None and sys.version_info >= (3, 8, 0): + warnings.warn("The loop argument is deprecated", + DeprecationWarning) + + if server is None: + server = _server.FakeServer() + socket = FakeSocket(server) + reader = FakeReader(socket) + writer = FakeWriter(socket) + conn = cls(reader, writer, encoding=encoding, + address='fakeredis', parser=parser) + + try: + if password is not None: + await conn.auth(password) + if db is not None: + await conn.select(db) + except Exception: + conn.close() + await conn.wait_closed() + raise + return conn + + +async def create_redis(server=None, *, db=None, password=None, ssl=None, + encoding=None, commands_factory=aioredis.Redis, + parser=None, timeout=None, + connection_cls=None, loop=None): + conn = await create_connection(server, db=db, + password=password, + ssl=ssl, + encoding=encoding, + parser=parser, + timeout=timeout, + connection_cls=connection_cls, + loop=loop) + return commands_factory(conn) + + +async def create_pool(server=None, *, db=None, password=None, ssl=None, + encoding=None, minsize=1, maxsize=10, + parser=None, loop=None, create_connection_timeout=None, + pool_cls=None, connection_cls=None): + # Mostly copied from aioredis.pool.create_pool. + if pool_cls: + assert issubclass(pool_cls, aioredis.AbcPool),\ + "pool_class does not meet the AbcPool contract" + cls = pool_cls + else: + cls = FakeConnectionsPool + + pool = cls(server, db, password, encoding, + minsize=minsize, maxsize=maxsize, + ssl=ssl, parser=parser, + create_connection_timeout=create_connection_timeout, + connection_cls=connection_cls, + loop=loop) + try: + await pool._fill_free(override_min=False) + except Exception: + pool.close() + await pool.wait_closed() + raise + return pool + + +async def create_redis_pool(server=None, *, db=None, password=None, ssl=None, + encoding=None, commands_factory=aioredis.Redis, + minsize=1, maxsize=10, parser=None, + timeout=None, pool_cls=None, + connection_cls=None, loop=None): + pool = await create_pool(server, db=db, + password=password, + ssl=ssl, + encoding=encoding, + minsize=minsize, + maxsize=maxsize, + parser=parser, + create_connection_timeout=timeout, + pool_cls=pool_cls, + connection_cls=connection_cls, + loop=loop) + return commands_factory(pool) diff --git a/build/lib/fakeredis/_aioredis2.py b/build/lib/fakeredis/_aioredis2.py new file mode 100644 index 0000000..d07d197 --- /dev/null +++ b/build/lib/fakeredis/_aioredis2.py @@ -0,0 +1,170 @@ +import asyncio +from typing import Union + +import aioredis + +from . import _async, _server + + +class FakeSocket(_async.AsyncFakeSocket): + _connection_error_class = aioredis.ConnectionError + + def _decode_error(self, error): + return aioredis.connection.BaseParser(1).parse_error(error.value) + + +class FakeReader: + pass + + +class FakeWriter: + def __init__(self, socket: FakeSocket) -> None: + self._socket = socket + + def close(self): + self._socket = None + + async def wait_closed(self): + pass + + async def drain(self): + pass + + def writelines(self, data): + for chunk in data: + self._socket.sendall(chunk) + + +class FakeConnection(aioredis.Connection): + def __init__(self, *args, **kwargs): + self._server = kwargs.pop('server') + self._sock = None + super().__init__(*args, **kwargs) + + async def _connect(self): + if not self._server.connected: + raise aioredis.ConnectionError(_server.CONNECTION_ERROR_MSG) + self._sock = FakeSocket(self._server) + self._reader = FakeReader() + self._writer = FakeWriter(self._sock) + + async def disconnect(self): + await super().disconnect() + self._sock = None + + async def can_read(self, timeout: float = 0): + if not self.is_connected: + await self.connect() + if timeout == 0: + return not self._sock.responses.empty() + # asyncio.Queue doesn't have a way to wait for the queue to be + # non-empty without consuming an item, so kludge it with a sleep/poll + # loop. + loop = asyncio.get_event_loop() + start = loop.time() + while True: + if not self._sock.responses.empty(): + return True + await asyncio.sleep(0.01) + now = loop.time() + if timeout is not None and now > start + timeout: + return False + + def _decode(self, response): + if isinstance(response, list): + return [self._decode(item) for item in response] + elif isinstance(response, bytes): + return self.encoder.decode(response) + else: + return response + + async def read_response(self): + if not self._server.connected: + try: + response = self._sock.responses.get_nowait() + except asyncio.QueueEmpty: + raise aioredis.ConnectionError(_server.CONNECTION_ERROR_MSG) + else: + response = await self._sock.responses.get() + if isinstance(response, aioredis.ResponseError): + raise response + return self._decode(response) + + def repr_pieces(self): + pieces = [ + ('server', self._server), + ('db', self.db) + ] + if self.client_name: + pieces.append(('client_name', self.client_name)) + return pieces + + +class FakeRedis(aioredis.Redis): + def __init__( + self, + *, + db: Union[str, int] = 0, + password: str = None, + socket_timeout: float = None, + connection_pool: aioredis.ConnectionPool = None, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + retry_on_timeout: bool = False, + max_connections: int = None, + health_check_interval: int = 0, + client_name: str = None, + username: str = None, + server: _server.FakeServer = None, + connected: bool = True, + **kwargs + ): + if not connection_pool: + # Adapted from aioredis + if server is None: + server = _server.FakeServer() + server.connected = connected + connection_kwargs = { + "db": db, + "username": username, + "password": password, + "socket_timeout": socket_timeout, + "encoding": encoding, + "encoding_errors": encoding_errors, + "decode_responses": decode_responses, + "retry_on_timeout": retry_on_timeout, + "max_connections": max_connections, + "health_check_interval": health_check_interval, + "client_name": client_name, + "server": server, + "connection_class": FakeConnection + } + connection_pool = aioredis.ConnectionPool(**connection_kwargs) + super().__init__( + db=db, + password=password, + socket_timeout=socket_timeout, + connection_pool=connection_pool, + encoding=encoding, + encoding_errors=encoding_errors, + decode_responses=decode_responses, + retry_on_timeout=retry_on_timeout, + max_connections=max_connections, + health_check_interval=health_check_interval, + client_name=client_name, + username=username, + **kwargs + ) + + @classmethod + def from_url(https://github.com/cls,%20url:%20str,%20**kwargs): + server = kwargs.pop('server', None) + if server is None: + server = _server.FakeServer() + self = super().from_url(https://github.com/url,%20**kwargs) + # Now override how it creates connections + pool = self.connection_pool + pool.connection_class = FakeConnection + pool.connection_kwargs['server'] = server + return self diff --git a/build/lib/fakeredis/_async.py b/build/lib/fakeredis/_async.py new file mode 100644 index 0000000..ec51d1e --- /dev/null +++ b/build/lib/fakeredis/_async.py @@ -0,0 +1,51 @@ +import asyncio + +import async_timeout + +from . import _server + + +class AsyncFakeSocket(_server.FakeSocket): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.responses = asyncio.Queue() + + def put_response(self, msg): + self.responses.put_nowait(msg) + + async def _async_blocking(self, timeout, func, event, callback): + try: + result = None + with async_timeout.timeout(timeout if timeout else None): + while True: + await event.wait() + event.clear() + # This is a coroutine outside the normal control flow that + # locks the server, so we have to take our own lock. + with self._server.lock: + ret = func(False) + if ret is not None: + result = self._decode_result(ret) + self.put_response(result) + break + except asyncio.TimeoutError: + result = None + finally: + with self._server.lock: + self._db.remove_change_callback(callback) + self.put_response(result) + self.resume() + + def _blocking(self, timeout, func): + loop = asyncio.get_event_loop() + ret = func(True) + if ret is not None or self._in_transaction: + return ret + event = asyncio.Event() + + def callback(): + loop.call_soon_threadsafe(event.set) + self._db.add_change_callback(callback) + self.pause() + loop.create_task(self._async_blocking(timeout, func, event, callback)) + return _server.NoResponse() diff --git a/build/lib/fakeredis/_server.py b/build/lib/fakeredis/_server.py new file mode 100644 index 0000000..dbeb438 --- /dev/null +++ b/build/lib/fakeredis/_server.py @@ -0,0 +1,2850 @@ +import functools +import hashlib +import inspect +import itertools +import logging +import math +import pickle +import queue +import random +import re +import threading +import time +import warnings +import weakref +from collections import defaultdict +from collections.abc import MutableMapping + +import redis +import six + +from ._zset import ZSet + +LOGGER = logging.getLogger('fakeredis') +REDIS_LOG_LEVELS = { + b'LOG_DEBUG': 0, + b'LOG_VERBOSE': 1, + b'LOG_NOTICE': 2, + b'LOG_WARNING': 3 +} +REDIS_LOG_LEVELS_TO_LOGGING = { + 0: logging.DEBUG, + 1: logging.INFO, + 2: logging.INFO, + 3: logging.WARNING +} + +MAX_STRING_SIZE = 512 * 1024 * 1024 + +INVALID_EXPIRE_MSG = "ERR invalid expire time in {}" +WRONGTYPE_MSG = \ + "WRONGTYPE Operation against a key holding the wrong kind of value" +SYNTAX_ERROR_MSG = "ERR syntax error" +INVALID_INT_MSG = "ERR value is not an integer or out of range" +INVALID_FLOAT_MSG = "ERR value is not a valid float" +INVALID_OFFSET_MSG = "ERR offset is out of range" +INVALID_BIT_OFFSET_MSG = "ERR bit offset is not an integer or out of range" +INVALID_BIT_VALUE_MSG = "ERR bit is not an integer or out of range" +INVALID_DB_MSG = "ERR DB index is out of range" +INVALID_MIN_MAX_FLOAT_MSG = "ERR min or max is not a float" +INVALID_MIN_MAX_STR_MSG = "ERR min or max not a valid string range item" +STRING_OVERFLOW_MSG = "ERR string exceeds maximum allowed size (512MB)" +OVERFLOW_MSG = "ERR increment or decrement would overflow" +NONFINITE_MSG = "ERR increment would produce NaN or Infinity" +SCORE_NAN_MSG = "ERR resulting score is not a number (NaN)" +INVALID_SORT_FLOAT_MSG = "ERR One or more scores can't be converted into double" +SRC_DST_SAME_MSG = "ERR source and destination objects are the same" +NO_KEY_MSG = "ERR no such key" +INDEX_ERROR_MSG = "ERR index out of range" +ZADD_NX_XX_ERROR_MSG = "ERR ZADD allows either 'nx' or 'xx', not both" +ZADD_INCR_LEN_ERROR_MSG = "ERR INCR option supports a single increment-element pair" +ZUNIONSTORE_KEYS_MSG = "ERR at least 1 input key is needed for ZUNIONSTORE/ZINTERSTORE" +WRONG_ARGS_MSG = "ERR wrong number of arguments for '{}' command" +UNKNOWN_COMMAND_MSG = "ERR unknown command '{}'" +EXECABORT_MSG = "EXECABORT Transaction discarded because of previous errors." +MULTI_NESTED_MSG = "ERR MULTI calls can not be nested" +WITHOUT_MULTI_MSG = "ERR {0} without MULTI" +WATCH_INSIDE_MULTI_MSG = "ERR WATCH inside MULTI is not allowed" +NEGATIVE_KEYS_MSG = "ERR Number of keys can't be negative" +TOO_MANY_KEYS_MSG = "ERR Number of keys can't be greater than number of args" +TIMEOUT_NEGATIVE_MSG = "ERR timeout is negative" +NO_MATCHING_SCRIPT_MSG = "NOSCRIPT No matching script. Please use EVAL." +GLOBAL_VARIABLE_MSG = "ERR Script attempted to set global variables: {}" +COMMAND_IN_SCRIPT_MSG = "ERR This Redis command is not allowed from scripts" +BAD_SUBCOMMAND_MSG = "ERR Unknown {} subcommand or wrong # of args." +BAD_COMMAND_IN_PUBSUB_MSG = \ + "ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT allowed in this context" +CONNECTION_ERROR_MSG = "FakeRedis is emulating a connection error." +REQUIRES_MORE_ARGS_MSG = "ERR {} requires {} arguments or more." +LOG_INVALID_DEBUG_LEVEL_MSG = "ERR Invalid debug level." +LUA_COMMAND_ARG_MSG = "ERR Lua redis() command arguments must be strings or integers" +LUA_WRONG_NUMBER_ARGS_MSG = "ERR wrong number or type of arguments" +SCRIPT_ERROR_MSG = "ERR Error running script (call to f_{}): @user_script:?: {}" +RESTORE_KEY_EXISTS = "BUSYKEY Target key name already exists." +RESTORE_INVALID_CHECKSUM_MSG = "ERR DUMP payload version or checksum are wrong" +RESTORE_INVALID_TTL_MSG = "ERR Invalid TTL value, must be >= 0" + +FLAG_NO_SCRIPT = 's' # Command not allowed in scripts + +# This needs to be grabbed early to avoid breaking tests that mock redis.Redis. +_ORIG_SIG = inspect.signature(redis.Redis) + + +class SimpleString: + def __init__(self, value): + assert isinstance(value, bytes) + self.value = value + + +class SimpleError(Exception): + """Exception that will be turned into a frontend-specific exception.""" + + def __init__(self, value): + assert isinstance(value, str) + self.value = value + + +class NoResponse: + """Returned by pub/sub commands to indicate that no response should be returned""" + pass + + +OK = SimpleString(b'OK') +QUEUED = SimpleString(b'QUEUED') +PONG = SimpleString(b'PONG') +BGSAVE_STARTED = SimpleString(b'Background saving started') + + +def null_terminate(s): + # Redis uses C functions on some strings, which means they stop at the + # first NULL. + if b'\0' in s: + return s[:s.find(b'\0')] + return s + + +def casenorm(s): + return null_terminate(s).lower() + + +def casematch(a, b): + return casenorm(a) == casenorm(b) + + +def compile_pattern(pattern): + """Compile a glob pattern (e.g. for keys) to a bytes regex. + + fnmatch.fnmatchcase doesn't work for this, because it uses different + escaping rules to redis, uses ! instead of ^ to negate a character set, + and handles invalid cases (such as a [ without a ]) differently. This + implementation was written by studying the redis implementation. + """ + # It's easier to work with text than bytes, because indexing bytes + # doesn't behave the same in Python 3. Latin-1 will round-trip safely. + pattern = pattern.decode('latin-1') + parts = ['^'] + i = 0 + L = len(pattern) + while i < L: + c = pattern[i] + i += 1 + if c == '?': + parts.append('.') + elif c == '*': + parts.append('.*') + elif c == '\\': + if i == L: + i -= 1 + parts.append(re.escape(pattern[i])) + i += 1 + elif c == '[': + parts.append('[') + if i < L and pattern[i] == '^': + i += 1 + parts.append('^') + parts_len = len(parts) # To detect if anything was added + while i < L: + if pattern[i] == '\\' and i + 1 < L: + i += 1 + parts.append(re.escape(pattern[i])) + elif pattern[i] == ']': + i += 1 + break + elif i + 2 < L and pattern[i + 1] == '-': + start = pattern[i] + end = pattern[i + 2] + if start > end: + start, end = end, start + parts.append(re.escape(start) + '-' + re.escape(end)) + i += 2 + else: + parts.append(re.escape(pattern[i])) + i += 1 + if len(parts) == parts_len: + if parts[-1] == '[': + # Empty group - will never match + parts[-1] = '(?:$.)' + else: + # Negated empty group - matches any character + assert parts[-1] == '^' + parts.pop() + parts[-1] = '.' + else: + parts.append(']') + else: + parts.append(re.escape(c)) + parts.append('\\Z') + regex = ''.join(parts).encode('latin-1') + return re.compile(regex, re.S) + + +class Item: + """An item stored in the database""" + + __slots__ = ['value', 'expireat'] + + def __init__(self, value): + self.value = value + self.expireat = None + + +class CommandItem: + """An item referenced by a command. + + It wraps an Item but has extra fields to manage updates and notifications. + """ + def __init__(self, key, db, item=None, default=None): + if item is None: + self._value = default + self._expireat = None + else: + self._value = item.value + self._expireat = item.expireat + self.key = key + self.db = db + self._modified = False + self._expireat_modified = False + + @property + def value(self): + return self._value + + @value.setter + def value(self, new_value): + self._value = new_value + self._modified = True + self.expireat = None + + @property + def expireat(self): + return self._expireat + + @expireat.setter + def expireat(self, value): + self._expireat = value + self._expireat_modified = True + self._modified = True # Since redis 6.0.7 + + def get(self, default): + return self._value if self else default + + def update(self, new_value): + self._value = new_value + self._modified = True + + def updated(self): + self._modified = True + + def writeback(self): + if self._modified: + self.db.notify_watch(self.key) + if not isinstance(self.value, bytes) and not self.value: + self.db.pop(self.key, None) + return + else: + item = self.db.setdefault(self.key, Item(None)) + item.value = self.value + item.expireat = self.expireat + elif self._expireat_modified and self.key in self.db: + self.db[self.key].expireat = self.expireat + + def __bool__(self): + return bool(self._value) or isinstance(self._value, bytes) + + __nonzero__ = __bool__ # For Python 2 + + +class Database(MutableMapping): + def __init__(self, lock, *args, **kwargs): + self._dict = dict(*args, **kwargs) + self.time = 0.0 + self._watches = defaultdict(weakref.WeakSet) # key to set of connections + self.condition = threading.Condition(lock) + self._change_callbacks = set() + + def swap(self, other): + self._dict, other._dict = other._dict, self._dict + self.time, other.time = other.time, self.time + + def notify_watch(self, key): + for sock in self._watches.get(key, set()): + sock.notify_watch() + self.condition.notify_all() + for callback in self._change_callbacks: + callback() + + def add_watch(self, key, sock): + self._watches[key].add(sock) + + def remove_watch(self, key, sock): + watches = self._watches[key] + watches.discard(sock) + if not watches: + del self._watches[key] + + def add_change_callback(self, callback): + self._change_callbacks.add(callback) + + def remove_change_callback(self, callback): + self._change_callbacks.remove(callback) + + def clear(self): + for key in self: + self.notify_watch(key) + self._dict.clear() + + def expired(self, item): + return item.expireat is not None and item.expireat < self.time + + def _remove_expired(self): + for key in list(self._dict): + item = self._dict[key] + if self.expired(item): + del self._dict[key] + + def __getitem__(self, key): + item = self._dict[key] + if self.expired(item): + del self._dict[key] + raise KeyError(key) + return item + + def __setitem__(self, key, value): + self._dict[key] = value + + def __delitem__(self, key): + del self._dict[key] + + def __iter__(self): + self._remove_expired() + return iter(self._dict) + + def __len__(self): + self._remove_expired() + return len(self._dict) + + def __hash__(self): + return hash(super(object, self)) + + def __eq__(self, other): + return super(object, self) == other + + +class Hash(dict): + redis_type = b'hash' + + +class Int: + """Argument converter for 64-bit signed integers""" + + DECODE_ERROR = INVALID_INT_MSG + ENCODE_ERROR = OVERFLOW_MSG + MIN_VALUE = -2**63 + MAX_VALUE = 2**63 - 1 + + @classmethod + def valid(cls, value): + return cls.MIN_VALUE <= value <= cls.MAX_VALUE + + @classmethod + def decode(cls, value): + try: + out = int(value) + if not cls.valid(out) or str(out).encode() != value: + raise ValueError + except ValueError: + raise SimpleError(cls.DECODE_ERROR) + return out + + @classmethod + def encode(cls, value): + if cls.valid(value): + return str(value).encode() + else: + raise SimpleError(cls.ENCODE_ERROR) + + +class BitOffset(Int): + """Argument converter for unsigned bit positions""" + + DECODE_ERROR = INVALID_BIT_OFFSET_MSG + MIN_VALUE = 0 + MAX_VALUE = 8 * MAX_STRING_SIZE - 1 # Redis imposes 512MB limit on keys + + +class BitValue(Int): + DECODE_ERROR = INVALID_BIT_VALUE_MSG + MIN_VALUE = 0 + MAX_VALUE = 1 + + +class DbIndex(Int): + """Argument converter for database indices""" + + DECODE_ERROR = INVALID_DB_MSG + MIN_VALUE = 0 + MAX_VALUE = 15 + + +class Timeout(Int): + """Argument converter for timeouts""" + + DECODE_ERROR = TIMEOUT_NEGATIVE_MSG + MIN_VALUE = 0 + + +class Float: + """Argument converter for floating-point values. + + Redis uses long double for some cases (INCRBYFLOAT, HINCRBYFLOAT) + and double for others (zset scores), but Python doesn't support + long double. + """ + + DECODE_ERROR = INVALID_FLOAT_MSG + + @classmethod + def decode(cls, value, + allow_leading_whitespace=False, + allow_erange=False, + allow_empty=False, + crop_null=False): + # redis has some quirks in float parsing, with several variants. + # See https://github.com/antirez/redis/issues/5706 + try: + if crop_null: + value = null_terminate(value) + if allow_empty and value == b'': + value = b'0.0' + if not allow_leading_whitespace and value[:1].isspace(): + raise ValueError + if value[-1:].isspace(): + raise ValueError + out = float(value) + if math.isnan(out): + raise ValueError + if not allow_erange: + # Values that over- or underflow- are explicitly rejected by + # redis. This is a crude hack to determine whether the input + # may have been such a value. + if out in (math.inf, -math.inf, 0.0) and re.match(b'^[^a-zA-Z]*[1-9]', value): + raise ValueError + return out + except ValueError: + raise SimpleError(cls.DECODE_ERROR) + + @classmethod + def encode(cls, value, humanfriendly): + if math.isinf(value): + return str(value).encode() + elif humanfriendly: + # Algorithm from ld2string in redis + out = '{:.17f}'.format(value) + out = re.sub(r'(?:\.)?0+$', '', out) + return out.encode() + else: + return '{:.17g}'.format(value).encode() + + +class SortFloat(Float): + DECODE_ERROR = INVALID_SORT_FLOAT_MSG + + @classmethod + def decode(cls, value): + return super().decode( + value, allow_leading_whitespace=True, allow_empty=True, crop_null=True) + + +class ScoreTest: + """Argument converter for sorted set score endpoints.""" + def __init__(self, value, exclusive=False): + self.value = value + self.exclusive = exclusive + + @classmethod + def decode(cls, value): + try: + exclusive = False + if value[:1] == b'(': + exclusive = True + value = value[1:] + value = Float.decode( + value, allow_leading_whitespace=True, allow_erange=True, + allow_empty=True, crop_null=True) + return cls(value, exclusive) + except SimpleError: + raise SimpleError(INVALID_MIN_MAX_FLOAT_MSG) + + def __str__(self): + if self.exclusive: + return '({!r}'.format(self.value) + else: + return repr(self.value) + + @property + def lower_bound(self): + return (self.value, AfterAny() if self.exclusive else BeforeAny()) + + @property + def upper_bound(self): + return (self.value, BeforeAny() if self.exclusive else AfterAny()) + + +class StringTest: + """Argument converter for sorted set LEX endpoints.""" + def __init__(self, value, exclusive): + self.value = value + self.exclusive = exclusive + + @classmethod + def decode(cls, value): + if value == b'-': + return cls(BeforeAny(), True) + elif value == b'+': + return cls(AfterAny(), True) + elif value[:1] == b'(': + return cls(value[1:], True) + elif value[:1] == b'[': + return cls(value[1:], False) + else: + raise SimpleError(INVALID_MIN_MAX_STR_MSG) + + +@functools.total_ordering +class BeforeAny: + def __gt__(self, other): + return False + + def __eq__(self, other): + return isinstance(other, BeforeAny) + + +@functools.total_ordering +class AfterAny: + def __lt__(self, other): + return False + + def __eq__(self, other): + return isinstance(other, AfterAny) + + +class Key: + """Marker to indicate that argument in signature is a key""" + UNSPECIFIED = object() + + def __init__(self, type_=None, missing_return=UNSPECIFIED): + self.type_ = type_ + self.missing_return = missing_return + + +class Signature: + def __init__(self, name, fixed, repeat=(), flags=""): + self.name = name + self.fixed = fixed + self.repeat = repeat + self.flags = flags + + def check_arity(self, args): + if len(args) != len(self.fixed): + delta = len(args) - len(self.fixed) + if delta < 0 or not self.repeat: + raise SimpleError(WRONG_ARGS_MSG.format(self.name)) + + def apply(self, args, db): + """Returns a tuple, which is either: + - transformed args and a dict of CommandItems; or + - a single containing a short-circuit return value + """ + self.check_arity(args) + if self.repeat: + delta = len(args) - len(self.fixed) + if delta % len(self.repeat) != 0: + raise SimpleError(WRONG_ARGS_MSG.format(self.name)) + + types = list(self.fixed) + for i in range(len(args) - len(types)): + types.append(self.repeat[i % len(self.repeat)]) + + args = list(args) + # First pass: convert/validate non-keys, and short-circuit on missing keys + for i, (arg, type_) in enumerate(zip(args, types)): + if isinstance(type_, Key): + if type_.missing_return is not Key.UNSPECIFIED and arg not in db: + return (type_.missing_return,) + elif type_ != bytes: + args[i] = type_.decode(args[i]) + + # Second pass: read keys and check their types + command_items = [] + for i, (arg, type_) in enumerate(zip(args, types)): + if isinstance(type_, Key): + item = db.get(arg) + default = None + if type_.type_ is not None: + if item is not None and type(item.value) != type_.type_: + raise SimpleError(WRONGTYPE_MSG) + if item is None: + if type_.type_ is not bytes: + default = type_.type_() + args[i] = CommandItem(arg, db, item, default=default) + command_items.append(args[i]) + + return args, command_items + + +def valid_response_type(value, nested=False): + if isinstance(value, NoResponse) and not nested: + return True + if value is not None and not isinstance(value, (bytes, SimpleString, SimpleError, + int, list)): + return False + if isinstance(value, list): + if any(not valid_response_type(item, True) for item in value): + return False + return True + + +def command(*args, **kwargs): + def decorator(func): + name = kwargs.pop('name', func.__name__) + func._fakeredis_sig = Signature(name, *args, **kwargs) + return func + + return decorator + + +class FakeServer: + def __init__(self): + self.lock = threading.Lock() + self.dbs = defaultdict(lambda: Database(self.lock)) + # Maps SHA1 to script source + self.script_cache = {} + # Maps channel/pattern to weak set of sockets + self.subscribers = defaultdict(weakref.WeakSet) + self.psubscribers = defaultdict(weakref.WeakSet) + self.lastsave = int(time.time()) + self.connected = True + # List of weakrefs to sockets that are being closed lazily + self.closed_sockets = [] + + +class FakeSocket: + _connection_error_class = redis.ConnectionError + + def __init__(self, server): + self._server = server + self._db = server.dbs[0] + self._db_num = 0 + # When in a MULTI, set to a list of function calls + self._transaction = None + self._transaction_failed = False + # Set when executing the commands from EXEC + self._in_transaction = False + self._watch_notified = False + self._watches = set() + self._pubsub = 0 # Count of subscriptions + self.responses = queue.Queue() + # Prevents parser from processing commands. Not used in this module, + # but set by aioredis module to prevent new commands being processed + # while handling a blocking command. + self._paused = False + self._parser = self._parse_commands() + self._parser.send(None) + + def put_response(self, msg): + # redis.Connection.__del__ might call self.close at any time, which + # will set self.responses to None. We assume this will happen + # atomically, and the code below then protects us against this. + responses = self.responses + if responses: + responses.put(msg) + + def pause(self): + self._paused = True + + def resume(self): + self._paused = False + self._parser.send(b'') + + def shutdown(self, flags): + self._parser.close() + + def fileno(self): + # Our fake socket must return an integer from `FakeSocket.fileno()` since a real selector + # will be created. The value does not matter since we replace the selector with our own + # `FakeSelector` before it is ever used. + return 0 + + def _cleanup(self, server): + """Remove all the references to `self` from `server`. + + This is called with the server lock held, but it may be some time after + self.close. + """ + for subs in server.subscribers.values(): + subs.discard(self) + for subs in server.psubscribers.values(): + subs.discard(self) + self._clear_watches() + + def close(self): + # Mark ourselves for cleanup. This might be called from + # redis.Connection.__del__, which the garbage collection could call + # at any time, and hence we can't safely take the server lock. + # We rely on list.append being atomic. + self._server.closed_sockets.append(weakref.ref(self)) + self._server = None + self._db = None + self.responses = None + + @staticmethod + def _extract_line(buf): + pos = buf.find(b'\n') + 1 + assert pos > 0 + line = buf[:pos] + buf = buf[pos:] + assert line.endswith(b'\r\n') + return line, buf + + def _parse_commands(self): + """Generator that parses commands. + + It is fed pieces of redis protocol data (via `send`) and calls + `_process_command` whenever it has a complete one. + """ + buf = b'' + while True: + while self._paused or b'\n' not in buf: + buf += yield + line, buf = self._extract_line(buf) + assert line[:1] == b'*' # array + n_fields = int(line[1:-2]) + fields = [] + for i in range(n_fields): + while b'\n' not in buf: + buf += yield + line, buf = self._extract_line(buf) + assert line[:1] == b'$' # string + length = int(line[1:-2]) + while len(buf) < length + 2: + buf += yield + fields.append(buf[:length]) + buf = buf[length+2:] # +2 to skip the CRLF + self._process_command(fields) + + def _run_command(self, func, sig, args, from_script): + command_items = {} + try: + ret = sig.apply(args, self._db) + if len(ret) == 1: + result = ret[0] + else: + args, command_items = ret + if from_script and FLAG_NO_SCRIPT in sig.flags: + raise SimpleError(COMMAND_IN_SCRIPT_MSG) + if self._pubsub and sig.name not in [ + 'ping', 'subscribe', 'unsubscribe', + 'psubscribe', 'punsubscribe', 'quit']: + raise SimpleError(BAD_COMMAND_IN_PUBSUB_MSG) + result = func(*args) + assert valid_response_type(result) + except SimpleError as exc: + result = exc + for command_item in command_items: + command_item.writeback() + return result + + def _decode_error(self, error): + return redis.connection.BaseParser().parse_error(error.value) + + def _decode_result(self, result): + """Convert SimpleString and SimpleError, recursively""" + if isinstance(result, list): + return [self._decode_result(r) for r in result] + elif isinstance(result, SimpleString): + return result.value + elif isinstance(result, SimpleError): + return self._decode_error(result) + else: + return result + + def _blocking(self, timeout, func): + """Run a function until it succeeds or timeout is reached. + + The timeout must be an integer, and 0 means infinite. The function + is called with a boolean to indicate whether this is the first call. + If it returns None it is considered to have "failed" and is retried + each time the condition variable is notified, until the timeout is + reached. + + Returns the function return value, or None if the timeout was reached. + """ + ret = func(True) + if ret is not None or self._in_transaction: + return ret + if timeout: + deadline = time.time() + timeout + else: + deadline = None + while True: + timeout = deadline - time.time() if deadline is not None else None + if timeout is not None and timeout <= 0: + return None + # Python <3.2 doesn't return a status from wait. On Python 3.2+ + # we bail out early on False. + if self._db.condition.wait(timeout=timeout) is False: + return None # Timeout expired + ret = func(False) + if ret is not None: + return ret + + def _name_to_func(self, name): + name = six.ensure_str(name, encoding='utf-8', errors='replace') + func_name = name.lower() + func = getattr(self, func_name, None) + if name.startswith('_') or not func or not hasattr(func, '_fakeredis_sig'): + # redis remaps \r or \n in an error to ' ' to make it legal protocol + clean_name = name.replace('\r', ' ').replace('\n', ' ') + raise SimpleError(UNKNOWN_COMMAND_MSG.format(clean_name)) + return func, func_name + + def sendall(self, data): + if not self._server.connected: + raise self._connection_error_class(CONNECTION_ERROR_MSG) + if isinstance(data, str): + data = data.encode('ascii') + self._parser.send(data) + + def _process_command(self, fields): + if not fields: + return + func_name = None + try: + func, func_name = self._name_to_func(fields[0]) + sig = func._fakeredis_sig + with self._server.lock: + # Clean out old connections + while True: + try: + weak_sock = self._server.closed_sockets.pop() + except IndexError: + break + else: + sock = weak_sock() + if sock: + sock._cleanup(self._server) + now = time.time() + for db in self._server.dbs.values(): + db.time = now + sig.check_arity(fields[1:]) + # TODO: make a signature attribute for transactions + if self._transaction is not None \ + and func_name not in ('exec', 'discard', 'multi', 'watch'): + self._transaction.append((func, sig, fields[1:])) + result = QUEUED + else: + result = self._run_command(func, sig, fields[1:], False) + except SimpleError as exc: + if self._transaction is not None: + # TODO: should not apply if the exception is from _run_command + # e.g. watch inside multi + self._transaction_failed = True + if func_name == 'exec' and exc.value.startswith('ERR '): + exc.value = 'EXECABORT Transaction discarded because of: ' + exc.value[4:] + self._transaction = None + self._transaction_failed = False + self._clear_watches() + result = exc + result = self._decode_result(result) + if not isinstance(result, NoResponse): + self.put_response(result) + + def notify_watch(self): + self._watch_notified = True + + # redis has inconsistent handling of negative indices, hence two versions + # of this code. + + @staticmethod + def _fix_range_string(start, end, length): + # Negative number handling is based on the redis source code + if start < 0 and end < 0 and start > end: + return -1, -1 + if start < 0: + start = max(0, start + length) + if end < 0: + end = max(0, end + length) + end = min(end, length - 1) + return start, end + 1 + + @staticmethod + def _fix_range(start, end, length): + # Redis handles negative slightly differently for zrange + if start < 0: + start = max(0, start + length) + if end < 0: + end += length + if start > end or start >= length: + return -1, -1 + end = min(end, length - 1) + return start, end + 1 + + def _scan(self, keys, cursor, *args): + """ + This is the basis of most of the ``scan`` methods. + + This implementation is KNOWN to be un-performant, as it requires + grabbing the full set of keys over which we are investigating subsets. + + It also doesn't adhere to the guarantee that every key will be iterated + at least once even if the database is modified during the scan. + However, provided the database is not modified, every key will be + returned exactly once. + """ + pattern = None + type = None + count = 10 + if len(args) % 2 != 0: + raise SimpleError(SYNTAX_ERROR_MSG) + for i in range(0, len(args), 2): + if casematch(args[i], b'match'): + pattern = args[i + 1] + elif casematch(args[i], b'count'): + count = Int.decode(args[i + 1]) + if count <= 0: + raise SimpleError(SYNTAX_ERROR_MSG) + elif casematch(args[i], b'type'): + type = args[i + 1] + else: + raise SimpleError(SYNTAX_ERROR_MSG) + + if cursor >= len(keys): + return [0, []] + data = sorted(keys) + result_cursor = cursor + count + result_data = [] + + regex = compile_pattern(pattern) if pattern is not None else None + + def match_key(key): + return regex.match(key) if pattern is not None else True + + def match_type(key): + if type is not None: + return casematch(self.type(self._db[key]).value, type) + return True + + if pattern is not None or type is not None: + for val in itertools.islice(data, cursor, result_cursor): + compare_val = val[0] if isinstance(val, tuple) else val + if match_key(compare_val) and match_type(compare_val): + result_data.append(val) + else: + result_data = data[cursor:result_cursor] + + if result_cursor >= len(data): + result_cursor = 0 + return [result_cursor, result_data] + + # Connection commands + # TODO: auth, quit + + @command((bytes,)) + def echo(self, message): + return message + + @command((), (bytes,)) + def ping(self, *args): + if len(args) > 1: + raise SimpleError(WRONG_ARGS_MSG.format('ping')) + if self._pubsub: + return [b'pong', args[0] if args else b''] + else: + return args[0] if args else PONG + + @command((DbIndex,)) + def select(self, index): + self._db = self._server.dbs[index] + self._db_num = index + return OK + + @command((DbIndex, DbIndex)) + def swapdb(self, index1, index2): + if index1 != index2: + db1 = self._server.dbs[index1] + db2 = self._server.dbs[index2] + db1.swap(db2) + return OK + + # Key commands + # TODO: lots + + def _delete(self, *keys): + ans = 0 + done = set() + for key in keys: + if key and key.key not in done: + key.value = None + done.add(key.key) + ans += 1 + return ans + + @command((Key(),), (Key(),), name='del') + def del_(self, *keys): + return self._delete(*keys) + + @command((Key(),), (Key(),), name='unlink') + def unlink(self, *keys): + return self._delete(*keys) + + @command((Key(),), (Key(),)) + def exists(self, *keys): + ret = 0 + for key in keys: + if key: + ret += 1 + return ret + + def _expireat(self, key, timestamp): + if not key: + return 0 + else: + key.expireat = timestamp + return 1 + + def _ttl(self, key, scale): + if not key: + return -2 + elif key.expireat is None: + return -1 + else: + return int(round((key.expireat - self._db.time) * scale)) + + @command((Key(), Int)) + def expire(self, key, seconds): + return self._expireat(key, self._db.time + seconds) + + @command((Key(), Int)) + def expireat(self, key, timestamp): + return self._expireat(key, float(timestamp)) + + @command((Key(), Int)) + def pexpire(self, key, ms): + return self._expireat(key, self._db.time + ms / 1000.0) + + @command((Key(), Int)) + def pexpireat(self, key, ms_timestamp): + return self._expireat(key, ms_timestamp / 1000.0) + + @command((Key(),)) + def ttl(self, key): + return self._ttl(key, 1.0) + + @command((Key(),)) + def pttl(self, key): + return self._ttl(key, 1000.0) + + @command((Key(),)) + def type(self, key): + if key.value is None: + return SimpleString(b'none') + elif isinstance(key.value, bytes): + return SimpleString(b'string') + elif isinstance(key.value, list): + return SimpleString(b'list') + elif isinstance(key.value, set): + return SimpleString(b'set') + elif isinstance(key.value, ZSet): + return SimpleString(b'zset') + elif isinstance(key.value, dict): + return SimpleString(b'hash') + else: + assert False # pragma: nocover + + @command((Key(),)) + def persist(self, key): + if key.expireat is None: + return 0 + key.expireat = None + return 1 + + @command((bytes,)) + def keys(self, pattern): + if pattern == b'*': + return list(self._db) + else: + regex = compile_pattern(pattern) + return [key for key in self._db if regex.match(key)] + + @command((Key(), DbIndex)) + def move(self, key, db): + if db == self._db_num: + raise SimpleError(SRC_DST_SAME_MSG) + if not key or key.key in self._server.dbs[db]: + return 0 + # TODO: what is the interaction with expiry? + self._server.dbs[db][key.key] = self._server.dbs[self._db_num][key.key] + key.value = None # Causes deletion + return 1 + + @command(()) + def randomkey(self): + keys = list(self._db.keys()) + if not keys: + return None + return random.choice(keys) + + @command((Key(), Key())) + def rename(self, key, newkey): + if not key: + raise SimpleError(NO_KEY_MSG) + # TODO: check interaction with WATCH + if newkey.key != key.key: + newkey.value = key.value + newkey.expireat = key.expireat + key.value = None + return OK + + @command((Key(), Key())) + def renamenx(self, key, newkey): + if not key: + raise SimpleError(NO_KEY_MSG) + if newkey: + return 0 + self.rename(key, newkey) + return 1 + + @command((Int,), (bytes, bytes)) + def scan(self, cursor, *args): + return self._scan(list(self._db), cursor, *args) + + def _lookup_key(self, key, pattern): + """Python implementation of lookupKeyByPattern from redis""" + if pattern == b'#': + return key + p = pattern.find(b'*') + if p == -1: + return None + prefix = pattern[:p] + suffix = pattern[p+1:] + arrow = suffix.find(b'->', 0, -1) + if arrow != -1: + field = suffix[arrow+2:] + suffix = suffix[:arrow] + else: + field = None + new_key = prefix + key + suffix + item = CommandItem(new_key, self._db, item=self._db.get(new_key)) + if item.value is None: + return None + if field is not None: + if not isinstance(item.value, dict): + return None + return item.value.get(field) + else: + if not isinstance(item.value, bytes): + return None + return item.value + + @command((Key(),), (bytes,)) + def sort(self, key, *args): + i = 0 + desc = False + alpha = False + limit_start = 0 + limit_count = -1 + store = None + sortby = None + dontsort = False + get = [] + if key.value is not None: + if not isinstance(key.value, (set, list, ZSet)): + raise SimpleError(WRONGTYPE_MSG) + + while i < len(args): + arg = args[i] + if casematch(arg, b'asc'): + desc = False + elif casematch(arg, b'desc'): + desc = True + elif casematch(arg, b'alpha'): + alpha = True + elif casematch(arg, b'limit') and i + 2 < len(args): + try: + limit_start = Int.decode(args[i + 1]) + limit_count = Int.decode(args[i + 2]) + except SimpleError: + raise SimpleError(SYNTAX_ERROR_MSG) + else: + i += 2 + elif casematch(arg, b'store') and i + 1 < len(args): + store = args[i + 1] + i += 1 + elif casematch(arg, b'by') and i + 1 < len(args): + sortby = args[i + 1] + if b'*' not in sortby: + dontsort = True + i += 1 + elif casematch(arg, b'get') and i + 1 < len(args): + get.append(args[i + 1]) + i += 1 + else: + raise SimpleError(SYNTAX_ERROR_MSG) + i += 1 + + # TODO: force sorting if the object is a set and either in Lua or + # storing to a key, to match redis behaviour. + items = list(key.value) if key.value is not None else [] + + # These transformations are based on the redis implementation, but + # changed to produce a half-open range. + start = max(limit_start, 0) + end = len(items) if limit_count < 0 else start + limit_count + if start >= len(items): + start = end = len(items) - 1 + end = min(end, len(items)) + + if not get: + get.append(b'#') + if sortby is None: + sortby = b'#' + + if not dontsort: + if alpha: + def sort_key(v): + byval = self._lookup_key(v, sortby) + # TODO: use locale.strxfrm when not storing? But then need + # to decode too. + if byval is None: + byval = BeforeAny() + return byval + + else: + def sort_key(v): + byval = self._lookup_key(v, sortby) + score = SortFloat.decode(byval) if byval is not None else 0.0 + return (score, v) + + items.sort(key=sort_key, reverse=desc) + elif isinstance(key.value, (list, ZSet)): + items.reverse() + + out = [] + for row in items[start:end]: + for g in get: + v = self._lookup_key(row, g) + if store is not None and v is None: + v = b'' + out.append(v) + if store is not None: + item = CommandItem(store, self._db, item=self._db.get(store)) + item.value = out + item.writeback() + return len(out) + else: + return out + + @command((Key(missing_return=None),)) + def dump(self, key): + value = pickle.dumps(key.value) + checksum = hashlib.sha1(value).digest() + return checksum + value + + @command((Key(), Int, bytes), (bytes,)) + def restore(self, key, ttl, value, *args): + replace = False + i = 0 + while i < len(args): + if casematch(args[i], b'replace'): + replace = True + i += 1 + else: + raise SimpleError(SYNTAX_ERROR_MSG) + if key and not replace: + raise SimpleError(RESTORE_KEY_EXISTS) + checksum, value = value[:20], value[20:] + if hashlib.sha1(value).digest() != checksum: + raise SimpleError(RESTORE_INVALID_CHECKSUM_MSG) + if ttl < 0: + raise SimpleError(RESTORE_INVALID_TTL_MSG) + if ttl == 0: + expireat = None + else: + expireat = self._db.time + ttl / 1000.0 + key.value = pickle.loads(value) + key.expireat = expireat + return OK + + # Transaction commands + + def _clear_watches(self): + self._watch_notified = False + while self._watches: + (key, db) = self._watches.pop() + db.remove_watch(key, self) + + @command((), flags='s') + def multi(self): + if self._transaction is not None: + raise SimpleError(MULTI_NESTED_MSG) + self._transaction = [] + self._transaction_failed = False + return OK + + @command((), flags='s') + def discard(self): + if self._transaction is None: + raise SimpleError(WITHOUT_MULTI_MSG.format('DISCARD')) + self._transaction = None + self._transaction_failed = False + self._clear_watches() + return OK + + @command((), name='exec', flags='s') + def exec_(self): + if self._transaction is None: + raise SimpleError(WITHOUT_MULTI_MSG.format('EXEC')) + if self._transaction_failed: + self._transaction = None + self._clear_watches() + raise SimpleError(EXECABORT_MSG) + transaction = self._transaction + self._transaction = None + self._transaction_failed = False + watch_notified = self._watch_notified + self._clear_watches() + if watch_notified: + return None + result = [] + for func, sig, args in transaction: + try: + self._in_transaction = True + ans = self._run_command(func, sig, args, False) + except SimpleError as exc: + ans = exc + finally: + self._in_transaction = False + result.append(ans) + return result + + @command((Key(),), (Key(),), flags='s') + def watch(self, *keys): + if self._transaction is not None: + raise SimpleError(WATCH_INSIDE_MULTI_MSG) + for key in keys: + if key not in self._watches: + self._watches.add((key.key, self._db)) + self._db.add_watch(key.key, self) + return OK + + @command((), flags='s') + def unwatch(self): + self._clear_watches() + return OK + + # String commands + # TODO: bitfield, bitop, bitpos + + @command((Key(bytes), bytes)) + def append(self, key, value): + old = key.get(b'') + if len(old) + len(value) > MAX_STRING_SIZE: + raise SimpleError(STRING_OVERFLOW_MSG) + key.update(key.get(b'') + value) + return len(key.value) + + @command((Key(bytes, 0),), (bytes,)) + def bitcount(self, key, *args): + # Redis checks the argument count before decoding integers. That's why + # we can't declare them as Int. + if args: + if len(args) != 2: + raise SimpleError(SYNTAX_ERROR_MSG) + start = Int.decode(args[0]) + end = Int.decode(args[1]) + start, end = self._fix_range_string(start, end, len(key.value)) + value = key.value[start:end] + else: + value = key.value + return bin(int.from_bytes(value, 'little')).count('1') + + @command((Key(bytes), Int)) + def decrby(self, key, amount): + return self.incrby(key, -amount) + + @command((Key(bytes),)) + def decr(self, key): + return self.incrby(key, -1) + + @command((Key(bytes), Int)) + def incrby(self, key, amount): + c = Int.decode(key.get(b'0')) + amount + key.update(Int.encode(c)) + return c + + @command((Key(bytes),)) + def incr(self, key): + return self.incrby(key, 1) + + @command((Key(bytes), bytes)) + def incrbyfloat(self, key, amount): + # TODO: introduce convert_order so that we can specify amount is Float + c = Float.decode(key.get(b'0')) + Float.decode(amount) + if not math.isfinite(c): + raise SimpleError(NONFINITE_MSG) + encoded = Float.encode(c, True) + key.update(encoded) + return encoded + + @command((Key(bytes),)) + def get(self, key): + return key.get(None) + + @command((Key(bytes), BitOffset)) + def getbit(self, key, offset): + value = key.get(b'') + byte = offset // 8 + remaining = offset % 8 + actual_bitoffset = 7 - remaining + try: + actual_val = value[byte] + except IndexError: + return 0 + return 1 if (1 << actual_bitoffset) & actual_val else 0 + + @command((Key(bytes), BitOffset, BitValue)) + def setbit(self, key, offset, value): + val = key.get(b'\x00') + byte = offset // 8 + remaining = offset % 8 + actual_bitoffset = 7 - remaining + if len(val) - 1 < byte: + # We need to expand val so that we can set the appropriate + # bit. + needed = byte - (len(val) - 1) + val += b'\x00' * needed + old_byte = val[byte] + if value == 1: + new_byte = old_byte | (1 << actual_bitoffset) + else: + new_byte = old_byte & ~(1 << actual_bitoffset) + old_value = value if old_byte == new_byte else 1 - value + reconstructed = bytearray(val) + reconstructed[byte] = new_byte + key.update(bytes(reconstructed)) + return old_value + + @command((Key(bytes), Int, Int)) + def getrange(self, key, start, end): + value = key.get(b'') + start, end = self._fix_range_string(start, end, len(value)) + return value[start:end] + + # substr is a deprecated alias for getrange + @command((Key(bytes), Int, Int)) + def substr(self, key, start, end): + return self.getrange(key, start, end) + + @command((Key(bytes), bytes)) + def getset(self, key, value): + old = key.value + key.value = value + return old + + @command((Key(),), (Key(),)) + def mget(self, *keys): + return [key.value if isinstance(key.value, bytes) else None for key in keys] + + @command((Key(), bytes), (Key(), bytes)) + def mset(self, *args): + for i in range(0, len(args), 2): + args[i].value = args[i + 1] + return OK + + @command((Key(), bytes), (Key(), bytes)) + def msetnx(self, *args): + for i in range(0, len(args), 2): + if args[i]: + return 0 + for i in range(0, len(args), 2): + args[i].value = args[i + 1] + return 1 + + @command((Key(), bytes), (bytes,), name='set') + def set_(self, key, value, *args): + i = 0 + ex = None + px = None + xx = False + nx = False + keepttl = False + get = False + while i < len(args): + if casematch(args[i], b'nx'): + nx = True + i += 1 + elif casematch(args[i], b'xx'): + xx = True + i += 1 + elif casematch(args[i], b'ex') and i + 1 < len(args): + ex = Int.decode(args[i + 1]) + if ex <= 0 or (self._db.time + ex) * 1000 >= 2**63: + raise SimpleError(INVALID_EXPIRE_MSG.format('set')) + i += 2 + elif casematch(args[i], b'px') and i + 1 < len(args): + px = Int.decode(args[i + 1]) + if px <= 0 or self._db.time * 1000 + px >= 2**63: + raise SimpleError(INVALID_EXPIRE_MSG.format('set')) + i += 2 + elif casematch(args[i], b'keepttl'): + keepttl = True + i += 1 + elif casematch(args[i], b'get'): + get = True + i += 1 + else: + raise SimpleError(SYNTAX_ERROR_MSG) + if (xx and nx) or ((px is not None) + (ex is not None) + keepttl > 1): + raise SimpleError(SYNTAX_ERROR_MSG) + if nx and get: + # The command docs say this is allowed from Redis 7.0. + raise SimpleError(SYNTAX_ERROR_MSG) + + old_value = None + if get: + if key.value is not None and type(key.value) is not bytes: + raise SimpleError(WRONGTYPE_MSG) + old_value = key.value + + if nx and key: + return old_value + if xx and not key: + return old_value + if not keepttl: + key.value = value + else: + key.update(value) + if ex is not None: + key.expireat = self._db.time + ex + if px is not None: + key.expireat = self._db.time + px / 1000.0 + return OK if not get else old_value + + @command((Key(), Int, bytes)) + def setex(self, key, seconds, value): + if seconds <= 0 or (self._db.time + seconds) * 1000 >= 2**63: + raise SimpleError(INVALID_EXPIRE_MSG.format('setex')) + key.value = value + key.expireat = self._db.time + seconds + return OK + + @command((Key(), Int, bytes)) + def psetex(self, key, ms, value): + if ms <= 0 or self._db.time * 1000 + ms >= 2**63: + raise SimpleError(INVALID_EXPIRE_MSG.format('psetex')) + key.value = value + key.expireat = self._db.time + ms / 1000.0 + return OK + + @command((Key(), bytes)) + def setnx(self, key, value): + if key: + return 0 + key.value = value + return 1 + + @command((Key(bytes), Int, bytes)) + def setrange(self, key, offset, value): + if offset < 0: + raise SimpleError(INVALID_OFFSET_MSG) + elif not value: + return len(key.get(b'')) + elif offset + len(value) > MAX_STRING_SIZE: + raise SimpleError(STRING_OVERFLOW_MSG) + else: + out = key.get(b'') + if len(out) < offset: + out += b'\x00' * (offset - len(out)) + out = out[0:offset] + value + out[offset+len(value):] + key.update(out) + return len(out) + + @command((Key(bytes),)) + def strlen(self, key): + return len(key.get(b'')) + + # Hash commands + + @command((Key(Hash), bytes), (bytes,)) + def hdel(self, key, *fields): + h = key.value + rem = 0 + for field in fields: + if field in h: + del h[field] + key.updated() + rem += 1 + return rem + + @command((Key(Hash), bytes)) + def hexists(self, key, field): + return int(field in key.value) + + @command((Key(Hash), bytes)) + def hget(self, key, field): + return key.value.get(field) + + @command((Key(Hash),)) + def hgetall(self, key): + return list(itertools.chain(*key.value.items())) + + @command((Key(Hash), bytes, Int)) + def hincrby(self, key, field, amount): + c = Int.decode(key.value.get(field, b'0')) + amount + key.value[field] = Int.encode(c) + key.updated() + return c + + @command((Key(Hash), bytes, bytes)) + def hincrbyfloat(self, key, field, amount): + c = Float.decode(key.value.get(field, b'0')) + Float.decode(amount) + if not math.isfinite(c): + raise SimpleError(NONFINITE_MSG) + encoded = Float.encode(c, True) + key.value[field] = encoded + key.updated() + return encoded + + @command((Key(Hash),)) + def hkeys(self, key): + return list(key.value.keys()) + + @command((Key(Hash),)) + def hlen(self, key): + return len(key.value) + + @command((Key(Hash), bytes), (bytes,)) + def hmget(self, key, *fields): + return [key.value.get(field) for field in fields] + + @command((Key(Hash), bytes, bytes), (bytes, bytes)) + def hmset(self, key, *args): + self.hset(key, *args) + return OK + + @command((Key(Hash), Int,), (bytes, bytes)) + def hscan(self, key, cursor, *args): + cursor, keys = self._scan(key.value, cursor, *args) + items = [] + for k in keys: + items.append(k) + items.append(key.value[k]) + return [cursor, items] + + @command((Key(Hash), bytes, bytes), (bytes, bytes)) + def hset(self, key, *args): + h = key.value + created = 0 + for i in range(0, len(args), 2): + if args[i] not in h: + created += 1 + h[args[i]] = args[i + 1] + key.updated() + return created + + @command((Key(Hash), bytes, bytes)) + def hsetnx(self, key, field, value): + if field in key.value: + return 0 + return self.hset(key, field, value) + + @command((Key(Hash), bytes)) + def hstrlen(self, key, field): + return len(key.value.get(field, b'')) + + @command((Key(Hash),)) + def hvals(self, key): + return list(key.value.values()) + + # List commands + + def _bpop_pass(self, keys, op, first_pass): + for key in keys: + item = CommandItem(key, self._db, item=self._db.get(key), default=[]) + if not isinstance(item.value, list): + if first_pass: + raise SimpleError(WRONGTYPE_MSG) + else: + continue + if item.value: + ret = op(item.value) + item.updated() + item.writeback() + return [key, ret] + return None + + def _bpop(self, args, op): + keys = args[:-1] + timeout = Timeout.decode(args[-1]) + return self._blocking(timeout, functools.partial(self._bpop_pass, keys, op)) + + @command((bytes, bytes), (bytes,), flags='s') + def blpop(self, *args): + return self._bpop(args, lambda lst: lst.pop(0)) + + @command((bytes, bytes), (bytes,), flags='s') + def brpop(self, *args): + return self._bpop(args, lambda lst: lst.pop()) + + def _brpoplpush_pass(self, source, destination, first_pass): + src = CommandItem(source, self._db, item=self._db.get(source), default=[]) + if not isinstance(src.value, list): + if first_pass: + raise SimpleError(WRONGTYPE_MSG) + else: + return None + if not src.value: + return None # Empty list + dst = CommandItem(destination, self._db, item=self._db.get(destination), default=[]) + if not isinstance(dst.value, list): + raise SimpleError(WRONGTYPE_MSG) + el = src.value.pop() + dst.value.insert(0, el) + src.updated() + src.writeback() + if destination != source: + # Ensure writeback only happens once + dst.updated() + dst.writeback() + return el + + @command((bytes, bytes, Timeout), flags='s') + def brpoplpush(self, source, destination, timeout): + return self._blocking(timeout, + functools.partial(self._brpoplpush_pass, source, destination)) + + @command((Key(list, None), Int)) + def lindex(self, key, index): + try: + return key.value[index] + except IndexError: + return None + + @command((Key(list), bytes, bytes, bytes)) + def linsert(self, key, where, pivot, value): + if not casematch(where, b'before') and not casematch(where, b'after'): + raise SimpleError(SYNTAX_ERROR_MSG) + if not key: + return 0 + else: + try: + index = key.value.index(pivot) + except ValueError: + return -1 + if casematch(where, b'after'): + index += 1 + key.value.insert(index, value) + key.updated() + return len(key.value) + + @command((Key(list),)) + def llen(self, key): + return len(key.value) + + def _list_pop(self, get_slice, key, *args): + """Implements lpop and rpop. + + `get_slice` must take a count and return a slice expression for the + range to pop. + """ + # This implementation is somewhat contorted to match the odd + # behaviours described in https://github.com/redis/redis/issues/9680. + count = 1 + if len(args) > 1: + raise SimpleError(SYNTAX_ERROR_MSG) + elif len(args) == 1: + count = args[0] + if count < 0: + raise SimpleError(INDEX_ERROR_MSG) + elif count == 0: + return None + if not key: + return None + elif type(key.value) != list: + raise SimpleError(WRONGTYPE_MSG) + slc = get_slice(count) + ret = key.value[slc] + del key.value[slc] + key.updated() + if not args: + ret = ret[0] + return ret + + @command((Key(),), (Int(),)) + def lpop(self, key, *args): + return self._list_pop(lambda count: slice(None, count), key, *args) + + @command((Key(list), bytes), (bytes,)) + def lpush(self, key, *values): + for value in values: + key.value.insert(0, value) + key.updated() + return len(key.value) + + @command((Key(list), bytes), (bytes,)) + def lpushx(self, key, *values): + if not key: + return 0 + return self.lpush(key, *values) + + @command((Key(list), Int, Int)) + def lrange(self, key, start, stop): + start, stop = self._fix_range(start, stop, len(key.value)) + return key.value[start:stop] + + @command((Key(list), Int, bytes)) + def lrem(self, key, count, value): + a_list = key.value + found = [] + for i, el in enumerate(a_list): + if el == value: + found.append(i) + if count > 0: + indices_to_remove = found[:count] + elif count < 0: + indices_to_remove = found[count:] + else: + indices_to_remove = found + # Iterating in reverse order to ensure the indices + # remain valid during deletion. + for index in reversed(indices_to_remove): + del a_list[index] + if indices_to_remove: + key.updated() + return len(indices_to_remove) + + @command((Key(list), Int, bytes)) + def lset(self, key, index, value): + if not key: + raise SimpleError(NO_KEY_MSG) + try: + key.value[index] = value + key.updated() + except IndexError: + raise SimpleError(INDEX_ERROR_MSG) + return OK + + @command((Key(list), Int, Int)) + def ltrim(self, key, start, stop): + if key: + if stop == -1: + stop = None + else: + stop += 1 + new_value = key.value[start:stop] + # TODO: check if this should actually be conditional + if len(new_value) != len(key.value): + key.update(new_value) + return OK + + @command((Key(),), (Int(),)) + def rpop(self, key, *args): + return self._list_pop(lambda count: slice(None, -count - 1, -1), key, *args) + + @command((Key(list, None), Key(list))) + def rpoplpush(self, src, dst): + el = self.rpop(src) + self.lpush(dst, el) + return el + + @command((Key(list), bytes), (bytes,)) + def rpush(self, key, *values): + for value in values: + key.value.append(value) + key.updated() + return len(key.value) + + @command((Key(list), bytes), (bytes,)) + def rpushx(self, key, *values): + if not key: + return 0 + return self.rpush(key, *values) + + # Set commands + + @command((Key(set), bytes), (bytes,)) + def sadd(self, key, *members): + old_size = len(key.value) + key.value.update(members) + key.updated() + return len(key.value) - old_size + + @command((Key(set),)) + def scard(self, key): + return len(key.value) + + def _calc_setop(self, op, stop_if_missing, key, *keys): + if stop_if_missing and not key.value: + return set() + ans = key.value.copy() + for other in keys: + value = other.value if other.value is not None else set() + if not isinstance(value, set): + raise SimpleError(WRONGTYPE_MSG) + if stop_if_missing and not value: + return set() + ans = op(ans, value) + return ans + + def _setop(self, op, stop_if_missing, dst, key, *keys): + """Apply one of SINTER[STORE], SUNION[STORE], SDIFF[STORE]. + + If `stop_if_missing`, the output will be made an empty set as soon as + an empty input set is encountered (use for SINTER[STORE]). May assume + that `key` is a set (or empty), but `keys` could be anything. + """ + ans = self._calc_setop(op, stop_if_missing, key, *keys) + if dst is None: + return list(ans) + else: + dst.value = ans + return len(dst.value) + + @command((Key(set),), (Key(set),)) + def sdiff(self, *keys): + return self._setop(lambda a, b: a - b, False, None, *keys) + + @command((Key(), Key(set)), (Key(set),)) + def sdiffstore(self, dst, *keys): + return self._setop(lambda a, b: a - b, False, dst, *keys) + + @command((Key(set),), (Key(set),)) + def sinter(self, *keys): + return self._setop(lambda a, b: a & b, True, None, *keys) + + @command((Key(), Key(set)), (Key(set),)) + def sinterstore(self, dst, *keys): + return self._setop(lambda a, b: a & b, True, dst, *keys) + + @command((Key(set), bytes)) + def sismember(self, key, member): + return int(member in key.value) + + @command((Key(set),)) + def smembers(self, key): + return list(key.value) + + @command((Key(set, 0), Key(set), bytes)) + def smove(self, src, dst, member): + try: + src.value.remove(member) + src.updated() + except KeyError: + return 0 + else: + dst.value.add(member) + dst.updated() # TODO: is it updated if member was already present? + return 1 + + @command((Key(set),), (Int,)) + def spop(self, key, count=None): + if count is None: + if not key.value: + return None + item = random.sample(list(key.value), 1)[0] + key.value.remove(item) + key.updated() + return item + else: + if count < 0: + raise SimpleError(INDEX_ERROR_MSG) + items = self.srandmember(key, count) + for item in items: + key.value.remove(item) + key.updated() # Inside the loop because redis special-cases count=0 + return items + + @command((Key(set),), (Int,)) + def srandmember(self, key, count=None): + if count is None: + if not key.value: + return None + else: + return random.sample(list(key.value), 1)[0] + elif count >= 0: + count = min(count, len(key.value)) + return random.sample(list(key.value), count) + else: + items = list(key.value) + return [random.choice(items) for _ in range(-count)] + + @command((Key(set), bytes), (bytes,)) + def srem(self, key, *members): + old_size = len(key.value) + for member in members: + key.value.discard(member) + deleted = old_size - len(key.value) + if deleted: + key.updated() + return deleted + + @command((Key(set), Int), (bytes, bytes)) + def sscan(self, key, cursor, *args): + return self._scan(key.value, cursor, *args) + + @command((Key(set),), (Key(set),)) + def sunion(self, *keys): + return self._setop(lambda a, b: a | b, False, None, *keys) + + @command((Key(), Key(set)), (Key(set),)) + def sunionstore(self, dst, *keys): + return self._setop(lambda a, b: a | b, False, dst, *keys) + + # Hyperloglog commands + # These are not quite the same as the real redis ones, which are + # approximate and store the results in a string. Instead, it is implemented + # on top of sets. + + @command((Key(set),), (bytes,)) + def pfadd(self, key, *elements): + result = self.sadd(key, *elements) + # Per the documentation: + # - 1 if at least 1 HyperLogLog internal register was altered. 0 otherwise. + return 1 if result > 0 else 0 + + @command((Key(set),), (Key(set),)) + def pfcount(self, *keys): + """ + Return the approximated cardinality of + the set observed by the HyperLogLog at key(s). + """ + return len(self.sunion(*keys)) + + @command((Key(set), Key(set)), (Key(set),)) + def pfmerge(self, dest, *sources): + "Merge N different HyperLogLogs into a single one." + self.sunionstore(dest, *sources) + return OK + + # Sorted set commands + # TODO: [b]zpopmin/zpopmax, + + @staticmethod + def _limit_items(items, offset, count): + out = [] + for item in items: + if offset: # Note: not offset > 0, in order to match redis + offset -= 1 + continue + if count == 0: + break + count -= 1 + out.append(item) + return out + + @staticmethod + def _apply_withscores(items, withscores): + if withscores: + out = [] + for item in items: + out.append(item[1]) + out.append(Float.encode(item[0], False)) + else: + out = [item[1] for item in items] + return out + + @command((Key(ZSet), bytes, bytes), (bytes,)) + def zadd(self, key, *args): + zset = key.value + + i = 0 + ch = False + nx = False + xx = False + incr = False + while i < len(args): + if casematch(args[i], b'ch'): + ch = True + i += 1 + elif casematch(args[i], b'nx'): + nx = True + i += 1 + elif casematch(args[i], b'xx'): + xx = True + i += 1 + elif casematch(args[i], b'incr'): + incr = True + i += 1 + else: + # First argument not matching flags indicates the start of + # score pairs. + break + + if nx and xx: + raise SimpleError(ZADD_NX_XX_ERROR_MSG) + + elements = args[i:] + if not elements or len(elements) % 2 != 0: + raise SimpleError(SYNTAX_ERROR_MSG) + if incr and len(elements) != 2: + raise SimpleError(ZADD_INCR_LEN_ERROR_MSG) + # Parse all scores first, before updating + items = [ + (Float.decode(elements[j]), elements[j + 1]) + for j in range(0, len(elements), 2) + ] + old_len = len(zset) + changed_items = 0 + + if incr: + item_score, item_name = items[0] + if (nx and item_name in zset) or (xx and item_name not in zset): + return None + return self.zincrby(key, item_score, item_name) + + for item_score, item_name in items: + if ( + (not nx or item_name not in zset) + and (not xx or item_name in zset) + ): + if zset.add(item_name, item_score): + changed_items += 1 + + if changed_items: + key.updated() + + if ch: + return changed_items + return len(zset) - old_len + + @command((Key(ZSet),)) + def zcard(self, key): + return len(key.value) + + @command((Key(ZSet), ScoreTest, ScoreTest)) + def zcount(self, key, min, max): + return key.value.zcount(min.lower_bound, max.upper_bound) + + @command((Key(ZSet), Float, bytes)) + def zincrby(self, key, increment, member): + # Can't just default the old score to 0.0, because in IEEE754, adding + # 0.0 to something isn't a nop (e.g. 0.0 + -0.0 == 0.0). + try: + score = key.value.get(member, None) + increment + except TypeError: + score = increment + if math.isnan(score): + raise SimpleError(SCORE_NAN_MSG) + key.value[member] = score + key.updated() + return Float.encode(score, False) + + @command((Key(ZSet), StringTest, StringTest)) + def zlexcount(self, key, min, max): + return key.value.zlexcount(min.value, min.exclusive, max.value, max.exclusive) + + def _zrange(self, key, start, stop, reverse, *args): + zset = key.value + withscores = False + for arg in args: + if casematch(arg, b'withscores'): + withscores = True + else: + raise SimpleError(SYNTAX_ERROR_MSG) + start, stop = self._fix_range(start, stop, len(zset)) + if reverse: + start, stop = len(zset) - stop, len(zset) - start + items = zset.islice_score(start, stop, reverse) + items = self._apply_withscores(items, withscores) + return items + + @command((Key(ZSet), Int, Int), (bytes,)) + def zrange(self, key, start, stop, *args): + return self._zrange(key, start, stop, False, *args) + + @command((Key(ZSet), Int, Int), (bytes,)) + def zrevrange(self, key, start, stop, *args): + return self._zrange(key, start, stop, True, *args) + + def _zrangebylex(self, key, min, max, reverse, *args): + if args: + if len(args) != 3 or not casematch(args[0], b'limit'): + raise SimpleError(SYNTAX_ERROR_MSG) + offset = Int.decode(args[1]) + count = Int.decode(args[2]) + else: + offset = 0 + count = -1 + zset = key.value + items = zset.irange_lex(min.value, max.value, + inclusive=(not min.exclusive, not max.exclusive), + reverse=reverse) + items = self._limit_items(items, offset, count) + return items + + @command((Key(ZSet), StringTest, StringTest), (bytes,)) + def zrangebylex(self, key, min, max, *args): + return self._zrangebylex(key, min, max, False, *args) + + @command((Key(ZSet), StringTest, StringTest), (bytes,)) + def zrevrangebylex(self, key, max, min, *args): + return self._zrangebylex(key, min, max, True, *args) + + def _zrangebyscore(self, key, min, max, reverse, *args): + withscores = False + offset = 0 + count = -1 + i = 0 + while i < len(args): + if casematch(args[i], b'withscores'): + withscores = True + i += 1 + elif casematch(args[i], b'limit') and i + 2 < len(args): + offset = Int.decode(args[i + 1]) + count = Int.decode(args[i + 2]) + i += 3 + else: + raise SimpleError(SYNTAX_ERROR_MSG) + zset = key.value + items = list(zset.irange_score(min.lower_bound, max.upper_bound, reverse=reverse)) + items = self._limit_items(items, offset, count) + items = self._apply_withscores(items, withscores) + return items + + @command((Key(ZSet), ScoreTest, ScoreTest), (bytes,)) + def zrangebyscore(self, key, min, max, *args): + return self._zrangebyscore(key, min, max, False, *args) + + @command((Key(ZSet), ScoreTest, ScoreTest), (bytes,)) + def zrevrangebyscore(self, key, max, min, *args): + return self._zrangebyscore(key, min, max, True, *args) + + @command((Key(ZSet), bytes)) + def zrank(self, key, member): + try: + return key.value.rank(member) + except KeyError: + return None + + @command((Key(ZSet), bytes)) + def zrevrank(self, key, member): + try: + return len(key.value) - 1 - key.value.rank(member) + except KeyError: + return None + + @command((Key(ZSet), bytes), (bytes,)) + def zrem(self, key, *members): + old_size = len(key.value) + for member in members: + key.value.discard(member) + deleted = old_size - len(key.value) + if deleted: + key.updated() + return deleted + + @command((Key(ZSet), StringTest, StringTest)) + def zremrangebylex(self, key, min, max): + items = key.value.irange_lex(min.value, max.value, + inclusive=(not min.exclusive, not max.exclusive)) + return self.zrem(key, *items) + + @command((Key(ZSet), ScoreTest, ScoreTest)) + def zremrangebyscore(self, key, min, max): + items = key.value.irange_score(min.lower_bound, max.upper_bound) + return self.zrem(key, *[item[1] for item in items]) + + @command((Key(ZSet), Int, Int)) + def zremrangebyrank(self, key, start, stop): + zset = key.value + start, stop = self._fix_range(start, stop, len(zset)) + items = zset.islice_score(start, stop) + return self.zrem(key, *[item[1] for item in items]) + + @command((Key(ZSet), Int), (bytes, bytes)) + def zscan(self, key, cursor, *args): + new_cursor, ans = self._scan(key.value.items(), cursor, *args) + flat = [] + for (key, score) in ans: + flat.append(key) + flat.append(Float.encode(score, False)) + return [new_cursor, flat] + + @command((Key(ZSet), bytes)) + def zscore(self, key, member): + try: + return Float.encode(key.value[member], False) + except KeyError: + return None + + @staticmethod + def _get_zset(value): + if isinstance(value, set): + zset = ZSet() + for item in value: + zset[item] = 1.0 + return zset + elif isinstance(value, ZSet): + return value + else: + raise SimpleError(WRONGTYPE_MSG) + + def _zunioninter(self, func, dest, numkeys, *args): + if numkeys < 1: + raise SimpleError(ZUNIONSTORE_KEYS_MSG) + if numkeys > len(args): + raise SimpleError(SYNTAX_ERROR_MSG) + aggregate = b'sum' + sets = [] + for i in range(numkeys): + item = CommandItem(args[i], self._db, item=self._db.get(args[i]), default=ZSet()) + sets.append(self._get_zset(item.value)) + weights = [1.0] * numkeys + + i = numkeys + while i < len(args): + arg = args[i] + if casematch(arg, b'weights') and i + numkeys < len(args): + weights = [Float.decode(x) for x in args[i + 1:i + numkeys + 1]] + i += numkeys + 1 + elif casematch(arg, b'aggregate') and i + 1 < len(args): + aggregate = casenorm(args[i + 1]) + if aggregate not in (b'sum', b'min', b'max'): + raise SimpleError(SYNTAX_ERROR_MSG) + i += 2 + else: + raise SimpleError(SYNTAX_ERROR_MSG) + + out_members = set(sets[0]) + for s in sets[1:]: + if func == 'ZUNIONSTORE': + out_members |= set(s) + else: + out_members.intersection_update(s) + + # We first build a regular dict and turn it into a ZSet. The + # reason is subtle: a ZSet won't update a score from -0 to +0 + # (or vice versa) through assignment, but a regular dict will. + out = {} + # The sort affects the order of floating-point operations. + # Note that redis uses qsort(1), which has no stability guarantees, + # so we can't be sure to match it in all cases. + for s, w in sorted(zip(sets, weights), key=lambda x: len(x[0])): + for member, score in s.items(): + score *= w + # Redis only does this step for ZUNIONSTORE. See + # https://github.com/antirez/redis/issues/3954. + if func == 'ZUNIONSTORE' and math.isnan(score): + score = 0.0 + if member not in out_members: + continue + if member in out: + old = out[member] + if aggregate == b'sum': + score += old + if math.isnan(score): + score = 0.0 + elif aggregate == b'max': + score = max(old, score) + elif aggregate == b'min': + score = min(old, score) + else: + assert False # pragma: nocover + if math.isnan(score): + score = 0.0 + out[member] = score + + out_zset = ZSet() + for member, score in out.items(): + out_zset[member] = score + + dest.value = out_zset + return len(out_zset) + + @command((Key(), Int, bytes), (bytes,)) + def zunionstore(self, dest, numkeys, *args): + return self._zunioninter('ZUNIONSTORE', dest, numkeys, *args) + + @command((Key(), Int, bytes), (bytes,)) + def zinterstore(self, dest, numkeys, *args): + return self._zunioninter('ZINTERSTORE', dest, numkeys, *args) + + # Server commands + # TODO: lots + + @command((), (bytes,), flags='s') + def bgsave(self, *args): + if len(args) > 1 or (len(args) == 1 and not casematch(args[0], b'schedule')): + raise SimpleError(SYNTAX_ERROR_MSG) + self._server.lastsave = int(time.time()) + return BGSAVE_STARTED + + @command(()) + def dbsize(self): + return len(self._db) + + @command((), (bytes,)) + def flushdb(self, *args): + if args: + if len(args) != 1 or not casematch(args[0], b'async'): + raise SimpleError(SYNTAX_ERROR_MSG) + self._db.clear() + return OK + + @command((), (bytes,)) + def flushall(self, *args): + if args: + if len(args) != 1 or not casematch(args[0], b'async'): + raise SimpleError(SYNTAX_ERROR_MSG) + for db in self._server.dbs.values(): + db.clear() + # TODO: clear watches and/or pubsub as well? + return OK + + @command(()) + def lastsave(self): + return self._server.lastsave + + @command((), flags='s') + def save(self): + self._server.lastsave = int(time.time()) + return OK + + @command(()) + def time(self): + now_us = round(time.time() * 1000000) + now_s = now_us // 1000000 + now_us %= 1000000 + return [str(now_s).encode(), str(now_us).encode()] + + # Script commands + # script debug and script kill will probably not be supported + + def _convert_redis_arg(self, lua_runtime, value): + # Type checks are exact to avoid issues like bool being a subclass of int. + if type(value) is bytes: + return value + elif type(value) in {int, float}: + return '{:.17g}'.format(value).encode() + else: + # TODO: add the context + raise SimpleError(LUA_COMMAND_ARG_MSG) + + def _convert_redis_result(self, lua_runtime, result): + if isinstance(result, (bytes, int)): + return result + elif isinstance(result, SimpleString): + return lua_runtime.table_from({b"ok": result.value}) + elif result is None: + return False + elif isinstance(result, list): + converted = [ + self._convert_redis_result(lua_runtime, item) + for item in result + ] + return lua_runtime.table_from(converted) + elif isinstance(result, SimpleError): + raise result + else: + raise RuntimeError("Unexpected return type from redis: {}".format(type(result))) + + def _convert_lua_result(self, result, nested=True): + from lupa import lua_type + if lua_type(result) == 'table': + for key in (b'ok', b'err'): + if key in result: + msg = self._convert_lua_result(result[key]) + if not isinstance(msg, bytes): + raise SimpleError(LUA_WRONG_NUMBER_ARGS_MSG) + if key == b'ok': + return SimpleString(msg) + elif nested: + return SimpleError(msg.decode('utf-8', 'replace')) + else: + raise SimpleError(msg.decode('utf-8', 'replace')) + # Convert Lua tables into lists, starting from index 1, mimicking the behavior of StrictRedis. + result_list = [] + for index in itertools.count(1): + if index not in result: + break + item = result[index] + result_list.append(self._convert_lua_result(item)) + return result_list + elif isinstance(result, str): + return result.encode() + elif isinstance(result, float): + return int(result) + elif isinstance(result, bool): + return 1 if result else None + return result + + def _check_for_lua_globals(self, lua_runtime, expected_globals): + actual_globals = set(lua_runtime.globals().keys()) + if actual_globals != expected_globals: + unexpected = [six.ensure_str(var, 'utf-8', 'replace') + for var in actual_globals - expected_globals] + raise SimpleError(GLOBAL_VARIABLE_MSG.format(", ".join(unexpected))) + + def _lua_redis_call(self, lua_runtime, expected_globals, op, *args): + # Check if we've set any global variables before making any change. + self._check_for_lua_globals(lua_runtime, expected_globals) + func, func_name = self._name_to_func(op) + args = [self._convert_redis_arg(lua_runtime, arg) for arg in args] + result = self._run_command(func, func._fakeredis_sig, args, True) + return self._convert_redis_result(lua_runtime, result) + + def _lua_redis_pcall(self, lua_runtime, expected_globals, op, *args): + try: + return self._lua_redis_call(lua_runtime, expected_globals, op, *args) + except Exception as ex: + return lua_runtime.table_from({b"err": str(ex)}) + + def _lua_redis_log(self, lua_runtime, expected_globals, lvl, *args): + self._check_for_lua_globals(lua_runtime, expected_globals) + if len(args) < 1: + raise SimpleError(REQUIRES_MORE_ARGS_MSG.format("redis.log()", "two")) + if lvl not in REDIS_LOG_LEVELS.values(): + raise SimpleError(LOG_INVALID_DEBUG_LEVEL_MSG) + msg = ' '.join([x.decode('utf-8') + if isinstance(x, bytes) else str(x) + for x in args if not isinstance(x, bool)]) + LOGGER.log(REDIS_LOG_LEVELS_TO_LOGGING[lvl], msg) + + @command((bytes, Int), (bytes,), flags='s') + def eval(self, script, numkeys, *keys_and_args): + from lupa import LuaError, LuaRuntime, as_attrgetter + + if numkeys > len(keys_and_args): + raise SimpleError(TOO_MANY_KEYS_MSG) + if numkeys < 0: + raise SimpleError(NEGATIVE_KEYS_MSG) + sha1 = hashlib.sha1(script).hexdigest().encode() + self._server.script_cache[sha1] = script + lua_runtime = LuaRuntime(encoding=None, unpack_returned_tuples=True) + + set_globals = lua_runtime.eval( + """ + function(keys, argv, redis_call, redis_pcall, redis_log, redis_log_levels) + redis = {} + redis.call = redis_call + redis.pcall = redis_pcall + redis.log = redis_log + for level, pylevel in python.iterex(redis_log_levels.items()) do + redis[level] = pylevel + end + redis.error_reply = function(msg) return {err=msg} end + redis.status_reply = function(msg) return {ok=msg} end + KEYS = keys + ARGV = argv + end + """ + ) + expected_globals = set() + set_globals( + lua_runtime.table_from(keys_and_args[:numkeys]), + lua_runtime.table_from(keys_and_args[numkeys:]), + functools.partial(self._lua_redis_call, lua_runtime, expected_globals), + functools.partial(self._lua_redis_pcall, lua_runtime, expected_globals), + functools.partial(self._lua_redis_log, lua_runtime, expected_globals), + as_attrgetter(REDIS_LOG_LEVELS) + ) + expected_globals.update(lua_runtime.globals().keys()) + + try: + result = lua_runtime.execute(script) + except (LuaError, SimpleError) as ex: + raise SimpleError(SCRIPT_ERROR_MSG.format(sha1.decode(), ex)) + + self._check_for_lua_globals(lua_runtime, expected_globals) + + return self._convert_lua_result(result, nested=False) + + @command((bytes, Int), (bytes,), flags='s') + def evalsha(self, sha1, numkeys, *keys_and_args): + try: + script = self._server.script_cache[sha1] + except KeyError: + raise SimpleError(NO_MATCHING_SCRIPT_MSG) + return self.eval(script, numkeys, *keys_and_args) + + @command((bytes,), (bytes,), flags='s') + def script(self, subcmd, *args): + if casematch(subcmd, b'load'): + if len(args) != 1: + raise SimpleError(BAD_SUBCOMMAND_MSG.format('SCRIPT')) + script = args[0] + sha1 = hashlib.sha1(script).hexdigest().encode() + self._server.script_cache[sha1] = script + return sha1 + elif casematch(subcmd, b'exists'): + return [int(sha1 in self._server.script_cache) for sha1 in args] + elif casematch(subcmd, b'flush'): + if len(args) > 1 or (len(args) == 1 and casenorm(args[0]) not in {b'sync', b'async'}): + raise SimpleError(BAD_SUBCOMMAND_MSG.format('SCRIPT')) + self._server.script_cache = {} + return OK + else: + raise SimpleError(BAD_SUBCOMMAND_MSG.format('SCRIPT')) + + # Pubsub commands + # TODO: pubsub command + + def _subscribe(self, channels, subscribers, mtype): + for channel in channels: + subs = subscribers[channel] + if self not in subs: + subs.add(self) + self._pubsub += 1 + msg = [mtype, channel, self._pubsub] + self.put_response(msg) + return NoResponse() + + def _unsubscribe(self, channels, subscribers, mtype): + if not channels: + channels = [] + for (channel, subs) in subscribers.items(): + if self in subs: + channels.append(channel) + for channel in channels: + subs = subscribers.get(channel, set()) + if self in subs: + subs.remove(self) + if not subs: + del subscribers[channel] + self._pubsub -= 1 + msg = [mtype, channel, self._pubsub] + self.put_response(msg) + return NoResponse() + + @command((bytes,), (bytes,), flags='s') + def psubscribe(self, *patterns): + return self._subscribe(patterns, self._server.psubscribers, b'psubscribe') + + @command((bytes,), (bytes,), flags='s') + def subscribe(self, *channels): + return self._subscribe(channels, self._server.subscribers, b'subscribe') + + @command((), (bytes,), flags='s') + def punsubscribe(self, *patterns): + return self._unsubscribe(patterns, self._server.psubscribers, b'punsubscribe') + + @command((), (bytes,), flags='s') + def unsubscribe(self, *channels): + return self._unsubscribe(channels, self._server.subscribers, b'unsubscribe') + + @command((bytes, bytes)) + def publish(self, channel, message): + receivers = 0 + msg = [b'message', channel, message] + subs = self._server.subscribers.get(channel, set()) + for sock in subs: + sock.put_response(msg) + receivers += 1 + for (pattern, socks) in self._server.psubscribers.items(): + regex = compile_pattern(pattern) + if regex.match(channel): + msg = [b'pmessage', pattern, channel, message] + for sock in socks: + sock.put_response(msg) + receivers += 1 + return receivers + + +setattr(FakeSocket, 'del', FakeSocket.del_) +delattr(FakeSocket, 'del_') +setattr(FakeSocket, 'set', FakeSocket.set_) +delattr(FakeSocket, 'set_') +setattr(FakeSocket, 'exec', FakeSocket.exec_) +delattr(FakeSocket, 'exec_') + + +class _DummyParser: + def __init__(self, socket_read_size): + self.socket_read_size = socket_read_size + + def on_disconnect(self): + pass + + def on_connect(self, connection): + pass + + +# Redis <3.2 will not have a selector +try: + from redis.selector import BaseSelector +except ImportError: + class BaseSelector: + def __init__(self, sock): + self.sock = sock + + +class FakeSelector(BaseSelector): + def check_can_read(self, timeout): + if self.sock.responses.qsize(): + return True + if timeout is not None and timeout <= 0: + return False + + # A sleep/poll loop is easier to mock out than messing with condition + # variables. + start = time.time() + while True: + if self.sock.responses.qsize(): + return True + time.sleep(0.01) + now = time.time() + if timeout is not None and now > start + timeout: + return False + + def check_is_ready_for_command(self, timeout): + return True + + +class FakeConnection(redis.Connection): + description_format = "FakeConnection" + + def __init__(self, *args, **kwargs): + self._server = kwargs.pop('server') + super().__init__(*args, **kwargs) + + def connect(self): + super().connect() + # The selector is set in redis.Connection.connect() after _connect() is called + self._selector = FakeSelector(self._sock) + + def _connect(self): + if not self._server.connected: + raise redis.ConnectionError(CONNECTION_ERROR_MSG) + return FakeSocket(self._server) + + def can_read(self, timeout=0): + if not self._server.connected: + return True + if not self._sock: + self.connect() + # We use check_can_read rather than can_read, because on redis-py<3.2, + # FakeSelector inherits from a stub BaseSelector which doesn't + # implement can_read. Normally can_read provides retries on EINTR, + # but that's not necessary for the implementation of + # FakeSelector.check_can_read. + return self._selector.check_can_read(timeout) + + def _decode(self, response): + if isinstance(response, list): + return [self._decode(item) for item in response] + elif isinstance(response, bytes): + return self.encoder.decode(response) + else: + return response + + def read_response(self, disable_decoding=False): + if not self._server.connected: + try: + response = self._sock.responses.get_nowait() + except queue.Empty: + raise redis.ConnectionError(CONNECTION_ERROR_MSG) + else: + response = self._sock.responses.get() + if isinstance(response, redis.ResponseError): + raise response + if disable_decoding: + return response + else: + return self._decode(response) + + def repr_pieces(self): + pieces = [ + ('server', self._server), + ('db', self.db) + ] + if self.client_name: + pieces.append(('client_name', self.client_name)) + return pieces + + +class FakeRedisMixin: + def __init__(self, *args, server=None, connected=True, **kwargs): + # Interpret the positional and keyword arguments according to the + # version of redis in use. + bound = _ORIG_SIG.bind(*args, **kwargs) + bound.apply_defaults() + if not bound.arguments['connection_pool']: + charset = bound.arguments['charset'] + errors = bound.arguments['errors'] + # Adapted from redis-py + if charset is not None: + warnings.warn(DeprecationWarning( + '"charset" is deprecated. Use "encoding" instead')) + bound.arguments['encoding'] = charset + if errors is not None: + warnings.warn(DeprecationWarning( + '"errors" is deprecated. Use "encoding_errors" instead')) + bound.arguments['encoding_errors'] = errors + + if server is None: + server = FakeServer() + server.connected = connected + kwargs = { + 'connection_class': FakeConnection, + 'server': server + } + conn_pool_args = [ + 'db', + 'username', + 'password', + 'socket_timeout', + 'encoding', + 'encoding_errors', + 'decode_responses', + 'retry_on_timeout', + 'max_connections', + 'health_check_interval', + 'client_name' + ] + for arg in conn_pool_args: + if arg in bound.arguments: + kwargs[arg] = bound.arguments[arg] + bound.arguments['connection_pool'] = redis.connection.ConnectionPool(**kwargs) + super().__init__(*bound.args, **bound.kwargs) + + @classmethod + def from_url(https://github.com/cls,%20*args,%20**kwargs): + server = kwargs.pop('server', None) + if server is None: + server = FakeServer() + pool = redis.ConnectionPool.from_url(https://github.com/*args,%20**kwargs) + # Now override how it creates connections + pool.connection_class = FakeConnection + pool.connection_kwargs['server'] = server + # FakeConnection cannot handle the path kwarg (present when from_url + # is called with a unix socket) + pool.connection_kwargs.pop('path', None) + return cls(connection_pool=pool) + + +class FakeStrictRedis(FakeRedisMixin, redis.StrictRedis): + pass + + +class FakeRedis(FakeRedisMixin, redis.Redis): + pass diff --git a/build/lib/fakeredis/_zset.py b/build/lib/fakeredis/_zset.py new file mode 100644 index 0000000..47d1169 --- /dev/null +++ b/build/lib/fakeredis/_zset.py @@ -0,0 +1,87 @@ +import sortedcontainers + + +class ZSet: + def __init__(self): + self._bylex = {} # Maps value to score + self._byscore = sortedcontainers.SortedList() + + def __contains__(self, value): + return value in self._bylex + + def add(self, value, score): + """Update the item and return whether it modified the zset""" + old_score = self._bylex.get(value, None) + if old_score is not None: + if score == old_score: + return False + self._byscore.remove((old_score, value)) + self._bylex[value] = score + self._byscore.add((score, value)) + return True + + def __setitem__(self, value, score): + self.add(value, score) + + def __getitem__(self, key): + return self._bylex[key] + + def get(self, key, default=None): + return self._bylex.get(key, default) + + def __len__(self): + return len(self._bylex) + + def __iter__(self): + def gen(): + for score, value in self._byscore: + yield value + + return gen() + + def discard(self, key): + try: + score = self._bylex.pop(key) + except KeyError: + return + else: + self._byscore.remove((score, key)) + + def zcount(self, min_, max_): + pos1 = self._byscore.bisect_left(min_) + pos2 = self._byscore.bisect_left(max_) + return max(0, pos2 - pos1) + + def zlexcount(self, min_value, min_exclusive, max_value, max_exclusive): + if not self._byscore: + return 0 + score = self._byscore[0][0] + if min_exclusive: + pos1 = self._byscore.bisect_right((score, min_value)) + else: + pos1 = self._byscore.bisect_left((score, min_value)) + if max_exclusive: + pos2 = self._byscore.bisect_left((score, max_value)) + else: + pos2 = self._byscore.bisect_right((score, max_value)) + return max(0, pos2 - pos1) + + def islice_score(self, start, stop, reverse=False): + return self._byscore.islice(start, stop, reverse) + + def irange_lex(self, start, stop, inclusive=(True, True), reverse=False): + if not self._byscore: + return iter([]) + score = self._byscore[0][0] + it = self._byscore.irange((score, start), (score, stop), + inclusive=inclusive, reverse=reverse) + return (item[1] for item in it) + + def irange_score(self, start, stop, reverse=False): + return self._byscore.irange(start, stop, reverse=reverse) + + def rank(self, member): + return self._byscore.index((self._bylex[member], member)) + + def items(self): + return self._bylex.items() diff --git a/build/lib/fakeredis/aioredis.py b/build/lib/fakeredis/aioredis.py new file mode 100644 index 0000000..7d5ba08 --- /dev/null +++ b/build/lib/fakeredis/aioredis.py @@ -0,0 +1,10 @@ +import aioredis +import packaging.version + + +if packaging.version.Version(aioredis.__version__) >= packaging.version.Version('2.0.0a1'): + from ._aioredis2 import FakeConnection, FakeRedis # noqa: F401 +else: + from ._aioredis1 import ( # noqa: F401 + FakeConnectionsPool, create_connection, create_redis, create_pool, create_redis_pool + ) diff --git a/fakeredis/_server.py b/fakeredis/_server.py index 873043a..dbeb438 100644 --- a/fakeredis/_server.py +++ b/fakeredis/_server.py @@ -2757,7 +2757,7 @@ def _decode(self, response): else: return response - def read_response(self): + def read_response(self, disable_decoding=False): if not self._server.connected: try: response = self._sock.responses.get_nowait() @@ -2767,7 +2767,10 @@ def read_response(self): response = self._sock.responses.get() if isinstance(response, redis.ResponseError): raise response - return self._decode(response) + if disable_decoding: + return response + else: + return self._decode(response) def repr_pieces(self): pieces = [ diff --git a/requirements.in b/requirements.in index ae95967..2e14ac2 100644 --- a/requirements.in +++ b/requirements.in @@ -7,7 +7,7 @@ pytest pytest-asyncio pytest-cov pytest-mock -redis==4.0.0 # Latest at time of writing +redis==4.1.0 # Latest at time of writing six sortedcontainers diff --git a/requirements.txt b/requirements.txt index cfb96f6..c9328fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,7 +30,7 @@ lupa==1.10 # via -r requirements.in mccabe==0.6.1 # via flake8 -packaging==20.7 +packaging==21.3 # via pytest pluggy==0.13.1 # via pytest @@ -54,7 +54,7 @@ pytest-cov==2.10.1 # via -r requirements.in pytest-mock==3.3.1 # via -r requirements.in -redis==4.0.0 +redis==4.1.0 # via -r requirements.in six==1.15.0 # via -r requirements.in diff --git a/setup.cfg b/setup.cfg index d6680c6..0ded32c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,7 +26,7 @@ install_requires = # Minor version updates to redis tend to break fakeredis. If you # need to use fakeredis with a newer redis, please submit a PR that # relaxes this restriction and adds it to the Github Actions tests. - redis<4.1.0 + redis<=4.1.0 six>=1.12 sortedcontainers python_requires = >=3.5 From ccaf36201bb895d4fe06b82805d721512a7c18fc Mon Sep 17 00:00:00 2001 From: rotten Date: Tue, 4 Jan 2022 17:12:13 -0500 Subject: [PATCH 12/20] remove build artifacts --- build/lib/fakeredis/__init__.py | 4 - build/lib/fakeredis/_aioredis1.py | 181 -- build/lib/fakeredis/_aioredis2.py | 170 -- build/lib/fakeredis/_async.py | 51 - build/lib/fakeredis/_server.py | 2850 ----------------------------- build/lib/fakeredis/_zset.py | 87 - build/lib/fakeredis/aioredis.py | 10 - 7 files changed, 3353 deletions(-) delete mode 100644 build/lib/fakeredis/__init__.py delete mode 100644 build/lib/fakeredis/_aioredis1.py delete mode 100644 build/lib/fakeredis/_aioredis2.py delete mode 100644 build/lib/fakeredis/_async.py delete mode 100644 build/lib/fakeredis/_server.py delete mode 100644 build/lib/fakeredis/_zset.py delete mode 100644 build/lib/fakeredis/aioredis.py diff --git a/build/lib/fakeredis/__init__.py b/build/lib/fakeredis/__init__.py deleted file mode 100644 index cff7eef..0000000 --- a/build/lib/fakeredis/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from ._server import FakeServer, FakeRedis, FakeStrictRedis, FakeConnection # noqa: F401 - - -__version__ = '1.7.0' diff --git a/build/lib/fakeredis/_aioredis1.py b/build/lib/fakeredis/_aioredis1.py deleted file mode 100644 index 7679f2e..0000000 --- a/build/lib/fakeredis/_aioredis1.py +++ /dev/null @@ -1,181 +0,0 @@ -import asyncio -import sys -import warnings - -import aioredis - -from . import _async, _server - - -class FakeSocket(_async.AsyncFakeSocket): - def _decode_error(self, error): - return aioredis.ReplyError(error.value) - - -class FakeReader: - """Re-implementation of aioredis.stream.StreamReader. - - It does not use a socket, but instead provides a queue that feeds - `readobj`. - """ - - def __init__(self, socket): - self._socket = socket - - def set_parser(self, parser): - pass # No parser needed, we get already-parsed data - - async def readobj(self): - if self._socket.responses is None: - raise asyncio.CancelledError - result = await self._socket.responses.get() - return result - - def at_eof(self): - return self._socket.responses is None - - def feed_obj(self, obj): - self._queue.put_nowait(obj) - - -class FakeWriter: - """Replaces a StreamWriter for an aioredis connection.""" - - def __init__(self, socket): - self.transport = socket # So that aioredis can call writer.transport.close() - - def write(self, data): - self.transport.sendall(data) - - -class FakeConnectionsPool(aioredis.ConnectionsPool): - def __init__(self, server=None, db=None, password=None, encoding=None, - *, minsize, maxsize, ssl=None, parser=None, - create_connection_timeout=None, - connection_cls=None, - loop=None): - super().__init__('fakeredis', - db=db, - password=password, - encoding=encoding, - minsize=minsize, - maxsize=maxsize, - ssl=ssl, - parser=parser, - create_connection_timeout=create_connection_timeout, - connection_cls=connection_cls, - loop=loop) - if server is None: - server = _server.FakeServer() - self._server = server - - def _create_new_connection(self, address): - # TODO: what does address do here? Might just be for sentinel? - return create_connection(self._server, - db=self._db, - password=self._password, - ssl=self._ssl, - encoding=self._encoding, - parser=self._parser_class, - timeout=self._create_connection_timeout, - connection_cls=self._connection_cls, - ) - - -async def create_connection(server=None, *, db=None, password=None, ssl=None, - encoding=None, parser=None, loop=None, - timeout=None, connection_cls=None): - # This is mostly copied from aioredis.connection.create_connection - if timeout is not None and timeout <= 0: - raise ValueError("Timeout has to be None or a number greater than 0") - - if connection_cls: - assert issubclass(connection_cls, aioredis.abc.AbcConnection),\ - "connection_class does not meet the AbcConnection contract" - cls = connection_cls - else: - cls = aioredis.connection.RedisConnection - - if loop is not None and sys.version_info >= (3, 8, 0): - warnings.warn("The loop argument is deprecated", - DeprecationWarning) - - if server is None: - server = _server.FakeServer() - socket = FakeSocket(server) - reader = FakeReader(socket) - writer = FakeWriter(socket) - conn = cls(reader, writer, encoding=encoding, - address='fakeredis', parser=parser) - - try: - if password is not None: - await conn.auth(password) - if db is not None: - await conn.select(db) - except Exception: - conn.close() - await conn.wait_closed() - raise - return conn - - -async def create_redis(server=None, *, db=None, password=None, ssl=None, - encoding=None, commands_factory=aioredis.Redis, - parser=None, timeout=None, - connection_cls=None, loop=None): - conn = await create_connection(server, db=db, - password=password, - ssl=ssl, - encoding=encoding, - parser=parser, - timeout=timeout, - connection_cls=connection_cls, - loop=loop) - return commands_factory(conn) - - -async def create_pool(server=None, *, db=None, password=None, ssl=None, - encoding=None, minsize=1, maxsize=10, - parser=None, loop=None, create_connection_timeout=None, - pool_cls=None, connection_cls=None): - # Mostly copied from aioredis.pool.create_pool. - if pool_cls: - assert issubclass(pool_cls, aioredis.AbcPool),\ - "pool_class does not meet the AbcPool contract" - cls = pool_cls - else: - cls = FakeConnectionsPool - - pool = cls(server, db, password, encoding, - minsize=minsize, maxsize=maxsize, - ssl=ssl, parser=parser, - create_connection_timeout=create_connection_timeout, - connection_cls=connection_cls, - loop=loop) - try: - await pool._fill_free(override_min=False) - except Exception: - pool.close() - await pool.wait_closed() - raise - return pool - - -async def create_redis_pool(server=None, *, db=None, password=None, ssl=None, - encoding=None, commands_factory=aioredis.Redis, - minsize=1, maxsize=10, parser=None, - timeout=None, pool_cls=None, - connection_cls=None, loop=None): - pool = await create_pool(server, db=db, - password=password, - ssl=ssl, - encoding=encoding, - minsize=minsize, - maxsize=maxsize, - parser=parser, - create_connection_timeout=timeout, - pool_cls=pool_cls, - connection_cls=connection_cls, - loop=loop) - return commands_factory(pool) diff --git a/build/lib/fakeredis/_aioredis2.py b/build/lib/fakeredis/_aioredis2.py deleted file mode 100644 index d07d197..0000000 --- a/build/lib/fakeredis/_aioredis2.py +++ /dev/null @@ -1,170 +0,0 @@ -import asyncio -from typing import Union - -import aioredis - -from . import _async, _server - - -class FakeSocket(_async.AsyncFakeSocket): - _connection_error_class = aioredis.ConnectionError - - def _decode_error(self, error): - return aioredis.connection.BaseParser(1).parse_error(error.value) - - -class FakeReader: - pass - - -class FakeWriter: - def __init__(self, socket: FakeSocket) -> None: - self._socket = socket - - def close(self): - self._socket = None - - async def wait_closed(self): - pass - - async def drain(self): - pass - - def writelines(self, data): - for chunk in data: - self._socket.sendall(chunk) - - -class FakeConnection(aioredis.Connection): - def __init__(self, *args, **kwargs): - self._server = kwargs.pop('server') - self._sock = None - super().__init__(*args, **kwargs) - - async def _connect(self): - if not self._server.connected: - raise aioredis.ConnectionError(_server.CONNECTION_ERROR_MSG) - self._sock = FakeSocket(self._server) - self._reader = FakeReader() - self._writer = FakeWriter(self._sock) - - async def disconnect(self): - await super().disconnect() - self._sock = None - - async def can_read(self, timeout: float = 0): - if not self.is_connected: - await self.connect() - if timeout == 0: - return not self._sock.responses.empty() - # asyncio.Queue doesn't have a way to wait for the queue to be - # non-empty without consuming an item, so kludge it with a sleep/poll - # loop. - loop = asyncio.get_event_loop() - start = loop.time() - while True: - if not self._sock.responses.empty(): - return True - await asyncio.sleep(0.01) - now = loop.time() - if timeout is not None and now > start + timeout: - return False - - def _decode(self, response): - if isinstance(response, list): - return [self._decode(item) for item in response] - elif isinstance(response, bytes): - return self.encoder.decode(response) - else: - return response - - async def read_response(self): - if not self._server.connected: - try: - response = self._sock.responses.get_nowait() - except asyncio.QueueEmpty: - raise aioredis.ConnectionError(_server.CONNECTION_ERROR_MSG) - else: - response = await self._sock.responses.get() - if isinstance(response, aioredis.ResponseError): - raise response - return self._decode(response) - - def repr_pieces(self): - pieces = [ - ('server', self._server), - ('db', self.db) - ] - if self.client_name: - pieces.append(('client_name', self.client_name)) - return pieces - - -class FakeRedis(aioredis.Redis): - def __init__( - self, - *, - db: Union[str, int] = 0, - password: str = None, - socket_timeout: float = None, - connection_pool: aioredis.ConnectionPool = None, - encoding: str = "utf-8", - encoding_errors: str = "strict", - decode_responses: bool = False, - retry_on_timeout: bool = False, - max_connections: int = None, - health_check_interval: int = 0, - client_name: str = None, - username: str = None, - server: _server.FakeServer = None, - connected: bool = True, - **kwargs - ): - if not connection_pool: - # Adapted from aioredis - if server is None: - server = _server.FakeServer() - server.connected = connected - connection_kwargs = { - "db": db, - "username": username, - "password": password, - "socket_timeout": socket_timeout, - "encoding": encoding, - "encoding_errors": encoding_errors, - "decode_responses": decode_responses, - "retry_on_timeout": retry_on_timeout, - "max_connections": max_connections, - "health_check_interval": health_check_interval, - "client_name": client_name, - "server": server, - "connection_class": FakeConnection - } - connection_pool = aioredis.ConnectionPool(**connection_kwargs) - super().__init__( - db=db, - password=password, - socket_timeout=socket_timeout, - connection_pool=connection_pool, - encoding=encoding, - encoding_errors=encoding_errors, - decode_responses=decode_responses, - retry_on_timeout=retry_on_timeout, - max_connections=max_connections, - health_check_interval=health_check_interval, - client_name=client_name, - username=username, - **kwargs - ) - - @classmethod - def from_url(https://github.com/cls,%20url:%20str,%20**kwargs): - server = kwargs.pop('server', None) - if server is None: - server = _server.FakeServer() - self = super().from_url(https://github.com/url,%20**kwargs) - # Now override how it creates connections - pool = self.connection_pool - pool.connection_class = FakeConnection - pool.connection_kwargs['server'] = server - return self diff --git a/build/lib/fakeredis/_async.py b/build/lib/fakeredis/_async.py deleted file mode 100644 index ec51d1e..0000000 --- a/build/lib/fakeredis/_async.py +++ /dev/null @@ -1,51 +0,0 @@ -import asyncio - -import async_timeout - -from . import _server - - -class AsyncFakeSocket(_server.FakeSocket): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.responses = asyncio.Queue() - - def put_response(self, msg): - self.responses.put_nowait(msg) - - async def _async_blocking(self, timeout, func, event, callback): - try: - result = None - with async_timeout.timeout(timeout if timeout else None): - while True: - await event.wait() - event.clear() - # This is a coroutine outside the normal control flow that - # locks the server, so we have to take our own lock. - with self._server.lock: - ret = func(False) - if ret is not None: - result = self._decode_result(ret) - self.put_response(result) - break - except asyncio.TimeoutError: - result = None - finally: - with self._server.lock: - self._db.remove_change_callback(callback) - self.put_response(result) - self.resume() - - def _blocking(self, timeout, func): - loop = asyncio.get_event_loop() - ret = func(True) - if ret is not None or self._in_transaction: - return ret - event = asyncio.Event() - - def callback(): - loop.call_soon_threadsafe(event.set) - self._db.add_change_callback(callback) - self.pause() - loop.create_task(self._async_blocking(timeout, func, event, callback)) - return _server.NoResponse() diff --git a/build/lib/fakeredis/_server.py b/build/lib/fakeredis/_server.py deleted file mode 100644 index dbeb438..0000000 --- a/build/lib/fakeredis/_server.py +++ /dev/null @@ -1,2850 +0,0 @@ -import functools -import hashlib -import inspect -import itertools -import logging -import math -import pickle -import queue -import random -import re -import threading -import time -import warnings -import weakref -from collections import defaultdict -from collections.abc import MutableMapping - -import redis -import six - -from ._zset import ZSet - -LOGGER = logging.getLogger('fakeredis') -REDIS_LOG_LEVELS = { - b'LOG_DEBUG': 0, - b'LOG_VERBOSE': 1, - b'LOG_NOTICE': 2, - b'LOG_WARNING': 3 -} -REDIS_LOG_LEVELS_TO_LOGGING = { - 0: logging.DEBUG, - 1: logging.INFO, - 2: logging.INFO, - 3: logging.WARNING -} - -MAX_STRING_SIZE = 512 * 1024 * 1024 - -INVALID_EXPIRE_MSG = "ERR invalid expire time in {}" -WRONGTYPE_MSG = \ - "WRONGTYPE Operation against a key holding the wrong kind of value" -SYNTAX_ERROR_MSG = "ERR syntax error" -INVALID_INT_MSG = "ERR value is not an integer or out of range" -INVALID_FLOAT_MSG = "ERR value is not a valid float" -INVALID_OFFSET_MSG = "ERR offset is out of range" -INVALID_BIT_OFFSET_MSG = "ERR bit offset is not an integer or out of range" -INVALID_BIT_VALUE_MSG = "ERR bit is not an integer or out of range" -INVALID_DB_MSG = "ERR DB index is out of range" -INVALID_MIN_MAX_FLOAT_MSG = "ERR min or max is not a float" -INVALID_MIN_MAX_STR_MSG = "ERR min or max not a valid string range item" -STRING_OVERFLOW_MSG = "ERR string exceeds maximum allowed size (512MB)" -OVERFLOW_MSG = "ERR increment or decrement would overflow" -NONFINITE_MSG = "ERR increment would produce NaN or Infinity" -SCORE_NAN_MSG = "ERR resulting score is not a number (NaN)" -INVALID_SORT_FLOAT_MSG = "ERR One or more scores can't be converted into double" -SRC_DST_SAME_MSG = "ERR source and destination objects are the same" -NO_KEY_MSG = "ERR no such key" -INDEX_ERROR_MSG = "ERR index out of range" -ZADD_NX_XX_ERROR_MSG = "ERR ZADD allows either 'nx' or 'xx', not both" -ZADD_INCR_LEN_ERROR_MSG = "ERR INCR option supports a single increment-element pair" -ZUNIONSTORE_KEYS_MSG = "ERR at least 1 input key is needed for ZUNIONSTORE/ZINTERSTORE" -WRONG_ARGS_MSG = "ERR wrong number of arguments for '{}' command" -UNKNOWN_COMMAND_MSG = "ERR unknown command '{}'" -EXECABORT_MSG = "EXECABORT Transaction discarded because of previous errors." -MULTI_NESTED_MSG = "ERR MULTI calls can not be nested" -WITHOUT_MULTI_MSG = "ERR {0} without MULTI" -WATCH_INSIDE_MULTI_MSG = "ERR WATCH inside MULTI is not allowed" -NEGATIVE_KEYS_MSG = "ERR Number of keys can't be negative" -TOO_MANY_KEYS_MSG = "ERR Number of keys can't be greater than number of args" -TIMEOUT_NEGATIVE_MSG = "ERR timeout is negative" -NO_MATCHING_SCRIPT_MSG = "NOSCRIPT No matching script. Please use EVAL." -GLOBAL_VARIABLE_MSG = "ERR Script attempted to set global variables: {}" -COMMAND_IN_SCRIPT_MSG = "ERR This Redis command is not allowed from scripts" -BAD_SUBCOMMAND_MSG = "ERR Unknown {} subcommand or wrong # of args." -BAD_COMMAND_IN_PUBSUB_MSG = \ - "ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT allowed in this context" -CONNECTION_ERROR_MSG = "FakeRedis is emulating a connection error." -REQUIRES_MORE_ARGS_MSG = "ERR {} requires {} arguments or more." -LOG_INVALID_DEBUG_LEVEL_MSG = "ERR Invalid debug level." -LUA_COMMAND_ARG_MSG = "ERR Lua redis() command arguments must be strings or integers" -LUA_WRONG_NUMBER_ARGS_MSG = "ERR wrong number or type of arguments" -SCRIPT_ERROR_MSG = "ERR Error running script (call to f_{}): @user_script:?: {}" -RESTORE_KEY_EXISTS = "BUSYKEY Target key name already exists." -RESTORE_INVALID_CHECKSUM_MSG = "ERR DUMP payload version or checksum are wrong" -RESTORE_INVALID_TTL_MSG = "ERR Invalid TTL value, must be >= 0" - -FLAG_NO_SCRIPT = 's' # Command not allowed in scripts - -# This needs to be grabbed early to avoid breaking tests that mock redis.Redis. -_ORIG_SIG = inspect.signature(redis.Redis) - - -class SimpleString: - def __init__(self, value): - assert isinstance(value, bytes) - self.value = value - - -class SimpleError(Exception): - """Exception that will be turned into a frontend-specific exception.""" - - def __init__(self, value): - assert isinstance(value, str) - self.value = value - - -class NoResponse: - """Returned by pub/sub commands to indicate that no response should be returned""" - pass - - -OK = SimpleString(b'OK') -QUEUED = SimpleString(b'QUEUED') -PONG = SimpleString(b'PONG') -BGSAVE_STARTED = SimpleString(b'Background saving started') - - -def null_terminate(s): - # Redis uses C functions on some strings, which means they stop at the - # first NULL. - if b'\0' in s: - return s[:s.find(b'\0')] - return s - - -def casenorm(s): - return null_terminate(s).lower() - - -def casematch(a, b): - return casenorm(a) == casenorm(b) - - -def compile_pattern(pattern): - """Compile a glob pattern (e.g. for keys) to a bytes regex. - - fnmatch.fnmatchcase doesn't work for this, because it uses different - escaping rules to redis, uses ! instead of ^ to negate a character set, - and handles invalid cases (such as a [ without a ]) differently. This - implementation was written by studying the redis implementation. - """ - # It's easier to work with text than bytes, because indexing bytes - # doesn't behave the same in Python 3. Latin-1 will round-trip safely. - pattern = pattern.decode('latin-1') - parts = ['^'] - i = 0 - L = len(pattern) - while i < L: - c = pattern[i] - i += 1 - if c == '?': - parts.append('.') - elif c == '*': - parts.append('.*') - elif c == '\\': - if i == L: - i -= 1 - parts.append(re.escape(pattern[i])) - i += 1 - elif c == '[': - parts.append('[') - if i < L and pattern[i] == '^': - i += 1 - parts.append('^') - parts_len = len(parts) # To detect if anything was added - while i < L: - if pattern[i] == '\\' and i + 1 < L: - i += 1 - parts.append(re.escape(pattern[i])) - elif pattern[i] == ']': - i += 1 - break - elif i + 2 < L and pattern[i + 1] == '-': - start = pattern[i] - end = pattern[i + 2] - if start > end: - start, end = end, start - parts.append(re.escape(start) + '-' + re.escape(end)) - i += 2 - else: - parts.append(re.escape(pattern[i])) - i += 1 - if len(parts) == parts_len: - if parts[-1] == '[': - # Empty group - will never match - parts[-1] = '(?:$.)' - else: - # Negated empty group - matches any character - assert parts[-1] == '^' - parts.pop() - parts[-1] = '.' - else: - parts.append(']') - else: - parts.append(re.escape(c)) - parts.append('\\Z') - regex = ''.join(parts).encode('latin-1') - return re.compile(regex, re.S) - - -class Item: - """An item stored in the database""" - - __slots__ = ['value', 'expireat'] - - def __init__(self, value): - self.value = value - self.expireat = None - - -class CommandItem: - """An item referenced by a command. - - It wraps an Item but has extra fields to manage updates and notifications. - """ - def __init__(self, key, db, item=None, default=None): - if item is None: - self._value = default - self._expireat = None - else: - self._value = item.value - self._expireat = item.expireat - self.key = key - self.db = db - self._modified = False - self._expireat_modified = False - - @property - def value(self): - return self._value - - @value.setter - def value(self, new_value): - self._value = new_value - self._modified = True - self.expireat = None - - @property - def expireat(self): - return self._expireat - - @expireat.setter - def expireat(self, value): - self._expireat = value - self._expireat_modified = True - self._modified = True # Since redis 6.0.7 - - def get(self, default): - return self._value if self else default - - def update(self, new_value): - self._value = new_value - self._modified = True - - def updated(self): - self._modified = True - - def writeback(self): - if self._modified: - self.db.notify_watch(self.key) - if not isinstance(self.value, bytes) and not self.value: - self.db.pop(self.key, None) - return - else: - item = self.db.setdefault(self.key, Item(None)) - item.value = self.value - item.expireat = self.expireat - elif self._expireat_modified and self.key in self.db: - self.db[self.key].expireat = self.expireat - - def __bool__(self): - return bool(self._value) or isinstance(self._value, bytes) - - __nonzero__ = __bool__ # For Python 2 - - -class Database(MutableMapping): - def __init__(self, lock, *args, **kwargs): - self._dict = dict(*args, **kwargs) - self.time = 0.0 - self._watches = defaultdict(weakref.WeakSet) # key to set of connections - self.condition = threading.Condition(lock) - self._change_callbacks = set() - - def swap(self, other): - self._dict, other._dict = other._dict, self._dict - self.time, other.time = other.time, self.time - - def notify_watch(self, key): - for sock in self._watches.get(key, set()): - sock.notify_watch() - self.condition.notify_all() - for callback in self._change_callbacks: - callback() - - def add_watch(self, key, sock): - self._watches[key].add(sock) - - def remove_watch(self, key, sock): - watches = self._watches[key] - watches.discard(sock) - if not watches: - del self._watches[key] - - def add_change_callback(self, callback): - self._change_callbacks.add(callback) - - def remove_change_callback(self, callback): - self._change_callbacks.remove(callback) - - def clear(self): - for key in self: - self.notify_watch(key) - self._dict.clear() - - def expired(self, item): - return item.expireat is not None and item.expireat < self.time - - def _remove_expired(self): - for key in list(self._dict): - item = self._dict[key] - if self.expired(item): - del self._dict[key] - - def __getitem__(self, key): - item = self._dict[key] - if self.expired(item): - del self._dict[key] - raise KeyError(key) - return item - - def __setitem__(self, key, value): - self._dict[key] = value - - def __delitem__(self, key): - del self._dict[key] - - def __iter__(self): - self._remove_expired() - return iter(self._dict) - - def __len__(self): - self._remove_expired() - return len(self._dict) - - def __hash__(self): - return hash(super(object, self)) - - def __eq__(self, other): - return super(object, self) == other - - -class Hash(dict): - redis_type = b'hash' - - -class Int: - """Argument converter for 64-bit signed integers""" - - DECODE_ERROR = INVALID_INT_MSG - ENCODE_ERROR = OVERFLOW_MSG - MIN_VALUE = -2**63 - MAX_VALUE = 2**63 - 1 - - @classmethod - def valid(cls, value): - return cls.MIN_VALUE <= value <= cls.MAX_VALUE - - @classmethod - def decode(cls, value): - try: - out = int(value) - if not cls.valid(out) or str(out).encode() != value: - raise ValueError - except ValueError: - raise SimpleError(cls.DECODE_ERROR) - return out - - @classmethod - def encode(cls, value): - if cls.valid(value): - return str(value).encode() - else: - raise SimpleError(cls.ENCODE_ERROR) - - -class BitOffset(Int): - """Argument converter for unsigned bit positions""" - - DECODE_ERROR = INVALID_BIT_OFFSET_MSG - MIN_VALUE = 0 - MAX_VALUE = 8 * MAX_STRING_SIZE - 1 # Redis imposes 512MB limit on keys - - -class BitValue(Int): - DECODE_ERROR = INVALID_BIT_VALUE_MSG - MIN_VALUE = 0 - MAX_VALUE = 1 - - -class DbIndex(Int): - """Argument converter for database indices""" - - DECODE_ERROR = INVALID_DB_MSG - MIN_VALUE = 0 - MAX_VALUE = 15 - - -class Timeout(Int): - """Argument converter for timeouts""" - - DECODE_ERROR = TIMEOUT_NEGATIVE_MSG - MIN_VALUE = 0 - - -class Float: - """Argument converter for floating-point values. - - Redis uses long double for some cases (INCRBYFLOAT, HINCRBYFLOAT) - and double for others (zset scores), but Python doesn't support - long double. - """ - - DECODE_ERROR = INVALID_FLOAT_MSG - - @classmethod - def decode(cls, value, - allow_leading_whitespace=False, - allow_erange=False, - allow_empty=False, - crop_null=False): - # redis has some quirks in float parsing, with several variants. - # See https://github.com/antirez/redis/issues/5706 - try: - if crop_null: - value = null_terminate(value) - if allow_empty and value == b'': - value = b'0.0' - if not allow_leading_whitespace and value[:1].isspace(): - raise ValueError - if value[-1:].isspace(): - raise ValueError - out = float(value) - if math.isnan(out): - raise ValueError - if not allow_erange: - # Values that over- or underflow- are explicitly rejected by - # redis. This is a crude hack to determine whether the input - # may have been such a value. - if out in (math.inf, -math.inf, 0.0) and re.match(b'^[^a-zA-Z]*[1-9]', value): - raise ValueError - return out - except ValueError: - raise SimpleError(cls.DECODE_ERROR) - - @classmethod - def encode(cls, value, humanfriendly): - if math.isinf(value): - return str(value).encode() - elif humanfriendly: - # Algorithm from ld2string in redis - out = '{:.17f}'.format(value) - out = re.sub(r'(?:\.)?0+$', '', out) - return out.encode() - else: - return '{:.17g}'.format(value).encode() - - -class SortFloat(Float): - DECODE_ERROR = INVALID_SORT_FLOAT_MSG - - @classmethod - def decode(cls, value): - return super().decode( - value, allow_leading_whitespace=True, allow_empty=True, crop_null=True) - - -class ScoreTest: - """Argument converter for sorted set score endpoints.""" - def __init__(self, value, exclusive=False): - self.value = value - self.exclusive = exclusive - - @classmethod - def decode(cls, value): - try: - exclusive = False - if value[:1] == b'(': - exclusive = True - value = value[1:] - value = Float.decode( - value, allow_leading_whitespace=True, allow_erange=True, - allow_empty=True, crop_null=True) - return cls(value, exclusive) - except SimpleError: - raise SimpleError(INVALID_MIN_MAX_FLOAT_MSG) - - def __str__(self): - if self.exclusive: - return '({!r}'.format(self.value) - else: - return repr(self.value) - - @property - def lower_bound(self): - return (self.value, AfterAny() if self.exclusive else BeforeAny()) - - @property - def upper_bound(self): - return (self.value, BeforeAny() if self.exclusive else AfterAny()) - - -class StringTest: - """Argument converter for sorted set LEX endpoints.""" - def __init__(self, value, exclusive): - self.value = value - self.exclusive = exclusive - - @classmethod - def decode(cls, value): - if value == b'-': - return cls(BeforeAny(), True) - elif value == b'+': - return cls(AfterAny(), True) - elif value[:1] == b'(': - return cls(value[1:], True) - elif value[:1] == b'[': - return cls(value[1:], False) - else: - raise SimpleError(INVALID_MIN_MAX_STR_MSG) - - -@functools.total_ordering -class BeforeAny: - def __gt__(self, other): - return False - - def __eq__(self, other): - return isinstance(other, BeforeAny) - - -@functools.total_ordering -class AfterAny: - def __lt__(self, other): - return False - - def __eq__(self, other): - return isinstance(other, AfterAny) - - -class Key: - """Marker to indicate that argument in signature is a key""" - UNSPECIFIED = object() - - def __init__(self, type_=None, missing_return=UNSPECIFIED): - self.type_ = type_ - self.missing_return = missing_return - - -class Signature: - def __init__(self, name, fixed, repeat=(), flags=""): - self.name = name - self.fixed = fixed - self.repeat = repeat - self.flags = flags - - def check_arity(self, args): - if len(args) != len(self.fixed): - delta = len(args) - len(self.fixed) - if delta < 0 or not self.repeat: - raise SimpleError(WRONG_ARGS_MSG.format(self.name)) - - def apply(self, args, db): - """Returns a tuple, which is either: - - transformed args and a dict of CommandItems; or - - a single containing a short-circuit return value - """ - self.check_arity(args) - if self.repeat: - delta = len(args) - len(self.fixed) - if delta % len(self.repeat) != 0: - raise SimpleError(WRONG_ARGS_MSG.format(self.name)) - - types = list(self.fixed) - for i in range(len(args) - len(types)): - types.append(self.repeat[i % len(self.repeat)]) - - args = list(args) - # First pass: convert/validate non-keys, and short-circuit on missing keys - for i, (arg, type_) in enumerate(zip(args, types)): - if isinstance(type_, Key): - if type_.missing_return is not Key.UNSPECIFIED and arg not in db: - return (type_.missing_return,) - elif type_ != bytes: - args[i] = type_.decode(args[i]) - - # Second pass: read keys and check their types - command_items = [] - for i, (arg, type_) in enumerate(zip(args, types)): - if isinstance(type_, Key): - item = db.get(arg) - default = None - if type_.type_ is not None: - if item is not None and type(item.value) != type_.type_: - raise SimpleError(WRONGTYPE_MSG) - if item is None: - if type_.type_ is not bytes: - default = type_.type_() - args[i] = CommandItem(arg, db, item, default=default) - command_items.append(args[i]) - - return args, command_items - - -def valid_response_type(value, nested=False): - if isinstance(value, NoResponse) and not nested: - return True - if value is not None and not isinstance(value, (bytes, SimpleString, SimpleError, - int, list)): - return False - if isinstance(value, list): - if any(not valid_response_type(item, True) for item in value): - return False - return True - - -def command(*args, **kwargs): - def decorator(func): - name = kwargs.pop('name', func.__name__) - func._fakeredis_sig = Signature(name, *args, **kwargs) - return func - - return decorator - - -class FakeServer: - def __init__(self): - self.lock = threading.Lock() - self.dbs = defaultdict(lambda: Database(self.lock)) - # Maps SHA1 to script source - self.script_cache = {} - # Maps channel/pattern to weak set of sockets - self.subscribers = defaultdict(weakref.WeakSet) - self.psubscribers = defaultdict(weakref.WeakSet) - self.lastsave = int(time.time()) - self.connected = True - # List of weakrefs to sockets that are being closed lazily - self.closed_sockets = [] - - -class FakeSocket: - _connection_error_class = redis.ConnectionError - - def __init__(self, server): - self._server = server - self._db = server.dbs[0] - self._db_num = 0 - # When in a MULTI, set to a list of function calls - self._transaction = None - self._transaction_failed = False - # Set when executing the commands from EXEC - self._in_transaction = False - self._watch_notified = False - self._watches = set() - self._pubsub = 0 # Count of subscriptions - self.responses = queue.Queue() - # Prevents parser from processing commands. Not used in this module, - # but set by aioredis module to prevent new commands being processed - # while handling a blocking command. - self._paused = False - self._parser = self._parse_commands() - self._parser.send(None) - - def put_response(self, msg): - # redis.Connection.__del__ might call self.close at any time, which - # will set self.responses to None. We assume this will happen - # atomically, and the code below then protects us against this. - responses = self.responses - if responses: - responses.put(msg) - - def pause(self): - self._paused = True - - def resume(self): - self._paused = False - self._parser.send(b'') - - def shutdown(self, flags): - self._parser.close() - - def fileno(self): - # Our fake socket must return an integer from `FakeSocket.fileno()` since a real selector - # will be created. The value does not matter since we replace the selector with our own - # `FakeSelector` before it is ever used. - return 0 - - def _cleanup(self, server): - """Remove all the references to `self` from `server`. - - This is called with the server lock held, but it may be some time after - self.close. - """ - for subs in server.subscribers.values(): - subs.discard(self) - for subs in server.psubscribers.values(): - subs.discard(self) - self._clear_watches() - - def close(self): - # Mark ourselves for cleanup. This might be called from - # redis.Connection.__del__, which the garbage collection could call - # at any time, and hence we can't safely take the server lock. - # We rely on list.append being atomic. - self._server.closed_sockets.append(weakref.ref(self)) - self._server = None - self._db = None - self.responses = None - - @staticmethod - def _extract_line(buf): - pos = buf.find(b'\n') + 1 - assert pos > 0 - line = buf[:pos] - buf = buf[pos:] - assert line.endswith(b'\r\n') - return line, buf - - def _parse_commands(self): - """Generator that parses commands. - - It is fed pieces of redis protocol data (via `send`) and calls - `_process_command` whenever it has a complete one. - """ - buf = b'' - while True: - while self._paused or b'\n' not in buf: - buf += yield - line, buf = self._extract_line(buf) - assert line[:1] == b'*' # array - n_fields = int(line[1:-2]) - fields = [] - for i in range(n_fields): - while b'\n' not in buf: - buf += yield - line, buf = self._extract_line(buf) - assert line[:1] == b'$' # string - length = int(line[1:-2]) - while len(buf) < length + 2: - buf += yield - fields.append(buf[:length]) - buf = buf[length+2:] # +2 to skip the CRLF - self._process_command(fields) - - def _run_command(self, func, sig, args, from_script): - command_items = {} - try: - ret = sig.apply(args, self._db) - if len(ret) == 1: - result = ret[0] - else: - args, command_items = ret - if from_script and FLAG_NO_SCRIPT in sig.flags: - raise SimpleError(COMMAND_IN_SCRIPT_MSG) - if self._pubsub and sig.name not in [ - 'ping', 'subscribe', 'unsubscribe', - 'psubscribe', 'punsubscribe', 'quit']: - raise SimpleError(BAD_COMMAND_IN_PUBSUB_MSG) - result = func(*args) - assert valid_response_type(result) - except SimpleError as exc: - result = exc - for command_item in command_items: - command_item.writeback() - return result - - def _decode_error(self, error): - return redis.connection.BaseParser().parse_error(error.value) - - def _decode_result(self, result): - """Convert SimpleString and SimpleError, recursively""" - if isinstance(result, list): - return [self._decode_result(r) for r in result] - elif isinstance(result, SimpleString): - return result.value - elif isinstance(result, SimpleError): - return self._decode_error(result) - else: - return result - - def _blocking(self, timeout, func): - """Run a function until it succeeds or timeout is reached. - - The timeout must be an integer, and 0 means infinite. The function - is called with a boolean to indicate whether this is the first call. - If it returns None it is considered to have "failed" and is retried - each time the condition variable is notified, until the timeout is - reached. - - Returns the function return value, or None if the timeout was reached. - """ - ret = func(True) - if ret is not None or self._in_transaction: - return ret - if timeout: - deadline = time.time() + timeout - else: - deadline = None - while True: - timeout = deadline - time.time() if deadline is not None else None - if timeout is not None and timeout <= 0: - return None - # Python <3.2 doesn't return a status from wait. On Python 3.2+ - # we bail out early on False. - if self._db.condition.wait(timeout=timeout) is False: - return None # Timeout expired - ret = func(False) - if ret is not None: - return ret - - def _name_to_func(self, name): - name = six.ensure_str(name, encoding='utf-8', errors='replace') - func_name = name.lower() - func = getattr(self, func_name, None) - if name.startswith('_') or not func or not hasattr(func, '_fakeredis_sig'): - # redis remaps \r or \n in an error to ' ' to make it legal protocol - clean_name = name.replace('\r', ' ').replace('\n', ' ') - raise SimpleError(UNKNOWN_COMMAND_MSG.format(clean_name)) - return func, func_name - - def sendall(self, data): - if not self._server.connected: - raise self._connection_error_class(CONNECTION_ERROR_MSG) - if isinstance(data, str): - data = data.encode('ascii') - self._parser.send(data) - - def _process_command(self, fields): - if not fields: - return - func_name = None - try: - func, func_name = self._name_to_func(fields[0]) - sig = func._fakeredis_sig - with self._server.lock: - # Clean out old connections - while True: - try: - weak_sock = self._server.closed_sockets.pop() - except IndexError: - break - else: - sock = weak_sock() - if sock: - sock._cleanup(self._server) - now = time.time() - for db in self._server.dbs.values(): - db.time = now - sig.check_arity(fields[1:]) - # TODO: make a signature attribute for transactions - if self._transaction is not None \ - and func_name not in ('exec', 'discard', 'multi', 'watch'): - self._transaction.append((func, sig, fields[1:])) - result = QUEUED - else: - result = self._run_command(func, sig, fields[1:], False) - except SimpleError as exc: - if self._transaction is not None: - # TODO: should not apply if the exception is from _run_command - # e.g. watch inside multi - self._transaction_failed = True - if func_name == 'exec' and exc.value.startswith('ERR '): - exc.value = 'EXECABORT Transaction discarded because of: ' + exc.value[4:] - self._transaction = None - self._transaction_failed = False - self._clear_watches() - result = exc - result = self._decode_result(result) - if not isinstance(result, NoResponse): - self.put_response(result) - - def notify_watch(self): - self._watch_notified = True - - # redis has inconsistent handling of negative indices, hence two versions - # of this code. - - @staticmethod - def _fix_range_string(start, end, length): - # Negative number handling is based on the redis source code - if start < 0 and end < 0 and start > end: - return -1, -1 - if start < 0: - start = max(0, start + length) - if end < 0: - end = max(0, end + length) - end = min(end, length - 1) - return start, end + 1 - - @staticmethod - def _fix_range(start, end, length): - # Redis handles negative slightly differently for zrange - if start < 0: - start = max(0, start + length) - if end < 0: - end += length - if start > end or start >= length: - return -1, -1 - end = min(end, length - 1) - return start, end + 1 - - def _scan(self, keys, cursor, *args): - """ - This is the basis of most of the ``scan`` methods. - - This implementation is KNOWN to be un-performant, as it requires - grabbing the full set of keys over which we are investigating subsets. - - It also doesn't adhere to the guarantee that every key will be iterated - at least once even if the database is modified during the scan. - However, provided the database is not modified, every key will be - returned exactly once. - """ - pattern = None - type = None - count = 10 - if len(args) % 2 != 0: - raise SimpleError(SYNTAX_ERROR_MSG) - for i in range(0, len(args), 2): - if casematch(args[i], b'match'): - pattern = args[i + 1] - elif casematch(args[i], b'count'): - count = Int.decode(args[i + 1]) - if count <= 0: - raise SimpleError(SYNTAX_ERROR_MSG) - elif casematch(args[i], b'type'): - type = args[i + 1] - else: - raise SimpleError(SYNTAX_ERROR_MSG) - - if cursor >= len(keys): - return [0, []] - data = sorted(keys) - result_cursor = cursor + count - result_data = [] - - regex = compile_pattern(pattern) if pattern is not None else None - - def match_key(key): - return regex.match(key) if pattern is not None else True - - def match_type(key): - if type is not None: - return casematch(self.type(self._db[key]).value, type) - return True - - if pattern is not None or type is not None: - for val in itertools.islice(data, cursor, result_cursor): - compare_val = val[0] if isinstance(val, tuple) else val - if match_key(compare_val) and match_type(compare_val): - result_data.append(val) - else: - result_data = data[cursor:result_cursor] - - if result_cursor >= len(data): - result_cursor = 0 - return [result_cursor, result_data] - - # Connection commands - # TODO: auth, quit - - @command((bytes,)) - def echo(self, message): - return message - - @command((), (bytes,)) - def ping(self, *args): - if len(args) > 1: - raise SimpleError(WRONG_ARGS_MSG.format('ping')) - if self._pubsub: - return [b'pong', args[0] if args else b''] - else: - return args[0] if args else PONG - - @command((DbIndex,)) - def select(self, index): - self._db = self._server.dbs[index] - self._db_num = index - return OK - - @command((DbIndex, DbIndex)) - def swapdb(self, index1, index2): - if index1 != index2: - db1 = self._server.dbs[index1] - db2 = self._server.dbs[index2] - db1.swap(db2) - return OK - - # Key commands - # TODO: lots - - def _delete(self, *keys): - ans = 0 - done = set() - for key in keys: - if key and key.key not in done: - key.value = None - done.add(key.key) - ans += 1 - return ans - - @command((Key(),), (Key(),), name='del') - def del_(self, *keys): - return self._delete(*keys) - - @command((Key(),), (Key(),), name='unlink') - def unlink(self, *keys): - return self._delete(*keys) - - @command((Key(),), (Key(),)) - def exists(self, *keys): - ret = 0 - for key in keys: - if key: - ret += 1 - return ret - - def _expireat(self, key, timestamp): - if not key: - return 0 - else: - key.expireat = timestamp - return 1 - - def _ttl(self, key, scale): - if not key: - return -2 - elif key.expireat is None: - return -1 - else: - return int(round((key.expireat - self._db.time) * scale)) - - @command((Key(), Int)) - def expire(self, key, seconds): - return self._expireat(key, self._db.time + seconds) - - @command((Key(), Int)) - def expireat(self, key, timestamp): - return self._expireat(key, float(timestamp)) - - @command((Key(), Int)) - def pexpire(self, key, ms): - return self._expireat(key, self._db.time + ms / 1000.0) - - @command((Key(), Int)) - def pexpireat(self, key, ms_timestamp): - return self._expireat(key, ms_timestamp / 1000.0) - - @command((Key(),)) - def ttl(self, key): - return self._ttl(key, 1.0) - - @command((Key(),)) - def pttl(self, key): - return self._ttl(key, 1000.0) - - @command((Key(),)) - def type(self, key): - if key.value is None: - return SimpleString(b'none') - elif isinstance(key.value, bytes): - return SimpleString(b'string') - elif isinstance(key.value, list): - return SimpleString(b'list') - elif isinstance(key.value, set): - return SimpleString(b'set') - elif isinstance(key.value, ZSet): - return SimpleString(b'zset') - elif isinstance(key.value, dict): - return SimpleString(b'hash') - else: - assert False # pragma: nocover - - @command((Key(),)) - def persist(self, key): - if key.expireat is None: - return 0 - key.expireat = None - return 1 - - @command((bytes,)) - def keys(self, pattern): - if pattern == b'*': - return list(self._db) - else: - regex = compile_pattern(pattern) - return [key for key in self._db if regex.match(key)] - - @command((Key(), DbIndex)) - def move(self, key, db): - if db == self._db_num: - raise SimpleError(SRC_DST_SAME_MSG) - if not key or key.key in self._server.dbs[db]: - return 0 - # TODO: what is the interaction with expiry? - self._server.dbs[db][key.key] = self._server.dbs[self._db_num][key.key] - key.value = None # Causes deletion - return 1 - - @command(()) - def randomkey(self): - keys = list(self._db.keys()) - if not keys: - return None - return random.choice(keys) - - @command((Key(), Key())) - def rename(self, key, newkey): - if not key: - raise SimpleError(NO_KEY_MSG) - # TODO: check interaction with WATCH - if newkey.key != key.key: - newkey.value = key.value - newkey.expireat = key.expireat - key.value = None - return OK - - @command((Key(), Key())) - def renamenx(self, key, newkey): - if not key: - raise SimpleError(NO_KEY_MSG) - if newkey: - return 0 - self.rename(key, newkey) - return 1 - - @command((Int,), (bytes, bytes)) - def scan(self, cursor, *args): - return self._scan(list(self._db), cursor, *args) - - def _lookup_key(self, key, pattern): - """Python implementation of lookupKeyByPattern from redis""" - if pattern == b'#': - return key - p = pattern.find(b'*') - if p == -1: - return None - prefix = pattern[:p] - suffix = pattern[p+1:] - arrow = suffix.find(b'->', 0, -1) - if arrow != -1: - field = suffix[arrow+2:] - suffix = suffix[:arrow] - else: - field = None - new_key = prefix + key + suffix - item = CommandItem(new_key, self._db, item=self._db.get(new_key)) - if item.value is None: - return None - if field is not None: - if not isinstance(item.value, dict): - return None - return item.value.get(field) - else: - if not isinstance(item.value, bytes): - return None - return item.value - - @command((Key(),), (bytes,)) - def sort(self, key, *args): - i = 0 - desc = False - alpha = False - limit_start = 0 - limit_count = -1 - store = None - sortby = None - dontsort = False - get = [] - if key.value is not None: - if not isinstance(key.value, (set, list, ZSet)): - raise SimpleError(WRONGTYPE_MSG) - - while i < len(args): - arg = args[i] - if casematch(arg, b'asc'): - desc = False - elif casematch(arg, b'desc'): - desc = True - elif casematch(arg, b'alpha'): - alpha = True - elif casematch(arg, b'limit') and i + 2 < len(args): - try: - limit_start = Int.decode(args[i + 1]) - limit_count = Int.decode(args[i + 2]) - except SimpleError: - raise SimpleError(SYNTAX_ERROR_MSG) - else: - i += 2 - elif casematch(arg, b'store') and i + 1 < len(args): - store = args[i + 1] - i += 1 - elif casematch(arg, b'by') and i + 1 < len(args): - sortby = args[i + 1] - if b'*' not in sortby: - dontsort = True - i += 1 - elif casematch(arg, b'get') and i + 1 < len(args): - get.append(args[i + 1]) - i += 1 - else: - raise SimpleError(SYNTAX_ERROR_MSG) - i += 1 - - # TODO: force sorting if the object is a set and either in Lua or - # storing to a key, to match redis behaviour. - items = list(key.value) if key.value is not None else [] - - # These transformations are based on the redis implementation, but - # changed to produce a half-open range. - start = max(limit_start, 0) - end = len(items) if limit_count < 0 else start + limit_count - if start >= len(items): - start = end = len(items) - 1 - end = min(end, len(items)) - - if not get: - get.append(b'#') - if sortby is None: - sortby = b'#' - - if not dontsort: - if alpha: - def sort_key(v): - byval = self._lookup_key(v, sortby) - # TODO: use locale.strxfrm when not storing? But then need - # to decode too. - if byval is None: - byval = BeforeAny() - return byval - - else: - def sort_key(v): - byval = self._lookup_key(v, sortby) - score = SortFloat.decode(byval) if byval is not None else 0.0 - return (score, v) - - items.sort(key=sort_key, reverse=desc) - elif isinstance(key.value, (list, ZSet)): - items.reverse() - - out = [] - for row in items[start:end]: - for g in get: - v = self._lookup_key(row, g) - if store is not None and v is None: - v = b'' - out.append(v) - if store is not None: - item = CommandItem(store, self._db, item=self._db.get(store)) - item.value = out - item.writeback() - return len(out) - else: - return out - - @command((Key(missing_return=None),)) - def dump(self, key): - value = pickle.dumps(key.value) - checksum = hashlib.sha1(value).digest() - return checksum + value - - @command((Key(), Int, bytes), (bytes,)) - def restore(self, key, ttl, value, *args): - replace = False - i = 0 - while i < len(args): - if casematch(args[i], b'replace'): - replace = True - i += 1 - else: - raise SimpleError(SYNTAX_ERROR_MSG) - if key and not replace: - raise SimpleError(RESTORE_KEY_EXISTS) - checksum, value = value[:20], value[20:] - if hashlib.sha1(value).digest() != checksum: - raise SimpleError(RESTORE_INVALID_CHECKSUM_MSG) - if ttl < 0: - raise SimpleError(RESTORE_INVALID_TTL_MSG) - if ttl == 0: - expireat = None - else: - expireat = self._db.time + ttl / 1000.0 - key.value = pickle.loads(value) - key.expireat = expireat - return OK - - # Transaction commands - - def _clear_watches(self): - self._watch_notified = False - while self._watches: - (key, db) = self._watches.pop() - db.remove_watch(key, self) - - @command((), flags='s') - def multi(self): - if self._transaction is not None: - raise SimpleError(MULTI_NESTED_MSG) - self._transaction = [] - self._transaction_failed = False - return OK - - @command((), flags='s') - def discard(self): - if self._transaction is None: - raise SimpleError(WITHOUT_MULTI_MSG.format('DISCARD')) - self._transaction = None - self._transaction_failed = False - self._clear_watches() - return OK - - @command((), name='exec', flags='s') - def exec_(self): - if self._transaction is None: - raise SimpleError(WITHOUT_MULTI_MSG.format('EXEC')) - if self._transaction_failed: - self._transaction = None - self._clear_watches() - raise SimpleError(EXECABORT_MSG) - transaction = self._transaction - self._transaction = None - self._transaction_failed = False - watch_notified = self._watch_notified - self._clear_watches() - if watch_notified: - return None - result = [] - for func, sig, args in transaction: - try: - self._in_transaction = True - ans = self._run_command(func, sig, args, False) - except SimpleError as exc: - ans = exc - finally: - self._in_transaction = False - result.append(ans) - return result - - @command((Key(),), (Key(),), flags='s') - def watch(self, *keys): - if self._transaction is not None: - raise SimpleError(WATCH_INSIDE_MULTI_MSG) - for key in keys: - if key not in self._watches: - self._watches.add((key.key, self._db)) - self._db.add_watch(key.key, self) - return OK - - @command((), flags='s') - def unwatch(self): - self._clear_watches() - return OK - - # String commands - # TODO: bitfield, bitop, bitpos - - @command((Key(bytes), bytes)) - def append(self, key, value): - old = key.get(b'') - if len(old) + len(value) > MAX_STRING_SIZE: - raise SimpleError(STRING_OVERFLOW_MSG) - key.update(key.get(b'') + value) - return len(key.value) - - @command((Key(bytes, 0),), (bytes,)) - def bitcount(self, key, *args): - # Redis checks the argument count before decoding integers. That's why - # we can't declare them as Int. - if args: - if len(args) != 2: - raise SimpleError(SYNTAX_ERROR_MSG) - start = Int.decode(args[0]) - end = Int.decode(args[1]) - start, end = self._fix_range_string(start, end, len(key.value)) - value = key.value[start:end] - else: - value = key.value - return bin(int.from_bytes(value, 'little')).count('1') - - @command((Key(bytes), Int)) - def decrby(self, key, amount): - return self.incrby(key, -amount) - - @command((Key(bytes),)) - def decr(self, key): - return self.incrby(key, -1) - - @command((Key(bytes), Int)) - def incrby(self, key, amount): - c = Int.decode(key.get(b'0')) + amount - key.update(Int.encode(c)) - return c - - @command((Key(bytes),)) - def incr(self, key): - return self.incrby(key, 1) - - @command((Key(bytes), bytes)) - def incrbyfloat(self, key, amount): - # TODO: introduce convert_order so that we can specify amount is Float - c = Float.decode(key.get(b'0')) + Float.decode(amount) - if not math.isfinite(c): - raise SimpleError(NONFINITE_MSG) - encoded = Float.encode(c, True) - key.update(encoded) - return encoded - - @command((Key(bytes),)) - def get(self, key): - return key.get(None) - - @command((Key(bytes), BitOffset)) - def getbit(self, key, offset): - value = key.get(b'') - byte = offset // 8 - remaining = offset % 8 - actual_bitoffset = 7 - remaining - try: - actual_val = value[byte] - except IndexError: - return 0 - return 1 if (1 << actual_bitoffset) & actual_val else 0 - - @command((Key(bytes), BitOffset, BitValue)) - def setbit(self, key, offset, value): - val = key.get(b'\x00') - byte = offset // 8 - remaining = offset % 8 - actual_bitoffset = 7 - remaining - if len(val) - 1 < byte: - # We need to expand val so that we can set the appropriate - # bit. - needed = byte - (len(val) - 1) - val += b'\x00' * needed - old_byte = val[byte] - if value == 1: - new_byte = old_byte | (1 << actual_bitoffset) - else: - new_byte = old_byte & ~(1 << actual_bitoffset) - old_value = value if old_byte == new_byte else 1 - value - reconstructed = bytearray(val) - reconstructed[byte] = new_byte - key.update(bytes(reconstructed)) - return old_value - - @command((Key(bytes), Int, Int)) - def getrange(self, key, start, end): - value = key.get(b'') - start, end = self._fix_range_string(start, end, len(value)) - return value[start:end] - - # substr is a deprecated alias for getrange - @command((Key(bytes), Int, Int)) - def substr(self, key, start, end): - return self.getrange(key, start, end) - - @command((Key(bytes), bytes)) - def getset(self, key, value): - old = key.value - key.value = value - return old - - @command((Key(),), (Key(),)) - def mget(self, *keys): - return [key.value if isinstance(key.value, bytes) else None for key in keys] - - @command((Key(), bytes), (Key(), bytes)) - def mset(self, *args): - for i in range(0, len(args), 2): - args[i].value = args[i + 1] - return OK - - @command((Key(), bytes), (Key(), bytes)) - def msetnx(self, *args): - for i in range(0, len(args), 2): - if args[i]: - return 0 - for i in range(0, len(args), 2): - args[i].value = args[i + 1] - return 1 - - @command((Key(), bytes), (bytes,), name='set') - def set_(self, key, value, *args): - i = 0 - ex = None - px = None - xx = False - nx = False - keepttl = False - get = False - while i < len(args): - if casematch(args[i], b'nx'): - nx = True - i += 1 - elif casematch(args[i], b'xx'): - xx = True - i += 1 - elif casematch(args[i], b'ex') and i + 1 < len(args): - ex = Int.decode(args[i + 1]) - if ex <= 0 or (self._db.time + ex) * 1000 >= 2**63: - raise SimpleError(INVALID_EXPIRE_MSG.format('set')) - i += 2 - elif casematch(args[i], b'px') and i + 1 < len(args): - px = Int.decode(args[i + 1]) - if px <= 0 or self._db.time * 1000 + px >= 2**63: - raise SimpleError(INVALID_EXPIRE_MSG.format('set')) - i += 2 - elif casematch(args[i], b'keepttl'): - keepttl = True - i += 1 - elif casematch(args[i], b'get'): - get = True - i += 1 - else: - raise SimpleError(SYNTAX_ERROR_MSG) - if (xx and nx) or ((px is not None) + (ex is not None) + keepttl > 1): - raise SimpleError(SYNTAX_ERROR_MSG) - if nx and get: - # The command docs say this is allowed from Redis 7.0. - raise SimpleError(SYNTAX_ERROR_MSG) - - old_value = None - if get: - if key.value is not None and type(key.value) is not bytes: - raise SimpleError(WRONGTYPE_MSG) - old_value = key.value - - if nx and key: - return old_value - if xx and not key: - return old_value - if not keepttl: - key.value = value - else: - key.update(value) - if ex is not None: - key.expireat = self._db.time + ex - if px is not None: - key.expireat = self._db.time + px / 1000.0 - return OK if not get else old_value - - @command((Key(), Int, bytes)) - def setex(self, key, seconds, value): - if seconds <= 0 or (self._db.time + seconds) * 1000 >= 2**63: - raise SimpleError(INVALID_EXPIRE_MSG.format('setex')) - key.value = value - key.expireat = self._db.time + seconds - return OK - - @command((Key(), Int, bytes)) - def psetex(self, key, ms, value): - if ms <= 0 or self._db.time * 1000 + ms >= 2**63: - raise SimpleError(INVALID_EXPIRE_MSG.format('psetex')) - key.value = value - key.expireat = self._db.time + ms / 1000.0 - return OK - - @command((Key(), bytes)) - def setnx(self, key, value): - if key: - return 0 - key.value = value - return 1 - - @command((Key(bytes), Int, bytes)) - def setrange(self, key, offset, value): - if offset < 0: - raise SimpleError(INVALID_OFFSET_MSG) - elif not value: - return len(key.get(b'')) - elif offset + len(value) > MAX_STRING_SIZE: - raise SimpleError(STRING_OVERFLOW_MSG) - else: - out = key.get(b'') - if len(out) < offset: - out += b'\x00' * (offset - len(out)) - out = out[0:offset] + value + out[offset+len(value):] - key.update(out) - return len(out) - - @command((Key(bytes),)) - def strlen(self, key): - return len(key.get(b'')) - - # Hash commands - - @command((Key(Hash), bytes), (bytes,)) - def hdel(self, key, *fields): - h = key.value - rem = 0 - for field in fields: - if field in h: - del h[field] - key.updated() - rem += 1 - return rem - - @command((Key(Hash), bytes)) - def hexists(self, key, field): - return int(field in key.value) - - @command((Key(Hash), bytes)) - def hget(self, key, field): - return key.value.get(field) - - @command((Key(Hash),)) - def hgetall(self, key): - return list(itertools.chain(*key.value.items())) - - @command((Key(Hash), bytes, Int)) - def hincrby(self, key, field, amount): - c = Int.decode(key.value.get(field, b'0')) + amount - key.value[field] = Int.encode(c) - key.updated() - return c - - @command((Key(Hash), bytes, bytes)) - def hincrbyfloat(self, key, field, amount): - c = Float.decode(key.value.get(field, b'0')) + Float.decode(amount) - if not math.isfinite(c): - raise SimpleError(NONFINITE_MSG) - encoded = Float.encode(c, True) - key.value[field] = encoded - key.updated() - return encoded - - @command((Key(Hash),)) - def hkeys(self, key): - return list(key.value.keys()) - - @command((Key(Hash),)) - def hlen(self, key): - return len(key.value) - - @command((Key(Hash), bytes), (bytes,)) - def hmget(self, key, *fields): - return [key.value.get(field) for field in fields] - - @command((Key(Hash), bytes, bytes), (bytes, bytes)) - def hmset(self, key, *args): - self.hset(key, *args) - return OK - - @command((Key(Hash), Int,), (bytes, bytes)) - def hscan(self, key, cursor, *args): - cursor, keys = self._scan(key.value, cursor, *args) - items = [] - for k in keys: - items.append(k) - items.append(key.value[k]) - return [cursor, items] - - @command((Key(Hash), bytes, bytes), (bytes, bytes)) - def hset(self, key, *args): - h = key.value - created = 0 - for i in range(0, len(args), 2): - if args[i] not in h: - created += 1 - h[args[i]] = args[i + 1] - key.updated() - return created - - @command((Key(Hash), bytes, bytes)) - def hsetnx(self, key, field, value): - if field in key.value: - return 0 - return self.hset(key, field, value) - - @command((Key(Hash), bytes)) - def hstrlen(self, key, field): - return len(key.value.get(field, b'')) - - @command((Key(Hash),)) - def hvals(self, key): - return list(key.value.values()) - - # List commands - - def _bpop_pass(self, keys, op, first_pass): - for key in keys: - item = CommandItem(key, self._db, item=self._db.get(key), default=[]) - if not isinstance(item.value, list): - if first_pass: - raise SimpleError(WRONGTYPE_MSG) - else: - continue - if item.value: - ret = op(item.value) - item.updated() - item.writeback() - return [key, ret] - return None - - def _bpop(self, args, op): - keys = args[:-1] - timeout = Timeout.decode(args[-1]) - return self._blocking(timeout, functools.partial(self._bpop_pass, keys, op)) - - @command((bytes, bytes), (bytes,), flags='s') - def blpop(self, *args): - return self._bpop(args, lambda lst: lst.pop(0)) - - @command((bytes, bytes), (bytes,), flags='s') - def brpop(self, *args): - return self._bpop(args, lambda lst: lst.pop()) - - def _brpoplpush_pass(self, source, destination, first_pass): - src = CommandItem(source, self._db, item=self._db.get(source), default=[]) - if not isinstance(src.value, list): - if first_pass: - raise SimpleError(WRONGTYPE_MSG) - else: - return None - if not src.value: - return None # Empty list - dst = CommandItem(destination, self._db, item=self._db.get(destination), default=[]) - if not isinstance(dst.value, list): - raise SimpleError(WRONGTYPE_MSG) - el = src.value.pop() - dst.value.insert(0, el) - src.updated() - src.writeback() - if destination != source: - # Ensure writeback only happens once - dst.updated() - dst.writeback() - return el - - @command((bytes, bytes, Timeout), flags='s') - def brpoplpush(self, source, destination, timeout): - return self._blocking(timeout, - functools.partial(self._brpoplpush_pass, source, destination)) - - @command((Key(list, None), Int)) - def lindex(self, key, index): - try: - return key.value[index] - except IndexError: - return None - - @command((Key(list), bytes, bytes, bytes)) - def linsert(self, key, where, pivot, value): - if not casematch(where, b'before') and not casematch(where, b'after'): - raise SimpleError(SYNTAX_ERROR_MSG) - if not key: - return 0 - else: - try: - index = key.value.index(pivot) - except ValueError: - return -1 - if casematch(where, b'after'): - index += 1 - key.value.insert(index, value) - key.updated() - return len(key.value) - - @command((Key(list),)) - def llen(self, key): - return len(key.value) - - def _list_pop(self, get_slice, key, *args): - """Implements lpop and rpop. - - `get_slice` must take a count and return a slice expression for the - range to pop. - """ - # This implementation is somewhat contorted to match the odd - # behaviours described in https://github.com/redis/redis/issues/9680. - count = 1 - if len(args) > 1: - raise SimpleError(SYNTAX_ERROR_MSG) - elif len(args) == 1: - count = args[0] - if count < 0: - raise SimpleError(INDEX_ERROR_MSG) - elif count == 0: - return None - if not key: - return None - elif type(key.value) != list: - raise SimpleError(WRONGTYPE_MSG) - slc = get_slice(count) - ret = key.value[slc] - del key.value[slc] - key.updated() - if not args: - ret = ret[0] - return ret - - @command((Key(),), (Int(),)) - def lpop(self, key, *args): - return self._list_pop(lambda count: slice(None, count), key, *args) - - @command((Key(list), bytes), (bytes,)) - def lpush(self, key, *values): - for value in values: - key.value.insert(0, value) - key.updated() - return len(key.value) - - @command((Key(list), bytes), (bytes,)) - def lpushx(self, key, *values): - if not key: - return 0 - return self.lpush(key, *values) - - @command((Key(list), Int, Int)) - def lrange(self, key, start, stop): - start, stop = self._fix_range(start, stop, len(key.value)) - return key.value[start:stop] - - @command((Key(list), Int, bytes)) - def lrem(self, key, count, value): - a_list = key.value - found = [] - for i, el in enumerate(a_list): - if el == value: - found.append(i) - if count > 0: - indices_to_remove = found[:count] - elif count < 0: - indices_to_remove = found[count:] - else: - indices_to_remove = found - # Iterating in reverse order to ensure the indices - # remain valid during deletion. - for index in reversed(indices_to_remove): - del a_list[index] - if indices_to_remove: - key.updated() - return len(indices_to_remove) - - @command((Key(list), Int, bytes)) - def lset(self, key, index, value): - if not key: - raise SimpleError(NO_KEY_MSG) - try: - key.value[index] = value - key.updated() - except IndexError: - raise SimpleError(INDEX_ERROR_MSG) - return OK - - @command((Key(list), Int, Int)) - def ltrim(self, key, start, stop): - if key: - if stop == -1: - stop = None - else: - stop += 1 - new_value = key.value[start:stop] - # TODO: check if this should actually be conditional - if len(new_value) != len(key.value): - key.update(new_value) - return OK - - @command((Key(),), (Int(),)) - def rpop(self, key, *args): - return self._list_pop(lambda count: slice(None, -count - 1, -1), key, *args) - - @command((Key(list, None), Key(list))) - def rpoplpush(self, src, dst): - el = self.rpop(src) - self.lpush(dst, el) - return el - - @command((Key(list), bytes), (bytes,)) - def rpush(self, key, *values): - for value in values: - key.value.append(value) - key.updated() - return len(key.value) - - @command((Key(list), bytes), (bytes,)) - def rpushx(self, key, *values): - if not key: - return 0 - return self.rpush(key, *values) - - # Set commands - - @command((Key(set), bytes), (bytes,)) - def sadd(self, key, *members): - old_size = len(key.value) - key.value.update(members) - key.updated() - return len(key.value) - old_size - - @command((Key(set),)) - def scard(self, key): - return len(key.value) - - def _calc_setop(self, op, stop_if_missing, key, *keys): - if stop_if_missing and not key.value: - return set() - ans = key.value.copy() - for other in keys: - value = other.value if other.value is not None else set() - if not isinstance(value, set): - raise SimpleError(WRONGTYPE_MSG) - if stop_if_missing and not value: - return set() - ans = op(ans, value) - return ans - - def _setop(self, op, stop_if_missing, dst, key, *keys): - """Apply one of SINTER[STORE], SUNION[STORE], SDIFF[STORE]. - - If `stop_if_missing`, the output will be made an empty set as soon as - an empty input set is encountered (use for SINTER[STORE]). May assume - that `key` is a set (or empty), but `keys` could be anything. - """ - ans = self._calc_setop(op, stop_if_missing, key, *keys) - if dst is None: - return list(ans) - else: - dst.value = ans - return len(dst.value) - - @command((Key(set),), (Key(set),)) - def sdiff(self, *keys): - return self._setop(lambda a, b: a - b, False, None, *keys) - - @command((Key(), Key(set)), (Key(set),)) - def sdiffstore(self, dst, *keys): - return self._setop(lambda a, b: a - b, False, dst, *keys) - - @command((Key(set),), (Key(set),)) - def sinter(self, *keys): - return self._setop(lambda a, b: a & b, True, None, *keys) - - @command((Key(), Key(set)), (Key(set),)) - def sinterstore(self, dst, *keys): - return self._setop(lambda a, b: a & b, True, dst, *keys) - - @command((Key(set), bytes)) - def sismember(self, key, member): - return int(member in key.value) - - @command((Key(set),)) - def smembers(self, key): - return list(key.value) - - @command((Key(set, 0), Key(set), bytes)) - def smove(self, src, dst, member): - try: - src.value.remove(member) - src.updated() - except KeyError: - return 0 - else: - dst.value.add(member) - dst.updated() # TODO: is it updated if member was already present? - return 1 - - @command((Key(set),), (Int,)) - def spop(self, key, count=None): - if count is None: - if not key.value: - return None - item = random.sample(list(key.value), 1)[0] - key.value.remove(item) - key.updated() - return item - else: - if count < 0: - raise SimpleError(INDEX_ERROR_MSG) - items = self.srandmember(key, count) - for item in items: - key.value.remove(item) - key.updated() # Inside the loop because redis special-cases count=0 - return items - - @command((Key(set),), (Int,)) - def srandmember(self, key, count=None): - if count is None: - if not key.value: - return None - else: - return random.sample(list(key.value), 1)[0] - elif count >= 0: - count = min(count, len(key.value)) - return random.sample(list(key.value), count) - else: - items = list(key.value) - return [random.choice(items) for _ in range(-count)] - - @command((Key(set), bytes), (bytes,)) - def srem(self, key, *members): - old_size = len(key.value) - for member in members: - key.value.discard(member) - deleted = old_size - len(key.value) - if deleted: - key.updated() - return deleted - - @command((Key(set), Int), (bytes, bytes)) - def sscan(self, key, cursor, *args): - return self._scan(key.value, cursor, *args) - - @command((Key(set),), (Key(set),)) - def sunion(self, *keys): - return self._setop(lambda a, b: a | b, False, None, *keys) - - @command((Key(), Key(set)), (Key(set),)) - def sunionstore(self, dst, *keys): - return self._setop(lambda a, b: a | b, False, dst, *keys) - - # Hyperloglog commands - # These are not quite the same as the real redis ones, which are - # approximate and store the results in a string. Instead, it is implemented - # on top of sets. - - @command((Key(set),), (bytes,)) - def pfadd(self, key, *elements): - result = self.sadd(key, *elements) - # Per the documentation: - # - 1 if at least 1 HyperLogLog internal register was altered. 0 otherwise. - return 1 if result > 0 else 0 - - @command((Key(set),), (Key(set),)) - def pfcount(self, *keys): - """ - Return the approximated cardinality of - the set observed by the HyperLogLog at key(s). - """ - return len(self.sunion(*keys)) - - @command((Key(set), Key(set)), (Key(set),)) - def pfmerge(self, dest, *sources): - "Merge N different HyperLogLogs into a single one." - self.sunionstore(dest, *sources) - return OK - - # Sorted set commands - # TODO: [b]zpopmin/zpopmax, - - @staticmethod - def _limit_items(items, offset, count): - out = [] - for item in items: - if offset: # Note: not offset > 0, in order to match redis - offset -= 1 - continue - if count == 0: - break - count -= 1 - out.append(item) - return out - - @staticmethod - def _apply_withscores(items, withscores): - if withscores: - out = [] - for item in items: - out.append(item[1]) - out.append(Float.encode(item[0], False)) - else: - out = [item[1] for item in items] - return out - - @command((Key(ZSet), bytes, bytes), (bytes,)) - def zadd(self, key, *args): - zset = key.value - - i = 0 - ch = False - nx = False - xx = False - incr = False - while i < len(args): - if casematch(args[i], b'ch'): - ch = True - i += 1 - elif casematch(args[i], b'nx'): - nx = True - i += 1 - elif casematch(args[i], b'xx'): - xx = True - i += 1 - elif casematch(args[i], b'incr'): - incr = True - i += 1 - else: - # First argument not matching flags indicates the start of - # score pairs. - break - - if nx and xx: - raise SimpleError(ZADD_NX_XX_ERROR_MSG) - - elements = args[i:] - if not elements or len(elements) % 2 != 0: - raise SimpleError(SYNTAX_ERROR_MSG) - if incr and len(elements) != 2: - raise SimpleError(ZADD_INCR_LEN_ERROR_MSG) - # Parse all scores first, before updating - items = [ - (Float.decode(elements[j]), elements[j + 1]) - for j in range(0, len(elements), 2) - ] - old_len = len(zset) - changed_items = 0 - - if incr: - item_score, item_name = items[0] - if (nx and item_name in zset) or (xx and item_name not in zset): - return None - return self.zincrby(key, item_score, item_name) - - for item_score, item_name in items: - if ( - (not nx or item_name not in zset) - and (not xx or item_name in zset) - ): - if zset.add(item_name, item_score): - changed_items += 1 - - if changed_items: - key.updated() - - if ch: - return changed_items - return len(zset) - old_len - - @command((Key(ZSet),)) - def zcard(self, key): - return len(key.value) - - @command((Key(ZSet), ScoreTest, ScoreTest)) - def zcount(self, key, min, max): - return key.value.zcount(min.lower_bound, max.upper_bound) - - @command((Key(ZSet), Float, bytes)) - def zincrby(self, key, increment, member): - # Can't just default the old score to 0.0, because in IEEE754, adding - # 0.0 to something isn't a nop (e.g. 0.0 + -0.0 == 0.0). - try: - score = key.value.get(member, None) + increment - except TypeError: - score = increment - if math.isnan(score): - raise SimpleError(SCORE_NAN_MSG) - key.value[member] = score - key.updated() - return Float.encode(score, False) - - @command((Key(ZSet), StringTest, StringTest)) - def zlexcount(self, key, min, max): - return key.value.zlexcount(min.value, min.exclusive, max.value, max.exclusive) - - def _zrange(self, key, start, stop, reverse, *args): - zset = key.value - withscores = False - for arg in args: - if casematch(arg, b'withscores'): - withscores = True - else: - raise SimpleError(SYNTAX_ERROR_MSG) - start, stop = self._fix_range(start, stop, len(zset)) - if reverse: - start, stop = len(zset) - stop, len(zset) - start - items = zset.islice_score(start, stop, reverse) - items = self._apply_withscores(items, withscores) - return items - - @command((Key(ZSet), Int, Int), (bytes,)) - def zrange(self, key, start, stop, *args): - return self._zrange(key, start, stop, False, *args) - - @command((Key(ZSet), Int, Int), (bytes,)) - def zrevrange(self, key, start, stop, *args): - return self._zrange(key, start, stop, True, *args) - - def _zrangebylex(self, key, min, max, reverse, *args): - if args: - if len(args) != 3 or not casematch(args[0], b'limit'): - raise SimpleError(SYNTAX_ERROR_MSG) - offset = Int.decode(args[1]) - count = Int.decode(args[2]) - else: - offset = 0 - count = -1 - zset = key.value - items = zset.irange_lex(min.value, max.value, - inclusive=(not min.exclusive, not max.exclusive), - reverse=reverse) - items = self._limit_items(items, offset, count) - return items - - @command((Key(ZSet), StringTest, StringTest), (bytes,)) - def zrangebylex(self, key, min, max, *args): - return self._zrangebylex(key, min, max, False, *args) - - @command((Key(ZSet), StringTest, StringTest), (bytes,)) - def zrevrangebylex(self, key, max, min, *args): - return self._zrangebylex(key, min, max, True, *args) - - def _zrangebyscore(self, key, min, max, reverse, *args): - withscores = False - offset = 0 - count = -1 - i = 0 - while i < len(args): - if casematch(args[i], b'withscores'): - withscores = True - i += 1 - elif casematch(args[i], b'limit') and i + 2 < len(args): - offset = Int.decode(args[i + 1]) - count = Int.decode(args[i + 2]) - i += 3 - else: - raise SimpleError(SYNTAX_ERROR_MSG) - zset = key.value - items = list(zset.irange_score(min.lower_bound, max.upper_bound, reverse=reverse)) - items = self._limit_items(items, offset, count) - items = self._apply_withscores(items, withscores) - return items - - @command((Key(ZSet), ScoreTest, ScoreTest), (bytes,)) - def zrangebyscore(self, key, min, max, *args): - return self._zrangebyscore(key, min, max, False, *args) - - @command((Key(ZSet), ScoreTest, ScoreTest), (bytes,)) - def zrevrangebyscore(self, key, max, min, *args): - return self._zrangebyscore(key, min, max, True, *args) - - @command((Key(ZSet), bytes)) - def zrank(self, key, member): - try: - return key.value.rank(member) - except KeyError: - return None - - @command((Key(ZSet), bytes)) - def zrevrank(self, key, member): - try: - return len(key.value) - 1 - key.value.rank(member) - except KeyError: - return None - - @command((Key(ZSet), bytes), (bytes,)) - def zrem(self, key, *members): - old_size = len(key.value) - for member in members: - key.value.discard(member) - deleted = old_size - len(key.value) - if deleted: - key.updated() - return deleted - - @command((Key(ZSet), StringTest, StringTest)) - def zremrangebylex(self, key, min, max): - items = key.value.irange_lex(min.value, max.value, - inclusive=(not min.exclusive, not max.exclusive)) - return self.zrem(key, *items) - - @command((Key(ZSet), ScoreTest, ScoreTest)) - def zremrangebyscore(self, key, min, max): - items = key.value.irange_score(min.lower_bound, max.upper_bound) - return self.zrem(key, *[item[1] for item in items]) - - @command((Key(ZSet), Int, Int)) - def zremrangebyrank(self, key, start, stop): - zset = key.value - start, stop = self._fix_range(start, stop, len(zset)) - items = zset.islice_score(start, stop) - return self.zrem(key, *[item[1] for item in items]) - - @command((Key(ZSet), Int), (bytes, bytes)) - def zscan(self, key, cursor, *args): - new_cursor, ans = self._scan(key.value.items(), cursor, *args) - flat = [] - for (key, score) in ans: - flat.append(key) - flat.append(Float.encode(score, False)) - return [new_cursor, flat] - - @command((Key(ZSet), bytes)) - def zscore(self, key, member): - try: - return Float.encode(key.value[member], False) - except KeyError: - return None - - @staticmethod - def _get_zset(value): - if isinstance(value, set): - zset = ZSet() - for item in value: - zset[item] = 1.0 - return zset - elif isinstance(value, ZSet): - return value - else: - raise SimpleError(WRONGTYPE_MSG) - - def _zunioninter(self, func, dest, numkeys, *args): - if numkeys < 1: - raise SimpleError(ZUNIONSTORE_KEYS_MSG) - if numkeys > len(args): - raise SimpleError(SYNTAX_ERROR_MSG) - aggregate = b'sum' - sets = [] - for i in range(numkeys): - item = CommandItem(args[i], self._db, item=self._db.get(args[i]), default=ZSet()) - sets.append(self._get_zset(item.value)) - weights = [1.0] * numkeys - - i = numkeys - while i < len(args): - arg = args[i] - if casematch(arg, b'weights') and i + numkeys < len(args): - weights = [Float.decode(x) for x in args[i + 1:i + numkeys + 1]] - i += numkeys + 1 - elif casematch(arg, b'aggregate') and i + 1 < len(args): - aggregate = casenorm(args[i + 1]) - if aggregate not in (b'sum', b'min', b'max'): - raise SimpleError(SYNTAX_ERROR_MSG) - i += 2 - else: - raise SimpleError(SYNTAX_ERROR_MSG) - - out_members = set(sets[0]) - for s in sets[1:]: - if func == 'ZUNIONSTORE': - out_members |= set(s) - else: - out_members.intersection_update(s) - - # We first build a regular dict and turn it into a ZSet. The - # reason is subtle: a ZSet won't update a score from -0 to +0 - # (or vice versa) through assignment, but a regular dict will. - out = {} - # The sort affects the order of floating-point operations. - # Note that redis uses qsort(1), which has no stability guarantees, - # so we can't be sure to match it in all cases. - for s, w in sorted(zip(sets, weights), key=lambda x: len(x[0])): - for member, score in s.items(): - score *= w - # Redis only does this step for ZUNIONSTORE. See - # https://github.com/antirez/redis/issues/3954. - if func == 'ZUNIONSTORE' and math.isnan(score): - score = 0.0 - if member not in out_members: - continue - if member in out: - old = out[member] - if aggregate == b'sum': - score += old - if math.isnan(score): - score = 0.0 - elif aggregate == b'max': - score = max(old, score) - elif aggregate == b'min': - score = min(old, score) - else: - assert False # pragma: nocover - if math.isnan(score): - score = 0.0 - out[member] = score - - out_zset = ZSet() - for member, score in out.items(): - out_zset[member] = score - - dest.value = out_zset - return len(out_zset) - - @command((Key(), Int, bytes), (bytes,)) - def zunionstore(self, dest, numkeys, *args): - return self._zunioninter('ZUNIONSTORE', dest, numkeys, *args) - - @command((Key(), Int, bytes), (bytes,)) - def zinterstore(self, dest, numkeys, *args): - return self._zunioninter('ZINTERSTORE', dest, numkeys, *args) - - # Server commands - # TODO: lots - - @command((), (bytes,), flags='s') - def bgsave(self, *args): - if len(args) > 1 or (len(args) == 1 and not casematch(args[0], b'schedule')): - raise SimpleError(SYNTAX_ERROR_MSG) - self._server.lastsave = int(time.time()) - return BGSAVE_STARTED - - @command(()) - def dbsize(self): - return len(self._db) - - @command((), (bytes,)) - def flushdb(self, *args): - if args: - if len(args) != 1 or not casematch(args[0], b'async'): - raise SimpleError(SYNTAX_ERROR_MSG) - self._db.clear() - return OK - - @command((), (bytes,)) - def flushall(self, *args): - if args: - if len(args) != 1 or not casematch(args[0], b'async'): - raise SimpleError(SYNTAX_ERROR_MSG) - for db in self._server.dbs.values(): - db.clear() - # TODO: clear watches and/or pubsub as well? - return OK - - @command(()) - def lastsave(self): - return self._server.lastsave - - @command((), flags='s') - def save(self): - self._server.lastsave = int(time.time()) - return OK - - @command(()) - def time(self): - now_us = round(time.time() * 1000000) - now_s = now_us // 1000000 - now_us %= 1000000 - return [str(now_s).encode(), str(now_us).encode()] - - # Script commands - # script debug and script kill will probably not be supported - - def _convert_redis_arg(self, lua_runtime, value): - # Type checks are exact to avoid issues like bool being a subclass of int. - if type(value) is bytes: - return value - elif type(value) in {int, float}: - return '{:.17g}'.format(value).encode() - else: - # TODO: add the context - raise SimpleError(LUA_COMMAND_ARG_MSG) - - def _convert_redis_result(self, lua_runtime, result): - if isinstance(result, (bytes, int)): - return result - elif isinstance(result, SimpleString): - return lua_runtime.table_from({b"ok": result.value}) - elif result is None: - return False - elif isinstance(result, list): - converted = [ - self._convert_redis_result(lua_runtime, item) - for item in result - ] - return lua_runtime.table_from(converted) - elif isinstance(result, SimpleError): - raise result - else: - raise RuntimeError("Unexpected return type from redis: {}".format(type(result))) - - def _convert_lua_result(self, result, nested=True): - from lupa import lua_type - if lua_type(result) == 'table': - for key in (b'ok', b'err'): - if key in result: - msg = self._convert_lua_result(result[key]) - if not isinstance(msg, bytes): - raise SimpleError(LUA_WRONG_NUMBER_ARGS_MSG) - if key == b'ok': - return SimpleString(msg) - elif nested: - return SimpleError(msg.decode('utf-8', 'replace')) - else: - raise SimpleError(msg.decode('utf-8', 'replace')) - # Convert Lua tables into lists, starting from index 1, mimicking the behavior of StrictRedis. - result_list = [] - for index in itertools.count(1): - if index not in result: - break - item = result[index] - result_list.append(self._convert_lua_result(item)) - return result_list - elif isinstance(result, str): - return result.encode() - elif isinstance(result, float): - return int(result) - elif isinstance(result, bool): - return 1 if result else None - return result - - def _check_for_lua_globals(self, lua_runtime, expected_globals): - actual_globals = set(lua_runtime.globals().keys()) - if actual_globals != expected_globals: - unexpected = [six.ensure_str(var, 'utf-8', 'replace') - for var in actual_globals - expected_globals] - raise SimpleError(GLOBAL_VARIABLE_MSG.format(", ".join(unexpected))) - - def _lua_redis_call(self, lua_runtime, expected_globals, op, *args): - # Check if we've set any global variables before making any change. - self._check_for_lua_globals(lua_runtime, expected_globals) - func, func_name = self._name_to_func(op) - args = [self._convert_redis_arg(lua_runtime, arg) for arg in args] - result = self._run_command(func, func._fakeredis_sig, args, True) - return self._convert_redis_result(lua_runtime, result) - - def _lua_redis_pcall(self, lua_runtime, expected_globals, op, *args): - try: - return self._lua_redis_call(lua_runtime, expected_globals, op, *args) - except Exception as ex: - return lua_runtime.table_from({b"err": str(ex)}) - - def _lua_redis_log(self, lua_runtime, expected_globals, lvl, *args): - self._check_for_lua_globals(lua_runtime, expected_globals) - if len(args) < 1: - raise SimpleError(REQUIRES_MORE_ARGS_MSG.format("redis.log()", "two")) - if lvl not in REDIS_LOG_LEVELS.values(): - raise SimpleError(LOG_INVALID_DEBUG_LEVEL_MSG) - msg = ' '.join([x.decode('utf-8') - if isinstance(x, bytes) else str(x) - for x in args if not isinstance(x, bool)]) - LOGGER.log(REDIS_LOG_LEVELS_TO_LOGGING[lvl], msg) - - @command((bytes, Int), (bytes,), flags='s') - def eval(self, script, numkeys, *keys_and_args): - from lupa import LuaError, LuaRuntime, as_attrgetter - - if numkeys > len(keys_and_args): - raise SimpleError(TOO_MANY_KEYS_MSG) - if numkeys < 0: - raise SimpleError(NEGATIVE_KEYS_MSG) - sha1 = hashlib.sha1(script).hexdigest().encode() - self._server.script_cache[sha1] = script - lua_runtime = LuaRuntime(encoding=None, unpack_returned_tuples=True) - - set_globals = lua_runtime.eval( - """ - function(keys, argv, redis_call, redis_pcall, redis_log, redis_log_levels) - redis = {} - redis.call = redis_call - redis.pcall = redis_pcall - redis.log = redis_log - for level, pylevel in python.iterex(redis_log_levels.items()) do - redis[level] = pylevel - end - redis.error_reply = function(msg) return {err=msg} end - redis.status_reply = function(msg) return {ok=msg} end - KEYS = keys - ARGV = argv - end - """ - ) - expected_globals = set() - set_globals( - lua_runtime.table_from(keys_and_args[:numkeys]), - lua_runtime.table_from(keys_and_args[numkeys:]), - functools.partial(self._lua_redis_call, lua_runtime, expected_globals), - functools.partial(self._lua_redis_pcall, lua_runtime, expected_globals), - functools.partial(self._lua_redis_log, lua_runtime, expected_globals), - as_attrgetter(REDIS_LOG_LEVELS) - ) - expected_globals.update(lua_runtime.globals().keys()) - - try: - result = lua_runtime.execute(script) - except (LuaError, SimpleError) as ex: - raise SimpleError(SCRIPT_ERROR_MSG.format(sha1.decode(), ex)) - - self._check_for_lua_globals(lua_runtime, expected_globals) - - return self._convert_lua_result(result, nested=False) - - @command((bytes, Int), (bytes,), flags='s') - def evalsha(self, sha1, numkeys, *keys_and_args): - try: - script = self._server.script_cache[sha1] - except KeyError: - raise SimpleError(NO_MATCHING_SCRIPT_MSG) - return self.eval(script, numkeys, *keys_and_args) - - @command((bytes,), (bytes,), flags='s') - def script(self, subcmd, *args): - if casematch(subcmd, b'load'): - if len(args) != 1: - raise SimpleError(BAD_SUBCOMMAND_MSG.format('SCRIPT')) - script = args[0] - sha1 = hashlib.sha1(script).hexdigest().encode() - self._server.script_cache[sha1] = script - return sha1 - elif casematch(subcmd, b'exists'): - return [int(sha1 in self._server.script_cache) for sha1 in args] - elif casematch(subcmd, b'flush'): - if len(args) > 1 or (len(args) == 1 and casenorm(args[0]) not in {b'sync', b'async'}): - raise SimpleError(BAD_SUBCOMMAND_MSG.format('SCRIPT')) - self._server.script_cache = {} - return OK - else: - raise SimpleError(BAD_SUBCOMMAND_MSG.format('SCRIPT')) - - # Pubsub commands - # TODO: pubsub command - - def _subscribe(self, channels, subscribers, mtype): - for channel in channels: - subs = subscribers[channel] - if self not in subs: - subs.add(self) - self._pubsub += 1 - msg = [mtype, channel, self._pubsub] - self.put_response(msg) - return NoResponse() - - def _unsubscribe(self, channels, subscribers, mtype): - if not channels: - channels = [] - for (channel, subs) in subscribers.items(): - if self in subs: - channels.append(channel) - for channel in channels: - subs = subscribers.get(channel, set()) - if self in subs: - subs.remove(self) - if not subs: - del subscribers[channel] - self._pubsub -= 1 - msg = [mtype, channel, self._pubsub] - self.put_response(msg) - return NoResponse() - - @command((bytes,), (bytes,), flags='s') - def psubscribe(self, *patterns): - return self._subscribe(patterns, self._server.psubscribers, b'psubscribe') - - @command((bytes,), (bytes,), flags='s') - def subscribe(self, *channels): - return self._subscribe(channels, self._server.subscribers, b'subscribe') - - @command((), (bytes,), flags='s') - def punsubscribe(self, *patterns): - return self._unsubscribe(patterns, self._server.psubscribers, b'punsubscribe') - - @command((), (bytes,), flags='s') - def unsubscribe(self, *channels): - return self._unsubscribe(channels, self._server.subscribers, b'unsubscribe') - - @command((bytes, bytes)) - def publish(self, channel, message): - receivers = 0 - msg = [b'message', channel, message] - subs = self._server.subscribers.get(channel, set()) - for sock in subs: - sock.put_response(msg) - receivers += 1 - for (pattern, socks) in self._server.psubscribers.items(): - regex = compile_pattern(pattern) - if regex.match(channel): - msg = [b'pmessage', pattern, channel, message] - for sock in socks: - sock.put_response(msg) - receivers += 1 - return receivers - - -setattr(FakeSocket, 'del', FakeSocket.del_) -delattr(FakeSocket, 'del_') -setattr(FakeSocket, 'set', FakeSocket.set_) -delattr(FakeSocket, 'set_') -setattr(FakeSocket, 'exec', FakeSocket.exec_) -delattr(FakeSocket, 'exec_') - - -class _DummyParser: - def __init__(self, socket_read_size): - self.socket_read_size = socket_read_size - - def on_disconnect(self): - pass - - def on_connect(self, connection): - pass - - -# Redis <3.2 will not have a selector -try: - from redis.selector import BaseSelector -except ImportError: - class BaseSelector: - def __init__(self, sock): - self.sock = sock - - -class FakeSelector(BaseSelector): - def check_can_read(self, timeout): - if self.sock.responses.qsize(): - return True - if timeout is not None and timeout <= 0: - return False - - # A sleep/poll loop is easier to mock out than messing with condition - # variables. - start = time.time() - while True: - if self.sock.responses.qsize(): - return True - time.sleep(0.01) - now = time.time() - if timeout is not None and now > start + timeout: - return False - - def check_is_ready_for_command(self, timeout): - return True - - -class FakeConnection(redis.Connection): - description_format = "FakeConnection" - - def __init__(self, *args, **kwargs): - self._server = kwargs.pop('server') - super().__init__(*args, **kwargs) - - def connect(self): - super().connect() - # The selector is set in redis.Connection.connect() after _connect() is called - self._selector = FakeSelector(self._sock) - - def _connect(self): - if not self._server.connected: - raise redis.ConnectionError(CONNECTION_ERROR_MSG) - return FakeSocket(self._server) - - def can_read(self, timeout=0): - if not self._server.connected: - return True - if not self._sock: - self.connect() - # We use check_can_read rather than can_read, because on redis-py<3.2, - # FakeSelector inherits from a stub BaseSelector which doesn't - # implement can_read. Normally can_read provides retries on EINTR, - # but that's not necessary for the implementation of - # FakeSelector.check_can_read. - return self._selector.check_can_read(timeout) - - def _decode(self, response): - if isinstance(response, list): - return [self._decode(item) for item in response] - elif isinstance(response, bytes): - return self.encoder.decode(response) - else: - return response - - def read_response(self, disable_decoding=False): - if not self._server.connected: - try: - response = self._sock.responses.get_nowait() - except queue.Empty: - raise redis.ConnectionError(CONNECTION_ERROR_MSG) - else: - response = self._sock.responses.get() - if isinstance(response, redis.ResponseError): - raise response - if disable_decoding: - return response - else: - return self._decode(response) - - def repr_pieces(self): - pieces = [ - ('server', self._server), - ('db', self.db) - ] - if self.client_name: - pieces.append(('client_name', self.client_name)) - return pieces - - -class FakeRedisMixin: - def __init__(self, *args, server=None, connected=True, **kwargs): - # Interpret the positional and keyword arguments according to the - # version of redis in use. - bound = _ORIG_SIG.bind(*args, **kwargs) - bound.apply_defaults() - if not bound.arguments['connection_pool']: - charset = bound.arguments['charset'] - errors = bound.arguments['errors'] - # Adapted from redis-py - if charset is not None: - warnings.warn(DeprecationWarning( - '"charset" is deprecated. Use "encoding" instead')) - bound.arguments['encoding'] = charset - if errors is not None: - warnings.warn(DeprecationWarning( - '"errors" is deprecated. Use "encoding_errors" instead')) - bound.arguments['encoding_errors'] = errors - - if server is None: - server = FakeServer() - server.connected = connected - kwargs = { - 'connection_class': FakeConnection, - 'server': server - } - conn_pool_args = [ - 'db', - 'username', - 'password', - 'socket_timeout', - 'encoding', - 'encoding_errors', - 'decode_responses', - 'retry_on_timeout', - 'max_connections', - 'health_check_interval', - 'client_name' - ] - for arg in conn_pool_args: - if arg in bound.arguments: - kwargs[arg] = bound.arguments[arg] - bound.arguments['connection_pool'] = redis.connection.ConnectionPool(**kwargs) - super().__init__(*bound.args, **bound.kwargs) - - @classmethod - def from_url(https://github.com/cls,%20*args,%20**kwargs): - server = kwargs.pop('server', None) - if server is None: - server = FakeServer() - pool = redis.ConnectionPool.from_url(https://github.com/*args,%20**kwargs) - # Now override how it creates connections - pool.connection_class = FakeConnection - pool.connection_kwargs['server'] = server - # FakeConnection cannot handle the path kwarg (present when from_url - # is called with a unix socket) - pool.connection_kwargs.pop('path', None) - return cls(connection_pool=pool) - - -class FakeStrictRedis(FakeRedisMixin, redis.StrictRedis): - pass - - -class FakeRedis(FakeRedisMixin, redis.Redis): - pass diff --git a/build/lib/fakeredis/_zset.py b/build/lib/fakeredis/_zset.py deleted file mode 100644 index 47d1169..0000000 --- a/build/lib/fakeredis/_zset.py +++ /dev/null @@ -1,87 +0,0 @@ -import sortedcontainers - - -class ZSet: - def __init__(self): - self._bylex = {} # Maps value to score - self._byscore = sortedcontainers.SortedList() - - def __contains__(self, value): - return value in self._bylex - - def add(self, value, score): - """Update the item and return whether it modified the zset""" - old_score = self._bylex.get(value, None) - if old_score is not None: - if score == old_score: - return False - self._byscore.remove((old_score, value)) - self._bylex[value] = score - self._byscore.add((score, value)) - return True - - def __setitem__(self, value, score): - self.add(value, score) - - def __getitem__(self, key): - return self._bylex[key] - - def get(self, key, default=None): - return self._bylex.get(key, default) - - def __len__(self): - return len(self._bylex) - - def __iter__(self): - def gen(): - for score, value in self._byscore: - yield value - - return gen() - - def discard(self, key): - try: - score = self._bylex.pop(key) - except KeyError: - return - else: - self._byscore.remove((score, key)) - - def zcount(self, min_, max_): - pos1 = self._byscore.bisect_left(min_) - pos2 = self._byscore.bisect_left(max_) - return max(0, pos2 - pos1) - - def zlexcount(self, min_value, min_exclusive, max_value, max_exclusive): - if not self._byscore: - return 0 - score = self._byscore[0][0] - if min_exclusive: - pos1 = self._byscore.bisect_right((score, min_value)) - else: - pos1 = self._byscore.bisect_left((score, min_value)) - if max_exclusive: - pos2 = self._byscore.bisect_left((score, max_value)) - else: - pos2 = self._byscore.bisect_right((score, max_value)) - return max(0, pos2 - pos1) - - def islice_score(self, start, stop, reverse=False): - return self._byscore.islice(start, stop, reverse) - - def irange_lex(self, start, stop, inclusive=(True, True), reverse=False): - if not self._byscore: - return iter([]) - score = self._byscore[0][0] - it = self._byscore.irange((score, start), (score, stop), - inclusive=inclusive, reverse=reverse) - return (item[1] for item in it) - - def irange_score(self, start, stop, reverse=False): - return self._byscore.irange(start, stop, reverse=reverse) - - def rank(self, member): - return self._byscore.index((self._bylex[member], member)) - - def items(self): - return self._bylex.items() diff --git a/build/lib/fakeredis/aioredis.py b/build/lib/fakeredis/aioredis.py deleted file mode 100644 index 7d5ba08..0000000 --- a/build/lib/fakeredis/aioredis.py +++ /dev/null @@ -1,10 +0,0 @@ -import aioredis -import packaging.version - - -if packaging.version.Version(aioredis.__version__) >= packaging.version.Version('2.0.0a1'): - from ._aioredis2 import FakeConnection, FakeRedis # noqa: F401 -else: - from ._aioredis1 import ( # noqa: F401 - FakeConnectionsPool, create_connection, create_redis, create_pool, create_redis_pool - ) From 00778489aafa930bb6952288a0f352708ff51358 Mon Sep 17 00:00:00 2001 From: rotten Date: Tue, 4 Jan 2022 17:12:38 -0500 Subject: [PATCH 13/20] add gitignore for /build --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index b408ae9..db7d5b6 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ extras/* cover/ venv/ dist/ +build/* From b25a6ea9a5d086f5ea05800389ebce927563ba0d Mon Sep 17 00:00:00 2001 From: Bruce Merry Date: Mon, 14 Feb 2022 10:41:21 +0200 Subject: [PATCH 14/20] Support redis<4.2.0 Rather than <=4.1.0, which prevented 4.1.1 etc from being tested. --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 0ded32c..a7b6b48 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,7 +26,7 @@ install_requires = # Minor version updates to redis tend to break fakeredis. If you # need to use fakeredis with a newer redis, please submit a PR that # relaxes this restriction and adds it to the Github Actions tests. - redis<=4.1.0 + redis<4.2.0 six>=1.12 sortedcontainers python_requires = >=3.5 From d220c4be91635cb72183f12dd13316d0af897e4d Mon Sep 17 00:00:00 2001 From: Bruce Merry Date: Mon, 14 Feb 2022 10:43:44 +0200 Subject: [PATCH 15/20] Bump testing to test with latest redis-py --- .github/workflows/test.yml | 7 +++++-- requirements.in | 2 +- requirements.txt | 6 ++++-- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3531db2..d2aef35 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,7 +10,7 @@ jobs: fail-fast: false matrix: python-version: ["3.6", "3.7", "3.8", "3.9", "pypy-3.7"] - redis-py: ["4.0.0"] + redis-py: ["4.1.3"] aioredis: ["2.0.0"] include: - python-version: "3.9" @@ -35,7 +35,10 @@ jobs: redis-py: "3.5.3" aioredis: "1.3.1" - python-version: "3.9" - redis-py: "4.0.*" + redis-py: "4.0.1" + aioredis: "1.3.1" + - python-version: "3.9" + redis-py: "4.1.*" aioredis: "2.0.0" coverage: yes services: diff --git a/requirements.in b/requirements.in index 2e14ac2..a5d59d7 100644 --- a/requirements.in +++ b/requirements.in @@ -7,7 +7,7 @@ pytest pytest-asyncio pytest-cov pytest-mock -redis==4.1.0 # Latest at time of writing +redis==4.1.3 # Latest at time of writing six sortedcontainers diff --git a/requirements.txt b/requirements.txt index c9328fb..a695ef0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,7 +31,9 @@ lupa==1.10 mccabe==0.6.1 # via flake8 packaging==21.3 - # via pytest + # via + # pytest + # redis pluggy==0.13.1 # via pytest py==1.10.0 @@ -54,7 +56,7 @@ pytest-cov==2.10.1 # via -r requirements.in pytest-mock==3.3.1 # via -r requirements.in -redis==4.1.0 +redis==4.1.3 # via -r requirements.in six==1.15.0 # via -r requirements.in From ff50c53590ac31b092176521595511e96d65857c Mon Sep 17 00:00:00 2001 From: Bruce Merry Date: Mon, 14 Feb 2022 10:59:26 +0200 Subject: [PATCH 16/20] Prepare 1.7.1 release --- README.rst | 4 ++++ fakeredis/__init__.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 7246661..11cf9a7 100644 --- a/README.rst +++ b/README.rst @@ -452,6 +452,10 @@ they have all been tagged as 'slow' so you can skip them by running:: Revision history ================ +1.7.1 +----- +- `#324 `_ Support redis-py 4.1.x + 1.7.0 ----- - Change a number of corner-case behaviours to match Redis 6.2.6. diff --git a/fakeredis/__init__.py b/fakeredis/__init__.py index cff7eef..7967393 100644 --- a/fakeredis/__init__.py +++ b/fakeredis/__init__.py @@ -1,4 +1,4 @@ from ._server import FakeServer, FakeRedis, FakeStrictRedis, FakeConnection # noqa: F401 -__version__ = '1.7.0' +__version__ = '1.7.1' From 798d534bd4d6d396ba5082184e478f160f791940 Mon Sep 17 00:00:00 2001 From: Bruce Merry Date: Tue, 22 Mar 2022 13:49:10 +0200 Subject: [PATCH 17/20] Note that fakeredis is now unmaintained. --- README.rst | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/README.rst b/README.rst index 11cf9a7..c42795b 100644 --- a/README.rst +++ b/README.rst @@ -1,3 +1,14 @@ +**Maintainer needed** +===================== +At present fakeredis is unmaintained. I (**@bmerry**) have changed roles in my +job and am no longer using redis or fakeredis, and nobody from my old team is +available to take over. Thus, further development will need community +involvement. + +I don't have admin rights on PyPI so I'm not able to transfer control; thus, +you'll need to either create a fork or persuade **@jamesls** to make you the +new maintainer. + fakeredis: A fake version of a redis-py ======================================= From 0303b2ddfa4d6cd5098a0a29b23d300296307c59 Mon Sep 17 00:00:00 2001 From: Bruce Merry Date: Tue, 22 Mar 2022 13:53:11 +0200 Subject: [PATCH 18/20] Update README to indicate that forks may be linked to --- README.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.rst b/README.rst index c42795b..1c41fbb 100644 --- a/README.rst +++ b/README.rst @@ -9,6 +9,10 @@ I don't have admin rights on PyPI so I'm not able to transfer control; thus, you'll need to either create a fork or persuade **@jamesls** to make you the new maintainer. +If you do publish a fork, feel free to send me an email (see `maintainer_email` +in setup.cfg) to let me know and I'll link to it from here. + + fakeredis: A fake version of a redis-py ======================================= From 0a0bcdbdcfc9c08b01d16604ae6bc7dc763c6a83 Mon Sep 17 00:00:00 2001 From: Bruce Merry Date: Mon, 9 May 2022 08:55:40 +0200 Subject: [PATCH 19/20] Replace with placeholder README It points at the new repository. --- .github/workflows/test.yml | 76 - .gitignore | 13 - CONTRIBUTING.rst | 17 - COPYING | 51 - README.rst | 798 +---- fakeredis/__init__.py | 4 - fakeredis/_aioredis1.py | 181 -- fakeredis/_aioredis2.py | 170 -- fakeredis/_async.py | 51 - fakeredis/_server.py | 2850 ------------------ fakeredis/_zset.py | 87 - fakeredis/aioredis.py | 10 - pyproject.toml | 2 - requirements-dev.txt | 4 - requirements.in | 15 - requirements.txt | 72 - scripts/supported | 65 - setup.cfg | 52 - setup.py | 4 - test/conftest.py | 24 - test/test_aioredis1.py | 158 - test/test_aioredis2.py | 252 -- test/test_fakeredis.py | 5592 ------------------------------------ test/test_hypothesis.py | 620 ---- tox.ini | 11 - 25 files changed, 4 insertions(+), 11175 deletions(-) delete mode 100644 .github/workflows/test.yml delete mode 100644 .gitignore delete mode 100644 CONTRIBUTING.rst delete mode 100644 COPYING delete mode 100644 fakeredis/__init__.py delete mode 100644 fakeredis/_aioredis1.py delete mode 100644 fakeredis/_aioredis2.py delete mode 100644 fakeredis/_async.py delete mode 100644 fakeredis/_server.py delete mode 100644 fakeredis/_zset.py delete mode 100644 fakeredis/aioredis.py delete mode 100644 pyproject.toml delete mode 100644 requirements-dev.txt delete mode 100644 requirements.in delete mode 100644 requirements.txt delete mode 100755 scripts/supported delete mode 100644 setup.cfg delete mode 100644 setup.py delete mode 100644 test/conftest.py delete mode 100644 test/test_aioredis1.py delete mode 100644 test/test_aioredis2.py delete mode 100644 test/test_fakeredis.py delete mode 100644 test/test_hypothesis.py delete mode 100644 tox.ini diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml deleted file mode 100644 index d2aef35..0000000 --- a/.github/workflows/test.yml +++ /dev/null @@ -1,76 +0,0 @@ -name: Unit tests -on: [push, pull_request] -concurrency: - group: test-${{ github.ref }} - cancel-in-progress: true -jobs: - test: - runs-on: ubuntu-20.04 - strategy: - fail-fast: false - matrix: - python-version: ["3.6", "3.7", "3.8", "3.9", "pypy-3.7"] - redis-py: ["4.1.3"] - aioredis: ["2.0.0"] - include: - - python-version: "3.9" - redis-py: "2.10.6" - aioredis: "1.3.1" - - python-version: "3.9" - redis-py: "3.0.1" - aioredis: "1.3.1" - - python-version: "3.9" - redis-py: "3.1.0" - aioredis: "1.3.1" - - python-version: "3.9" - redis-py: "3.2.1" - aioredis: "1.3.1" - - python-version: "3.9" - redis-py: "3.3.11" - aioredis: "1.3.1" - - python-version: "3.9" - redis-py: "3.4.1" - aioredis: "1.3.1" - - python-version: "3.9" - redis-py: "3.5.3" - aioredis: "1.3.1" - - python-version: "3.9" - redis-py: "4.0.1" - aioredis: "1.3.1" - - python-version: "3.9" - redis-py: "4.1.*" - aioredis: "2.0.0" - coverage: yes - services: - redis: - image: redis:6.2.6 - ports: - - 6379:6379 - steps: - - uses: actions/checkout@v2 - - uses: actions/cache@v2 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-pip-cache-v2-${{ matrix.python-version }}-${{ hashFiles('requirements.txt') }} - restore-keys: | - ${{ runner.os }}-pip-cache-v2-${{ matrix.python-version }}- - ${{ runner.os }}-pip-cache-v2- - - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - run: pip install -U pip setuptools wheel - - run: pip install -r requirements.txt - - run: pip install redis==${{ matrix.redis-py }} aioredis==${{ matrix.aioredis }} coveralls - - run: pip install -e . - if: ${{ matrix.coverage == 'yes' }} - - run: pip install . - if: ${{ matrix.coverage != 'yes' }} - - run: flake8 - - run: pytest -v --cov=fakeredis --cov-branch - if: ${{ matrix.coverage == 'yes' }} - - run: pytest -v - if: ${{ matrix.coverage != 'yes' }} - - run: coveralls --service=github - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - if: ${{ matrix.coverage == 'yes' }} diff --git a/.gitignore b/.gitignore deleted file mode 100644 index db7d5b6..0000000 --- a/.gitignore +++ /dev/null @@ -1,13 +0,0 @@ -.commands.json -fakeredis.egg-info -dump.rdb -extras/* -.tox -*.pyc -.idea -.hypothesis -.coverage -cover/ -venv/ -dist/ -build/* diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst deleted file mode 100644 index 0cec852..0000000 --- a/CONTRIBUTING.rst +++ /dev/null @@ -1,17 +0,0 @@ -============ -Contributing -============ - -Contributions are welcome. To ensure that your contributions are accepted -please follow these guidelines. - -* Follow pep8 (Github Actions will fail builds that don't pass flake8). -* If you are adding docstrings, follow pep257. -* If you are adding new functionality or fixing a bug, please add tests. -* If you are making a large change, consider filing an issue on github - first to see if there are any objections to the proposed changes. - -In general, new features or bug fixes **will not be merged unless they -have tests.** This is not only to ensure the correctness of -the code, but to also encourage others to experiment without wondering -whether or not they are breaking existing code. diff --git a/COPYING b/COPYING deleted file mode 100644 index f225b41..0000000 --- a/COPYING +++ /dev/null @@ -1,51 +0,0 @@ -Copyright (c) 2011 James Saryerwinnie, 2017-2018 Bruce Merry -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions -are met: -1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. -2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. -3. The name of the author may not be used to endorse or promote products - derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR -IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES -OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. -IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, -INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT -NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF -THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -This software contains portions of code from redis-py, which is distributed -under the following license: - -Copyright (c) 2012 Andy McCurdy - - Permission is hereby granted, free of charge, to any person - obtaining a copy of this software and associated documentation - files (the "Software"), to deal in the Software without - restriction, including without limitation the rights to use, - copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the - Software is furnished to do so, subject to the following - conditions: - - The above copyright notice and this permission notice shall be - included in all copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES - OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND - NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT - HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, - WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR - OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.rst b/README.rst index 1c41fbb..edad847 100644 --- a/README.rst +++ b/README.rst @@ -1,794 +1,4 @@ -**Maintainer needed** -===================== -At present fakeredis is unmaintained. I (**@bmerry**) have changed roles in my -job and am no longer using redis or fakeredis, and nobody from my old team is -available to take over. Thus, further development will need community -involvement. - -I don't have admin rights on PyPI so I'm not able to transfer control; thus, -you'll need to either create a fork or persuade **@jamesls** to make you the -new maintainer. - -If you do publish a fork, feel free to send me an email (see `maintainer_email` -in setup.cfg) to let me know and I'll link to it from here. - - -fakeredis: A fake version of a redis-py -======================================= - -.. image:: https://github.com/jamesls/fakeredis/actions/workflows/test.yml/badge.svg - :target: https://github.com/jamesls/fakeredis/actions/workflows/test.yml - -.. image:: https://coveralls.io/repos/jamesls/fakeredis/badge.svg?branch=master - :target: https://coveralls.io/r/jamesls/fakeredis - - -fakeredis is a pure-Python implementation of the redis-py python client -that simulates talking to a redis server. This was created for a single -purpose: **to write unittests**. Setting up redis is not hard, but -many times you want to write unittests that do not talk to an external server -(such as redis). This module now allows tests to simply use this -module as a reasonable substitute for redis. - -Although fakeredis is pure Python, you will need lupa_ if you want to run Lua -scripts (this includes features like ``redis.lock.Lock``, which are implemented -in Lua). If you install fakeredis with ``pip install fakeredis[lua]`` it will -be automatically installed. - -.. _lupa: https://pypi.org/project/lupa/ - -Alternatives -============ - -Consider using redislite_ instead of fakeredis. It runs a real redis server and -connects to it over a UNIX domain socket, so it will behave just like a real -server. Another alternative is birdisle_, which runs the redis code as a Python -extension (no separate process), but which is currently unmaintained. - -.. _birdisle: https://birdisle.readthedocs.io/en/latest/ -.. _redislite: https://redislite.readthedocs.io/en/latest/ - - -How to Use -========== - -The intent is for fakeredis to act as though you're talking to a real -redis server. It does this by storing state internally. -For example: - -.. code-block:: python - - >>> import fakeredis - >>> r = fakeredis.FakeStrictRedis() - >>> r.set('foo', 'bar') - True - >>> r.get('foo') - 'bar' - >>> r.lpush('bar', 1) - 1 - >>> r.lpush('bar', 2) - 2 - >>> r.lrange('bar', 0, -1) - [2, 1] - -The state is stored in an instance of `FakeServer`. If one is not provided at -construction, a new instance is automatically created for you, but you can -explicitly create one to share state: - -.. code-block:: python - - >>> import fakeredis - >>> server = fakeredis.FakeServer() - >>> r1 = fakeredis.FakeStrictRedis(server=server) - >>> r1.set('foo', 'bar') - True - >>> r2 = fakeredis.FakeStrictRedis(server=server) - >>> r2.get('foo') - 'bar' - >>> r2.set('bar', 'baz') - True - >>> r1.get('bar') - 'baz' - >>> r2.get('bar') - 'baz' - -It is also possible to mock connection errors so you can effectively test -your error handling. Simply set the connected attribute of the server to -`False` after initialization. - -.. code-block:: python - - >>> import fakeredis - >>> server = fakeredis.FakeServer() - >>> server.connected = False - >>> r = fakeredis.FakeStrictRedis(server=server) - >>> r.set('foo', 'bar') - ConnectionError: FakeRedis is emulating a connection error. - >>> server.connected = True - >>> r.set('foo', 'bar') - True - -Fakeredis implements the same interface as `redis-py`_, the -popular redis client for python, and models the responses -of redis 6.2 (although most new features are not supported). - -Support for aioredis -==================== - -You can also use fakeredis to mock out aioredis_. This is a much newer -addition to fakeredis (added in 1.4.0) with less testing, so your mileage may -vary. Both version 1 and version 2 (which have very different APIs) are -supported. The API provided by fakeredis depends on the version of aioredis that is -installed. - -.. _aioredis: https://aioredis.readthedocs.io/ - -aioredis 1.x ------------- - -Example: - -.. code-block:: python - - >>> import fakeredis.aioredis - >>> r = await fakeredis.aioredis.create_redis_pool() - >>> await r.set('foo', 'bar') - True - >>> await r.get('foo') - b'bar' - -You can pass a `FakeServer` as the first argument to `create_redis` or -`create_redis_pool` to share state (you can even share state with a -`fakeredis.FakeRedis`). It should even be safe to do this state sharing between -threads (as long as each connection/pool is only used in one thread). - -It is highly recommended that you only use the aioredis support with -Python 3.5.3 or higher. Earlier versions will not work correctly with -non-default event loops. - -aioredis 2.x ------------- - -Example: - -.. code-block:: python - - >>> import fakeredis.aioredis - >>> r = fakeredis.aioredis.FakeRedis() - >>> await r.set('foo', 'bar') - True - >>> await r.get('foo') - b'bar' - -The support is essentially the same as for redis-py e.g., you can pass a -`server` keyword argument to the `FakeRedis` constructor. - -Porting to fakeredis 1.0 -======================== - -Version 1.0 is an almost total rewrite, intended to support redis-py 3.x and -improve the Lua scripting emulation. It has a few backwards incompatibilities -that may require changes to your code: - -1. By default, each FakeRedis or FakeStrictRedis instance contains its own - state. This is equivalent to the `singleton=False` option to previous - versions of fakeredis. This change was made to improve isolation between - tests. If you need to share state between instances, create a FakeServer, - as described above. - -2. FakeRedis is now a subclass of Redis, and similarly - FakeStrictRedis is a subclass of StrictRedis. Code that uses `isinstance` - may behave differently. - -3. The `connected` attribute is now a property of `FakeServer`, rather than - `FakeRedis` or `FakeStrictRedis`. You can still pass the property to the - constructor of the latter (provided no server is provided). - - -Unimplemented Commands -====================== - -All of the redis commands are implemented in fakeredis with -these exceptions: - - -server ------- - - * acl load - * acl save - * acl list - * acl users - * acl getuser - * acl setuser - * acl deluser - * acl cat - * acl genpass - * acl whoami - * acl log - * acl help - * bgrewriteaof - * command - * command count - * command getkeys - * command info - * config get - * config rewrite - * config set - * config resetstat - * debug object - * debug segfault - * info - * lolwut - * memory doctor - * memory help - * memory malloc-stats - * memory purge - * memory stats - * memory usage - * module list - * module load - * module unload - * monitor - * role - * shutdown - * slaveof - * replicaof - * slowlog - * sync - * psync - * latency doctor - * latency graph - * latency history - * latency latest - * latency reset - * latency help - - -connection ----------- - - * auth - * client caching - * client id - * client kill - * client list - * client getname - * client getredir - * client pause - * client reply - * client setname - * client tracking - * client unblock - * hello - * quit - - -string ------- - - * bitfield - * bitop - * bitpos - * stralgo - - -sorted_set ----------- - - * bzpopmin - * bzpopmax - * zpopmax - * zpopmin - - -cluster -------- - - * cluster addslots - * cluster bumpepoch - * cluster count-failure-reports - * cluster countkeysinslot - * cluster delslots - * cluster failover - * cluster flushslots - * cluster forget - * cluster getkeysinslot - * cluster info - * cluster keyslot - * cluster meet - * cluster myid - * cluster nodes - * cluster replicate - * cluster reset - * cluster saveconfig - * cluster set-config-epoch - * cluster setslot - * cluster slaves - * cluster replicas - * cluster slots - * readonly - * readwrite - - -generic -------- - - * migrate - * object - * touch - * wait - - -geo ---- - - * geoadd - * geohash - * geopos - * geodist - * georadius - * georadiusbymember - - -list ----- - - * lpos - - -pubsub ------- - - * pubsub - - -scripting ---------- - - * script debug - * script kill - - -stream ------- - - * xinfo - * xadd - * xtrim - * xdel - * xrange - * xrevrange - * xlen - * xread - * xgroup - * xreadgroup - * xack - * xclaim - * xpending - - -Other limitations -================= - -Apart from unimplemented commands, there are a number of cases where fakeredis -won't give identical results to real redis. The following are differences that -are unlikely to ever be fixed; there are also differences that are fixable -(such as commands that do not support all features) which should be filed as -bugs in Github. - -1. Hyperloglogs are implemented using sets underneath. This means that the - `type` command will return the wrong answer, you can't use `get` to retrieve - the encoded value, and counts will be slightly different (they will in fact be - exact). - -2. When a command has multiple error conditions, such as operating on a key of - the wrong type and an integer argument is not well-formed, the choice of - error to return may not match redis. - -3. The `incrbyfloat` and `hincrbyfloat` commands in redis use the C `long - double` type, which typically has more precision than Python's `float` - type. - -4. Redis makes guarantees about the order in which clients blocked on blocking - commands are woken up. Fakeredis does not honour these guarantees. - -5. Where redis contains bugs, fakeredis generally does not try to provide exact - bug-compatibility. It's not practical for fakeredis to try to match the set - of bugs in your specific version of redis. - -6. There are a number of cases where the behaviour of redis is undefined, such - as the order of elements returned by set and hash commands. Fakeredis will - generally not produce the same results, and in Python versions before 3.6 - may produce different results each time the process is re-run. - -7. SCAN/ZSCAN/HSCAN/SSCAN will not necessarily iterate all items if items are - deleted or renamed during iteration. They also won't necessarily iterate in - the same chunk sizes or the same order as redis. - -8. DUMP/RESTORE will not return or expect data in the RDB format. Instead the - `pickle` module is used to mimic an opaque and non-standard format. - **WARNING**: Do not use RESTORE with untrusted data, as a malicious pickle - can execute arbitrary code. - -Contributing -============ - -Contributions are welcome. Please see the `contributing guide`_ for -more details. The maintainer generally has very little time to work on -fakeredis, so the best way to get a bug fixed is to contribute a pull -request. - -If you'd like to help out, you can start with any of the issues -labeled with `HelpWanted`_. - - -Running the Tests -================= - -To ensure parity with the real redis, there are a set of integration tests -that mirror the unittests. For every unittest that is written, the same -test is run against a real redis instance using a real redis-py client -instance. In order to run these tests you must have a redis server running -on localhost, port 6379 (the default settings). **WARNING**: the tests will -completely wipe your database! - - -First install the requirements file:: - - pip install -r requirements.txt - -To run all the tests:: - - pytest - -If you only want to run tests against fake redis, without a real redis:: - - pytest -m fake - -Because this module is attempting to provide the same interface as `redis-py`_, -the python bindings to redis, a reasonable way to test this to to take each -unittest and run it against a real redis server. fakeredis and the real redis -server should give the same result. To run tests against a real redis instance -instead:: - - pytest -m real - -If redis is not running and you try to run tests against a real redis server, -these tests will have a result of 's' for skipped. - -There are some tests that test redis blocking operations that are somewhat -slow. If you want to skip these tests during day to day development, -they have all been tagged as 'slow' so you can skip them by running:: - - pytest -m "not slow" - - -Revision history -================ - -1.7.1 ------ -- `#324 `_ Support redis-py 4.1.x - -1.7.0 ------ -- Change a number of corner-case behaviours to match Redis 6.2.6. -- `#310 `_ Fix DeprecationWarning for sampling from a set -- `#315 `_ Improved support for constructor arguments -- `#316 `_ Support redis-py 4 -- `#319 `_ Add support for GET option to SET -- `#323 `_ PERSIST and EXPIRE should invalidate watches - -1.6.1 ------ -- `#305 `_ Some packaging modernisation -- `#306 `_ Fix FakeRedisMixin.from_url for unix sockets -- `#308 `_ Remove use of async_generator from tests - -1.6.0 ------ -- `#304 `_ Support aioredis 2 -- `#302 `_ Switch CI from Travis CI to Github Actions - -1.5.2 ------ -- Depend on `aioredis<2` (aioredis 2.x is a backwards-incompatible rewrite). - -1.5.1 ------ -- `#298 `_ Fix a deadlock caused - by garbage collection - -1.5.0 ------ -- Fix clearing of watches when a transaction is aborted. -- Support Python 3.9 and drop support for Python 3.5. -- Update handling of EXEC failures to match redis 6.0.6+. -- `#293 `_ Align - `FakeConnection` constructor signature to base class -- Skip hypothesis tests on 32-bit Redis servers. - -1.4.5 ------ -- `#285 `_ Add support for DUMP - and RESTORE commands -- `#286 `_ Add support for TYPE - option to SCAN command - -1.4.4 ------ -- `#281 `_ Add support for - SCRIPT EXISTS and SCRIPT FLUSH subcommands -- `#280 `_ Fix documentation - about singleton argument - -1.4.3 ------ -- `#277 `_ Implement SET with KEEPTTL -- `#278 `_ Handle indefinite - timeout for PUBSUB commands - -1.4.2 ------ -- `#269 `_ Prevent passing - booleans from Lua to redis -- `#254 `_ Implement TIME command -- `#232 `_ Implement ZADD with INCR -- Rework of unit tests to use more pytest idioms - -1.4.1 ------ -- `#268 `_ Support redis-py 3.5 - (no code changes, just setup.py) - -1.4.0 ------ -- Add support for aioredis. -- Fix interaction of no-op SREM with WATCH. - -1.3.1 ------ -- Make errors from Lua behave more like real redis - -1.3.0 ------ -- `#266 `_ Implement redis.log in Lua - -1.2.1 ------ -- `#262 `_ Cannot repr redis object without host attribute -- Fix a bug in the hypothesis test framework that occasionally caused a failure - -1.2.0 ------ -- Drop support for Python 2.7. -- Test with Python 3.8 and Pypy3. -- Refactor Hypothesis-based tests to support the latest version of Hypothesis. -- Fix a number of bugs in the Hypothesis tests that were causing spurious test - failures or hangs. -- Fix some obscure corner cases - - - If a WATCHed key is MOVEd, don't invalidate the transaction. - - Some cases of passing a key of the wrong type to SINTER/SINTERSTORE were - not reporting a WRONGTYPE error. - - ZUNIONSTORE/ZINTERSTORE could generate different scores from real redis - in corner cases (mostly involving infinities). - -- Speed up the implementation of BINCOUNT. - -1.1.1 ------ -- Support redis-py 3.4. - -1.1.0 ------ -- `#257 `_ Add other inputs for redis connection - -1.0.5 ------ -- `#247 `_ Support NX/XX/CH flags in ZADD command -- `#250 `_ Implement UNLINK command -- `#252 `_ Fix implementation of ZSCAN - -1.0.4 ------ -- `#240 `_ `#242 `_ Support for ``redis==3.3`` - -1.0.3 ------ -- `#235 `_ Support for ``redis==3.2`` - -1.0.2 ------ -- `#235 `_ Depend on ``redis<3.2`` - -1.0.1 ------ -- Fix crash when a connection closes without unsubscribing and there is a subsequent PUBLISH - -1.0 ---- - -Version 1.0 is a major rewrite. It works at the redis protocol level, rather -than at the redis-py level. This allows for many improvements and bug fixes. - -- `#225 `_ Support redis-py 3.0 -- `#65 `_ Support `execute_command` method -- `#206 `_ Drop Python 2.6 support -- `#141 `_ Support strings in integer arguments -- `#218 `_ Watches checks commands rather than final value -- `#220 `_ Better support for calling into redis from Lua -- `#158 `_ Better timestamp handling -- Support for `register_script` function. -- Fixes for race conditions caused by keys expiring mid-command -- Disallow certain commands in scripts -- Fix handling of blocking commands inside transactions -- Fix handling of PING inside pubsub connections - -It also has new unit tests based on hypothesis_, which has identified many -corner cases that are now handled correctly. - -.. _hypothesis: https://hypothesis.readthedocs.io/en/latest/ - -1.0rc1 ------- -Compared to 1.0b1: - -- `#231 `_ Fix setup.py, fakeredis is directory/package now -- Fix some corner case handling of +0 vs -0 -- Fix pubsub `get_message` with a timeout -- Disallow certain commands in scripts -- Fix handling of blocking commands inside transactions -- Fix handling of PING inside pubsub connections -- Make hypothesis tests skip if redis is not running -- Minor optimisations to zset - -1.0b1 ------ -Version 1.0 is a major rewrite. It works at the redis protocol level, rather -than at the redis-py level. This allows for many improvements and bug fixes. - -- `#225 `_ Support redis-py 3.0 -- `#65 `_ Support `execute_command` method -- `#206 `_ Drop Python 2.6 support -- `#141 `_ Support strings in integer arguments -- `#218 `_ Watches checks commands rather than final value -- `#220 `_ Better support for calling into redis from Lua -- `#158 `_ Better timestamp handling -- Support for `register_script` function. -- Fixes for race conditions caused by keys expiring mid-command - -It also has new unit tests based on hypothesis_, which has identified many -corner cases that are now handled correctly. - -.. _hypothesis: https://hypothesis.readthedocs.io/en/latest/ - -0.16.0 ------- -- `#224 `_ Add __delitem__ -- Restrict to redis<3 - -0.15.0 ------- -- `#219 `_ Add SAVE, BGSAVE and LASTSAVE commands -- `#222 `_ Fix deprecation warnings in Python 3.7 - -0.14.0 ------- -This release greatly improves support for threads: the bulk of commands are now -thread-safe, ``lock`` has been rewritten to more closely match redis-py, and -pubsub now supports ``run_in_thread``: - -- `#213 `_ pipeline.watch runs transaction even if no commands are queued -- `#214 `_ Added pubsub.run_in_thread as it is implemented in redis-py -- `#215 `_ Keep pace with redis-py for zrevrange method -- `#216 `_ Update behavior of lock to behave closer to redis lock - -0.13.1 ------- -- `#208 `_ eval's KEYS and ARGV are now lua tables -- `#209 `_ Redis operation that returns dict now converted to Lua table when called inside eval operation -- `#212 `_ Optimize ``_scan()`` - -0.13.0.1 --------- -- Fix a typo in the Trove classifiers - -0.13.0 ------- -- `#202 `_ Function smembers returns deepcopy -- `#205 `_ Implemented hstrlen -- `#207 `_ Test on Python 3.7 - -0.12.0 ------- -- `#197 `_ Mock connection error -- `#195 `_ Align bool/len behaviour of pipeline -- `#199 `_ future.types.newbytes does not encode correctly - -0.11.0 ------- -- `#194 `_ Support ``score_cast_func`` in zset functions -- `#192 `_ Make ``__getitem__`` raise a KeyError for missing keys - -0.10.3 ------- -This is a minor bug-fix release. - -- `#189 `_ Add 'System' to the list of libc equivalents - -0.10.2 ------- -This is a bug-fix release. - -- `#181 `_ Upgrade twine & other packaging dependencies -- `#106 `_ randomkey method is not implemented, but is not in the list of unimplemented commands -- `#170 `_ Prefer readthedocs.io instead of readthedocs.org for doc links -- `#180 `_ zadd with no member-score pairs should fail -- `#145 `_ expire / _expire: accept 'long' also as time -- `#182 `_ Pattern matching does not match redis behaviour -- `#135 `_ Scan includes expired keys -- `#185 `_ flushall() doesn't clean everything -- `#186 `_ Fix psubscribe with handlers -- Run CI on PyPy -- Fix coverage measurement - -0.10.1 ------- -This release merges the fakenewsredis_ fork back into fakeredis. The version -number is chosen to be larger than any fakenewsredis release, so version -numbers between the forks are comparable. All the features listed under -fakenewsredis version numbers below are thus included in fakeredis for the -first time in this release. - -Additionally, the following was added: -- `#169 `_ Fix set-bit - -fakenewsredis 0.10.0 --------------------- -- `#14 `_ Add option to create an instance with non-shared data -- `#13 `_ Improve emulation of redis -> Lua returns -- `#12 `_ Update tox.ini: py35/py36 and extras for eval tests -- `#11 `_ Fix typo in private method name - -fakenewsredis 0.9.5 -------------------- -This release makes a start on supporting Lua scripting: -- `#9 `_ Add support for StrictRedis.eval for Lua scripts - -fakenewsredis 0.9.4 -------------------- -This is a minor bugfix and optimization release: -- `#5 `_ Update to match redis-py 2.10.6 -- `#7 `_ Set with invalid expiry time should not set key -- Avoid storing useless expiry times in hashes and sorted sets -- Improve the performance of bulk zadd - -fakenewsredis 0.9.3 -------------------- -This is a minor bugfix release: -- `#6 `_ Fix iteration over pubsub list -- `#3 `_ Preserve expiry time when mutating keys -- Fixes to typos and broken links in documentation - -fakenewsredis 0.9.2 -------------------- -This is the first release of fakenewsredis, based on fakeredis 0.9.0, with the following features and fixes: - -- fakeredis `#78 `_ Behaviour of transaction() does not match redis-py -- fakeredis `#79 `_ Implement redis-py's .lock() -- fakeredis `#90 `_ HINCRBYFLOAT changes hash value type to float -- fakeredis `#101 `_ Should raise an error when attempting to get a key holding a list) -- fakeredis `#146 `_ Pubsub messages and channel names are forced to be ASCII strings on Python 2 -- fakeredis `#163 `_ getset does not to_bytes the value -- fakeredis `#165 `_ linsert implementation is incomplete -- fakeredis `#128 `_ Remove `_ex_keys` mapping -- fakeredis `#139 `_ Fixed all flake8 errors and added flake8 to Travis CI -- fakeredis `#166 `_ Add type checking -- fakeredis `#168 `_ Use repr to encode floats in to_bytes - -.. _fakenewsredis: https://github.com/ska-sa/fakenewsredis -.. _redis-py: http://redis-py.readthedocs.io/ -.. _contributing guide: https://github.com/jamesls/fakeredis/blob/master/CONTRIBUTING.rst -.. _HelpWanted: https://github.com/jamesls/fakeredis/issues?q=is%3Aissue+is%3Aopen+label%3AHelpWanted +**fakeredis has moved** +======================= +Fakeredis has a new maintainer, and can now be found at +https://github.com/dsoftwareinc/fakeredis. diff --git a/fakeredis/__init__.py b/fakeredis/__init__.py deleted file mode 100644 index 7967393..0000000 --- a/fakeredis/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from ._server import FakeServer, FakeRedis, FakeStrictRedis, FakeConnection # noqa: F401 - - -__version__ = '1.7.1' diff --git a/fakeredis/_aioredis1.py b/fakeredis/_aioredis1.py deleted file mode 100644 index 7679f2e..0000000 --- a/fakeredis/_aioredis1.py +++ /dev/null @@ -1,181 +0,0 @@ -import asyncio -import sys -import warnings - -import aioredis - -from . import _async, _server - - -class FakeSocket(_async.AsyncFakeSocket): - def _decode_error(self, error): - return aioredis.ReplyError(error.value) - - -class FakeReader: - """Re-implementation of aioredis.stream.StreamReader. - - It does not use a socket, but instead provides a queue that feeds - `readobj`. - """ - - def __init__(self, socket): - self._socket = socket - - def set_parser(self, parser): - pass # No parser needed, we get already-parsed data - - async def readobj(self): - if self._socket.responses is None: - raise asyncio.CancelledError - result = await self._socket.responses.get() - return result - - def at_eof(self): - return self._socket.responses is None - - def feed_obj(self, obj): - self._queue.put_nowait(obj) - - -class FakeWriter: - """Replaces a StreamWriter for an aioredis connection.""" - - def __init__(self, socket): - self.transport = socket # So that aioredis can call writer.transport.close() - - def write(self, data): - self.transport.sendall(data) - - -class FakeConnectionsPool(aioredis.ConnectionsPool): - def __init__(self, server=None, db=None, password=None, encoding=None, - *, minsize, maxsize, ssl=None, parser=None, - create_connection_timeout=None, - connection_cls=None, - loop=None): - super().__init__('fakeredis', - db=db, - password=password, - encoding=encoding, - minsize=minsize, - maxsize=maxsize, - ssl=ssl, - parser=parser, - create_connection_timeout=create_connection_timeout, - connection_cls=connection_cls, - loop=loop) - if server is None: - server = _server.FakeServer() - self._server = server - - def _create_new_connection(self, address): - # TODO: what does address do here? Might just be for sentinel? - return create_connection(self._server, - db=self._db, - password=self._password, - ssl=self._ssl, - encoding=self._encoding, - parser=self._parser_class, - timeout=self._create_connection_timeout, - connection_cls=self._connection_cls, - ) - - -async def create_connection(server=None, *, db=None, password=None, ssl=None, - encoding=None, parser=None, loop=None, - timeout=None, connection_cls=None): - # This is mostly copied from aioredis.connection.create_connection - if timeout is not None and timeout <= 0: - raise ValueError("Timeout has to be None or a number greater than 0") - - if connection_cls: - assert issubclass(connection_cls, aioredis.abc.AbcConnection),\ - "connection_class does not meet the AbcConnection contract" - cls = connection_cls - else: - cls = aioredis.connection.RedisConnection - - if loop is not None and sys.version_info >= (3, 8, 0): - warnings.warn("The loop argument is deprecated", - DeprecationWarning) - - if server is None: - server = _server.FakeServer() - socket = FakeSocket(server) - reader = FakeReader(socket) - writer = FakeWriter(socket) - conn = cls(reader, writer, encoding=encoding, - address='fakeredis', parser=parser) - - try: - if password is not None: - await conn.auth(password) - if db is not None: - await conn.select(db) - except Exception: - conn.close() - await conn.wait_closed() - raise - return conn - - -async def create_redis(server=None, *, db=None, password=None, ssl=None, - encoding=None, commands_factory=aioredis.Redis, - parser=None, timeout=None, - connection_cls=None, loop=None): - conn = await create_connection(server, db=db, - password=password, - ssl=ssl, - encoding=encoding, - parser=parser, - timeout=timeout, - connection_cls=connection_cls, - loop=loop) - return commands_factory(conn) - - -async def create_pool(server=None, *, db=None, password=None, ssl=None, - encoding=None, minsize=1, maxsize=10, - parser=None, loop=None, create_connection_timeout=None, - pool_cls=None, connection_cls=None): - # Mostly copied from aioredis.pool.create_pool. - if pool_cls: - assert issubclass(pool_cls, aioredis.AbcPool),\ - "pool_class does not meet the AbcPool contract" - cls = pool_cls - else: - cls = FakeConnectionsPool - - pool = cls(server, db, password, encoding, - minsize=minsize, maxsize=maxsize, - ssl=ssl, parser=parser, - create_connection_timeout=create_connection_timeout, - connection_cls=connection_cls, - loop=loop) - try: - await pool._fill_free(override_min=False) - except Exception: - pool.close() - await pool.wait_closed() - raise - return pool - - -async def create_redis_pool(server=None, *, db=None, password=None, ssl=None, - encoding=None, commands_factory=aioredis.Redis, - minsize=1, maxsize=10, parser=None, - timeout=None, pool_cls=None, - connection_cls=None, loop=None): - pool = await create_pool(server, db=db, - password=password, - ssl=ssl, - encoding=encoding, - minsize=minsize, - maxsize=maxsize, - parser=parser, - create_connection_timeout=timeout, - pool_cls=pool_cls, - connection_cls=connection_cls, - loop=loop) - return commands_factory(pool) diff --git a/fakeredis/_aioredis2.py b/fakeredis/_aioredis2.py deleted file mode 100644 index d07d197..0000000 --- a/fakeredis/_aioredis2.py +++ /dev/null @@ -1,170 +0,0 @@ -import asyncio -from typing import Union - -import aioredis - -from . import _async, _server - - -class FakeSocket(_async.AsyncFakeSocket): - _connection_error_class = aioredis.ConnectionError - - def _decode_error(self, error): - return aioredis.connection.BaseParser(1).parse_error(error.value) - - -class FakeReader: - pass - - -class FakeWriter: - def __init__(self, socket: FakeSocket) -> None: - self._socket = socket - - def close(self): - self._socket = None - - async def wait_closed(self): - pass - - async def drain(self): - pass - - def writelines(self, data): - for chunk in data: - self._socket.sendall(chunk) - - -class FakeConnection(aioredis.Connection): - def __init__(self, *args, **kwargs): - self._server = kwargs.pop('server') - self._sock = None - super().__init__(*args, **kwargs) - - async def _connect(self): - if not self._server.connected: - raise aioredis.ConnectionError(_server.CONNECTION_ERROR_MSG) - self._sock = FakeSocket(self._server) - self._reader = FakeReader() - self._writer = FakeWriter(self._sock) - - async def disconnect(self): - await super().disconnect() - self._sock = None - - async def can_read(self, timeout: float = 0): - if not self.is_connected: - await self.connect() - if timeout == 0: - return not self._sock.responses.empty() - # asyncio.Queue doesn't have a way to wait for the queue to be - # non-empty without consuming an item, so kludge it with a sleep/poll - # loop. - loop = asyncio.get_event_loop() - start = loop.time() - while True: - if not self._sock.responses.empty(): - return True - await asyncio.sleep(0.01) - now = loop.time() - if timeout is not None and now > start + timeout: - return False - - def _decode(self, response): - if isinstance(response, list): - return [self._decode(item) for item in response] - elif isinstance(response, bytes): - return self.encoder.decode(response) - else: - return response - - async def read_response(self): - if not self._server.connected: - try: - response = self._sock.responses.get_nowait() - except asyncio.QueueEmpty: - raise aioredis.ConnectionError(_server.CONNECTION_ERROR_MSG) - else: - response = await self._sock.responses.get() - if isinstance(response, aioredis.ResponseError): - raise response - return self._decode(response) - - def repr_pieces(self): - pieces = [ - ('server', self._server), - ('db', self.db) - ] - if self.client_name: - pieces.append(('client_name', self.client_name)) - return pieces - - -class FakeRedis(aioredis.Redis): - def __init__( - self, - *, - db: Union[str, int] = 0, - password: str = None, - socket_timeout: float = None, - connection_pool: aioredis.ConnectionPool = None, - encoding: str = "utf-8", - encoding_errors: str = "strict", - decode_responses: bool = False, - retry_on_timeout: bool = False, - max_connections: int = None, - health_check_interval: int = 0, - client_name: str = None, - username: str = None, - server: _server.FakeServer = None, - connected: bool = True, - **kwargs - ): - if not connection_pool: - # Adapted from aioredis - if server is None: - server = _server.FakeServer() - server.connected = connected - connection_kwargs = { - "db": db, - "username": username, - "password": password, - "socket_timeout": socket_timeout, - "encoding": encoding, - "encoding_errors": encoding_errors, - "decode_responses": decode_responses, - "retry_on_timeout": retry_on_timeout, - "max_connections": max_connections, - "health_check_interval": health_check_interval, - "client_name": client_name, - "server": server, - "connection_class": FakeConnection - } - connection_pool = aioredis.ConnectionPool(**connection_kwargs) - super().__init__( - db=db, - password=password, - socket_timeout=socket_timeout, - connection_pool=connection_pool, - encoding=encoding, - encoding_errors=encoding_errors, - decode_responses=decode_responses, - retry_on_timeout=retry_on_timeout, - max_connections=max_connections, - health_check_interval=health_check_interval, - client_name=client_name, - username=username, - **kwargs - ) - - @classmethod - def from_url(https://github.com/cls,%20url:%20str,%20**kwargs): - server = kwargs.pop('server', None) - if server is None: - server = _server.FakeServer() - self = super().from_url(https://github.com/url,%20**kwargs) - # Now override how it creates connections - pool = self.connection_pool - pool.connection_class = FakeConnection - pool.connection_kwargs['server'] = server - return self diff --git a/fakeredis/_async.py b/fakeredis/_async.py deleted file mode 100644 index ec51d1e..0000000 --- a/fakeredis/_async.py +++ /dev/null @@ -1,51 +0,0 @@ -import asyncio - -import async_timeout - -from . import _server - - -class AsyncFakeSocket(_server.FakeSocket): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.responses = asyncio.Queue() - - def put_response(self, msg): - self.responses.put_nowait(msg) - - async def _async_blocking(self, timeout, func, event, callback): - try: - result = None - with async_timeout.timeout(timeout if timeout else None): - while True: - await event.wait() - event.clear() - # This is a coroutine outside the normal control flow that - # locks the server, so we have to take our own lock. - with self._server.lock: - ret = func(False) - if ret is not None: - result = self._decode_result(ret) - self.put_response(result) - break - except asyncio.TimeoutError: - result = None - finally: - with self._server.lock: - self._db.remove_change_callback(callback) - self.put_response(result) - self.resume() - - def _blocking(self, timeout, func): - loop = asyncio.get_event_loop() - ret = func(True) - if ret is not None or self._in_transaction: - return ret - event = asyncio.Event() - - def callback(): - loop.call_soon_threadsafe(event.set) - self._db.add_change_callback(callback) - self.pause() - loop.create_task(self._async_blocking(timeout, func, event, callback)) - return _server.NoResponse() diff --git a/fakeredis/_server.py b/fakeredis/_server.py deleted file mode 100644 index dbeb438..0000000 --- a/fakeredis/_server.py +++ /dev/null @@ -1,2850 +0,0 @@ -import functools -import hashlib -import inspect -import itertools -import logging -import math -import pickle -import queue -import random -import re -import threading -import time -import warnings -import weakref -from collections import defaultdict -from collections.abc import MutableMapping - -import redis -import six - -from ._zset import ZSet - -LOGGER = logging.getLogger('fakeredis') -REDIS_LOG_LEVELS = { - b'LOG_DEBUG': 0, - b'LOG_VERBOSE': 1, - b'LOG_NOTICE': 2, - b'LOG_WARNING': 3 -} -REDIS_LOG_LEVELS_TO_LOGGING = { - 0: logging.DEBUG, - 1: logging.INFO, - 2: logging.INFO, - 3: logging.WARNING -} - -MAX_STRING_SIZE = 512 * 1024 * 1024 - -INVALID_EXPIRE_MSG = "ERR invalid expire time in {}" -WRONGTYPE_MSG = \ - "WRONGTYPE Operation against a key holding the wrong kind of value" -SYNTAX_ERROR_MSG = "ERR syntax error" -INVALID_INT_MSG = "ERR value is not an integer or out of range" -INVALID_FLOAT_MSG = "ERR value is not a valid float" -INVALID_OFFSET_MSG = "ERR offset is out of range" -INVALID_BIT_OFFSET_MSG = "ERR bit offset is not an integer or out of range" -INVALID_BIT_VALUE_MSG = "ERR bit is not an integer or out of range" -INVALID_DB_MSG = "ERR DB index is out of range" -INVALID_MIN_MAX_FLOAT_MSG = "ERR min or max is not a float" -INVALID_MIN_MAX_STR_MSG = "ERR min or max not a valid string range item" -STRING_OVERFLOW_MSG = "ERR string exceeds maximum allowed size (512MB)" -OVERFLOW_MSG = "ERR increment or decrement would overflow" -NONFINITE_MSG = "ERR increment would produce NaN or Infinity" -SCORE_NAN_MSG = "ERR resulting score is not a number (NaN)" -INVALID_SORT_FLOAT_MSG = "ERR One or more scores can't be converted into double" -SRC_DST_SAME_MSG = "ERR source and destination objects are the same" -NO_KEY_MSG = "ERR no such key" -INDEX_ERROR_MSG = "ERR index out of range" -ZADD_NX_XX_ERROR_MSG = "ERR ZADD allows either 'nx' or 'xx', not both" -ZADD_INCR_LEN_ERROR_MSG = "ERR INCR option supports a single increment-element pair" -ZUNIONSTORE_KEYS_MSG = "ERR at least 1 input key is needed for ZUNIONSTORE/ZINTERSTORE" -WRONG_ARGS_MSG = "ERR wrong number of arguments for '{}' command" -UNKNOWN_COMMAND_MSG = "ERR unknown command '{}'" -EXECABORT_MSG = "EXECABORT Transaction discarded because of previous errors." -MULTI_NESTED_MSG = "ERR MULTI calls can not be nested" -WITHOUT_MULTI_MSG = "ERR {0} without MULTI" -WATCH_INSIDE_MULTI_MSG = "ERR WATCH inside MULTI is not allowed" -NEGATIVE_KEYS_MSG = "ERR Number of keys can't be negative" -TOO_MANY_KEYS_MSG = "ERR Number of keys can't be greater than number of args" -TIMEOUT_NEGATIVE_MSG = "ERR timeout is negative" -NO_MATCHING_SCRIPT_MSG = "NOSCRIPT No matching script. Please use EVAL." -GLOBAL_VARIABLE_MSG = "ERR Script attempted to set global variables: {}" -COMMAND_IN_SCRIPT_MSG = "ERR This Redis command is not allowed from scripts" -BAD_SUBCOMMAND_MSG = "ERR Unknown {} subcommand or wrong # of args." -BAD_COMMAND_IN_PUBSUB_MSG = \ - "ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT allowed in this context" -CONNECTION_ERROR_MSG = "FakeRedis is emulating a connection error." -REQUIRES_MORE_ARGS_MSG = "ERR {} requires {} arguments or more." -LOG_INVALID_DEBUG_LEVEL_MSG = "ERR Invalid debug level." -LUA_COMMAND_ARG_MSG = "ERR Lua redis() command arguments must be strings or integers" -LUA_WRONG_NUMBER_ARGS_MSG = "ERR wrong number or type of arguments" -SCRIPT_ERROR_MSG = "ERR Error running script (call to f_{}): @user_script:?: {}" -RESTORE_KEY_EXISTS = "BUSYKEY Target key name already exists." -RESTORE_INVALID_CHECKSUM_MSG = "ERR DUMP payload version or checksum are wrong" -RESTORE_INVALID_TTL_MSG = "ERR Invalid TTL value, must be >= 0" - -FLAG_NO_SCRIPT = 's' # Command not allowed in scripts - -# This needs to be grabbed early to avoid breaking tests that mock redis.Redis. -_ORIG_SIG = inspect.signature(redis.Redis) - - -class SimpleString: - def __init__(self, value): - assert isinstance(value, bytes) - self.value = value - - -class SimpleError(Exception): - """Exception that will be turned into a frontend-specific exception.""" - - def __init__(self, value): - assert isinstance(value, str) - self.value = value - - -class NoResponse: - """Returned by pub/sub commands to indicate that no response should be returned""" - pass - - -OK = SimpleString(b'OK') -QUEUED = SimpleString(b'QUEUED') -PONG = SimpleString(b'PONG') -BGSAVE_STARTED = SimpleString(b'Background saving started') - - -def null_terminate(s): - # Redis uses C functions on some strings, which means they stop at the - # first NULL. - if b'\0' in s: - return s[:s.find(b'\0')] - return s - - -def casenorm(s): - return null_terminate(s).lower() - - -def casematch(a, b): - return casenorm(a) == casenorm(b) - - -def compile_pattern(pattern): - """Compile a glob pattern (e.g. for keys) to a bytes regex. - - fnmatch.fnmatchcase doesn't work for this, because it uses different - escaping rules to redis, uses ! instead of ^ to negate a character set, - and handles invalid cases (such as a [ without a ]) differently. This - implementation was written by studying the redis implementation. - """ - # It's easier to work with text than bytes, because indexing bytes - # doesn't behave the same in Python 3. Latin-1 will round-trip safely. - pattern = pattern.decode('latin-1') - parts = ['^'] - i = 0 - L = len(pattern) - while i < L: - c = pattern[i] - i += 1 - if c == '?': - parts.append('.') - elif c == '*': - parts.append('.*') - elif c == '\\': - if i == L: - i -= 1 - parts.append(re.escape(pattern[i])) - i += 1 - elif c == '[': - parts.append('[') - if i < L and pattern[i] == '^': - i += 1 - parts.append('^') - parts_len = len(parts) # To detect if anything was added - while i < L: - if pattern[i] == '\\' and i + 1 < L: - i += 1 - parts.append(re.escape(pattern[i])) - elif pattern[i] == ']': - i += 1 - break - elif i + 2 < L and pattern[i + 1] == '-': - start = pattern[i] - end = pattern[i + 2] - if start > end: - start, end = end, start - parts.append(re.escape(start) + '-' + re.escape(end)) - i += 2 - else: - parts.append(re.escape(pattern[i])) - i += 1 - if len(parts) == parts_len: - if parts[-1] == '[': - # Empty group - will never match - parts[-1] = '(?:$.)' - else: - # Negated empty group - matches any character - assert parts[-1] == '^' - parts.pop() - parts[-1] = '.' - else: - parts.append(']') - else: - parts.append(re.escape(c)) - parts.append('\\Z') - regex = ''.join(parts).encode('latin-1') - return re.compile(regex, re.S) - - -class Item: - """An item stored in the database""" - - __slots__ = ['value', 'expireat'] - - def __init__(self, value): - self.value = value - self.expireat = None - - -class CommandItem: - """An item referenced by a command. - - It wraps an Item but has extra fields to manage updates and notifications. - """ - def __init__(self, key, db, item=None, default=None): - if item is None: - self._value = default - self._expireat = None - else: - self._value = item.value - self._expireat = item.expireat - self.key = key - self.db = db - self._modified = False - self._expireat_modified = False - - @property - def value(self): - return self._value - - @value.setter - def value(self, new_value): - self._value = new_value - self._modified = True - self.expireat = None - - @property - def expireat(self): - return self._expireat - - @expireat.setter - def expireat(self, value): - self._expireat = value - self._expireat_modified = True - self._modified = True # Since redis 6.0.7 - - def get(self, default): - return self._value if self else default - - def update(self, new_value): - self._value = new_value - self._modified = True - - def updated(self): - self._modified = True - - def writeback(self): - if self._modified: - self.db.notify_watch(self.key) - if not isinstance(self.value, bytes) and not self.value: - self.db.pop(self.key, None) - return - else: - item = self.db.setdefault(self.key, Item(None)) - item.value = self.value - item.expireat = self.expireat - elif self._expireat_modified and self.key in self.db: - self.db[self.key].expireat = self.expireat - - def __bool__(self): - return bool(self._value) or isinstance(self._value, bytes) - - __nonzero__ = __bool__ # For Python 2 - - -class Database(MutableMapping): - def __init__(self, lock, *args, **kwargs): - self._dict = dict(*args, **kwargs) - self.time = 0.0 - self._watches = defaultdict(weakref.WeakSet) # key to set of connections - self.condition = threading.Condition(lock) - self._change_callbacks = set() - - def swap(self, other): - self._dict, other._dict = other._dict, self._dict - self.time, other.time = other.time, self.time - - def notify_watch(self, key): - for sock in self._watches.get(key, set()): - sock.notify_watch() - self.condition.notify_all() - for callback in self._change_callbacks: - callback() - - def add_watch(self, key, sock): - self._watches[key].add(sock) - - def remove_watch(self, key, sock): - watches = self._watches[key] - watches.discard(sock) - if not watches: - del self._watches[key] - - def add_change_callback(self, callback): - self._change_callbacks.add(callback) - - def remove_change_callback(self, callback): - self._change_callbacks.remove(callback) - - def clear(self): - for key in self: - self.notify_watch(key) - self._dict.clear() - - def expired(self, item): - return item.expireat is not None and item.expireat < self.time - - def _remove_expired(self): - for key in list(self._dict): - item = self._dict[key] - if self.expired(item): - del self._dict[key] - - def __getitem__(self, key): - item = self._dict[key] - if self.expired(item): - del self._dict[key] - raise KeyError(key) - return item - - def __setitem__(self, key, value): - self._dict[key] = value - - def __delitem__(self, key): - del self._dict[key] - - def __iter__(self): - self._remove_expired() - return iter(self._dict) - - def __len__(self): - self._remove_expired() - return len(self._dict) - - def __hash__(self): - return hash(super(object, self)) - - def __eq__(self, other): - return super(object, self) == other - - -class Hash(dict): - redis_type = b'hash' - - -class Int: - """Argument converter for 64-bit signed integers""" - - DECODE_ERROR = INVALID_INT_MSG - ENCODE_ERROR = OVERFLOW_MSG - MIN_VALUE = -2**63 - MAX_VALUE = 2**63 - 1 - - @classmethod - def valid(cls, value): - return cls.MIN_VALUE <= value <= cls.MAX_VALUE - - @classmethod - def decode(cls, value): - try: - out = int(value) - if not cls.valid(out) or str(out).encode() != value: - raise ValueError - except ValueError: - raise SimpleError(cls.DECODE_ERROR) - return out - - @classmethod - def encode(cls, value): - if cls.valid(value): - return str(value).encode() - else: - raise SimpleError(cls.ENCODE_ERROR) - - -class BitOffset(Int): - """Argument converter for unsigned bit positions""" - - DECODE_ERROR = INVALID_BIT_OFFSET_MSG - MIN_VALUE = 0 - MAX_VALUE = 8 * MAX_STRING_SIZE - 1 # Redis imposes 512MB limit on keys - - -class BitValue(Int): - DECODE_ERROR = INVALID_BIT_VALUE_MSG - MIN_VALUE = 0 - MAX_VALUE = 1 - - -class DbIndex(Int): - """Argument converter for database indices""" - - DECODE_ERROR = INVALID_DB_MSG - MIN_VALUE = 0 - MAX_VALUE = 15 - - -class Timeout(Int): - """Argument converter for timeouts""" - - DECODE_ERROR = TIMEOUT_NEGATIVE_MSG - MIN_VALUE = 0 - - -class Float: - """Argument converter for floating-point values. - - Redis uses long double for some cases (INCRBYFLOAT, HINCRBYFLOAT) - and double for others (zset scores), but Python doesn't support - long double. - """ - - DECODE_ERROR = INVALID_FLOAT_MSG - - @classmethod - def decode(cls, value, - allow_leading_whitespace=False, - allow_erange=False, - allow_empty=False, - crop_null=False): - # redis has some quirks in float parsing, with several variants. - # See https://github.com/antirez/redis/issues/5706 - try: - if crop_null: - value = null_terminate(value) - if allow_empty and value == b'': - value = b'0.0' - if not allow_leading_whitespace and value[:1].isspace(): - raise ValueError - if value[-1:].isspace(): - raise ValueError - out = float(value) - if math.isnan(out): - raise ValueError - if not allow_erange: - # Values that over- or underflow- are explicitly rejected by - # redis. This is a crude hack to determine whether the input - # may have been such a value. - if out in (math.inf, -math.inf, 0.0) and re.match(b'^[^a-zA-Z]*[1-9]', value): - raise ValueError - return out - except ValueError: - raise SimpleError(cls.DECODE_ERROR) - - @classmethod - def encode(cls, value, humanfriendly): - if math.isinf(value): - return str(value).encode() - elif humanfriendly: - # Algorithm from ld2string in redis - out = '{:.17f}'.format(value) - out = re.sub(r'(?:\.)?0+$', '', out) - return out.encode() - else: - return '{:.17g}'.format(value).encode() - - -class SortFloat(Float): - DECODE_ERROR = INVALID_SORT_FLOAT_MSG - - @classmethod - def decode(cls, value): - return super().decode( - value, allow_leading_whitespace=True, allow_empty=True, crop_null=True) - - -class ScoreTest: - """Argument converter for sorted set score endpoints.""" - def __init__(self, value, exclusive=False): - self.value = value - self.exclusive = exclusive - - @classmethod - def decode(cls, value): - try: - exclusive = False - if value[:1] == b'(': - exclusive = True - value = value[1:] - value = Float.decode( - value, allow_leading_whitespace=True, allow_erange=True, - allow_empty=True, crop_null=True) - return cls(value, exclusive) - except SimpleError: - raise SimpleError(INVALID_MIN_MAX_FLOAT_MSG) - - def __str__(self): - if self.exclusive: - return '({!r}'.format(self.value) - else: - return repr(self.value) - - @property - def lower_bound(self): - return (self.value, AfterAny() if self.exclusive else BeforeAny()) - - @property - def upper_bound(self): - return (self.value, BeforeAny() if self.exclusive else AfterAny()) - - -class StringTest: - """Argument converter for sorted set LEX endpoints.""" - def __init__(self, value, exclusive): - self.value = value - self.exclusive = exclusive - - @classmethod - def decode(cls, value): - if value == b'-': - return cls(BeforeAny(), True) - elif value == b'+': - return cls(AfterAny(), True) - elif value[:1] == b'(': - return cls(value[1:], True) - elif value[:1] == b'[': - return cls(value[1:], False) - else: - raise SimpleError(INVALID_MIN_MAX_STR_MSG) - - -@functools.total_ordering -class BeforeAny: - def __gt__(self, other): - return False - - def __eq__(self, other): - return isinstance(other, BeforeAny) - - -@functools.total_ordering -class AfterAny: - def __lt__(self, other): - return False - - def __eq__(self, other): - return isinstance(other, AfterAny) - - -class Key: - """Marker to indicate that argument in signature is a key""" - UNSPECIFIED = object() - - def __init__(self, type_=None, missing_return=UNSPECIFIED): - self.type_ = type_ - self.missing_return = missing_return - - -class Signature: - def __init__(self, name, fixed, repeat=(), flags=""): - self.name = name - self.fixed = fixed - self.repeat = repeat - self.flags = flags - - def check_arity(self, args): - if len(args) != len(self.fixed): - delta = len(args) - len(self.fixed) - if delta < 0 or not self.repeat: - raise SimpleError(WRONG_ARGS_MSG.format(self.name)) - - def apply(self, args, db): - """Returns a tuple, which is either: - - transformed args and a dict of CommandItems; or - - a single containing a short-circuit return value - """ - self.check_arity(args) - if self.repeat: - delta = len(args) - len(self.fixed) - if delta % len(self.repeat) != 0: - raise SimpleError(WRONG_ARGS_MSG.format(self.name)) - - types = list(self.fixed) - for i in range(len(args) - len(types)): - types.append(self.repeat[i % len(self.repeat)]) - - args = list(args) - # First pass: convert/validate non-keys, and short-circuit on missing keys - for i, (arg, type_) in enumerate(zip(args, types)): - if isinstance(type_, Key): - if type_.missing_return is not Key.UNSPECIFIED and arg not in db: - return (type_.missing_return,) - elif type_ != bytes: - args[i] = type_.decode(args[i]) - - # Second pass: read keys and check their types - command_items = [] - for i, (arg, type_) in enumerate(zip(args, types)): - if isinstance(type_, Key): - item = db.get(arg) - default = None - if type_.type_ is not None: - if item is not None and type(item.value) != type_.type_: - raise SimpleError(WRONGTYPE_MSG) - if item is None: - if type_.type_ is not bytes: - default = type_.type_() - args[i] = CommandItem(arg, db, item, default=default) - command_items.append(args[i]) - - return args, command_items - - -def valid_response_type(value, nested=False): - if isinstance(value, NoResponse) and not nested: - return True - if value is not None and not isinstance(value, (bytes, SimpleString, SimpleError, - int, list)): - return False - if isinstance(value, list): - if any(not valid_response_type(item, True) for item in value): - return False - return True - - -def command(*args, **kwargs): - def decorator(func): - name = kwargs.pop('name', func.__name__) - func._fakeredis_sig = Signature(name, *args, **kwargs) - return func - - return decorator - - -class FakeServer: - def __init__(self): - self.lock = threading.Lock() - self.dbs = defaultdict(lambda: Database(self.lock)) - # Maps SHA1 to script source - self.script_cache = {} - # Maps channel/pattern to weak set of sockets - self.subscribers = defaultdict(weakref.WeakSet) - self.psubscribers = defaultdict(weakref.WeakSet) - self.lastsave = int(time.time()) - self.connected = True - # List of weakrefs to sockets that are being closed lazily - self.closed_sockets = [] - - -class FakeSocket: - _connection_error_class = redis.ConnectionError - - def __init__(self, server): - self._server = server - self._db = server.dbs[0] - self._db_num = 0 - # When in a MULTI, set to a list of function calls - self._transaction = None - self._transaction_failed = False - # Set when executing the commands from EXEC - self._in_transaction = False - self._watch_notified = False - self._watches = set() - self._pubsub = 0 # Count of subscriptions - self.responses = queue.Queue() - # Prevents parser from processing commands. Not used in this module, - # but set by aioredis module to prevent new commands being processed - # while handling a blocking command. - self._paused = False - self._parser = self._parse_commands() - self._parser.send(None) - - def put_response(self, msg): - # redis.Connection.__del__ might call self.close at any time, which - # will set self.responses to None. We assume this will happen - # atomically, and the code below then protects us against this. - responses = self.responses - if responses: - responses.put(msg) - - def pause(self): - self._paused = True - - def resume(self): - self._paused = False - self._parser.send(b'') - - def shutdown(self, flags): - self._parser.close() - - def fileno(self): - # Our fake socket must return an integer from `FakeSocket.fileno()` since a real selector - # will be created. The value does not matter since we replace the selector with our own - # `FakeSelector` before it is ever used. - return 0 - - def _cleanup(self, server): - """Remove all the references to `self` from `server`. - - This is called with the server lock held, but it may be some time after - self.close. - """ - for subs in server.subscribers.values(): - subs.discard(self) - for subs in server.psubscribers.values(): - subs.discard(self) - self._clear_watches() - - def close(self): - # Mark ourselves for cleanup. This might be called from - # redis.Connection.__del__, which the garbage collection could call - # at any time, and hence we can't safely take the server lock. - # We rely on list.append being atomic. - self._server.closed_sockets.append(weakref.ref(self)) - self._server = None - self._db = None - self.responses = None - - @staticmethod - def _extract_line(buf): - pos = buf.find(b'\n') + 1 - assert pos > 0 - line = buf[:pos] - buf = buf[pos:] - assert line.endswith(b'\r\n') - return line, buf - - def _parse_commands(self): - """Generator that parses commands. - - It is fed pieces of redis protocol data (via `send`) and calls - `_process_command` whenever it has a complete one. - """ - buf = b'' - while True: - while self._paused or b'\n' not in buf: - buf += yield - line, buf = self._extract_line(buf) - assert line[:1] == b'*' # array - n_fields = int(line[1:-2]) - fields = [] - for i in range(n_fields): - while b'\n' not in buf: - buf += yield - line, buf = self._extract_line(buf) - assert line[:1] == b'$' # string - length = int(line[1:-2]) - while len(buf) < length + 2: - buf += yield - fields.append(buf[:length]) - buf = buf[length+2:] # +2 to skip the CRLF - self._process_command(fields) - - def _run_command(self, func, sig, args, from_script): - command_items = {} - try: - ret = sig.apply(args, self._db) - if len(ret) == 1: - result = ret[0] - else: - args, command_items = ret - if from_script and FLAG_NO_SCRIPT in sig.flags: - raise SimpleError(COMMAND_IN_SCRIPT_MSG) - if self._pubsub and sig.name not in [ - 'ping', 'subscribe', 'unsubscribe', - 'psubscribe', 'punsubscribe', 'quit']: - raise SimpleError(BAD_COMMAND_IN_PUBSUB_MSG) - result = func(*args) - assert valid_response_type(result) - except SimpleError as exc: - result = exc - for command_item in command_items: - command_item.writeback() - return result - - def _decode_error(self, error): - return redis.connection.BaseParser().parse_error(error.value) - - def _decode_result(self, result): - """Convert SimpleString and SimpleError, recursively""" - if isinstance(result, list): - return [self._decode_result(r) for r in result] - elif isinstance(result, SimpleString): - return result.value - elif isinstance(result, SimpleError): - return self._decode_error(result) - else: - return result - - def _blocking(self, timeout, func): - """Run a function until it succeeds or timeout is reached. - - The timeout must be an integer, and 0 means infinite. The function - is called with a boolean to indicate whether this is the first call. - If it returns None it is considered to have "failed" and is retried - each time the condition variable is notified, until the timeout is - reached. - - Returns the function return value, or None if the timeout was reached. - """ - ret = func(True) - if ret is not None or self._in_transaction: - return ret - if timeout: - deadline = time.time() + timeout - else: - deadline = None - while True: - timeout = deadline - time.time() if deadline is not None else None - if timeout is not None and timeout <= 0: - return None - # Python <3.2 doesn't return a status from wait. On Python 3.2+ - # we bail out early on False. - if self._db.condition.wait(timeout=timeout) is False: - return None # Timeout expired - ret = func(False) - if ret is not None: - return ret - - def _name_to_func(self, name): - name = six.ensure_str(name, encoding='utf-8', errors='replace') - func_name = name.lower() - func = getattr(self, func_name, None) - if name.startswith('_') or not func or not hasattr(func, '_fakeredis_sig'): - # redis remaps \r or \n in an error to ' ' to make it legal protocol - clean_name = name.replace('\r', ' ').replace('\n', ' ') - raise SimpleError(UNKNOWN_COMMAND_MSG.format(clean_name)) - return func, func_name - - def sendall(self, data): - if not self._server.connected: - raise self._connection_error_class(CONNECTION_ERROR_MSG) - if isinstance(data, str): - data = data.encode('ascii') - self._parser.send(data) - - def _process_command(self, fields): - if not fields: - return - func_name = None - try: - func, func_name = self._name_to_func(fields[0]) - sig = func._fakeredis_sig - with self._server.lock: - # Clean out old connections - while True: - try: - weak_sock = self._server.closed_sockets.pop() - except IndexError: - break - else: - sock = weak_sock() - if sock: - sock._cleanup(self._server) - now = time.time() - for db in self._server.dbs.values(): - db.time = now - sig.check_arity(fields[1:]) - # TODO: make a signature attribute for transactions - if self._transaction is not None \ - and func_name not in ('exec', 'discard', 'multi', 'watch'): - self._transaction.append((func, sig, fields[1:])) - result = QUEUED - else: - result = self._run_command(func, sig, fields[1:], False) - except SimpleError as exc: - if self._transaction is not None: - # TODO: should not apply if the exception is from _run_command - # e.g. watch inside multi - self._transaction_failed = True - if func_name == 'exec' and exc.value.startswith('ERR '): - exc.value = 'EXECABORT Transaction discarded because of: ' + exc.value[4:] - self._transaction = None - self._transaction_failed = False - self._clear_watches() - result = exc - result = self._decode_result(result) - if not isinstance(result, NoResponse): - self.put_response(result) - - def notify_watch(self): - self._watch_notified = True - - # redis has inconsistent handling of negative indices, hence two versions - # of this code. - - @staticmethod - def _fix_range_string(start, end, length): - # Negative number handling is based on the redis source code - if start < 0 and end < 0 and start > end: - return -1, -1 - if start < 0: - start = max(0, start + length) - if end < 0: - end = max(0, end + length) - end = min(end, length - 1) - return start, end + 1 - - @staticmethod - def _fix_range(start, end, length): - # Redis handles negative slightly differently for zrange - if start < 0: - start = max(0, start + length) - if end < 0: - end += length - if start > end or start >= length: - return -1, -1 - end = min(end, length - 1) - return start, end + 1 - - def _scan(self, keys, cursor, *args): - """ - This is the basis of most of the ``scan`` methods. - - This implementation is KNOWN to be un-performant, as it requires - grabbing the full set of keys over which we are investigating subsets. - - It also doesn't adhere to the guarantee that every key will be iterated - at least once even if the database is modified during the scan. - However, provided the database is not modified, every key will be - returned exactly once. - """ - pattern = None - type = None - count = 10 - if len(args) % 2 != 0: - raise SimpleError(SYNTAX_ERROR_MSG) - for i in range(0, len(args), 2): - if casematch(args[i], b'match'): - pattern = args[i + 1] - elif casematch(args[i], b'count'): - count = Int.decode(args[i + 1]) - if count <= 0: - raise SimpleError(SYNTAX_ERROR_MSG) - elif casematch(args[i], b'type'): - type = args[i + 1] - else: - raise SimpleError(SYNTAX_ERROR_MSG) - - if cursor >= len(keys): - return [0, []] - data = sorted(keys) - result_cursor = cursor + count - result_data = [] - - regex = compile_pattern(pattern) if pattern is not None else None - - def match_key(key): - return regex.match(key) if pattern is not None else True - - def match_type(key): - if type is not None: - return casematch(self.type(self._db[key]).value, type) - return True - - if pattern is not None or type is not None: - for val in itertools.islice(data, cursor, result_cursor): - compare_val = val[0] if isinstance(val, tuple) else val - if match_key(compare_val) and match_type(compare_val): - result_data.append(val) - else: - result_data = data[cursor:result_cursor] - - if result_cursor >= len(data): - result_cursor = 0 - return [result_cursor, result_data] - - # Connection commands - # TODO: auth, quit - - @command((bytes,)) - def echo(self, message): - return message - - @command((), (bytes,)) - def ping(self, *args): - if len(args) > 1: - raise SimpleError(WRONG_ARGS_MSG.format('ping')) - if self._pubsub: - return [b'pong', args[0] if args else b''] - else: - return args[0] if args else PONG - - @command((DbIndex,)) - def select(self, index): - self._db = self._server.dbs[index] - self._db_num = index - return OK - - @command((DbIndex, DbIndex)) - def swapdb(self, index1, index2): - if index1 != index2: - db1 = self._server.dbs[index1] - db2 = self._server.dbs[index2] - db1.swap(db2) - return OK - - # Key commands - # TODO: lots - - def _delete(self, *keys): - ans = 0 - done = set() - for key in keys: - if key and key.key not in done: - key.value = None - done.add(key.key) - ans += 1 - return ans - - @command((Key(),), (Key(),), name='del') - def del_(self, *keys): - return self._delete(*keys) - - @command((Key(),), (Key(),), name='unlink') - def unlink(self, *keys): - return self._delete(*keys) - - @command((Key(),), (Key(),)) - def exists(self, *keys): - ret = 0 - for key in keys: - if key: - ret += 1 - return ret - - def _expireat(self, key, timestamp): - if not key: - return 0 - else: - key.expireat = timestamp - return 1 - - def _ttl(self, key, scale): - if not key: - return -2 - elif key.expireat is None: - return -1 - else: - return int(round((key.expireat - self._db.time) * scale)) - - @command((Key(), Int)) - def expire(self, key, seconds): - return self._expireat(key, self._db.time + seconds) - - @command((Key(), Int)) - def expireat(self, key, timestamp): - return self._expireat(key, float(timestamp)) - - @command((Key(), Int)) - def pexpire(self, key, ms): - return self._expireat(key, self._db.time + ms / 1000.0) - - @command((Key(), Int)) - def pexpireat(self, key, ms_timestamp): - return self._expireat(key, ms_timestamp / 1000.0) - - @command((Key(),)) - def ttl(self, key): - return self._ttl(key, 1.0) - - @command((Key(),)) - def pttl(self, key): - return self._ttl(key, 1000.0) - - @command((Key(),)) - def type(self, key): - if key.value is None: - return SimpleString(b'none') - elif isinstance(key.value, bytes): - return SimpleString(b'string') - elif isinstance(key.value, list): - return SimpleString(b'list') - elif isinstance(key.value, set): - return SimpleString(b'set') - elif isinstance(key.value, ZSet): - return SimpleString(b'zset') - elif isinstance(key.value, dict): - return SimpleString(b'hash') - else: - assert False # pragma: nocover - - @command((Key(),)) - def persist(self, key): - if key.expireat is None: - return 0 - key.expireat = None - return 1 - - @command((bytes,)) - def keys(self, pattern): - if pattern == b'*': - return list(self._db) - else: - regex = compile_pattern(pattern) - return [key for key in self._db if regex.match(key)] - - @command((Key(), DbIndex)) - def move(self, key, db): - if db == self._db_num: - raise SimpleError(SRC_DST_SAME_MSG) - if not key or key.key in self._server.dbs[db]: - return 0 - # TODO: what is the interaction with expiry? - self._server.dbs[db][key.key] = self._server.dbs[self._db_num][key.key] - key.value = None # Causes deletion - return 1 - - @command(()) - def randomkey(self): - keys = list(self._db.keys()) - if not keys: - return None - return random.choice(keys) - - @command((Key(), Key())) - def rename(self, key, newkey): - if not key: - raise SimpleError(NO_KEY_MSG) - # TODO: check interaction with WATCH - if newkey.key != key.key: - newkey.value = key.value - newkey.expireat = key.expireat - key.value = None - return OK - - @command((Key(), Key())) - def renamenx(self, key, newkey): - if not key: - raise SimpleError(NO_KEY_MSG) - if newkey: - return 0 - self.rename(key, newkey) - return 1 - - @command((Int,), (bytes, bytes)) - def scan(self, cursor, *args): - return self._scan(list(self._db), cursor, *args) - - def _lookup_key(self, key, pattern): - """Python implementation of lookupKeyByPattern from redis""" - if pattern == b'#': - return key - p = pattern.find(b'*') - if p == -1: - return None - prefix = pattern[:p] - suffix = pattern[p+1:] - arrow = suffix.find(b'->', 0, -1) - if arrow != -1: - field = suffix[arrow+2:] - suffix = suffix[:arrow] - else: - field = None - new_key = prefix + key + suffix - item = CommandItem(new_key, self._db, item=self._db.get(new_key)) - if item.value is None: - return None - if field is not None: - if not isinstance(item.value, dict): - return None - return item.value.get(field) - else: - if not isinstance(item.value, bytes): - return None - return item.value - - @command((Key(),), (bytes,)) - def sort(self, key, *args): - i = 0 - desc = False - alpha = False - limit_start = 0 - limit_count = -1 - store = None - sortby = None - dontsort = False - get = [] - if key.value is not None: - if not isinstance(key.value, (set, list, ZSet)): - raise SimpleError(WRONGTYPE_MSG) - - while i < len(args): - arg = args[i] - if casematch(arg, b'asc'): - desc = False - elif casematch(arg, b'desc'): - desc = True - elif casematch(arg, b'alpha'): - alpha = True - elif casematch(arg, b'limit') and i + 2 < len(args): - try: - limit_start = Int.decode(args[i + 1]) - limit_count = Int.decode(args[i + 2]) - except SimpleError: - raise SimpleError(SYNTAX_ERROR_MSG) - else: - i += 2 - elif casematch(arg, b'store') and i + 1 < len(args): - store = args[i + 1] - i += 1 - elif casematch(arg, b'by') and i + 1 < len(args): - sortby = args[i + 1] - if b'*' not in sortby: - dontsort = True - i += 1 - elif casematch(arg, b'get') and i + 1 < len(args): - get.append(args[i + 1]) - i += 1 - else: - raise SimpleError(SYNTAX_ERROR_MSG) - i += 1 - - # TODO: force sorting if the object is a set and either in Lua or - # storing to a key, to match redis behaviour. - items = list(key.value) if key.value is not None else [] - - # These transformations are based on the redis implementation, but - # changed to produce a half-open range. - start = max(limit_start, 0) - end = len(items) if limit_count < 0 else start + limit_count - if start >= len(items): - start = end = len(items) - 1 - end = min(end, len(items)) - - if not get: - get.append(b'#') - if sortby is None: - sortby = b'#' - - if not dontsort: - if alpha: - def sort_key(v): - byval = self._lookup_key(v, sortby) - # TODO: use locale.strxfrm when not storing? But then need - # to decode too. - if byval is None: - byval = BeforeAny() - return byval - - else: - def sort_key(v): - byval = self._lookup_key(v, sortby) - score = SortFloat.decode(byval) if byval is not None else 0.0 - return (score, v) - - items.sort(key=sort_key, reverse=desc) - elif isinstance(key.value, (list, ZSet)): - items.reverse() - - out = [] - for row in items[start:end]: - for g in get: - v = self._lookup_key(row, g) - if store is not None and v is None: - v = b'' - out.append(v) - if store is not None: - item = CommandItem(store, self._db, item=self._db.get(store)) - item.value = out - item.writeback() - return len(out) - else: - return out - - @command((Key(missing_return=None),)) - def dump(self, key): - value = pickle.dumps(key.value) - checksum = hashlib.sha1(value).digest() - return checksum + value - - @command((Key(), Int, bytes), (bytes,)) - def restore(self, key, ttl, value, *args): - replace = False - i = 0 - while i < len(args): - if casematch(args[i], b'replace'): - replace = True - i += 1 - else: - raise SimpleError(SYNTAX_ERROR_MSG) - if key and not replace: - raise SimpleError(RESTORE_KEY_EXISTS) - checksum, value = value[:20], value[20:] - if hashlib.sha1(value).digest() != checksum: - raise SimpleError(RESTORE_INVALID_CHECKSUM_MSG) - if ttl < 0: - raise SimpleError(RESTORE_INVALID_TTL_MSG) - if ttl == 0: - expireat = None - else: - expireat = self._db.time + ttl / 1000.0 - key.value = pickle.loads(value) - key.expireat = expireat - return OK - - # Transaction commands - - def _clear_watches(self): - self._watch_notified = False - while self._watches: - (key, db) = self._watches.pop() - db.remove_watch(key, self) - - @command((), flags='s') - def multi(self): - if self._transaction is not None: - raise SimpleError(MULTI_NESTED_MSG) - self._transaction = [] - self._transaction_failed = False - return OK - - @command((), flags='s') - def discard(self): - if self._transaction is None: - raise SimpleError(WITHOUT_MULTI_MSG.format('DISCARD')) - self._transaction = None - self._transaction_failed = False - self._clear_watches() - return OK - - @command((), name='exec', flags='s') - def exec_(self): - if self._transaction is None: - raise SimpleError(WITHOUT_MULTI_MSG.format('EXEC')) - if self._transaction_failed: - self._transaction = None - self._clear_watches() - raise SimpleError(EXECABORT_MSG) - transaction = self._transaction - self._transaction = None - self._transaction_failed = False - watch_notified = self._watch_notified - self._clear_watches() - if watch_notified: - return None - result = [] - for func, sig, args in transaction: - try: - self._in_transaction = True - ans = self._run_command(func, sig, args, False) - except SimpleError as exc: - ans = exc - finally: - self._in_transaction = False - result.append(ans) - return result - - @command((Key(),), (Key(),), flags='s') - def watch(self, *keys): - if self._transaction is not None: - raise SimpleError(WATCH_INSIDE_MULTI_MSG) - for key in keys: - if key not in self._watches: - self._watches.add((key.key, self._db)) - self._db.add_watch(key.key, self) - return OK - - @command((), flags='s') - def unwatch(self): - self._clear_watches() - return OK - - # String commands - # TODO: bitfield, bitop, bitpos - - @command((Key(bytes), bytes)) - def append(self, key, value): - old = key.get(b'') - if len(old) + len(value) > MAX_STRING_SIZE: - raise SimpleError(STRING_OVERFLOW_MSG) - key.update(key.get(b'') + value) - return len(key.value) - - @command((Key(bytes, 0),), (bytes,)) - def bitcount(self, key, *args): - # Redis checks the argument count before decoding integers. That's why - # we can't declare them as Int. - if args: - if len(args) != 2: - raise SimpleError(SYNTAX_ERROR_MSG) - start = Int.decode(args[0]) - end = Int.decode(args[1]) - start, end = self._fix_range_string(start, end, len(key.value)) - value = key.value[start:end] - else: - value = key.value - return bin(int.from_bytes(value, 'little')).count('1') - - @command((Key(bytes), Int)) - def decrby(self, key, amount): - return self.incrby(key, -amount) - - @command((Key(bytes),)) - def decr(self, key): - return self.incrby(key, -1) - - @command((Key(bytes), Int)) - def incrby(self, key, amount): - c = Int.decode(key.get(b'0')) + amount - key.update(Int.encode(c)) - return c - - @command((Key(bytes),)) - def incr(self, key): - return self.incrby(key, 1) - - @command((Key(bytes), bytes)) - def incrbyfloat(self, key, amount): - # TODO: introduce convert_order so that we can specify amount is Float - c = Float.decode(key.get(b'0')) + Float.decode(amount) - if not math.isfinite(c): - raise SimpleError(NONFINITE_MSG) - encoded = Float.encode(c, True) - key.update(encoded) - return encoded - - @command((Key(bytes),)) - def get(self, key): - return key.get(None) - - @command((Key(bytes), BitOffset)) - def getbit(self, key, offset): - value = key.get(b'') - byte = offset // 8 - remaining = offset % 8 - actual_bitoffset = 7 - remaining - try: - actual_val = value[byte] - except IndexError: - return 0 - return 1 if (1 << actual_bitoffset) & actual_val else 0 - - @command((Key(bytes), BitOffset, BitValue)) - def setbit(self, key, offset, value): - val = key.get(b'\x00') - byte = offset // 8 - remaining = offset % 8 - actual_bitoffset = 7 - remaining - if len(val) - 1 < byte: - # We need to expand val so that we can set the appropriate - # bit. - needed = byte - (len(val) - 1) - val += b'\x00' * needed - old_byte = val[byte] - if value == 1: - new_byte = old_byte | (1 << actual_bitoffset) - else: - new_byte = old_byte & ~(1 << actual_bitoffset) - old_value = value if old_byte == new_byte else 1 - value - reconstructed = bytearray(val) - reconstructed[byte] = new_byte - key.update(bytes(reconstructed)) - return old_value - - @command((Key(bytes), Int, Int)) - def getrange(self, key, start, end): - value = key.get(b'') - start, end = self._fix_range_string(start, end, len(value)) - return value[start:end] - - # substr is a deprecated alias for getrange - @command((Key(bytes), Int, Int)) - def substr(self, key, start, end): - return self.getrange(key, start, end) - - @command((Key(bytes), bytes)) - def getset(self, key, value): - old = key.value - key.value = value - return old - - @command((Key(),), (Key(),)) - def mget(self, *keys): - return [key.value if isinstance(key.value, bytes) else None for key in keys] - - @command((Key(), bytes), (Key(), bytes)) - def mset(self, *args): - for i in range(0, len(args), 2): - args[i].value = args[i + 1] - return OK - - @command((Key(), bytes), (Key(), bytes)) - def msetnx(self, *args): - for i in range(0, len(args), 2): - if args[i]: - return 0 - for i in range(0, len(args), 2): - args[i].value = args[i + 1] - return 1 - - @command((Key(), bytes), (bytes,), name='set') - def set_(self, key, value, *args): - i = 0 - ex = None - px = None - xx = False - nx = False - keepttl = False - get = False - while i < len(args): - if casematch(args[i], b'nx'): - nx = True - i += 1 - elif casematch(args[i], b'xx'): - xx = True - i += 1 - elif casematch(args[i], b'ex') and i + 1 < len(args): - ex = Int.decode(args[i + 1]) - if ex <= 0 or (self._db.time + ex) * 1000 >= 2**63: - raise SimpleError(INVALID_EXPIRE_MSG.format('set')) - i += 2 - elif casematch(args[i], b'px') and i + 1 < len(args): - px = Int.decode(args[i + 1]) - if px <= 0 or self._db.time * 1000 + px >= 2**63: - raise SimpleError(INVALID_EXPIRE_MSG.format('set')) - i += 2 - elif casematch(args[i], b'keepttl'): - keepttl = True - i += 1 - elif casematch(args[i], b'get'): - get = True - i += 1 - else: - raise SimpleError(SYNTAX_ERROR_MSG) - if (xx and nx) or ((px is not None) + (ex is not None) + keepttl > 1): - raise SimpleError(SYNTAX_ERROR_MSG) - if nx and get: - # The command docs say this is allowed from Redis 7.0. - raise SimpleError(SYNTAX_ERROR_MSG) - - old_value = None - if get: - if key.value is not None and type(key.value) is not bytes: - raise SimpleError(WRONGTYPE_MSG) - old_value = key.value - - if nx and key: - return old_value - if xx and not key: - return old_value - if not keepttl: - key.value = value - else: - key.update(value) - if ex is not None: - key.expireat = self._db.time + ex - if px is not None: - key.expireat = self._db.time + px / 1000.0 - return OK if not get else old_value - - @command((Key(), Int, bytes)) - def setex(self, key, seconds, value): - if seconds <= 0 or (self._db.time + seconds) * 1000 >= 2**63: - raise SimpleError(INVALID_EXPIRE_MSG.format('setex')) - key.value = value - key.expireat = self._db.time + seconds - return OK - - @command((Key(), Int, bytes)) - def psetex(self, key, ms, value): - if ms <= 0 or self._db.time * 1000 + ms >= 2**63: - raise SimpleError(INVALID_EXPIRE_MSG.format('psetex')) - key.value = value - key.expireat = self._db.time + ms / 1000.0 - return OK - - @command((Key(), bytes)) - def setnx(self, key, value): - if key: - return 0 - key.value = value - return 1 - - @command((Key(bytes), Int, bytes)) - def setrange(self, key, offset, value): - if offset < 0: - raise SimpleError(INVALID_OFFSET_MSG) - elif not value: - return len(key.get(b'')) - elif offset + len(value) > MAX_STRING_SIZE: - raise SimpleError(STRING_OVERFLOW_MSG) - else: - out = key.get(b'') - if len(out) < offset: - out += b'\x00' * (offset - len(out)) - out = out[0:offset] + value + out[offset+len(value):] - key.update(out) - return len(out) - - @command((Key(bytes),)) - def strlen(self, key): - return len(key.get(b'')) - - # Hash commands - - @command((Key(Hash), bytes), (bytes,)) - def hdel(self, key, *fields): - h = key.value - rem = 0 - for field in fields: - if field in h: - del h[field] - key.updated() - rem += 1 - return rem - - @command((Key(Hash), bytes)) - def hexists(self, key, field): - return int(field in key.value) - - @command((Key(Hash), bytes)) - def hget(self, key, field): - return key.value.get(field) - - @command((Key(Hash),)) - def hgetall(self, key): - return list(itertools.chain(*key.value.items())) - - @command((Key(Hash), bytes, Int)) - def hincrby(self, key, field, amount): - c = Int.decode(key.value.get(field, b'0')) + amount - key.value[field] = Int.encode(c) - key.updated() - return c - - @command((Key(Hash), bytes, bytes)) - def hincrbyfloat(self, key, field, amount): - c = Float.decode(key.value.get(field, b'0')) + Float.decode(amount) - if not math.isfinite(c): - raise SimpleError(NONFINITE_MSG) - encoded = Float.encode(c, True) - key.value[field] = encoded - key.updated() - return encoded - - @command((Key(Hash),)) - def hkeys(self, key): - return list(key.value.keys()) - - @command((Key(Hash),)) - def hlen(self, key): - return len(key.value) - - @command((Key(Hash), bytes), (bytes,)) - def hmget(self, key, *fields): - return [key.value.get(field) for field in fields] - - @command((Key(Hash), bytes, bytes), (bytes, bytes)) - def hmset(self, key, *args): - self.hset(key, *args) - return OK - - @command((Key(Hash), Int,), (bytes, bytes)) - def hscan(self, key, cursor, *args): - cursor, keys = self._scan(key.value, cursor, *args) - items = [] - for k in keys: - items.append(k) - items.append(key.value[k]) - return [cursor, items] - - @command((Key(Hash), bytes, bytes), (bytes, bytes)) - def hset(self, key, *args): - h = key.value - created = 0 - for i in range(0, len(args), 2): - if args[i] not in h: - created += 1 - h[args[i]] = args[i + 1] - key.updated() - return created - - @command((Key(Hash), bytes, bytes)) - def hsetnx(self, key, field, value): - if field in key.value: - return 0 - return self.hset(key, field, value) - - @command((Key(Hash), bytes)) - def hstrlen(self, key, field): - return len(key.value.get(field, b'')) - - @command((Key(Hash),)) - def hvals(self, key): - return list(key.value.values()) - - # List commands - - def _bpop_pass(self, keys, op, first_pass): - for key in keys: - item = CommandItem(key, self._db, item=self._db.get(key), default=[]) - if not isinstance(item.value, list): - if first_pass: - raise SimpleError(WRONGTYPE_MSG) - else: - continue - if item.value: - ret = op(item.value) - item.updated() - item.writeback() - return [key, ret] - return None - - def _bpop(self, args, op): - keys = args[:-1] - timeout = Timeout.decode(args[-1]) - return self._blocking(timeout, functools.partial(self._bpop_pass, keys, op)) - - @command((bytes, bytes), (bytes,), flags='s') - def blpop(self, *args): - return self._bpop(args, lambda lst: lst.pop(0)) - - @command((bytes, bytes), (bytes,), flags='s') - def brpop(self, *args): - return self._bpop(args, lambda lst: lst.pop()) - - def _brpoplpush_pass(self, source, destination, first_pass): - src = CommandItem(source, self._db, item=self._db.get(source), default=[]) - if not isinstance(src.value, list): - if first_pass: - raise SimpleError(WRONGTYPE_MSG) - else: - return None - if not src.value: - return None # Empty list - dst = CommandItem(destination, self._db, item=self._db.get(destination), default=[]) - if not isinstance(dst.value, list): - raise SimpleError(WRONGTYPE_MSG) - el = src.value.pop() - dst.value.insert(0, el) - src.updated() - src.writeback() - if destination != source: - # Ensure writeback only happens once - dst.updated() - dst.writeback() - return el - - @command((bytes, bytes, Timeout), flags='s') - def brpoplpush(self, source, destination, timeout): - return self._blocking(timeout, - functools.partial(self._brpoplpush_pass, source, destination)) - - @command((Key(list, None), Int)) - def lindex(self, key, index): - try: - return key.value[index] - except IndexError: - return None - - @command((Key(list), bytes, bytes, bytes)) - def linsert(self, key, where, pivot, value): - if not casematch(where, b'before') and not casematch(where, b'after'): - raise SimpleError(SYNTAX_ERROR_MSG) - if not key: - return 0 - else: - try: - index = key.value.index(pivot) - except ValueError: - return -1 - if casematch(where, b'after'): - index += 1 - key.value.insert(index, value) - key.updated() - return len(key.value) - - @command((Key(list),)) - def llen(self, key): - return len(key.value) - - def _list_pop(self, get_slice, key, *args): - """Implements lpop and rpop. - - `get_slice` must take a count and return a slice expression for the - range to pop. - """ - # This implementation is somewhat contorted to match the odd - # behaviours described in https://github.com/redis/redis/issues/9680. - count = 1 - if len(args) > 1: - raise SimpleError(SYNTAX_ERROR_MSG) - elif len(args) == 1: - count = args[0] - if count < 0: - raise SimpleError(INDEX_ERROR_MSG) - elif count == 0: - return None - if not key: - return None - elif type(key.value) != list: - raise SimpleError(WRONGTYPE_MSG) - slc = get_slice(count) - ret = key.value[slc] - del key.value[slc] - key.updated() - if not args: - ret = ret[0] - return ret - - @command((Key(),), (Int(),)) - def lpop(self, key, *args): - return self._list_pop(lambda count: slice(None, count), key, *args) - - @command((Key(list), bytes), (bytes,)) - def lpush(self, key, *values): - for value in values: - key.value.insert(0, value) - key.updated() - return len(key.value) - - @command((Key(list), bytes), (bytes,)) - def lpushx(self, key, *values): - if not key: - return 0 - return self.lpush(key, *values) - - @command((Key(list), Int, Int)) - def lrange(self, key, start, stop): - start, stop = self._fix_range(start, stop, len(key.value)) - return key.value[start:stop] - - @command((Key(list), Int, bytes)) - def lrem(self, key, count, value): - a_list = key.value - found = [] - for i, el in enumerate(a_list): - if el == value: - found.append(i) - if count > 0: - indices_to_remove = found[:count] - elif count < 0: - indices_to_remove = found[count:] - else: - indices_to_remove = found - # Iterating in reverse order to ensure the indices - # remain valid during deletion. - for index in reversed(indices_to_remove): - del a_list[index] - if indices_to_remove: - key.updated() - return len(indices_to_remove) - - @command((Key(list), Int, bytes)) - def lset(self, key, index, value): - if not key: - raise SimpleError(NO_KEY_MSG) - try: - key.value[index] = value - key.updated() - except IndexError: - raise SimpleError(INDEX_ERROR_MSG) - return OK - - @command((Key(list), Int, Int)) - def ltrim(self, key, start, stop): - if key: - if stop == -1: - stop = None - else: - stop += 1 - new_value = key.value[start:stop] - # TODO: check if this should actually be conditional - if len(new_value) != len(key.value): - key.update(new_value) - return OK - - @command((Key(),), (Int(),)) - def rpop(self, key, *args): - return self._list_pop(lambda count: slice(None, -count - 1, -1), key, *args) - - @command((Key(list, None), Key(list))) - def rpoplpush(self, src, dst): - el = self.rpop(src) - self.lpush(dst, el) - return el - - @command((Key(list), bytes), (bytes,)) - def rpush(self, key, *values): - for value in values: - key.value.append(value) - key.updated() - return len(key.value) - - @command((Key(list), bytes), (bytes,)) - def rpushx(self, key, *values): - if not key: - return 0 - return self.rpush(key, *values) - - # Set commands - - @command((Key(set), bytes), (bytes,)) - def sadd(self, key, *members): - old_size = len(key.value) - key.value.update(members) - key.updated() - return len(key.value) - old_size - - @command((Key(set),)) - def scard(self, key): - return len(key.value) - - def _calc_setop(self, op, stop_if_missing, key, *keys): - if stop_if_missing and not key.value: - return set() - ans = key.value.copy() - for other in keys: - value = other.value if other.value is not None else set() - if not isinstance(value, set): - raise SimpleError(WRONGTYPE_MSG) - if stop_if_missing and not value: - return set() - ans = op(ans, value) - return ans - - def _setop(self, op, stop_if_missing, dst, key, *keys): - """Apply one of SINTER[STORE], SUNION[STORE], SDIFF[STORE]. - - If `stop_if_missing`, the output will be made an empty set as soon as - an empty input set is encountered (use for SINTER[STORE]). May assume - that `key` is a set (or empty), but `keys` could be anything. - """ - ans = self._calc_setop(op, stop_if_missing, key, *keys) - if dst is None: - return list(ans) - else: - dst.value = ans - return len(dst.value) - - @command((Key(set),), (Key(set),)) - def sdiff(self, *keys): - return self._setop(lambda a, b: a - b, False, None, *keys) - - @command((Key(), Key(set)), (Key(set),)) - def sdiffstore(self, dst, *keys): - return self._setop(lambda a, b: a - b, False, dst, *keys) - - @command((Key(set),), (Key(set),)) - def sinter(self, *keys): - return self._setop(lambda a, b: a & b, True, None, *keys) - - @command((Key(), Key(set)), (Key(set),)) - def sinterstore(self, dst, *keys): - return self._setop(lambda a, b: a & b, True, dst, *keys) - - @command((Key(set), bytes)) - def sismember(self, key, member): - return int(member in key.value) - - @command((Key(set),)) - def smembers(self, key): - return list(key.value) - - @command((Key(set, 0), Key(set), bytes)) - def smove(self, src, dst, member): - try: - src.value.remove(member) - src.updated() - except KeyError: - return 0 - else: - dst.value.add(member) - dst.updated() # TODO: is it updated if member was already present? - return 1 - - @command((Key(set),), (Int,)) - def spop(self, key, count=None): - if count is None: - if not key.value: - return None - item = random.sample(list(key.value), 1)[0] - key.value.remove(item) - key.updated() - return item - else: - if count < 0: - raise SimpleError(INDEX_ERROR_MSG) - items = self.srandmember(key, count) - for item in items: - key.value.remove(item) - key.updated() # Inside the loop because redis special-cases count=0 - return items - - @command((Key(set),), (Int,)) - def srandmember(self, key, count=None): - if count is None: - if not key.value: - return None - else: - return random.sample(list(key.value), 1)[0] - elif count >= 0: - count = min(count, len(key.value)) - return random.sample(list(key.value), count) - else: - items = list(key.value) - return [random.choice(items) for _ in range(-count)] - - @command((Key(set), bytes), (bytes,)) - def srem(self, key, *members): - old_size = len(key.value) - for member in members: - key.value.discard(member) - deleted = old_size - len(key.value) - if deleted: - key.updated() - return deleted - - @command((Key(set), Int), (bytes, bytes)) - def sscan(self, key, cursor, *args): - return self._scan(key.value, cursor, *args) - - @command((Key(set),), (Key(set),)) - def sunion(self, *keys): - return self._setop(lambda a, b: a | b, False, None, *keys) - - @command((Key(), Key(set)), (Key(set),)) - def sunionstore(self, dst, *keys): - return self._setop(lambda a, b: a | b, False, dst, *keys) - - # Hyperloglog commands - # These are not quite the same as the real redis ones, which are - # approximate and store the results in a string. Instead, it is implemented - # on top of sets. - - @command((Key(set),), (bytes,)) - def pfadd(self, key, *elements): - result = self.sadd(key, *elements) - # Per the documentation: - # - 1 if at least 1 HyperLogLog internal register was altered. 0 otherwise. - return 1 if result > 0 else 0 - - @command((Key(set),), (Key(set),)) - def pfcount(self, *keys): - """ - Return the approximated cardinality of - the set observed by the HyperLogLog at key(s). - """ - return len(self.sunion(*keys)) - - @command((Key(set), Key(set)), (Key(set),)) - def pfmerge(self, dest, *sources): - "Merge N different HyperLogLogs into a single one." - self.sunionstore(dest, *sources) - return OK - - # Sorted set commands - # TODO: [b]zpopmin/zpopmax, - - @staticmethod - def _limit_items(items, offset, count): - out = [] - for item in items: - if offset: # Note: not offset > 0, in order to match redis - offset -= 1 - continue - if count == 0: - break - count -= 1 - out.append(item) - return out - - @staticmethod - def _apply_withscores(items, withscores): - if withscores: - out = [] - for item in items: - out.append(item[1]) - out.append(Float.encode(item[0], False)) - else: - out = [item[1] for item in items] - return out - - @command((Key(ZSet), bytes, bytes), (bytes,)) - def zadd(self, key, *args): - zset = key.value - - i = 0 - ch = False - nx = False - xx = False - incr = False - while i < len(args): - if casematch(args[i], b'ch'): - ch = True - i += 1 - elif casematch(args[i], b'nx'): - nx = True - i += 1 - elif casematch(args[i], b'xx'): - xx = True - i += 1 - elif casematch(args[i], b'incr'): - incr = True - i += 1 - else: - # First argument not matching flags indicates the start of - # score pairs. - break - - if nx and xx: - raise SimpleError(ZADD_NX_XX_ERROR_MSG) - - elements = args[i:] - if not elements or len(elements) % 2 != 0: - raise SimpleError(SYNTAX_ERROR_MSG) - if incr and len(elements) != 2: - raise SimpleError(ZADD_INCR_LEN_ERROR_MSG) - # Parse all scores first, before updating - items = [ - (Float.decode(elements[j]), elements[j + 1]) - for j in range(0, len(elements), 2) - ] - old_len = len(zset) - changed_items = 0 - - if incr: - item_score, item_name = items[0] - if (nx and item_name in zset) or (xx and item_name not in zset): - return None - return self.zincrby(key, item_score, item_name) - - for item_score, item_name in items: - if ( - (not nx or item_name not in zset) - and (not xx or item_name in zset) - ): - if zset.add(item_name, item_score): - changed_items += 1 - - if changed_items: - key.updated() - - if ch: - return changed_items - return len(zset) - old_len - - @command((Key(ZSet),)) - def zcard(self, key): - return len(key.value) - - @command((Key(ZSet), ScoreTest, ScoreTest)) - def zcount(self, key, min, max): - return key.value.zcount(min.lower_bound, max.upper_bound) - - @command((Key(ZSet), Float, bytes)) - def zincrby(self, key, increment, member): - # Can't just default the old score to 0.0, because in IEEE754, adding - # 0.0 to something isn't a nop (e.g. 0.0 + -0.0 == 0.0). - try: - score = key.value.get(member, None) + increment - except TypeError: - score = increment - if math.isnan(score): - raise SimpleError(SCORE_NAN_MSG) - key.value[member] = score - key.updated() - return Float.encode(score, False) - - @command((Key(ZSet), StringTest, StringTest)) - def zlexcount(self, key, min, max): - return key.value.zlexcount(min.value, min.exclusive, max.value, max.exclusive) - - def _zrange(self, key, start, stop, reverse, *args): - zset = key.value - withscores = False - for arg in args: - if casematch(arg, b'withscores'): - withscores = True - else: - raise SimpleError(SYNTAX_ERROR_MSG) - start, stop = self._fix_range(start, stop, len(zset)) - if reverse: - start, stop = len(zset) - stop, len(zset) - start - items = zset.islice_score(start, stop, reverse) - items = self._apply_withscores(items, withscores) - return items - - @command((Key(ZSet), Int, Int), (bytes,)) - def zrange(self, key, start, stop, *args): - return self._zrange(key, start, stop, False, *args) - - @command((Key(ZSet), Int, Int), (bytes,)) - def zrevrange(self, key, start, stop, *args): - return self._zrange(key, start, stop, True, *args) - - def _zrangebylex(self, key, min, max, reverse, *args): - if args: - if len(args) != 3 or not casematch(args[0], b'limit'): - raise SimpleError(SYNTAX_ERROR_MSG) - offset = Int.decode(args[1]) - count = Int.decode(args[2]) - else: - offset = 0 - count = -1 - zset = key.value - items = zset.irange_lex(min.value, max.value, - inclusive=(not min.exclusive, not max.exclusive), - reverse=reverse) - items = self._limit_items(items, offset, count) - return items - - @command((Key(ZSet), StringTest, StringTest), (bytes,)) - def zrangebylex(self, key, min, max, *args): - return self._zrangebylex(key, min, max, False, *args) - - @command((Key(ZSet), StringTest, StringTest), (bytes,)) - def zrevrangebylex(self, key, max, min, *args): - return self._zrangebylex(key, min, max, True, *args) - - def _zrangebyscore(self, key, min, max, reverse, *args): - withscores = False - offset = 0 - count = -1 - i = 0 - while i < len(args): - if casematch(args[i], b'withscores'): - withscores = True - i += 1 - elif casematch(args[i], b'limit') and i + 2 < len(args): - offset = Int.decode(args[i + 1]) - count = Int.decode(args[i + 2]) - i += 3 - else: - raise SimpleError(SYNTAX_ERROR_MSG) - zset = key.value - items = list(zset.irange_score(min.lower_bound, max.upper_bound, reverse=reverse)) - items = self._limit_items(items, offset, count) - items = self._apply_withscores(items, withscores) - return items - - @command((Key(ZSet), ScoreTest, ScoreTest), (bytes,)) - def zrangebyscore(self, key, min, max, *args): - return self._zrangebyscore(key, min, max, False, *args) - - @command((Key(ZSet), ScoreTest, ScoreTest), (bytes,)) - def zrevrangebyscore(self, key, max, min, *args): - return self._zrangebyscore(key, min, max, True, *args) - - @command((Key(ZSet), bytes)) - def zrank(self, key, member): - try: - return key.value.rank(member) - except KeyError: - return None - - @command((Key(ZSet), bytes)) - def zrevrank(self, key, member): - try: - return len(key.value) - 1 - key.value.rank(member) - except KeyError: - return None - - @command((Key(ZSet), bytes), (bytes,)) - def zrem(self, key, *members): - old_size = len(key.value) - for member in members: - key.value.discard(member) - deleted = old_size - len(key.value) - if deleted: - key.updated() - return deleted - - @command((Key(ZSet), StringTest, StringTest)) - def zremrangebylex(self, key, min, max): - items = key.value.irange_lex(min.value, max.value, - inclusive=(not min.exclusive, not max.exclusive)) - return self.zrem(key, *items) - - @command((Key(ZSet), ScoreTest, ScoreTest)) - def zremrangebyscore(self, key, min, max): - items = key.value.irange_score(min.lower_bound, max.upper_bound) - return self.zrem(key, *[item[1] for item in items]) - - @command((Key(ZSet), Int, Int)) - def zremrangebyrank(self, key, start, stop): - zset = key.value - start, stop = self._fix_range(start, stop, len(zset)) - items = zset.islice_score(start, stop) - return self.zrem(key, *[item[1] for item in items]) - - @command((Key(ZSet), Int), (bytes, bytes)) - def zscan(self, key, cursor, *args): - new_cursor, ans = self._scan(key.value.items(), cursor, *args) - flat = [] - for (key, score) in ans: - flat.append(key) - flat.append(Float.encode(score, False)) - return [new_cursor, flat] - - @command((Key(ZSet), bytes)) - def zscore(self, key, member): - try: - return Float.encode(key.value[member], False) - except KeyError: - return None - - @staticmethod - def _get_zset(value): - if isinstance(value, set): - zset = ZSet() - for item in value: - zset[item] = 1.0 - return zset - elif isinstance(value, ZSet): - return value - else: - raise SimpleError(WRONGTYPE_MSG) - - def _zunioninter(self, func, dest, numkeys, *args): - if numkeys < 1: - raise SimpleError(ZUNIONSTORE_KEYS_MSG) - if numkeys > len(args): - raise SimpleError(SYNTAX_ERROR_MSG) - aggregate = b'sum' - sets = [] - for i in range(numkeys): - item = CommandItem(args[i], self._db, item=self._db.get(args[i]), default=ZSet()) - sets.append(self._get_zset(item.value)) - weights = [1.0] * numkeys - - i = numkeys - while i < len(args): - arg = args[i] - if casematch(arg, b'weights') and i + numkeys < len(args): - weights = [Float.decode(x) for x in args[i + 1:i + numkeys + 1]] - i += numkeys + 1 - elif casematch(arg, b'aggregate') and i + 1 < len(args): - aggregate = casenorm(args[i + 1]) - if aggregate not in (b'sum', b'min', b'max'): - raise SimpleError(SYNTAX_ERROR_MSG) - i += 2 - else: - raise SimpleError(SYNTAX_ERROR_MSG) - - out_members = set(sets[0]) - for s in sets[1:]: - if func == 'ZUNIONSTORE': - out_members |= set(s) - else: - out_members.intersection_update(s) - - # We first build a regular dict and turn it into a ZSet. The - # reason is subtle: a ZSet won't update a score from -0 to +0 - # (or vice versa) through assignment, but a regular dict will. - out = {} - # The sort affects the order of floating-point operations. - # Note that redis uses qsort(1), which has no stability guarantees, - # so we can't be sure to match it in all cases. - for s, w in sorted(zip(sets, weights), key=lambda x: len(x[0])): - for member, score in s.items(): - score *= w - # Redis only does this step for ZUNIONSTORE. See - # https://github.com/antirez/redis/issues/3954. - if func == 'ZUNIONSTORE' and math.isnan(score): - score = 0.0 - if member not in out_members: - continue - if member in out: - old = out[member] - if aggregate == b'sum': - score += old - if math.isnan(score): - score = 0.0 - elif aggregate == b'max': - score = max(old, score) - elif aggregate == b'min': - score = min(old, score) - else: - assert False # pragma: nocover - if math.isnan(score): - score = 0.0 - out[member] = score - - out_zset = ZSet() - for member, score in out.items(): - out_zset[member] = score - - dest.value = out_zset - return len(out_zset) - - @command((Key(), Int, bytes), (bytes,)) - def zunionstore(self, dest, numkeys, *args): - return self._zunioninter('ZUNIONSTORE', dest, numkeys, *args) - - @command((Key(), Int, bytes), (bytes,)) - def zinterstore(self, dest, numkeys, *args): - return self._zunioninter('ZINTERSTORE', dest, numkeys, *args) - - # Server commands - # TODO: lots - - @command((), (bytes,), flags='s') - def bgsave(self, *args): - if len(args) > 1 or (len(args) == 1 and not casematch(args[0], b'schedule')): - raise SimpleError(SYNTAX_ERROR_MSG) - self._server.lastsave = int(time.time()) - return BGSAVE_STARTED - - @command(()) - def dbsize(self): - return len(self._db) - - @command((), (bytes,)) - def flushdb(self, *args): - if args: - if len(args) != 1 or not casematch(args[0], b'async'): - raise SimpleError(SYNTAX_ERROR_MSG) - self._db.clear() - return OK - - @command((), (bytes,)) - def flushall(self, *args): - if args: - if len(args) != 1 or not casematch(args[0], b'async'): - raise SimpleError(SYNTAX_ERROR_MSG) - for db in self._server.dbs.values(): - db.clear() - # TODO: clear watches and/or pubsub as well? - return OK - - @command(()) - def lastsave(self): - return self._server.lastsave - - @command((), flags='s') - def save(self): - self._server.lastsave = int(time.time()) - return OK - - @command(()) - def time(self): - now_us = round(time.time() * 1000000) - now_s = now_us // 1000000 - now_us %= 1000000 - return [str(now_s).encode(), str(now_us).encode()] - - # Script commands - # script debug and script kill will probably not be supported - - def _convert_redis_arg(self, lua_runtime, value): - # Type checks are exact to avoid issues like bool being a subclass of int. - if type(value) is bytes: - return value - elif type(value) in {int, float}: - return '{:.17g}'.format(value).encode() - else: - # TODO: add the context - raise SimpleError(LUA_COMMAND_ARG_MSG) - - def _convert_redis_result(self, lua_runtime, result): - if isinstance(result, (bytes, int)): - return result - elif isinstance(result, SimpleString): - return lua_runtime.table_from({b"ok": result.value}) - elif result is None: - return False - elif isinstance(result, list): - converted = [ - self._convert_redis_result(lua_runtime, item) - for item in result - ] - return lua_runtime.table_from(converted) - elif isinstance(result, SimpleError): - raise result - else: - raise RuntimeError("Unexpected return type from redis: {}".format(type(result))) - - def _convert_lua_result(self, result, nested=True): - from lupa import lua_type - if lua_type(result) == 'table': - for key in (b'ok', b'err'): - if key in result: - msg = self._convert_lua_result(result[key]) - if not isinstance(msg, bytes): - raise SimpleError(LUA_WRONG_NUMBER_ARGS_MSG) - if key == b'ok': - return SimpleString(msg) - elif nested: - return SimpleError(msg.decode('utf-8', 'replace')) - else: - raise SimpleError(msg.decode('utf-8', 'replace')) - # Convert Lua tables into lists, starting from index 1, mimicking the behavior of StrictRedis. - result_list = [] - for index in itertools.count(1): - if index not in result: - break - item = result[index] - result_list.append(self._convert_lua_result(item)) - return result_list - elif isinstance(result, str): - return result.encode() - elif isinstance(result, float): - return int(result) - elif isinstance(result, bool): - return 1 if result else None - return result - - def _check_for_lua_globals(self, lua_runtime, expected_globals): - actual_globals = set(lua_runtime.globals().keys()) - if actual_globals != expected_globals: - unexpected = [six.ensure_str(var, 'utf-8', 'replace') - for var in actual_globals - expected_globals] - raise SimpleError(GLOBAL_VARIABLE_MSG.format(", ".join(unexpected))) - - def _lua_redis_call(self, lua_runtime, expected_globals, op, *args): - # Check if we've set any global variables before making any change. - self._check_for_lua_globals(lua_runtime, expected_globals) - func, func_name = self._name_to_func(op) - args = [self._convert_redis_arg(lua_runtime, arg) for arg in args] - result = self._run_command(func, func._fakeredis_sig, args, True) - return self._convert_redis_result(lua_runtime, result) - - def _lua_redis_pcall(self, lua_runtime, expected_globals, op, *args): - try: - return self._lua_redis_call(lua_runtime, expected_globals, op, *args) - except Exception as ex: - return lua_runtime.table_from({b"err": str(ex)}) - - def _lua_redis_log(self, lua_runtime, expected_globals, lvl, *args): - self._check_for_lua_globals(lua_runtime, expected_globals) - if len(args) < 1: - raise SimpleError(REQUIRES_MORE_ARGS_MSG.format("redis.log()", "two")) - if lvl not in REDIS_LOG_LEVELS.values(): - raise SimpleError(LOG_INVALID_DEBUG_LEVEL_MSG) - msg = ' '.join([x.decode('utf-8') - if isinstance(x, bytes) else str(x) - for x in args if not isinstance(x, bool)]) - LOGGER.log(REDIS_LOG_LEVELS_TO_LOGGING[lvl], msg) - - @command((bytes, Int), (bytes,), flags='s') - def eval(self, script, numkeys, *keys_and_args): - from lupa import LuaError, LuaRuntime, as_attrgetter - - if numkeys > len(keys_and_args): - raise SimpleError(TOO_MANY_KEYS_MSG) - if numkeys < 0: - raise SimpleError(NEGATIVE_KEYS_MSG) - sha1 = hashlib.sha1(script).hexdigest().encode() - self._server.script_cache[sha1] = script - lua_runtime = LuaRuntime(encoding=None, unpack_returned_tuples=True) - - set_globals = lua_runtime.eval( - """ - function(keys, argv, redis_call, redis_pcall, redis_log, redis_log_levels) - redis = {} - redis.call = redis_call - redis.pcall = redis_pcall - redis.log = redis_log - for level, pylevel in python.iterex(redis_log_levels.items()) do - redis[level] = pylevel - end - redis.error_reply = function(msg) return {err=msg} end - redis.status_reply = function(msg) return {ok=msg} end - KEYS = keys - ARGV = argv - end - """ - ) - expected_globals = set() - set_globals( - lua_runtime.table_from(keys_and_args[:numkeys]), - lua_runtime.table_from(keys_and_args[numkeys:]), - functools.partial(self._lua_redis_call, lua_runtime, expected_globals), - functools.partial(self._lua_redis_pcall, lua_runtime, expected_globals), - functools.partial(self._lua_redis_log, lua_runtime, expected_globals), - as_attrgetter(REDIS_LOG_LEVELS) - ) - expected_globals.update(lua_runtime.globals().keys()) - - try: - result = lua_runtime.execute(script) - except (LuaError, SimpleError) as ex: - raise SimpleError(SCRIPT_ERROR_MSG.format(sha1.decode(), ex)) - - self._check_for_lua_globals(lua_runtime, expected_globals) - - return self._convert_lua_result(result, nested=False) - - @command((bytes, Int), (bytes,), flags='s') - def evalsha(self, sha1, numkeys, *keys_and_args): - try: - script = self._server.script_cache[sha1] - except KeyError: - raise SimpleError(NO_MATCHING_SCRIPT_MSG) - return self.eval(script, numkeys, *keys_and_args) - - @command((bytes,), (bytes,), flags='s') - def script(self, subcmd, *args): - if casematch(subcmd, b'load'): - if len(args) != 1: - raise SimpleError(BAD_SUBCOMMAND_MSG.format('SCRIPT')) - script = args[0] - sha1 = hashlib.sha1(script).hexdigest().encode() - self._server.script_cache[sha1] = script - return sha1 - elif casematch(subcmd, b'exists'): - return [int(sha1 in self._server.script_cache) for sha1 in args] - elif casematch(subcmd, b'flush'): - if len(args) > 1 or (len(args) == 1 and casenorm(args[0]) not in {b'sync', b'async'}): - raise SimpleError(BAD_SUBCOMMAND_MSG.format('SCRIPT')) - self._server.script_cache = {} - return OK - else: - raise SimpleError(BAD_SUBCOMMAND_MSG.format('SCRIPT')) - - # Pubsub commands - # TODO: pubsub command - - def _subscribe(self, channels, subscribers, mtype): - for channel in channels: - subs = subscribers[channel] - if self not in subs: - subs.add(self) - self._pubsub += 1 - msg = [mtype, channel, self._pubsub] - self.put_response(msg) - return NoResponse() - - def _unsubscribe(self, channels, subscribers, mtype): - if not channels: - channels = [] - for (channel, subs) in subscribers.items(): - if self in subs: - channels.append(channel) - for channel in channels: - subs = subscribers.get(channel, set()) - if self in subs: - subs.remove(self) - if not subs: - del subscribers[channel] - self._pubsub -= 1 - msg = [mtype, channel, self._pubsub] - self.put_response(msg) - return NoResponse() - - @command((bytes,), (bytes,), flags='s') - def psubscribe(self, *patterns): - return self._subscribe(patterns, self._server.psubscribers, b'psubscribe') - - @command((bytes,), (bytes,), flags='s') - def subscribe(self, *channels): - return self._subscribe(channels, self._server.subscribers, b'subscribe') - - @command((), (bytes,), flags='s') - def punsubscribe(self, *patterns): - return self._unsubscribe(patterns, self._server.psubscribers, b'punsubscribe') - - @command((), (bytes,), flags='s') - def unsubscribe(self, *channels): - return self._unsubscribe(channels, self._server.subscribers, b'unsubscribe') - - @command((bytes, bytes)) - def publish(self, channel, message): - receivers = 0 - msg = [b'message', channel, message] - subs = self._server.subscribers.get(channel, set()) - for sock in subs: - sock.put_response(msg) - receivers += 1 - for (pattern, socks) in self._server.psubscribers.items(): - regex = compile_pattern(pattern) - if regex.match(channel): - msg = [b'pmessage', pattern, channel, message] - for sock in socks: - sock.put_response(msg) - receivers += 1 - return receivers - - -setattr(FakeSocket, 'del', FakeSocket.del_) -delattr(FakeSocket, 'del_') -setattr(FakeSocket, 'set', FakeSocket.set_) -delattr(FakeSocket, 'set_') -setattr(FakeSocket, 'exec', FakeSocket.exec_) -delattr(FakeSocket, 'exec_') - - -class _DummyParser: - def __init__(self, socket_read_size): - self.socket_read_size = socket_read_size - - def on_disconnect(self): - pass - - def on_connect(self, connection): - pass - - -# Redis <3.2 will not have a selector -try: - from redis.selector import BaseSelector -except ImportError: - class BaseSelector: - def __init__(self, sock): - self.sock = sock - - -class FakeSelector(BaseSelector): - def check_can_read(self, timeout): - if self.sock.responses.qsize(): - return True - if timeout is not None and timeout <= 0: - return False - - # A sleep/poll loop is easier to mock out than messing with condition - # variables. - start = time.time() - while True: - if self.sock.responses.qsize(): - return True - time.sleep(0.01) - now = time.time() - if timeout is not None and now > start + timeout: - return False - - def check_is_ready_for_command(self, timeout): - return True - - -class FakeConnection(redis.Connection): - description_format = "FakeConnection" - - def __init__(self, *args, **kwargs): - self._server = kwargs.pop('server') - super().__init__(*args, **kwargs) - - def connect(self): - super().connect() - # The selector is set in redis.Connection.connect() after _connect() is called - self._selector = FakeSelector(self._sock) - - def _connect(self): - if not self._server.connected: - raise redis.ConnectionError(CONNECTION_ERROR_MSG) - return FakeSocket(self._server) - - def can_read(self, timeout=0): - if not self._server.connected: - return True - if not self._sock: - self.connect() - # We use check_can_read rather than can_read, because on redis-py<3.2, - # FakeSelector inherits from a stub BaseSelector which doesn't - # implement can_read. Normally can_read provides retries on EINTR, - # but that's not necessary for the implementation of - # FakeSelector.check_can_read. - return self._selector.check_can_read(timeout) - - def _decode(self, response): - if isinstance(response, list): - return [self._decode(item) for item in response] - elif isinstance(response, bytes): - return self.encoder.decode(response) - else: - return response - - def read_response(self, disable_decoding=False): - if not self._server.connected: - try: - response = self._sock.responses.get_nowait() - except queue.Empty: - raise redis.ConnectionError(CONNECTION_ERROR_MSG) - else: - response = self._sock.responses.get() - if isinstance(response, redis.ResponseError): - raise response - if disable_decoding: - return response - else: - return self._decode(response) - - def repr_pieces(self): - pieces = [ - ('server', self._server), - ('db', self.db) - ] - if self.client_name: - pieces.append(('client_name', self.client_name)) - return pieces - - -class FakeRedisMixin: - def __init__(self, *args, server=None, connected=True, **kwargs): - # Interpret the positional and keyword arguments according to the - # version of redis in use. - bound = _ORIG_SIG.bind(*args, **kwargs) - bound.apply_defaults() - if not bound.arguments['connection_pool']: - charset = bound.arguments['charset'] - errors = bound.arguments['errors'] - # Adapted from redis-py - if charset is not None: - warnings.warn(DeprecationWarning( - '"charset" is deprecated. Use "encoding" instead')) - bound.arguments['encoding'] = charset - if errors is not None: - warnings.warn(DeprecationWarning( - '"errors" is deprecated. Use "encoding_errors" instead')) - bound.arguments['encoding_errors'] = errors - - if server is None: - server = FakeServer() - server.connected = connected - kwargs = { - 'connection_class': FakeConnection, - 'server': server - } - conn_pool_args = [ - 'db', - 'username', - 'password', - 'socket_timeout', - 'encoding', - 'encoding_errors', - 'decode_responses', - 'retry_on_timeout', - 'max_connections', - 'health_check_interval', - 'client_name' - ] - for arg in conn_pool_args: - if arg in bound.arguments: - kwargs[arg] = bound.arguments[arg] - bound.arguments['connection_pool'] = redis.connection.ConnectionPool(**kwargs) - super().__init__(*bound.args, **bound.kwargs) - - @classmethod - def from_url(https://github.com/cls,%20*args,%20**kwargs): - server = kwargs.pop('server', None) - if server is None: - server = FakeServer() - pool = redis.ConnectionPool.from_url(https://github.com/*args,%20**kwargs) - # Now override how it creates connections - pool.connection_class = FakeConnection - pool.connection_kwargs['server'] = server - # FakeConnection cannot handle the path kwarg (present when from_url - # is called with a unix socket) - pool.connection_kwargs.pop('path', None) - return cls(connection_pool=pool) - - -class FakeStrictRedis(FakeRedisMixin, redis.StrictRedis): - pass - - -class FakeRedis(FakeRedisMixin, redis.Redis): - pass diff --git a/fakeredis/_zset.py b/fakeredis/_zset.py deleted file mode 100644 index 47d1169..0000000 --- a/fakeredis/_zset.py +++ /dev/null @@ -1,87 +0,0 @@ -import sortedcontainers - - -class ZSet: - def __init__(self): - self._bylex = {} # Maps value to score - self._byscore = sortedcontainers.SortedList() - - def __contains__(self, value): - return value in self._bylex - - def add(self, value, score): - """Update the item and return whether it modified the zset""" - old_score = self._bylex.get(value, None) - if old_score is not None: - if score == old_score: - return False - self._byscore.remove((old_score, value)) - self._bylex[value] = score - self._byscore.add((score, value)) - return True - - def __setitem__(self, value, score): - self.add(value, score) - - def __getitem__(self, key): - return self._bylex[key] - - def get(self, key, default=None): - return self._bylex.get(key, default) - - def __len__(self): - return len(self._bylex) - - def __iter__(self): - def gen(): - for score, value in self._byscore: - yield value - - return gen() - - def discard(self, key): - try: - score = self._bylex.pop(key) - except KeyError: - return - else: - self._byscore.remove((score, key)) - - def zcount(self, min_, max_): - pos1 = self._byscore.bisect_left(min_) - pos2 = self._byscore.bisect_left(max_) - return max(0, pos2 - pos1) - - def zlexcount(self, min_value, min_exclusive, max_value, max_exclusive): - if not self._byscore: - return 0 - score = self._byscore[0][0] - if min_exclusive: - pos1 = self._byscore.bisect_right((score, min_value)) - else: - pos1 = self._byscore.bisect_left((score, min_value)) - if max_exclusive: - pos2 = self._byscore.bisect_left((score, max_value)) - else: - pos2 = self._byscore.bisect_right((score, max_value)) - return max(0, pos2 - pos1) - - def islice_score(self, start, stop, reverse=False): - return self._byscore.islice(start, stop, reverse) - - def irange_lex(self, start, stop, inclusive=(True, True), reverse=False): - if not self._byscore: - return iter([]) - score = self._byscore[0][0] - it = self._byscore.irange((score, start), (score, stop), - inclusive=inclusive, reverse=reverse) - return (item[1] for item in it) - - def irange_score(self, start, stop, reverse=False): - return self._byscore.irange(start, stop, reverse=reverse) - - def rank(self, member): - return self._byscore.index((self._bylex[member], member)) - - def items(self): - return self._bylex.items() diff --git a/fakeredis/aioredis.py b/fakeredis/aioredis.py deleted file mode 100644 index 7d5ba08..0000000 --- a/fakeredis/aioredis.py +++ /dev/null @@ -1,10 +0,0 @@ -import aioredis -import packaging.version - - -if packaging.version.Version(aioredis.__version__) >= packaging.version.Version('2.0.0a1'): - from ._aioredis2 import FakeConnection, FakeRedis # noqa: F401 -else: - from ._aioredis1 import ( # noqa: F401 - FakeConnectionsPool, create_connection, create_redis, create_pool, create_redis_pool - ) diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index 2a39a17..0000000 --- a/pyproject.toml +++ /dev/null @@ -1,2 +0,0 @@ -[build-system] -requires = ["setuptools", "wheel", "setuptools-scm"] diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index abadf71..0000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,4 +0,0 @@ -invoke==0.22.1 -wheel==0.31.1 -tox==3.6.1 -twine==1.12.1 diff --git a/requirements.in b/requirements.in deleted file mode 100644 index a5d59d7..0000000 --- a/requirements.in +++ /dev/null @@ -1,15 +0,0 @@ -aioredis -coverage -flake8 -hypothesis -lupa -pytest -pytest-asyncio -pytest-cov -pytest-mock -redis==4.1.3 # Latest at time of writing -six -sortedcontainers - -# Not needed directly, but the latest versions don't support Python 3.5 -zipp<2 diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index a695ef0..0000000 --- a/requirements.txt +++ /dev/null @@ -1,72 +0,0 @@ -# -# This file is autogenerated by pip-compile with python 3.8 -# To update, run: -# -# pip-compile requirements.in -# -aioredis==1.3.1 - # via -r requirements.in -async-timeout==3.0.1 - # via aioredis -attrs==20.3.0 - # via - # hypothesis - # pytest -coverage==5.3 - # via - # -r requirements.in - # pytest-cov -deprecated==1.2.13 - # via redis -flake8==3.8.4 - # via -r requirements.in -hiredis==1.1.0 - # via aioredis -hypothesis==5.41.4 - # via -r requirements.in -iniconfig==1.1.1 - # via pytest -lupa==1.10 - # via -r requirements.in -mccabe==0.6.1 - # via flake8 -packaging==21.3 - # via - # pytest - # redis -pluggy==0.13.1 - # via pytest -py==1.10.0 - # via pytest -pycodestyle==2.6.0 - # via flake8 -pyflakes==2.2.0 - # via flake8 -pyparsing==2.4.7 - # via packaging -pytest==6.2.5 - # via - # -r requirements.in - # pytest-asyncio - # pytest-cov - # pytest-mock -pytest-asyncio==0.15.1 - # via -r requirements.in -pytest-cov==2.10.1 - # via -r requirements.in -pytest-mock==3.3.1 - # via -r requirements.in -redis==4.1.3 - # via -r requirements.in -six==1.15.0 - # via -r requirements.in -sortedcontainers==2.3.0 - # via - # -r requirements.in - # hypothesis -toml==0.10.2 - # via pytest -wrapt==1.13.3 - # via deprecated -zipp==1.2.0 - # via -r requirements.in diff --git a/scripts/supported b/scripts/supported deleted file mode 100755 index ffddc82..0000000 --- a/scripts/supported +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/env python - -# Script will import fakeredis and list what -# commands it does not have support for, based on the -# command list from: -# https://raw.github.com/antirez/redis-doc/master/commands.json -# Because, who wants to do this by hand... - -from __future__ import print_function -import os -import json -import inspect -from collections import OrderedDict -import requests - -import fakeredis - -THIS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__))) -COMMANDS_FILE = os.path.join(THIS_DIR, '.commands.json') -COMMANDS_URL = 'https://raw.github.com/antirez/redis-doc/master/commands.json' - -if not os.path.exists(COMMANDS_FILE): - contents = requests.get(COMMANDS_URL).content - open(COMMANDS_FILE, 'wb').write(contents) -commands = json.load(open(COMMANDS_FILE), object_pairs_hook=OrderedDict) -for k, v in list(commands.items()): - commands[k.lower()] = v - del commands[k] - - -implemented_commands = set() -for name, method in inspect.getmembers(fakeredis._server.FakeSocket): - if hasattr(method, '_fakeredis_sig'): - implemented_commands.add(name) -# Currently no programmatic way to discover implemented subcommands -implemented_commands.add('script load') - -unimplemented_commands = [] -for command in commands: - if command not in implemented_commands: - unimplemented_commands.append(command) - -# Group by 'group' for easier to read output -groups = OrderedDict() -for command in unimplemented_commands: - group = commands[command]['group'] - groups.setdefault(group, []).append(command) - -print(""" - -Unimplemented Commands -====================== - -All of the redis commands are implemented in fakeredis with -these exceptions: - -""") - -for group in groups: - print(group) - print("-" * len(str(group))) - print() - for command in groups[group]: - print(" *", command) - print("\n") diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index a7b6b48..0000000 --- a/setup.cfg +++ /dev/null @@ -1,52 +0,0 @@ -[metadata] -name = fakeredis -version = attr: fakeredis.__version__ -description = Fake implementation of redis API for testing purposes. -long_description = file: README.rst -long_description_content_type = text/x-rst -license = BSD -url = https://github.com/jamesls/fakeredis -author = James Saryerwinnie -author_email = js@jamesls.com -maintainer = Bruce Merry -maintainer_email = bmerry@sarao.ac.za -classifiers = - Development Status :: 5 - Production/Stable - License :: OSI Approved :: BSD License - Programming Language :: Python :: 3 - Programming Language :: Python :: 3.6 - Programming Language :: Python :: 3.7 - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 - -[options] -packages = fakeredis -install_requires = - packaging - # Minor version updates to redis tend to break fakeredis. If you - # need to use fakeredis with a newer redis, please submit a PR that - # relaxes this restriction and adds it to the Github Actions tests. - redis<4.2.0 - six>=1.12 - sortedcontainers -python_requires = >=3.5 - -[options.extras_require] -lua = - lupa -aioredis = - aioredis - -# Tool configurations below here - -[flake8] -max-line-length = 119 - -[tool:pytest] -markers = - slow: marks tests as slow (deselect with '-m "not slow"') - real: tests that run only against a real redis server - fake: tests that run only against fakeredis and do not require a real redis - disconnected - decode_responses - min_server diff --git a/setup.py b/setup.py deleted file mode 100644 index b024da8..0000000 --- a/setup.py +++ /dev/null @@ -1,4 +0,0 @@ -from setuptools import setup - - -setup() diff --git a/test/conftest.py b/test/conftest.py deleted file mode 100644 index 8a51b47..0000000 --- a/test/conftest.py +++ /dev/null @@ -1,24 +0,0 @@ -import pytest -import redis - -import fakeredis - - -@pytest.fixture(scope="session") -def is_redis_running(): - try: - r = redis.StrictRedis('localhost', port=6379) - r.ping() - return True - except redis.ConnectionError: - return False - finally: - if hasattr(r, 'close'): - r.close() # Absent in older versions of redis-py - - -@pytest.fixture -def fake_server(request): - server = fakeredis.FakeServer() - server.connected = request.node.get_closest_marker('disconnected') is None - return server diff --git a/test/test_aioredis1.py b/test/test_aioredis1.py deleted file mode 100644 index fe4082a..0000000 --- a/test/test_aioredis1.py +++ /dev/null @@ -1,158 +0,0 @@ -import asyncio - -from packaging.version import Version -import pytest -import aioredis - -import fakeredis.aioredis - - -aioredis2 = Version(aioredis.__version__) >= Version('2.0.0a1') -pytestmark = [ - pytest.mark.asyncio, - pytest.mark.skipif(aioredis2, reason="Test is only applicable to aioredis 1.x") -] - - -@pytest.fixture( - params=[ - pytest.param('fake', marks=pytest.mark.fake), - pytest.param('real', marks=pytest.mark.real) - ] -) -async def r(request): - if request.param == 'fake': - ret = await fakeredis.aioredis.create_redis_pool() - else: - if not request.getfixturevalue('is_redis_running'): - pytest.skip('Redis is not running') - ret = await aioredis.create_redis_pool('redis://localhost') - await ret.flushall() - - yield ret - - await ret.flushall() - ret.close() - await ret.wait_closed() - - -@pytest.fixture -async def conn(r): - """A single connection, rather than a pool.""" - with await r as conn: - yield conn - - -async def test_ping(r): - pong = await r.ping() - assert pong == b'PONG' - - -async def test_types(r): - await r.hmset_dict('hash', key1='value1', key2='value2', key3=123) - result = await r.hgetall('hash', encoding='utf-8') - assert result == { - 'key1': 'value1', - 'key2': 'value2', - 'key3': '123' - } - - -async def test_transaction(r): - tr = r.multi_exec() - tr.set('key1', 'value1') - tr.set('key2', 'value2') - ok1, ok2 = await tr.execute() - assert ok1 - assert ok2 - result = await r.get('key1') - assert result == b'value1' - - -async def test_transaction_fail(r, conn): - # ensure that the WATCH applies to the same connection as the MULTI/EXEC. - await r.set('foo', '1') - await conn.watch('foo') - await conn.set('foo', '2') # Different connection - tr = conn.multi_exec() - tr.get('foo') - with pytest.raises(aioredis.MultiExecError): - await tr.execute() - - -async def test_pubsub(r, event_loop): - ch, = await r.subscribe('channel') - queue = asyncio.Queue() - - async def reader(channel): - async for message in ch.iter(): - queue.put_nowait(message) - - task = event_loop.create_task(reader(ch)) - await r.publish('channel', 'message1') - await r.publish('channel', 'message2') - result1 = await queue.get() - result2 = await queue.get() - assert result1 == b'message1' - assert result2 == b'message2' - ch.close() - await task - - -async def test_blocking_ready(r, conn): - """Blocking command which does not need to block.""" - await r.rpush('list', 'x') - result = await conn.blpop('list', timeout=1) - assert result == [b'list', b'x'] - - -@pytest.mark.slow -async def test_blocking_timeout(conn): - """Blocking command that times out without completing.""" - result = await conn.blpop('missing', timeout=1) - assert result is None - - -@pytest.mark.slow -async def test_blocking_unblock(r, conn, event_loop): - """Blocking command that gets unblocked after some time.""" - async def unblock(): - await asyncio.sleep(0.1) - await r.rpush('list', 'y') - - task = event_loop.create_task(unblock()) - result = await conn.blpop('list', timeout=1) - assert result == [b'list', b'y'] - await task - - -@pytest.mark.slow -async def test_blocking_pipeline(conn): - """Blocking command with another command issued behind it.""" - await conn.set('foo', 'bar') - fut = asyncio.ensure_future(conn.blpop('list', timeout=1)) - assert (await conn.get('foo')) == b'bar' - assert (await fut) is None - - -async def test_wrongtype_error(r): - await r.set('foo', 'bar') - with pytest.raises(aioredis.ReplyError, match='^WRONGTYPE'): - await r.rpush('foo', 'baz') - - -async def test_syntax_error(r): - with pytest.raises(aioredis.ReplyError, - match="^ERR wrong number of arguments for 'get' command$"): - await r.execute('get') - - -async def test_no_script_error(r): - with pytest.raises(aioredis.ReplyError, match='^NOSCRIPT '): - await r.evalsha('0123456789abcdef0123456789abcdef') - - -async def test_failed_script_error(r): - await r.set('foo', 'bar') - with pytest.raises(aioredis.ReplyError, match='^ERR Error running script'): - await r.eval('return redis.call("ZCOUNT", KEYS[1])', ['foo']) diff --git a/test/test_aioredis2.py b/test/test_aioredis2.py deleted file mode 100644 index ac21443..0000000 --- a/test/test_aioredis2.py +++ /dev/null @@ -1,252 +0,0 @@ -import asyncio -import re - -from packaging.version import Version -import pytest -import aioredis -import async_timeout - -import fakeredis.aioredis - - -aioredis2 = Version(aioredis.__version__) >= Version('2.0.0a1') -pytestmark = [ - pytest.mark.asyncio, - pytest.mark.skipif(not aioredis2, reason="Test is only applicable to aioredis 2.x") -] -fake_only = pytest.mark.parametrize( - 'r', - [pytest.param('fake', marks=pytest.mark.fake)], - indirect=True -) - - -@pytest.fixture( - params=[ - pytest.param('fake', marks=pytest.mark.fake), - pytest.param('real', marks=pytest.mark.real) - ] -) -async def r(request): - if request.param == 'fake': - fake_server = request.getfixturevalue('fake_server') - ret = fakeredis.aioredis.FakeRedis(server=fake_server) - else: - if not request.getfixturevalue('is_redis_running'): - pytest.skip('Redis is not running') - ret = aioredis.Redis() - fake_server = None - if not fake_server or fake_server.connected: - await ret.flushall() - - yield ret - - if not fake_server or fake_server.connected: - await ret.flushall() - await ret.connection_pool.disconnect() - - -@pytest.fixture -async def conn(r): - """A single connection, rather than a pool.""" - async with r.client() as conn: - yield conn - - -async def test_ping(r): - pong = await r.ping() - assert pong is True - - -async def test_types(r): - await r.hset('hash', mapping={'key1': 'value1', 'key2': 'value2', 'key3': 123}) - result = await r.hgetall('hash') - assert result == { - b'key1': b'value1', - b'key2': b'value2', - b'key3': b'123' - } - - -async def test_transaction(r): - async with r.pipeline(transaction=True) as tr: - tr.set('key1', 'value1') - tr.set('key2', 'value2') - ok1, ok2 = await tr.execute() - assert ok1 - assert ok2 - result = await r.get('key1') - assert result == b'value1' - - -async def test_transaction_fail(r): - await r.set('foo', '1') - async with r.pipeline(transaction=True) as tr: - await tr.watch('foo') - await r.set('foo', '2') # Different connection - tr.multi() - tr.get('foo') - with pytest.raises(aioredis.exceptions.WatchError): - await tr.execute() - - -async def test_pubsub(r, event_loop): - queue = asyncio.Queue() - - async def reader(ps): - while True: - message = await ps.get_message(ignore_subscribe_messages=True, timeout=5) - if message is not None: - if message.get('data') == b'stop': - break - queue.put_nowait(message) - - async with async_timeout.timeout(5), r.pubsub() as ps: - await ps.subscribe('channel') - task = event_loop.create_task(reader(ps)) - await r.publish('channel', 'message1') - await r.publish('channel', 'message2') - result1 = await queue.get() - result2 = await queue.get() - assert result1 == { - 'channel': b'channel', - 'pattern': None, - 'type': 'message', - 'data': b'message1' - } - assert result2 == { - 'channel': b'channel', - 'pattern': None, - 'type': 'message', - 'data': b'message2' - } - await r.publish('channel', 'stop') - await task - - -@pytest.mark.slow -async def test_pubsub_timeout(r): - async with r.pubsub() as ps: - await ps.subscribe('channel') - await ps.get_message(timeout=0.5) # Subscription message - message = await ps.get_message(timeout=0.5) - assert message is None - - -@pytest.mark.slow -async def test_pubsub_disconnect(r): - async with r.pubsub() as ps: - await ps.subscribe('channel') - await ps.connection.disconnect() - message = await ps.get_message(timeout=0.5) # Subscription message - assert message is not None - message = await ps.get_message(timeout=0.5) - assert message is None - - -async def test_blocking_ready(r, conn): - """Blocking command which does not need to block.""" - await r.rpush('list', 'x') - result = await conn.blpop('list', timeout=1) - assert result == (b'list', b'x') - - -@pytest.mark.slow -async def test_blocking_timeout(conn): - """Blocking command that times out without completing.""" - result = await conn.blpop('missing', timeout=1) - assert result is None - - -@pytest.mark.slow -async def test_blocking_unblock(r, conn, event_loop): - """Blocking command that gets unblocked after some time.""" - async def unblock(): - await asyncio.sleep(0.1) - await r.rpush('list', 'y') - - task = event_loop.create_task(unblock()) - result = await conn.blpop('list', timeout=1) - assert result == (b'list', b'y') - await task - - -async def test_wrongtype_error(r): - await r.set('foo', 'bar') - with pytest.raises(aioredis.ResponseError, match='^WRONGTYPE'): - await r.rpush('foo', 'baz') - - -async def test_syntax_error(r): - with pytest.raises(aioredis.ResponseError, - match="^wrong number of arguments for 'get' command$"): - await r.execute_command('get') - - -async def test_no_script_error(r): - with pytest.raises(aioredis.exceptions.NoScriptError): - await r.evalsha('0123456789abcdef0123456789abcdef', 0) - - -async def test_failed_script_error(r): - await r.set('foo', 'bar') - with pytest.raises(aioredis.ResponseError, match='^Error running script'): - await r.eval('return redis.call("ZCOUNT", KEYS[1])', 1, 'foo') - - -@fake_only -def test_repr(r): - assert re.fullmatch( - r'ConnectionPool,db=0>>', - repr(r.connection_pool) - ) - - -@fake_only -@pytest.mark.disconnected -async def test_not_connected(r): - with pytest.raises(aioredis.ConnectionError): - await r.ping() - - -@fake_only -async def test_disconnect_server(r, fake_server): - await r.ping() - fake_server.connected = False - with pytest.raises(aioredis.ConnectionError): - await r.ping() - fake_server.connected = True - - -@pytest.mark.fake -async def test_from_url(): - r0 = fakeredis.aioredis.FakeRedis.from_url('/service/redis://localhost?db=0') - r1 = fakeredis.aioredis.FakeRedis.from_url('/service/redis://localhost?db=1') - # Check that they are indeed different databases - await r0.set('foo', 'a') - await r1.set('foo', 'b') - assert await r0.get('foo') == b'a' - assert await r1.get('foo') == b'b' - await r0.connection_pool.disconnect() - await r1.connection_pool.disconnect() - - -@fake_only -async def test_from_url_with_server(r, fake_server): - r2 = fakeredis.aioredis.FakeRedis.from_url('redis://localhost', server=fake_server) - await r.set('foo', 'bar') - assert await r2.get('foo') == b'bar' - await r2.connection_pool.disconnect() - - -@pytest.mark.fake -async def test_without_server(): - r = fakeredis.aioredis.FakeRedis() - assert await r.ping() - - -@pytest.mark.fake -async def test_without_server_disconnected(): - r = fakeredis.aioredis.FakeRedis(connected=False) - with pytest.raises(aioredis.ConnectionError): - await r.ping() diff --git a/test/test_fakeredis.py b/test/test_fakeredis.py deleted file mode 100644 index 81dc123..0000000 --- a/test/test_fakeredis.py +++ /dev/null @@ -1,5592 +0,0 @@ -from time import sleep, time -from redis.exceptions import ResponseError -from collections import OrderedDict -import os -import math -import threading -import logging -from queue import Queue - -import six -from packaging.version import Version -import pytest -import redis -import redis.client - -import fakeredis -from datetime import datetime, timedelta - - -REDIS_VERSION = Version(redis.__version__) -REDIS3 = REDIS_VERSION >= Version('3') - - -redis2_only = pytest.mark.skipif(REDIS3, reason="Test is only applicable to redis-py 2.x") -redis3_only = pytest.mark.skipif(not REDIS3, reason="Test is only applicable to redis-py 3.x") -fake_only = pytest.mark.parametrize( - 'create_redis', - [pytest.param('FakeStrictRedis', marks=pytest.mark.fake)], - indirect=True -) - - -def key_val_dict(size=100): - return {b'key:' + bytes([i]): b'val:' + bytes([i]) - for i in range(size)} - - -def round_str(x): - assert isinstance(x, bytes) - return round(float(x)) - - -def raw_command(r, *args): - """Like execute_command, but does not do command-specific response parsing""" - response_callbacks = r.response_callbacks - try: - r.response_callbacks = {} - return r.execute_command(*args) - finally: - r.response_callbacks = response_callbacks - - -# Wrap some redis commands to abstract differences between redis-py 2 and 3. -def zadd(r, key, d, *args, **kwargs): - if REDIS3: - return r.zadd(key, d, *args, **kwargs) - else: - return r.zadd(key, **d) - - -def zincrby(r, key, amount, value): - if REDIS3: - return r.zincrby(key, amount, value) - else: - return r.zincrby(key, value, amount) - - -@pytest.fixture(scope="session") -def is_redis_running(): - try: - r = redis.StrictRedis('localhost', port=6379) - r.ping() - except redis.ConnectionError: - return False - else: - return True - - -@pytest.fixture( - params=[ - pytest.param('StrictRedis', marks=pytest.mark.real), - pytest.param('FakeStrictRedis', marks=pytest.mark.fake) - ] -) -def create_redis(request): - name = request.param - if not name.startswith('Fake') and not request.getfixturevalue('is_redis_running'): - pytest.skip('Redis is not running') - decode_responses = request.node.get_closest_marker('decode_responses') is not None - - def factory(db=0): - if name.startswith('Fake'): - fake_server = request.getfixturevalue('fake_server') - cls = getattr(fakeredis, name) - return cls(db=db, decode_responses=decode_responses, server=fake_server) - else: - cls = getattr(redis, name) - conn = cls('localhost', port=6379, db=db, decode_responses=decode_responses) - min_server_marker = request.node.get_closest_marker('min_server') - if min_server_marker is not None: - server_version = conn.info()['redis_version'] - min_version = Version(min_server_marker.args[0]) - if Version(server_version) < min_version: - pytest.skip( - 'Redis server {} required but {} found'.format(min_version, server_version) - ) - return conn - - return factory - - -@pytest.fixture -def r(request, create_redis): - r = create_redis(db=0) - connected = request.node.get_closest_marker('disconnected') is None - if connected: - r.flushall() - yield r - if connected: - r.flushall() - if hasattr(r, 'close'): - r.close() # Older versions of redis-py don't have this method - - -def test_large_command(r): - r.set('foo', 'bar' * 10000) - assert r.get('foo') == b'bar' * 10000 - - -def test_dbsize(r): - assert r.dbsize() == 0 - r.set('foo', 'bar') - r.set('bar', 'foo') - assert r.dbsize() == 2 - - -def test_flushdb(r): - r.set('foo', 'bar') - assert r.keys() == [b'foo'] - assert r.flushdb() is True - assert r.keys() == [] - - -def test_dump_missing(r): - assert r.dump('foo') is None - - -def test_dump_restore(r): - r.set('foo', 'bar') - dump = r.dump('foo') - r.restore('baz', 0, dump) - assert r.get('baz') == b'bar' - assert r.ttl('baz') == -1 - - -def test_dump_restore_ttl(r): - r.set('foo', 'bar') - dump = r.dump('foo') - r.restore('baz', 2000, dump) - assert r.get('baz') == b'bar' - assert 1000 <= r.pttl('baz') <= 2000 - - -def test_dump_restore_replace(r): - r.set('foo', 'bar') - dump = r.dump('foo') - r.set('foo', 'baz') - r.restore('foo', 0, dump, replace=True) - assert r.get('foo') == b'bar' - - -def test_restore_exists(r): - r.set('foo', 'bar') - dump = r.dump('foo') - with pytest.raises(ResponseError): - r.restore('foo', 0, dump) - - -def test_restore_invalid_dump(r): - r.set('foo', 'bar') - dump = r.dump('foo') - with pytest.raises(ResponseError): - r.restore('baz', 0, dump[:-1]) - - -def test_restore_invalid_ttl(r): - r.set('foo', 'bar') - dump = r.dump('foo') - with pytest.raises(ResponseError): - r.restore('baz', -1, dump) - - -def test_set_then_get(r): - assert r.set('foo', 'bar') is True - assert r.get('foo') == b'bar' - - -@redis2_only -def test_set_None_value(r): - assert r.set('foo', None) is True - assert r.get('foo') == b'None' - - -def test_set_float_value(r): - x = 1.23456789123456789 - r.set('foo', x) - assert float(r.get('foo')) == x - - -def test_saving_non_ascii_chars_as_value(r): - assert r.set('foo', 'Ñandu') is True - assert r.get('foo') == 'Ñandu'.encode() - - -def test_saving_unicode_type_as_value(r): - assert r.set('foo', 'Ñandu') is True - assert r.get('foo') == 'Ñandu'.encode() - - -def test_saving_non_ascii_chars_as_key(r): - assert r.set('Ñandu', 'foo') is True - assert r.get('Ñandu') == b'foo' - - -def test_saving_unicode_type_as_key(r): - assert r.set('Ñandu', 'foo') is True - assert r.get('Ñandu') == b'foo' - - -def test_future_newbytes(r): - bytes = pytest.importorskip('builtins', reason='future.types not available').bytes - r.set(bytes(b'\xc3\x91andu'), 'foo') - assert r.get('Ñandu') == b'foo' - - -def test_future_newstr(r): - str = pytest.importorskip('builtins', reason='future.types not available').str - r.set(str('Ñandu'), 'foo') - assert r.get('Ñandu') == b'foo' - - -def test_get_does_not_exist(r): - assert r.get('foo') is None - - -def test_get_with_non_str_keys(r): - assert r.set('2', 'bar') is True - assert r.get(2) == b'bar' - - -def test_get_invalid_type(r): - assert r.hset('foo', 'key', 'value') == 1 - with pytest.raises(redis.ResponseError): - r.get('foo') - - -def test_set_non_str_keys(r): - assert r.set(2, 'bar') is True - assert r.get(2) == b'bar' - assert r.get('2') == b'bar' - - -def test_getbit(r): - r.setbit('foo', 3, 1) - assert r.getbit('foo', 0) == 0 - assert r.getbit('foo', 1) == 0 - assert r.getbit('foo', 2) == 0 - assert r.getbit('foo', 3) == 1 - assert r.getbit('foo', 4) == 0 - assert r.getbit('foo', 100) == 0 - - -def test_getbit_wrong_type(r): - r.rpush('foo', b'x') - with pytest.raises(redis.ResponseError): - r.getbit('foo', 1) - - -def test_multiple_bits_set(r): - r.setbit('foo', 1, 1) - r.setbit('foo', 3, 1) - r.setbit('foo', 5, 1) - - assert r.getbit('foo', 0) == 0 - assert r.getbit('foo', 1) == 1 - assert r.getbit('foo', 2) == 0 - assert r.getbit('foo', 3) == 1 - assert r.getbit('foo', 4) == 0 - assert r.getbit('foo', 5) == 1 - assert r.getbit('foo', 6) == 0 - - -def test_unset_bits(r): - r.setbit('foo', 1, 1) - r.setbit('foo', 2, 0) - r.setbit('foo', 3, 1) - assert r.getbit('foo', 1) == 1 - r.setbit('foo', 1, 0) - assert r.getbit('foo', 1) == 0 - r.setbit('foo', 3, 0) - assert r.getbit('foo', 3) == 0 - - -def test_get_set_bits(r): - # set bit 5 - assert not r.setbit('a', 5, True) - assert r.getbit('a', 5) - # unset bit 4 - assert not r.setbit('a', 4, False) - assert not r.getbit('a', 4) - # set bit 4 - assert not r.setbit('a', 4, True) - assert r.getbit('a', 4) - # set bit 5 again - assert r.setbit('a', 5, True) - assert r.getbit('a', 5) - - -def test_setbits_and_getkeys(r): - # The bit operations and the get commands - # should play nicely with each other. - r.setbit('foo', 1, 1) - assert r.get('foo') == b'@' - r.setbit('foo', 2, 1) - assert r.get('foo') == b'`' - r.setbit('foo', 3, 1) - assert r.get('foo') == b'p' - r.setbit('foo', 9, 1) - assert r.get('foo') == b'p@' - r.setbit('foo', 54, 1) - assert r.get('foo') == b'p@\x00\x00\x00\x00\x02' - - -def test_setbit_wrong_type(r): - r.rpush('foo', b'x') - with pytest.raises(redis.ResponseError): - r.setbit('foo', 0, 1) - - -def test_setbit_expiry(r): - r.set('foo', b'0x00', ex=10) - r.setbit('foo', 1, 1) - assert r.ttl('foo') > 0 - - -def test_bitcount(r): - r.delete('foo') - assert r.bitcount('foo') == 0 - r.setbit('foo', 1, 1) - assert r.bitcount('foo') == 1 - r.setbit('foo', 8, 1) - assert r.bitcount('foo') == 2 - assert r.bitcount('foo', 1, 1) == 1 - r.setbit('foo', 57, 1) - assert r.bitcount('foo') == 3 - r.set('foo', ' ') - assert r.bitcount('foo') == 1 - - -def test_bitcount_wrong_type(r): - r.rpush('foo', b'x') - with pytest.raises(redis.ResponseError): - r.bitcount('foo') - - -def test_getset_not_exist(r): - val = r.getset('foo', 'bar') - assert val is None - assert r.get('foo') == b'bar' - - -def test_getset_exists(r): - r.set('foo', 'bar') - val = r.getset('foo', b'baz') - assert val == b'bar' - val = r.getset('foo', b'baz2') - assert val == b'baz' - - -def test_getset_wrong_type(r): - r.rpush('foo', b'x') - with pytest.raises(redis.ResponseError): - r.getset('foo', 'bar') - - -def test_setitem_getitem(r): - assert r.keys() == [] - r['foo'] = 'bar' - assert r['foo'] == b'bar' - - -def test_getitem_non_existent_key(r): - assert r.keys() == [] - with pytest.raises(KeyError): - r['noexists'] - - -def test_strlen(r): - r['foo'] = 'bar' - - assert r.strlen('foo') == 3 - assert r.strlen('noexists') == 0 - - -def test_strlen_wrong_type(r): - r.rpush('foo', b'x') - with pytest.raises(redis.ResponseError): - r.strlen('foo') - - -def test_substr(r): - r['foo'] = 'one_two_three' - assert r.substr('foo', 0) == b'one_two_three' - assert r.substr('foo', 0, 2) == b'one' - assert r.substr('foo', 4, 6) == b'two' - assert r.substr('foo', -5) == b'three' - assert r.substr('foo', -4, -5) == b'' - assert r.substr('foo', -5, -3) == b'thr' - - -def test_substr_noexist_key(r): - assert r.substr('foo', 0) == b'' - assert r.substr('foo', 10) == b'' - assert r.substr('foo', -5, -1) == b'' - - -def test_substr_wrong_type(r): - r.rpush('foo', b'x') - with pytest.raises(redis.ResponseError): - r.substr('foo', 0) - - -def test_append(r): - assert r.set('foo', 'bar') - assert r.append('foo', 'baz') == 6 - assert r.get('foo') == b'barbaz' - - -def test_append_with_no_preexisting_key(r): - assert r.append('foo', 'bar') == 3 - assert r.get('foo') == b'bar' - - -def test_append_wrong_type(r): - r.rpush('foo', b'x') - with pytest.raises(redis.ResponseError): - r.append('foo', b'x') - - -def test_incr_with_no_preexisting_key(r): - assert r.incr('foo') == 1 - assert r.incr('bar', 2) == 2 - - -def test_incr_by(r): - assert r.incrby('foo') == 1 - assert r.incrby('bar', 2) == 2 - - -def test_incr_preexisting_key(r): - r.set('foo', 15) - assert r.incr('foo', 5) == 20 - assert r.get('foo') == b'20' - - -def test_incr_expiry(r): - r.set('foo', 15, ex=10) - r.incr('foo', 5) - assert r.ttl('foo') > 0 - - -def test_incr_bad_type(r): - r.set('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.incr('foo', 15) - r.rpush('foo2', 1) - with pytest.raises(redis.ResponseError): - r.incr('foo2', 15) - - -def test_incr_with_float(r): - with pytest.raises(redis.ResponseError): - r.incr('foo', 2.0) - - -def test_incr_followed_by_mget(r): - r.set('foo', 15) - assert r.incr('foo', 5) == 20 - assert r.get('foo') == b'20' - - -def test_incr_followed_by_mget_returns_strings(r): - r.incr('foo', 1) - assert r.mget(['foo']) == [b'1'] - - -def test_incrbyfloat(r): - r.set('foo', 0) - assert r.incrbyfloat('foo', 1.0) == 1.0 - assert r.incrbyfloat('foo', 1.0) == 2.0 - - -def test_incrbyfloat_with_noexist(r): - assert r.incrbyfloat('foo', 1.0) == 1.0 - assert r.incrbyfloat('foo', 1.0) == 2.0 - - -def test_incrbyfloat_expiry(r): - r.set('foo', 1.5, ex=10) - r.incrbyfloat('foo', 2.5) - assert r.ttl('foo') > 0 - - -def test_incrbyfloat_bad_type(r): - r.set('foo', 'bar') - with pytest.raises(redis.ResponseError, match='not a valid float'): - r.incrbyfloat('foo', 1.0) - r.rpush('foo2', 1) - with pytest.raises(redis.ResponseError): - r.incrbyfloat('foo2', 1.0) - - -def test_incrbyfloat_precision(r): - x = 1.23456789123456789 - assert r.incrbyfloat('foo', x) == x - assert float(r.get('foo')) == x - - -def test_decr(r): - r.set('foo', 10) - assert r.decr('foo') == 9 - assert r.get('foo') == b'9' - - -def test_decr_newkey(r): - r.decr('foo') - assert r.get('foo') == b'-1' - - -def test_decr_expiry(r): - r.set('foo', 10, ex=10) - r.decr('foo', 5) - assert r.ttl('foo') > 0 - - -def test_decr_badtype(r): - r.set('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.decr('foo', 15) - r.rpush('foo2', 1) - with pytest.raises(redis.ResponseError): - r.decr('foo2', 15) - - -def test_keys(r): - r.set('', 'empty') - r.set('abc\n', '') - r.set('abc\\', '') - r.set('abcde', '') - r.set(b'\xfe\xcd', '') - assert sorted(r.keys()) == [b'', b'abc\n', b'abc\\', b'abcde', b'\xfe\xcd'] - assert r.keys('??') == [b'\xfe\xcd'] - # empty pattern not the same as no pattern - assert r.keys('') == [b''] - # ? must match \n - assert sorted(r.keys('abc?')) == [b'abc\n', b'abc\\'] - # must be anchored at both ends - assert r.keys('abc') == [] - assert r.keys('bcd') == [] - # wildcard test - assert r.keys('a*de') == [b'abcde'] - # positive groups - assert sorted(r.keys('abc[d\n]*')) == [b'abc\n', b'abcde'] - assert r.keys('abc[c-e]?') == [b'abcde'] - assert r.keys('abc[e-c]?') == [b'abcde'] - assert r.keys('abc[e-e]?') == [] - assert r.keys('abcd[ef') == [b'abcde'] - assert r.keys('abcd[]') == [] - # negative groups - assert r.keys('abc[^d\\\\]*') == [b'abc\n'] - assert r.keys('abc[^]e') == [b'abcde'] - # escaping - assert r.keys(r'abc\?e') == [] - assert r.keys(r'abc\de') == [b'abcde'] - assert r.keys(r'abc[\d]e') == [b'abcde'] - # some escaping cases that redis handles strangely - assert r.keys('abc\\') == [b'abc\\'] - assert r.keys(r'abc[\c-e]e') == [] - assert r.keys(r'abc[c-\e]e') == [] - - -def test_exists(r): - assert 'foo' not in r - r.set('foo', 'bar') - assert 'foo' in r - - -def test_contains(r): - assert not r.exists('foo') - r.set('foo', 'bar') - assert r.exists('foo') - - -def test_rename(r): - r.set('foo', 'unique value') - assert r.rename('foo', 'bar') - assert r.get('foo') is None - assert r.get('bar') == b'unique value' - - -def test_rename_nonexistent_key(r): - with pytest.raises(redis.ResponseError): - r.rename('foo', 'bar') - - -def test_renamenx_doesnt_exist(r): - r.set('foo', 'unique value') - assert r.renamenx('foo', 'bar') - assert r.get('foo') is None - assert r.get('bar') == b'unique value' - - -def test_rename_does_exist(r): - r.set('foo', 'unique value') - r.set('bar', 'unique value2') - assert not r.renamenx('foo', 'bar') - assert r.get('foo') == b'unique value' - assert r.get('bar') == b'unique value2' - - -def test_rename_expiry(r): - r.set('foo', 'value1', ex=10) - r.set('bar', 'value2') - r.rename('foo', 'bar') - assert r.ttl('bar') > 0 - - -def test_mget(r): - r.set('foo', 'one') - r.set('bar', 'two') - assert r.mget(['foo', 'bar']) == [b'one', b'two'] - assert r.mget(['foo', 'bar', 'baz']) == [b'one', b'two', None] - assert r.mget('foo', 'bar') == [b'one', b'two'] - - -@redis2_only -def test_mget_none(r): - r.set('foo', 'one') - r.set('bar', 'two') - assert r.mget('foo', 'bar', None) == [b'one', b'two', None] - - -def test_mget_with_no_keys(r): - if REDIS3: - assert r.mget([]) == [] - else: - with pytest.raises(redis.ResponseError, match='wrong number of arguments'): - r.mget([]) - - -def test_mget_mixed_types(r): - r.hset('hash', 'bar', 'baz') - zadd(r, 'zset', {'bar': 1}) - r.sadd('set', 'member') - r.rpush('list', 'item1') - r.set('string', 'value') - assert ( - r.mget(['hash', 'zset', 'set', 'string', 'absent']) - == [None, None, None, b'value', None] - ) - - -def test_mset_with_no_keys(r): - with pytest.raises(redis.ResponseError): - r.mset({}) - - -def test_mset(r): - assert r.mset({'foo': 'one', 'bar': 'two'}) is True - assert r.mset({'foo': 'one', 'bar': 'two'}) is True - assert r.mget('foo', 'bar') == [b'one', b'two'] - - -@redis2_only -def test_mset_accepts_kwargs(r): - assert r.mset(foo='one', bar='two') is True - assert r.mset(foo='one', baz='three') is True - assert r.mget('foo', 'bar', 'baz') == [b'one', b'two', b'three'] - - -def test_msetnx(r): - assert r.msetnx({'foo': 'one', 'bar': 'two'}) is True - assert r.msetnx({'bar': 'two', 'baz': 'three'}) is False - assert r.mget('foo', 'bar', 'baz') == [b'one', b'two', None] - - -def test_setex(r): - assert r.setex('foo', 100, 'bar') is True - assert r.get('foo') == b'bar' - - -def test_setex_using_timedelta(r): - assert r.setex('foo', timedelta(seconds=100), 'bar') is True - assert r.get('foo') == b'bar' - - -def test_setex_using_float(r): - with pytest.raises(redis.ResponseError, match='integer'): - r.setex('foo', 1.2, 'bar') - - -@pytest.mark.min_server('6.2') -def test_setex_overflow(r): - with pytest.raises(ResponseError): - r.setex('foo', 18446744073709561, 'bar') # Overflows long long in ms - - -def test_set_ex(r): - assert r.set('foo', 'bar', ex=100) is True - assert r.get('foo') == b'bar' - - -def test_set_ex_using_timedelta(r): - assert r.set('foo', 'bar', ex=timedelta(seconds=100)) is True - assert r.get('foo') == b'bar' - - -def test_set_ex_overflow(r): - with pytest.raises(ResponseError): - r.set('foo', 'bar', ex=18446744073709561) # Overflows long long in ms - - -def test_set_px_overflow(r): - with pytest.raises(ResponseError): - r.set('foo', 'bar', px=2**63 - 2) # Overflows after adding current time - - -def test_set_px(r): - assert r.set('foo', 'bar', px=100) is True - assert r.get('foo') == b'bar' - - -def test_set_px_using_timedelta(r): - assert r.set('foo', 'bar', px=timedelta(milliseconds=100)) is True - assert r.get('foo') == b'bar' - - -@pytest.mark.skipif(REDIS_VERSION < Version('3.5'), reason="Test is only applicable to redis-py 3.5+") -@pytest.mark.min_server('6.0') -def test_set_keepttl(r): - r.set('foo', 'bar', ex=100) - assert r.set('foo', 'baz', keepttl=True) is True - assert r.ttl('foo') == 100 - assert r.get('foo') == b'baz' - - -def test_set_conflicting_expire_options(r): - with pytest.raises(ResponseError): - r.set('foo', 'bar', ex=1, px=1) - - -@pytest.mark.skipif(REDIS_VERSION < Version('3.5'), reason="Test is only applicable to redis-py 3.5+") -def test_set_conflicting_expire_options_w_keepttl(r): - with pytest.raises(ResponseError): - r.set('foo', 'bar', ex=1, keepttl=True) - with pytest.raises(ResponseError): - r.set('foo', 'bar', px=1, keepttl=True) - with pytest.raises(ResponseError): - r.set('foo', 'bar', ex=1, px=1, keepttl=True) - - -def test_set_raises_wrong_ex(r): - with pytest.raises(ResponseError): - r.set('foo', 'bar', ex=-100) - with pytest.raises(ResponseError): - r.set('foo', 'bar', ex=0) - assert not r.exists('foo') - - -def test_set_using_timedelta_raises_wrong_ex(r): - with pytest.raises(ResponseError): - r.set('foo', 'bar', ex=timedelta(seconds=-100)) - with pytest.raises(ResponseError): - r.set('foo', 'bar', ex=timedelta(seconds=0)) - assert not r.exists('foo') - - -def test_set_raises_wrong_px(r): - with pytest.raises(ResponseError): - r.set('foo', 'bar', px=-100) - with pytest.raises(ResponseError): - r.set('foo', 'bar', px=0) - assert not r.exists('foo') - - -def test_set_using_timedelta_raises_wrong_px(r): - with pytest.raises(ResponseError): - r.set('foo', 'bar', px=timedelta(milliseconds=-100)) - with pytest.raises(ResponseError): - r.set('foo', 'bar', px=timedelta(milliseconds=0)) - assert not r.exists('foo') - - -def test_setex_raises_wrong_ex(r): - with pytest.raises(ResponseError): - r.setex('foo', -100, 'bar') - with pytest.raises(ResponseError): - r.setex('foo', 0, 'bar') - assert not r.exists('foo') - - -def test_setex_using_timedelta_raises_wrong_ex(r): - with pytest.raises(ResponseError): - r.setex('foo', timedelta(seconds=-100), 'bar') - with pytest.raises(ResponseError): - r.setex('foo', timedelta(seconds=-100), 'bar') - assert not r.exists('foo') - - -def test_setnx(r): - assert r.setnx('foo', 'bar') is True - assert r.get('foo') == b'bar' - assert r.setnx('foo', 'baz') is False - assert r.get('foo') == b'bar' - - -def test_set_nx(r): - assert r.set('foo', 'bar', nx=True) is True - assert r.get('foo') == b'bar' - assert r.set('foo', 'bar', nx=True) is None - assert r.get('foo') == b'bar' - - -def test_set_xx(r): - assert r.set('foo', 'bar', xx=True) is None - r.set('foo', 'bar') - assert r.set('foo', 'bar', xx=True) is True - - -@pytest.mark.min_server('6.2') -def test_set_get(r): - assert raw_command(r, 'set', 'foo', 'bar', 'GET') is None - assert r.get('foo') == b'bar' - assert raw_command(r, 'set', 'foo', 'baz', 'GET') == b'bar' - assert r.get('foo') == b'baz' - - -@pytest.mark.min_server('6.2') -def test_set_get_xx(r): - assert raw_command(r, 'set', 'foo', 'bar', 'XX', 'GET') is None - assert r.get('foo') is None - r.set('foo', 'bar') - assert raw_command(r, 'set', 'foo', 'baz', 'XX', 'GET') == b'bar' - assert r.get('foo') == b'baz' - assert raw_command(r, 'set', 'foo', 'baz', 'GET') == b'baz' - - -@pytest.mark.min_server('6.2') -def test_set_get_nx(r): - # Note: this will most likely fail on a 7.0 server, based on the docs for SET - with pytest.raises(redis.ResponseError): - raw_command(r, 'set', 'foo', 'bar', 'NX', 'GET') - - -@pytest.mark.min_server('6.2') -def set_get_wrongtype(r): - r.lpush('foo', 'bar') - with pytest.raises(redis.ResponseError): - raw_command(r, 'set', 'foo', 'bar', 'GET') - - -def test_del_operator(r): - r['foo'] = 'bar' - del r['foo'] - assert r.get('foo') is None - - -def test_delete(r): - r['foo'] = 'bar' - assert r.delete('foo') == 1 - assert r.get('foo') is None - - -def test_echo(r): - assert r.echo(b'hello') == b'hello' - assert r.echo('hello') == b'hello' - - -@pytest.mark.slow -def test_delete_expire(r): - r.set("foo", "bar", ex=1) - r.delete("foo") - r.set("foo", "bar") - sleep(2) - assert r.get("foo") == b'bar' - - -def test_delete_multiple(r): - r['one'] = 'one' - r['two'] = 'two' - r['three'] = 'three' - # Since redis>=2.7.6 returns number of deleted items. - assert r.delete('one', 'two') == 2 - assert r.get('one') is None - assert r.get('two') is None - assert r.get('three') == b'three' - assert r.delete('one', 'two') == 0 - # If any keys are deleted, True is returned. - assert r.delete('two', 'three', 'three') == 1 - assert r.get('three') is None - - -def test_delete_nonexistent_key(r): - assert r.delete('foo') == 0 - - -# Tests for the list type. - -@redis2_only -def test_rpush_then_lrange_with_nested_list1(r): - assert r.rpush('foo', [12345, 6789]) == 1 - assert r.rpush('foo', [54321, 9876]) == 2 - assert r.lrange('foo', 0, -1) == [b'[12345, 6789]', b'[54321, 9876]'] - - -@redis2_only -def test_rpush_then_lrange_with_nested_list2(r): - assert r.rpush('foo', [12345, 'banana']) == 1 - assert r.rpush('foo', [54321, 'elephant']) == 2 - assert r.lrange('foo', 0, -1), [b'[12345, \'banana\']', b'[54321, \'elephant\']'] - - -@redis2_only -def test_rpush_then_lrange_with_nested_list3(r): - assert r.rpush('foo', [12345, []]) == 1 - assert r.rpush('foo', [54321, []]) == 2 - assert r.lrange('foo', 0, -1) == [b'[12345, []]', b'[54321, []]'] - - -def test_lpush_then_lrange_all(r): - assert r.lpush('foo', 'bar') == 1 - assert r.lpush('foo', 'baz') == 2 - assert r.lpush('foo', 'bam', 'buzz') == 4 - assert r.lrange('foo', 0, -1) == [b'buzz', b'bam', b'baz', b'bar'] - - -def test_lpush_then_lrange_portion(r): - r.lpush('foo', 'one') - r.lpush('foo', 'two') - r.lpush('foo', 'three') - r.lpush('foo', 'four') - assert r.lrange('foo', 0, 2) == [b'four', b'three', b'two'] - assert r.lrange('foo', 0, 3) == [b'four', b'three', b'two', b'one'] - - -def test_lrange_negative_indices(r): - r.rpush('foo', 'a', 'b', 'c') - assert r.lrange('foo', -1, -2) == [] - assert r.lrange('foo', -2, -1) == [b'b', b'c'] - - -def test_lpush_key_does_not_exist(r): - assert r.lrange('foo', 0, -1) == [] - - -def test_lpush_with_nonstr_key(r): - r.lpush(1, 'one') - r.lpush(1, 'two') - r.lpush(1, 'three') - assert r.lrange(1, 0, 2) == [b'three', b'two', b'one'] - assert r.lrange('1', 0, 2) == [b'three', b'two', b'one'] - - -def test_lpush_wrong_type(r): - r.set('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.lpush('foo', 'element') - - -def test_llen(r): - r.lpush('foo', 'one') - r.lpush('foo', 'two') - r.lpush('foo', 'three') - assert r.llen('foo') == 3 - - -def test_llen_no_exist(r): - assert r.llen('foo') == 0 - - -def test_llen_wrong_type(r): - r.set('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.llen('foo') - - -def test_lrem_positive_count(r): - r.lpush('foo', 'same') - r.lpush('foo', 'same') - r.lpush('foo', 'different') - r.lrem('foo', 2, 'same') - assert r.lrange('foo', 0, -1) == [b'different'] - - -def test_lrem_negative_count(r): - r.lpush('foo', 'removeme') - r.lpush('foo', 'three') - r.lpush('foo', 'two') - r.lpush('foo', 'one') - r.lpush('foo', 'removeme') - r.lrem('foo', -1, 'removeme') - # Should remove it from the end of the list, - # leaving the 'removeme' from the front of the list alone. - assert r.lrange('foo', 0, -1) == [b'removeme', b'one', b'two', b'three'] - - -def test_lrem_zero_count(r): - r.lpush('foo', 'one') - r.lpush('foo', 'one') - r.lpush('foo', 'one') - r.lrem('foo', 0, 'one') - assert r.lrange('foo', 0, -1) == [] - - -def test_lrem_default_value(r): - r.lpush('foo', 'one') - r.lpush('foo', 'one') - r.lpush('foo', 'one') - r.lrem('foo', 0, 'one') - assert r.lrange('foo', 0, -1) == [] - - -def test_lrem_does_not_exist(r): - r.lpush('foo', 'one') - r.lrem('foo', 0, 'one') - # These should be noops. - r.lrem('foo', -2, 'one') - r.lrem('foo', 2, 'one') - - -def test_lrem_return_value(r): - r.lpush('foo', 'one') - count = r.lrem('foo', 0, 'one') - assert count == 1 - assert r.lrem('foo', 0, 'one') == 0 - - -def test_lrem_wrong_type(r): - r.set('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.lrem('foo', 0, 'element') - - -def test_rpush(r): - r.rpush('foo', 'one') - r.rpush('foo', 'two') - r.rpush('foo', 'three') - r.rpush('foo', 'four', 'five') - assert r.lrange('foo', 0, -1) == [b'one', b'two', b'three', b'four', b'five'] - - -def test_rpush_wrong_type(r): - r.set('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.rpush('foo', 'element') - - -def test_lpop(r): - assert r.rpush('foo', 'one') == 1 - assert r.rpush('foo', 'two') == 2 - assert r.rpush('foo', 'three') == 3 - assert r.lpop('foo') == b'one' - assert r.lpop('foo') == b'two' - assert r.lpop('foo') == b'three' - - -def test_lpop_empty_list(r): - r.rpush('foo', 'one') - r.lpop('foo') - assert r.lpop('foo') is None - # Verify what happens if we try to pop from a key - # we've never seen before. - assert r.lpop('noexists') is None - - -def test_lpop_wrong_type(r): - r.set('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.lpop('foo') - - -@pytest.mark.min_server('6.2') -def test_lpop_count(r): - assert r.rpush('foo', 'one') == 1 - assert r.rpush('foo', 'two') == 2 - assert r.rpush('foo', 'three') == 3 - assert raw_command(r, 'lpop', 'foo', 2) == [b'one', b'two'] - # See https://github.com/redis/redis/issues/9680 - assert raw_command(r, 'lpop', 'foo', 0) is None - - -@pytest.mark.min_server('6.2') -def test_lpop_count_negative(r): - with pytest.raises(redis.ResponseError): - raw_command(r, 'lpop', 'foo', -1) - - -def test_lset(r): - r.rpush('foo', 'one') - r.rpush('foo', 'two') - r.rpush('foo', 'three') - r.lset('foo', 0, 'four') - r.lset('foo', -2, 'five') - assert r.lrange('foo', 0, -1) == [b'four', b'five', b'three'] - - -def test_lset_index_out_of_range(r): - r.rpush('foo', 'one') - with pytest.raises(redis.ResponseError): - r.lset('foo', 3, 'three') - - -def test_lset_wrong_type(r): - r.set('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.lset('foo', 0, 'element') - - -def test_rpushx(r): - r.rpush('foo', 'one') - r.rpushx('foo', 'two') - r.rpushx('bar', 'three') - assert r.lrange('foo', 0, -1) == [b'one', b'two'] - assert r.lrange('bar', 0, -1) == [] - - -def test_rpushx_wrong_type(r): - r.set('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.rpushx('foo', 'element') - - -def test_ltrim(r): - r.rpush('foo', 'one') - r.rpush('foo', 'two') - r.rpush('foo', 'three') - r.rpush('foo', 'four') - - assert r.ltrim('foo', 1, 3) - assert r.lrange('foo', 0, -1) == [b'two', b'three', b'four'] - assert r.ltrim('foo', 1, -1) - assert r.lrange('foo', 0, -1) == [b'three', b'four'] - - -def test_ltrim_with_non_existent_key(r): - assert r.ltrim('foo', 0, -1) - - -def test_ltrim_expiry(r): - r.rpush('foo', 'one', 'two', 'three') - r.expire('foo', 10) - r.ltrim('foo', 1, 2) - assert r.ttl('foo') > 0 - - -def test_ltrim_wrong_type(r): - r.set('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.ltrim('foo', 1, -1) - - -def test_lindex(r): - r.rpush('foo', 'one') - r.rpush('foo', 'two') - assert r.lindex('foo', 0) == b'one' - assert r.lindex('foo', 4) is None - assert r.lindex('bar', 4) is None - - -def test_lindex_wrong_type(r): - r.set('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.lindex('foo', 0) - - -def test_lpushx(r): - r.lpush('foo', 'two') - r.lpushx('foo', 'one') - r.lpushx('bar', 'one') - assert r.lrange('foo', 0, -1) == [b'one', b'two'] - assert r.lrange('bar', 0, -1) == [] - - -def test_lpushx_wrong_type(r): - r.set('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.lpushx('foo', 'element') - - -def test_rpop(r): - assert r.rpop('foo') is None - r.rpush('foo', 'one') - r.rpush('foo', 'two') - assert r.rpop('foo') == b'two' - assert r.rpop('foo') == b'one' - assert r.rpop('foo') is None - - -def test_rpop_wrong_type(r): - r.set('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.rpop('foo') - - -@pytest.mark.min_server('6.2') -def test_rpop_count(r): - assert r.rpush('foo', 'one') == 1 - assert r.rpush('foo', 'two') == 2 - assert r.rpush('foo', 'three') == 3 - assert raw_command(r, 'rpop', 'foo', 2) == [b'three', b'two'] - # See https://github.com/redis/redis/issues/9680 - assert raw_command(r, 'rpop', 'foo', 0) is None - - -@pytest.mark.min_server('6.2') -def test_rpop_count_negative(r): - with pytest.raises(redis.ResponseError): - raw_command(r, 'rpop', 'foo', -1) - - -def test_linsert_before(r): - r.rpush('foo', 'hello') - r.rpush('foo', 'world') - assert r.linsert('foo', 'before', 'world', 'there') == 3 - assert r.lrange('foo', 0, -1) == [b'hello', b'there', b'world'] - - -def test_linsert_after(r): - r.rpush('foo', 'hello') - r.rpush('foo', 'world') - assert r.linsert('foo', 'after', 'hello', 'there') == 3 - assert r.lrange('foo', 0, -1) == [b'hello', b'there', b'world'] - - -def test_linsert_no_pivot(r): - r.rpush('foo', 'hello') - r.rpush('foo', 'world') - assert r.linsert('foo', 'after', 'goodbye', 'bar') == -1 - assert r.lrange('foo', 0, -1) == [b'hello', b'world'] - - -def test_linsert_wrong_type(r): - r.set('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.linsert('foo', 'after', 'bar', 'element') - - -def test_rpoplpush(r): - assert r.rpoplpush('foo', 'bar') is None - assert r.lpop('bar') is None - r.rpush('foo', 'one') - r.rpush('foo', 'two') - r.rpush('bar', 'one') - - assert r.rpoplpush('foo', 'bar') == b'two' - assert r.lrange('foo', 0, -1) == [b'one'] - assert r.lrange('bar', 0, -1) == [b'two', b'one'] - - # Catch instances where we store bytes and strings inconsistently - # and thus bar = ['two', b'one'] - assert r.lrem('bar', -1, 'two') == 1 - - -def test_rpoplpush_to_nonexistent_destination(r): - r.rpush('foo', 'one') - assert r.rpoplpush('foo', 'bar') == b'one' - assert r.rpop('bar') == b'one' - - -def test_rpoplpush_expiry(r): - r.rpush('foo', 'one') - r.rpush('bar', 'two') - r.expire('bar', 10) - r.rpoplpush('foo', 'bar') - assert r.ttl('bar') > 0 - - -def test_rpoplpush_one_to_self(r): - r.rpush('list', 'element') - assert r.brpoplpush('list', 'list') == b'element' - assert r.lrange('list', 0, -1) == [b'element'] - - -def test_rpoplpush_wrong_type(r): - r.set('foo', 'bar') - r.rpush('list', 'element') - with pytest.raises(redis.ResponseError): - r.rpoplpush('foo', 'list') - assert r.get('foo') == b'bar' - assert r.lrange('list', 0, -1) == [b'element'] - with pytest.raises(redis.ResponseError): - r.rpoplpush('list', 'foo') - assert r.get('foo') == b'bar' - assert r.lrange('list', 0, -1) == [b'element'] - - -def test_blpop_single_list(r): - r.rpush('foo', 'one') - r.rpush('foo', 'two') - r.rpush('foo', 'three') - assert r.blpop(['foo'], timeout=1) == (b'foo', b'one') - - -def test_blpop_test_multiple_lists(r): - r.rpush('baz', 'zero') - assert r.blpop(['foo', 'baz'], timeout=1) == (b'baz', b'zero') - assert not r.exists('baz') - - r.rpush('foo', 'one') - r.rpush('foo', 'two') - # bar has nothing, so the returned value should come - # from foo. - assert r.blpop(['bar', 'foo'], timeout=1) == (b'foo', b'one') - r.rpush('bar', 'three') - # bar now has something, so the returned value should come - # from bar. - assert r.blpop(['bar', 'foo'], timeout=1) == (b'bar', b'three') - assert r.blpop(['bar', 'foo'], timeout=1) == (b'foo', b'two') - - -def test_blpop_allow_single_key(r): - # blpop converts single key arguments to a one element list. - r.rpush('foo', 'one') - assert r.blpop('foo', timeout=1) == (b'foo', b'one') - - -@pytest.mark.slow -def test_blpop_block(r): - def push_thread(): - sleep(0.5) - r.rpush('foo', 'value1') - sleep(0.5) - # Will wake the condition variable - r.set('bar', 'go back to sleep some more') - r.rpush('foo', 'value2') - - thread = threading.Thread(target=push_thread) - thread.start() - try: - assert r.blpop('foo') == (b'foo', b'value1') - assert r.blpop('foo', timeout=5) == (b'foo', b'value2') - finally: - thread.join() - - -def test_blpop_wrong_type(r): - r.set('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.blpop('foo', timeout=1) - - -def test_blpop_transaction(r): - p = r.pipeline() - p.multi() - p.blpop('missing', timeout=1000) - result = p.execute() - # Blocking commands behave like non-blocking versions in transactions - assert result == [None] - - -def test_eval_blpop(r): - r.rpush('foo', 'bar') - with pytest.raises(redis.ResponseError, match='not allowed from scripts'): - r.eval('return redis.pcall("BLPOP", KEYS[1], 1)', 1, 'foo') - - -def test_brpop_test_multiple_lists(r): - r.rpush('baz', 'zero') - assert r.brpop(['foo', 'baz'], timeout=1) == (b'baz', b'zero') - assert not r.exists('baz') - - r.rpush('foo', 'one') - r.rpush('foo', 'two') - assert r.brpop(['bar', 'foo'], timeout=1) == (b'foo', b'two') - - -def test_brpop_single_key(r): - r.rpush('foo', 'one') - r.rpush('foo', 'two') - assert r.brpop('foo', timeout=1) == (b'foo', b'two') - - -@pytest.mark.slow -def test_brpop_block(r): - def push_thread(): - sleep(0.5) - r.rpush('foo', 'value1') - sleep(0.5) - # Will wake the condition variable - r.set('bar', 'go back to sleep some more') - r.rpush('foo', 'value2') - - thread = threading.Thread(target=push_thread) - thread.start() - try: - assert r.brpop('foo') == (b'foo', b'value1') - assert r.brpop('foo', timeout=5) == (b'foo', b'value2') - finally: - thread.join() - - -def test_brpop_wrong_type(r): - r.set('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.brpop('foo', timeout=1) - - -def test_brpoplpush_multi_keys(r): - assert r.lpop('bar') is None - r.rpush('foo', 'one') - r.rpush('foo', 'two') - assert r.brpoplpush('foo', 'bar', timeout=1) == b'two' - assert r.lrange('bar', 0, -1) == [b'two'] - - # Catch instances where we store bytes and strings inconsistently - # and thus bar = ['two'] - assert r.lrem('bar', -1, 'two') == 1 - - -def test_brpoplpush_wrong_type(r): - r.set('foo', 'bar') - r.rpush('list', 'element') - with pytest.raises(redis.ResponseError): - r.brpoplpush('foo', 'list') - assert r.get('foo') == b'bar' - assert r.lrange('list', 0, -1) == [b'element'] - with pytest.raises(redis.ResponseError): - r.brpoplpush('list', 'foo') - assert r.get('foo') == b'bar' - assert r.lrange('list', 0, -1) == [b'element'] - - -@pytest.mark.slow -def test_blocking_operations_when_empty(r): - assert r.blpop(['foo'], timeout=1) is None - assert r.blpop(['bar', 'foo'], timeout=1) is None - assert r.brpop('foo', timeout=1) is None - assert r.brpoplpush('foo', 'bar', timeout=1) is None - - -def test_empty_list(r): - r.rpush('foo', 'bar') - r.rpop('foo') - assert not r.exists('foo') - - -# Tests for the hash type. - -def test_hstrlen_missing(r): - assert r.hstrlen('foo', 'doesnotexist') == 0 - - r.hset('foo', 'key', 'value') - assert r.hstrlen('foo', 'doesnotexist') == 0 - - -def test_hstrlen(r): - r.hset('foo', 'key', 'value') - assert r.hstrlen('foo', 'key') == 5 - - -def test_hset_then_hget(r): - assert r.hset('foo', 'key', 'value') == 1 - assert r.hget('foo', 'key') == b'value' - - -def test_hset_update(r): - assert r.hset('foo', 'key', 'value') == 1 - assert r.hset('foo', 'key', 'value') == 0 - - -def test_hset_wrong_type(r): - zadd(r, 'foo', {'bar': 1}) - with pytest.raises(redis.ResponseError): - r.hset('foo', 'key', 'value') - - -def test_hgetall(r): - assert r.hset('foo', 'k1', 'v1') == 1 - assert r.hset('foo', 'k2', 'v2') == 1 - assert r.hset('foo', 'k3', 'v3') == 1 - assert r.hgetall('foo') == { - b'k1': b'v1', - b'k2': b'v2', - b'k3': b'v3' - } - - -@redis2_only -def test_hgetall_with_tuples(r): - assert r.hset('foo', (1, 2), (1, 2, 3)) == 1 - assert r.hgetall('foo') == {b'(1, 2)': b'(1, 2, 3)'} - - -def test_hgetall_empty_key(r): - assert r.hgetall('foo') == {} - - -def test_hgetall_wrong_type(r): - zadd(r, 'foo', {'bar': 1}) - with pytest.raises(redis.ResponseError): - r.hgetall('foo') - - -def test_hexists(r): - r.hset('foo', 'bar', 'v1') - assert r.hexists('foo', 'bar') == 1 - assert r.hexists('foo', 'baz') == 0 - assert r.hexists('bar', 'bar') == 0 - - -def test_hexists_wrong_type(r): - zadd(r, 'foo', {'bar': 1}) - with pytest.raises(redis.ResponseError): - r.hexists('foo', 'key') - - -def test_hkeys(r): - r.hset('foo', 'k1', 'v1') - r.hset('foo', 'k2', 'v2') - assert set(r.hkeys('foo')) == {b'k1', b'k2'} - assert set(r.hkeys('bar')) == set() - - -def test_hkeys_wrong_type(r): - zadd(r, 'foo', {'bar': 1}) - with pytest.raises(redis.ResponseError): - r.hkeys('foo') - - -def test_hlen(r): - r.hset('foo', 'k1', 'v1') - r.hset('foo', 'k2', 'v2') - assert r.hlen('foo') == 2 - - -def test_hlen_wrong_type(r): - zadd(r, 'foo', {'bar': 1}) - with pytest.raises(redis.ResponseError): - r.hlen('foo') - - -def test_hvals(r): - r.hset('foo', 'k1', 'v1') - r.hset('foo', 'k2', 'v2') - assert set(r.hvals('foo')) == {b'v1', b'v2'} - assert set(r.hvals('bar')) == set() - - -def test_hvals_wrong_type(r): - zadd(r, 'foo', {'bar': 1}) - with pytest.raises(redis.ResponseError): - r.hvals('foo') - - -def test_hmget(r): - r.hset('foo', 'k1', 'v1') - r.hset('foo', 'k2', 'v2') - r.hset('foo', 'k3', 'v3') - # Normal case. - assert r.hmget('foo', ['k1', 'k3']) == [b'v1', b'v3'] - assert r.hmget('foo', 'k1', 'k3') == [b'v1', b'v3'] - # Key does not exist. - assert r.hmget('bar', ['k1', 'k3']) == [None, None] - assert r.hmget('bar', 'k1', 'k3') == [None, None] - # Some keys in the hash do not exist. - assert r.hmget('foo', ['k1', 'k500']) == [b'v1', None] - assert r.hmget('foo', 'k1', 'k500') == [b'v1', None] - - -def test_hmget_wrong_type(r): - zadd(r, 'foo', {'bar': 1}) - with pytest.raises(redis.ResponseError): - r.hmget('foo', 'key1', 'key2') - - -def test_hdel(r): - r.hset('foo', 'k1', 'v1') - r.hset('foo', 'k2', 'v2') - r.hset('foo', 'k3', 'v3') - assert r.hget('foo', 'k1') == b'v1' - assert r.hdel('foo', 'k1') == 1 - assert r.hget('foo', 'k1') is None - assert r.hdel('foo', 'k1') == 0 - # Since redis>=2.7.6 returns number of deleted items. - assert r.hdel('foo', 'k2', 'k3') == 2 - assert r.hget('foo', 'k2') is None - assert r.hget('foo', 'k3') is None - assert r.hdel('foo', 'k2', 'k3') == 0 - - -def test_hdel_wrong_type(r): - zadd(r, 'foo', {'bar': 1}) - with pytest.raises(redis.ResponseError): - r.hdel('foo', 'key') - - -def test_hincrby(r): - r.hset('foo', 'counter', 0) - assert r.hincrby('foo', 'counter') == 1 - assert r.hincrby('foo', 'counter') == 2 - assert r.hincrby('foo', 'counter') == 3 - - -def test_hincrby_with_no_starting_value(r): - assert r.hincrby('foo', 'counter') == 1 - assert r.hincrby('foo', 'counter') == 2 - assert r.hincrby('foo', 'counter') == 3 - - -def test_hincrby_with_range_param(r): - assert r.hincrby('foo', 'counter', 2) == 2 - assert r.hincrby('foo', 'counter', 2) == 4 - assert r.hincrby('foo', 'counter', 2) == 6 - - -def test_hincrby_wrong_type(r): - zadd(r, 'foo', {'bar': 1}) - with pytest.raises(redis.ResponseError): - r.hincrby('foo', 'key', 2) - - -def test_hincrbyfloat(r): - r.hset('foo', 'counter', 0.0) - assert r.hincrbyfloat('foo', 'counter') == 1.0 - assert r.hincrbyfloat('foo', 'counter') == 2.0 - assert r.hincrbyfloat('foo', 'counter') == 3.0 - - -def test_hincrbyfloat_with_no_starting_value(r): - assert r.hincrbyfloat('foo', 'counter') == 1.0 - assert r.hincrbyfloat('foo', 'counter') == 2.0 - assert r.hincrbyfloat('foo', 'counter') == 3.0 - - -def test_hincrbyfloat_with_range_param(r): - assert r.hincrbyfloat('foo', 'counter', 0.1) == pytest.approx(0.1) - assert r.hincrbyfloat('foo', 'counter', 0.1) == pytest.approx(0.2) - assert r.hincrbyfloat('foo', 'counter', 0.1) == pytest.approx(0.3) - - -def test_hincrbyfloat_on_non_float_value_raises_error(r): - r.hset('foo', 'counter', 'cat') - with pytest.raises(redis.ResponseError): - r.hincrbyfloat('foo', 'counter') - - -def test_hincrbyfloat_with_non_float_amount_raises_error(r): - with pytest.raises(redis.ResponseError): - r.hincrbyfloat('foo', 'counter', 'cat') - - -def test_hincrbyfloat_wrong_type(r): - zadd(r, 'foo', {'bar': 1}) - with pytest.raises(redis.ResponseError): - r.hincrbyfloat('foo', 'key', 0.1) - - -def test_hincrbyfloat_precision(r): - x = 1.23456789123456789 - assert r.hincrbyfloat('foo', 'bar', x) == x - assert float(r.hget('foo', 'bar')) == x - - -def test_hsetnx(r): - assert r.hsetnx('foo', 'newkey', 'v1') == 1 - assert r.hsetnx('foo', 'newkey', 'v1') == 0 - assert r.hget('foo', 'newkey') == b'v1' - - -def test_hmset_empty_raises_error(r): - with pytest.raises(redis.DataError): - r.hmset('foo', {}) - - -def test_hmset(r): - r.hset('foo', 'k1', 'v1') - assert r.hmset('foo', {'k2': 'v2', 'k3': 'v3'}) is True - - -@redis2_only -def test_hmset_convert_values(r): - r.hmset('foo', {'k1': True, 'k2': 1}) - assert r.hgetall('foo') == {b'k1': b'True', b'k2': b'1'} - - -@redis2_only -def test_hmset_does_not_mutate_input_params(r): - original = {'key': [123, 456]} - r.hmset('foo', original) - assert original == {'key': [123, 456]} - - -def test_hmset_wrong_type(r): - zadd(r, 'foo', {'bar': 1}) - with pytest.raises(redis.ResponseError): - r.hmset('foo', {'key': 'value'}) - - -def test_empty_hash(r): - r.hset('foo', 'bar', 'baz') - r.hdel('foo', 'bar') - assert not r.exists('foo') - - -def test_sadd(r): - assert r.sadd('foo', 'member1') == 1 - assert r.sadd('foo', 'member1') == 0 - assert r.smembers('foo') == {b'member1'} - assert r.sadd('foo', 'member2', 'member3') == 2 - assert r.smembers('foo') == {b'member1', b'member2', b'member3'} - assert r.sadd('foo', 'member3', 'member4') == 1 - assert r.smembers('foo') == {b'member1', b'member2', b'member3', b'member4'} - - -def test_sadd_as_str_type(r): - assert r.sadd('foo', *range(3)) == 3 - assert r.smembers('foo') == {b'0', b'1', b'2'} - - -def test_sadd_wrong_type(r): - zadd(r, 'foo', {'member': 1}) - with pytest.raises(redis.ResponseError): - r.sadd('foo', 'member2') - - -def test_scan_single(r): - r.set('foo1', 'bar1') - assert r.scan(match="foo*") == (0, [b'foo1']) - - -def test_scan_iter_single_page(r): - r.set('foo1', 'bar1') - r.set('foo2', 'bar2') - assert set(r.scan_iter(match="foo*")) == {b'foo1', b'foo2'} - assert set(r.scan_iter()) == {b'foo1', b'foo2'} - assert set(r.scan_iter(match="")) == set() - - -def test_scan_iter_multiple_pages(r): - all_keys = key_val_dict(size=100) - assert all(r.set(k, v) for k, v in all_keys.items()) - assert set(r.scan_iter()) == set(all_keys) - - -def test_scan_iter_multiple_pages_with_match(r): - all_keys = key_val_dict(size=100) - assert all(r.set(k, v) for k, v in all_keys.items()) - # Now add a few keys that don't match the key: pattern. - r.set('otherkey', 'foo') - r.set('andanother', 'bar') - actual = set(r.scan_iter(match='key:*')) - assert actual == set(all_keys) - - -@pytest.mark.skipif(REDIS_VERSION < Version('3.5'), reason="Test is only applicable to redis-py 3.5+") -@pytest.mark.min_server('6.0') -def test_scan_iter_multiple_pages_with_type(r): - all_keys = key_val_dict(size=100) - assert all(r.set(k, v) for k, v in all_keys.items()) - # Now add a few keys of another type - zadd(r, 'zset1', {'otherkey': 1}) - zadd(r, 'zset2', {'andanother': 1}) - actual = set(r.scan_iter(_type='string')) - assert actual == set(all_keys) - actual = set(r.scan_iter(_type='ZSET')) - assert actual == {b'zset1', b'zset2'} - - -def test_scan_multiple_pages_with_count_arg(r): - all_keys = key_val_dict(size=100) - assert all(r.set(k, v) for k, v in all_keys.items()) - assert set(r.scan_iter(count=1000)) == set(all_keys) - - -def test_scan_all_in_single_call(r): - all_keys = key_val_dict(size=100) - assert all(r.set(k, v) for k, v in all_keys.items()) - # Specify way more than the 100 keys we've added. - actual = r.scan(count=1000) - assert set(actual[1]) == set(all_keys) - assert actual[0] == 0 - - -@pytest.mark.slow -def test_scan_expired_key(r): - r.set('expiringkey', 'value') - r.pexpire('expiringkey', 1) - sleep(1) - assert r.scan()[1] == [] - - -def test_scard(r): - r.sadd('foo', 'member1') - r.sadd('foo', 'member2') - r.sadd('foo', 'member2') - assert r.scard('foo') == 2 - - -def test_scard_wrong_type(r): - zadd(r, 'foo', {'member': 1}) - with pytest.raises(redis.ResponseError): - r.scard('foo') - - -def test_sdiff(r): - r.sadd('foo', 'member1') - r.sadd('foo', 'member2') - r.sadd('bar', 'member2') - r.sadd('bar', 'member3') - assert r.sdiff('foo', 'bar') == {b'member1'} - # Original sets shouldn't be modified. - assert r.smembers('foo') == {b'member1', b'member2'} - assert r.smembers('bar') == {b'member2', b'member3'} - - -def test_sdiff_one_key(r): - r.sadd('foo', 'member1') - r.sadd('foo', 'member2') - assert r.sdiff('foo') == {b'member1', b'member2'} - - -def test_sdiff_empty(r): - assert r.sdiff('foo') == set() - - -def test_sdiff_wrong_type(r): - zadd(r, 'foo', {'member': 1}) - r.sadd('bar', 'member') - with pytest.raises(redis.ResponseError): - r.sdiff('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.sdiff('bar', 'foo') - - -def test_sdiffstore(r): - r.sadd('foo', 'member1') - r.sadd('foo', 'member2') - r.sadd('bar', 'member2') - r.sadd('bar', 'member3') - assert r.sdiffstore('baz', 'foo', 'bar') == 1 - - # Catch instances where we store bytes and strings inconsistently - # and thus baz = {'member1', b'member1'} - r.sadd('baz', 'member1') - assert r.scard('baz') == 1 - - -def test_setrange(r): - r.set('foo', 'test') - assert r.setrange('foo', 1, 'aste') == 5 - assert r.get('foo') == b'taste' - - r.set('foo', 'test') - assert r.setrange('foo', 1, 'a') == 4 - assert r.get('foo') == b'tast' - - assert r.setrange('bar', 2, 'test') == 6 - assert r.get('bar') == b'\x00\x00test' - - -def test_setrange_expiry(r): - r.set('foo', 'test', ex=10) - r.setrange('foo', 1, 'aste') - assert r.ttl('foo') > 0 - - -def test_sinter(r): - r.sadd('foo', 'member1') - r.sadd('foo', 'member2') - r.sadd('bar', 'member2') - r.sadd('bar', 'member3') - assert r.sinter('foo', 'bar') == {b'member2'} - assert r.sinter('foo') == {b'member1', b'member2'} - - -def test_sinter_bytes_keys(r): - foo = os.urandom(10) - bar = os.urandom(10) - r.sadd(foo, 'member1') - r.sadd(foo, 'member2') - r.sadd(bar, 'member2') - r.sadd(bar, 'member3') - assert r.sinter(foo, bar) == {b'member2'} - assert r.sinter(foo) == {b'member1', b'member2'} - - -def test_sinter_wrong_type(r): - zadd(r, 'foo', {'member': 1}) - r.sadd('bar', 'member') - with pytest.raises(redis.ResponseError): - r.sinter('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.sinter('bar', 'foo') - - -def test_sinterstore(r): - r.sadd('foo', 'member1') - r.sadd('foo', 'member2') - r.sadd('bar', 'member2') - r.sadd('bar', 'member3') - assert r.sinterstore('baz', 'foo', 'bar') == 1 - - # Catch instances where we store bytes and strings inconsistently - # and thus baz = {'member2', b'member2'} - r.sadd('baz', 'member2') - assert r.scard('baz') == 1 - - -def test_sismember(r): - assert r.sismember('foo', 'member1') is False - r.sadd('foo', 'member1') - assert r.sismember('foo', 'member1') is True - - -def test_sismember_wrong_type(r): - zadd(r, 'foo', {'member': 1}) - with pytest.raises(redis.ResponseError): - r.sismember('foo', 'member') - - -def test_smembers(r): - assert r.smembers('foo') == set() - - -def test_smembers_copy(r): - r.sadd('foo', 'member1') - set = r.smembers('foo') - r.sadd('foo', 'member2') - assert r.smembers('foo') != set - - -def test_smembers_wrong_type(r): - zadd(r, 'foo', {'member': 1}) - with pytest.raises(redis.ResponseError): - r.smembers('foo') - - -def test_smembers_runtime_error(r): - r.sadd('foo', 'member1', 'member2') - for member in r.smembers('foo'): - r.srem('foo', member) - - -def test_smove(r): - r.sadd('foo', 'member1') - r.sadd('foo', 'member2') - assert r.smove('foo', 'bar', 'member1') is True - assert r.smembers('bar') == {b'member1'} - - -def test_smove_non_existent_key(r): - assert r.smove('foo', 'bar', 'member1') is False - - -def test_smove_wrong_type(r): - zadd(r, 'foo', {'member': 1}) - r.sadd('bar', 'member') - with pytest.raises(redis.ResponseError): - r.smove('bar', 'foo', 'member') - # Must raise the error before removing member from bar - assert r.smembers('bar') == {b'member'} - with pytest.raises(redis.ResponseError): - r.smove('foo', 'bar', 'member') - - -def test_spop(r): - # This is tricky because it pops a random element. - r.sadd('foo', 'member1') - assert r.spop('foo') == b'member1' - assert r.spop('foo') is None - - -def test_spop_wrong_type(r): - zadd(r, 'foo', {'member': 1}) - with pytest.raises(redis.ResponseError): - r.spop('foo') - - -def test_srandmember(r): - r.sadd('foo', 'member1') - assert r.srandmember('foo') == b'member1' - # Shouldn't be removed from the set. - assert r.srandmember('foo') == b'member1' - - -def test_srandmember_number(r): - """srandmember works with the number argument.""" - assert r.srandmember('foo', 2) == [] - r.sadd('foo', b'member1') - assert r.srandmember('foo', 2) == [b'member1'] - r.sadd('foo', b'member2') - assert set(r.srandmember('foo', 2)) == {b'member1', b'member2'} - r.sadd('foo', b'member3') - res = r.srandmember('foo', 2) - assert len(res) == 2 - for e in res: - assert e in {b'member1', b'member2', b'member3'} - - -def test_srandmember_wrong_type(r): - zadd(r, 'foo', {'member': 1}) - with pytest.raises(redis.ResponseError): - r.srandmember('foo') - - -def test_srem(r): - r.sadd('foo', 'member1', 'member2', 'member3', 'member4') - assert r.smembers('foo') == {b'member1', b'member2', b'member3', b'member4'} - assert r.srem('foo', 'member1') == 1 - assert r.smembers('foo') == {b'member2', b'member3', b'member4'} - assert r.srem('foo', 'member1') == 0 - # Since redis>=2.7.6 returns number of deleted items. - assert r.srem('foo', 'member2', 'member3') == 2 - assert r.smembers('foo') == {b'member4'} - assert r.srem('foo', 'member3', 'member4') == 1 - assert r.smembers('foo') == set() - assert r.srem('foo', 'member3', 'member4') == 0 - - -def test_srem_wrong_type(r): - zadd(r, 'foo', {'member': 1}) - with pytest.raises(redis.ResponseError): - r.srem('foo', 'member') - - -def test_sunion(r): - r.sadd('foo', 'member1') - r.sadd('foo', 'member2') - r.sadd('bar', 'member2') - r.sadd('bar', 'member3') - assert r.sunion('foo', 'bar') == {b'member1', b'member2', b'member3'} - - -def test_sunion_wrong_type(r): - zadd(r, 'foo', {'member': 1}) - r.sadd('bar', 'member') - with pytest.raises(redis.ResponseError): - r.sunion('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.sunion('bar', 'foo') - - -def test_sunionstore(r): - r.sadd('foo', 'member1') - r.sadd('foo', 'member2') - r.sadd('bar', 'member2') - r.sadd('bar', 'member3') - assert r.sunionstore('baz', 'foo', 'bar') == 3 - assert r.smembers('baz') == {b'member1', b'member2', b'member3'} - - # Catch instances where we store bytes and strings inconsistently - # and thus baz = {b'member1', b'member2', b'member3', 'member3'} - r.sadd('baz', 'member3') - assert r.scard('baz') == 3 - - -def test_empty_set(r): - r.sadd('foo', 'bar') - r.srem('foo', 'bar') - assert not r.exists('foo') - - -def test_zadd(r): - zadd(r, 'foo', {'four': 4}) - zadd(r, 'foo', {'three': 3}) - assert zadd(r, 'foo', {'two': 2, 'one': 1, 'zero': 0}) == 3 - assert r.zrange('foo', 0, -1) == [b'zero', b'one', b'two', b'three', b'four'] - assert zadd(r, 'foo', {'zero': 7, 'one': 1, 'five': 5}) == 1 - assert ( - r.zrange('foo', 0, -1) - == [b'one', b'two', b'three', b'four', b'five', b'zero'] - ) - - -@redis2_only -def test_zadd_uses_str(r): - r.zadd('foo', 12345, (1, 2, 3)) - assert r.zrange('foo', 0, 0) == [b'(1, 2, 3)'] - - -@redis2_only -def test_zadd_errors(r): - # The args are backwards, it should be 2, "two", so we - # expect an exception to be raised. - with pytest.raises(redis.ResponseError): - r.zadd('foo', 'two', 2) - with pytest.raises(redis.ResponseError): - r.zadd('foo', two='two') - # It's expected an equal number of values and scores - with pytest.raises(redis.RedisError): - r.zadd('foo', 'two') - - -def test_zadd_empty(r): - # Have to add at least one key/value pair - with pytest.raises(redis.RedisError): - zadd(r, 'foo', {}) - - -def test_zadd_minus_zero(r): - # Changing -0 to +0 is ignored - zadd(r, 'foo', {'a': -0.0}) - zadd(r, 'foo', {'a': 0.0}) - assert raw_command(r, 'zscore', 'foo', 'a') == b'-0' - - -def test_zadd_wrong_type(r): - r.sadd('foo', 'bar') - with pytest.raises(redis.ResponseError): - zadd(r, 'foo', {'two': 2}) - - -def test_zadd_multiple(r): - zadd(r, 'foo', {'one': 1, 'two': 2}) - assert r.zrange('foo', 0, 0) == [b'one'] - assert r.zrange('foo', 1, 1) == [b'two'] - - -@redis3_only -@pytest.mark.parametrize( - 'input,return_value,state', - [ - ({'four': 2.0, 'three': 1.0}, 0, [(b'three', 3.0), (b'four', 4.0)]), - ({'four': 2.0, 'three': 1.0, 'zero': 0.0}, 1, [(b'zero', 0.0), (b'three', 3.0), (b'four', 4.0)]), - ({'two': 2.0, 'one': 1.0}, 2, [(b'one', 1.0), (b'two', 2.0), (b'three', 3.0), (b'four', 4.0)]) - ] -) -@pytest.mark.parametrize('ch', [False, True]) -def test_zadd_with_nx(r, input, return_value, state, ch): - zadd(r, 'foo', {'four': 4.0, 'three': 3.0}) - assert zadd(r, 'foo', input, nx=True, ch=ch) == return_value - assert r.zrange('foo', 0, -1, withscores=True) == state - - -@redis3_only -@pytest.mark.parametrize( - 'input,return_value,state', - [ - ({'four': 4.0, 'three': 1.0}, 1, [(b'three', 1.0), (b'four', 4.0)]), - ({'four': 4.0, 'three': 1.0, 'zero': 0.0}, 2, [(b'zero', 0.0), (b'three', 1.0), (b'four', 4.0)]), - ({'two': 2.0, 'one': 1.0}, 2, [(b'one', 1.0), (b'two', 2.0), (b'three', 3.0), (b'four', 4.0)]) - ] -) -def test_zadd_with_ch(r, input, return_value, state): - zadd(r, 'foo', {'four': 4.0, 'three': 3.0}) - assert zadd(r, 'foo', input, ch=True) == return_value - assert r.zrange('foo', 0, -1, withscores=True) == state - - -@redis3_only -@pytest.mark.parametrize( - 'input,changed,state', - [ - ({'four': 2.0, 'three': 1.0}, 2, [(b'three', 1.0), (b'four', 2.0)]), - ({'four': 4.0, 'three': 3.0, 'zero': 0.0}, 0, [(b'three', 3.0), (b'four', 4.0)]), - ({'two': 2.0, 'one': 1.0}, 0, [(b'three', 3.0), (b'four', 4.0)]) - ] -) -@pytest.mark.parametrize('ch', [False, True]) -def test_zadd_with_xx(r, input, changed, state, ch): - zadd(r, 'foo', {'four': 4.0, 'three': 3.0}) - assert zadd(r, 'foo', input, xx=True, ch=ch) == (changed if ch else 0) - assert r.zrange('foo', 0, -1, withscores=True) == state - - -@redis3_only -@pytest.mark.parametrize('ch', [False, True]) -def test_zadd_with_nx_and_xx(r, ch): - zadd(r, 'foo', {'four': 4.0, 'three': 3.0}) - with pytest.raises(redis.DataError): - zadd(r, 'foo', {'four': -4.0, 'three': -3.0}, nx=True, xx=True, ch=ch) - - -@pytest.mark.skipif(REDIS_VERSION < Version('3.1'), reason="Test is only applicable to redis-py 3.1+") -@pytest.mark.parametrize('ch', [False, True]) -def test_zadd_incr(r, ch): - zadd(r, 'foo', {'four': 4.0, 'three': 3.0}) - assert zadd(r, 'foo', {'four': 1.0}, incr=True, ch=ch) == 5.0 - assert zadd(r, 'foo', {'three': 1.0}, incr=True, nx=True, ch=ch) is None - assert r.zscore('foo', 'three') == 3.0 - assert zadd(r, 'foo', {'bar': 1.0}, incr=True, xx=True, ch=ch) is None - assert zadd(r, 'foo', {'three': 1.0}, incr=True, xx=True, ch=ch) == 4.0 - - -def test_zrange_same_score(r): - zadd(r, 'foo', {'two_a': 2}) - zadd(r, 'foo', {'two_b': 2}) - zadd(r, 'foo', {'two_c': 2}) - zadd(r, 'foo', {'two_d': 2}) - zadd(r, 'foo', {'two_e': 2}) - assert r.zrange('foo', 2, 3) == [b'two_c', b'two_d'] - - -def test_zcard(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'two': 2}) - assert r.zcard('foo') == 2 - - -def test_zcard_non_existent_key(r): - assert r.zcard('foo') == 0 - - -def test_zcard_wrong_type(r): - r.sadd('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.zcard('foo') - - -def test_zcount(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'three': 2}) - zadd(r, 'foo', {'five': 5}) - assert r.zcount('foo', 2, 4) == 1 - assert r.zcount('foo', 1, 4) == 2 - assert r.zcount('foo', 0, 5) == 3 - assert r.zcount('foo', 4, '+inf') == 1 - assert r.zcount('foo', '-inf', 4) == 2 - assert r.zcount('foo', '-inf', '+inf') == 3 - - -def test_zcount_exclusive(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'three': 2}) - zadd(r, 'foo', {'five': 5}) - assert r.zcount('foo', '-inf', '(2') == 1 - assert r.zcount('foo', '-inf', 2) == 2 - assert r.zcount('foo', '(5', '+inf') == 0 - assert r.zcount('foo', '(1', 5) == 2 - assert r.zcount('foo', '(2', '(5') == 0 - assert r.zcount('foo', '(1', '(5') == 1 - assert r.zcount('foo', 2, '(5') == 1 - - -def test_zcount_wrong_type(r): - r.sadd('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.zcount('foo', '-inf', '+inf') - - -def test_zincrby(r): - zadd(r, 'foo', {'one': 1}) - assert zincrby(r, 'foo', 10, 'one') == 11 - assert r.zrange('foo', 0, -1, withscores=True) == [(b'one', 11)] - - -def test_zincrby_wrong_type(r): - r.sadd('foo', 'bar') - with pytest.raises(redis.ResponseError): - zincrby(r, 'foo', 10, 'one') - - -def test_zrange_descending(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'foo', {'three': 3}) - assert r.zrange('foo', 0, -1, desc=True) == [b'three', b'two', b'one'] - - -def test_zrange_descending_with_scores(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'foo', {'three': 3}) - assert ( - r.zrange('foo', 0, -1, desc=True, withscores=True) - == [(b'three', 3), (b'two', 2), (b'one', 1)] - ) - - -def test_zrange_with_positive_indices(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'foo', {'three': 3}) - assert r.zrange('foo', 0, 1) == [b'one', b'two'] - - -def test_zrange_wrong_type(r): - r.sadd('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.zrange('foo', 0, -1) - - -def test_zrange_score_cast(r): - zadd(r, 'foo', {'one': 1.2}) - zadd(r, 'foo', {'two': 2.2}) - - expected_without_cast_round = [(b'one', 1.2), (b'two', 2.2)] - expected_with_cast_round = [(b'one', 1.0), (b'two', 2.0)] - assert r.zrange('foo', 0, 2, withscores=True) == expected_without_cast_round - assert ( - r.zrange('foo', 0, 2, withscores=True, score_cast_func=round_str) - == expected_with_cast_round - ) - - -def test_zrank(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'foo', {'three': 3}) - assert r.zrank('foo', 'one') == 0 - assert r.zrank('foo', 'two') == 1 - assert r.zrank('foo', 'three') == 2 - - -def test_zrank_non_existent_member(r): - assert r.zrank('foo', 'one') is None - - -def test_zrank_wrong_type(r): - r.sadd('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.zrank('foo', 'one') - - -def test_zrem(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'foo', {'three': 3}) - zadd(r, 'foo', {'four': 4}) - assert r.zrem('foo', 'one') == 1 - assert r.zrange('foo', 0, -1) == [b'two', b'three', b'four'] - # Since redis>=2.7.6 returns number of deleted items. - assert r.zrem('foo', 'two', 'three') == 2 - assert r.zrange('foo', 0, -1) == [b'four'] - assert r.zrem('foo', 'three', 'four') == 1 - assert r.zrange('foo', 0, -1) == [] - assert r.zrem('foo', 'three', 'four') == 0 - - -def test_zrem_non_existent_member(r): - assert not r.zrem('foo', 'one') - - -def test_zrem_numeric_member(r): - zadd(r, 'foo', {'128': 13.0, '129': 12.0}) - assert r.zrem('foo', 128) == 1 - assert r.zrange('foo', 0, -1) == [b'129'] - - -def test_zrem_wrong_type(r): - r.sadd('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.zrem('foo', 'bar') - - -def test_zscore(r): - zadd(r, 'foo', {'one': 54}) - assert r.zscore('foo', 'one') == 54 - - -def test_zscore_non_existent_member(r): - assert r.zscore('foo', 'one') is None - - -def test_zscore_wrong_type(r): - r.sadd('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.zscore('foo', 'one') - - -def test_zrevrank(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'foo', {'three': 3}) - assert r.zrevrank('foo', 'one') == 2 - assert r.zrevrank('foo', 'two') == 1 - assert r.zrevrank('foo', 'three') == 0 - - -def test_zrevrank_non_existent_member(r): - assert r.zrevrank('foo', 'one') is None - - -def test_zrevrank_wrong_type(r): - r.sadd('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.zrevrank('foo', 'one') - - -def test_zrevrange(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'foo', {'three': 3}) - assert r.zrevrange('foo', 0, 1) == [b'three', b'two'] - assert r.zrevrange('foo', 0, -1) == [b'three', b'two', b'one'] - - -def test_zrevrange_sorted_keys(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'foo', {'two_b': 2}) - zadd(r, 'foo', {'three': 3}) - assert r.zrevrange('foo', 0, 2) == [b'three', b'two_b', b'two'] - assert r.zrevrange('foo', 0, -1) == [b'three', b'two_b', b'two', b'one'] - - -def test_zrevrange_wrong_type(r): - r.sadd('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.zrevrange('foo', 0, 2) - - -def test_zrevrange_score_cast(r): - zadd(r, 'foo', {'one': 1.2}) - zadd(r, 'foo', {'two': 2.2}) - - expected_without_cast_round = [(b'two', 2.2), (b'one', 1.2)] - expected_with_cast_round = [(b'two', 2.0), (b'one', 1.0)] - assert r.zrevrange('foo', 0, 2, withscores=True) == expected_without_cast_round - assert ( - r.zrevrange('foo', 0, 2, withscores=True, score_cast_func=round_str) - == expected_with_cast_round - ) - - -def test_zrangebyscore(r): - zadd(r, 'foo', {'zero': 0}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'foo', {'two_a_also': 2}) - zadd(r, 'foo', {'two_b_also': 2}) - zadd(r, 'foo', {'four': 4}) - assert r.zrangebyscore('foo', 1, 3) == [b'two', b'two_a_also', b'two_b_also'] - assert r.zrangebyscore('foo', 2, 3) == [b'two', b'two_a_also', b'two_b_also'] - assert ( - r.zrangebyscore('foo', 0, 4) - == [b'zero', b'two', b'two_a_also', b'two_b_also', b'four'] - ) - assert r.zrangebyscore('foo', '-inf', 1) == [b'zero'] - assert ( - r.zrangebyscore('foo', 2, '+inf') - == [b'two', b'two_a_also', b'two_b_also', b'four'] - ) - assert ( - r.zrangebyscore('foo', '-inf', '+inf') - == [b'zero', b'two', b'two_a_also', b'two_b_also', b'four'] - ) - - -def test_zrangebysore_exclusive(r): - zadd(r, 'foo', {'zero': 0}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'foo', {'four': 4}) - zadd(r, 'foo', {'five': 5}) - assert r.zrangebyscore('foo', '(0', 6) == [b'two', b'four', b'five'] - assert r.zrangebyscore('foo', '(2', '(5') == [b'four'] - assert r.zrangebyscore('foo', 0, '(4') == [b'zero', b'two'] - - -def test_zrangebyscore_raises_error(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'foo', {'three': 3}) - with pytest.raises(redis.ResponseError): - r.zrangebyscore('foo', 'one', 2) - with pytest.raises(redis.ResponseError): - r.zrangebyscore('foo', 2, 'three') - with pytest.raises(redis.ResponseError): - r.zrangebyscore('foo', 2, '3)') - with pytest.raises(redis.RedisError): - r.zrangebyscore('foo', 2, '3)', 0, None) - - -def test_zrangebyscore_wrong_type(r): - r.sadd('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.zrangebyscore('foo', '(1', '(2') - - -def test_zrangebyscore_slice(r): - zadd(r, 'foo', {'two_a': 2}) - zadd(r, 'foo', {'two_b': 2}) - zadd(r, 'foo', {'two_c': 2}) - zadd(r, 'foo', {'two_d': 2}) - assert r.zrangebyscore('foo', 0, 4, 0, 2) == [b'two_a', b'two_b'] - assert r.zrangebyscore('foo', 0, 4, 1, 3) == [b'two_b', b'two_c', b'two_d'] - - -def test_zrangebyscore_withscores(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'foo', {'three': 3}) - assert r.zrangebyscore('foo', 1, 3, 0, 2, True) == [(b'one', 1), (b'two', 2)] - - -def test_zrangebyscore_cast_scores(r): - zadd(r, 'foo', {'two': 2}) - zadd(r, 'foo', {'two_a_also': 2.2}) - - expected_without_cast_round = [(b'two', 2.0), (b'two_a_also', 2.2)] - expected_with_cast_round = [(b'two', 2.0), (b'two_a_also', 2.0)] - assert ( - sorted(r.zrangebyscore('foo', 2, 3, withscores=True)) - == sorted(expected_without_cast_round) - ) - assert ( - sorted(r.zrangebyscore('foo', 2, 3, withscores=True, - score_cast_func=round_str)) - == sorted(expected_with_cast_round) - ) - - -def test_zrevrangebyscore(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'foo', {'three': 3}) - assert r.zrevrangebyscore('foo', 3, 1) == [b'three', b'two', b'one'] - assert r.zrevrangebyscore('foo', 3, 2) == [b'three', b'two'] - assert r.zrevrangebyscore('foo', 3, 1, 0, 1) == [b'three'] - assert r.zrevrangebyscore('foo', 3, 1, 1, 2) == [b'two', b'one'] - - -def test_zrevrangebyscore_exclusive(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'foo', {'three': 3}) - assert r.zrevrangebyscore('foo', '(3', 1) == [b'two', b'one'] - assert r.zrevrangebyscore('foo', 3, '(2') == [b'three'] - assert r.zrevrangebyscore('foo', '(3', '(1') == [b'two'] - assert r.zrevrangebyscore('foo', '(2', 1, 0, 1) == [b'one'] - assert r.zrevrangebyscore('foo', '(2', '(1', 0, 1) == [] - assert r.zrevrangebyscore('foo', '(3', '(0', 1, 2) == [b'one'] - - -def test_zrevrangebyscore_raises_error(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'foo', {'three': 3}) - with pytest.raises(redis.ResponseError): - r.zrevrangebyscore('foo', 'three', 1) - with pytest.raises(redis.ResponseError): - r.zrevrangebyscore('foo', 3, 'one') - with pytest.raises(redis.ResponseError): - r.zrevrangebyscore('foo', 3, '1)') - with pytest.raises(redis.ResponseError): - r.zrevrangebyscore('foo', '((3', '1)') - - -def test_zrevrangebyscore_wrong_type(r): - r.sadd('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.zrevrangebyscore('foo', '(3', '(1') - - -def test_zrevrangebyscore_cast_scores(r): - zadd(r, 'foo', {'two': 2}) - zadd(r, 'foo', {'two_a_also': 2.2}) - - expected_without_cast_round = [(b'two_a_also', 2.2), (b'two', 2.0)] - expected_with_cast_round = [(b'two_a_also', 2.0), (b'two', 2.0)] - assert ( - r.zrevrangebyscore('foo', 3, 2, withscores=True) - == expected_without_cast_round - ) - assert ( - r.zrevrangebyscore('foo', 3, 2, withscores=True, - score_cast_func=round_str) - == expected_with_cast_round - ) - - -def test_zrangebylex(r): - zadd(r, 'foo', {'one_a': 0}) - zadd(r, 'foo', {'two_a': 0}) - zadd(r, 'foo', {'two_b': 0}) - zadd(r, 'foo', {'three_a': 0}) - assert r.zrangebylex('foo', b'(t', b'+') == [b'three_a', b'two_a', b'two_b'] - assert r.zrangebylex('foo', b'(t', b'[two_b') == [b'three_a', b'two_a', b'two_b'] - assert r.zrangebylex('foo', b'(t', b'(two_b') == [b'three_a', b'two_a'] - assert ( - r.zrangebylex('foo', b'[three_a', b'[two_b') - == [b'three_a', b'two_a', b'two_b'] - ) - assert r.zrangebylex('foo', b'(three_a', b'[two_b') == [b'two_a', b'two_b'] - assert r.zrangebylex('foo', b'-', b'(two_b') == [b'one_a', b'three_a', b'two_a'] - assert r.zrangebylex('foo', b'[two_b', b'(two_b') == [] - # reversed max + and min - boundaries - # these will be always empty, but allowed by redis - assert r.zrangebylex('foo', b'+', b'-') == [] - assert r.zrangebylex('foo', b'+', b'[three_a') == [] - assert r.zrangebylex('foo', b'[o', b'-') == [] - - -def test_zrangebylex_wrong_type(r): - r.sadd('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.zrangebylex('foo', b'-', b'+') - - -def test_zlexcount(r): - zadd(r, 'foo', {'one_a': 0}) - zadd(r, 'foo', {'two_a': 0}) - zadd(r, 'foo', {'two_b': 0}) - zadd(r, 'foo', {'three_a': 0}) - assert r.zlexcount('foo', b'(t', b'+') == 3 - assert r.zlexcount('foo', b'(t', b'[two_b') == 3 - assert r.zlexcount('foo', b'(t', b'(two_b') == 2 - assert r.zlexcount('foo', b'[three_a', b'[two_b') == 3 - assert r.zlexcount('foo', b'(three_a', b'[two_b') == 2 - assert r.zlexcount('foo', b'-', b'(two_b') == 3 - assert r.zlexcount('foo', b'[two_b', b'(two_b') == 0 - # reversed max + and min - boundaries - # these will be always empty, but allowed by redis - assert r.zlexcount('foo', b'+', b'-') == 0 - assert r.zlexcount('foo', b'+', b'[three_a') == 0 - assert r.zlexcount('foo', b'[o', b'-') == 0 - - -def test_zlexcount_wrong_type(r): - r.sadd('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.zlexcount('foo', b'-', b'+') - - -def test_zrangebylex_with_limit(r): - zadd(r, 'foo', {'one_a': 0}) - zadd(r, 'foo', {'two_a': 0}) - zadd(r, 'foo', {'two_b': 0}) - zadd(r, 'foo', {'three_a': 0}) - assert r.zrangebylex('foo', b'-', b'+', 1, 2) == [b'three_a', b'two_a'] - - # negative offset no results - assert r.zrangebylex('foo', b'-', b'+', -1, 3) == [] - - # negative limit ignored - assert ( - r.zrangebylex('foo', b'-', b'+', 0, -2) - == [b'one_a', b'three_a', b'two_a', b'two_b'] - ) - assert r.zrangebylex('foo', b'-', b'+', 1, -2) == [b'three_a', b'two_a', b'two_b'] - assert r.zrangebylex('foo', b'+', b'-', 1, 1) == [] - - -def test_zrangebylex_raises_error(r): - zadd(r, 'foo', {'one_a': 0}) - zadd(r, 'foo', {'two_a': 0}) - zadd(r, 'foo', {'two_b': 0}) - zadd(r, 'foo', {'three_a': 0}) - - with pytest.raises(redis.ResponseError): - r.zrangebylex('foo', b'', b'[two_b') - - with pytest.raises(redis.ResponseError): - r.zrangebylex('foo', b'-', b'two_b') - - with pytest.raises(redis.ResponseError): - r.zrangebylex('foo', b'(t', b'two_b') - - with pytest.raises(redis.ResponseError): - r.zrangebylex('foo', b't', b'+') - - with pytest.raises(redis.ResponseError): - r.zrangebylex('foo', b'[two_a', b'') - - with pytest.raises(redis.RedisError): - r.zrangebylex('foo', b'(two_a', b'[two_b', 1) - - -def test_zrevrangebylex(r): - zadd(r, 'foo', {'one_a': 0}) - zadd(r, 'foo', {'two_a': 0}) - zadd(r, 'foo', {'two_b': 0}) - zadd(r, 'foo', {'three_a': 0}) - assert r.zrevrangebylex('foo', b'+', b'(t') == [b'two_b', b'two_a', b'three_a'] - assert ( - r.zrevrangebylex('foo', b'[two_b', b'(t') - == [b'two_b', b'two_a', b'three_a'] - ) - assert r.zrevrangebylex('foo', b'(two_b', b'(t') == [b'two_a', b'three_a'] - assert ( - r.zrevrangebylex('foo', b'[two_b', b'[three_a') - == [b'two_b', b'two_a', b'three_a'] - ) - assert r.zrevrangebylex('foo', b'[two_b', b'(three_a') == [b'two_b', b'two_a'] - assert r.zrevrangebylex('foo', b'(two_b', b'-') == [b'two_a', b'three_a', b'one_a'] - assert r.zrangebylex('foo', b'(two_b', b'[two_b') == [] - # reversed max + and min - boundaries - # these will be always empty, but allowed by redis - assert r.zrevrangebylex('foo', b'-', b'+') == [] - assert r.zrevrangebylex('foo', b'[three_a', b'+') == [] - assert r.zrevrangebylex('foo', b'-', b'[o') == [] - - -def test_zrevrangebylex_with_limit(r): - zadd(r, 'foo', {'one_a': 0}) - zadd(r, 'foo', {'two_a': 0}) - zadd(r, 'foo', {'two_b': 0}) - zadd(r, 'foo', {'three_a': 0}) - assert r.zrevrangebylex('foo', b'+', b'-', 1, 2) == [b'two_a', b'three_a'] - - -def test_zrevrangebylex_raises_error(r): - zadd(r, 'foo', {'one_a': 0}) - zadd(r, 'foo', {'two_a': 0}) - zadd(r, 'foo', {'two_b': 0}) - zadd(r, 'foo', {'three_a': 0}) - - with pytest.raises(redis.ResponseError): - r.zrevrangebylex('foo', b'[two_b', b'') - - with pytest.raises(redis.ResponseError): - r.zrevrangebylex('foo', b'two_b', b'-') - - with pytest.raises(redis.ResponseError): - r.zrevrangebylex('foo', b'two_b', b'(t') - - with pytest.raises(redis.ResponseError): - r.zrevrangebylex('foo', b'+', b't') - - with pytest.raises(redis.ResponseError): - r.zrevrangebylex('foo', b'', b'[two_a') - - with pytest.raises(redis.RedisError): - r.zrevrangebylex('foo', b'[two_a', b'(two_b', 1) - - -def test_zrevrangebylex_wrong_type(r): - r.sadd('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.zrevrangebylex('foo', b'+', b'-') - - -def test_zremrangebyrank(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'foo', {'three': 3}) - assert r.zremrangebyrank('foo', 0, 1) == 2 - assert r.zrange('foo', 0, -1) == [b'three'] - - -def test_zremrangebyrank_negative_indices(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'foo', {'three': 3}) - assert r.zremrangebyrank('foo', -2, -1) == 2 - assert r.zrange('foo', 0, -1) == [b'one'] - - -def test_zremrangebyrank_out_of_bounds(r): - zadd(r, 'foo', {'one': 1}) - assert r.zremrangebyrank('foo', 1, 3) == 0 - - -def test_zremrangebyrank_wrong_type(r): - r.sadd('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.zremrangebyrank('foo', 1, 3) - - -def test_zremrangebyscore(r): - zadd(r, 'foo', {'zero': 0}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'foo', {'four': 4}) - # Outside of range. - assert r.zremrangebyscore('foo', 5, 10) == 0 - assert r.zrange('foo', 0, -1) == [b'zero', b'two', b'four'] - # Middle of range. - assert r.zremrangebyscore('foo', 1, 3) == 1 - assert r.zrange('foo', 0, -1) == [b'zero', b'four'] - assert r.zremrangebyscore('foo', 1, 3) == 0 - # Entire range. - assert r.zremrangebyscore('foo', 0, 4) == 2 - assert r.zrange('foo', 0, -1) == [] - - -def test_zremrangebyscore_exclusive(r): - zadd(r, 'foo', {'zero': 0}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'foo', {'four': 4}) - assert r.zremrangebyscore('foo', '(0', 1) == 0 - assert r.zrange('foo', 0, -1) == [b'zero', b'two', b'four'] - assert r.zremrangebyscore('foo', '-inf', '(0') == 0 - assert r.zrange('foo', 0, -1) == [b'zero', b'two', b'four'] - assert r.zremrangebyscore('foo', '(2', 5) == 1 - assert r.zrange('foo', 0, -1) == [b'zero', b'two'] - assert r.zremrangebyscore('foo', 0, '(2') == 1 - assert r.zrange('foo', 0, -1) == [b'two'] - assert r.zremrangebyscore('foo', '(1', '(3') == 1 - assert r.zrange('foo', 0, -1) == [] - - -def test_zremrangebyscore_raises_error(r): - zadd(r, 'foo', {'zero': 0}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'foo', {'four': 4}) - with pytest.raises(redis.ResponseError): - r.zremrangebyscore('foo', 'three', 1) - with pytest.raises(redis.ResponseError): - r.zremrangebyscore('foo', 3, 'one') - with pytest.raises(redis.ResponseError): - r.zremrangebyscore('foo', 3, '1)') - with pytest.raises(redis.ResponseError): - r.zremrangebyscore('foo', '((3', '1)') - - -def test_zremrangebyscore_badkey(r): - assert r.zremrangebyscore('foo', 0, 2) == 0 - - -def test_zremrangebyscore_wrong_type(r): - r.sadd('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.zremrangebyscore('foo', 0, 2) - - -def test_zremrangebylex(r): - zadd(r, 'foo', {'two_a': 0}) - zadd(r, 'foo', {'two_b': 0}) - zadd(r, 'foo', {'one_a': 0}) - zadd(r, 'foo', {'three_a': 0}) - assert r.zremrangebylex('foo', b'(three_a', b'[two_b') == 2 - assert r.zremrangebylex('foo', b'(three_a', b'[two_b') == 0 - assert r.zremrangebylex('foo', b'-', b'(o') == 0 - assert r.zremrangebylex('foo', b'-', b'[one_a') == 1 - assert r.zremrangebylex('foo', b'[tw', b'+') == 0 - assert r.zremrangebylex('foo', b'[t', b'+') == 1 - assert r.zremrangebylex('foo', b'[t', b'+') == 0 - - -def test_zremrangebylex_error(r): - zadd(r, 'foo', {'two_a': 0}) - zadd(r, 'foo', {'two_b': 0}) - zadd(r, 'foo', {'one_a': 0}) - zadd(r, 'foo', {'three_a': 0}) - with pytest.raises(redis.ResponseError): - r.zremrangebylex('foo', b'(t', b'two_b') - - with pytest.raises(redis.ResponseError): - r.zremrangebylex('foo', b't', b'+') - - with pytest.raises(redis.ResponseError): - r.zremrangebylex('foo', b'[two_a', b'') - - -def test_zremrangebylex_badkey(r): - assert r.zremrangebylex('foo', b'(three_a', b'[two_b') == 0 - - -def test_zremrangebylex_wrong_type(r): - r.sadd('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.zremrangebylex('foo', b'bar', b'baz') - - -def test_zunionstore(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'bar', {'one': 1}) - zadd(r, 'bar', {'two': 2}) - zadd(r, 'bar', {'three': 3}) - r.zunionstore('baz', ['foo', 'bar']) - assert ( - r.zrange('baz', 0, -1, withscores=True) - == [(b'one', 2), (b'three', 3), (b'two', 4)] - ) - - -def test_zunionstore_sum(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'bar', {'one': 1}) - zadd(r, 'bar', {'two': 2}) - zadd(r, 'bar', {'three': 3}) - r.zunionstore('baz', ['foo', 'bar'], aggregate='SUM') - assert ( - r.zrange('baz', 0, -1, withscores=True) - == [(b'one', 2), (b'three', 3), (b'two', 4)] - ) - - -def test_zunionstore_max(r): - zadd(r, 'foo', {'one': 0}) - zadd(r, 'foo', {'two': 0}) - zadd(r, 'bar', {'one': 1}) - zadd(r, 'bar', {'two': 2}) - zadd(r, 'bar', {'three': 3}) - r.zunionstore('baz', ['foo', 'bar'], aggregate='MAX') - assert ( - r.zrange('baz', 0, -1, withscores=True) - == [(b'one', 1), (b'two', 2), (b'three', 3)] - ) - - -def test_zunionstore_min(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'bar', {'one': 0}) - zadd(r, 'bar', {'two': 0}) - zadd(r, 'bar', {'three': 3}) - r.zunionstore('baz', ['foo', 'bar'], aggregate='MIN') - assert ( - r.zrange('baz', 0, -1, withscores=True) - == [(b'one', 0), (b'two', 0), (b'three', 3)] - ) - - -def test_zunionstore_weights(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'bar', {'one': 1}) - zadd(r, 'bar', {'two': 2}) - zadd(r, 'bar', {'four': 4}) - r.zunionstore('baz', {'foo': 1, 'bar': 2}, aggregate='SUM') - assert ( - r.zrange('baz', 0, -1, withscores=True) - == [(b'one', 3), (b'two', 6), (b'four', 8)] - ) - - -def test_zunionstore_nan_to_zero(r): - zadd(r, 'foo', {'x': math.inf}) - zadd(r, 'foo2', {'x': math.inf}) - r.zunionstore('bar', OrderedDict([('foo', 1.0), ('foo2', 0.0)])) - # This is different to test_zinterstore_nan_to_zero because of a quirk - # in redis. See https://github.com/antirez/redis/issues/3954. - assert r.zscore('bar', 'x') == math.inf - - -def test_zunionstore_nan_to_zero2(r): - zadd(r, 'foo', {'zero': 0}) - zadd(r, 'foo2', {'one': 1}) - zadd(r, 'foo3', {'one': 1}) - r.zunionstore('bar', {'foo': math.inf}, aggregate='SUM') - assert r.zrange('bar', 0, -1, withscores=True) == [(b'zero', 0)] - r.zunionstore('bar', OrderedDict([('foo2', math.inf), ('foo3', -math.inf)])) - assert r.zrange('bar', 0, -1, withscores=True) == [(b'one', 0)] - - -def test_zunionstore_nan_to_zero_ordering(r): - zadd(r, 'foo', {'e1': math.inf}) - zadd(r, 'bar', {'e1': -math.inf, 'e2': 0.0}) - r.zunionstore('baz', ['foo', 'bar', 'foo']) - assert r.zscore('baz', 'e1') == 0.0 - - -def test_zunionstore_mixed_set_types(r): - # No score, redis will use 1.0. - r.sadd('foo', 'one') - r.sadd('foo', 'two') - zadd(r, 'bar', {'one': 1}) - zadd(r, 'bar', {'two': 2}) - zadd(r, 'bar', {'three': 3}) - r.zunionstore('baz', ['foo', 'bar'], aggregate='SUM') - assert ( - r.zrange('baz', 0, -1, withscores=True) - == [(b'one', 2), (b'three', 3), (b'two', 3)] - ) - - -def test_zunionstore_badkey(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'two': 2}) - r.zunionstore('baz', ['foo', 'bar'], aggregate='SUM') - assert r.zrange('baz', 0, -1, withscores=True) == [(b'one', 1), (b'two', 2)] - r.zunionstore('baz', {'foo': 1, 'bar': 2}, aggregate='SUM') - assert r.zrange('baz', 0, -1, withscores=True) == [(b'one', 1), (b'two', 2)] - - -def test_zunionstore_wrong_type(r): - r.set('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.zunionstore('baz', ['foo', 'bar']) - - -def test_zinterstore(r): - zadd(r, 'foo', {'one': 1}) - zadd(r, 'foo', {'two': 2}) - zadd(r, 'bar', {'one': 1}) - zadd(r, 'bar', {'two': 2}) - zadd(r, 'bar', {'three': 3}) - r.zinterstore('baz', ['foo', 'bar']) - assert r.zrange('baz', 0, -1, withscores=True) == [(b'one', 2), (b'two', 4)] - - -def test_zinterstore_mixed_set_types(r): - r.sadd('foo', 'one') - r.sadd('foo', 'two') - zadd(r, 'bar', {'one': 1}) - zadd(r, 'bar', {'two': 2}) - zadd(r, 'bar', {'three': 3}) - r.zinterstore('baz', ['foo', 'bar'], aggregate='SUM') - assert r.zrange('baz', 0, -1, withscores=True) == [(b'one', 2), (b'two', 3)] - - -def test_zinterstore_max(r): - zadd(r, 'foo', {'one': 0}) - zadd(r, 'foo', {'two': 0}) - zadd(r, 'bar', {'one': 1}) - zadd(r, 'bar', {'two': 2}) - zadd(r, 'bar', {'three': 3}) - r.zinterstore('baz', ['foo', 'bar'], aggregate='MAX') - assert r.zrange('baz', 0, -1, withscores=True) == [(b'one', 1), (b'two', 2)] - - -def test_zinterstore_onekey(r): - zadd(r, 'foo', {'one': 1}) - r.zinterstore('baz', ['foo'], aggregate='MAX') - assert r.zrange('baz', 0, -1, withscores=True) == [(b'one', 1)] - - -def test_zinterstore_nokey(r): - with pytest.raises(redis.ResponseError): - r.zinterstore('baz', [], aggregate='MAX') - - -def test_zinterstore_nan_to_zero(r): - zadd(r, 'foo', {'x': math.inf}) - zadd(r, 'foo2', {'x': math.inf}) - r.zinterstore('bar', OrderedDict([('foo', 1.0), ('foo2', 0.0)])) - assert r.zscore('bar', 'x') == 0.0 - - -def test_zunionstore_nokey(r): - with pytest.raises(redis.ResponseError): - r.zunionstore('baz', [], aggregate='MAX') - - -def test_zinterstore_wrong_type(r): - r.set('foo', 'bar') - with pytest.raises(redis.ResponseError): - r.zinterstore('baz', ['foo', 'bar']) - - -def test_empty_zset(r): - zadd(r, 'foo', {'one': 1}) - r.zrem('foo', 'one') - assert not r.exists('foo') - - -def test_multidb(r, create_redis): - r1 = create_redis(db=0) - r2 = create_redis(db=1) - - r1['r1'] = 'r1' - r2['r2'] = 'r2' - - assert 'r2' not in r1 - assert 'r1' not in r2 - - assert r1['r1'] == b'r1' - assert r2['r2'] == b'r2' - - assert r1.flushall() is True - - assert 'r1' not in r1 - assert 'r2' not in r2 - - -def test_basic_sort(r): - r.rpush('foo', '2') - r.rpush('foo', '1') - r.rpush('foo', '3') - - assert r.sort('foo') == [b'1', b'2', b'3'] - - -def test_empty_sort(r): - assert r.sort('foo') == [] - - -def test_sort_range_offset_range(r): - r.rpush('foo', '2') - r.rpush('foo', '1') - r.rpush('foo', '4') - r.rpush('foo', '3') - - assert r.sort('foo', start=0, num=2) == [b'1', b'2'] - - -def test_sort_range_offset_range_and_desc(r): - r.rpush('foo', '2') - r.rpush('foo', '1') - r.rpush('foo', '4') - r.rpush('foo', '3') - - assert r.sort("foo", start=0, num=1, desc=True) == [b"4"] - - -def test_sort_range_offset_norange(r): - with pytest.raises(redis.RedisError): - r.sort('foo', start=1) - - -def test_sort_range_with_large_range(r): - r.rpush('foo', '2') - r.rpush('foo', '1') - r.rpush('foo', '4') - r.rpush('foo', '3') - # num=20 even though len(foo) is 4. - assert r.sort('foo', start=1, num=20) == [b'2', b'3', b'4'] - - -def test_sort_descending(r): - r.rpush('foo', '1') - r.rpush('foo', '2') - r.rpush('foo', '3') - assert r.sort('foo', desc=True) == [b'3', b'2', b'1'] - - -def test_sort_alpha(r): - r.rpush('foo', '2a') - r.rpush('foo', '1b') - r.rpush('foo', '2b') - r.rpush('foo', '1a') - - assert r.sort('foo', alpha=True) == [b'1a', b'1b', b'2a', b'2b'] - - -def test_sort_wrong_type(r): - r.set('string', '3') - with pytest.raises(redis.ResponseError): - r.sort('string') - - -def test_foo(r): - r.rpush('foo', '2a') - r.rpush('foo', '1b') - r.rpush('foo', '2b') - r.rpush('foo', '1a') - with pytest.raises(redis.ResponseError): - r.sort('foo', alpha=False) - - -def test_sort_with_store_option(r): - r.rpush('foo', '2') - r.rpush('foo', '1') - r.rpush('foo', '4') - r.rpush('foo', '3') - - assert r.sort('foo', store='bar') == 4 - assert r.lrange('bar', 0, -1) == [b'1', b'2', b'3', b'4'] - - -def test_sort_with_by_and_get_option(r): - r.rpush('foo', '2') - r.rpush('foo', '1') - r.rpush('foo', '4') - r.rpush('foo', '3') - - r['weight_1'] = '4' - r['weight_2'] = '3' - r['weight_3'] = '2' - r['weight_4'] = '1' - - r['data_1'] = 'one' - r['data_2'] = 'two' - r['data_3'] = 'three' - r['data_4'] = 'four' - - assert ( - r.sort('foo', by='weight_*', get='data_*') - == [b'four', b'three', b'two', b'one'] - ) - assert r.sort('foo', by='weight_*', get='#') == [b'4', b'3', b'2', b'1'] - assert ( - r.sort('foo', by='weight_*', get=('data_*', '#')) - == [b'four', b'4', b'three', b'3', b'two', b'2', b'one', b'1'] - ) - assert r.sort('foo', by='weight_*', get='data_1') == [None, None, None, None] - - -def test_sort_with_hash(r): - r.rpush('foo', 'middle') - r.rpush('foo', 'eldest') - r.rpush('foo', 'youngest') - r.hset('record_youngest', 'age', 1) - r.hset('record_youngest', 'name', 'baby') - - r.hset('record_middle', 'age', 10) - r.hset('record_middle', 'name', 'teen') - - r.hset('record_eldest', 'age', 20) - r.hset('record_eldest', 'name', 'adult') - - assert r.sort('foo', by='record_*->age') == [b'youngest', b'middle', b'eldest'] - assert ( - r.sort('foo', by='record_*->age', get='record_*->name') - == [b'baby', b'teen', b'adult'] - ) - - -def test_sort_with_set(r): - r.sadd('foo', '3') - r.sadd('foo', '1') - r.sadd('foo', '2') - assert r.sort('foo') == [b'1', b'2', b'3'] - - -def test_pipeline(r): - # The pipeline method returns an object for - # issuing multiple commands in a batch. - p = r.pipeline() - p.watch('bam') - p.multi() - p.set('foo', 'bar').get('foo') - p.lpush('baz', 'quux') - p.lpush('baz', 'quux2').lrange('baz', 0, -1) - res = p.execute() - - # Check return values returned as list. - assert res == [True, b'bar', 1, 2, [b'quux2', b'quux']] - - # Check side effects happened as expected. - assert r.lrange('baz', 0, -1) == [b'quux2', b'quux'] - - # Check that the command buffer has been emptied. - assert p.execute() == [] - - -def test_pipeline_ignore_errors(r): - """Test the pipeline ignoring errors when asked.""" - with r.pipeline() as p: - p.set('foo', 'bar') - p.rename('baz', 'bats') - with pytest.raises(redis.exceptions.ResponseError): - p.execute() - assert [] == p.execute() - with r.pipeline() as p: - p.set('foo', 'bar') - p.rename('baz', 'bats') - res = p.execute(raise_on_error=False) - - assert [] == p.execute() - - assert len(res) == 2 - assert isinstance(res[1], redis.exceptions.ResponseError) - - -def test_multiple_successful_watch_calls(r): - p = r.pipeline() - p.watch('bam') - p.multi() - p.set('foo', 'bar') - # Check that the watched keys buffer has been emptied. - p.execute() - - # bam is no longer being watched, so it's ok to modify - # it now. - p.watch('foo') - r.set('bam', 'boo') - p.multi() - p.set('foo', 'bats') - assert p.execute() == [True] - - -def test_pipeline_non_transactional(r): - # For our simple-minded model I don't think - # there is any observable difference. - p = r.pipeline(transaction=False) - res = p.set('baz', 'quux').get('baz').execute() - - assert res == [True, b'quux'] - - -def test_pipeline_raises_when_watched_key_changed(r): - r.set('foo', 'bar') - r.rpush('greet', 'hello') - p = r.pipeline() - try: - p.watch('greet', 'foo') - nextf = six.ensure_binary(p.get('foo')) + b'baz' - # Simulate change happening on another thread. - r.rpush('greet', 'world') - # Begin pipelining. - p.multi() - p.set('foo', nextf) - - with pytest.raises(redis.WatchError): - p.execute() - finally: - p.reset() - - -def test_pipeline_succeeds_despite_unwatched_key_changed(r): - # Same setup as before except for the params to the WATCH command. - r.set('foo', 'bar') - r.rpush('greet', 'hello') - p = r.pipeline() - try: - # Only watch one of the 2 keys. - p.watch('foo') - nextf = six.ensure_binary(p.get('foo')) + b'baz' - # Simulate change happening on another thread. - r.rpush('greet', 'world') - p.multi() - p.set('foo', nextf) - p.execute() - - # Check the commands were executed. - assert r.get('foo') == b'barbaz' - finally: - p.reset() - - -def test_pipeline_succeeds_when_watching_nonexistent_key(r): - r.set('foo', 'bar') - r.rpush('greet', 'hello') - p = r.pipeline() - try: - # Also watch a nonexistent key. - p.watch('foo', 'bam') - nextf = six.ensure_binary(p.get('foo')) + b'baz' - # Simulate change happening on another thread. - r.rpush('greet', 'world') - p.multi() - p.set('foo', nextf) - p.execute() - - # Check the commands were executed. - assert r.get('foo') == b'barbaz' - finally: - p.reset() - - -def test_watch_state_is_cleared_across_multiple_watches(r): - r.set('foo', 'one') - r.set('bar', 'baz') - p = r.pipeline() - - try: - p.watch('foo') - # Simulate change happening on another thread. - r.set('foo', 'three') - p.multi() - p.set('foo', 'three') - with pytest.raises(redis.WatchError): - p.execute() - - # Now watch another key. It should be ok to change - # foo as we're no longer watching it. - p.watch('bar') - r.set('foo', 'four') - p.multi() - p.set('bar', 'five') - assert p.execute() == [True] - finally: - p.reset() - - -def test_watch_state_is_cleared_after_abort(r): - # redis-py's pipeline handling and connection pooling interferes with this - # test, so raw commands are used instead. - raw_command(r, 'watch', 'foo') - raw_command(r, 'multi') - with pytest.raises(redis.ResponseError): - raw_command(r, 'mget') # Wrong number of arguments - with pytest.raises(redis.exceptions.ExecAbortError): - raw_command(r, 'exec') - - raw_command(r, 'set', 'foo', 'bar') # Should NOT trigger the watch from earlier - raw_command(r, 'multi') - raw_command(r, 'set', 'abc', 'done') - raw_command(r, 'exec') - - assert r.get('abc') == b'done' - - -def test_pipeline_transaction_shortcut(r): - # This example taken pretty much from the redis-py documentation. - r.set('OUR-SEQUENCE-KEY', 13) - calls = [] - - def client_side_incr(pipe): - calls.append((pipe,)) - current_value = pipe.get('OUR-SEQUENCE-KEY') - next_value = int(current_value) + 1 - - if len(calls) < 3: - # Simulate a change from another thread. - r.set('OUR-SEQUENCE-KEY', next_value) - - pipe.multi() - pipe.set('OUR-SEQUENCE-KEY', next_value) - - res = r.transaction(client_side_incr, 'OUR-SEQUENCE-KEY') - - assert res == [True] - assert int(r.get('OUR-SEQUENCE-KEY')) == 16 - assert len(calls) == 3 - - -def test_pipeline_transaction_value_from_callable(r): - def callback(pipe): - # No need to do anything here since we only want the return value - return 'OUR-RETURN-VALUE' - - res = r.transaction(callback, 'OUR-SEQUENCE-KEY', value_from_callable=True) - assert res == 'OUR-RETURN-VALUE' - - -def test_pipeline_empty(r): - p = r.pipeline() - assert len(p) == 0 - - -def test_pipeline_length(r): - p = r.pipeline() - p.set('baz', 'quux').get('baz') - assert len(p) == 2 - - -def test_pipeline_no_commands(r): - # Prior to 3.4, redis-py's execute is a nop if there are no commands - # queued, so it succeeds even if watched keys have been changed. - r.set('foo', '1') - p = r.pipeline() - p.watch('foo') - r.set('foo', '2') - if REDIS_VERSION >= Version('3.4'): - with pytest.raises(redis.WatchError): - p.execute() - else: - assert p.execute() == [] - - -def test_pipeline_failed_transaction(r): - p = r.pipeline() - p.multi() - p.set('foo', 'bar') - # Deliberately induce a syntax error - p.execute_command('set') - # It should be an ExecAbortError, but redis-py tries to DISCARD after the - # failed EXEC, which raises a ResponseError. - with pytest.raises(redis.ResponseError): - p.execute() - assert not r.exists('foo') - - -def test_pipeline_srem_no_change(r): - # A regression test for a case picked up by hypothesis tests. - p = r.pipeline() - p.watch('foo') - r.srem('foo', 'bar') - p.multi() - p.set('foo', 'baz') - p.execute() - assert r.get('foo') == b'baz' - - -# The behaviour changed in redis 6.0 (see https://github.com/redis/redis/issues/6594). -@pytest.mark.min_server('6.0') -def test_pipeline_move(r): - # A regression test for a case picked up by hypothesis tests. - r.set('foo', 'bar') - p = r.pipeline() - p.watch('foo') - r.move('foo', 1) - # Ensure the transaction isn't empty, which had different behaviour in - # older versions of redis-py. - p.multi() - p.set('bar', 'baz') - with pytest.raises(redis.exceptions.WatchError): - p.execute() - - -@pytest.mark.min_server('6.0.6') -def test_exec_bad_arguments(r): - # Redis 6.0.6 changed the behaviour of exec so that it always fails with - # EXECABORT, even when it's just bad syntax. - with pytest.raises(redis.exceptions.ExecAbortError): - r.execute_command('exec', 'blahblah') - - -@pytest.mark.min_server('6.0.6') -def test_exec_bad_arguments_abort(r): - r.execute_command('multi') - with pytest.raises(redis.exceptions.ExecAbortError): - r.execute_command('exec', 'blahblah') - # Should have aborted the transaction, so we can run another one - p = r.pipeline() - p.multi() - p.set('bar', 'baz') - p.execute() - assert r.get('bar') == b'baz' - - -def test_key_patterns(r): - r.mset({'one': 1, 'two': 2, 'three': 3, 'four': 4}) - assert sorted(r.keys('*o*')) == [b'four', b'one', b'two'] - assert r.keys('t??') == [b'two'] - assert sorted(r.keys('*')) == [b'four', b'one', b'three', b'two'] - assert sorted(r.keys()) == [b'four', b'one', b'three', b'two'] - - -def test_ping(r): - assert r.ping() - assert raw_command(r, 'ping', 'test') == b'test' - - -@redis3_only -def test_ping_pubsub(r): - p = r.pubsub() - p.subscribe('channel') - p.parse_response() # Consume the subscribe reply - p.ping() - assert p.parse_response() == [b'pong', b''] - p.ping('test') - assert p.parse_response() == [b'pong', b'test'] - - -@redis3_only -def test_swapdb(r, create_redis): - r1 = create_redis(1) - r.set('foo', 'abc') - r.set('bar', 'xyz') - r1.set('foo', 'foo') - r1.set('baz', 'baz') - assert r.swapdb(0, 1) - assert r.get('foo') == b'foo' - assert r.get('bar') is None - assert r.get('baz') == b'baz' - assert r1.get('foo') == b'abc' - assert r1.get('bar') == b'xyz' - assert r1.get('baz') is None - - -@redis3_only -def test_swapdb_same_db(r): - assert r.swapdb(1, 1) - - -def test_save(r): - assert r.save() - - -def test_bgsave(r): - assert r.bgsave() - with pytest.raises(ResponseError): - r.execute_command('BGSAVE', 'SCHEDULE', 'FOO') - with pytest.raises(ResponseError): - r.execute_command('BGSAVE', 'FOO') - - -def test_lastsave(r): - assert isinstance(r.lastsave(), datetime) - - -@fake_only -def test_time(r, mocker): - fake_time = mocker.patch('time.time') - fake_time.return_value = 1234567890.1234567 - assert r.time() == (1234567890, 123457) - fake_time.return_value = 1234567890.000001 - assert r.time() == (1234567890, 1) - fake_time.return_value = 1234567890.9999999 - assert r.time() == (1234567891, 0) - - -@pytest.mark.slow -def test_bgsave_timestamp_update(r): - early_timestamp = r.lastsave() - sleep(1) - assert r.bgsave() - sleep(1) - late_timestamp = r.lastsave() - assert early_timestamp < late_timestamp - - -@pytest.mark.slow -def test_save_timestamp_update(r): - early_timestamp = r.lastsave() - sleep(1) - assert r.save() - late_timestamp = r.lastsave() - assert early_timestamp < late_timestamp - - -def test_type(r): - r.set('string_key', "value") - r.lpush("list_key", "value") - r.sadd("set_key", "value") - zadd(r, "zset_key", {"value": 1}) - r.hset('hset_key', 'key', 'value') - - assert r.type('string_key') == b'string' - assert r.type('list_key') == b'list' - assert r.type('set_key') == b'set' - assert r.type('zset_key') == b'zset' - assert r.type('hset_key') == b'hash' - assert r.type('none_key') == b'none' - - -@pytest.mark.slow -def test_pubsub_subscribe(r): - pubsub = r.pubsub() - pubsub.subscribe("channel") - sleep(1) - expected_message = {'type': 'subscribe', 'pattern': None, - 'channel': b'channel', 'data': 1} - message = pubsub.get_message() - keys = list(pubsub.channels.keys()) - - key = keys[0] - key = (key if type(key) == bytes - else bytes(key, encoding='utf-8')) - - assert len(keys) == 1 - assert key == b'channel' - assert message == expected_message - - -@pytest.mark.slow -def test_pubsub_psubscribe(r): - pubsub = r.pubsub() - pubsub.psubscribe("channel.*") - sleep(1) - expected_message = {'type': 'psubscribe', 'pattern': None, - 'channel': b'channel.*', 'data': 1} - - message = pubsub.get_message() - keys = list(pubsub.patterns.keys()) - assert len(keys) == 1 - assert message == expected_message - - -@pytest.mark.slow -def test_pubsub_unsubscribe(r): - pubsub = r.pubsub() - pubsub.subscribe('channel-1', 'channel-2', 'channel-3') - sleep(1) - expected_message = {'type': 'unsubscribe', 'pattern': None, - 'channel': b'channel-1', 'data': 2} - pubsub.get_message() - pubsub.get_message() - pubsub.get_message() - - # unsubscribe from one - pubsub.unsubscribe('channel-1') - sleep(1) - message = pubsub.get_message() - keys = list(pubsub.channels.keys()) - assert message == expected_message - assert len(keys) == 2 - - # unsubscribe from multiple - pubsub.unsubscribe() - sleep(1) - pubsub.get_message() - pubsub.get_message() - keys = list(pubsub.channels.keys()) - assert message == expected_message - assert len(keys) == 0 - - -@pytest.mark.slow -def test_pubsub_punsubscribe(r): - pubsub = r.pubsub() - pubsub.psubscribe('channel-1.*', 'channel-2.*', 'channel-3.*') - sleep(1) - expected_message = {'type': 'punsubscribe', 'pattern': None, - 'channel': b'channel-1.*', 'data': 2} - pubsub.get_message() - pubsub.get_message() - pubsub.get_message() - - # unsubscribe from one - pubsub.punsubscribe('channel-1.*') - sleep(1) - message = pubsub.get_message() - keys = list(pubsub.patterns.keys()) - assert message == expected_message - assert len(keys) == 2 - - # unsubscribe from multiple - pubsub.punsubscribe() - sleep(1) - pubsub.get_message() - pubsub.get_message() - keys = list(pubsub.patterns.keys()) - assert len(keys) == 0 - - -@pytest.mark.slow -def test_pubsub_listen(r): - def _listen(pubsub, q): - count = 0 - for message in pubsub.listen(): - q.put(message) - count += 1 - if count == 4: - pubsub.close() - - channel = 'ch1' - patterns = ['ch1*', 'ch[1]', 'ch?'] - pubsub = r.pubsub() - pubsub.subscribe(channel) - pubsub.psubscribe(*patterns) - sleep(1) - msg1 = pubsub.get_message() - msg2 = pubsub.get_message() - msg3 = pubsub.get_message() - msg4 = pubsub.get_message() - assert msg1['type'] == 'subscribe' - assert msg2['type'] == 'psubscribe' - assert msg3['type'] == 'psubscribe' - assert msg4['type'] == 'psubscribe' - - q = Queue() - t = threading.Thread(target=_listen, args=(pubsub, q)) - t.start() - msg = 'hello world' - r.publish(channel, msg) - t.join() - - msg1 = q.get() - msg2 = q.get() - msg3 = q.get() - msg4 = q.get() - - bpatterns = [pattern.encode() for pattern in patterns] - bpatterns.append(channel.encode()) - msg = msg.encode() - assert msg1['data'] == msg - assert msg1['channel'] in bpatterns - assert msg2['data'] == msg - assert msg2['channel'] in bpatterns - assert msg3['data'] == msg - assert msg3['channel'] in bpatterns - assert msg4['data'] == msg - assert msg4['channel'] in bpatterns - - -@pytest.mark.slow -def test_pubsub_listen_handler(r): - def _handler(message): - calls.append(message) - - channel = 'ch1' - patterns = {'ch?': _handler} - calls = [] - - pubsub = r.pubsub() - pubsub.subscribe(ch1=_handler) - pubsub.psubscribe(**patterns) - sleep(1) - msg1 = pubsub.get_message() - msg2 = pubsub.get_message() - assert msg1['type'] == 'subscribe' - assert msg2['type'] == 'psubscribe' - msg = 'hello world' - r.publish(channel, msg) - sleep(1) - for i in range(2): - msg = pubsub.get_message() - assert msg is None # get_message returns None when handler is used - pubsub.close() - calls.sort(key=lambda call: call['type']) - assert calls == [ - {'pattern': None, 'channel': b'ch1', 'data': b'hello world', 'type': 'message'}, - {'pattern': b'ch?', 'channel': b'ch1', 'data': b'hello world', 'type': 'pmessage'} - ] - - -@pytest.mark.slow -def test_pubsub_ignore_sub_messages_listen(r): - def _listen(pubsub, q): - count = 0 - for message in pubsub.listen(): - q.put(message) - count += 1 - if count == 4: - pubsub.close() - - channel = 'ch1' - patterns = ['ch1*', 'ch[1]', 'ch?'] - pubsub = r.pubsub(ignore_subscribe_messages=True) - pubsub.subscribe(channel) - pubsub.psubscribe(*patterns) - sleep(1) - - q = Queue() - t = threading.Thread(target=_listen, args=(pubsub, q)) - t.start() - msg = 'hello world' - r.publish(channel, msg) - t.join() - - msg1 = q.get() - msg2 = q.get() - msg3 = q.get() - msg4 = q.get() - - bpatterns = [pattern.encode() for pattern in patterns] - bpatterns.append(channel.encode()) - msg = msg.encode() - assert msg1['data'] == msg - assert msg1['channel'] in bpatterns - assert msg2['data'] == msg - assert msg2['channel'] in bpatterns - assert msg3['data'] == msg - assert msg3['channel'] in bpatterns - assert msg4['data'] == msg - assert msg4['channel'] in bpatterns - - -@pytest.mark.slow -def test_pubsub_binary(r): - def _listen(pubsub, q): - for message in pubsub.listen(): - q.put(message) - pubsub.close() - - pubsub = r.pubsub(ignore_subscribe_messages=True) - pubsub.subscribe('channel\r\n\xff') - sleep(1) - - q = Queue() - t = threading.Thread(target=_listen, args=(pubsub, q)) - t.start() - msg = b'\x00hello world\r\n\xff' - r.publish('channel\r\n\xff', msg) - t.join() - - received = q.get() - assert received['data'] == msg - - -@pytest.mark.slow -def test_pubsub_run_in_thread(r): - q = Queue() - - pubsub = r.pubsub() - pubsub.subscribe(channel=q.put) - pubsub_thread = pubsub.run_in_thread() - - msg = b"Hello World" - r.publish("channel", msg) - - retrieved = q.get() - assert retrieved["data"] == msg - - pubsub_thread.stop() - # Newer versions of redis wait for an unsubscribe message, which sometimes comes early - # https://github.com/andymccurdy/redis-py/issues/1150 - if pubsub.channels: - pubsub.channels = {} - pubsub_thread.join() - assert not pubsub_thread.is_alive() - - pubsub.subscribe(channel=None) - with pytest.raises(redis.exceptions.PubSubError): - pubsub_thread = pubsub.run_in_thread() - - pubsub.unsubscribe("channel") - - pubsub.psubscribe(channel=None) - with pytest.raises(redis.exceptions.PubSubError): - pubsub_thread = pubsub.run_in_thread() - - -@pytest.mark.slow -@pytest.mark.parametrize( - "timeout_value", - [ - 1, - pytest.param( - None, - marks=pytest.mark.skipif( - Version("3.2") <= REDIS_VERSION < Version("3.3"), - reason="This test is not applicable to redis-py 3.2" - ) - ) - ] -) -def test_pubsub_timeout(r, timeout_value): - def publish(): - sleep(0.1) - r.publish('channel', 'hello') - - p = r.pubsub() - p.subscribe('channel') - p.parse_response() # Drains the subscribe message - publish_thread = threading.Thread(target=publish) - publish_thread.start() - message = p.get_message(timeout=timeout_value) - assert message == { - 'type': 'message', 'pattern': None, - 'channel': b'channel', 'data': b'hello' - } - publish_thread.join() - - if timeout_value is not None: - # For infinite timeout case don't wait for the message that will never appear. - message = p.get_message(timeout=timeout_value) - assert message is None - - -def test_pfadd(r): - key = "hll-pfadd" - assert r.pfadd(key, "a", "b", "c", "d", "e", "f", "g") == 1 - assert r.pfcount(key) == 7 - - -def test_pfcount(r): - key1 = "hll-pfcount01" - key2 = "hll-pfcount02" - key3 = "hll-pfcount03" - assert r.pfadd(key1, "foo", "bar", "zap") == 1 - assert r.pfadd(key1, "zap", "zap", "zap") == 0 - assert r.pfadd(key1, "foo", "bar") == 0 - assert r.pfcount(key1) == 3 - assert r.pfadd(key2, "1", "2", "3") == 1 - assert r.pfcount(key2) == 3 - assert r.pfcount(key1, key2) == 6 - assert r.pfadd(key3, "foo", "bar", "zip") == 1 - assert r.pfcount(key3) == 3 - assert r.pfcount(key1, key3) == 4 - assert r.pfcount(key1, key2, key3) == 7 - - -def test_pfmerge(r): - key1 = "hll-pfmerge01" - key2 = "hll-pfmerge02" - key3 = "hll-pfmerge03" - assert r.pfadd(key1, "foo", "bar", "zap", "a") == 1 - assert r.pfadd(key2, "a", "b", "c", "foo") == 1 - assert r.pfmerge(key3, key1, key2) - assert r.pfcount(key3) == 6 - - -def test_scan(r): - # Setup the data - for ix in range(20): - k = 'scan-test:%s' % ix - v = 'result:%s' % ix - r.set(k, v) - expected = r.keys() - assert len(expected) == 20 # Ensure we know what we're testing - - # Test that we page through the results and get everything out - results = [] - cursor = '0' - while cursor != 0: - cursor, data = r.scan(cursor, count=6) - results.extend(data) - assert set(expected) == set(results) - - # Now test that the MATCH functionality works - results = [] - cursor = '0' - while cursor != 0: - cursor, data = r.scan(cursor, match='*7', count=100) - results.extend(data) - assert b'scan-test:7' in results - assert b'scan-test:17' in results - assert len(results) == 2 - - # Test the match on iterator - results = [r for r in r.scan_iter(match='*7')] - assert b'scan-test:7' in results - assert b'scan-test:17' in results - assert len(results) == 2 - - -def test_sscan(r): - # Setup the data - name = 'sscan-test' - for ix in range(20): - k = 'sscan-test:%s' % ix - r.sadd(name, k) - expected = r.smembers(name) - assert len(expected) == 20 # Ensure we know what we're testing - - # Test that we page through the results and get everything out - results = [] - cursor = '0' - while cursor != 0: - cursor, data = r.sscan(name, cursor, count=6) - results.extend(data) - assert set(expected) == set(results) - - # Test the iterator version - results = [r for r in r.sscan_iter(name, count=6)] - assert set(expected) == set(results) - - # Now test that the MATCH functionality works - results = [] - cursor = '0' - while cursor != 0: - cursor, data = r.sscan(name, cursor, match='*7', count=100) - results.extend(data) - assert b'sscan-test:7' in results - assert b'sscan-test:17' in results - assert len(results) == 2 - - # Test the match on iterator - results = [r for r in r.sscan_iter(name, match='*7')] - assert b'sscan-test:7' in results - assert b'sscan-test:17' in results - assert len(results) == 2 - - -def test_hscan(r): - # Setup the data - name = 'hscan-test' - for ix in range(20): - k = 'key:%s' % ix - v = 'result:%s' % ix - r.hset(name, k, v) - expected = r.hgetall(name) - assert len(expected) == 20 # Ensure we know what we're testing - - # Test that we page through the results and get everything out - results = {} - cursor = '0' - while cursor != 0: - cursor, data = r.hscan(name, cursor, count=6) - results.update(data) - assert expected == results - - # Test the iterator version - results = {} - for key, val in r.hscan_iter(name, count=6): - results[key] = val - assert expected == results - - # Now test that the MATCH functionality works - results = {} - cursor = '0' - while cursor != 0: - cursor, data = r.hscan(name, cursor, match='*7', count=100) - results.update(data) - assert b'key:7' in results - assert b'key:17' in results - assert len(results) == 2 - - # Test the match on iterator - results = {} - for key, val in r.hscan_iter(name, match='*7'): - results[key] = val - assert b'key:7' in results - assert b'key:17' in results - assert len(results) == 2 - - -def test_zscan(r): - # Setup the data - name = 'zscan-test' - for ix in range(20): - zadd(r, name, {'key:%s' % ix: ix}) - expected = dict(r.zrange(name, 0, -1, withscores=True)) - - # Test the basic version - results = {} - for key, val in r.zscan_iter(name, count=6): - results[key] = val - assert results == expected - - # Now test that the MATCH functionality works - results = {} - cursor = '0' - while cursor != 0: - cursor, data = r.zscan(name, cursor, match='*7', count=6) - results.update(data) - assert results == {b'key:7': 7.0, b'key:17': 17.0} - - -@pytest.mark.slow -def test_set_ex_should_expire_value(r): - r.set('foo', 'bar') - assert r.get('foo') == b'bar' - r.set('foo', 'bar', ex=1) - sleep(2) - assert r.get('foo') is None - - -@pytest.mark.slow -def test_set_px_should_expire_value(r): - r.set('foo', 'bar', px=500) - sleep(1.5) - assert r.get('foo') is None - - -@pytest.mark.slow -def test_psetex_expire_value(r): - with pytest.raises(ResponseError): - r.psetex('foo', 0, 'bar') - r.psetex('foo', 500, 'bar') - sleep(1.5) - assert r.get('foo') is None - - -@pytest.mark.slow -def test_psetex_expire_value_using_timedelta(r): - with pytest.raises(ResponseError): - r.psetex('foo', timedelta(seconds=0), 'bar') - r.psetex('foo', timedelta(seconds=0.5), 'bar') - sleep(1.5) - assert r.get('foo') is None - - -@pytest.mark.slow -def test_expire_should_expire_key(r): - r.set('foo', 'bar') - assert r.get('foo') == b'bar' - r.expire('foo', 1) - sleep(1.5) - assert r.get('foo') is None - assert r.expire('bar', 1) is False - - -def test_expire_should_return_true_for_existing_key(r): - r.set('foo', 'bar') - assert r.expire('foo', 1) is True - - -def test_expire_should_return_false_for_missing_key(r): - assert r.expire('missing', 1) is False - - -@pytest.mark.slow -def test_expire_should_expire_key_using_timedelta(r): - r.set('foo', 'bar') - assert r.get('foo') == b'bar' - r.expire('foo', timedelta(seconds=1)) - sleep(1.5) - assert r.get('foo') is None - assert r.expire('bar', 1) is False - - -@pytest.mark.slow -def test_expire_should_expire_immediately_with_millisecond_timedelta(r): - r.set('foo', 'bar') - assert r.get('foo') == b'bar' - r.expire('foo', timedelta(milliseconds=750)) - assert r.get('foo') is None - assert r.expire('bar', 1) is False - - -def test_watch_expire(r): - """EXPIRE should mark a key as changed for WATCH.""" - r.set('foo', 'bar') - with r.pipeline() as p: - p.watch('foo') - r.expire('foo', 10000) - p.multi() - p.get('foo') - with pytest.raises(redis.exceptions.WatchError): - p.execute() - - -@pytest.mark.slow -def test_pexpire_should_expire_key(r): - r.set('foo', 'bar') - assert r.get('foo') == b'bar' - r.pexpire('foo', 150) - sleep(0.2) - assert r.get('foo') is None - assert r.pexpire('bar', 1) == 0 - - -def test_pexpire_should_return_truthy_for_existing_key(r): - r.set('foo', 'bar') - assert r.pexpire('foo', 1) - - -def test_pexpire_should_return_falsey_for_missing_key(r): - assert not r.pexpire('missing', 1) - - -@pytest.mark.slow -def test_pexpire_should_expire_key_using_timedelta(r): - r.set('foo', 'bar') - assert r.get('foo') == b'bar' - r.pexpire('foo', timedelta(milliseconds=750)) - sleep(0.5) - assert r.get('foo') == b'bar' - sleep(0.5) - assert r.get('foo') is None - assert r.pexpire('bar', 1) == 0 - - -@pytest.mark.slow -def test_expireat_should_expire_key_by_datetime(r): - r.set('foo', 'bar') - assert r.get('foo') == b'bar' - r.expireat('foo', datetime.now() + timedelta(seconds=1)) - sleep(1.5) - assert r.get('foo') is None - assert r.expireat('bar', datetime.now()) is False - - -@pytest.mark.slow -def test_expireat_should_expire_key_by_timestamp(r): - r.set('foo', 'bar') - assert r.get('foo') == b'bar' - r.expireat('foo', int(time() + 1)) - sleep(1.5) - assert r.get('foo') is None - assert r.expire('bar', 1) is False - - -def test_expireat_should_return_true_for_existing_key(r): - r.set('foo', 'bar') - assert r.expireat('foo', int(time() + 1)) is True - - -def test_expireat_should_return_false_for_missing_key(r): - assert r.expireat('missing', int(time() + 1)) is False - - -@pytest.mark.slow -def test_pexpireat_should_expire_key_by_datetime(r): - r.set('foo', 'bar') - assert r.get('foo') == b'bar' - r.pexpireat('foo', datetime.now() + timedelta(milliseconds=150)) - sleep(0.2) - assert r.get('foo') is None - assert r.pexpireat('bar', datetime.now()) == 0 - - -@pytest.mark.slow -def test_pexpireat_should_expire_key_by_timestamp(r): - r.set('foo', 'bar') - assert r.get('foo') == b'bar' - r.pexpireat('foo', int(time() * 1000 + 150)) - sleep(0.2) - assert r.get('foo') is None - assert r.expire('bar', 1) is False - - -def test_pexpireat_should_return_true_for_existing_key(r): - r.set('foo', 'bar') - assert r.pexpireat('foo', int(time() * 1000 + 150)) - - -def test_pexpireat_should_return_false_for_missing_key(r): - assert not r.pexpireat('missing', int(time() * 1000 + 150)) - - -def test_expire_should_not_handle_floating_point_values(r): - r.set('foo', 'bar') - with pytest.raises(redis.ResponseError, match='value is not an integer or out of range'): - r.expire('something_new', 1.2) - r.pexpire('something_new', 1000.2) - r.expire('some_unused_key', 1.2) - r.pexpire('some_unused_key', 1000.2) - - -def test_ttl_should_return_minus_one_for_non_expiring_key(r): - r.set('foo', 'bar') - assert r.get('foo') == b'bar' - assert r.ttl('foo') == -1 - - -def test_ttl_should_return_minus_two_for_non_existent_key(r): - assert r.get('foo') is None - assert r.ttl('foo') == -2 - - -def test_pttl_should_return_minus_one_for_non_expiring_key(r): - r.set('foo', 'bar') - assert r.get('foo') == b'bar' - assert r.pttl('foo') == -1 - - -def test_pttl_should_return_minus_two_for_non_existent_key(r): - assert r.get('foo') is None - assert r.pttl('foo') == -2 - - -def test_persist(r): - r.set('foo', 'bar', ex=20) - assert r.persist('foo') == 1 - assert r.ttl('foo') == -1 - assert r.persist('foo') == 0 - - -def test_watch_persist(r): - """PERSIST should mark a variable as changed.""" - r.set('foo', 'bar', ex=10000) - with r.pipeline() as p: - p.watch('foo') - r.persist('foo') - p.multi() - p.get('foo') - with pytest.raises(redis.exceptions.WatchError): - p.execute() - - -def test_set_existing_key_persists(r): - r.set('foo', 'bar', ex=20) - r.set('foo', 'foo') - assert r.ttl('foo') == -1 - - -def test_eval_set_value_to_arg(r): - r.eval('redis.call("SET", KEYS[1], ARGV[1])', 1, 'foo', 'bar') - val = r.get('foo') - assert val == b'bar' - - -def test_eval_conditional(r): - lua = """ - local val = redis.call("GET", KEYS[1]) - if val == ARGV[1] then - redis.call("SET", KEYS[1], ARGV[2]) - else - redis.call("SET", KEYS[1], ARGV[1]) - end - """ - r.eval(lua, 1, 'foo', 'bar', 'baz') - val = r.get('foo') - assert val == b'bar' - r.eval(lua, 1, 'foo', 'bar', 'baz') - val = r.get('foo') - assert val == b'baz' - - -def test_eval_table(r): - lua = """ - local a = {} - a[1] = "foo" - a[2] = "bar" - a[17] = "baz" - return a - """ - val = r.eval(lua, 0) - assert val == [b'foo', b'bar'] - - -def test_eval_table_with_nil(r): - lua = """ - local a = {} - a[1] = "foo" - a[2] = nil - a[3] = "bar" - return a - """ - val = r.eval(lua, 0) - assert val == [b'foo'] - - -def test_eval_table_with_numbers(r): - lua = """ - local a = {} - a[1] = 42 - return a - """ - val = r.eval(lua, 0) - assert val == [42] - - -def test_eval_nested_table(r): - lua = """ - local a = {} - a[1] = {} - a[1][1] = "foo" - return a - """ - val = r.eval(lua, 0) - assert val == [[b'foo']] - - -def test_eval_iterate_over_argv(r): - lua = """ - for i, v in ipairs(ARGV) do - end - return ARGV - """ - val = r.eval(lua, 0, "a", "b", "c") - assert val == [b"a", b"b", b"c"] - - -def test_eval_iterate_over_keys(r): - lua = """ - for i, v in ipairs(KEYS) do - end - return KEYS - """ - val = r.eval(lua, 2, "a", "b", "c") - assert val == [b"a", b"b"] - - -def test_eval_mget(r): - r.set('foo1', 'bar1') - r.set('foo2', 'bar2') - val = r.eval('return redis.call("mget", "foo1", "foo2")', 2, 'foo1', 'foo2') - assert val == [b'bar1', b'bar2'] - - -@redis2_only -def test_eval_mget_none(r): - r.set('foo1', None) - r.set('foo2', None) - val = r.eval('return redis.call("mget", "foo1", "foo2")', 2, 'foo1', 'foo2') - assert val == [b'None', b'None'] - - -def test_eval_mget_not_set(r): - val = r.eval('return redis.call("mget", "foo1", "foo2")', 2, 'foo1', 'foo2') - assert val == [None, None] - - -def test_eval_hgetall(r): - r.hset('foo', 'k1', 'bar') - r.hset('foo', 'k2', 'baz') - val = r.eval('return redis.call("hgetall", "foo")', 1, 'foo') - sorted_val = sorted([val[:2], val[2:]]) - assert sorted_val == [[b'k1', b'bar'], [b'k2', b'baz']] - - -def test_eval_hgetall_iterate(r): - r.hset('foo', 'k1', 'bar') - r.hset('foo', 'k2', 'baz') - lua = """ - local result = redis.call("hgetall", "foo") - for i, v in ipairs(result) do - end - return result - """ - val = r.eval(lua, 1, 'foo') - sorted_val = sorted([val[:2], val[2:]]) - assert sorted_val == [[b'k1', b'bar'], [b'k2', b'baz']] - - -@redis2_only -def test_eval_list_with_nil(r): - r.lpush('foo', 'bar') - r.lpush('foo', None) - r.lpush('foo', 'baz') - val = r.eval('return redis.call("lrange", KEYS[1], 0, 2)', 1, 'foo') - assert val == [b'baz', b'None', b'bar'] - - -def test_eval_invalid_command(r): - with pytest.raises(ResponseError): - r.eval( - 'return redis.call("FOO")', - 0 - ) - - -def test_eval_syntax_error(r): - with pytest.raises(ResponseError): - r.eval('return "', 0) - - -def test_eval_runtime_error(r): - with pytest.raises(ResponseError): - r.eval('error("CRASH")', 0) - - -def test_eval_more_keys_than_args(r): - with pytest.raises(ResponseError): - r.eval('return 1', 42) - - -def test_eval_numkeys_float_string(r): - with pytest.raises(ResponseError): - r.eval('return KEYS[1]', '0.7', 'foo') - - -def test_eval_numkeys_integer_string(r): - val = r.eval('return KEYS[1]', "1", "foo") - assert val == b'foo' - - -def test_eval_numkeys_negative(r): - with pytest.raises(ResponseError): - r.eval('return KEYS[1]', -1, "foo") - - -def test_eval_numkeys_float(r): - with pytest.raises(ResponseError): - r.eval('return KEYS[1]', 0.7, "foo") - - -def test_eval_global_variable(r): - # Redis doesn't allow script to define global variables - with pytest.raises(ResponseError): - r.eval('a=10', 0) - - -def test_eval_global_and_return_ok(r): - # Redis doesn't allow script to define global variables - with pytest.raises(ResponseError): - r.eval( - ''' - a=10 - return redis.status_reply("Everything is awesome") - ''', - 0 - ) - - -def test_eval_convert_number(r): - # Redis forces all Lua numbers to integer - val = r.eval('return 3.2', 0) - assert val == 3 - val = r.eval('return 3.8', 0) - assert val == 3 - val = r.eval('return -3.8', 0) - assert val == -3 - - -def test_eval_convert_bool(r): - # Redis converts true to 1 and false to nil (which redis-py converts to None) - assert r.eval('return false', 0) is None - val = r.eval('return true', 0) - assert val == 1 - assert not isinstance(val, bool) - - -def test_eval_call_bool(r): - # Redis doesn't allow Lua bools to be passed to [p]call - with pytest.raises(redis.ResponseError, - match=r'Lua redis\(\) command arguments must be strings or integers'): - r.eval('return redis.call("SET", KEYS[1], true)', 1, "testkey") - - -@redis2_only -def test_eval_none_arg(r): - val = r.eval('return ARGV[1] == "None"', 0, None) - assert val - - -def test_eval_return_error(r): - with pytest.raises(redis.ResponseError, match='Testing') as exc_info: - r.eval('return {err="Testing"}', 0) - assert isinstance(exc_info.value.args[0], str) - with pytest.raises(redis.ResponseError, match='Testing') as exc_info: - r.eval('return redis.error_reply("Testing")', 0) - assert isinstance(exc_info.value.args[0], str) - - -def test_eval_return_redis_error(r): - with pytest.raises(redis.ResponseError) as exc_info: - r.eval('return redis.pcall("BADCOMMAND")', 0) - assert isinstance(exc_info.value.args[0], str) - - -def test_eval_return_ok(r): - val = r.eval('return {ok="Testing"}', 0) - assert val == b'Testing' - val = r.eval('return redis.status_reply("Testing")', 0) - assert val == b'Testing' - - -def test_eval_return_ok_nested(r): - val = r.eval( - ''' - local a = {} - a[1] = {ok="Testing"} - return a - ''', - 0 - ) - assert val == [b'Testing'] - - -def test_eval_return_ok_wrong_type(r): - with pytest.raises(redis.ResponseError): - r.eval('return redis.status_reply(123)', 0) - - -def test_eval_pcall(r): - val = r.eval( - ''' - local a = {} - a[1] = redis.pcall("foo") - return a - ''', - 0 - ) - assert isinstance(val, list) - assert len(val) == 1 - assert isinstance(val[0], ResponseError) - - -def test_eval_pcall_return_value(r): - with pytest.raises(ResponseError): - r.eval('return redis.pcall("foo")', 0) - - -def test_eval_delete(r): - r.set('foo', 'bar') - val = r.get('foo') - assert val == b'bar' - val = r.eval('redis.call("DEL", KEYS[1])', 1, 'foo') - assert val is None - - -def test_eval_exists(r): - val = r.eval('return redis.call("exists", KEYS[1]) == 0', 1, 'foo') - assert val == 1 - - -def test_eval_flushdb(r): - r.set('foo', 'bar') - val = r.eval( - ''' - local value = redis.call("FLUSHDB"); - return type(value) == "table" and value.ok == "OK"; - ''', 0 - ) - assert val == 1 - - -def test_eval_flushall(r, create_redis): - r1 = create_redis(db=0) - r2 = create_redis(db=1) - - r1['r1'] = 'r1' - r2['r2'] = 'r2' - - val = r.eval( - ''' - local value = redis.call("FLUSHALL"); - return type(value) == "table" and value.ok == "OK"; - ''', 0 - ) - - assert val == 1 - assert 'r1' not in r1 - assert 'r2' not in r2 - - -def test_eval_incrbyfloat(r): - r.set('foo', 0.5) - val = r.eval( - ''' - local value = redis.call("INCRBYFLOAT", KEYS[1], 2.0); - return type(value) == "string" and tonumber(value) == 2.5; - ''', 1, 'foo' - ) - assert val == 1 - - -def test_eval_lrange(r): - r.rpush('foo', 'a', 'b') - val = r.eval( - ''' - local value = redis.call("LRANGE", KEYS[1], 0, -1); - return type(value) == "table" and value[1] == "a" and value[2] == "b"; - ''', 1, 'foo' - ) - assert val == 1 - - -def test_eval_ltrim(r): - r.rpush('foo', 'a', 'b', 'c', 'd') - val = r.eval( - ''' - local value = redis.call("LTRIM", KEYS[1], 1, 2); - return type(value) == "table" and value.ok == "OK"; - ''', 1, 'foo' - ) - assert val == 1 - assert r.lrange('foo', 0, -1) == [b'b', b'c'] - - -def test_eval_lset(r): - r.rpush('foo', 'a', 'b') - val = r.eval( - ''' - local value = redis.call("LSET", KEYS[1], 0, "z"); - return type(value) == "table" and value.ok == "OK"; - ''', 1, 'foo' - ) - assert val == 1 - assert r.lrange('foo', 0, -1) == [b'z', b'b'] - - -def test_eval_sdiff(r): - r.sadd('foo', 'a', 'b', 'c', 'f', 'e', 'd') - r.sadd('bar', 'b') - val = r.eval( - ''' - local value = redis.call("SDIFF", KEYS[1], KEYS[2]); - if type(value) ~= "table" then - return redis.error_reply(type(value) .. ", should be table"); - else - return value; - end - ''', 2, 'foo', 'bar') - # Note: while fakeredis sorts the result when using Lua, this isn't - # actually part of the redis contract (see - # https://github.com/antirez/redis/issues/5538), and for Redis 5 we - # need to sort val to pass the test. - assert sorted(val) == [b'a', b'c', b'd', b'e', b'f'] - - -def test_script(r): - script = r.register_script('return ARGV[1]') - result = script(args=[42]) - assert result == b'42' - - -def test_script_exists(r): - # test response for no arguments by bypassing the py-redis command - # as it requires at least one argument - assert raw_command(r, "SCRIPT EXISTS") == [] - - # use single character characters for non-existing scripts, as those - # will never be equal to an actual sha1 hash digest - assert r.script_exists("a") == [0] - assert r.script_exists("a", "b", "c", "d", "e", "f") == [0, 0, 0, 0, 0, 0] - - sha1_one = r.script_load("return 'a'") - assert r.script_exists(sha1_one) == [1] - assert r.script_exists(sha1_one, "a") == [1, 0] - assert r.script_exists("a", "b", "c", sha1_one, "e") == [0, 0, 0, 1, 0] - - sha1_two = r.script_load("return 'b'") - assert r.script_exists(sha1_one, sha1_two) == [1, 1] - assert r.script_exists("a", sha1_one, "c", sha1_two, "e", "f") == [0, 1, 0, 1, 0, 0] - - -@pytest.mark.parametrize("args", [("a",), tuple("abcdefghijklmn")]) -def test_script_flush_errors_with_args(r, args): - with pytest.raises(redis.ResponseError): - raw_command(r, "SCRIPT FLUSH %s" % " ".join(args)) - - -def test_script_flush(r): - # generate/load six unique scripts and store their sha1 hash values - sha1_values = [r.script_load("return '%s'" % char) for char in "abcdef"] - - # assert the scripts all exist prior to flushing - assert r.script_exists(*sha1_values) == [1] * len(sha1_values) - - # flush and assert OK response - assert r.script_flush() is True - - # assert none of the scripts exists after flushing - assert r.script_exists(*sha1_values) == [0] * len(sha1_values) - - -@fake_only -def test_lua_log(r, caplog): - logger = fakeredis._server.LOGGER - script = """ - redis.log(redis.LOG_DEBUG, "debug") - redis.log(redis.LOG_VERBOSE, "verbose") - redis.log(redis.LOG_NOTICE, "notice") - redis.log(redis.LOG_WARNING, "warning") - """ - script = r.register_script(script) - with caplog.at_level('DEBUG'): - script() - assert caplog.record_tuples == [ - (logger.name, logging.DEBUG, 'debug'), - (logger.name, logging.INFO, 'verbose'), - (logger.name, logging.INFO, 'notice'), - (logger.name, logging.WARNING, 'warning') - ] - - -def test_lua_log_no_message(r): - script = "redis.log(redis.LOG_DEBUG)" - script = r.register_script(script) - with pytest.raises(redis.ResponseError): - script() - - -@fake_only -def test_lua_log_different_types(r, caplog): - logger = fakeredis._server.LOGGER - script = "redis.log(redis.LOG_DEBUG, 'string', 1, true, 3.14, 'string')" - script = r.register_script(script) - with caplog.at_level('DEBUG'): - script() - assert caplog.record_tuples == [ - (logger.name, logging.DEBUG, 'string 1 3.14 string') - ] - - -def test_lua_log_wrong_level(r): - script = "redis.log(10, 'string')" - script = r.register_script(script) - with pytest.raises(redis.ResponseError): - script() - - -@fake_only -def test_lua_log_defined_vars(r, caplog): - logger = fakeredis._server.LOGGER - script = """ - local var='string' - redis.log(redis.LOG_DEBUG, var) - """ - script = r.register_script(script) - with caplog.at_level('DEBUG'): - script() - assert caplog.record_tuples == [(logger.name, logging.DEBUG, 'string')] - - -@redis3_only -def test_unlink(r): - r.set('foo', 'bar') - r.unlink('foo') - assert r.get('foo') is None - - -@pytest.mark.skipif(REDIS_VERSION < Version("3.4"), reason="Test requires redis-py 3.4+") -@pytest.mark.fake -def test_socket_cleanup_pubsub(fake_server): - r1 = fakeredis.FakeStrictRedis(server=fake_server) - r2 = fakeredis.FakeStrictRedis(server=fake_server) - ps = r1.pubsub() - with ps: - ps.subscribe('test') - ps.psubscribe('test*') - r2.publish('test', 'foo') - - -@pytest.mark.fake -def test_socket_cleanup_watch(fake_server): - r1 = fakeredis.FakeStrictRedis(server=fake_server) - r2 = fakeredis.FakeStrictRedis(server=fake_server) - pipeline = r1.pipeline(transaction=False) - # This needs some poking into redis-py internals to ensure that we reach - # FakeSocket._cleanup. We need to close the socket while there is still - # a watch in place, but not allow it to be garbage collected (hence we - # set 'sock' even though it is unused). - with pipeline: - pipeline.watch('test') - sock = pipeline.connection._sock # noqa: F841 - pipeline.connection.disconnect() - r2.set('test', 'foo') - - -@redis2_only -@pytest.mark.parametrize( - 'create_redis', - [ - pytest.param('FakeRedis', marks=pytest.mark.fake), - pytest.param('Redis', marks=pytest.mark.real) - ], - indirect=True -) -class TestNonStrict: - def test_setex(self, r): - assert r.setex('foo', 'bar', 100) is True - assert r.get('foo') == b'bar' - - def test_setex_using_timedelta(self, r): - assert r.setex('foo', 'bar', timedelta(seconds=100)) is True - assert r.get('foo') == b'bar' - - def test_lrem_positive_count(self, r): - r.lpush('foo', 'same') - r.lpush('foo', 'same') - r.lpush('foo', 'different') - r.lrem('foo', 'same', 2) - assert r.lrange('foo', 0, -1) == [b'different'] - - def test_lrem_negative_count(self, r): - r.lpush('foo', 'removeme') - r.lpush('foo', 'three') - r.lpush('foo', 'two') - r.lpush('foo', 'one') - r.lpush('foo', 'removeme') - r.lrem('foo', 'removeme', -1) - # Should remove it from the end of the list, - # leaving the 'removeme' from the front of the list alone. - assert r.lrange('foo', 0, -1) == [b'removeme', b'one', b'two', b'three'] - - def test_lrem_zero_count(self, r): - r.lpush('foo', 'one') - r.lpush('foo', 'one') - r.lpush('foo', 'one') - r.lrem('foo', 'one') - assert r.lrange('foo', 0, -1) == [] - - def test_lrem_default_value(self, r): - r.lpush('foo', 'one') - r.lpush('foo', 'one') - r.lpush('foo', 'one') - r.lrem('foo', 'one') - assert r.lrange('foo', 0, -1) == [] - - def test_lrem_does_not_exist(self, r): - r.lpush('foo', 'one') - r.lrem('foo', 'one') - # These should be noops. - r.lrem('foo', 'one', -2) - r.lrem('foo', 'one', 2) - - def test_lrem_return_value(self, r): - r.lpush('foo', 'one') - count = r.lrem('foo', 'one', 0) - assert count == 1 - assert r.lrem('foo', 'one') == 0 - - def test_zadd_deprecated(self, r): - result = r.zadd('foo', 'one', 1) - assert result == 1 - assert r.zrange('foo', 0, -1) == [b'one'] - - def test_zadd_missing_required_params(self, r): - with pytest.raises(redis.RedisError): - # Missing the 'score' param. - r.zadd('foo', 'one') - with pytest.raises(redis.RedisError): - # Missing the 'value' param. - r.zadd('foo', None, score=1) - with pytest.raises(redis.RedisError): - r.zadd('foo') - - def test_zadd_with_single_keypair(self, r): - result = r.zadd('foo', bar=1) - assert result == 1 - assert r.zrange('foo', 0, -1) == [b'bar'] - - def test_zadd_with_multiple_keypairs(self, r): - result = r.zadd('foo', bar=1, baz=9) - assert result == 2 - assert r.zrange('foo', 0, -1) == [b'bar', b'baz'] - - def test_zadd_with_name_is_non_string(self, r): - result = r.zadd('foo', 1, 9) - assert result == 1 - assert r.zrange('foo', 0, -1) == [b'1'] - - def test_ttl_should_return_none_for_non_expiring_key(self, r): - r.set('foo', 'bar') - assert r.get('foo') == b'bar' - assert r.ttl('foo') is None - - def test_ttl_should_return_value_for_expiring_key(self, r): - r.set('foo', 'bar') - r.expire('foo', 1) - assert r.ttl('foo') == 1 - r.expire('foo', 2) - assert r.ttl('foo') == 2 - # See https://github.com/antirez/redis/blob/unstable/src/db.c#L632 - ttl = 1000000000 - r.expire('foo', ttl) - assert r.ttl('foo') == ttl - - def test_pttl_should_return_none_for_non_expiring_key(self, r): - r.set('foo', 'bar') - assert r.get('foo') == b'bar' - assert r.pttl('foo') is None - - def test_pttl_should_return_value_for_expiring_key(self, r): - d = 100 - r.set('foo', 'bar') - r.expire('foo', 1) - assert 1000 - d <= r.pttl('foo') <= 1000 - r.expire('foo', 2) - assert 2000 - d <= r.pttl('foo') <= 2000 - ttl = 1000000000 - # See https://github.com/antirez/redis/blob/unstable/src/db.c#L632 - r.expire('foo', ttl) - assert ttl * 1000 - d <= r.pttl('foo') <= ttl * 1000 - - def test_expire_should_not_handle_floating_point_values(self, r): - r.set('foo', 'bar') - with pytest.raises(redis.ResponseError, match='value is not an integer or out of range'): - r.expire('something_new', 1.2) - r.pexpire('something_new', 1000.2) - r.expire('some_unused_key', 1.2) - r.pexpire('some_unused_key', 1000.2) - - def test_lock(self, r): - lock = r.lock('foo') - assert lock.acquire() - assert r.exists('foo') - lock.release() - assert not r.exists('foo') - with r.lock('bar'): - assert r.exists('bar') - assert not r.exists('bar') - - def test_unlock_without_lock(self, r): - lock = r.lock('foo') - with pytest.raises(redis.exceptions.LockError): - lock.release() - - @pytest.mark.slow - def test_unlock_expired(self, r): - lock = r.lock('foo', timeout=0.01, sleep=0.001) - assert lock.acquire() - sleep(0.1) - with pytest.raises(redis.exceptions.LockError): - lock.release() - - @pytest.mark.slow - def test_lock_blocking_timeout(self, r): - lock = r.lock('foo') - assert lock.acquire() - lock2 = r.lock('foo') - assert not lock2.acquire(blocking_timeout=1) - - def test_lock_nonblocking(self, r): - lock = r.lock('foo') - assert lock.acquire() - lock2 = r.lock('foo') - assert not lock2.acquire(blocking=False) - - def test_lock_twice(self, r): - lock = r.lock('foo') - assert lock.acquire(blocking=False) - assert not lock.acquire(blocking=False) - - def test_acquiring_lock_different_lock_release(self, r): - lock1 = r.lock('foo') - lock2 = r.lock('foo') - assert lock1.acquire(blocking=False) - assert not lock2.acquire(blocking=False) - - # Test only releasing lock1 actually releases the lock - with pytest.raises(redis.exceptions.LockError): - lock2.release() - assert not lock2.acquire(blocking=False) - lock1.release() - # Locking with lock2 now has the lock - assert lock2.acquire(blocking=False) - assert not lock1.acquire(blocking=False) - - def test_lock_extend(self, r): - lock = r.lock('foo', timeout=2) - lock.acquire() - lock.extend(3) - ttl = int(r.pttl('foo')) - assert 4000 < ttl <= 5000 - - def test_lock_extend_exceptions(self, r): - lock1 = r.lock('foo', timeout=2) - with pytest.raises(redis.exceptions.LockError): - lock1.extend(3) - lock2 = r.lock('foo') - lock2.acquire() - with pytest.raises(redis.exceptions.LockError): - lock2.extend(3) # Cannot extend a lock with no timeout - - @pytest.mark.slow - def test_lock_extend_expired(self, r): - lock = r.lock('foo', timeout=0.01, sleep=0.001) - lock.acquire() - sleep(0.1) - with pytest.raises(redis.exceptions.LockError): - lock.extend(3) - - -@pytest.mark.decode_responses -class TestDecodeResponses: - def test_decode_str(self, r): - r.set('foo', 'bar') - assert r.get('foo') == 'bar' - - def test_decode_set(self, r): - r.sadd('foo', 'member1') - assert r.smembers('foo') == {'member1'} - - def test_decode_list(self, r): - r.rpush('foo', 'a', 'b') - assert r.lrange('foo', 0, -1) == ['a', 'b'] - - def test_decode_dict(self, r): - r.hset('foo', 'key', 'value') - assert r.hgetall('foo') == {'key': 'value'} - - def test_decode_error(self, r): - r.set('foo', 'bar') - with pytest.raises(ResponseError) as exc_info: - r.hset('foo', 'bar', 'baz') - assert isinstance(exc_info.value.args[0], str) - - -@pytest.mark.fake -class TestInitArgs: - def test_singleton(self): - shared_server = fakeredis.FakeServer() - r1 = fakeredis.FakeStrictRedis() - r2 = fakeredis.FakeStrictRedis() - r3 = fakeredis.FakeStrictRedis(server=shared_server) - r4 = fakeredis.FakeStrictRedis(server=shared_server) - - r1.set('foo', 'bar') - r3.set('bar', 'baz') - - assert 'foo' in r1 - assert 'foo' not in r2 - assert 'foo' not in r3 - - assert 'bar' in r3 - assert 'bar' in r4 - assert 'bar' not in r1 - - def test_from_url(https://github.com/self): - db = fakeredis.FakeStrictRedis.from_url( - 'redis://localhost:6379/0') - db.set('foo', 'bar') - assert db.get('foo') == b'bar' - - def test_from_url_with_db_arg(self): - db = fakeredis.FakeStrictRedis.from_url( - 'redis://localhost:6379/0') - db1 = fakeredis.FakeStrictRedis.from_url( - 'redis://localhost:6379/1') - db2 = fakeredis.FakeStrictRedis.from_url( - 'redis://localhost:6379/', - db=2) - db.set('foo', 'foo0') - db1.set('foo', 'foo1') - db2.set('foo', 'foo2') - assert db.get('foo') == b'foo0' - assert db1.get('foo') == b'foo1' - assert db2.get('foo') == b'foo2' - - def test_from_url_db_value_error(self): - # In case of ValueError, should default to 0, or be absent in redis-py 4.0 - db = fakeredis.FakeStrictRedis.from_url( - 'redis://localhost:6379/a') - assert db.connection_pool.connection_kwargs.get('db', 0) == 0 - - def test_can_pass_through_extra_args(self): - db = fakeredis.FakeStrictRedis.from_url( - 'redis://localhost:6379/0', - decode_responses=True) - db.set('foo', 'bar') - assert db.get('foo') == 'bar' - - @redis3_only - def test_can_allow_extra_args(self): - db = fakeredis.FakeStrictRedis.from_url( - 'redis://localhost:6379/0', - socket_connect_timeout=11, socket_timeout=12, socket_keepalive=True, - socket_keepalive_options={60: 30}, socket_type=1, - retry_on_timeout=True, - ) - fake_conn = db.connection_pool.make_connection() - assert fake_conn.socket_connect_timeout == 11 - assert fake_conn.socket_timeout == 12 - assert fake_conn.socket_keepalive is True - assert fake_conn.socket_keepalive_options == {60: 30} - assert fake_conn.socket_type == 1 - assert fake_conn.retry_on_timeout is True - - # Make fallback logic match redis-py - db = fakeredis.FakeStrictRedis.from_url( - 'redis://localhost:6379/0', - socket_connect_timeout=None, socket_timeout=30 - ) - fake_conn = db.connection_pool.make_connection() - assert fake_conn.socket_connect_timeout == fake_conn.socket_timeout - assert fake_conn.socket_keepalive_options == {} - - def test_repr(self): - # repr is human-readable, so we only test that it doesn't crash, - # and that it contains the db number. - db = fakeredis.FakeStrictRedis.from_url('/service/redis://localhost:6379/11') - rep = repr(db) - assert 'db=11' in rep - - def test_from_unix_socket(self): - db = fakeredis.FakeStrictRedis.from_url('/service/unix://a/b/c') - db.set('foo', 'bar') - assert db.get('foo') == b'bar' - - -@pytest.mark.disconnected -@fake_only -class TestFakeStrictRedisConnectionErrors: - def test_flushdb(self, r): - with pytest.raises(redis.ConnectionError): - r.flushdb() - - def test_flushall(self, r): - with pytest.raises(redis.ConnectionError): - r.flushall() - - def test_append(self, r): - with pytest.raises(redis.ConnectionError): - r.append('key', 'value') - - def test_bitcount(self, r): - with pytest.raises(redis.ConnectionError): - r.bitcount('key', 0, 20) - - def test_decr(self, r): - with pytest.raises(redis.ConnectionError): - r.decr('key', 2) - - def test_exists(self, r): - with pytest.raises(redis.ConnectionError): - r.exists('key') - - def test_expire(self, r): - with pytest.raises(redis.ConnectionError): - r.expire('key', 20) - - def test_pexpire(self, r): - with pytest.raises(redis.ConnectionError): - r.pexpire('key', 20) - - def test_echo(self, r): - with pytest.raises(redis.ConnectionError): - r.echo('value') - - def test_get(self, r): - with pytest.raises(redis.ConnectionError): - r.get('key') - - def test_getbit(self, r): - with pytest.raises(redis.ConnectionError): - r.getbit('key', 2) - - def test_getset(self, r): - with pytest.raises(redis.ConnectionError): - r.getset('key', 'value') - - def test_incr(self, r): - with pytest.raises(redis.ConnectionError): - r.incr('key') - - def test_incrby(self, r): - with pytest.raises(redis.ConnectionError): - r.incrby('key') - - def test_ncrbyfloat(self, r): - with pytest.raises(redis.ConnectionError): - r.incrbyfloat('key') - - def test_keys(self, r): - with pytest.raises(redis.ConnectionError): - r.keys() - - def test_mget(self, r): - with pytest.raises(redis.ConnectionError): - r.mget(['key1', 'key2']) - - def test_mset(self, r): - with pytest.raises(redis.ConnectionError): - r.mset({'key': 'value'}) - - def test_msetnx(self, r): - with pytest.raises(redis.ConnectionError): - r.msetnx({'key': 'value'}) - - def test_persist(self, r): - with pytest.raises(redis.ConnectionError): - r.persist('key') - - def test_rename(self, r): - server = r.connection_pool.connection_kwargs['server'] - server.connected = True - r.set('key1', 'value') - server.connected = False - with pytest.raises(redis.ConnectionError): - r.rename('key1', 'key2') - server.connected = True - assert r.exists('key1') - - def test_eval(self, r): - with pytest.raises(redis.ConnectionError): - r.eval('', 0) - - def test_lpush(self, r): - with pytest.raises(redis.ConnectionError): - r.lpush('name', 1, 2) - - def test_lrange(self, r): - with pytest.raises(redis.ConnectionError): - r.lrange('name', 1, 5) - - def test_llen(self, r): - with pytest.raises(redis.ConnectionError): - r.llen('name') - - def test_lrem(self, r): - with pytest.raises(redis.ConnectionError): - r.lrem('name', 2, 2) - - def test_rpush(self, r): - with pytest.raises(redis.ConnectionError): - r.rpush('name', 1) - - def test_lpop(self, r): - with pytest.raises(redis.ConnectionError): - r.lpop('name') - - def test_lset(self, r): - with pytest.raises(redis.ConnectionError): - r.lset('name', 1, 4) - - def test_rpushx(self, r): - with pytest.raises(redis.ConnectionError): - r.rpushx('name', 1) - - def test_ltrim(self, r): - with pytest.raises(redis.ConnectionError): - r.ltrim('name', 1, 4) - - def test_lindex(self, r): - with pytest.raises(redis.ConnectionError): - r.lindex('name', 1) - - def test_lpushx(self, r): - with pytest.raises(redis.ConnectionError): - r.lpushx('name', 1) - - def test_rpop(self, r): - with pytest.raises(redis.ConnectionError): - r.rpop('name') - - def test_linsert(self, r): - with pytest.raises(redis.ConnectionError): - r.linsert('name', 'where', 'refvalue', 'value') - - def test_rpoplpush(self, r): - with pytest.raises(redis.ConnectionError): - r.rpoplpush('src', 'dst') - - def test_blpop(self, r): - with pytest.raises(redis.ConnectionError): - r.blpop('keys') - - def test_brpop(self, r): - with pytest.raises(redis.ConnectionError): - r.brpop('keys') - - def test_brpoplpush(self, r): - with pytest.raises(redis.ConnectionError): - r.brpoplpush('src', 'dst') - - def test_hdel(self, r): - with pytest.raises(redis.ConnectionError): - r.hdel('name') - - def test_hexists(self, r): - with pytest.raises(redis.ConnectionError): - r.hexists('name', 'key') - - def test_hget(self, r): - with pytest.raises(redis.ConnectionError): - r.hget('name', 'key') - - def test_hgetall(self, r): - with pytest.raises(redis.ConnectionError): - r.hgetall('name') - - def test_hincrby(self, r): - with pytest.raises(redis.ConnectionError): - r.hincrby('name', 'key') - - def test_hincrbyfloat(self, r): - with pytest.raises(redis.ConnectionError): - r.hincrbyfloat('name', 'key') - - def test_hkeys(self, r): - with pytest.raises(redis.ConnectionError): - r.hkeys('name') - - def test_hlen(self, r): - with pytest.raises(redis.ConnectionError): - r.hlen('name') - - def test_hset(self, r): - with pytest.raises(redis.ConnectionError): - r.hset('name', 'key', 1) - - def test_hsetnx(self, r): - with pytest.raises(redis.ConnectionError): - r.hsetnx('name', 'key', 2) - - def test_hmset(self, r): - with pytest.raises(redis.ConnectionError): - r.hmset('name', {'key': 1}) - - def test_hmget(self, r): - with pytest.raises(redis.ConnectionError): - r.hmget('name', ['a', 'b']) - - def test_hvals(self, r): - with pytest.raises(redis.ConnectionError): - r.hvals('name') - - def test_sadd(self, r): - with pytest.raises(redis.ConnectionError): - r.sadd('name', 1, 2) - - def test_scard(self, r): - with pytest.raises(redis.ConnectionError): - r.scard('name') - - def test_sdiff(self, r): - with pytest.raises(redis.ConnectionError): - r.sdiff(['a', 'b']) - - def test_sdiffstore(self, r): - with pytest.raises(redis.ConnectionError): - r.sdiffstore('dest', ['a', 'b']) - - def test_sinter(self, r): - with pytest.raises(redis.ConnectionError): - r.sinter(['a', 'b']) - - def test_sinterstore(self, r): - with pytest.raises(redis.ConnectionError): - r.sinterstore('dest', ['a', 'b']) - - def test_sismember(self, r): - with pytest.raises(redis.ConnectionError): - r.sismember('name', 20) - - def test_smembers(self, r): - with pytest.raises(redis.ConnectionError): - r.smembers('name') - - def test_smove(self, r): - with pytest.raises(redis.ConnectionError): - r.smove('src', 'dest', 20) - - def test_spop(self, r): - with pytest.raises(redis.ConnectionError): - r.spop('name') - - def test_srandmember(self, r): - with pytest.raises(redis.ConnectionError): - r.srandmember('name') - - def test_srem(self, r): - with pytest.raises(redis.ConnectionError): - r.srem('name') - - def test_sunion(self, r): - with pytest.raises(redis.ConnectionError): - r.sunion(['a', 'b']) - - def test_sunionstore(self, r): - with pytest.raises(redis.ConnectionError): - r.sunionstore('dest', ['a', 'b']) - - def test_zadd(self, r): - with pytest.raises(redis.ConnectionError): - zadd(r, 'name', {'key': 'value'}) - - def test_zcard(self, r): - with pytest.raises(redis.ConnectionError): - r.zcard('name') - - def test_zcount(self, r): - with pytest.raises(redis.ConnectionError): - r.zcount('name', 1, 5) - - def test_zincrby(self, r): - with pytest.raises(redis.ConnectionError): - r.zincrby('name', 1, 1) - - def test_zinterstore(self, r): - with pytest.raises(redis.ConnectionError): - r.zinterstore('dest', ['a', 'b']) - - def test_zrange(self, r): - with pytest.raises(redis.ConnectionError): - r.zrange('name', 1, 5) - - def test_zrangebyscore(self, r): - with pytest.raises(redis.ConnectionError): - r.zrangebyscore('name', 1, 5) - - def test_rangebylex(self, r): - with pytest.raises(redis.ConnectionError): - r.zrangebylex('name', 1, 4) - - def test_zrem(self, r): - with pytest.raises(redis.ConnectionError): - r.zrem('name', 'value') - - def test_zremrangebyrank(self, r): - with pytest.raises(redis.ConnectionError): - r.zremrangebyrank('name', 1, 5) - - def test_zremrangebyscore(self, r): - with pytest.raises(redis.ConnectionError): - r.zremrangebyscore('name', 1, 5) - - def test_zremrangebylex(self, r): - with pytest.raises(redis.ConnectionError): - r.zremrangebylex('name', 1, 5) - - def test_zlexcount(self, r): - with pytest.raises(redis.ConnectionError): - r.zlexcount('name', 1, 5) - - def test_zrevrange(self, r): - with pytest.raises(redis.ConnectionError): - r.zrevrange('name', 1, 5, 1) - - def test_zrevrangebyscore(self, r): - with pytest.raises(redis.ConnectionError): - r.zrevrangebyscore('name', 5, 1) - - def test_zrevrangebylex(self, r): - with pytest.raises(redis.ConnectionError): - r.zrevrangebylex('name', 5, 1) - - def test_zrevran(self, r): - with pytest.raises(redis.ConnectionError): - r.zrevrank('name', 2) - - def test_zscore(self, r): - with pytest.raises(redis.ConnectionError): - r.zscore('name', 2) - - def test_zunionstor(self, r): - with pytest.raises(redis.ConnectionError): - r.zunionstore('dest', ['1', '2']) - - def test_pipeline(self, r): - with pytest.raises(redis.ConnectionError): - r.pipeline().watch('key') - - def test_transaction(self, r): - with pytest.raises(redis.ConnectionError): - def func(a): - return a * a - - r.transaction(func, 3) - - def test_lock(self, r): - with pytest.raises(redis.ConnectionError): - with r.lock('name'): - pass - - def test_pubsub(self, r): - with pytest.raises(redis.ConnectionError): - r.pubsub().subscribe('channel') - - def test_pfadd(self, r): - with pytest.raises(redis.ConnectionError): - r.pfadd('name', 1) - - def test_pfmerge(self, r): - with pytest.raises(redis.ConnectionError): - r.pfmerge('dest', 'a', 'b') - - def test_scan(self, r): - with pytest.raises(redis.ConnectionError): - list(r.scan()) - - def test_sscan(self, r): - with pytest.raises(redis.ConnectionError): - r.sscan('name') - - def test_hscan(self, r): - with pytest.raises(redis.ConnectionError): - r.hscan('name') - - def test_scan_iter(self, r): - with pytest.raises(redis.ConnectionError): - list(r.scan_iter()) - - def test_sscan_iter(self, r): - with pytest.raises(redis.ConnectionError): - list(r.sscan_iter('name')) - - def test_hscan_iter(self, r): - with pytest.raises(redis.ConnectionError): - list(r.hscan_iter('name')) - - -@pytest.mark.disconnected -@fake_only -class TestPubSubConnected: - @pytest.fixture - def pubsub(self, r): - return r.pubsub() - - def test_basic_subscribe(self, pubsub): - with pytest.raises(redis.ConnectionError): - pubsub.subscribe('logs') - - def test_subscription_conn_lost(self, fake_server, pubsub): - fake_server.connected = True - pubsub.subscribe('logs') - fake_server.connected = False - # The initial message is already in the pipe - msg = pubsub.get_message() - check = { - 'type': 'subscribe', - 'pattern': None, - 'channel': b'logs', - 'data': 1 - } - assert msg == check, 'Message was not published to channel' - with pytest.raises(redis.ConnectionError): - pubsub.get_message() diff --git a/test/test_hypothesis.py b/test/test_hypothesis.py deleted file mode 100644 index e17f78b..0000000 --- a/test/test_hypothesis.py +++ /dev/null @@ -1,620 +0,0 @@ -import operator -import functools - -import hypothesis -import hypothesis.stateful -from hypothesis.stateful import rule, initialize, precondition -import hypothesis.strategies as st -import pytest -import redis - -import fakeredis - - -self_strategy = st.runner() - - -@st.composite -def sample_attr(draw, name): - """Strategy for sampling a specific attribute from a state machine""" - machine = draw(self_strategy) - values = getattr(machine, name) - position = draw(st.integers(min_value=0, max_value=len(values) - 1)) - return values[position] - - -keys = sample_attr('keys') -fields = sample_attr('fields') -values = sample_attr('values') -scores = sample_attr('scores') - -int_as_bytes = st.builds(lambda x: str(x).encode(), st.integers()) -float_as_bytes = st.builds(lambda x: repr(x).encode(), st.floats(width=32)) -counts = st.integers(min_value=-3, max_value=3) | st.integers() -limits = st.just(()) | st.tuples(st.just('limit'), counts, counts) -# Redis has an integer overflow bug in swapdb, so we confine the numbers to -# a limited range (https://github.com/antirez/redis/issues/5737). -dbnums = st.integers(min_value=0, max_value=3) | st.integers(min_value=-1000, max_value=1000) -# The filter is to work around https://github.com/antirez/redis/issues/5632 -patterns = (st.text(alphabet=st.sampled_from('[]^$*.?-azAZ\\\r\n\t')) - | st.binary().filter(lambda x: b'\0' not in x)) -score_tests = scores | st.builds(lambda x: b'(' + repr(x).encode(), scores) -string_tests = ( - st.sampled_from([b'+', b'-']) - | st.builds(operator.add, st.sampled_from([b'(', b'[']), fields)) -# Redis has integer overflow bugs in time computations, which is why we set a maximum. -expires_seconds = st.integers(min_value=100000, max_value=10000000000) -expires_ms = st.integers(min_value=100000000, max_value=10000000000000) - - -class WrappedException: - """Wraps an exception for the purposes of comparison.""" - def __init__(self, exc): - self.wrapped = exc - - def __str__(self): - return str(self.wrapped) - - def __repr__(self): - return 'WrappedException({!r})'.format(self.wrapped) - - def __eq__(self, other): - if not isinstance(other, WrappedException): - return NotImplemented - if type(self.wrapped) != type(other.wrapped): # noqa: E721 - return False - # TODO: re-enable after more carefully handling order of error checks - # return self.wrapped.args == other.wrapped.args - return True - - def __ne__(self, other): - if not isinstance(other, WrappedException): - return NotImplemented - return not self == other - - -def wrap_exceptions(obj): - if isinstance(obj, list): - return [wrap_exceptions(item) for item in obj] - elif isinstance(obj, Exception): - return WrappedException(obj) - else: - return obj - - -def sort_list(lst): - if isinstance(lst, list): - return sorted(lst) - else: - return lst - - -def flatten(args): - if isinstance(args, (list, tuple)): - for arg in args: - yield from flatten(arg) - elif args is not None: - yield args - - -def default_normalize(x): - return x - - -class Command: - def __init__(self, *args): - self.args = tuple(flatten(args)) - - def __repr__(self): - parts = [repr(arg) for arg in self.args] - return 'Command({})'.format(', '.join(parts)) - - @staticmethod - def encode(arg): - encoder = redis.connection.Encoder('utf-8', 'replace', False) - return encoder.encode(arg) - - @property - def normalize(self): - command = self.encode(self.args[0]).lower() if self.args else None - # Functions that return a list in arbitrary order - unordered = { - b'keys', - b'sort', - b'hgetall', b'hkeys', b'hvals', - b'sdiff', b'sinter', b'sunion', - b'smembers' - } - if command in unordered: - return sort_list - else: - return lambda x: x - - @property - def testable(self): - """Whether this command is suitable for a test. - - The fuzzer can create commands with behaviour that is - non-deterministic, not supported, or which hits redis bugs. - """ - N = len(self.args) - if N == 0: - return False - command = self.encode(self.args[0]).lower() - if not command.split(): - return False - if command == b'keys' and N == 2 and self.args[1] != b'*': - return False - # redis will ignore a NUL character in some commands but not others - # e.g. it recognises EXEC\0 but not MULTI\00. Rather than try to - # reproduce this quirky behaviour, just skip these tests. - if b'\0' in command: - return False - return True - - -def commands(*args, **kwargs): - return st.builds(functools.partial(Command, **kwargs), *args) - - -# TODO: all expiry-related commands -common_commands = ( - commands(st.sampled_from(['del', 'persist', 'type', 'unlink']), keys) - | commands(st.just('exists'), st.lists(keys)) - | commands(st.just('keys'), st.just('*')) - # Disabled for now due to redis giving wrong answers - # (https://github.com/antirez/redis/issues/5632) - # | st.tuples(st.just('keys'), patterns) - | commands(st.just('move'), keys, dbnums) - | commands(st.sampled_from(['rename', 'renamenx']), keys, keys) - # TODO: find a better solution to sort instability than throwing - # away the sort entirely with normalize. This also prevents us - # using LIMIT. - | commands(st.just('sort'), keys, - st.none() | st.just('asc'), - st.none() | st.just('desc'), - st.none() | st.just('alpha')) -) - -# TODO: tests for select -connection_commands = ( - commands(st.just('echo'), values) - | commands(st.just('ping'), st.lists(values, max_size=2)) - | commands(st.just('swapdb'), dbnums, dbnums) -) - -string_create_commands = commands(st.just('set'), keys, values) -string_commands = ( - commands(st.just('append'), keys, values) - | commands(st.just('bitcount'), keys) - | commands(st.just('bitcount'), keys, values, values) - | commands(st.sampled_from(['incr', 'decr']), keys) - | commands(st.sampled_from(['incrby', 'decrby']), keys, values) - # Disabled for now because Python can't exactly model the long doubles. - # TODO: make a more targeted test that checks the basics. - # TODO: check how it gets stringified, without relying on hypothesis - # to get generate a get call before it gets overwritten. - # | commands(st.just('incrbyfloat'), keys, st.floats(width=32)) - | commands(st.just('get'), keys) - | commands(st.just('getbit'), keys, counts) - | commands(st.just('setbit'), keys, counts, - st.integers(min_value=0, max_value=1) | st.integers()) - | commands(st.sampled_from(['substr', 'getrange']), keys, counts, counts) - | commands(st.just('getset'), keys, values) - | commands(st.just('mget'), st.lists(keys)) - | commands(st.sampled_from(['mset', 'msetnx']), st.lists(st.tuples(keys, values))) - | commands(st.just('set'), keys, values, - st.none() | st.just('nx'), - st.none() | st.just('xx'), - st.none() | st.just('keepttl'), - st.none() | st.just('get')) - | commands(st.just('setex'), keys, expires_seconds, values) - | commands(st.just('psetex'), keys, expires_ms, values) - | commands(st.just('setnx'), keys, values) - | commands(st.just('setrange'), keys, counts, values) - | commands(st.just('strlen'), keys) -) - -# TODO: add a test for hincrbyfloat. See incrbyfloat for why this is -# problematic. -hash_create_commands = ( - commands(st.just('hmset'), keys, st.lists(st.tuples(fields, values), min_size=1)) -) -hash_commands = ( - commands(st.just('hmset'), keys, st.lists(st.tuples(fields, values))) - | commands(st.just('hdel'), keys, st.lists(fields)) - | commands(st.just('hexists'), keys, fields) - | commands(st.just('hget'), keys, fields) - | commands(st.sampled_from(['hgetall', 'hkeys', 'hvals']), keys) - | commands(st.just('hincrby'), keys, fields, st.integers()) - | commands(st.just('hlen'), keys) - | commands(st.just('hmget'), keys, st.lists(fields)) - | commands(st.sampled_from(['hset', 'hmset']), keys, st.lists(st.tuples(fields, values))) - | commands(st.just('hsetnx'), keys, fields, values) - | commands(st.just('hstrlen'), keys, fields) -) - -# TODO: blocking commands -list_create_commands = commands(st.just('rpush'), keys, st.lists(values, min_size=1)) -list_commands = ( - commands(st.just('lindex'), keys, counts) - | commands(st.just('linsert'), keys, - st.sampled_from(['before', 'after', 'BEFORE', 'AFTER']) | st.binary(), - values, values) - | commands(st.just('llen'), keys) - | commands(st.sampled_from(['lpop', 'rpop']), keys, st.just(None) | st.integers()) - | commands(st.sampled_from(['lpush', 'lpushx', 'rpush', 'rpushx']), keys, st.lists(values)) - | commands(st.just('lrange'), keys, counts, counts) - | commands(st.just('lrem'), keys, counts, values) - | commands(st.just('lset'), keys, counts, values) - | commands(st.just('ltrim'), keys, counts, counts) - | commands(st.just('rpoplpush'), keys, keys) -) - -# TODO: -# - find a way to test srandmember, spop which are random -# - sscan -set_create_commands = ( - commands(st.just('sadd'), keys, st.lists(fields, min_size=1)) -) -set_commands = ( - commands(st.just('sadd'), keys, st.lists(fields,)) - | commands(st.just('scard'), keys) - | commands(st.sampled_from(['sdiff', 'sinter', 'sunion']), st.lists(keys)) - | commands(st.sampled_from(['sdiffstore', 'sinterstore', 'sunionstore']), - keys, st.lists(keys)) - | commands(st.just('sismember'), keys, fields) - | commands(st.just('smembers'), keys) - | commands(st.just('smove'), keys, keys, fields) - | commands(st.just('srem'), keys, st.lists(fields)) -) - - -def build_zstore(command, dest, sources, weights, aggregate): - args = [command, dest, len(sources)] - args += [source[0] for source in sources] - if weights: - args.append('weights') - args += [source[1] for source in sources] - if aggregate: - args += ['aggregate', aggregate] - return Command(args) - - -# TODO: zscan, zpopmin/zpopmax, bzpopmin/bzpopmax, probably more -zset_create_commands = ( - commands(st.just('zadd'), keys, st.lists(st.tuples(scores, fields), min_size=1)) -) -zset_commands = ( - commands(st.just('zadd'), keys, - st.none() | st.just('nx'), - st.none() | st.just('xx'), - st.none() | st.just('ch'), - st.none() | st.just('incr'), - st.lists(st.tuples(scores, fields))) - | commands(st.just('zcard'), keys) - | commands(st.just('zcount'), keys, score_tests, score_tests) - | commands(st.just('zincrby'), keys, scores, fields) - | commands(st.sampled_from(['zrange', 'zrevrange']), keys, counts, counts, - st.none() | st.just('withscores')) - | commands(st.sampled_from(['zrangebyscore', 'zrevrangebyscore']), - keys, score_tests, score_tests, - limits, - st.none() | st.just('withscores')) - | commands(st.sampled_from(['zrank', 'zrevrank']), keys, fields) - | commands(st.just('zrem'), keys, st.lists(fields)) - | commands(st.just('zremrangebyrank'), keys, counts, counts) - | commands(st.just('zremrangebyscore'), keys, score_tests, score_tests) - | commands(st.just('zscore'), keys, fields) - | st.builds(build_zstore, - command=st.sampled_from(['zunionstore', 'zinterstore']), - dest=keys, sources=st.lists(st.tuples(keys, float_as_bytes)), - weights=st.booleans(), - aggregate=st.sampled_from([None, 'sum', 'min', 'max'])) -) - -zset_no_score_create_commands = ( - commands(st.just('zadd'), keys, st.lists(st.tuples(st.just(0), fields), min_size=1)) -) -zset_no_score_commands = ( - # TODO: test incr - commands(st.just('zadd'), keys, - st.none() | st.just('nx'), - st.none() | st.just('xx'), - st.none() | st.just('ch'), - st.none() | st.just('incr'), - st.lists(st.tuples(st.just(0), fields))) - | commands(st.just('zlexcount'), keys, string_tests, string_tests) - | commands(st.sampled_from(['zrangebylex', 'zrevrangebylex']), - keys, string_tests, string_tests, - limits) - | commands(st.just('zremrangebylex'), keys, string_tests, string_tests) -) - -transaction_commands = ( - commands(st.sampled_from(['multi', 'discard', 'exec', 'unwatch'])) - | commands(st.just('watch'), keys) -) - -server_commands = ( - # TODO: real redis raises an error if there is a save already in progress. - # Find a better way to test this. - # commands(st.just('bgsave')) - commands(st.just('dbsize')) - | commands(st.sampled_from(['flushdb', 'flushall']), st.sampled_from([[], 'async'])) - # TODO: result is non-deterministic - # | commands(st.just('lastsave')) - | commands(st.just('save')) -) - -bad_commands = ( - # redis-py splits the command on spaces, and hangs if that ends up - # being an empty list - commands(st.text().filter(lambda x: bool(x.split())), - st.lists(st.binary() | st.text())) -) - -attrs = st.fixed_dictionaries({ - 'keys': st.lists(st.binary(), min_size=2, max_size=5, unique=True), - 'fields': st.lists(st.binary(), min_size=2, max_size=5, unique=True), - 'values': st.lists(st.binary() | int_as_bytes | float_as_bytes, - min_size=2, max_size=5, unique=True), - 'scores': st.lists(st.floats(width=32), min_size=2, max_size=5, unique=True) -}) - - -@hypothesis.settings(max_examples=1000) -class CommonMachine(hypothesis.stateful.RuleBasedStateMachine): - create_command_strategy = st.nothing() - - def __init__(self): - super().__init__() - try: - self.real = redis.StrictRedis('localhost', port=6379) - self.real.ping() - except redis.ConnectionError: - pytest.skip('redis is not running') - if self.real.info('server').get('arch_bits') != 64: - self.real.connection_pool.disconnect() - pytest.skip('redis server is not 64-bit') - self.fake = fakeredis.FakeStrictRedis() - # Disable the response parsing so that we can check the raw values returned - self.fake.response_callbacks.clear() - self.real.response_callbacks.clear() - self.transaction_normalize = [] - self.keys = [] - self.fields = [] - self.values = [] - self.scores = [] - self.initialized_data = False - try: - self.real.execute_command('discard') - except redis.ResponseError: - pass - self.real.flushall() - - def teardown(self): - self.real.connection_pool.disconnect() - self.fake.connection_pool.disconnect() - super().teardown() - - def _evaluate(self, client, command): - try: - result = client.execute_command(*command.args) - if result != 'QUEUED': - result = command.normalize(result) - exc = None - except Exception as e: - result = exc = e - return wrap_exceptions(result), exc - - def _compare(self, command): - fake_result, fake_exc = self._evaluate(self.fake, command) - real_result, real_exc = self._evaluate(self.real, command) - - if fake_exc is not None and real_exc is None: - raise fake_exc - elif real_exc is not None and fake_exc is None: - assert real_exc == fake_exc, "Expected exception {} not raised".format(real_exc) - elif (real_exc is None and isinstance(real_result, list) - and command.args and command.args[0].lower() == 'exec'): - assert fake_result is not None - # Transactions need to use the normalize functions of the - # component commands. - assert len(self.transaction_normalize) == len(real_result) - assert len(self.transaction_normalize) == len(fake_result) - for n, r, f in zip(self.transaction_normalize, real_result, fake_result): - assert n(f) == n(r) - self.transaction_normalize = [] - else: - assert fake_result == real_result - if real_result == b'QUEUED': - # Since redis removes the distinction between simple strings and - # bulk strings, this might not actually indicate that we're in a - # transaction. But it is extremely unlikely that hypothesis will - # find such examples. - self.transaction_normalize.append(command.normalize) - if (len(command.args) == 1 - and Command.encode(command.args[0]).lower() in (b'discard', b'exec')): - self.transaction_normalize = [] - - @initialize(attrs=attrs) - def init_attrs(self, attrs): - for key, value in attrs.items(): - setattr(self, key, value) - - # hypothesis doesn't allow ordering of @initialize, so we have to put - # preconditions on rules to ensure we call init_data exactly once and - # after init_attrs. - @precondition(lambda self: not self.initialized_data) - @rule(commands=self_strategy.flatmap( - lambda self: st.lists(self.create_command_strategy))) - def init_data(self, commands): - for command in commands: - self._compare(command) - self.initialized_data = True - - @precondition(lambda self: self.initialized_data) - @rule(command=self_strategy.flatmap(lambda self: self.command_strategy)) - def one_command(self, command): - self._compare(command) - - -class BaseTest: - """Base class for test classes.""" - - create_command_strategy = st.nothing() - - @pytest.mark.slow - def test(self): - class Machine(CommonMachine): - create_command_strategy = self.create_command_strategy - command_strategy = self.command_strategy - - hypothesis.stateful.run_state_machine_as_test(Machine) - - -class TestConnection(BaseTest): - command_strategy = connection_commands | common_commands - - -class TestString(BaseTest): - create_command_strategy = string_create_commands - command_strategy = string_commands | common_commands - - -class TestHash(BaseTest): - create_command_strategy = hash_create_commands - command_strategy = hash_commands | common_commands - - -class TestList(BaseTest): - create_command_strategy = list_create_commands - command_strategy = list_commands | common_commands - - -class TestSet(BaseTest): - create_command_strategy = set_create_commands - command_strategy = set_commands | common_commands - - -class TestZSet(BaseTest): - create_command_strategy = zset_create_commands - command_strategy = zset_commands | common_commands - - -class TestZSetNoScores(BaseTest): - create_command_strategy = zset_no_score_create_commands - command_strategy = zset_no_score_commands | common_commands - - -class TestTransaction(BaseTest): - create_command_strategy = string_create_commands - command_strategy = transaction_commands | string_commands | common_commands - - -class TestServer(BaseTest): - create_command_strategy = string_create_commands - command_strategy = server_commands | string_commands | common_commands - - -class TestJoint(BaseTest): - create_command_strategy = ( - string_create_commands | hash_create_commands | list_create_commands - | set_create_commands | zset_create_commands) - command_strategy = ( - transaction_commands | server_commands | connection_commands - | string_commands | hash_commands | list_commands | set_commands - | zset_commands | common_commands | bad_commands) - - -@st.composite -def delete_arg(draw, commands): - command = draw(commands) - if command.args: - pos = draw(st.integers(min_value=0, max_value=len(command.args) - 1)) - command.args = command.args[:pos] + command.args[pos + 1:] - return command - - -@st.composite -def command_args(draw, commands): - """Generate an argument from some command""" - command = draw(commands) - hypothesis.assume(len(command.args)) - return draw(st.sampled_from(command.args)) - - -def mutate_arg(draw, commands, mutate): - command = draw(commands) - if command.args: - pos = draw(st.integers(min_value=0, max_value=len(command.args) - 1)) - arg = mutate(Command.encode(command.args[pos])) - command.args = command.args[:pos] + (arg,) + command.args[pos + 1:] - return command - - -@st.composite -def replace_arg(draw, commands, replacements): - return mutate_arg(draw, commands, lambda arg: draw(replacements)) - - -@st.composite -def uppercase_arg(draw, commands): - return mutate_arg(draw, commands, lambda arg: arg.upper()) - - -@st.composite -def prefix_arg(draw, commands, prefixes): - return mutate_arg(draw, commands, lambda arg: draw(prefixes) + arg) - - -@st.composite -def suffix_arg(draw, commands, suffixes): - return mutate_arg(draw, commands, lambda arg: arg + draw(suffixes)) - - -@st.composite -def add_arg(draw, commands, arguments): - command = draw(commands) - arg = draw(arguments) - pos = draw(st.integers(min_value=0, max_value=len(command.args))) - command.args = command.args[:pos] + (arg,) + command.args[pos:] - return command - - -@st.composite -def swap_args(draw, commands): - command = draw(commands) - if len(command.args) >= 2: - pos1 = draw(st.integers(min_value=0, max_value=len(command.args) - 1)) - pos2 = draw(st.integers(min_value=0, max_value=len(command.args) - 1)) - hypothesis.assume(pos1 != pos2) - args = list(command.args) - arg1 = args[pos1] - arg2 = args[pos2] - args[pos1] = arg2 - args[pos2] = arg1 - command.args = tuple(args) - return command - - -def mutated_commands(commands): - args = st.sampled_from([b'withscores', b'xx', b'nx', b'ex', b'px', b'weights', b'aggregate', - b'', b'0', b'-1', b'nan', b'inf', b'-inf']) | command_args(commands) - affixes = st.sampled_from([b'\0', b'-', b'+', b'\t', b'\n', b'0000']) | st.binary() - return st.recursive( - commands, - lambda x: - delete_arg(x) - | replace_arg(x, args) - | uppercase_arg(x) - | prefix_arg(x, affixes) - | suffix_arg(x, affixes) - | add_arg(x, args) - | swap_args(x)) - - -class TestFuzz(BaseTest): - command_strategy = mutated_commands(TestJoint.command_strategy) - command_strategy = command_strategy.filter(lambda command: command.testable) diff --git a/tox.ini b/tox.ini deleted file mode 100644 index 3d7da2e..0000000 --- a/tox.ini +++ /dev/null @@ -1,11 +0,0 @@ -[tox] -envlist = - py{27,34,35,36,37,38,py} - -[testenv] -usedevelop = True -commands = pytest -v {posargs} -extras = lua -deps = - hypothesis - pytest From d5bfe297990626cc3a3d507d2c0c10852b1ce4e3 Mon Sep 17 00:00:00 2001 From: Bruce Merry Date: Mon, 9 May 2022 13:42:59 +0200 Subject: [PATCH 20/20] Update URL to new fakeredis repo --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index edad847..68f2425 100644 --- a/README.rst +++ b/README.rst @@ -1,4 +1,4 @@ **fakeredis has moved** ======================= Fakeredis has a new maintainer, and can now be found at -https://github.com/dsoftwareinc/fakeredis. +https://github.com/dsoftwareinc/fakeredis-py.