78
78
from .crud import Schema
79
79
from .constants import SSLMode , Auth
80
80
from .helpers import escape , get_item_or_attr , iani_to_openssl_cs_name
81
- from .protocol import Protocol , MessageReaderWriter
81
+ from .protocol import Protocol , MessageReader , MessageWriter , HAVE_LZ4
82
82
from .result import Result , RowResult , SqlResult , DocResult
83
83
from .statement import SqlStatement , AddStatement , quote_identifier
84
84
from .protobuf import Protobuf
@@ -558,7 +558,6 @@ class Connection(object):
558
558
def __init__ (self , settings ):
559
559
self .settings = settings
560
560
self .stream = SocketStream ()
561
- self .reader_writer = None
562
561
self .protocol = None
563
562
self .keep_open = None
564
563
self ._user = settings .get ("user" )
@@ -617,10 +616,31 @@ def connect(self):
617
616
router = self .router_manager .get_next_router ()
618
617
self .stream .connect (router .get_connection_params (),
619
618
self ._connect_timeout )
620
- self .reader_writer = MessageReaderWriter (self .stream )
621
- self .protocol = Protocol (self .reader_writer )
622
- self ._handle_capabilities ()
619
+ reader = MessageReader (self .stream )
620
+ writer = MessageWriter (self .stream )
621
+ self .protocol = Protocol (reader , writer )
622
+
623
+ caps_data = self .protocol .get_capabilites ().capabilities
624
+ caps = {
625
+ get_item_or_attr (cap , "name" ).lower ():
626
+ cap for cap in caps_data
627
+ } if caps_data else {}
628
+
629
+ # Set TLS capabilities
630
+ self ._set_tls_capabilities (caps )
631
+
632
+ # Set connection attributes capabilities
633
+ if "attributes" in self .settings :
634
+ conn_attrs = self .settings ["attributes" ]
635
+ self .protocol .set_capabilities (
636
+ session_connect_attrs = conn_attrs )
637
+
638
+ # Set compression capabilities
639
+ compression = self .settings .get ("compression" , "preferred" )
640
+ algorithm = None if compression == "disabled" \
641
+ else self ._set_compression_capabilities (caps , compression )
623
642
self ._authenticate ()
643
+ self .protocol .set_compression (algorithm )
624
644
return
625
645
except (socket .error , RuntimeError ) as err :
626
646
error = err
@@ -643,30 +663,31 @@ def connect(self):
643
663
raise InterfaceError ("Cannot connect to host: {0}" .format (error ))
644
664
raise InterfaceError ("Unable to connect to any of the target hosts" , 4001 )
645
665
646
- def _handle_capabilities (self ):
647
- """Handle capabilities.
666
+ def _set_tls_capabilities (self , caps ):
667
+ """Sets the TLS capabilities.
668
+
669
+ Args:
670
+ caps (dict): Dictionary with the server capabilities.
648
671
649
672
Raises:
650
673
:class:`mysqlx.OperationalError`: If SSL is not enabled at the
651
674
server.
652
675
:class:`mysqlx.RuntimeError`: If support for SSL is not available
653
676
in Python.
677
+
678
+ .. versionadded:: 8.0.21
654
679
"""
655
680
if self .settings .get ("ssl-mode" ) == SSLMode .DISABLED :
656
681
return
682
+
657
683
if self .stream .is_socket ():
658
684
if self .settings .get ("ssl-mode" ):
659
685
_LOGGER .warning ("SSL not required when using Unix socket." )
660
686
return
661
687
662
- try :
663
- data = self .protocol .get_capabilites ().capabilities
664
- if not (get_item_or_attr (data [0 ], "name" ).lower () == "tls"
665
- if data else False ):
666
- self .close_connection ()
667
- raise OperationalError ("SSL not enabled at server" )
668
- except (AttributeError , KeyError ):
669
- pass
688
+ if "tls" not in caps :
689
+ self .close_connection ()
690
+ raise OperationalError ("SSL not enabled at server" )
670
691
671
692
is_ol7 = False
672
693
if platform .system () == "Linux" :
@@ -694,6 +715,65 @@ def _handle_capabilities(self):
694
715
conn_attrs = self .settings ["attributes" ]
695
716
self .protocol .set_capabilities (session_connect_attrs = conn_attrs )
696
717
718
+ def _set_compression_capabilities (self , caps , compression ):
719
+ """Sets the compression capabilities.
720
+
721
+ If compression is available, negociates client and server algorithms.
722
+ Using the following priority:
723
+
724
+ 1) lz4_message
725
+ 2) deflate_stream
726
+
727
+ Args:
728
+ caps (dict): Dictionary with the server capabilities.
729
+ compression (str): The compression connection setting.
730
+
731
+ Returns:
732
+ str: The compression algorithm.
733
+
734
+ .. versionadded:: 8.0.21
735
+ """
736
+ compression_data = caps .get ("compression" )
737
+ if compression_data is None :
738
+ msg = "Compression requested but the server does not support it"
739
+ if compression == "required" :
740
+ raise NotSupportedError (msg )
741
+ else :
742
+ _LOGGER .warning (msg )
743
+ return None
744
+
745
+ compression_dict = {}
746
+ if isinstance (compression_data , dict ): # C extension is being used
747
+ for fld in compression_data ["value" ]["obj" ]["fld" ]:
748
+ compression_dict [fld ["key" ]] = [
749
+ value ["scalar" ]["v_string" ]["value" ].decode ("utf-8" )
750
+ for value in fld ["value" ]["array" ]["value" ]
751
+ ]
752
+ else :
753
+ for fld in compression_data .value .obj .fld :
754
+ compression_dict [fld .key ] = [
755
+ value .scalar .v_string .value .decode ("utf-8" )
756
+ for value in fld .value .array .value
757
+ ]
758
+
759
+ server_algorithms = compression_dict .get ("algorithm" , [])
760
+ if HAVE_LZ4 and "lz4_message" in server_algorithms :
761
+ algorithm = "lz4_message"
762
+ else :
763
+ algorithm = "deflate_stream"
764
+
765
+ if algorithm not in server_algorithms :
766
+ msg = ("Compression requested but the compression algorithm "
767
+ "negotiation failed" )
768
+ if compression == "required" :
769
+ raise InterfaceError (msg )
770
+ else :
771
+ _LOGGER .warning (msg )
772
+ return None
773
+
774
+ self .protocol .set_capabilities (compression = {"algorithm" : algorithm })
775
+ return algorithm
776
+
697
777
def _authenticate (self ):
698
778
"""Authenticate with the MySQL server."""
699
779
auth = self .settings .get ("auth" )
0 commit comments