Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 58 additions & 6 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,53 @@ def next(self):
"""
return self.__next__()

def _buffer_intermediate_results(self):
"""
Buffer intermediate results automatically.

This method skips "rows affected" messages and empty result sets,
positioning the cursor on the first meaningful result set that contains
actual data. This eliminates the need for SET NOCOUNT ON detection.
"""
try:
# Keep advancing through result sets until we find one with actual data
# or reach the end
while True:
# Check if current result set has actual columns/data
if self.description and len(self.description) > 0:
# We have a meaningful result set with columns, stop here
break

# Try to advance to next result set
try:
ret = ddbc_bindings.DDBCSQLMoreResults(self.hstmt)

# If no more result sets, we're done
if ret == ddbc_sql_const.SQL_NO_DATA.value:
break

# Check for errors
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)

# Update description for the new result set
column_metadata = []
try:
ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata)
self._initialize_description(column_metadata)
except Exception:
# If describe fails, it's likely there are no results (e.g., for INSERT)
self.description = None

except Exception:
# If we can't advance further, stop
break

except Exception as e:
log('warning', "Exception occurred during `_buffer_intermediate_results` %s", e)
# If anything goes wrong during buffering, continue with current state
# This ensures we don't break existing functionality
pass

def execute(
self,
operation: str,
Expand Down Expand Up @@ -965,6 +1012,7 @@ def execute(
# Executing a new statement. Reset is_stmt_prepared to false
self.is_stmt_prepared = [False]


log('debug', "Executing query: %s", operation)
for i, param in enumerate(parameters):
log('debug',
Expand Down Expand Up @@ -1005,9 +1053,9 @@ def execute(

self.last_executed_stmt = operation

# Update rowcount after execution
# Update rowcount after execution (before buffering)
# TODO: rowcount return code from SQL needs to be handled
self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt)
initial_rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt)

# Initialize description after execution
# After successful execution, initialize description if there are results
Expand All @@ -1019,12 +1067,16 @@ def execute(
# If describe fails, it's likely there are no results (e.g., for INSERT)
self.description = None

# Reset rownumber for new result set (only for SELECT statements)
# Buffer intermediate results automatically
self._buffer_intermediate_results()

# Set final rowcount based on result type (preserve original rowcount for non-SELECT)
if self.description: # If we have column descriptions, it's likely a SELECT
self.rowcount = -1
self._reset_rownumber()
else:
self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt)
# For non-SELECT statements (INSERT/UPDATE/DELETE), preserve the original rowcount
self.rowcount = initial_rowcount
self._clear_rownumber()

# After successful execution, initialize description if there are results
Expand Down Expand Up @@ -2183,11 +2235,11 @@ def tables(self, table=None, catalog=None, schema=None, tableType=None):
("table_type", str, None, 128, 128, 0, False),
("remarks", str, None, 254, 254, 0, True)
]

# Use the helper method to prepare the result set
return self._prepare_metadata_result_set(fallback_description=fallback_description)

except Exception as e:
# Log the error and re-raise
log('error', f"Error executing tables query: {e}")
raise
raise
145 changes: 144 additions & 1 deletion tests/test_004_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10744,6 +10744,149 @@ def test_datetime_string_parameter_binding(cursor, db_connection):
drop_table_if_exists(cursor, table_name)
db_connection.commit()

def test_multi_statement_query(cursor, db_connection):
"""Test multi-statement query with temp tables"""
table_name = "#temp1"
try:
drop_table_if_exists(cursor, table_name)
# Single SQL with multiple statements - tests pyODBC-style buffering
multi_statement_sql = f"""
SELECT 1 as col1, 'test' as col2 INTO {table_name};
SELECT * FROM {table_name};
"""

cursor.execute(multi_statement_sql)
results = cursor.fetchall()

assert len(results) > 0, "Multi-statement query should return results"
assert results[0][1] == 'test', "Should return string 'test'"

except Exception as e:
pytest.fail(f"Multi-statement query test failed: {e}")
finally:
drop_table_if_exists(cursor, table_name)
db_connection.commit()

def test_multiple_result_sets_with_nextset(cursor, db_connection):
"""Test multiple result sets with multiple select statements on temp tables with nextset()"""
table_name1 = "#TempData1"
table_name2 = "#TempData2"
try:
drop_table_if_exists(cursor, table_name1)
drop_table_if_exists(cursor, table_name2)

# Create temp tables and execute multiple SELECT statements
multi_select_sql = f"""
CREATE TABLE {table_name1} (id INT, name NVARCHAR(50));
INSERT INTO {table_name1} VALUES (1, 'First'), (2, 'Second');

CREATE TABLE {table_name2} (id INT, value INT);
INSERT INTO {table_name2} VALUES (1, 100), (2, 200);

SELECT id, name FROM {table_name1} ORDER BY id;
SELECT id, value FROM {table_name2} ORDER BY id;
SELECT t1.name, t2.value FROM {table_name1} t1 JOIN {table_name2} t2 ON t1.id = t2.id ORDER BY t1.id;
"""

cursor.execute(multi_select_sql)

# First result set
results1 = cursor.fetchall()
assert len(results1) == 2, "First result set should have 2 rows"
assert results1[0][1] == 'First', "First row should contain 'First'"

# Move to second result set
assert cursor.nextset() is True, "Should have second result set"
results2 = cursor.fetchall()
assert len(results2) == 2, "Second result set should have 2 rows"
assert results2[0][1] == 100, "First row should contain value 100"

# Move to third result set
assert cursor.nextset() is True, "Should have third result set"
results3 = cursor.fetchall()
assert len(results3) == 2, "Third result set should have 2 rows"
assert results3[0][0] == 'First', "First row should contain 'First'"
assert results3[0][1] == 100, "First row should contain value 100"

# Check if there are more result sets (there shouldn't be any more SELECT results)
next_result = cursor.nextset()
if next_result is not None:
# If there are more, they should be empty (from CREATE/INSERT statements)
remaining_results = cursor.fetchall()
assert len(remaining_results) == 0, "Any remaining result sets should be empty"

except Exception as e:
pytest.fail(f"Multiple result sets with nextset test failed: {e}")
finally:
drop_table_if_exists(cursor, table_name1)
drop_table_if_exists(cursor, table_name2)
db_connection.commit()

def test_semicolons_in_string_literals(cursor, db_connection):
"""Test semicolons in string literals to ensure no false positives in buffering logic"""
table_name = "#StringTest"
try:
drop_table_if_exists(cursor, table_name)
# SQL with semicolons inside string literals - should not be treated as statement separators
sql_with_semicolons = f"""
CREATE TABLE {table_name} (id INT, data NVARCHAR(200));
INSERT INTO {table_name} VALUES
(1, 'Value with; semicolon inside'),
(2, 'Another; value; with; multiple; semicolons'),
(3, 'Normal value');
SELECT id, data, 'Status: OK; Processing: Complete' as status_message
FROM {table_name}
WHERE data LIKE '%semicolon%' OR data = 'Normal value'
ORDER BY id;
"""

cursor.execute(sql_with_semicolons)
results = cursor.fetchall()

assert len(results) == 3, "Should return 3 rows"
assert 'semicolon inside' in results[0][1], "Should preserve semicolon in string literal"
assert 'multiple; semicolons' in results[1][1], "Should preserve multiple semicolons in string literal"
assert 'Status: OK; Processing: Complete' in results[0][2], "Should preserve semicolons in status message"

except Exception as e:
pytest.fail(f"Semicolons in string literals test failed: {e}")
finally:
drop_table_if_exists(cursor, table_name)
db_connection.commit()

def test_multi_statement_batch_final_non_select(cursor, db_connection):
"""Test multi-statement batch where the final statement is not a SELECT"""
table_name = "#BatchTest"
try:
drop_table_if_exists(cursor, table_name)
# Multi-statement batch ending with non-SELECT statement
multi_statement_non_select = f"""
CREATE TABLE {table_name} (id INT, name NVARCHAR(50), created_at DATETIME);
INSERT INTO {table_name} VALUES (1, 'Test1', GETDATE()), (2, 'Test2', GETDATE());
SELECT COUNT(*) as record_count FROM {table_name};
UPDATE {table_name} SET name = name + '_updated' WHERE id IN (1, 2);
"""

cursor.execute(multi_statement_non_select)

# Should be able to fetch results from the SELECT statement
results = cursor.fetchall()
assert len(results) == 1, "Should return 1 row from COUNT query"
assert results[0][0] == 2, "Should count 2 records"

# Verify the UPDATE was executed by checking the updated records
cursor.execute(f"SELECT name FROM {table_name} ORDER BY id")
updated_results = cursor.fetchall()
assert len(updated_results) == 2, "Should have 2 updated records"
assert updated_results[0][0] == 'Test1_updated', "First record should be updated"
assert updated_results[1][0] == 'Test2_updated', "Second record should be updated"

except Exception as e:
pytest.fail(f"Multi-statement batch with final non-SELECT test failed: {e}")
finally:
drop_table_if_exists(cursor, table_name)
db_connection.commit()

def test_close(db_connection):
"""Test closing the cursor"""
try:
Expand All @@ -10753,4 +10896,4 @@ def test_close(db_connection):
except Exception as e:
pytest.fail(f"Cursor close test failed: {e}")
finally:
cursor = db_connection.cursor()
cursor = db_connection.cursor()