diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index b68739d8..06b6ea08 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -110,6 +110,8 @@ class Database(AbstractDatabase): TYPE_CLASSES: Dict[str, type] = {} default_schema: str = None SUPPORTS_ALPHANUMS = True + SUPPORTS_PRIMARY_KEY = False + SUPPORTS_UNIQUE_CONSTAINT = False _interactive = False @@ -235,6 +237,21 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: assert len(d) == len(rows) return d + def select_table_unique_columns(self, path: DbPath) -> str: + schema, table = self._normalize_table_path(path) + + return ( + "SELECT column_name " + "FROM information_schema.key_column_usage " + f"WHERE table_name = '{table}' AND table_schema = '{schema}'" + ) + + def query_table_unique_columns(self, path: DbPath) -> List[str]: + if not self.SUPPORTS_UNIQUE_CONSTAINT: + raise NotImplementedError("This database doesn't support 'unique' constraints") + res = self.query(self.select_table_unique_columns(path), List[str]) + return list(res) + def _process_table_schema( self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None ): diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 9c500dd5..3d3720b6 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import List, Union from .database_types import Timestamp, Datetime, Integer, Decimal, Float, Text, DbPath, FractionalType, TemporalType from .base import Database, import_helper, parse_table_name, ConnectError, apply_query from .base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter @@ -78,6 +78,9 @@ def select_table_schema(self, path: DbPath) -> str: f"WHERE table_name = '{table}' AND table_schema = '{schema}'" ) + def query_table_unique_columns(self, path: DbPath) -> List[str]: + return [] + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index 8adc9fbb..6c9af301 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -254,6 +254,16 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: """ ... + @abstractmethod + def select_table_unique_columns(self, path: DbPath) -> str: + "Provide SQL for selecting the names of unique columns in the table" + ... + + @abstractmethod + def query_table_unique_columns(self, path: DbPath) -> List[str]: + """Query the table for its unique columns for table in 'path', and return {column}""" + ... + @abstractmethod def _process_table_schema( self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index 3f9eb98c..e8e47b1b 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -39,6 +39,8 @@ class MySQL(ThreadedDatabase): } ROUNDS_ON_PREC_LOSS = True SUPPORTS_ALPHANUMS = False + SUPPORTS_PRIMARY_KEY = True + SUPPORTS_UNIQUE_CONSTAINT = True def __init__(self, *, thread_count, **kw): self._args = kw diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 6b4ebe2c..73b53492 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -38,6 +38,7 @@ class Oracle(ThreadedDatabase): "VARCHAR2": Text, } ROUNDS_ON_PREC_LOSS = True + SUPPORTS_PRIMARY_KEY = True def __init__(self, *, host, database, thread_count, **kw): self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw) diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index 72d26d07..3181dab1 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -46,6 +46,8 @@ class PostgreSQL(ThreadedDatabase): "uuid": Native_UUID, } ROUNDS_ON_PREC_LOSS = True + SUPPORTS_PRIMARY_KEY = True + SUPPORTS_UNIQUE_CONSTAINT = True default_schema = "public" diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 635ba8f4..afd52ba8 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, List import logging from .database_types import Timestamp, TimestampTZ, Decimal, Float, Text, FractionalType, TemporalType, DbPath @@ -95,3 +95,6 @@ def is_autocommit(self) -> bool: def explain_as_text(self, query: str) -> str: return f"EXPLAIN USING TEXT {query}" + + def query_table_unique_columns(self, path: DbPath) -> List[str]: + return [] diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index d2dbca61..95e1a4d5 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -195,18 +195,24 @@ def _diff_segments( if not is_xa: yield "+", tuple(b_row) - def _test_duplicate_keys(self, table1, table2): + def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment): logger.debug("Testing for duplicate keys") # Test duplicate keys for ts in [table1, table2]: + unique = ts.database.query_table_unique_columns(ts.table_path) if ts.database.SUPPORTS_UNIQUE_CONSTAINT else [] + t = ts.make_select() key_columns = ts.key_columns - q = t.select(total=Count(), total_distinct=Count(Concat(this[key_columns]), distinct=True)) - total, total_distinct = ts.database.query(q, tuple) - if total != total_distinct: - raise ValueError("Duplicate primary keys") + unvalidated = list(set(key_columns) - set(unique)) + if unvalidated: + # Validate that there are no duplicate keys + self.stats["validated_unique_keys"] = self.stats.get("validated_unique_keys", []) + [unvalidated] + q = t.select(total=Count(), total_distinct=Count(Concat(this[unvalidated]), distinct=True)) + total, total_distinct = ts.database.query(q, tuple) + if total != total_distinct: + raise ValueError("Duplicate primary keys") def _test_null_keys(self, table1, table2): logger.debug("Testing for null keys") diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index f363df14..66d62783 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -306,13 +306,13 @@ def compile(self, c: Compiler) -> str: return ".".join(map(c.quote, path)) # Statement shorthands + def create(self, source_table: ITable = None, *, if_not_exists=False, primary_keys=None): - def create(self, source_table: ITable = None, *, if_not_exists=False): if source_table is None and not self.schema: raise ValueError("Either schema or source table needed to create table") if isinstance(source_table, TablePath): source_table = source_table.select() - return CreateTable(self, source_table, if_not_exists=if_not_exists) + return CreateTable(self, source_table, if_not_exists=if_not_exists, primary_keys=primary_keys) def drop(self, if_exists=False): return DropTable(self, if_exists=if_exists) @@ -641,14 +641,20 @@ class CreateTable(Statement): path: TablePath source_table: Expr = None if_not_exists: bool = False + primary_keys: List[str] = None def compile(self, c: Compiler) -> str: ne = "IF NOT EXISTS " if self.if_not_exists else "" if self.source_table: return f"CREATE TABLE {ne}{c.compile(self.path)} AS {c.compile(self.source_table)}" - schema = ", ".join(f"{c.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items()) - return f"CREATE TABLE {ne}{c.compile(self.path)}({schema})" + schema = ", ".join(f"{c.database.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items()) + pks = ( + ", PRIMARY KEY (%s)" % ", ".join(self.primary_keys) + if self.primary_keys and c.database.SUPPORTS_PRIMARY_KEY + else "" + ) + return f"CREATE TABLE {ne}{c.compile(self.path)}({schema}{pks})" @dataclass diff --git a/tests/common.py b/tests/common.py index aad75074..cd974e34 100644 --- a/tests/common.py +++ b/tests/common.py @@ -149,6 +149,7 @@ def tearDown(self): def _parameterized_class_per_conn(test_databases): + test_databases = set(test_databases) names = [(cls.__name__, cls) for cls in CONN_STRINGS if cls in test_databases] return parameterized_class(("name", "db_cls"), names) diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index 03ca3d69..22ed217d 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -273,3 +273,49 @@ def test_null_pks(self): x = self.differ.diff_tables(self.table, self.table2) self.assertRaises(ValueError, list, x) + + +@test_each_database_in_list(d for d in TEST_DATABASES if d.SUPPORTS_PRIMARY_KEY and d.SUPPORTS_UNIQUE_CONSTAINT) +class TestUniqueConstraint(TestPerDatabase): + def setUp(self): + super().setUp() + + self.src_table = table( + self.table_src_path, + schema={"id": int, "userid": int, "movieid": int, "rating": float}, + ) + self.dst_table = table( + self.table_dst_path, + schema={"id": int, "userid": int, "movieid": int, "rating": float}, + ) + + self.connection.query( + [self.src_table.create(primary_keys=["id"]), self.dst_table.create(primary_keys=["id", "userid"]), commit] + ) + + self.differ = JoinDiffer() + + def test_unique_constraint(self): + self.connection.query( + [ + self.src_table.insert_rows([[1, 1, 1, 9], [2, 2, 2, 9]]), + self.dst_table.insert_rows([[1, 1, 1, 9], [2, 2, 2, 9]]), + commit, + ] + ) + + # Test no active validation + table = TableSegment(self.connection, self.table_src_path, ("id",), case_sensitive=False) + table2 = TableSegment(self.connection, self.table_dst_path, ("id",), case_sensitive=False) + + res = list(self.differ.diff_tables(table, table2)) + assert not res + assert "validated_unique_keys" not in self.differ.stats + + # Test active validation + table = TableSegment(self.connection, self.table_src_path, ("userid",), case_sensitive=False) + table2 = TableSegment(self.connection, self.table_dst_path, ("userid",), case_sensitive=False) + + res = list(self.differ.diff_tables(table, table2)) + assert not res + self.assertEqual(self.differ.stats["validated_unique_keys"], [["userid"]])