Skip to content

Commit 4459f87

Browse files
committed
WL#16164: Implement support for new vector data type
This work log adds support for the new MySQL `VECTOR` type. Change-Id: Id73305c9eaf2a76c1ba10fea7916bf5f319dd324
1 parent f3e285e commit 4459f87

File tree

12 files changed

+1182
-46
lines changed

12 files changed

+1182
-46
lines changed

CHANGES.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ v8.4.0
1313

1414
- WL#16203: GPL License Exception Update
1515
- WL#16173: Update allowed cipher and cipher-suite lists
16+
- WL#16164: Implement support for new vector data type
1617
- WL#16053: Support GSSAPI/Kerberos authentication on Windows using authentication_ldap_sasl_client plug-in for C-extension
1718

1819
v8.3.0

mysql-connector-python/lib/mysql/connector/aio/charsets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def set_mysql_major_version(self, version: int) -> None:
9191
self._name_store.clear()
9292

9393
charsets_tuple: Sequence[Tuple[int, str, str, bool]] = None
94-
if version == 8:
94+
if version >= 8:
9595
charsets_tuple = MYSQL_8_CHARSETS
9696
elif version == 5:
9797
charsets_tuple = MYSQL_5_CHARSETS

mysql-connector-python/lib/mysql/connector/aio/protocol.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,13 +305,13 @@ async def read_text_result( # type: ignore[override]
305305
datas.append(packet[4:])
306306
packet = await sock.read()
307307
datas.append(packet[4:])
308-
rowdata = read_lc_string_list(bytearray(b"").join(datas))
308+
rowdata = read_lc_string_list(b"".join(datas))
309309
elif packet[4] == 254 and packet[0] < 7:
310310
eof = self.parse_eof(packet)
311311
rowdata = None
312312
else:
313313
eof = None
314-
rowdata = read_lc_string_list(packet[4:])
314+
rowdata = read_lc_string_list(bytes(packet[4:]))
315315
if eof is None and rowdata is not None:
316316
rows.append(rowdata)
317317
elif eof is None and rowdata is None:

mysql-connector-python/lib/mysql/connector/connection_cext.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050

5151
from . import version
5252
from .abstracts import CMySQLPrepStmt, MySQLConnectionAbstract
53-
from .constants import ClientFlag, FieldFlag, ServerFlag, ShutdownType
53+
from .constants import ClientFlag, FieldFlag, FieldType, ServerFlag, ShutdownType
5454
from .conversion import MySQLConverter
5555
from .errors import (
5656
InterfaceError,
@@ -473,18 +473,36 @@ def get_rows(
473473
# convert the values. This can be accomplished by setting
474474
# the raw option to True.
475475
self._cmysql.raw(True)
476+
476477
row = fetch_row()
477478
while row:
479+
row = list(row)
480+
481+
if not self._cmysql.raw() and not raw:
482+
# `not _cmysql.raw()` means the c-ext conversion layer will happen.
483+
# `not raw` means the caller wants conversion to happen.
484+
# For a VECTOR type, the c-ext conversion layer cannot return
485+
# an array.array type since such a type isn't part of the Python/C
486+
# API. Therefore, the c-ext will treat VECTOR types as if they
487+
# were BLOB types - be returned as `bytes` always.
488+
# Hence, a VECTOR type must be cast to an array.array type using the
489+
# built-in python conversion layer.
490+
# pylint: disable=protected-access
491+
for i, dsc in enumerate(self._columns):
492+
if dsc[1] == FieldType.VECTOR:
493+
row[i] = MySQLConverter._vector_to_python(row[i])
494+
478495
if not self._raw and self.converter:
479-
row = list(row)
480496
for i, _ in enumerate(row):
481497
if not raw:
482498
row[i] = self.converter.to_python(self._columns[i], row[i])
483-
row = tuple(row)
484-
rows.append(row)
499+
500+
rows.append(tuple(row))
485501
counter += 1
502+
486503
if count and counter == count:
487504
break
505+
488506
row = fetch_row()
489507
if not row:
490508
_eof: Optional[CextEofPacketType] = self.fetch_eof_columns(prep_stmt)[

mysql-connector-python/lib/mysql/connector/constants.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,12 @@
3939

4040
NET_BUFFER_LENGTH: int = 8192
4141
MAX_MYSQL_TABLE_COLUMNS: int = 4096
42-
# Flag used to send the Query Attributes with 0 (or more) parameters.
4342
PARAMETER_COUNT_AVAILABLE: int = 8
43+
"""Flag used to send the Query Attributes with 0 (or more) parameters."""
44+
MYSQL_VECTOR_TYPE_CODE = "f"
45+
"""Expected `typecode` when decoding VECTOR values from
46+
MySQL (blob) to Python (array.array).
47+
"""
4448

4549
DEFAULT_CONFIGURATION: Dict[str, Optional[Union[str, bool, int]]] = {
4650
"database": None,
@@ -235,6 +239,7 @@ class FieldType(_Constants):
235239
NEWDATE: int = 0x0E
236240
VARCHAR: int = 0x0F
237241
BIT: int = 0x10
242+
VECTOR: int = 0xF2
238243
JSON: int = 0xF5
239244
NEWDECIMAL: int = 0xF6
240245
ENUM: int = 0xF7
@@ -248,34 +253,35 @@ class FieldType(_Constants):
248253
GEOMETRY: int = 0xFF
249254

250255
desc: Dict[str, Tuple[int, str]] = {
251-
"DECIMAL": (0x00, "DECIMAL"),
252-
"TINY": (0x01, "TINY"),
253-
"SHORT": (0x02, "SHORT"),
254-
"LONG": (0x03, "LONG"),
255-
"FLOAT": (0x04, "FLOAT"),
256-
"DOUBLE": (0x05, "DOUBLE"),
257-
"NULL": (0x06, "NULL"),
258-
"TIMESTAMP": (0x07, "TIMESTAMP"),
259-
"LONGLONG": (0x08, "LONGLONG"),
260-
"INT24": (0x09, "INT24"),
261-
"DATE": (0x0A, "DATE"),
262-
"TIME": (0x0B, "TIME"),
263-
"DATETIME": (0x0C, "DATETIME"),
264-
"YEAR": (0x0D, "YEAR"),
265-
"NEWDATE": (0x0E, "NEWDATE"),
266-
"VARCHAR": (0x0F, "VARCHAR"),
267-
"BIT": (0x10, "BIT"),
268-
"JSON": (0xF5, "JSON"),
269-
"NEWDECIMAL": (0xF6, "NEWDECIMAL"),
270-
"ENUM": (0xF7, "ENUM"),
271-
"SET": (0xF8, "SET"),
272-
"TINY_BLOB": (0xF9, "TINY_BLOB"),
273-
"MEDIUM_BLOB": (0xFA, "MEDIUM_BLOB"),
274-
"LONG_BLOB": (0xFB, "LONG_BLOB"),
275-
"BLOB": (0xFC, "BLOB"),
276-
"VAR_STRING": (0xFD, "VAR_STRING"),
277-
"STRING": (0xFE, "STRING"),
278-
"GEOMETRY": (0xFF, "GEOMETRY"),
256+
"DECIMAL": (DECIMAL, "DECIMAL"),
257+
"TINY": (TINY, "TINY"),
258+
"SHORT": (SHORT, "SHORT"),
259+
"LONG": (LONG, "LONG"),
260+
"FLOAT": (FLOAT, "FLOAT"),
261+
"DOUBLE": (DOUBLE, "DOUBLE"),
262+
"NULL": (NULL, "NULL"),
263+
"TIMESTAMP": (TIMESTAMP, "TIMESTAMP"),
264+
"LONGLONG": (LONGLONG, "LONGLONG"),
265+
"INT24": (INT24, "INT24"),
266+
"DATE": (DATE, "DATE"),
267+
"TIME": (TIME, "TIME"),
268+
"DATETIME": (DATETIME, "DATETIME"),
269+
"YEAR": (YEAR, "YEAR"),
270+
"NEWDATE": (NEWDATE, "NEWDATE"),
271+
"VARCHAR": (VARCHAR, "VARCHAR"),
272+
"BIT": (BIT, "BIT"),
273+
"VECTOR": (VECTOR, "VECTOR"),
274+
"JSON": (JSON, "JSON"),
275+
"NEWDECIMAL": (NEWDECIMAL, "NEWDECIMAL"),
276+
"ENUM": (ENUM, "ENUM"),
277+
"SET": (SET, "SET"),
278+
"TINY_BLOB": (TINY_BLOB, "TINY_BLOB"),
279+
"MEDIUM_BLOB": (MEDIUM_BLOB, "MEDIUM_BLOB"),
280+
"LONG_BLOB": (LONG_BLOB, "LONG_BLOB"),
281+
"BLOB": (BLOB, "BLOB"),
282+
"VAR_STRING": (VAR_STRING, "VAR_STRING"),
283+
"STRING": (STRING, "STRING"),
284+
"GEOMETRY": (GEOMETRY, "GEOMETRY"),
279285
}
280286

281287
@classmethod

mysql-connector-python/lib/mysql/connector/conversion.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"""Converting MySQL and Python types
3030
"""
3131

32+
import array
3233
import datetime
3334
import math
3435
import struct
@@ -37,7 +38,7 @@
3738
from decimal import Decimal
3839
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
3940

40-
from .constants import CharacterSet, FieldFlag, FieldType
41+
from .constants import MYSQL_VECTOR_TYPE_CODE, CharacterSet, FieldFlag, FieldType
4142
from .custom_types import HexLiteral
4243
from .types import (
4344
DescriptionType,
@@ -740,6 +741,22 @@ def _blob_to_python(
740741
return bytes(value)
741742
return self._string_to_python(value, dsc)
742743

744+
@staticmethod
745+
def _vector_to_python(
746+
value: Optional[bytes], desc: Optional[DescriptionType] = None
747+
) -> Optional[array.array]:
748+
"""
749+
Converts a MySQL VECTOR value to a Python array.array type.
750+
751+
Returns an array of floats if `value` isn't `None`, otherwise `None`.
752+
"""
753+
if value is None or isinstance(value, array.array):
754+
return value
755+
elif isinstance(value, (bytes, bytearray)):
756+
return array.array(MYSQL_VECTOR_TYPE_CODE, value)
757+
else:
758+
raise TypeError(f"Got unsupported type {value.__class__.__name__}")
759+
743760
_long_blob_to_python = _blob_to_python
744761
_medium_blob_to_python = _blob_to_python
745762
_tiny_blob_to_python = _blob_to_python

mysql-connector-python/lib/mysql/connector/protocol.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
FieldType,
5555
ServerCmd,
5656
)
57+
from .conversion import MySQLConverter
5758
from .errors import DatabaseError, InterfaceError, ProgrammingError, get_exception
5859
from .logger import logger
5960
from .plugins import MySQLAuthPlugin, get_auth_plugin
@@ -798,6 +799,10 @@ def _parse_binary_values(
798799
elif field[1] == FieldType.TIME:
799800
(packet, value) = self._parse_binary_time(packet)
800801
values.append(value)
802+
elif field[1] == FieldType.VECTOR:
803+
# pylint: disable=protected-access
804+
(packet, value) = utils.read_lc_string(packet)
805+
values.append(MySQLConverter._vector_to_python(value))
801806
elif field[7] == FieldFlag.BINARY or field[8] == 63: # "binary" charset
802807
(packet, value) = utils.read_lc_string(packet)
803808
values.append(value)

mysql-connector-python/src/mysql_capi.c

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2695,7 +2695,7 @@ MySQL_fetch_row(MySQL *self)
26952695
}
26962696
else {
26972697
PyTuple_SET_ITEM(result_row, i,
2698-
PyByteArray_FromStringAndSize(row[i], field_lengths[i]));
2698+
PyBytes_FromStringAndSize(row[i], field_lengths[i]));
26992699
}
27002700
continue;
27012701
}
@@ -2782,7 +2782,12 @@ MySQL_fetch_row(MySQL *self)
27822782
else if (field_type == MYSQL_TYPE_BIT) {
27832783
PyTuple_SET_ITEM(result_row, i, mytopy_bit(row[i], field_lengths[i]));
27842784
}
2785-
else if (field_type == MYSQL_TYPE_BLOB) {
2785+
#if MYSQL_VERSION_ID >= 90000
2786+
else if (field_type == MYSQL_TYPE_BLOB || field_type == MYSQL_TYPE_VECTOR)
2787+
#else
2788+
else if (field_type == MYSQL_TYPE_BLOB)
2789+
#endif
2790+
{
27862791
if ((field_flags & BLOB_FLAG) &&
27872792
(field_flags & BINARY_FLAG) && field_charsetnr == 63) {
27882793
value = PyBytes_FromStringAndSize(row[i], field_lengths[i]);
@@ -2795,7 +2800,7 @@ MySQL_fetch_row(MySQL *self)
27952800
}
27962801
else if (field_type == MYSQL_TYPE_GEOMETRY) {
27972802
PyTuple_SET_ITEM(result_row, i,
2798-
PyByteArray_FromStringAndSize(row[i], field_lengths[i]));
2803+
PyBytes_FromStringAndSize(row[i], field_lengths[i]));
27992804
}
28002805
else {
28012806
// Do our best to convert whatever we got from MySQL to a str/bytes
@@ -3740,8 +3745,8 @@ MySQLPrepStmt_fetch_row(MySQLPrepStmt *self)
37403745
Py_XDECREF(obj);
37413746
}
37423747
else if (field->type == MYSQL_TYPE_GEOMETRY) {
3743-
obj = PyByteArray_FromStringAndSize(NULL, self->cols[i].length);
3744-
self->bind[i].buffer = PyByteArray_AsString(obj);
3748+
obj = PyBytes_FromStringAndSize(NULL, self->cols[i].length);
3749+
self->bind[i].buffer = PyBytes_AsString(obj);
37453750
self->bind[i].buffer_length = self->cols[i].length;
37463751

37473752
Py_BEGIN_ALLOW_THREADS

mysql-connector-python/src/mysql_capi_conversion.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -671,13 +671,13 @@ mytopy_string(const char *data, enum_field_types field_type,
671671
}
672672

673673
if (strcmp(charset, "binary") == 0) {
674-
return PyByteArray_FromStringAndSize(data, field_length);
674+
return PyBytes_FromStringAndSize(data, field_length);
675675
}
676676

677677
/* 'binary' charset = 63 */
678678
if (use_unicode && (field_type == MYSQL_TYPE_JSON || field_charsetnr != 63)) {
679679
return PyUnicode_Decode(data, field_length, charset, "replace");
680680
}
681681

682-
return PyByteArray_FromStringAndSize(data, field_length);
682+
return PyBytes_FromStringAndSize(data, field_length);
683683
}

mysql-connector-python/tests/cext/test_cext_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -833,8 +833,8 @@ def test_query(self):
833833

834834
self.assertTrue(cmy.query("SELECT 'ham', 'spam', 5", raw=True))
835835
row = cmy.fetch_row()
836-
self.assertTrue(isinstance(row[0], bytearray))
837-
self.assertEqual(bytearray(b"spam"), row[1])
836+
self.assertTrue(isinstance(row[0], bytes))
837+
self.assertEqual(b"spam", row[1])
838838
self.assertEqual(None, cmy.fetch_row())
839839
cmy.free_result()
840840

0 commit comments

Comments
 (0)