Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 0343a6e

Browse files
committed
Now tests for unique key constraints (if possible) instead of always actively validating (+ tests)
1 parent d86950e commit 0343a6e

File tree

11 files changed

+103
-10
lines changed

11 files changed

+103
-10
lines changed

data_diff/databases/base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ class Database(AbstractDatabase):
108108
TYPE_CLASSES: Dict[str, type] = {}
109109
default_schema: str = None
110110
SUPPORTS_ALPHANUMS = True
111+
SUPPORTS_PRIMARY_KEY = False
111112

112113
_interactive = False
113114

@@ -232,6 +233,20 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
232233
assert len(d) == len(rows)
233234
return d
234235

236+
237+
def select_table_unique_columns(self, path: DbPath) -> str:
238+
schema, table = self._normalize_table_path(path)
239+
240+
return (
241+
"SELECT column_name "
242+
"FROM information_schema.key_column_usage "
243+
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
244+
)
245+
246+
def query_table_unique_columns(self, path: DbPath) -> List[str]:
247+
res = self.query(self.select_table_unique_columns(path), List[str])
248+
return list(res)
249+
235250
def _process_table_schema(
236251
self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None
237252
):

data_diff/databases/bigquery.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Union
1+
from typing import List, Union
22
from .database_types import Timestamp, Datetime, Integer, Decimal, Float, Text, DbPath, FractionalType, TemporalType
33
from .base import Database, import_helper, parse_table_name, ConnectError, apply_query
44
from .base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter
@@ -78,6 +78,9 @@ def select_table_schema(self, path: DbPath) -> str:
7878
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
7979
)
8080

81+
def query_table_unique_columns(self, path: DbPath) -> List[str]:
82+
return []
83+
8184
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
8285
if coltype.rounds:
8386
timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))"

data_diff/databases/database_types.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,17 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
253253
"""
254254
...
255255

256+
@abstractmethod
257+
def select_table_unique_columns(self, path: DbPath) -> str:
258+
"Provide SQL for selecting the names of unique columns in the table"
259+
...
260+
261+
@abstractmethod
262+
def query_table_unique_columns(self, path: DbPath) -> List[str]:
263+
"""Query the table for its unique columns for table in 'path', and return {column}
264+
"""
265+
...
266+
256267
@abstractmethod
257268
def _process_table_schema(
258269
self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None

data_diff/databases/mysql.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class MySQL(ThreadedDatabase):
3939
}
4040
ROUNDS_ON_PREC_LOSS = True
4141
SUPPORTS_ALPHANUMS = False
42+
SUPPORTS_PRIMARY_KEY = True
4243

4344
def __init__(self, *, thread_count, **kw):
4445
self._args = kw

data_diff/databases/oracle.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class Oracle(ThreadedDatabase):
3939
"VARCHAR2": Text,
4040
}
4141
ROUNDS_ON_PREC_LOSS = True
42+
SUPPORTS_PRIMARY_KEY = True
4243

4344
def __init__(self, *, host, database, thread_count, **kw):
4445
self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw)

data_diff/databases/postgresql.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class PostgreSQL(ThreadedDatabase):
4646
"uuid": Native_UUID,
4747
}
4848
ROUNDS_ON_PREC_LOSS = True
49+
SUPPORTS_PRIMARY_KEY = True
4950

5051
default_schema = "public"
5152

data_diff/databases/snowflake.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Union
1+
from typing import Union, List
22
import logging
33

44
from .database_types import Timestamp, TimestampTZ, Decimal, Float, Text, FractionalType, TemporalType, DbPath
@@ -95,3 +95,6 @@ def is_autocommit(self) -> bool:
9595

9696
def explain_as_text(self, query: str) -> str:
9797
return f"EXPLAIN USING TEXT {query}"
98+
99+
def query_table_unique_columns(self, path: DbPath) -> List[str]:
100+
return []

data_diff/joindiff_tables.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,18 +190,24 @@ def _diff_segments(
190190
if not is_xa:
191191
yield "+", tuple(b_row)
192192

193-
def _test_duplicate_keys(self, table1, table2):
193+
def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment):
194194
logger.debug("Testing for duplicate keys")
195195

196196
# Test duplicate keys
197197
for ts in [table1, table2]:
198+
unique = ts.database.query_table_unique_columns(ts.table_path)
199+
198200
t = ts.make_select()
199201
key_columns = ts.key_columns
200202

201-
q = t.select(total=Count(), total_distinct=Count(Concat(this[key_columns]), distinct=True))
202-
total, total_distinct = ts.database.query(q, tuple)
203-
if total != total_distinct:
204-
raise ValueError("Duplicate primary keys")
203+
unvalidated = list(set(key_columns) - set(unique))
204+
if unvalidated:
205+
# Validate that there are no duplicate keys
206+
self.stats['validated_unique_keys'] = self.stats.get('validated_unique_keys', []) + [unvalidated]
207+
q = t.select(total=Count(), total_distinct=Count(Concat(this[unvalidated]), distinct=True))
208+
total, total_distinct = ts.database.query(q, tuple)
209+
if total != total_distinct:
210+
raise ValueError("Duplicate primary keys")
205211

206212
def _test_null_keys(self, table1, table2):
207213
logger.debug("Testing for null keys")

data_diff/queries/ast_classes.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,12 +299,12 @@ class TablePath(ExprNode, ITable):
299299
path: DbPath
300300
schema: Optional[Schema] = field(default=None, repr=False)
301301

302-
def create(self, source_table: ITable = None, *, if_not_exists=False):
302+
def create(self, source_table: ITable = None, *, if_not_exists=False, primary_keys=None):
303303
if source_table is None and not self.schema:
304304
raise ValueError("Either schema or source table needed to create table")
305305
if isinstance(source_table, TablePath):
306306
source_table = source_table.select()
307-
return CreateTable(self, source_table, if_not_exists=if_not_exists)
307+
return CreateTable(self, source_table, if_not_exists=if_not_exists, primary_keys=primary_keys)
308308

309309
def drop(self, if_exists=False):
310310
return DropTable(self, if_exists=if_exists)
@@ -647,14 +647,16 @@ class CreateTable(Statement):
647647
path: TablePath
648648
source_table: Expr = None
649649
if_not_exists: bool = False
650+
primary_keys: List[str] = None
650651

651652
def compile(self, c: Compiler) -> str:
652653
ne = "IF NOT EXISTS " if self.if_not_exists else ""
653654
if self.source_table:
654655
return f"CREATE TABLE {ne}{c.compile(self.path)} AS {c.compile(self.source_table)}"
655656

656657
schema = ", ".join(f"{c.database.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items())
657-
return f"CREATE TABLE {ne}{c.compile(self.path)}({schema})"
658+
pks = ", PRIMARY KEY (%s)" % ', '.join(self.primary_keys) if self.primary_keys and c.database.SUPPORTS_PRIMARY_KEY else ""
659+
return f"CREATE TABLE {ne}{c.compile(self.path)}({schema}{pks})"
658660

659661

660662
@dataclass

tests/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def tearDown(self):
150150

151151

152152
def _parameterized_class_per_conn(test_databases):
153+
test_databases = set(test_databases)
153154
names = [(cls.__name__, cls) for cls in CONN_STRINGS if cls in test_databases]
154155
return parameterized_class(("name", "db_cls"), names)
155156

0 commit comments

Comments
 (0)