Skip to content

Commit 9324f28

Browse files
author
Reggie Burnett
committed
implemented table.update and collection.modify
updated some of the methods of result to meet spec
1 parent a515895 commit 9324f28

File tree

8 files changed

+140
-25
lines changed

8 files changed

+140
-25
lines changed

lib/mysqlx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from .result import (ColumnMetaData, Row, Result, BufferingResult, RowResult,
3232
SqlResult)
3333
from .statement import (Statement, FilterableStatement, SqlStatement,
34-
AddStatement, RemoveStatement, TableDeleteStatement,
34+
AddStatement, RemoveStatement, DeleteStatement,
3535
CreateCollectionIndexStatement,
3636
DropCollectionIndexStatement)
3737
from .dbdoc import DbDoc

lib/mysqlx/connection.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@ def delete(self, statement):
104104
self.protocol.send_delete(statement)
105105
return Result(self)
106106

107+
def update(self, statement):
108+
self.protocol.send_update(statement)
109+
return Result(self)
110+
107111
def execute_nonquery(self, namespace, cmd, raise_on_fail=True, *args):
108112
self.protocol.send_execute_statement(namespace, cmd, args)
109113
return Result(self)

lib/mysqlx/crud.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@
2121
# along with this program; if not, write to the Free Software
2222
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
2323

24-
from .statement import (AddStatement, RemoveStatement, TableDeleteStatement,
25-
FindStatement, SelectStatement,
26-
CreateCollectionIndexStatement, InsertStatement,
27-
DropCollectionIndexStatement)
24+
from .statement import (FindStatement, AddStatement, RemoveStatement, ModifyStatement,
25+
SelectStatement, InsertStatement, DeleteStatement, UpdateStatement,
26+
CreateCollectionIndexStatement, DropCollectionIndexStatement)
2827

2928

3029
_COUNT_TABLES_QUERY = ("SELECT COUNT(*) FROM information_schema.tables "
@@ -145,6 +144,9 @@ def remove(self, condition=None):
145144
rs.where(condition)
146145
return rs
147146

147+
def modify(self, condition=None):
148+
return ModifyStatement(self, condition)
149+
148150
def count(self):
149151
sql = _COUNT_QUERY.format(self._schema.name, self._name)
150152
return self._connection.execute_sql_scalar(sql)
@@ -182,10 +184,10 @@ def insert(self, *fields):
182184
return InsertStatement(self, *fields)
183185

184186
def update(self):
185-
pass
187+
return UpdateStatement(self)
186188

187189
def delete(self, condition=None):
188-
return TableDeleteStatement(self, condition)
190+
return DeleteStatement(self, condition)
189191

190192
def count(self):
191193
sql = _COUNT_QUERY.format(self._schema.name, self._name)

lib/mysqlx/expr.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,9 @@ def expr(self):
653653
def parse_table_insert_field(self):
654654
return Column(name=self.consume_token(TokenType.IDENT))
655655

656+
def parse_table_update_field(self):
657+
return self.column_identifier().identifier
658+
656659
def parse_table_select_projection(self):
657660
project_expr = []
658661
first = True

lib/mysqlx/protocol.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,18 @@ def send_find(self, stmt):
140140
self._apply_filter(find, stmt)
141141
self._writer.write_message(MySQLx.ClientMessages.CRUD_FIND, find)
142142

143-
def send_insert(self, schema, target, is_docs, rows, cols):
144-
stmt = MySQLxCrud.Insert(
145-
datamodel=MySQLxCrud.DOCUMENT if is_docs else MySQLxCrud.TABLE,
146-
collection=MySQLxCrud.Collection(name=target, schema=schema))
147-
for row in rows:
148-
typed_row = MySQLxCrud.Insert.TypedRow()
149-
stmt.rows.extend(row)
143+
def send_update(self, statement):
144+
update = MySQLxCrud.Update(
145+
data_model=MySQLxCrud.DOCUMENT if statement._doc_based else MySQLxCrud.TABLE,
146+
collection=MySQLxCrud.Collection(name=statement.target.name, schema=statement.schema.name))
147+
self._apply_filter(update, statement)
148+
for update_op in statement._update_ops:
149+
opexpr = UpdateOperation(operation=update_op.update_type, source=update_op.source)
150+
if update_op.value != None:
151+
opexpr.value.CopyFrom(self.arg_object_to_expr(update_op.value, not statement._doc_based))
152+
update.operation.extend([opexpr])
153+
self._writer.write_message(MySQLx.ClientMessages.CRUD_UPDATE, update)
154+
150155

151156
def send_delete(self, stmt):
152157
delete = MySQLxCrud.Delete(

lib/mysqlx/result.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,18 +448,25 @@ def __init__(self, connection):
448448
connection._active_result = None
449449

450450
@property
451-
def Warnings(self):
451+
def get_warnings(self):
452452
return self._warnings
453453

454+
def get_warnings_count(self):
455+
return len(self._warnings)
456+
454457
class Result(BaseResult):
455458
def __init__(self, connection):
456459
super(Result, self).__init__(connection)
457460
self._protocol.close_result(self)
458461

459462
@property
460-
def rows_affected(self):
463+
def get_affected_items_count(self):
461464
return self._rows_affected
462465

466+
@property
467+
def get_autoincrement_value(self):
468+
pass
469+
463470
class BufferingResult(BaseResult):
464471
def __init__(self, connection):
465472
super(BufferingResult, self).__init__(connection)

lib/mysqlx/statement.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
2323

2424

25-
25+
from .protobuf import mysqlx_crud_pb2 as MySQLxCrud
2626
from .result import SqlResult
2727
from .expr import ExprParser
2828
from .dbdoc import DbDoc
@@ -121,6 +121,52 @@ def execute(self):
121121
doc.ensure_id()
122122
return self._connection.send_insert(self)
123123

124+
class UpdateSpec(object):
125+
def __init__(self, type, source, value=None):
126+
if type == MySQLxCrud.UpdateOperation.SET:
127+
self._table_set(source, value)
128+
else:
129+
self.update_type = type
130+
self.source = source
131+
if len(source) > 0 and source[0] == '$':
132+
self.source = source[1:]
133+
self.source = ExprParser(self.source, False).document_field().identifier
134+
self.value = value
135+
136+
def _table_set(self, source, value):
137+
self.update_type = MySQLxCrud.UpdateOperation.SET
138+
self.source = ExprParser(source, True).parse_table_update_field()
139+
self.value = value
140+
141+
class ModifyStatement(FilterableStatement):
142+
def __init__(self, collection, condition=None):
143+
super(ModifyStatement, self).__init__(target=collection, condition=condition)
144+
self._update_ops = []
145+
146+
def set(self, doc_path, value):
147+
self._update_ops.append(UpdateSpec(MySQLxCrud.UpdateOperation.ITEM_SET, doc_path, value))
148+
return self
149+
150+
def change(self, doc_path, value):
151+
self._update_ops.append(UpdateSpec(MySQLxCrud.UpdateOperation.ITEM_REPLACE, doc_path, value))
152+
return self
153+
154+
def unset(self, doc_path):
155+
self._update_ops.append(UpdateSpec(MySQLxCrud.UpdateOperation.ITEM_REMOVE, doc_path))
156+
return self
157+
158+
def array_insert(self, field, value):
159+
self._update_ops.append(UpdateSpec(MySQLxCrud.UpdateOperation.ARRAY_INSERT, field, value))
160+
return self
161+
162+
def array_append(self, doc_path, value):
163+
self._update_ops.append(UpdateSpec(MySQLxCrud.UpdateOperation.UpdateType.ARRAY_APPEND, doc_path, value))
164+
return self
165+
166+
def execute(self):
167+
return self._connection.update(self)
168+
169+
124170
class FindStatement(FilterableStatement):
125171
def __init__(self, collection, condition=None):
126172
super(FindStatement, self).__init__(collection, True, condition)
@@ -170,6 +216,19 @@ def values(self, *values):
170216
def execute(self):
171217
return self._connection.send_insert(self)
172218

219+
class UpdateStatement(FilterableStatement):
220+
def __init__(self, table, *fields):
221+
super(UpdateStatement, self).__init__(target=table, doc_based=False)
222+
self._update_ops = []
223+
224+
def set(self, field, value):
225+
self._update_ops.append(UpdateSpec(MySQLxCrud.UpdateOperation.SET, field, value))
226+
return self
227+
228+
def execute(self):
229+
return self._connection.update(self)
230+
231+
173232
class RemoveStatement(FilterableStatement):
174233
def __init__(self, collection):
175234
super(RemoveStatement, self).__init__(target=collection)
@@ -178,9 +237,9 @@ def execute(self):
178237
return self._connection.delete(self)
179238

180239

181-
class TableDeleteStatement(FilterableStatement):
240+
class DeleteStatement(FilterableStatement):
182241
def __init__(self, table, condition=None):
183-
super(TableDeleteStatement, self).__init__(target=table,
242+
super(DeleteStatement, self).__init__(target=table,
184243
condition=condition,
185244
doc_based=False)
186245

tests/test_mysqlx_crud.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,17 +181,17 @@ def test_add(self):
181181
collection_name = "collection_test"
182182
collection = self.schema.create_collection(collection_name)
183183
result = collection.add({"name":"Fred", "age":21}).execute()
184-
self.assertEqual(result.rows_affected, 1)
184+
self.assertEqual(result.get_affected_items_count, 1)
185185
self.assertEqual(1, collection.count())
186186

187187
# now add multiple dictionaries at once
188188
result = collection.add({"name": "Wilma", "age": 33}, {"name": "Barney", "age": 42}).execute()
189-
self.assertEqual(result.rows_affected, 2)
189+
self.assertEqual(result.get_affected_items_count, 2)
190190
self.assertEqual(3, collection.count())
191191

192192
# now let's try adding strings
193193
result = collection.add('{"name": "Bambam", "age": 8}', '{"name": "Pebbles", "age": 8}').execute()
194-
self.assertEqual(result.rows_affected, 2)
194+
self.assertEqual(result.get_affected_items_count, 2)
195195
self.assertEqual(5, collection.count())
196196

197197
def test_remove(self):
@@ -200,7 +200,7 @@ def test_remove(self):
200200
collection.add({"name":"Fred", "age":21}).execute()
201201
self.assertEqual(1, collection.count())
202202
result = collection.remove("age == 21").execute()
203-
self.assertEqual(1, result.rows_affected)
203+
self.assertEqual(1, result.get_affected_items_count)
204204
self.assertEqual(0, collection.count())
205205

206206
def test_find(self):
@@ -229,6 +229,31 @@ def test_find(self):
229229
self.assertEqual(42, docs[1]["age"])
230230
self.assertEqual(1, len(docs[1].keys()))
231231

232+
def test_modify(self):
233+
collection_name = "collection_test"
234+
collection = self.schema.create_collection(collection_name)
235+
result = collection.add(
236+
{"name":"Fred", "age":21},
237+
{"name": "Barney", "age": 28},
238+
{"name": "Wilma", "age": 42},
239+
{"name": "Betty", "age": 67},
240+
241+
).execute()
242+
243+
result = collection.modify("age < 67").set("young", True).execute()
244+
self.assertEqual(3, result.get_affected_items_count)
245+
doc = collection.find("name = 'Fred'").execute().fetch_all()[0]
246+
self.assertEqual(True, doc.young)
247+
248+
result = collection.modify("age == 28").change("young", False).execute()
249+
self.assertEqual(1, result.get_affected_items_count)
250+
docs = collection.find("young = True").execute().fetch_all()
251+
self.assertEqual(2, len(docs))
252+
253+
result = collection.modify("young == True").unset("young").execute()
254+
self.assertEqual(2, result.get_affected_items_count)
255+
docs = collection.find("young = True").execute().fetch_all()
256+
self.assertEqual(0, len(docs))
232257

233258
def test_results(self):
234259
collection_name = "collection_test"
@@ -376,8 +401,18 @@ def test_insert(self):
376401
self.assertEqual(4, len(rows))
377402

378403
def test_update(self):
379-
# TODO: To implement
380-
pass
404+
table = self.schema.get_table("test")
405+
406+
self.node_session.sql("CREATE TABLE {0}.test(age INT, name VARCHAR(50), gender CHAR(1))".format(self.schema_name)).execute()
407+
408+
result = table.insert("age", "name") \
409+
.values(21, 'Fred') \
410+
.values(28, 'Barney') \
411+
.values(42, 'Wilma') \
412+
.values(67, 'Betty').execute()
413+
414+
result = table.update().set("age", 25).where("age == 21").execute()
415+
self.assertEqual(1, result.get_affected_items_count)
381416

382417
def test_delete(self):
383418
table_name = "table_test"

0 commit comments

Comments
 (0)