Skip to content

Commit e03eb3b

Browse files
committed
WL13847: Add support for context managers
Context managers in Python, allows the allocation and releasing of resources precisely when it's needed. The most widely used example of context managers is the with statement. This worklog implements this functionality in the classic protocol for Connection objects and in the X DevAPI protocol for Session and Client objects.
1 parent 1401bd7 commit e03eb3b

File tree

5 files changed

+72
-3
lines changed

5 files changed

+72
-3
lines changed

lib/mysql/connector/abstracts.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2014, 2019, Oracle and/or its affiliates. All rights reserved.
1+
# Copyright (c) 2014, 2020, Oracle and/or its affiliates. All rights reserved.
22
#
33
# This program is free software; you can redistribute it and/or modify
44
# it under the terms of the GNU General Public License, version 2.0, as
@@ -111,6 +111,12 @@ def __init__(self, **kwargs):
111111

112112
self._consume_results = False
113113

114+
def __enter__(self):
115+
return self
116+
117+
def __exit__(self, exc_type, exc_value, traceback):
118+
self.close()
119+
114120
def _get_self(self):
115121
"""Return self for weakref.proxy
116122
@@ -1232,6 +1238,12 @@ def __init__(self):
12321238
self._warnings = None
12331239
self.arraysize = 1
12341240

1241+
def __enter__(self):
1242+
return self
1243+
1244+
def __exit__(self, exc_type, exc_value, traceback):
1245+
self.close()
1246+
12351247
@abstractmethod
12361248
def callproc(self, procname, args=()):
12371249
"""Calls a stored procedure with the given arguments

lib/mysqlx/connection.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1845,6 +1845,12 @@ def __init__(self, settings):
18451845
else "Default schema '{}' does not exists".format(schema)
18461846
raise InterfaceError(errmsg, err.errno)
18471847

1848+
def __enter__(self):
1849+
return self
1850+
1851+
def __exit__(self, exc_type, exc_value, traceback):
1852+
self.close()
1853+
18481854
def _init_attributes(self):
18491855
"""Setup default and user defined connection-attributes."""
18501856
if os.name == "nt":
@@ -2125,6 +2131,12 @@ def __init__(self, connection_dict, options_dict=None):
21252131
self.settings["max_size"] = self.max_size
21262132
self.settings["client_id"] = self.client_id
21272133

2134+
def __enter__(self):
2135+
return self
2136+
2137+
def __exit__(self, exc_type, exc_value, traceback):
2138+
self.close()
2139+
21282140
def _set_pool_size(self, pool_size):
21292141
"""Set the size of the pool.
21302142

tests/test_connection.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646

4747
try:
4848
from mysql.connector.connection_cext import HAVE_CMYSQL, CMySQLConnection
49+
from mysql.connector import cursor_cext
4950
except ImportError:
5051
# Test without C Extension
5152
CMySQLConnection = None
@@ -2001,7 +2002,6 @@ def test_get_connection_with_tls_version(self):
20012002
cur = cnx.cursor()
20022003
cur.execute("SHOW STATUS LIKE 'Ssl_version%'")
20032004
res = cur.fetchall()
2004-
20052005
self.assertEqual(res[0][1], expected_ssl_version,
20062006
err_msg.format(expected_ssl_version, res))
20072007

@@ -2195,6 +2195,29 @@ def test_dns_srv(self):
21952195
self.assertRaises(InterfaceError, connect, **config)
21962196
del config["unix_socket"]
21972197

2198+
@unittest.skipIf(not HAVE_CMYSQL, "C Extension not available")
2199+
def test_context_manager_cext(self):
2200+
"""Test connection and cursor context manager using the C extension."""
2201+
config = tests.get_mysql_config().copy()
2202+
config["use_pure"] = False
2203+
with connect(**config) as conn:
2204+
self.assertTrue(conn.is_connected())
2205+
with conn.cursor() as cur:
2206+
self.assertIsInstance(cur, cursor_cext.CMySQLCursor)
2207+
self.assertIsNone(cur._cnx)
2208+
self.assertFalse(conn.is_connected())
2209+
2210+
def test_context_manager_pure(self):
2211+
"""Test connection and cursor context manager using pure Python."""
2212+
config = tests.get_mysql_config().copy()
2213+
config["use_pure"] = True
2214+
with connect(**config) as conn:
2215+
self.assertTrue(conn.is_connected())
2216+
with conn.cursor() as cur:
2217+
self.assertIsInstance(cur, cursor.MySQLCursor)
2218+
self.assertIsNone(cur._connection)
2219+
self.assertFalse(conn.is_connected())
2220+
21982221

21992222
class WL13335(tests.MySQLConnectorTests):
22002223
"""WL#13335: Avoid set config values whit flag CAN_HANDLE_EXPIRED_PASSWORDS

tests/test_mysqlx_connection.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1583,6 +1583,13 @@ def test_dns_srv(self):
15831583
uri = "mysqlx+srv://root:@localhost/myschema?dns-srv=true"
15841584
self.assertRaises(InterfaceError, mysqlx.get_session, uri)
15851585

1586+
def test_context_manager(self):
1587+
"""Test mysqlx.get_session() context manager."""
1588+
with mysqlx.get_session(self.connect_kwargs) as session:
1589+
self.assertIsInstance(session, mysqlx.Session)
1590+
self.assertTrue(session.is_open())
1591+
self.assertFalse(session.is_open())
1592+
15861593

15871594
@unittest.skipIf(tests.MYSQL_VERSION < (8, 0, 20), "XPlugin not compatible")
15881595
class MySQLxInnitialNoticeTests(tests.MySQLxTests):

tests/test_mysqlx_pooling.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
# Copyright (c) 2018, 2019, Oracle and/or its affiliates. All rights reserved.
3+
# Copyright (c) 2018, 2020, Oracle and/or its affiliates. All rights reserved.
44
#
55
# This program is free software; you can redistribute it and/or modify
66
# it under the terms of the GNU General Public License, version 2.0, as
@@ -593,6 +593,21 @@ def test_get_client_with_tls_version(self):
593593
session.close()
594594
client.close()
595595

596+
def test_context_manager(self):
597+
"""Test mysqlx.get_client() context manager."""
598+
settings = self.connect_kwargs.copy()
599+
pooling_dict = {"enabled": True, "max_size": 5}
600+
cnx_options = {"pooling": pooling_dict}
601+
with mysqlx.get_client(settings, cnx_options) as client:
602+
with client.get_session() as session:
603+
self.assertIsInstance(session, mysqlx.Session)
604+
self.assertTrue(session.is_open())
605+
self.assertFalse(session.is_open())
606+
# Create one more session
607+
_ = client.get_session()
608+
for session in client.sessions:
609+
self.assertFalse(session.is_open())
610+
596611

597612
@unittest.skipIf(tests.MYSQL_VERSION < (5, 7, 12), "XPlugin not compatible")
598613
class MySQLxClientPoolingTests(tests.MySQLxTests):

0 commit comments

Comments
 (0)