Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 61 additions & 49 deletions mssql_python/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef
# Initialize encoding settings with defaults for Python 3
# Python 3 only has str (which is Unicode), so we use utf-16le by default
self._encoding_settings = {
'encoding': 'utf-16le',
'ctype': ConstantsDDBC.SQL_WCHAR.value
'encoding': 'utf-8',
'ctype': ConstantsDDBC.SQL_CHAR.value
}

# Initialize decoding settings with Python 3 defaults
Expand Down Expand Up @@ -326,13 +326,11 @@ def setencoding(self, encoding=None, ctype=None):
Raises:
ProgrammingError: If the encoding is not valid or not supported.
InterfaceError: If the connection is closed.
ValueError: If attempting to use non-UTF-16LE encoding with SQL_WCHAR.

Example:
# For databases that only communicate with UTF-8
cnxn.setencoding(encoding='utf-8')

# For explicitly using SQL_CHAR
cnxn.setencoding(encoding='utf-8', ctype=mssql_python.SQL_CHAR)
Note:
SQL_WCHAR must always use UTF-16LE encoding as required by SQL Server.
Custom encodings are only supported with SQL_CHAR.
"""
if self._closed:
raise InterfaceError(
Expand All @@ -342,7 +340,7 @@ def setencoding(self, encoding=None, ctype=None):

# Set default encoding if not provided
if encoding is None:
encoding = 'utf-16le'
encoding = 'utf-8'

# Validate encoding using cached validation for better performance
if not _validate_encoding(encoding):
Expand Down Expand Up @@ -373,6 +371,14 @@ def setencoding(self, encoding=None, ctype=None):
ddbc_error=f"ctype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}) or SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value})",
)

# Enforce UTF-16LE for SQL_WCHAR
if ctype == ConstantsDDBC.SQL_WCHAR.value and encoding not in UTF16_ENCODINGS:
raise ProgrammingError(
driver_error=f"SQL_WCHAR requires UTF-16LE encoding",
ddbc_error=f"SQL_WCHAR must use UTF-16LE encoding. '{encoding}' is not supported for SQL_WCHAR. "
f"Use SQL_CHAR if you need to use '{encoding}' encoding."
)

# Store the encoding settings
self._encoding_settings = {
'encoding': encoding,
Expand Down Expand Up @@ -428,16 +434,12 @@ def setdecoding(self, sqltype, encoding=None, ctype=None):
Raises:
ProgrammingError: If the sqltype, encoding, or ctype is invalid.
InterfaceError: If the connection is closed.
ValueError: If attempting to use non-UTF-16LE encoding with SQL_WCHAR.

Example:
# Configure SQL_CHAR to use UTF-8 decoding
cnxn.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8')

# Configure column metadata decoding
cnxn.setdecoding(mssql_python.SQL_WMETADATA, encoding='utf-16le')

# Use explicit ctype
cnxn.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16le', ctype=mssql_python.SQL_WCHAR)
Note:
SQL_WCHAR and SQL_WMETADATA data from SQL Server is always encoded as UTF-16LE
and must use SQL_WCHAR ctype as required by the SQL Server ODBC driver.
Custom encodings are only supported for SQL_CHAR.
"""
if self._closed:
raise InterfaceError(
Expand All @@ -458,39 +460,49 @@ def setdecoding(self, sqltype, encoding=None, ctype=None):
ddbc_error=f"sqltype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}), SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value}), or SQL_WMETADATA ({SQL_WMETADATA})",
)

# Set default encoding based on sqltype if not provided
if encoding is None:
if sqltype == ConstantsDDBC.SQL_CHAR.value:
# For SQL_WCHAR and SQL_WMETADATA, enforce UTF-16LE encoding and SQL_WCHAR ctype
if sqltype in (ConstantsDDBC.SQL_WCHAR.value, SQL_WMETADATA):
if encoding is not None and encoding.lower() not in UTF16_ENCODINGS:
raise ProgrammingError(
driver_error=f"SQL_WCHAR and SQL_WMETADATA must use UTF-16LE encoding. '{encoding}' is not supported.",
ddbc_error=f"Custom encodings are only supported for SQL_CHAR. '{encoding}' is not valid for SQL_WCHAR or SQL_WMETADATA."
)
# Always enforce UTF-16LE for wide character types
encoding = 'utf-16le'
# Always enforce SQL_WCHAR ctype for wide character types
ctype = ConstantsDDBC.SQL_WCHAR.value
else:
# For SQL_CHAR, allow custom encoding settings
# Set default encoding for SQL_CHAR if not provided
if encoding is None:
encoding = 'utf-8' # Default for SQL_CHAR in Python 3
else: # SQL_WCHAR or SQL_WMETADATA
encoding = 'utf-16le' # Default for SQL_WCHAR in Python 3

# Validate encoding using cached validation for better performance
if not _validate_encoding(encoding):
log('warning', "Invalid encoding attempted: %s", sanitize_user_input(str(encoding)))
raise ProgrammingError(
driver_error=f"Unsupported encoding: {encoding}",
ddbc_error=f"The encoding '{encoding}' is not supported by Python",
)

# Normalize encoding to lowercase for consistency
encoding = encoding.lower()

# Set default ctype based on encoding if not provided
if ctype is None:
if encoding in UTF16_ENCODINGS:
ctype = ConstantsDDBC.SQL_WCHAR.value
else:
ctype = ConstantsDDBC.SQL_CHAR.value

# Validate ctype
valid_ctypes = [ConstantsDDBC.SQL_CHAR.value, ConstantsDDBC.SQL_WCHAR.value]
if ctype not in valid_ctypes:
log('warning', "Invalid ctype attempted: %s", sanitize_user_input(str(ctype)))
raise ProgrammingError(
driver_error=f"Invalid ctype: {ctype}",
ddbc_error=f"ctype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}) or SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value})",
)
# Validate encoding
if not _validate_encoding(encoding):
log('warning', "Invalid encoding attempted: %s", sanitize_user_input(str(encoding)))
raise ProgrammingError(
driver_error=f"Unsupported encoding: {encoding}",
ddbc_error=f"The encoding '{encoding}' is not supported by Python",
)

# Normalize encoding to lowercase for consistency
encoding = encoding.lower()

# Set default ctype based on encoding if not provided
if ctype is None:
if encoding in UTF16_ENCODINGS:
ctype = ConstantsDDBC.SQL_WCHAR.value
else:
ctype = ConstantsDDBC.SQL_CHAR.value

# Validate ctype
valid_ctypes = [ConstantsDDBC.SQL_CHAR.value, ConstantsDDBC.SQL_WCHAR.value]
if ctype not in valid_ctypes:
log('warning', "Invalid ctype attempted: %s", sanitize_user_input(str(ctype)))
raise ProgrammingError(
driver_error=f"Invalid ctype: {ctype}",
ddbc_error=f"ctype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}) or SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value})",
)

# Store the decoding settings for the specified sqltype
self._decoding_settings[sqltype] = {
Expand Down
78 changes: 73 additions & 5 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from mssql_python.constants import ConstantsDDBC as ddbc_sql_const, SQLTypes
from mssql_python.helpers import check_error, log
from mssql_python import ddbc_bindings
from mssql_python.exceptions import InterfaceError, NotSupportedError, ProgrammingError
from mssql_python.exceptions import InterfaceError, NotSupportedError, ProgrammingError, OperationalError, DatabaseError
from mssql_python.row import Row
from mssql_python import get_settings

Expand Down Expand Up @@ -104,6 +104,53 @@ def __init__(self, connection, timeout: int = 0) -> None:

self.messages = [] # Store diagnostic messages

def _get_encoding_settings(self):
"""
Get the encoding settings from the connection.

Returns:
dict: A dictionary with 'encoding' and 'ctype' keys, or default settings if not available
"""
if hasattr(self._connection, 'getencoding'):
try:
return self._connection.getencoding()
except (OperationalError, DatabaseError) as db_error:
# Only catch database-related errors, not programming errors
log('warning', f"Failed to get encoding settings from connection due to database error: {db_error}")
return {
'encoding': 'utf-8',
'ctype': ddbc_sql_const.SQL_CHAR.value
}
# Let programming errors (AttributeError, TypeError, etc.) propagate up the stack

# Return default encoding settings if getencoding is not available
return {
'encoding': 'utf-8',
'ctype': ddbc_sql_const.SQL_CHAR.value
}

def _get_decoding_settings(self, sql_type):
"""
Get decoding settings for a specific SQL type.

Args:
sql_type: SQL type constant (SQL_CHAR, SQL_WCHAR, etc.)

Returns:
Dictionary containing the decoding settings.
"""
try:
# Get decoding settings from connection for this SQL type
return self._connection.getdecoding(sql_type)
except (OperationalError, DatabaseError) as db_error:
# Only handle expected database-related errors
log('warning', f"Failed to get decoding settings for SQL type {sql_type} due to database error: {db_error}")
if sql_type == ddbc_sql_const.SQL_WCHAR.value:
return {'encoding': 'utf-16le', 'ctype': ddbc_sql_const.SQL_WCHAR.value}
else:
return {'encoding': 'utf-8', 'ctype': ddbc_sql_const.SQL_CHAR.value}
# Let programming errors propagate up the stack - we want to know if there's a bug

def _is_unicode_string(self, param):
"""
Check if a string contains non-ASCII characters.
Expand Down Expand Up @@ -966,6 +1013,8 @@ def execute(
parameters_type[i].decimalDigits,
parameters_type[i].inputOutputType,
)

encoding_settings = self._get_encoding_settings()

ret = ddbc_bindings.DDBCSQLExecute(
self.hstmt,
Expand All @@ -974,6 +1023,8 @@ def execute(
parameters_type,
self.is_stmt_prepared,
use_prepare,
encoding_settings.get('encoding'),
encoding_settings.get('ctype')
)
# Check return code
try:
Expand Down Expand Up @@ -1666,12 +1717,16 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
len(seq_of_parameters), "\n".join(f" {i+1}: {tuple(p) if isinstance(p, (list, tuple)) else p}" for i, p in enumerate(seq_of_parameters[:5])) # Limit to first 5 rows for large batches
)

encoding_settings = self._get_encoding_settings()

ret = ddbc_bindings.SQLExecuteMany(
self.hstmt,
operation,
columnwise_params,
parameters_type,
row_count
row_count,
encoding_settings.get('encoding'),
encoding_settings.get('ctype')
)

# Capture any diagnostic messages after execution
Expand Down Expand Up @@ -1703,10 +1758,14 @@ def fetchone(self) -> Union[None, Row]:
"""
self._check_closed() # Check if the cursor is closed

# Get decoding settings for character data
char_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_CHAR.value)
wchar_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_WCHAR.value)

# Fetch raw data
row_data = []
try:
ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_data)
ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_data, char_decoding.get('encoding', 'utf-8'), wchar_decoding.get('encoding', 'utf-16le'))

if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
Expand Down Expand Up @@ -1753,11 +1812,16 @@ def fetchmany(self, size: int = None) -> List[Row]:

if size <= 0:
return []

# Get decoding settings for character data
char_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_CHAR.value)
wchar_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_WCHAR.value)

# Fetch raw data
rows_data = []
try:
ret = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows_data, size)
ret = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows_data, size, char_decoding.get('encoding', 'utf-8'), wchar_decoding.get('encoding', 'utf-16le'))


if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
Expand Down Expand Up @@ -1793,10 +1857,14 @@ def fetchall(self) -> List[Row]:
if not self._has_result_set and self.description:
self._reset_rownumber()

# Get decoding settings for character data
char_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_CHAR.value)
wchar_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_WCHAR.value)

# Fetch raw data
rows_data = []
try:
ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data)
ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data, char_decoding.get('encoding', 'utf-8'), wchar_decoding.get('encoding', 'utf-16le'))

if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
Expand Down
Loading
Loading