Skip to content

Commit ae71200

Browse files
committed
new pubsub tests
1 parent def45c6 commit ae71200

File tree

6 files changed

+90
-111
lines changed

6 files changed

+90
-111
lines changed

redis/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1011,7 +1011,7 @@ def __init__(self, connection_pool, shard_hint=None):
10111011
self.patterns = set()
10121012
self.subscription_count = 0
10131013
self.subscribe_commands = set(
1014-
('subscribe', 'psusbscribe', 'unsubscribe', 'punsubscribe')
1014+
('subscribe', 'psubscribe', 'unsubscribe', 'punsubscribe')
10151015
)
10161016

10171017
def execute_command(self, *args, **kwargs):

redis/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def get_connection(self, command_name, *keys):
262262
def make_connection(self):
263263
"Create a new connection"
264264
if self._created_connections >= self.max_connections:
265-
raise Exception("Too many connections")
265+
raise ConnectionError("Too many connections")
266266
self._created_connections += 1
267267
return self.connection_class(**self.connection_kwargs)
268268

tests/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from connection_pool import ConnectionPoolTestCase
44
from pipeline import PipelineTestCase
55
from lock import LockTestCase
6+
from pubsub import PubSubTestCase
67

78
use_hiredis = False
89
try:
@@ -17,4 +18,5 @@ def all_tests():
1718
suite.addTest(unittest.makeSuite(ConnectionPoolTestCase))
1819
suite.addTest(unittest.makeSuite(PipelineTestCase))
1920
suite.addTest(unittest.makeSuite(LockTestCase))
21+
suite.addTest(unittest.makeSuite(PubSubTestCase))
2022
return suite

tests/connection_pool.py

Lines changed: 35 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,39 @@
11
import redis
2-
import threading
3-
import time
42
import unittest
53

6-
class ConnectionPoolTestCase(unittest.TestCase):
7-
# TODO:
8-
# THIS TEST IS INVALID WITH THE DEFAULT CONNECTIONPOOL
9-
#
10-
# def test_multiple_connections(self):
11-
# # 2 clients to the same host/port/db/pool should use the same connection
12-
# pool = redis.ConnectionPool()
13-
# r1 = redis.Redis(host='localhost', port=6379, db=9, connection_pool=pool)
14-
# r2 = redis.Redis(host='localhost', port=6379, db=9, connection_pool=pool)
15-
# self.assertEquals(r1.connection, r2.connection)
16-
17-
# # if one of them switches, they should have
18-
# # separate conncetion objects
19-
# r2.select(db=10, host='localhost', port=6379)
20-
# self.assertNotEqual(r1.connection, r2.connection)
21-
22-
# conns = [r1.connection, r2.connection]
23-
# conns.sort()
24-
25-
# # but returning to the original state shares the object again
26-
# r2.select(db=9, host='localhost', port=6379)
27-
# self.assertEquals(r1.connection, r2.connection)
28-
29-
# # the connection manager should still have just 2 connections
30-
# mgr_conns = pool.get_all_connections()
31-
# mgr_conns.sort()
32-
# self.assertEquals(conns, mgr_conns)
33-
34-
def test_threaded_workers(self):
35-
# TODO: review this, does it even make sense anymore?
36-
r = redis.Redis(host='localhost', port=6379, db=9)
37-
r.set('a', 'foo')
38-
r.set('b', 'bar')
39-
40-
def _info_worker():
41-
for i in range(50):
42-
_ = r.info()
43-
time.sleep(0.01)
44-
45-
def _keys_worker():
46-
for i in range(50):
47-
_ = r.keys()
48-
time.sleep(0.01)
49-
50-
t1 = threading.Thread(target=_info_worker)
51-
t2 = threading.Thread(target=_keys_worker)
52-
t1.start()
53-
t2.start()
54-
55-
for i in [t1, t2]:
56-
i.join()
4+
class DummyConnection(object):
5+
def __init__(self, **kwargs):
6+
self.kwargs = kwargs
577

8+
class ConnectionPoolTestCase(unittest.TestCase):
9+
def get_pool(self, connection_info=None, max_connections=None):
10+
connection_info = connection_info or {'a': 1, 'b': 2, 'c': 3}
11+
pool = redis.ConnectionPool(
12+
connection_class=DummyConnection, max_connections=max_connections,
13+
**connection_info)
14+
return pool
15+
16+
def test_connection_creation(self):
17+
connection_info = {'foo': 'bar', 'biz': 'baz'}
18+
pool = self.get_pool(connection_info=connection_info)
19+
connection = pool.get_connection('_')
20+
self.assertEquals(connection.kwargs, connection_info)
21+
22+
def test_multiple_connections(self):
23+
pool = self.get_pool()
24+
c1 = pool.get_connection('_')
25+
c2 = pool.get_connection('_')
26+
self.assert_(c1 != c2)
27+
28+
def test_max_connections(self):
29+
pool = self.get_pool(max_connections=2)
30+
c1 = pool.get_connection('_')
31+
c2 = pool.get_connection('_')
32+
self.assertRaises(redis.ConnectionError, pool.get_connection, '_')
33+
34+
def test_release(self):
35+
pool = self.get_pool()
36+
c1 = pool.get_connection('_')
37+
pool.release(c1)
38+
c2 = pool.get_connection('_')
39+
self.assertEquals(c1, c2)

tests/pubsub.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import redis
2+
import unittest
3+
4+
class PubSubTestCase(unittest.TestCase):
5+
def setUp(self):
6+
self.connection_pool = redis.ConnectionPool()
7+
self.client = redis.Redis(connection_pool=self.connection_pool)
8+
self.pubsub = self.client.pubsub()
9+
10+
def tearDown(self):
11+
self.connection_pool.disconnect()
12+
13+
def test_channel_subscribe(self):
14+
self.assertEquals(
15+
self.pubsub.subscribe('foo'),
16+
['subscribe', 'foo', 1]
17+
)
18+
self.assertEquals(self.client.publish('foo', 'hello foo'), 1)
19+
self.assertEquals(
20+
self.pubsub.listen().next(),
21+
{
22+
'type': 'message',
23+
'pattern': None,
24+
'channel': 'foo',
25+
'data': 'hello foo'
26+
}
27+
)
28+
self.assertEquals(
29+
self.pubsub.unsubscribe('foo'),
30+
['unsubscribe', 'foo', 0]
31+
)
32+
33+
def test_pattern_subscribe(self):
34+
self.assertEquals(
35+
self.pubsub.psubscribe('fo*'),
36+
['psubscribe', 'fo*', 1]
37+
)
38+
self.assertEquals(self.client.publish('foo', 'hello foo'), 1)
39+
self.assertEquals(
40+
self.pubsub.listen().next(),
41+
{
42+
'type': 'pmessage',
43+
'pattern': 'fo*',
44+
'channel': 'foo',
45+
'data': 'hello foo'
46+
}
47+
)
48+
self.assertEquals(
49+
self.pubsub.punsubscribe('fo*'),
50+
['punsubscribe', 'fo*', 0]
51+
)

tests/server_commands.py

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import redis
22
import unittest
33
import datetime
4-
import threading
54
import time
65
from distutils.version import StrictVersion
76

@@ -1194,61 +1193,6 @@ def test_sort_all_options(self):
11941193
self.assertEquals(self.client.lrange('sorted', 0, 10),
11951194
['vodka', 'milk', 'gin', 'apple juice'])
11961195

1197-
# PUBSUB
1198-
# def test_pubsub(self):
1199-
# # create a new client to not polute the existing one
1200-
# r = self.get_client()
1201-
# channels = ('a1', 'a2', 'a3')
1202-
# for c in channels:
1203-
# r.subscribe(c)
1204-
# # state variable should be flipped
1205-
# self.assertEquals(r.subscribed, True)
1206-
1207-
# channels_to_publish_to = channels + ('a4',)
1208-
# messages_per_channel = 4
1209-
# def publish():
1210-
# for i in range(messages_per_channel):
1211-
# for c in channels_to_publish_to:
1212-
# self.client.publish(c, 'a message')
1213-
# time.sleep(0.01)
1214-
# for c in channels_to_publish_to:
1215-
# self.client.publish(c, 'unsubscribe')
1216-
# time.sleep(0.01)
1217-
1218-
# messages = []
1219-
# self.assertRaises(redis.PubSubError, r.set, 'foo', 'bar')
1220-
# # should receive a message for each subscribe/unsubscribe command
1221-
# # plus a message for each iteration of the loop * num channels
1222-
# # we hide the data messages that tell the client to unsubscribe
1223-
# num_messages_to_expect = len(channels)*2 + \
1224-
# (messages_per_channel*len(channels))
1225-
# t = threading.Thread(target=publish)
1226-
# t.start()
1227-
# for msg in r.listen():
1228-
# if msg['data'] == 'unsubscribe':
1229-
# r.unsubscribe(msg['channel'])
1230-
# else:
1231-
# messages.append(msg)
1232-
1233-
# self.assertEquals(r.subscribed, False)
1234-
# self.assertEquals(len(messages), num_messages_to_expect)
1235-
# sent_types, sent_channels = {}, {}
1236-
# for msg in messages:
1237-
# msg_type = msg['type']
1238-
# channel = msg['channel']
1239-
# sent_types.setdefault(msg_type, 0)
1240-
# sent_types[msg_type] += 1
1241-
# if msg_type == 'message':
1242-
# sent_channels.setdefault(channel, 0)
1243-
# sent_channels[channel] += 1
1244-
# for channel in channels:
1245-
# self.assert_(channel in sent_channels)
1246-
# self.assertEquals(sent_channels[channel], messages_per_channel)
1247-
# self.assertEquals(sent_types['subscribe'], len(channels))
1248-
# self.assertEquals(sent_types['unsubscribe'], len(channels))
1249-
# self.assertEquals(sent_types['message'],
1250-
# len(channels) * messages_per_channel)
1251-
12521196
## BINARY SAFE
12531197
# TODO add more tests
12541198
def test_binary_get_set(self):

0 commit comments

Comments
 (0)