8
8
import re
9
9
10
10
from abc import ABC , abstractmethod
11
+ from contextlib import nullcontext
11
12
from typing import TYPE_CHECKING , Any , Callable , Collection , Dict , Optional , Union
12
13
13
14
# pylint: disable=cyclic-import
@@ -187,7 +188,22 @@ def set_connection_span_attrs(
187
188
cnx_span .set_attributes (attrs )
188
189
189
190
190
- def instrument_execution (
191
+ def with_connection_span_attached (method : Callable ) -> Callable :
192
+ """Attach the connection span while executing a connection method."""
193
+
194
+ def wrapper (
195
+ cnx : Union ["MySQLConnection" , "CMySQLConnection" ], * args : Any , ** kwargs : Any
196
+ ) -> Any :
197
+ """Context propagation decorator."""
198
+ with trace .use_span (
199
+ cnx ._span , end_on_exit = False
200
+ ) if cnx ._span and cnx ._span .is_recording () else nullcontext ():
201
+ return method (cnx , * args , ** kwargs )
202
+
203
+ return wrapper
204
+
205
+
206
+ def _instrument_execution (
191
207
query_method : Callable ,
192
208
tracer : trace .Tracer ,
193
209
connection_span_link : trace .Link ,
@@ -288,7 +304,7 @@ def __init__(
288
304
289
305
def execute (self , * args : Any , ** kwargs : Any ) -> Any :
290
306
"""Instruments execute method."""
291
- return instrument_execution (
307
+ return _instrument_execution (
292
308
self ._wrapped .execute ,
293
309
self ._tracer ,
294
310
self ._connection_span_link ,
@@ -299,7 +315,7 @@ def execute(self, *args: Any, **kwargs: Any) -> Any:
299
315
300
316
def executemany (self , * args : Any , ** kwargs : Any ) -> Any :
301
317
"""Instruments executemany method."""
302
- return instrument_execution (
318
+ return _instrument_execution (
303
319
self ._wrapped .executemany ,
304
320
self ._tracer ,
305
321
self ._connection_span_link ,
@@ -310,7 +326,7 @@ def executemany(self, *args: Any, **kwargs: Any) -> Any:
310
326
311
327
def callproc (self , * args : Any , ** kwargs : Any ) -> Any :
312
328
"""Instruments callproc method."""
313
- return instrument_execution (
329
+ return _instrument_execution (
314
330
self ._wrapped .callproc ,
315
331
self ._tracer ,
316
332
self ._connection_span_link ,
@@ -332,15 +348,20 @@ def __init__(self, wrapped: Union["MySQLConnection", "CMySQLConnection"]) -> Non
332
348
_ = self ._wrapped .sql_mode
333
349
334
350
def cursor (self , * args : Any , ** kwargs : Any ) -> TracedMySQLCursor :
335
- """Wraps the cursor object."""
351
+ """Wraps the object method ."""
336
352
return TracedMySQLCursor (
337
353
wrapped = self ._wrapped .cursor (* args , ** kwargs ),
338
354
tracer = self ._tracer ,
339
355
connection_span = self ._span ,
340
356
)
341
357
358
+ @with_connection_span_attached
359
+ def cmd_change_user (self , * args : Any , ** kwargs : Any ) -> Any :
360
+ """Wraps the object method."""
361
+ return self ._wrapped .cmd_change_user (* args , ** kwargs )
362
+
342
363
343
- def instrument_connect (
364
+ def _instrument_connect (
344
365
connect : Callable [
345
366
..., Union ["MySQLConnection" , "CMySQLConnection" , "PooledMySQLConnection" ]
346
367
],
@@ -376,19 +397,21 @@ def wrapper(
376
397
)
377
398
kwargs [OPTION_CNX_TRACER ] = tracer
378
399
379
- # Add basic net information.
380
- set_connection_span_attrs (None , kwargs [OPTION_CNX_SPAN ], kwargs )
400
+ # attach connection span
401
+ with trace .use_span (kwargs [OPTION_CNX_SPAN ], end_on_exit = False ) as cnx_span :
402
+ # Add basic net information.
403
+ set_connection_span_attrs (None , cnx_span , kwargs )
381
404
382
- # Connection may fail at this point, in case it does, basic net info is already
383
- # included so the user can check the net configuration she/he provided.
384
- cnx = connect (* args , ** kwargs )
405
+ # Connection may fail at this point, in case it does, basic net info is already
406
+ # included so the user can check the net configuration she/he provided.
407
+ cnx = connect (* args , ** kwargs )
385
408
386
- # connection went ok, let's refine the net information.
387
- set_connection_span_attrs (cnx , cnx . _span , kwargs ) # type: ignore[arg-type]
409
+ # connection went ok, let's refine the net information.
410
+ set_connection_span_attrs (cnx , cnx_span , kwargs ) # type: ignore[arg-type]
388
411
389
- return TracedMySQLConnection (
390
- wrapped = cnx , # type: ignore[return-value, arg-type]
391
- )
412
+ return TracedMySQLConnection (
413
+ wrapped = cnx , # type: ignore[return-value, arg-type]
414
+ )
392
415
393
416
return wrapper
394
417
@@ -427,7 +450,7 @@ def instrument(self, **kwargs: Any) -> None:
427
450
if connector .connect != getattr (self , "_original_connect" ):
428
451
logger .warning ("MySQL Connector/Python module already instrumented." )
429
452
return
430
- connector .connect = instrument_connect (
453
+ connector .connect = _instrument_connect (
431
454
connect = getattr (self , "_original_connect" ),
432
455
tracer_provider = kwargs .get ("tracer_provider" ),
433
456
)
0 commit comments