Skip to content

Commit d1cd365

Browse files
committed
completely refactored pubsub. this is backwards incompatible, but quite necessary.
1 parent 58390a7 commit d1cd365

File tree

1 file changed

+96
-70
lines changed

1 file changed

+96
-70
lines changed

redis/client.py

Lines changed: 96 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -167,16 +167,10 @@ class Redis(object):
167167
}
168168
)
169169

170-
# commands that should NOT pull data off the network buffer when executed
171-
SUBSCRIPTION_COMMANDS = set([
172-
'SUBSCRIBE', 'UNSUBSCRIBE', 'PSUBSCRIBE', 'PUNSUBSCRIBE'
173-
])
174-
175170
def __init__(self, host='localhost', port=6379,
176171
db=0, password=None, socket_timeout=None,
177172
connection_pool=None,
178173
charset='utf-8', errors='strict'):
179-
self.subscribed = False
180174
if connection_pool:
181175
self.connection_pool = connection_pool
182176
else:
@@ -214,26 +208,20 @@ def lock(self, name, timeout=None, sleep=0.1):
214208
"""
215209
return Lock(self, name, timeout=timeout, sleep=sleep)
216210

211+
def pubsub(self):
212+
return PubSub(self.connection_pool)
213+
217214
#### COMMAND EXECUTION AND PROTOCOL PARSING ####
218215
def execute_command(self, *args, **options):
219216
command_name = args[0]
220217
connection = self.connection_pool.get_connection(command_name)
221218
try:
222-
subscription_command = command_name in self.SUBSCRIPTION_COMMANDS
223-
if self.subscribed and not subscription_command:
224-
raise PubSubError("Cannot issue commands other than SUBSCRIBE "
225-
"and UNSUBSCRIBE while channels are open")
226-
try:
227-
connection.send_command(*args)
228-
if subscription_command:
229-
return None
230-
return self.parse_response(connection, command_name, **options)
231-
except ConnectionError:
232-
connection.disconnect()
233-
connection.send_command(*args)
234-
if subscription_command:
235-
return None
236-
return self.parse_response(connection, command_name, **options)
219+
connection.send_command(*args)
220+
return self.parse_response(connection, command_name, **options)
221+
except ConnectionError:
222+
connection.disconnect()
223+
connection.send_command(*args)
224+
return self.parse_response(connection, command_name, **options)
237225
finally:
238226
self.connection_pool.release(connection)
239227

@@ -244,23 +232,6 @@ def parse_response(self, connection, command_name, **options):
244232
return self.RESPONSE_CALLBACKS[command_name](response, **options)
245233
return response
246234

247-
#### CONNECTION HANDLING ####
248-
def select(self, db):
249-
"SELECT a differnet Redis database."
250-
return self.execute_command('SELECT', db)
251-
252-
def shutdown(self):
253-
"Shutdown the server"
254-
if self.subscribed:
255-
raise PubSubError("Can't call 'shutdown' when 'subscribed'")
256-
try:
257-
self.execute_command('SHUTDOWN')
258-
except ConnectionError:
259-
# a ConnectionError here is expected
260-
return
261-
raise RedisError("SHUTDOWN seems to have failed.")
262-
263-
264235
#### SERVER INFORMATION ####
265236
def bgrewriteaof(self):
266237
"Tell the Redis server to rewrite the AOF file from data in memory."
@@ -328,6 +299,19 @@ def save(self):
328299
"""
329300
return self.execute_command('SAVE')
330301

302+
def select(self, db):
303+
"Select a differnet Redis database"
304+
return self.execute_command('SELECT', db)
305+
306+
def shutdown(self):
307+
"Shutdown the server"
308+
try:
309+
self.execute_command('SHUTDOWN')
310+
except ConnectionError:
311+
# a ConnectionError here is expected
312+
return
313+
raise RedisError("SHUTDOWN seems to have failed.")
314+
331315
def slaveof(self, host=None, port=None):
332316
"""
333317
Set the server to be a replicated slave of the instance identified
@@ -552,18 +536,12 @@ def watch(self, *names):
552536
"""
553537
Watches the values at keys ``names``, or None if the key doesn't exist
554538
"""
555-
if self.subscribed:
556-
raise PubSubError("Can't call 'watch' when 'subscribed'")
557-
558539
return self.execute_command('WATCH', *names)
559540

560541
def unwatch(self):
561542
"""
562543
Unwatches the value at key ``name``, or None of the key doesn't exist
563544
"""
564-
if self.subscribed:
565-
raise PubSubError("Can't call 'unwatch' when 'subscribed'")
566-
567545
return self.execute_command('UNWATCH')
568546

569547
#### LIST COMMANDS ####
@@ -1107,20 +1085,63 @@ def hvals(self, name):
11071085
"Return the list of values within hash ``name``"
11081086
return self.execute_command('HVALS', name)
11091087

1110-
def pubsub(self):
1111-
return PubSub(self.connection_pool)
1088+
def publish(self, channel, message):
1089+
"""
1090+
Publish ``message`` on ``channel``.
1091+
Returns the number of subscribers the message was delivered to.
1092+
"""
1093+
return self.execute_command('PUBLISH', channel, message)
11121094

11131095

1114-
# channels
1096+
class PubSub(object):
1097+
def __init__(self, connection_pool):
1098+
self.connection_pool = connection_pool
1099+
self.connection = None
1100+
self.channels = set()
1101+
self.patterns = set()
1102+
self.subscription_count = 0
1103+
self.subscribe_commands = set(
1104+
('subscribe', 'psusbscribe', 'unsubscribe', 'punsubscribe')
1105+
)
1106+
1107+
def execute_command(self, *args, **kwargs):
1108+
"Execute a publish/subscribe command"
1109+
if self.connection is None:
1110+
self.connection = self.connection_pool.get_connection('pubsub')
1111+
connection = self.connection
1112+
try:
1113+
connection.send_command(*args)
1114+
return self.parse_response()
1115+
except ConnectionError:
1116+
connection.disconnect()
1117+
# resubscribe to all channels and patterns before
1118+
# resending the current command
1119+
for channel in self.channels:
1120+
self.subscribe(channel)
1121+
for pattern in self.patterns:
1122+
self.psubscribe(pattern)
1123+
connection.send_command(*args)
1124+
return self.parse_response()
1125+
1126+
def parse_response(self):
1127+
"Parse the response from a publish/subscribe command"
1128+
response = self.connection.read_response()
1129+
if response[0] in self.subscribe_commands:
1130+
self.subscription_count = response[2]
1131+
# if we've just unsubscribed from the remaining channels,
1132+
# release the connection back to the pool
1133+
if not self.subscription_count:
1134+
self.connection_pool.release(self.connection)
1135+
self.connection = None
1136+
return response
1137+
11151138
def psubscribe(self, patterns):
11161139
"Subscribe to all channels matching any pattern in ``patterns``"
11171140
if isinstance(patterns, basestring):
11181141
patterns = [patterns]
1119-
response = self.execute_command('PSUBSCRIBE', *patterns)
1120-
# this is *after* the SUBSCRIBE in order to allow for lazy and broken
1121-
# connections that need to issue AUTH and SELECT commands
1122-
self.subscribed = True
1123-
return response
1142+
for pattern in patterns:
1143+
self.patterns.add(pattern)
1144+
return self.execute_command('PSUBSCRIBE', *patterns)
11241145

11251146
def punsubscribe(self, patterns=[]):
11261147
"""
@@ -1129,17 +1150,20 @@ def punsubscribe(self, patterns=[]):
11291150
"""
11301151
if isinstance(patterns, basestring):
11311152
patterns = [patterns]
1153+
for pattern in patterns:
1154+
try:
1155+
self.patterns.remove(pattern)
1156+
except KeyError:
1157+
pass
11321158
return self.execute_command('PUNSUBSCRIBE', *patterns)
11331159

11341160
def subscribe(self, channels):
11351161
"Subscribe to ``channels``, waiting for messages to be published"
11361162
if isinstance(channels, basestring):
11371163
channels = [channels]
1138-
response = self.execute_command('SUBSCRIBE', *channels)
1139-
# this is *after* the SUBSCRIBE in order to allow for lazy and broken
1140-
# connections that need to issue AUTH and SELECT commands
1141-
self.subscribed = True
1142-
return response
1164+
for channel in channels:
1165+
self.channels.add(channel)
1166+
return self.execute_command('SUBSCRIBE', *channels)
11431167

11441168
def unsubscribe(self, channels=[]):
11451169
"""
@@ -1148,6 +1172,11 @@ def unsubscribe(self, channels=[]):
11481172
"""
11491173
if isinstance(channels, basestring):
11501174
channels = [channels]
1175+
for channel in channels:
1176+
try:
1177+
self.channels.remove(channel)
1178+
except KeyError:
1179+
pass
11511180
return self.execute_command('UNSUBSCRIBE', *channels)
11521181

11531182
def publish(self, channel, message):
@@ -1159,24 +1188,22 @@ def publish(self, channel, message):
11591188

11601189
def listen(self):
11611190
"Listen for messages on channels this client has been subscribed to"
1162-
while self.subscribed:
1163-
r = self.parse_response('LISTEN')
1191+
while self.subscription_count:
1192+
r = self.parse_response()
11641193
if r[0] == 'pmessage':
11651194
msg = {
1166-
'type': r[0],
1167-
'pattern': r[1],
1168-
'channel': r[2],
1169-
'data': r[3]
1195+
'type': r[0],
1196+
'pattern': r[1],
1197+
'channel': r[2],
1198+
'data': r[3]
11701199
}
11711200
else:
11721201
msg = {
1173-
'type': r[0],
1174-
'pattern': None,
1175-
'channel': r[1],
1176-
'data': r[2]
1202+
'type': r[0],
1203+
'pattern': None,
1204+
'channel': r[1],
1205+
'data': r[2]
11771206
}
1178-
if r[0] == 'unsubscribe' and r[2] == 0:
1179-
self.subscribed = False
11801207
yield msg
11811208

11821209

@@ -1201,7 +1228,6 @@ class Pipeline(Redis):
12011228
def __init__(self, connection_pool, transaction):
12021229
self.connection_pool = connection_pool
12031230
self.transaction = transaction
1204-
self.subscribed = False # NOTE not in use, but necessary
12051231
self.reset()
12061232

12071233
def reset(self):

0 commit comments

Comments
 (0)