Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit b9ce7ed

Browse files
authored
Merge pull request #257 from datafold/test_unique_keys
tests for unique key constraints (if possible) instead of always actively validating (+ tests)
2 parents 7dc4ca0 + 01267fd commit b9ce7ed

File tree

11 files changed

+108
-11
lines changed

11 files changed

+108
-11
lines changed

data_diff/databases/base.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ class Database(AbstractDatabase):
110110
TYPE_CLASSES: Dict[str, type] = {}
111111
default_schema: str = None
112112
SUPPORTS_ALPHANUMS = True
113+
SUPPORTS_PRIMARY_KEY = False
114+
SUPPORTS_UNIQUE_CONSTAINT = False
113115

114116
_interactive = False
115117

@@ -235,6 +237,21 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
235237
assert len(d) == len(rows)
236238
return d
237239

240+
def select_table_unique_columns(self, path: DbPath) -> str:
241+
schema, table = self._normalize_table_path(path)
242+
243+
return (
244+
"SELECT column_name "
245+
"FROM information_schema.key_column_usage "
246+
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
247+
)
248+
249+
def query_table_unique_columns(self, path: DbPath) -> List[str]:
250+
if not self.SUPPORTS_UNIQUE_CONSTAINT:
251+
raise NotImplementedError("This database doesn't support 'unique' constraints")
252+
res = self.query(self.select_table_unique_columns(path), List[str])
253+
return list(res)
254+
238255
def _process_table_schema(
239256
self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None
240257
):

data_diff/databases/bigquery.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Union
1+
from typing import List, Union
22
from .database_types import Timestamp, Datetime, Integer, Decimal, Float, Text, DbPath, FractionalType, TemporalType
33
from .base import Database, import_helper, parse_table_name, ConnectError, apply_query
44
from .base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter
@@ -78,6 +78,9 @@ def select_table_schema(self, path: DbPath) -> str:
7878
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
7979
)
8080

81+
def query_table_unique_columns(self, path: DbPath) -> List[str]:
82+
return []
83+
8184
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
8285
if coltype.rounds:
8386
timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))"

data_diff/databases/database_types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,16 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
254254
"""
255255
...
256256

257+
@abstractmethod
258+
def select_table_unique_columns(self, path: DbPath) -> str:
259+
"Provide SQL for selecting the names of unique columns in the table"
260+
...
261+
262+
@abstractmethod
263+
def query_table_unique_columns(self, path: DbPath) -> List[str]:
264+
"""Query the table for its unique columns for table in 'path', and return {column}"""
265+
...
266+
257267
@abstractmethod
258268
def _process_table_schema(
259269
self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None

data_diff/databases/mysql.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ class MySQL(ThreadedDatabase):
3939
}
4040
ROUNDS_ON_PREC_LOSS = True
4141
SUPPORTS_ALPHANUMS = False
42+
SUPPORTS_PRIMARY_KEY = True
43+
SUPPORTS_UNIQUE_CONSTAINT = True
4244

4345
def __init__(self, *, thread_count, **kw):
4446
self._args = kw

data_diff/databases/oracle.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class Oracle(ThreadedDatabase):
3838
"VARCHAR2": Text,
3939
}
4040
ROUNDS_ON_PREC_LOSS = True
41+
SUPPORTS_PRIMARY_KEY = True
4142

4243
def __init__(self, *, host, database, thread_count, **kw):
4344
self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw)

data_diff/databases/postgresql.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class PostgreSQL(ThreadedDatabase):
4646
"uuid": Native_UUID,
4747
}
4848
ROUNDS_ON_PREC_LOSS = True
49+
SUPPORTS_PRIMARY_KEY = True
50+
SUPPORTS_UNIQUE_CONSTAINT = True
4951

5052
default_schema = "public"
5153

data_diff/databases/snowflake.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Union
1+
from typing import Union, List
22
import logging
33

44
from .database_types import Timestamp, TimestampTZ, Decimal, Float, Text, FractionalType, TemporalType, DbPath
@@ -95,3 +95,6 @@ def is_autocommit(self) -> bool:
9595

9696
def explain_as_text(self, query: str) -> str:
9797
return f"EXPLAIN USING TEXT {query}"
98+
99+
def query_table_unique_columns(self, path: DbPath) -> List[str]:
100+
return []

data_diff/joindiff_tables.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,18 +195,24 @@ def _diff_segments(
195195
if not is_xa:
196196
yield "+", tuple(b_row)
197197

198-
def _test_duplicate_keys(self, table1, table2):
198+
def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment):
199199
logger.debug("Testing for duplicate keys")
200200

201201
# Test duplicate keys
202202
for ts in [table1, table2]:
203+
unique = ts.database.query_table_unique_columns(ts.table_path) if ts.database.SUPPORTS_UNIQUE_CONSTAINT else []
204+
203205
t = ts.make_select()
204206
key_columns = ts.key_columns
205207

206-
q = t.select(total=Count(), total_distinct=Count(Concat(this[key_columns]), distinct=True))
207-
total, total_distinct = ts.database.query(q, tuple)
208-
if total != total_distinct:
209-
raise ValueError("Duplicate primary keys")
208+
unvalidated = list(set(key_columns) - set(unique))
209+
if unvalidated:
210+
# Validate that there are no duplicate keys
211+
self.stats["validated_unique_keys"] = self.stats.get("validated_unique_keys", []) + [unvalidated]
212+
q = t.select(total=Count(), total_distinct=Count(Concat(this[unvalidated]), distinct=True))
213+
total, total_distinct = ts.database.query(q, tuple)
214+
if total != total_distinct:
215+
raise ValueError("Duplicate primary keys")
210216

211217
def _test_null_keys(self, table1, table2):
212218
logger.debug("Testing for null keys")

data_diff/queries/ast_classes.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -306,13 +306,13 @@ def compile(self, c: Compiler) -> str:
306306
return ".".join(map(c.quote, path))
307307

308308
# Statement shorthands
309+
def create(self, source_table: ITable = None, *, if_not_exists=False, primary_keys=None):
309310

310-
def create(self, source_table: ITable = None, *, if_not_exists=False):
311311
if source_table is None and not self.schema:
312312
raise ValueError("Either schema or source table needed to create table")
313313
if isinstance(source_table, TablePath):
314314
source_table = source_table.select()
315-
return CreateTable(self, source_table, if_not_exists=if_not_exists)
315+
return CreateTable(self, source_table, if_not_exists=if_not_exists, primary_keys=primary_keys)
316316

317317
def drop(self, if_exists=False):
318318
return DropTable(self, if_exists=if_exists)
@@ -641,14 +641,20 @@ class CreateTable(Statement):
641641
path: TablePath
642642
source_table: Expr = None
643643
if_not_exists: bool = False
644+
primary_keys: List[str] = None
644645

645646
def compile(self, c: Compiler) -> str:
646647
ne = "IF NOT EXISTS " if self.if_not_exists else ""
647648
if self.source_table:
648649
return f"CREATE TABLE {ne}{c.compile(self.path)} AS {c.compile(self.source_table)}"
649650

650-
schema = ", ".join(f"{c.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items())
651-
return f"CREATE TABLE {ne}{c.compile(self.path)}({schema})"
651+
schema = ", ".join(f"{c.database.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items())
652+
pks = (
653+
", PRIMARY KEY (%s)" % ", ".join(self.primary_keys)
654+
if self.primary_keys and c.database.SUPPORTS_PRIMARY_KEY
655+
else ""
656+
)
657+
return f"CREATE TABLE {ne}{c.compile(self.path)}({schema}{pks})"
652658

653659

654660
@dataclass

tests/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def tearDown(self):
149149

150150

151151
def _parameterized_class_per_conn(test_databases):
152+
test_databases = set(test_databases)
152153
names = [(cls.__name__, cls) for cls in CONN_STRINGS if cls in test_databases]
153154
return parameterized_class(("name", "db_cls"), names)
154155

0 commit comments

Comments
 (0)