Skip to content

Commit c650073

Browse files
committed
all tests passing with new connection pool
1 parent f64c4ad commit c650073

File tree

4 files changed

+165
-165
lines changed

4 files changed

+165
-165
lines changed

redis/client.py

Lines changed: 85 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,21 @@ def __init__(self, host='localhost', port=6379,
177177
db=0, password=None, socket_timeout=None,
178178
connection_pool=None,
179179
charset='utf-8', errors='strict'):
180-
self.encoding = charset
181-
self.errors = errors
182-
self.connection = None
183180
self.subscribed = False
184-
self.connection_pool = connection_pool and connection_pool or ConnectionPool()
185-
self.select(db, host, port, password, socket_timeout)
186-
187-
#### Legacty accessors of connection information ####
181+
if connection_pool:
182+
self.connection_pool = connection_pool
183+
else:
184+
self.connection_pool = ConnectionPool(
185+
host=host,
186+
port=port,
187+
db=db,
188+
password=password,
189+
socket_timeout=socket_timeout,
190+
encoding=charset,
191+
encoding_errors=errors
192+
)
193+
194+
#### Legacy accessors of connection information ####
188195
def _get_host(self):
189196
return self.connection.host
190197
host = property(_get_host)
@@ -205,12 +212,7 @@ def pipeline(self, transaction=True):
205212
pipelines are useful for batch loading of data as they reduce the
206213
number of back and forth network operations between client and server.
207214
"""
208-
return Pipeline(
209-
self.connection,
210-
transaction,
211-
self.encoding,
212-
self.errors
213-
)
215+
return Pipeline(self.connection_pool, transaction)
214216

215217
def lock(self, name, timeout=None, sleep=0.1):
216218
"""
@@ -228,77 +230,38 @@ def lock(self, name, timeout=None, sleep=0.1):
228230

229231
#### COMMAND EXECUTION AND PROTOCOL PARSING ####
230232
def execute_command(self, *args, **options):
231-
command_name = args[0]
232-
subscription_command = command_name in self.SUBSCRIPTION_COMMANDS
233-
if self.subscribed and not subscription_command:
234-
raise PubSubError("Cannot issue commands other than SUBSCRIBE and "
235-
"UNSUBSCRIBE while channels are open")
233+
connection = self.connection_pool.get_connection()
236234
try:
237-
self.connection.send_command(*args)
238-
if subscription_command:
239-
return None
240-
return self.parse_response(command_name, **options)
241-
except ConnectionError:
242-
self.connection.disconnect()
243-
self.connection.send_command(*args)
244-
if subscription_command:
245-
return None
246-
return self.parse_response(command_name, **options)
247-
248-
def parse_response(self, command_name, **options):
235+
command_name = args[0]
236+
subscription_command = command_name in self.SUBSCRIPTION_COMMANDS
237+
if self.subscribed and not subscription_command:
238+
raise PubSubError("Cannot issue commands other than SUBSCRIBE "
239+
"and UNSUBSCRIBE while channels are open")
240+
try:
241+
connection.send_command(*args)
242+
if subscription_command:
243+
return None
244+
return self.parse_response(connection, command_name, **options)
245+
except ConnectionError:
246+
connection.disconnect()
247+
connection.send_command(*args)
248+
if subscription_command:
249+
return None
250+
return self.parse_response(connection, command_name, **options)
251+
finally:
252+
self.connection_pool.release(connection)
253+
254+
def parse_response(self, connection, command_name, **options):
249255
"Parses a response from the Redis server"
250-
response = self.connection.read_response()
256+
response = connection.read_response()
251257
if command_name in self.RESPONSE_CALLBACKS:
252258
return self.RESPONSE_CALLBACKS[command_name](response, **options)
253259
return response
254260

255261
#### CONNECTION HANDLING ####
256-
def get_connection(self, host, port, db, password, socket_timeout):
257-
"Returns a connection object"
258-
conn = self.connection_pool.get_connection(
259-
host, port, db, password, socket_timeout)
260-
# if for whatever reason the connection gets a bad password, make
261-
# sure a subsequent attempt with the right password makes its way
262-
# to the connection
263-
conn.password = password
264-
return conn
265-
266-
def _setup_connection(self):
267-
"""
268-
After successfully opening a socket to the Redis server, the
269-
connection object calls this method to authenticate and select
270-
the appropriate database.
271-
"""
272-
self.subscribed = False
273-
if self.connection.password:
274-
if not self.execute_command('AUTH', self.connection.password):
275-
raise AuthenticationError("Invalid Password")
276-
self.execute_command('SELECT', self.connection.db)
277-
278-
def select(self, db, host=None, port=None, password=None,
279-
socket_timeout=None):
280-
"""
281-
Switch to a different Redis connection.
282-
283-
If the host and port aren't provided and there's an existing
284-
connection, use the existing connection's host and port instead.
285-
286-
Note this method actually replaces the underlying connection object
287-
prior to issuing the SELECT command. This makes sure we protect
288-
the thread-safe connections
289-
"""
290-
if host is None:
291-
if self.connection is None:
292-
raise RedisError("A valid hostname or IP address "
293-
"must be specified")
294-
host = self.connection.host
295-
if port is None:
296-
if self.connection is None:
297-
raise RedisError("A valid port must be specified")
298-
port = self.connection.port
299-
300-
self.connection = self.get_connection(
301-
host, port, db, password, socket_timeout)
262+
def select(self, db):
263+
"SELECT a differnet Redis database."
264+
return self.execute_command('SELECT', db)
302265

303266
def shutdown(self):
304267
"Shutdown the server"
@@ -1246,11 +1209,9 @@ class Pipeline(Redis):
12461209
ResponseError exceptions, such as those raised when issuing a command
12471210
on a key of a different datatype.
12481211
"""
1249-
def __init__(self, connection, transaction, charset, errors):
1250-
self.connection = connection
1212+
def __init__(self, connection_pool, transaction):
1213+
self.connection_pool = connection_pool
12511214
self.transaction = transaction
1252-
self.encoding = charset
1253-
self.errors = errors
12541215
self.subscribed = False # NOTE not in use, but necessary
12551216
self.reset()
12561217

@@ -1275,42 +1236,51 @@ def execute_command(self, *args, **options):
12751236
return self
12761237

12771238
def _execute_transaction(self, commands):
1278-
all_cmds = ''.join(starmap(self.connection.pack_command,
1279-
[args for args, options in commands]))
1280-
self.connection.send_packed_command(all_cmds)
1281-
# we don't care about the multi/exec any longer
1282-
commands = commands[1:-1]
1283-
# parse off the response for MULTI and all commands prior to EXEC
1284-
# the only data we care about is the response the EXEC, the last command
1285-
for i in range(len(commands)+1):
1286-
_ = self.parse_response('_')
1287-
# parse the EXEC.
1288-
response = self.parse_response('_')
1289-
1290-
if response is None:
1291-
raise WatchError("Watched variable changed.")
1292-
1293-
if len(response) != len(commands):
1294-
raise ResponseError("Wrong number of response items from "
1295-
"pipeline execution")
1296-
# We have to run response callbacks manually
1297-
data = []
1298-
for r, cmd in izip(response, commands):
1299-
if not isinstance(r, Exception):
1300-
args, options = cmd
1301-
command_name = args[0]
1302-
if command_name in self.RESPONSE_CALLBACKS:
1303-
r = self.RESPONSE_CALLBACKS[command_name](r, **options)
1304-
data.append(r)
1305-
return data
1239+
connection = self.connection_pool.get_connection()
1240+
try:
1241+
all_cmds = ''.join(starmap(connection.pack_command,
1242+
[args for args, options in commands]))
1243+
connection.send_packed_command(all_cmds)
1244+
# we don't care about the multi/exec any longer
1245+
commands = commands[1:-1]
1246+
# parse off the response for MULTI and all commands prior to EXEC.
1247+
# the only data we care about is the response the EXEC
1248+
# which is the last command
1249+
for i in range(len(commands)+1):
1250+
_ = self.parse_response(connection, '_')
1251+
# parse the EXEC.
1252+
response = self.parse_response(connection, '_')
1253+
1254+
if response is None:
1255+
raise WatchError("Watched variable changed.")
1256+
1257+
if len(response) != len(commands):
1258+
raise ResponseError("Wrong number of response items from "
1259+
"pipeline execution")
1260+
# We have to run response callbacks manually
1261+
data = []
1262+
for r, cmd in izip(response, commands):
1263+
if not isinstance(r, Exception):
1264+
args, options = cmd
1265+
command_name = args[0]
1266+
if command_name in self.RESPONSE_CALLBACKS:
1267+
r = self.RESPONSE_CALLBACKS[command_name](r, **options)
1268+
data.append(r)
1269+
return data
1270+
finally:
1271+
self.connection_pool.release(connection)
13061272

13071273
def _execute_pipeline(self, commands):
13081274
# build up all commands into a single request to increase network perf
1309-
all_cmds = ''.join(starmap(self.connection.pack_command,
1310-
[args for args, options in commands]))
1311-
self.connection.send_packed_command(all_cmds)
1312-
return [self.parse_response(args[0], **options)
1313-
for args, options in commands]
1275+
connection = self.connection_pool.get_connection()
1276+
try:
1277+
all_cmds = ''.join(starmap(connection.pack_command,
1278+
[args for args, options in commands]))
1279+
connection.send_packed_command(all_cmds)
1280+
return [self.parse_response(connection, args[0], **options)
1281+
for args, options in commands]
1282+
finally:
1283+
self.connection_pool.release(connection)
13141284

13151285
def execute(self):
13161286
"Execute all the commands in the current pipeline"
@@ -1324,7 +1294,7 @@ def execute(self):
13241294
try:
13251295
return execute(stack)
13261296
except ConnectionError:
1327-
self.connection.disconnect()
1297+
connection.disconnect()
13281298
return execute(stack)
13291299

13301300
def select(self, *args, **kwargs):

redis/connection.py

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,7 @@ def connect(self):
116116
if self._sock:
117117
return
118118
try:
119-
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
120-
sock.settimeout(self.socket_timeout)
121-
sock.connect((self.host, self.port))
119+
sock = self._connect()
122120
except socket.error, e:
123121
# args for socket.error can either be (errno, "message")
124122
# or just "message"
@@ -133,6 +131,13 @@ def connect(self):
133131
self._sock = sock
134132
self.on_connect()
135133

134+
def _connect(self):
135+
"Create a TCP socket connection"
136+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
137+
sock.settimeout(self.socket_timeout)
138+
sock.connect((self.host, self.port))
139+
return sock
140+
136141
def on_connect(self):
137142
"Initialize the connection, authenticate and select a database"
138143
self._parser.on_connect(self)
@@ -199,7 +204,7 @@ def read_response(self):
199204
return response
200205

201206
def encode(self, value):
202-
"Return a bytestring of the value"
207+
"Return a bytestring representation of the value"
203208
if isinstance(value, unicode):
204209
return value.encode(self.encoding, self.encoding_errors)
205210
return str(value)
@@ -210,24 +215,47 @@ def pack_command(self, *args):
210215
for enc_value in imap(self.encode, args)]
211216
return '*%s\r\n%s' % (len(command), ''.join(command))
212217

213-
class ConnectionPool(threading.local):
214-
"Manages a list of connections on the local thread"
215-
def __init__(self, connection_class=None):
216-
self.connections = {}
217-
self.connection_class = connection_class or Connection
218-
219-
def make_connection_key(self, host, port, db):
220-
"Create a unique key for the specified host, port and db"
221-
return '%s:%s:%s' % (host, port, db)
222-
223-
def get_connection(self, host, port, db, password, socket_timeout):
224-
"Return a specific connection for the specified host, port and db"
225-
key = self.make_connection_key(host, port, db)
226-
if key not in self.connections:
227-
self.connections[key] = self.connection_class(
228-
host, port, db, password, socket_timeout)
229-
return self.connections[key]
230-
231-
def get_all_connections(self):
232-
"Return a list of all connection objects the manager knows about"
233-
return self.connections.values()
218+
class UnixDomainSocketConnection(Connection):
219+
def __init__(self, path='', db=0, password=None,
220+
socket_timeout=None, encoding='utf-8',
221+
encoding_errors='strict', parser_class=DefaultParser):
222+
self.path = path
223+
self.db = db
224+
self.password = password
225+
self.socket_timeout = socket_timeout
226+
self.encoding = encoding
227+
self.encoding_errors = encoding_errors
228+
self._sock = None
229+
self._parser = parser_class()
230+
231+
def _connect(self):
232+
"Create a Unix domain socket connection"
233+
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
234+
sock.settimeout(self.socket_timeout)
235+
sock.connect(self.path)
236+
return sock
237+
238+
class ConnectionPool(object):
239+
"""
240+
A connection pool that maintains only one connection. Great for
241+
single-threaded apps with no sharding
242+
"""
243+
def __init__(self, connection_class=Connection, **kwargs):
244+
self.connection_class = connection_class
245+
self.kwargs = kwargs
246+
self._connection = None
247+
248+
def get_connection(self, *args, **kwargs):
249+
"Get a connection from the pool"
250+
if not self._connection:
251+
self._connection = self.connection_class(**self.kwargs)
252+
return self._connection
253+
254+
def release(self, connection):
255+
"Releases the connection back to the pool"
256+
pass
257+
258+
def disconnect(self):
259+
"Disconnects all connections in the pool"
260+
if self._connection:
261+
self._connection.disconnect()

0 commit comments

Comments
 (0)