Skip to content

Commit e9c748a

Browse files
committed
- ensure rowcount is returned for an UPDATE with no implicit returning
- modernize test for that - use py3k compatible next() in test_returning/test_versioning
1 parent df1113a commit e9c748a

File tree

5 files changed

+35
-36
lines changed

5 files changed

+35
-36
lines changed

lib/sqlalchemy/engine/base.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -898,11 +898,10 @@ def _execute_context(self, dialect, constructor,
898898
elif not context._is_explicit_returning:
899899
result.close(_autoclose_connection=False)
900900
result._metadata = None
901-
elif context.isupdate:
902-
if context._is_implicit_returning:
903-
context._fetch_implicit_update_returning(result)
904-
result.close(_autoclose_connection=False)
905-
result._metadata = None
901+
elif context.isupdate and context._is_implicit_returning:
902+
context._fetch_implicit_update_returning(result)
903+
result.close(_autoclose_connection=False)
904+
result._metadata = None
906905

907906
elif result._metadata is None:
908907
# no results, get rowcount

lib/sqlalchemy/testing/mock.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
from ..util import py33
55

66
if py33:
7-
from unittest.mock import MagicMock, Mock, call
7+
from unittest.mock import MagicMock, Mock, call, patch
88
else:
99
try:
10-
from mock import MagicMock, Mock, call
10+
from mock import MagicMock, Mock, call, patch
1111
except ImportError:
1212
raise ImportError(
1313
"SQLAlchemy's test suite requires the "

test/engine/test_execute.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
import logging.handlers
1818
from sqlalchemy.dialects.oracle.zxjdbc import ReturningParam
1919
from sqlalchemy.engine import result as _result, default
20-
from sqlalchemy.engine.base import Connection, Engine
20+
from sqlalchemy.engine.base import Engine
2121
from sqlalchemy.testing import fixtures
22-
from sqlalchemy.testing.mock import Mock, call
22+
from sqlalchemy.testing.mock import Mock, call, patch
2323

2424

2525
users, metadata, users_autoinc = None, None, None
@@ -29,11 +29,11 @@ def setup_class(cls):
2929
global users, users_autoinc, metadata
3030
metadata = MetaData(testing.db)
3131
users = Table('users', metadata,
32-
Column('user_id', INT, primary_key = True, autoincrement=False),
32+
Column('user_id', INT, primary_key=True, autoincrement=False),
3333
Column('user_name', VARCHAR(20)),
3434
)
3535
users_autoinc = Table('users_autoinc', metadata,
36-
Column('user_id', INT, primary_key = True,
36+
Column('user_id', INT, primary_key=True,
3737
test_needs_autoincrement=True),
3838
Column('user_name', VARCHAR(20)),
3939
)
@@ -892,42 +892,42 @@ def __getitem__(self, i):
892892
def test_no_rowcount_on_selects_inserts(self):
893893
"""assert that rowcount is only called on deletes and updates.
894894
895-
This because cursor.rowcount can be expensive on some dialects
896-
such as Firebird.
895+
This because cursor.rowcount may can be expensive on some dialects
896+
such as Firebird, however many dialects require it be called
897+
before the cursor is closed.
897898
898899
"""
899900

900901
metadata = self.metadata
901902

902903
engine = engines.testing_engine()
903-
metadata.bind = engine
904904

905905
t = Table('t1', metadata,
906906
Column('data', String(10))
907907
)
908-
metadata.create_all()
908+
metadata.create_all(engine)
909909

910-
class BreakRowcountMixin(object):
911-
@property
912-
def rowcount(self):
913-
assert False
910+
with patch.object(engine.dialect.execution_ctx_cls, "rowcount") as mock_rowcount:
911+
mock_rowcount.__get__ = Mock()
912+
engine.execute(t.insert(),
913+
{'data': 'd1'},
914+
{'data': 'd2'},
915+
{'data': 'd3'})
914916

915-
execution_ctx_cls = engine.dialect.execution_ctx_cls
916-
engine.dialect.execution_ctx_cls = type("FakeCtx",
917-
(BreakRowcountMixin,
918-
execution_ctx_cls),
919-
{})
917+
eq_(len(mock_rowcount.__get__.mock_calls), 0)
920918

921-
try:
922-
r = t.insert().execute({'data': 'd1'}, {'data': 'd2'},
923-
{'data': 'd3'})
924-
eq_(t.select().execute().fetchall(), [('d1', ), ('d2', ),
925-
('d3', )])
926-
assert_raises(AssertionError, t.update().execute, {'data'
927-
: 'd4'})
928-
assert_raises(AssertionError, t.delete().execute)
929-
finally:
930-
engine.dialect.execution_ctx_cls = execution_ctx_cls
919+
eq_(
920+
engine.execute(t.select()).fetchall(),
921+
[('d1', ), ('d2', ), ('d3', )]
922+
)
923+
eq_(len(mock_rowcount.__get__.mock_calls), 0)
924+
925+
engine.execute(t.update(), {'data': 'd4'})
926+
927+
eq_(len(mock_rowcount.__get__.mock_calls), 1)
928+
929+
engine.execute(t.delete())
930+
eq_(len(mock_rowcount.__get__.mock_calls), 2)
931931

932932

933933
@testing.requires.python26

test/orm/test_versioning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ def compile(element, compiler, **kw):
668668
if hasattr(stmt, "_counter"):
669669
return stmt._counter
670670
else:
671-
stmt._counter = str(counter.next())
671+
stmt._counter = str(next(counter))
672672
return stmt._counter
673673

674674
Table('version_table', metadata,

test/sql/test_returning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ class IncDefault(ColumnElement):
201201

202202
@compiles(IncDefault)
203203
def compile(element, compiler, **kw):
204-
return str(counter.next())
204+
return str(next(counter))
205205

206206
Table("t1", metadata,
207207
Column("id", Integer, primary_key=True, test_needs_autoincrement=True),

0 commit comments

Comments
 (0)