From 616fd63f5b2af888ad6514c1f7fe75fb11f7f37a Mon Sep 17 00:00:00 2001 From: Simon Eskildsen Date: Fri, 17 Jun 2022 22:17:09 +0000 Subject: [PATCH 01/32] tests: failing float/numeric tests --- tests/test_database_types.py | 70 ++++++++++++++++++++++++------------ 1 file changed, 48 insertions(+), 22 deletions(-) diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 6b95d310..c20b5fd8 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -24,7 +24,7 @@ "2022-05-01 15:10:03.003030", "2022-06-01 15:10:05.009900", ], - "float": [0.0, 0.1, 0.10, 10.0, 100.98], + "float": [0.0, 0.1, 0.00188, 0.99999, 0.091919, 0.10, 10.0, 100.98], } DATABASE_TYPES = { @@ -37,15 +37,15 @@ ], # https://www.postgresql.org/docs/current/datatype-datetime.html "datetime_no_timezone": [ - "timestamp(6) without time zone", - "timestamp(3) without time zone", - "timestamp(0) without time zone", + # "timestamp(6) without time zone", + # "timestamp(3) without time zone", + # "timestamp(0) without time zone", ], # https://www.postgresql.org/docs/current/datatype-numeric.html "float": [ - # "real", - # "double precision", - # "numeric(6,3)", + "real", + "double precision", + "numeric(6,3)", ], }, db.MySQL: { @@ -58,12 +58,19 @@ # "bigint", # 8 bytes ], # https://dev.mysql.com/doc/refman/8.0/en/datetime.html - "datetime_no_timezone": ["timestamp(6)", "timestamp(3)", "timestamp(0)", "timestamp", "datetime(6)"], + "datetime_no_timezone": [ + # "timestamp(6)", + # "timestamp(3)", + # "timestamp(0)", + # "timestamp", + # "datetime(6)", + ], # https://dev.mysql.com/doc/refman/8.0/en/numeric-types.html "float": [ - # "float", - # "double", - # "numeric", + "float", + "double", + "numeric", + "numeric(65, 10)", ], }, db.BigQuery: { @@ -71,6 +78,10 @@ "timestamp", # "datetime", ], + "float": [ + "numeric", + "float64", + ] }, db.Snowflake: { # https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#int-integer-bigint-smallint-tinyint-byteint @@ -85,15 +96,15 @@ ], # https://docs.snowflake.com/en/sql-reference/data-types-datetime.html "datetime_no_timezone": [ - "timestamp(0)", - "timestamp(3)", - "timestamp(6)", - "timestamp(9)", + # "timestamp(0)", + # "timestamp(3)", + # "timestamp(6)", + # "timestamp(9)", ], # https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#decimal-numeric "float": [ - # "float" - # "numeric", + "float", + "numeric", ], }, db.Redshift: { @@ -132,7 +143,13 @@ # "int", # 4 bytes # "bigint", # 8 bytes ], - "datetime_no_timezone": ["timestamp(6)", "timestamp(3)", "timestamp(0)", "timestamp", "datetime(6)"], + "datetime_no_timezone": [ + # "timestamp(6)", + # "timestamp(3)", + # "timestamp(0)", + # "timestamp", + # "datetime(6)", + ], "float": [ # "float", # "double", @@ -150,7 +167,10 @@ # target_type: (int, bigint) } for source_db, source_type_categories in DATABASE_TYPES.items(): for target_db, target_type_categories in DATABASE_TYPES.items(): - for type_category, source_types in source_type_categories.items(): # int, datetime, .. + for ( + type_category, + source_types, + ) in source_type_categories.items(): # int, datetime, .. for source_type in source_types: for target_type in target_type_categories[type_category]: if CONNS.get(source_db, False) and CONNS.get(target_db, False): @@ -229,13 +249,19 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego dst_conn.query(f"CREATE TABLE {dst_table}(id int, col {target_type})", None) _insert_to_table(dst_conn, dst_table, values_in_source) - self.table = TableSegment(self.src_conn, src_table_path, "id", None, ("col",), case_sensitive=False) - self.table2 = TableSegment(self.dst_conn, dst_table_path, "id", None, ("col",), case_sensitive=False) + self.table = TableSegment( + self.src_conn, src_table_path, "id", None, ("col",), case_sensitive=False + ) + self.table2 = TableSegment( + self.dst_conn, dst_table_path, "id", None, ("col",), case_sensitive=False + ) self.assertEqual(len(sample_values), self.table.count()) self.assertEqual(len(sample_values), self.table2.count()) - differ = TableDiffer(bisection_threshold=3, bisection_factor=2) # ensure we actually checksum + differ = TableDiffer( + bisection_threshold=3, bisection_factor=2 + ) # ensure we actually checksum diff = list(differ.diff_tables(self.table, self.table2)) expected = [] self.assertEqual(expected, diff) From 06b1b55440fc0a4230edbf71e6e2c31a9005aabb Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Mon, 20 Jun 2022 11:50:06 +0200 Subject: [PATCH 02/32] MySQL+Postgres numeric precision passing (WIP) --- data_diff/database.py | 69 ++++++++++++++++++++++++++++------ data_diff/diff_tables.py | 14 ++++++- tests/test_database_types.py | 14 ++----- tests/test_normalize_fields.py | 2 +- 4 files changed, 76 insertions(+), 23 deletions(-) diff --git a/data_diff/database.py b/data_diff/database.py index 5e39e8a3..fad589f4 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -105,6 +105,19 @@ class Datetime(TemporalType): pass +@dataclass +class NumericType(ColType): + precision: int + + +class Float(NumericType): + pass + + +class Decimal(NumericType): + pass + + @dataclass class UnknownColType(ColType): text: str @@ -212,23 +225,35 @@ def query(self, sql_ast: SqlOrStr, res_type: type): def enable_interactive(self): self._interactive = True - def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None) -> ColType: + def _parse_type( + self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None, numeric_scale: int = None + ) -> ColType: """ """ cls = self.DATETIME_TYPES.get(type_repr) if cls: return cls( - precision=datetime_precision if datetime_precision is not None else DEFAULT_PRECISION, + precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, rounds=self.ROUNDS_ON_PREC_LOSS, ) + cls = self.NUMERIC_TYPES.get(type_repr) + if cls: + assert numeric_precision is not None + if cls is Decimal: + assert numeric_scale is not None + return cls(precision=numeric_scale) + + assert numeric_scale is None + return cls(precision=numeric_precision // 3) + return UnknownColType(type_repr) def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) return ( - "SELECT column_name, data_type, datetime_precision, numeric_precision FROM information_schema.columns " + "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM information_schema.columns " f"WHERE table_name = '{table}' AND table_schema = '{schema}'" ) @@ -250,7 +275,9 @@ def _normalize_table_path(self, path: DbPath) -> DbPath: elif len(path) == 2: return path - raise ValueError(f"{self.__class__.__name__}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table") + raise ValueError( + f"{self.__class__.__name__}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table" + ) def parse_table_name(self, name: str) -> DbPath: return parse_table_name(name) @@ -295,7 +322,8 @@ def close(self): _CHECKSUM_BITSIZE = CHECKSUM_HEXDIGITS << 2 CHECKSUM_MASK = (2**_CHECKSUM_BITSIZE) - 1 -DEFAULT_PRECISION = 6 +DEFAULT_DATETIME_PRECISION = 6 +DEFAULT_NUMERIC_PRECISION = 6 TIMESTAMP_PRECISION_POS = 20 # len("2022-06-03 12:24:35.") == 20 @@ -307,6 +335,13 @@ class Postgres(ThreadedDatabase): "timestamp": Timestamp, # "datetime": Datetime, } + NUMERIC_TYPES = { + "double precision": Float, + "real": Float, + "decimal": Decimal, + "integer": Decimal, + "numeric": Decimal, + } ROUNDS_ON_PREC_LOSS = True default_schema = "public" @@ -351,6 +386,9 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + elif isinstance(coltype, NumericType): + value = f"{value}::decimal(38,{coltype.precision})" + return self.to_string(f"{value}") @@ -422,7 +460,8 @@ def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_pr if m: datetime_precision = int(m.group(1)) return cls( - precision=datetime_precision if datetime_precision is not None else DEFAULT_PRECISION, rounds=False + precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, + rounds=False, ) return UnknownColType(type_repr) @@ -433,6 +472,12 @@ class MySQL(ThreadedDatabase): "datetime": Datetime, "timestamp": Timestamp, } + NUMERIC_TYPES = { + "double": Float, + "float": Float, + "decimal": Decimal, + "int": Decimal, + } ROUNDS_ON_PREC_LOSS = True def __init__(self, host, port, user, password, *, database, thread_count, **kw): @@ -472,6 +517,9 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: s = self.to_string(f"cast({value} as datetime(6))") return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" + elif isinstance(coltype, NumericType): + value = f"cast({value} as decimal(38,{coltype.precision}))" + return self.to_string(f"{value}") @@ -518,7 +566,7 @@ def select_table_schema(self, path: DbPath) -> str: ) def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - if isinstance(coltype, PrecisionType): + if isinstance(coltype, TemporalType): return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')" return self.to_string(f"{value}") @@ -532,9 +580,8 @@ def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_pr m = re.match(regexp + "$", type_repr) if m: datetime_precision = int(m.group(1)) - return cls(precision=datetime_precision if datetime_precision is not None else DEFAULT_PRECISION, + return cls(precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, rounds=self.ROUNDS_ON_PREC_LOSS - ) return UnknownColType(type_repr) @@ -645,7 +692,7 @@ def select_table_schema(self, path: DbPath) -> str: ) def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - if isinstance(coltype, PrecisionType): + if isinstance(coltype, TemporalType): if coltype.rounds: timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})" @@ -729,7 +776,7 @@ def select_table_schema(self, path: DbPath) -> str: return super().select_table_schema((schema, table)) def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - if isinstance(coltype, PrecisionType): + if isinstance(coltype, TemporalType): if coltype.rounds: timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, {value}::timestamp(9))/1000000000, {coltype.precision}))" else: diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 4087b49d..9838189c 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -12,7 +12,7 @@ from runtype import dataclass from .sql import Select, Checksum, Compare, DbPath, DbKey, DbTime, Count, TableName, Time, Min, Max -from .database import Database, PrecisionType, ColType +from .database import Database, NumericType, PrecisionType, ColType logger = logging.getLogger("diff_tables") @@ -369,6 +369,18 @@ def _validate_and_adjust_columns(self, table1, table2): table1._schema[c] = col1.replace(precision=lowest.precision, rounds=lowest.rounds) table2._schema[c] = col2.replace(precision=lowest.precision, rounds=lowest.rounds) + elif isinstance(col1, NumericType): + if not isinstance(col2, NumericType): + raise TypeError(f"Incompatible types for column {c}: {col1} <-> {col2}") + + lowest = min(col1, col2, key=attrgetter("precision")) + + if col1.precision != col2.precision: + logger.warn(f"Using reduced precision {lowest} for column '{c}'. Types={col1}, {col2}") + + table1._schema[c] = col1.replace(precision=lowest.precision) + table2._schema[c] = col2.replace(precision=lowest.precision) + def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None): assert table1.is_bounded and table2.is_bounded diff --git a/tests/test_database_types.py b/tests/test_database_types.py index c20b5fd8..89107a34 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -81,7 +81,7 @@ "float": [ "numeric", "float64", - ] + ], }, db.Snowflake: { # https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#int-integer-bigint-smallint-tinyint-byteint @@ -249,19 +249,13 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego dst_conn.query(f"CREATE TABLE {dst_table}(id int, col {target_type})", None) _insert_to_table(dst_conn, dst_table, values_in_source) - self.table = TableSegment( - self.src_conn, src_table_path, "id", None, ("col",), case_sensitive=False - ) - self.table2 = TableSegment( - self.dst_conn, dst_table_path, "id", None, ("col",), case_sensitive=False - ) + self.table = TableSegment(self.src_conn, src_table_path, "id", None, ("col",), case_sensitive=False) + self.table2 = TableSegment(self.dst_conn, dst_table_path, "id", None, ("col",), case_sensitive=False) self.assertEqual(len(sample_values), self.table.count()) self.assertEqual(len(sample_values), self.table2.count()) - differ = TableDiffer( - bisection_threshold=3, bisection_factor=2 - ) # ensure we actually checksum + differ = TableDiffer(bisection_threshold=3, bisection_factor=2) # ensure we actually checksum diff = list(differ.diff_tables(self.table, self.table2)) expected = [] self.assertEqual(expected, diff) diff --git a/tests/test_normalize_fields.py b/tests/test_normalize_fields.py index 2953f8ad..468d2667 100644 --- a/tests/test_normalize_fields.py +++ b/tests/test_normalize_fields.py @@ -5,7 +5,7 @@ import preql -from data_diff.database import BigQuery, MySQL, Snowflake, connect_to_uri, Oracle, DEFAULT_PRECISION +from data_diff.database import BigQuery, MySQL, Snowflake, connect_to_uri, Oracle from data_diff.sql import Select from data_diff import database as db From 902542f1a0a7f60614b20c0eeb9c7895d47a1c6d Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Mon, 20 Jun 2022 13:43:33 +0200 Subject: [PATCH 03/32] Numeric-precision BigQuery tests passing (WIP) --- data_diff/database.py | 29 ++++++++++++++++++++++++----- tests/test_database_types.py | 14 +++++++++++++- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/data_diff/database.py b/data_diff/database.py index fad589f4..fa191367 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -239,13 +239,14 @@ def _parse_type( cls = self.NUMERIC_TYPES.get(type_repr) if cls: - assert numeric_precision is not None if cls is Decimal: - assert numeric_scale is not None + assert numeric_precision is not None + assert numeric_scale is not None, (type_repr, numeric_precision, numeric_scale) return cls(precision=numeric_scale) - assert numeric_scale is None - return cls(precision=numeric_precision // 3) + assert cls is Float + # assert numeric_scale is None + return cls(precision=(numeric_precision if numeric_precision is not None else 15) // 3) return UnknownColType(type_repr) @@ -642,6 +643,13 @@ class BigQuery(Database): "TIMESTAMP": Timestamp, "DATETIME": Datetime, } + NUMERIC_TYPES = { + "INT64": Decimal, + "INT32": Decimal, + "NUMERIC": Decimal, + "FLOAT64": Float, + "FLOAT32": Float, + } ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation def __init__(self, project, *, dataset, **kw): @@ -687,7 +695,7 @@ def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) return ( - f"SELECT column_name, data_type, 6 as datetime_precision, 6 as numeric_precision FROM {schema}.INFORMATION_SCHEMA.COLUMNS " + f"SELECT column_name, data_type, 6 as datetime_precision, 38 as numeric_precision, 9 as numeric_scale FROM {schema}.INFORMATION_SCHEMA.COLUMNS " f"WHERE table_name = '{table}' AND table_schema = '{schema}'" ) @@ -705,6 +713,10 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + elif isinstance(coltype, NumericType): + # value = f"cast({value} as decimal)" + return f"format('%.{coltype.precision}f', cast({value} as decimal))" + return self.to_string(f"{value}") def parse_table_name(self, name: str) -> DbPath: @@ -718,6 +730,10 @@ class Snowflake(Database): "TIMESTAMP_LTZ": Timestamp, "TIMESTAMP_TZ": TimestampTZ, } + NUMERIC_TYPES = { + "NUMBER": Decimal, + "FLOAT": Float, + } ROUNDS_ON_PREC_LOSS = False def __init__( @@ -784,6 +800,9 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')" + elif isinstance(coltype, NumericType): + value = f"cast({value} as decimal(38, {coltype.precision}))" + return self.to_string(f"{value}") diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 89107a34..1dbb5308 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -1,12 +1,19 @@ from contextlib import suppress import unittest import time +import logging +from decimal import Decimal + +from parameterized import parameterized, parameterized_class +import preql + from data_diff import database as db from data_diff.diff_tables import TableDiffer, TableSegment from parameterized import parameterized, parameterized_class from .common import CONN_STRINGS import logging + logging.getLogger("diff_tables").setLevel(logging.WARN) logging.getLogger("database").setLevel(logging.WARN) @@ -209,7 +216,12 @@ def _insert_to_table(conn, table, values): else: insertion_query += ' VALUES ' for j, sample in values: - insertion_query += f"({j}, '{sample}')," + if isinstance(sample, (float, Decimal)): + value = str(sample) + else: + value = f"'{sample}'" + insertion_query += f"({j}, {value})," + insertion_query = insertion_query[0:-1] conn.query(insertion_query, None) From a91dbab73f8282e54ba29ac2a0703af39505f898 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Mon, 20 Jun 2022 14:42:39 +0200 Subject: [PATCH 04/32] Presto numeric-precision passing tests (WIP) --- data_diff/database.py | 41 ++++++++++++++++++++++++++++++------ tests/test_database_types.py | 15 +++++++------ 2 files changed, 42 insertions(+), 14 deletions(-) diff --git a/data_diff/database.py b/data_diff/database.py index fa191367..6dd76397 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -63,6 +63,7 @@ def import_presto(): class ConnectError(Exception): pass + class QueryError(Exception): pass @@ -118,6 +119,12 @@ class Decimal(NumericType): pass +@dataclass +class Integer(Decimal): + def __post_init__(self): + assert self.precision == 0 + + @dataclass class UnknownColType(ColType): text: str @@ -239,12 +246,12 @@ def _parse_type( cls = self.NUMERIC_TYPES.get(type_repr) if cls: - if cls is Decimal: + if issubclass(cls, Decimal): assert numeric_precision is not None assert numeric_scale is not None, (type_repr, numeric_precision, numeric_scale) return cls(precision=numeric_scale) - assert cls is Float + assert issubclass(cls, Float) # assert numeric_scale is None return cls(precision=(numeric_precision if numeric_precision is not None else 15) // 3) @@ -340,7 +347,7 @@ class Postgres(ThreadedDatabase): "double precision": Float, "real": Float, "decimal": Decimal, - "integer": Decimal, + "integer": Integer, "numeric": Decimal, } ROUNDS_ON_PREC_LOSS = True @@ -401,6 +408,10 @@ class Presto(Database): "timestamp": Timestamp, # "datetime": Datetime, } + NUMERIC_TYPES = { + "integer": Integer, + "real": Float, + } ROUNDS_ON_PREC_LOSS = True def __init__(self, host, port, user, password, *, catalog, schema=None, **kw): @@ -440,6 +451,9 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" ) + elif isinstance(coltype, NumericType): + value = f"cast({value} as decimal(38,{coltype.precision}))" + return self.to_string(value) def select_table_schema(self, path: DbPath) -> str: @@ -465,6 +479,17 @@ def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_pr rounds=False, ) + cls = self.NUMERIC_TYPES.get(type_repr) + if cls: + if issubclass(cls, Integer): + assert numeric_precision is not None + return cls(0) + elif issubclass(cls, Decimal): + return cls(6) + + assert issubclass(cls, Float) + return cls(precision=(numeric_precision if numeric_precision is not None else 15) // 3) + return UnknownColType(type_repr) @@ -581,8 +606,10 @@ def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_pr m = re.match(regexp + "$", type_repr) if m: datetime_precision = int(m.group(1)) - return cls(precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, - rounds=self.ROUNDS_ON_PREC_LOSS + return cls( + precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, + rounds=self.ROUNDS_ON_PREC_LOSS, + ) return UnknownColType(type_repr) @@ -644,8 +671,8 @@ class BigQuery(Database): "DATETIME": Datetime, } NUMERIC_TYPES = { - "INT64": Decimal, - "INT32": Decimal, + "INT64": Integer, + "INT32": Integer, "NUMERIC": Decimal, "FLOAT64": Float, "FLOAT32": Float, diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 1dbb5308..69e83d93 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -158,9 +158,9 @@ # "datetime(6)", ], "float": [ - # "float", - # "double", - # "numeric", + "float", + "double", + "numeric", ], }, } @@ -211,10 +211,10 @@ def _insert_to_table(conn, table, values): if isinstance(conn, db.Oracle): selects = [] for j, sample in values: - selects.append( f"SELECT {j}, timestamp '{sample}' FROM dual" ) - insertion_query += ' UNION ALL '.join(selects) + selects.append(f"SELECT {j}, timestamp '{sample}' FROM dual") + insertion_query += " UNION ALL ".join(selects) else: - insertion_query += ' VALUES ' + insertion_query += " VALUES " for j, sample in values: if isinstance(sample, (float, Decimal)): value = str(sample) @@ -228,6 +228,7 @@ def _insert_to_table(conn, table, values): if not isinstance(conn, db.BigQuery): conn.query("COMMIT", None) + def _drop_table_if_exists(conn, table): with suppress(db.QueryError): if isinstance(conn, db.Oracle): @@ -235,6 +236,7 @@ def _drop_table_if_exists(conn, table): else: conn.query(f"DROP TABLE IF EXISTS {table}", None) + class TestDiffCrossDatabaseTables(unittest.TestCase): @parameterized.expand(type_pairs, name_func=expand_params) def test_types(self, source_db, target_db, source_type, target_type, type_category): @@ -282,4 +284,3 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego duration = time.time() - start # print(f"source_db={source_db.__name__} target_db={target_db.__name__} source_type={source_type} target_type={target_type} duration={round(duration * 1000, 2)}ms") - From c2e86975a2985bc30c9c1ef96f3e9449ebec59ff Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Mon, 20 Jun 2022 17:19:33 +0200 Subject: [PATCH 05/32] Oracle numeric-precision passing tests --- data_diff/database.py | 24 ++++++++++++++++++++++-- data_diff/diff_tables.py | 12 ++++++------ tests/test_database_types.py | 13 +++++++++---- 3 files changed, 37 insertions(+), 12 deletions(-) diff --git a/data_diff/database.py b/data_diff/database.py index 6dd76397..ceafe0c0 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -587,16 +587,24 @@ def select_table_schema(self, path: DbPath) -> str: (table,) = path return ( - f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision" + f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision, data_scale as numeric_scale" f" FROM USER_TAB_COLUMNS WHERE table_name = '{table.upper()}'" ) def normalize_value_by_type(self, value: str, coltype: ColType) -> str: if isinstance(coltype, TemporalType): return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')" + elif isinstance(coltype, NumericType): + # FM999.9990 + format_str = "FM" + "9" * (38 - coltype.precision) + if coltype.precision: + format_str += "0." + "9" * (coltype.precision - 1) + "0" + return f"to_char({value}, '{format_str}')" return self.to_string(f"{value}") - def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None) -> ColType: + def _parse_type( + self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None, numeric_scale: int = None + ) -> ColType: """ """ regexps = { r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp, @@ -611,6 +619,18 @@ def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_pr rounds=self.ROUNDS_ON_PREC_LOSS, ) + cls = { + "NUMBER": Decimal, + "FLOAT": Float, + }.get(type_repr, None) + if cls: + if issubclass(cls, Decimal): + assert numeric_scale is not None, (type_repr, numeric_precision, numeric_scale) + return cls(precision=numeric_scale) + + assert issubclass(cls, Float) + return cls(precision=(numeric_precision if numeric_precision is not None else 15) // 3) + return UnknownColType(type_repr) diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 9838189c..565d51a2 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -147,10 +147,10 @@ def with_schema(self) -> "TableSegment": schema = Schema_CaseSensitive(schema) else: if len({k.lower() for k in schema}) < len(schema): - logger.warn( + logger.warning( f'Ambiguous schema for {self.database}:{".".join(self.table_path)} | Columns = {", ".join(list(schema))}' ) - logger.warn("We recommend to disable case-insensitivity (remove --any-case).") + logger.warning("We recommend to disable case-insensitivity (remove --any-case).") schema = Schema_CaseInsensitive(schema) return self.new(_schema=schema) @@ -241,7 +241,7 @@ def count_and_checksum(self) -> Tuple[int, int]: ) duration = time.time() - start if duration > RECOMMENDED_CHECKSUM_DURATION: - logger.warn( + logger.warning( f"Checksum is taking longer than expected ({duration:.2f}s). " "We recommend increasing --bisection-factor or decreasing --threads." ) @@ -364,7 +364,7 @@ def _validate_and_adjust_columns(self, table1, table2): lowest = min(col1, col2, key=attrgetter("precision")) if col1.precision != col2.precision: - logger.warn(f"Using reduced precision {lowest} for column '{c}'. Types={col1}, {col2}") + logger.warning(f"Using reduced precision {lowest} for column '{c}'. Types={col1}, {col2}") table1._schema[c] = col1.replace(precision=lowest.precision, rounds=lowest.rounds) table2._schema[c] = col2.replace(precision=lowest.precision, rounds=lowest.rounds) @@ -376,7 +376,7 @@ def _validate_and_adjust_columns(self, table1, table2): lowest = min(col1, col2, key=attrgetter("precision")) if col1.precision != col2.precision: - logger.warn(f"Using reduced precision {lowest} for column '{c}'. Types={col1}, {col2}") + logger.warning(f"Using reduced precision {lowest} for column '{c}'. Types={col1}, {col2}") table1._schema[c] = col1.replace(precision=lowest.precision) table2._schema[c] = col2.replace(precision=lowest.precision) @@ -424,7 +424,7 @@ def _diff_tables(self, table1, table2, level=0, segment_index=None, segment_coun (count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2]) if count1 == 0 and count2 == 0: - logger.warn( + logger.warning( "Uneven distribution of keys detected. (big gaps in the key column). " "For better performance, we recommend to increase the bisection-threshold." ) diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 69e83d93..97348aa3 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -14,7 +14,7 @@ import logging -logging.getLogger("diff_tables").setLevel(logging.WARN) +logging.getLogger("diff_tables").setLevel(logging.ERROR) logging.getLogger("database").setLevel(logging.WARN) CONNS = {k: db.connect_to_uri(v) for k, v in CONN_STRINGS.items()} @@ -138,8 +138,8 @@ "timestamp(9) with local time zone", ], "float": [ - # "float", - # "numeric", + "float", + "numeric", ], }, db.Presto: { @@ -211,7 +211,11 @@ def _insert_to_table(conn, table, values): if isinstance(conn, db.Oracle): selects = [] for j, sample in values: - selects.append(f"SELECT {j}, timestamp '{sample}' FROM dual") + if isinstance(sample, (float, Decimal, int)): + value = str(sample) + else: + value = f"timestamp '{sample}'" + selects.append(f"SELECT {j}, {value} FROM dual") insertion_query += " UNION ALL ".join(selects) else: insertion_query += " VALUES " @@ -233,6 +237,7 @@ def _drop_table_if_exists(conn, table): with suppress(db.QueryError): if isinstance(conn, db.Oracle): conn.query(f"DROP TABLE {table}", None) + conn.query(f"DROP TABLE {table}", None) else: conn.query(f"DROP TABLE IF EXISTS {table}", None) From ff1a6d6ad22a251341b2d74f84c62a10f14cb5c3 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Mon, 20 Jun 2022 18:17:58 +0200 Subject: [PATCH 06/32] Numeric precision: Convert the precision correctly. Added redshift. --- data_diff/database.py | 39 +++++++++++++++++++++++++++++++----- tests/test_database_types.py | 6 +++--- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/data_diff/database.py b/data_diff/database.py index ceafe0c0..18dd275b 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -1,3 +1,4 @@ +import math from functools import lru_cache from itertools import zip_longest import re @@ -232,6 +233,10 @@ def query(self, sql_ast: SqlOrStr, res_type: type): def enable_interactive(self): self._interactive = True + def _convert_db_precision_to_digits(self, p: int) -> int: + # See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format + return math.floor(math.log(2**p, 10)) + def _parse_type( self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None, numeric_scale: int = None ) -> ColType: @@ -253,7 +258,11 @@ def _parse_type( assert issubclass(cls, Float) # assert numeric_scale is None - return cls(precision=(numeric_precision if numeric_precision is not None else 15) // 3) + return cls( + precision=self._convert_db_precision_to_digits( + numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION + ) + ) return UnknownColType(type_repr) @@ -331,7 +340,7 @@ def close(self): CHECKSUM_MASK = (2**_CHECKSUM_BITSIZE) - 1 DEFAULT_DATETIME_PRECISION = 6 -DEFAULT_NUMERIC_PRECISION = 6 +DEFAULT_NUMERIC_PRECISION = 24 TIMESTAMP_PRECISION_POS = 20 # len("2022-06-03 12:24:35.") == 20 @@ -395,7 +404,7 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" elif isinstance(coltype, NumericType): - value = f"{value}::decimal(38,{coltype.precision})" + value = f"{value}::decimal(38, {coltype.precision})" return self.to_string(f"{value}") @@ -488,7 +497,11 @@ def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_pr return cls(6) assert issubclass(cls, Float) - return cls(precision=(numeric_precision if numeric_precision is not None else 15) // 3) + return cls( + precision=self._convert_db_precision_to_digits( + numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION + ) + ) return UnknownColType(type_repr) @@ -629,12 +642,25 @@ def _parse_type( return cls(precision=numeric_scale) assert issubclass(cls, Float) - return cls(precision=(numeric_precision if numeric_precision is not None else 15) // 3) + return cls( + precision=self._convert_db_precision_to_digits( + numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION + ) + ) return UnknownColType(type_repr) class Redshift(Postgres): + NUMERIC_TYPES = { + **Postgres.NUMERIC_TYPES, + "double": Float, + "real": Float, + } + + def _convert_db_precision_to_digits(self, p: int) -> int: + return super()._convert_db_precision_to_digits(p // 2) + def md5_to_int(self, s: str) -> str: return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)" @@ -655,6 +681,9 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + elif isinstance(coltype, NumericType): + value = f"{value}::decimal(38,{coltype.precision})" + return self.to_string(f"{value}") diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 97348aa3..7dfc45da 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -123,9 +123,9 @@ ], # https://docs.aws.amazon.com/redshift/latest/dg/r_Numeric_types201.html#r_Numeric_types201-floating-point-types "float": [ - # "float4", - # "float8", - # "numeric", + "float4", + "float8", + "numeric", ], }, db.Oracle: { From 315c244907b401896deccdf0b30f2c923d744443 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Mon, 20 Jun 2022 19:44:15 +0200 Subject: [PATCH 07/32] Fix for BigQuery + more tests. --- data_diff/database.py | 10 +++++++--- tests/common.py | 6 +++++- tests/test_database_types.py | 1 + 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/data_diff/database.py b/data_diff/database.py index 18dd275b..5dc257f8 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -251,9 +251,12 @@ def _parse_type( cls = self.NUMERIC_TYPES.get(type_repr) if cls: - if issubclass(cls, Decimal): - assert numeric_precision is not None - assert numeric_scale is not None, (type_repr, numeric_precision, numeric_scale) + if issubclass(cls, Integer): + # Some DBs have a constant numeric_scale, so they don't report it. + # We fill in the constant, so we need to ignore it for integers. + return cls(precision=0) + + elif issubclass(cls, Decimal): return cls(precision=numeric_scale) assert issubclass(cls, Float) @@ -723,6 +726,7 @@ class BigQuery(Database): "INT64": Integer, "INT32": Integer, "NUMERIC": Decimal, + "BIGNUMERIC": Decimal, "FLOAT64": Float, "FLOAT32": Float, } diff --git a/tests/common.py b/tests/common.py index 33281861..32c4c30b 100644 --- a/tests/common.py +++ b/tests/common.py @@ -19,8 +19,12 @@ except ImportError: pass # No local settings +if TEST_BIGQUERY_CONN_STRING and TEST_SNOWFLAKE_CONN_STRING: + # TODO Fix this. Seems to have something to do with pyarrow + raise RuntimeError("Using BigQuery at the same time as Snowflake causes an error!!") + CONN_STRINGS = { - # db.BigQuery: TEST_BIGQUERY_CONN_STRING, # TODO BigQuery before/after Snowflake causes an error! + db.BigQuery: TEST_BIGQUERY_CONN_STRING, db.MySQL: TEST_MYSQL_CONN_STRING, db.Postgres: TEST_POSTGRES_CONN_STRING, db.Snowflake: TEST_SNOWFLAKE_CONN_STRING, diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 7dfc45da..84e422ef 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -88,6 +88,7 @@ "float": [ "numeric", "float64", + "bignumeric", ], }, db.Snowflake: { From 56371dad878c010a0dc63a11452e486cea8f0755 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Mon, 20 Jun 2022 20:02:49 +0200 Subject: [PATCH 08/32] README: Added docs link --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index aab3ac0b..8e92f3d7 100644 --- a/README.md +++ b/README.md @@ -167,6 +167,10 @@ Options: ## How to use from Python +API reference: [https://data-diff.readthedocs.io/en/latest/](https://data-diff.readthedocs.io/en/latest/) + +Example: + ```python # Optional: Set logging to display the progress of the diff import logging @@ -182,7 +186,7 @@ for different_row in diff_tables(table1, table2): print(plus_or_minus, columns) ``` -Run `help(diff_tables)` or read the docs [ADD LINK] to learn about the different options. +Run `help(diff_tables)` or [read the docs](https://data-diff.readthedocs.io/en/latest/) to learn about the different options. # Technical Explanation From fab39853f0cb2d695f8dc7597498c5739f520d40 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Mon, 20 Jun 2022 20:29:04 +0200 Subject: [PATCH 09/32] Docs: Small update --- data_diff/__main__.py | 1 - docs/index.rst | 16 +++++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index dcdd7f27..5b357153 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -1,4 +1,3 @@ -from multiprocessing.sharedctypes import Value import sys import time import logging diff --git a/docs/index.rst b/docs/index.rst index 6caab45d..094e2009 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -5,11 +5,13 @@ python-api +Introduction +------------ -**data-diff** is a command-line tool and Python library to efficiently diff +**Data-diff** is a command-line tool and Python library to efficiently diff rows across two different databases. -⇄ Verifies across many different databases (e.g. Postgres -> Snowflake) ! +⇄ Verifies across many different databases (e.g. *Postgres* -> *Snowflake*) ! 🔍 Outputs diff of rows in detail @@ -51,9 +53,13 @@ How to use from Python table1 = connect_to_table("postgres:///", "table_name", "id") table2 = connect_to_table("mysql:///", "table_name", "id") - for different_row in diff_tables(table1, table2): - plus_or_minus, columns = different_row - print(plus_or_minus, columns) + for sign, columns in diff_tables(table1, table2): + print(sign, columns) + + # Example output: + + ('4775622148347', '2022-06-05 16:57:32.000000') + - ('4775622312187', '2022-06-05 16:57:32.000000') + - ('4777375432955', '2022-06-07 16:57:36.000000') Resources From c7367ba8de4b31d491efa805cb0d1ca5535c2540 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 21 Jun 2022 09:28:23 +0200 Subject: [PATCH 10/32] Added some comments --- data_diff/database.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/data_diff/database.py b/data_diff/database.py index 5dc257f8..158b82a5 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -109,6 +109,7 @@ class Datetime(TemporalType): @dataclass class NumericType(ColType): + # 'precision' signifies how many fractional digits (after the dot) we want to compare precision: int @@ -183,6 +184,19 @@ def normalize_value_by_type(value: str, coltype: ColType) -> str: Rounded up/down according to coltype.rounds + - Floats/Decimals are expected in the format + "I.P" + + Where I is the integer part of the number (as many digits as necessary), + and must be at least one digit (0). + P is the fractional digits, the amount of which is specified with + coltype.precision. Trailing zeroes may be necessary. + + Note: This precision is different than the one used by databases. For decimals, + it's the same as "numeric_scale", and for floats, who use binary precision, + it can be calculated as log10(2**p) + + """ ... @@ -234,6 +248,7 @@ def enable_interactive(self): self._interactive = True def _convert_db_precision_to_digits(self, p: int) -> int: + """Convert from binary precision, used by floats, to decimal precision.""" # See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format return math.floor(math.log(2**p, 10)) From e7bd98a5827dd20f66bbbe332d1990feb9fa1e22 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 21 Jun 2022 15:25:58 +0200 Subject: [PATCH 11/32] Fix for BigQuery: Table-name no longer needs dataset, takes it from URI Now we can run: data_diff bigquery://datafold-dev-2/data_diff rating bigquery://datafold-dev-2/data_diff rating_del1 --- data_diff/sql.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/data_diff/sql.py b/data_diff/sql.py index 1c19aef1..a81839eb 100644 --- a/data_diff/sql.py +++ b/data_diff/sql.py @@ -46,7 +46,8 @@ class TableName(Sql): name: DbPath def compile(self, c: Compiler): - return ".".join(map(c.quote, self.name)) + path = c.database._normalize_table_path(self.name) + return ".".join(map(c.quote, path)) @dataclass From 2c65adf8857b18c7432bbe38e9527eb4f6ab77b2 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 21 Jun 2022 16:40:33 +0200 Subject: [PATCH 12/32] Some fixes for Presto --- data_diff/database.py | 15 +++++++++++---- tests/test_database_types.py | 5 +++-- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/data_diff/database.py b/data_diff/database.py index 158b82a5..7afdba66 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -438,6 +438,7 @@ class Presto(Database): NUMERIC_TYPES = { "integer": Integer, "real": Float, + "double": Float, } ROUNDS_ON_PREC_LOSS = True @@ -492,7 +493,6 @@ def select_table_schema(self, path: DbPath) -> str: ) def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None) -> ColType: - """ """ regexps = { r"timestamp\((\d)\)": Timestamp, r"timestamp\((\d)\) with time zone": TimestampTZ, @@ -506,13 +506,20 @@ def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_pr rounds=False, ) + regexps = { + r"decimal\((\d+),(\d+)\)": Decimal + } + for regexp, cls in regexps.items(): + m = re.match(regexp + "$", type_repr) + if m: + prec, scale = map(int, m.groups()) + return cls(scale) + cls = self.NUMERIC_TYPES.get(type_repr) if cls: if issubclass(cls, Integer): assert numeric_precision is not None return cls(0) - elif issubclass(cls, Decimal): - return cls(6) assert issubclass(cls, Float) return cls( @@ -533,7 +540,7 @@ class MySQL(ThreadedDatabase): "double": Float, "float": Float, "decimal": Decimal, - "int": Decimal, + "int": Integer, } ROUNDS_ON_PREC_LOSS = True diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 84e422ef..cb852ca7 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -159,9 +159,10 @@ # "datetime(6)", ], "float": [ - "float", + "real", "double", - "numeric", + "decimal(10,2)", + "decimal(30,6)", ], }, } From 79670f84ee54834706d19c9bb3f687cfcb738549 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 21 Jun 2022 17:51:56 +0200 Subject: [PATCH 13/32] Test dates again; More float tests; Patch for weird Postgres behavior --- data_diff/database.py | 12 ++++---- tests/test_database_types.py | 55 ++++++++++++++++++++++++------------ 2 files changed, 44 insertions(+), 23 deletions(-) diff --git a/data_diff/database.py b/data_diff/database.py index 7afdba66..58ee74bd 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -386,6 +386,10 @@ def __init__(self, host, port, user, password, *, database, thread_count, **kw): super().__init__(thread_count=thread_count) + def _convert_db_precision_to_digits(self, p: int) -> int: + # Subtracting 2 due to wierd precision issues in Postgres + return super()._convert_db_precision_to_digits(p) - 2 + def create_connection(self): postgres = import_postgres() try: @@ -506,9 +510,7 @@ def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_pr rounds=False, ) - regexps = { - r"decimal\((\d+),(\d+)\)": Decimal - } + regexps = {r"decimal\((\d+),(\d+)\)": Decimal} for regexp, cls in regexps.items(): m = re.match(regexp + "$", type_repr) if m: @@ -683,8 +685,8 @@ class Redshift(Postgres): "real": Float, } - def _convert_db_precision_to_digits(self, p: int) -> int: - return super()._convert_db_precision_to_digits(p // 2) + # def _convert_db_precision_to_digits(self, p: int) -> int: + # return super()._convert_db_precision_to_digits(p // 2) def md5_to_int(self, s: str) -> str: return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)" diff --git a/tests/test_database_types.py b/tests/test_database_types.py index cb852ca7..6cab7b1c 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -31,7 +31,24 @@ "2022-05-01 15:10:03.003030", "2022-06-01 15:10:05.009900", ], - "float": [0.0, 0.1, 0.00188, 0.99999, 0.091919, 0.10, 10.0, 100.98], + "float": [ + 0.0, + 0.1, + 0.00188, + 0.99999, + 0.091919, + 0.10, + 10.0, + 100.98, + 0.001201923076923077, + 1 / 3, + 1 / 5, + 1 / 109, + 1 / 109489, + 1 / 1094893892389, + 1 / 10948938923893289, + 3.141592653589793, + ], } DATABASE_TYPES = { @@ -44,13 +61,14 @@ ], # https://www.postgresql.org/docs/current/datatype-datetime.html "datetime_no_timezone": [ - # "timestamp(6) without time zone", - # "timestamp(3) without time zone", - # "timestamp(0) without time zone", + "timestamp(6) without time zone", + "timestamp(3) without time zone", + "timestamp(0) without time zone", ], # https://www.postgresql.org/docs/current/datatype-numeric.html "float": [ "real", + "float", "double precision", "numeric(6,3)", ], @@ -66,11 +84,11 @@ ], # https://dev.mysql.com/doc/refman/8.0/en/datetime.html "datetime_no_timezone": [ - # "timestamp(6)", - # "timestamp(3)", - # "timestamp(0)", - # "timestamp", - # "datetime(6)", + "timestamp(6)", + "timestamp(3)", + "timestamp(0)", + "timestamp", + "datetime(6)", ], # https://dev.mysql.com/doc/refman/8.0/en/numeric-types.html "float": [ @@ -104,10 +122,10 @@ ], # https://docs.snowflake.com/en/sql-reference/data-types-datetime.html "datetime_no_timezone": [ - # "timestamp(0)", - # "timestamp(3)", - # "timestamp(6)", - # "timestamp(9)", + "timestamp(0)", + "timestamp(3)", + "timestamp(6)", + "timestamp(9)", ], # https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#decimal-numeric "float": [ @@ -152,11 +170,11 @@ # "bigint", # 8 bytes ], "datetime_no_timezone": [ - # "timestamp(6)", - # "timestamp(3)", - # "timestamp(0)", - # "timestamp", - # "datetime(6)", + "timestamp(6)", + "timestamp(3)", + "timestamp(0)", + "timestamp", + "datetime(6)", ], "float": [ "real", @@ -231,6 +249,7 @@ def _insert_to_table(conn, table, values): insertion_query = insertion_query[0:-1] conn.query(insertion_query, None) + if not isinstance(conn, db.BigQuery): conn.query("COMMIT", None) From 154f2de5ddfd2596ecdcf3a843c0318cabb6a271 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 21 Jun 2022 18:29:56 +0200 Subject: [PATCH 14/32] Allow to run tests in parallel Before: Ran 271 tests in 794.271s After: Ran 271 tests in 85.712s --- tests/test_api.py | 1 + tests/test_database_types.py | 8 ++++++-- tests/test_diff_tables.py | 19 ++++++++++++++++++- 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/tests/test_api.py b/tests/test_api.py index cd5b9c19..2a532edd 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -15,6 +15,7 @@ def setUpClass(cls): cls.preql = preql.Preql(TEST_MYSQL_CONN_STRING) def setUp(self) -> None: + self.preql = preql.Preql(TEST_MYSQL_CONN_STRING) self.preql( r""" table test_api { diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 6cab7b1c..1e5602f9 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -274,8 +274,12 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego self.connections = [self.src_conn, self.dst_conn] sample_values = TYPE_SAMPLES[type_category] - src_table_path = src_conn.parse_table_name("src") - dst_table_path = dst_conn.parse_table_name("dst") + # Limit in MySQL is 64 + src_table_name = f"src_{self._testMethodName[:60]}" + dst_table_name = f"dst_{self._testMethodName[:60]}" + + src_table_path = src_conn.parse_table_name(src_table_name) + dst_table_path = dst_conn.parse_table_name(dst_table_name) src_table = src_conn.quote(".".join(src_table_path)) dst_table = dst_conn.quote(".".join(dst_table_path)) diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index a457081d..84bb6b9a 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -33,8 +33,24 @@ def tearDownClass(cls): cls.connection.close() + # Fallback for test runners that doesn't support setUpClass/tearDownClass + def setUp(self) -> None: + if not hasattr(self, 'connection'): + self.setUpClass.__func__(self) + self.private_connection = True + + return super().setUp() + + def tearDown(self) -> None: + if hasattr(self, 'private_connection'): + self.tearDownClass.__func__(self) + + return super().tearDown() + + class TestDates(TestWithConnection): def setUp(self): + super().setUp() self.connection.query("DROP TABLE IF EXISTS a", None) self.connection.query("DROP TABLE IF EXISTS b", None) self.preql( @@ -110,6 +126,7 @@ def test_offset(self): class TestDiffTables(TestWithConnection): def setUp(self): + super().setUp() self.connection.query("DROP TABLE IF EXISTS ratings_test", None) self.connection.query("DROP TABLE IF EXISTS ratings_test2", None) self.preql.load("./tests/setup.pql") @@ -221,9 +238,9 @@ def test_diff_sorted_by_key(self): class TestTableSegment(TestWithConnection): def setUp(self) -> None: + super().setUp() self.table = TableSegment(self.connection, ("ratings_test",), "id", "timestamp") self.table2 = TableSegment(self.connection, ("ratings_test2",), "id", "timestamp") - return super().setUp() def test_table_segment(self): early = datetime.datetime(2021, 1, 1, 0, 0) From 6ef0a07e48bfd0d30ed77a9def45d19193696d89 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 21 Jun 2022 19:02:49 +0200 Subject: [PATCH 15/32] Fix for BigQuery --- data_diff/database.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/data_diff/database.py b/data_diff/database.py index 58ee74bd..0a4d900e 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -816,10 +816,12 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + elif isinstance(coltype, Integer): + pass elif isinstance(coltype, NumericType): # value = f"cast({value} as decimal)" - return f"format('%.{coltype.precision}f', cast({value} as decimal))" + return f"format('%.{coltype.precision}f', ({value}))" return self.to_string(f"{value}") From f09a86625990ec466c1e9a2ff4a2e32c5881aee5 Mon Sep 17 00:00:00 2001 From: Simon Eskildsen Date: Tue, 21 Jun 2022 13:47:50 -0400 Subject: [PATCH 16/32] cli: output diff as jsonl, stats as json --- data_diff/__main__.py | 18 ++++++++++++++---- data_diff/diff_tables.py | 8 +++++++- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 5b357153..93ccf231 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -1,5 +1,6 @@ import sys import time +import json import logging from itertools import islice @@ -146,15 +147,24 @@ def main( unique_diff_count = len({i[0] for _, i in diff}) table1_count = differ.stats.get("table1_count") percent = 100 * unique_diff_count / (table1_count or 1) - print(f"Diff-Total: {len(diff)} changed rows out of {table1_count}") - print(f"Diff-Percent: {percent:.4f}%") plus = len([1 for op, _ in diff if op == "+"]) minus = len([1 for op, _ in diff if op == "-"]) - print(f"Diff-Split: +{plus} -{minus}") + + count = differ.stats["table_count"] + diff = { + "different_rows": len(diff), + "different_percent": percent, + "different_+": plus, + "different_-": minus, + "total": count, + } + + print(json.dumps(diff, indent=2)) else: for op, key in diff_iter: color = COLOR_SCHEME[op] - rich.print(f"[{color}]{op} {key!r}[/{color}]") + jsonl = json.dumps([op, list(key)]) + rich.print(f"[{color}]{jsonl}[/{color}]") sys.stdout.flush() end = time.time() diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 4087b49d..05d082fa 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -381,6 +381,12 @@ def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None): if max_rows < self.bisection_threshold: rows1, rows2 = self._threaded_call("get_values", [table1, table2]) diff = list(diff_sets(rows1, rows2)) + + # This happens when the initial bisection threshold is larger than + # the table itself. + if level == 0 and not self.stats.get("table_count", False): + self.stats["table_count"] = self.stats.get("table_count", 0) + max(len(rows1), len(rows2)) + logger.info(". " * level + f"Diff found {len(diff)} different rows.") self.stats["rows_downloaded"] = self.stats.get("rows_downloaded", 0) + max(len(rows1), len(rows2)) yield from diff @@ -420,7 +426,7 @@ def _diff_tables(self, table1, table2, level=0, segment_index=None, segment_coun return if level == 1: - self.stats["table1_count"] = self.stats.get("table1_count", 0) + count1 + self.stats["table_count"] = self.stats.get("table_count", 0) + max(count1, count2) if checksum1 != checksum2: yield from self._bisect_and_diff_tables(table1, table2, level=level, max_rows=max(count1, count2)) From 94d1419091ee4c6266a4ea66713ff8ee280d9376 Mon Sep 17 00:00:00 2001 From: Simon Eskildsen Date: Tue, 21 Jun 2022 14:36:24 -0400 Subject: [PATCH 17/32] cli: add --json for stats, table1 + table2 counts --- data_diff/__main__.py | 29 +++++++++++++++++------------ data_diff/diff_tables.py | 8 +++++--- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 93ccf231..fc32899e 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -51,6 +51,7 @@ @click.option("--max-age", default=None, help="Considers only rows younger than specified. See --min-age.") @click.option("-s", "--stats", is_flag=True, help="Print stats instead of a detailed diff") @click.option("-d", "--debug", is_flag=True, help="Print debug info") +@click.option("--json", 'json_output', is_flag=True, help="Print JSON output for --stats") @click.option("-v", "--verbose", is_flag=True, help="Print extra info") @click.option("-i", "--interactive", is_flag=True, help="Confirm queries, implies --debug") @click.option("--keep-column-case", is_flag=True, help="Don't use the schema to fix the case of given column names.") @@ -81,6 +82,7 @@ def main( interactive, threads, keep_column_case, + json_output, ): if limit and stats: print("Error: cannot specify a limit when using the -s/--stats switch") @@ -145,21 +147,24 @@ def main( if stats: diff = list(diff_iter) unique_diff_count = len({i[0] for _, i in diff}) - table1_count = differ.stats.get("table1_count") - percent = 100 * unique_diff_count / (table1_count or 1) + max_table_count = max(differ.stats["table1_count"], differ.stats["table2_count"]) + percent = 100 * unique_diff_count / (max_table_count or 1) plus = len([1 for op, _ in diff if op == "+"]) minus = len([1 for op, _ in diff if op == "-"]) - count = differ.stats["table_count"] - diff = { - "different_rows": len(diff), - "different_percent": percent, - "different_+": plus, - "different_-": minus, - "total": count, - } - - print(json.dumps(diff, indent=2)) + if json_output: + json_output = { + "different_rows": len(diff), + "different_percent": percent, + "different_+": plus, + "different_-": minus, + "total": max_table_count, + } + print(json.dumps(json_output, indent=2)) + else: + print(f"Diff-Total: {len(diff)} changed rows out of {max_table_count}") + print(f"Diff-Percent: {percent:.14f}%") + print(f"Diff-Split: +{plus} -{minus}") else: for op, key in diff_iter: color = COLOR_SCHEME[op] diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 05d082fa..48099c8f 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -384,8 +384,9 @@ def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None): # This happens when the initial bisection threshold is larger than # the table itself. - if level == 0 and not self.stats.get("table_count", False): - self.stats["table_count"] = self.stats.get("table_count", 0) + max(len(rows1), len(rows2)) + if level == 0 and not self.stats.get("table1_count", False): + self.stats["table1_count"] = self.stats.get("table1_count", 0) + len(rows1) + self.stats["table2_count"] = self.stats.get("table2_count", 0) + len(rows2) logger.info(". " * level + f"Diff found {len(diff)} different rows.") self.stats["rows_downloaded"] = self.stats.get("rows_downloaded", 0) + max(len(rows1), len(rows2)) @@ -426,7 +427,8 @@ def _diff_tables(self, table1, table2, level=0, segment_index=None, segment_coun return if level == 1: - self.stats["table_count"] = self.stats.get("table_count", 0) + max(count1, count2) + self.stats["table1_count"] = self.stats.get("table_count1", 0) + count1 + self.stats["table2_count"] = self.stats.get("table_count2", 0) + count2 if checksum1 != checksum2: yield from self._bisect_and_diff_tables(table1, table2, level=level, max_rows=max(count1, count2)) From b3e1d8a8a73f348c970629a6b3c614f907bd8aec Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 21 Jun 2022 21:05:48 +0200 Subject: [PATCH 18/32] Fix for Redshift --- data_diff/database.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/data_diff/database.py b/data_diff/database.py index 0a4d900e..4f2d93f4 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -209,8 +209,8 @@ class Database(AbstractDatabase): Instanciated using :meth:`~data_diff.connect_to_uri` """ - DATETIME_TYPES = NotImplemented - default_schema = NotImplemented + DATETIME_TYPES = {} + default_schema = None def query(self, sql_ast: SqlOrStr, res_type: type): "Query the given SQL code/AST, and attempt to convert the result to type 'res_type'" @@ -306,13 +306,15 @@ def query_table_schema(self, path: DbPath) -> Dict[str, ColType]: def _normalize_table_path(self, path: DbPath) -> DbPath: if len(path) == 1: - return self.default_schema, path[0] - elif len(path) == 2: - return path + if self.default_schema: + return self.default_schema, path[0] + elif len(path) != 2: + raise ValueError( + f"{self.__class__.__name__}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table" + ) + + return path - raise ValueError( - f"{self.__class__.__name__}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table" - ) def parse_table_name(self, name: str) -> DbPath: return parse_table_name(name) @@ -713,6 +715,14 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: return self.to_string(f"{value}") + def select_table_schema(self, path: DbPath) -> str: + schema, table = self._normalize_table_path(path) + + return ( + "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM information_schema.columns " + f"WHERE table_name = '{table.lower()}' AND table_schema = '{schema.lower()}'" + ) + class MsSQL(ThreadedDatabase): "AKA sql-server" From 152ebbe6a0c8fb55d042bba2bd77631192bab17d Mon Sep 17 00:00:00 2001 From: Simon Eskildsen Date: Tue, 21 Jun 2022 13:22:21 -0400 Subject: [PATCH 19/32] readme: add performance graph --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 8e92f3d7..39b396b8 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,13 @@ there are few/no changes, but is able to output each differing row! By pushing the compute into the databases, it's _much_ faster than querying for and comparing every row. +![Performance for 100M rows](https://user-images.githubusercontent.com/97400/174860361-35158d2b-0cad-4089-be66-8bf467058387.png) + +**†:** The implementation for downloading all rows that `data-diff` and +`count(*)` is compared to is not optimal. It is a single Python multi-threaded +process. The performance is fairly driver-specific, e.g. Postgres' performs 10x +better than MySQL. + ## Table of Contents - [Common use-cases](#common-use-cases) From 469f1420649d629aa401fa3076b19c90dd1534dd Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 21 Jun 2022 11:03:39 +0200 Subject: [PATCH 20/32] Refactor into normalize_timestamp() normalize_number() --- data_diff/database.py | 228 +++++++++++++++++++++--------------------- 1 file changed, 114 insertions(+), 114 deletions(-) diff --git a/data_diff/database.py b/data_diff/database.py index 4f2d93f4..c4d5edf4 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -131,6 +131,10 @@ def __post_init__(self): class UnknownColType(ColType): text: str + def __post_init__(self): + logger.warn(f"Column of type '{self.text}' has no compatibility handling. " + "If encoding/formatting differs between databases, it may result in false positives.") + class AbstractDatabase(ABC): @abstractmethod @@ -173,16 +177,24 @@ def close(self): "Close connection(s) to the database instance. Querying will stop functioning." ... + @abstractmethod - def normalize_value_by_type(value: str, coltype: ColType) -> str: - """Creates an SQL expression, that converts 'value' to a normalized representation. + def normalize_timestamp(self, value: str, coltype: ColType) -> str: + """Creates an SQL expression, that converts 'value' to a normalized timestamp. - The returned expression must accept any SQL value, and return a string. + The returned expression must accept any SQL datetime/timestamp, and return a string. + + Date format: "YYYY-MM-DD HH:mm:SS.FFFFFF" + + Precision of dates should be rounded up/down according to coltype.rounds + """ + ... - - Dates are expected in the format: - "YYYY-MM-DD HH:mm:SS.FFFFFF" + @abstractmethod + def normalize_number(self, value: str, coltype: ColType) -> str: + """Creates an SQL expression, that converts 'value' to a normalized number. - Rounded up/down according to coltype.rounds + The returned expression must accept any SQL int/numeric/float, and return a string. - Floats/Decimals are expected in the format "I.P" @@ -191,14 +203,31 @@ def normalize_value_by_type(value: str, coltype: ColType) -> str: and must be at least one digit (0). P is the fractional digits, the amount of which is specified with coltype.precision. Trailing zeroes may be necessary. + If P is 0, the dot is omitted. Note: This precision is different than the one used by databases. For decimals, - it's the same as "numeric_scale", and for floats, who use binary precision, - it can be calculated as log10(2**p) + it's the same as ``numeric_scale``, and for floats, who use binary precision, + it can be calculated as ``log10(2**numeric_precision)``. + """ + ... + + def normalize_value_by_type(self, value: str, coltype: ColType) -> str: + """Creates an SQL expression, that converts 'value' to a normalized representation. + + The returned expression must accept any SQL value, and return a string. + + The default implementation dispatches to a method according to ``coltype``: + TemporalType -> normalize_timestamp() + NumericType -> normalize_number() + -else- -> to_string() """ - ... + if isinstance(coltype, TemporalType): + return self.normalize_timestamp(value, coltype) + elif isinstance(coltype, NumericType): + return self.normalize_number(value, coltype) + return self.to_string(f"{value}") class Database(AbstractDatabase): @@ -410,27 +439,16 @@ def md5_to_int(self, s: str) -> str: def to_string(self, s: str): return f"{s}::varchar" - def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - if isinstance(coltype, TemporalType): - # if coltype.precision == 0: - # return f"to_char({value}::timestamp(0), 'YYYY-mm-dd HH24:MI:SS')" - # if coltype.precision == 3: - # return f"to_char({value}, 'YYYY-mm-dd HH24:MI:SS.US')" - # elif coltype.precision == 6: - # return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')" - # else: - # # Postgres/Redshift doesn't support arbitrary precision - # raise TypeError(f"Bad precision for {type(self).__name__}: {coltype})") - if coltype.rounds: - return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')" - else: - timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" - return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" - elif isinstance(coltype, NumericType): - value = f"{value}::decimal(38, {coltype.precision})" + def normalize_timestamp(self, value: str, coltype: ColType) -> str: + if coltype.rounds: + return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')" - return self.to_string(f"{value}") + timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" + return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + + def normalize_number(self, value: str, coltype: ColType) -> str: + return self.to_string(f"{value}::decimal(38, {coltype.precision})") class Presto(Database): @@ -470,25 +488,19 @@ def _query(self, sql_code: str) -> list: def close(self): self._conn.close() - def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - if isinstance(coltype, TemporalType): - if coltype.rounds: - if coltype.precision > 3: - pass - s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" - else: - s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" - # datetime = f"date_format(cast({value} as timestamp(6), '%Y-%m-%d %H:%i:%S.%f'))" - # datetime = self.to_string(f"cast({value} as datetime(6))") - - return ( - f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" - ) + def normalize_timestamp(self, value: str, coltype: ColType) -> str: + # TODO + if coltype.rounds: + s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" + else: + s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" - elif isinstance(coltype, NumericType): - value = f"cast({value} as decimal(38,{coltype.precision}))" + return ( + f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) - return self.to_string(value) + def normalize_number(self, value: str, coltype: ColType) -> str: + return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))") def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) @@ -577,18 +589,16 @@ def md5_to_int(self, s: str) -> str: def to_string(self, s: str): return f"cast({s} as char)" - def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - if isinstance(coltype, TemporalType): - if coltype.rounds: - return self.to_string(f"cast( cast({value} as datetime({coltype.precision})) as datetime(6))") - else: - s = self.to_string(f"cast({value} as datetime(6))") - return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" + def normalize_timestamp(self, value: str, coltype: ColType) -> str: + if coltype.rounds: + return self.to_string(f"cast( cast({value} as datetime({coltype.precision})) as datetime(6))") - elif isinstance(coltype, NumericType): - value = f"cast({value} as decimal(38,{coltype.precision}))" + s = self.to_string(f"cast({value} as datetime(6))") + return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" + + def normalize_number(self, value: str, coltype: ColType) -> str: + return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") - return self.to_string(f"{value}") class Oracle(ThreadedDatabase): @@ -633,16 +643,15 @@ def select_table_schema(self, path: DbPath) -> str: f" FROM USER_TAB_COLUMNS WHERE table_name = '{table.upper()}'" ) - def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - if isinstance(coltype, TemporalType): - return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')" - elif isinstance(coltype, NumericType): - # FM999.9990 - format_str = "FM" + "9" * (38 - coltype.precision) - if coltype.precision: - format_str += "0." + "9" * (coltype.precision - 1) + "0" - return f"to_char({value}, '{format_str}')" - return self.to_string(f"{value}") + def normalize_timestamp(self, value: str, coltype: ColType) -> str: + return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')" + + def normalize_number(self, value: str, coltype: ColType) -> str: + # FM999.9990 + format_str = "FM" + "9" * (38 - coltype.precision) + if coltype.precision: + format_str += "0." + "9" * (coltype.precision - 1) + "0" + return f"to_char({value}, '{format_str}')" def _parse_type( self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None, numeric_scale: int = None @@ -693,27 +702,25 @@ class Redshift(Postgres): def md5_to_int(self, s: str) -> str: return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)" - def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - if isinstance(coltype, TemporalType): - if coltype.rounds: - timestamp = f"{value}::timestamp(6)" - # Get seconds since epoch. Redshift doesn't support milli- or micro-seconds. - secs = f"timestamp 'epoch' + round(extract(epoch from {timestamp})::decimal(38)" - # Get the milliseconds from timestamp. - ms = f"extract(ms from {timestamp})" - # Get the microseconds from timestamp, without the milliseconds! - us = f"extract(us from {timestamp})" - # epoch = Total time since epoch in microseconds. - epoch = f"{secs}*1000000 + {ms}*1000 + {us}" - timestamp6 = f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')" - else: - timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" - return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" - - elif isinstance(coltype, NumericType): - value = f"{value}::decimal(38,{coltype.precision})" + def normalize_timestamp(self, value: str, coltype: ColType) -> str: + if coltype.rounds: + timestamp = f"{value}::timestamp(6)" + # Get seconds since epoch. Redshift doesn't support milli- or micro-seconds. + secs = f"timestamp 'epoch' + round(extract(epoch from {timestamp})::decimal(38)" + # Get the milliseconds from timestamp. + ms = f"extract(ms from {timestamp})" + # Get the microseconds from timestamp, without the milliseconds! + us = f"extract(us from {timestamp})" + # epoch = Total time since epoch in microseconds. + epoch = f"{secs}*1000000 + {ms}*1000 + {us}" + timestamp6 = f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')" + else: + timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" + return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + + def normalize_number(self, value: str, coltype: ColType) -> str: + return self.to_string(f"{value}::decimal(38,{coltype.precision})") - return self.to_string(f"{value}") def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) @@ -813,27 +820,23 @@ def select_table_schema(self, path: DbPath) -> str: f"WHERE table_name = '{table}' AND table_schema = '{schema}'" ) - def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - if isinstance(coltype, TemporalType): - if coltype.rounds: - timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" - return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})" - else: - if coltype.precision == 0: - return f"FORMAT_TIMESTAMP('%F %H:%M:%S.000000, {value})" - elif coltype.precision == 6: - return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" + def normalize_timestamp(self, value: str, coltype: ColType) -> str: + if coltype.rounds: + timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" + return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})" - timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" - return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" - elif isinstance(coltype, Integer): - pass + if coltype.precision == 0: + return f"FORMAT_TIMESTAMP('%F %H:%M:%S.000000, {value})" + elif coltype.precision == 6: + return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" - elif isinstance(coltype, NumericType): - # value = f"cast({value} as decimal)" - return f"format('%.{coltype.precision}f', ({value}))" + timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" + return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" - return self.to_string(f"{value}") + def normalize_number(self, value: str, coltype: ColType) -> str: + if isinstance(coltype, Integer): + return self.to_string(value) + return f"format('%.{coltype.precision}f', {value})" def parse_table_name(self, name: str) -> DbPath: path = parse_table_name(name) @@ -907,19 +910,16 @@ def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) return super().select_table_schema((schema, table)) - def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - if isinstance(coltype, TemporalType): - if coltype.rounds: - timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, {value}::timestamp(9))/1000000000, {coltype.precision}))" - else: - timestamp = f"cast({value} as timestamp({coltype.precision}))" - - return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')" + def normalize_timestamp(self, value: str, coltype: ColType) -> str: + if coltype.rounds: + timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, {value}::timestamp(9))/1000000000, {coltype.precision}))" + else: + timestamp = f"cast({value} as timestamp({coltype.precision}))" - elif isinstance(coltype, NumericType): - value = f"cast({value} as decimal(38, {coltype.precision}))" + return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')" - return self.to_string(f"{value}") + def normalize_number(self, value: str, coltype: ColType) -> str: + return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") @dataclass From 4e4958c64fec2d40ca29742b283abc1306133dee Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 21 Jun 2022 11:56:32 +0200 Subject: [PATCH 21/32] Better errors for missing imports --- data_diff/database.py | 29 ++++++++++++++++++++++++++--- pyproject.toml | 1 + 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/data_diff/database.py b/data_diff/database.py index c4d5edf4..7a3829e5 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -1,5 +1,5 @@ import math -from functools import lru_cache +from functools import lru_cache, wraps from itertools import zip_longest import re from abc import ABC, abstractmethod @@ -23,6 +23,21 @@ def parse_table_name(t): return tuple(t.split(".")) +def import_helper(s: str): + def dec(f): + @wraps(f) + def _inner(): + try: + return f() + except ModuleNotFoundError as e: + raise ModuleNotFoundError(f"{e}\n\nYou can install it using 'pip install data-diff[{s}]'.") + + return _inner + + return dec + + +@import_helper("pgsql") def import_postgres(): import psycopg2 import psycopg2.extras @@ -31,12 +46,14 @@ def import_postgres(): return psycopg2 +@import_helper("mysql") def import_mysql(): import mysql.connector return mysql.connector +@import_helper("snowflake") def import_snowflake(): import snowflake.connector @@ -55,6 +72,7 @@ def import_oracle(): return cx_Oracle +@import_helper("presto") def import_presto(): import prestodb @@ -344,7 +362,6 @@ def _normalize_table_path(self, path: DbPath) -> DbPath: return path - def parse_table_name(self, name: str) -> DbPath: return parse_table_name(name) @@ -356,12 +373,16 @@ class ThreadedDatabase(Database): """ def __init__(self, thread_count=1): + self._init_error = None self._queue = ThreadPoolExecutor(thread_count, initializer=self.set_conn) self.thread_local = threading.local() def set_conn(self): assert not hasattr(self.thread_local, "conn") - self.thread_local.conn = self.create_connection() + try: + self.thread_local.conn = self.create_connection() + except ModuleNotFoundError as e: + self._init_error = e def _query(self, sql_code: str): r = self._queue.submit(self._query_in_worker, sql_code) @@ -369,6 +390,8 @@ def _query(self, sql_code: str): def _query_in_worker(self, sql_code: str): "This method runs in a worker thread" + if self._init_error: + raise self._init_error return _query_conn(self.thread_local.conn, sql_code) def close(self): diff --git a/pyproject.toml b/pyproject.toml index a1d60be4..685143f6 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ mysql = ["mysql-connector-python"] pgsql = ["psycopg2"] snowflake = ["snowflake-connector-python"] presto = ["presto-python-client"] +oracle = ["cx_Oracle"] [build-system] requires = ["poetry-core>=1.0.0"] From 4b670fa60d5e10f69d3cef06c66e00bfe42b2c1d Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 21 Jun 2022 12:03:38 +0200 Subject: [PATCH 22/32] Fix for BigQuery --- data_diff/database.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/data_diff/database.py b/data_diff/database.py index 7a3829e5..fd033e6e 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -23,14 +23,17 @@ def parse_table_name(t): return tuple(t.split(".")) -def import_helper(s: str): +def import_helper(package: str = None, text=""): def dec(f): @wraps(f) def _inner(): try: return f() except ModuleNotFoundError as e: - raise ModuleNotFoundError(f"{e}\n\nYou can install it using 'pip install data-diff[{s}]'.") + s = text + if package: + s += f"You can install it using 'pip install data-diff[{package}]'." + raise ModuleNotFoundError(f"{e}\n\n{s}\n") return _inner @@ -79,6 +82,13 @@ def import_presto(): return prestodb +@import_helper(text="Please install BigQuery and configure your google-cloud access.") +def import_bigquery(): + from google.cloud import bigquery + + return bigquery + + class ConnectError(Exception): pass @@ -797,7 +807,7 @@ class BigQuery(Database): ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation def __init__(self, project, *, dataset, **kw): - from google.cloud import bigquery + bigquery = import_bigquery() self._client = bigquery.Client(project, **kw) self.project = project From d958f6995d3104fafcc1d62df1471bd6f5871b97 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 22 Jun 2022 10:00:49 +0200 Subject: [PATCH 23/32] Use proper name: pgsql,postgres -> PostgreSQL SQLAlchemy uses "postgresql://" so now we do too. --- README.md | 36 +++++++++++++++++----------------- data_diff/__init__.py | 2 +- data_diff/database.py | 22 ++++++++++----------- docs/index.rst | 8 ++++---- poetry.lock | 4 ++-- pyproject.toml | 2 +- tests/common.py | 4 ++-- tests/test_database.py | 2 +- tests/test_database_types.py | 2 +- tests/test_normalize_fields.py | 2 +- 10 files changed, 42 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index 39b396b8..216bb638 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ also find us in `#tools-data-diff` in the [Locally Optimistic Slack.][slack]** **data-diff** is a command-line tool and Python library to efficiently diff rows across two different databases. -* ⇄ Verifies across [many different databases][dbs] (e.g. Postgres -> Snowflake) +* ⇄ Verifies across [many different databases][dbs] (e.g. PostgreSQL -> Snowflake) * 🔍 Outputs [diff of rows](#example-command-and-output) in detail * 🚨 Simple CLI/API to create monitoring and alerts * 🔥 Verify 25M+ rows in <10s, and 1B+ rows in ~5min. @@ -28,7 +28,7 @@ comparing every row. **†:** The implementation for downloading all rows that `data-diff` and `count(*)` is compared to is not optimal. It is a single Python multi-threaded -process. The performance is fairly driver-specific, e.g. Postgres' performs 10x +process. The performance is fairly driver-specific, e.g. PostgreSQL's performs 10x better than MySQL. ## Table of Contents @@ -45,7 +45,7 @@ better than MySQL. ## Common use-cases * **Verify data migrations.** Verify that all data was copied when doing a - critical data migration. For example, migrating from Heroku Postgres to Amazon RDS. + critical data migration. For example, migrating from Heroku PostgreSQL to Amazon RDS. * **Verifying data pipelines.** Moving data from a relational database to a warehouse/data lake with Fivetran, Airbyte, Debezium, or some other pipeline. * **Alerting and maintaining data integrity SLOs.** You can create and monitor @@ -63,13 +63,13 @@ better than MySQL. ## Example Command and Output -Below we run a comparison with the CLI for 25M rows in Postgres where the +Below we run a comparison with the CLI for 25M rows in PostgreSQL where the right-hand table is missing single row with `id=12500048`: ``` $ data-diff \ - postgres://postgres:password@localhost/postgres rating \ - postgres://postgres:password@localhost/postgres rating_del1 \ + postgresql://user:password@localhost/database rating \ + postgresql://user:password@localhost/database rating_del1 \ --bisection-threshold 100000 \ # for readability, try default first --bisection-factor 6 \ # for readability, try default first --update-column timestamp \ @@ -111,7 +111,7 @@ $ data-diff \ | Database | Connection string | Status | |---------------|-----------------------------------------------------------------------------------------|--------| -| Postgres | `postgres://user:password@hostname:5432/database` | 💚 | +| PostgreSQL | `postgresql://user:password@hostname:5432/database` | 💚 | | MySQL | `mysql://user:password@hostname:5432/database` | 💚 | | Snowflake | `snowflake://user:password@account/database/SCHEMA?warehouse=WAREHOUSE&role=role` | 💚 | | Oracle | `oracle://username:password@hostname/database` | 💛 | @@ -140,9 +140,9 @@ Requires Python 3.7+ with pip. ```pip install data-diff``` -or when you need extras like mysql and postgres +or when you need extras like mysql and postgresql -```pip install "data-diff[mysql,pgsql]"``` +```pip install "data-diff[mysql,postgresql]"``` # How to use @@ -185,7 +185,7 @@ logging.basicConfig(level=logging.INFO) from data_diff import connect_to_table, diff_tables -table1 = connect_to_table("postgres:///", "table_name", "id") +table1 = connect_to_table("postgresql:///", "table_name", "id") table2 = connect_to_table("mysql:///", "table_name", "id") for different_row in diff_tables(table1, table2): @@ -201,11 +201,11 @@ In this section we'll be doing a walk-through of exactly how **data-diff** works, and how to tune `--bisection-factor` and `--bisection-threshold`. Let's consider a scenario with an `orders` table with 1M rows. Fivetran is -replicating it contionously from Postgres to Snowflake: +replicating it contionously from PostgreSQL to Snowflake: ``` ┌─────────────┐ ┌─────────────┐ -│ Postgres │ │ Snowflake │ +│ PostgreSQL │ │ Snowflake │ ├─────────────┤ ├─────────────┤ │ │ │ │ │ │ │ │ @@ -233,7 +233,7 @@ of the table. Then it splits the table into `--bisection-factor=10` segments of ``` ┌──────────────────────┐ ┌──────────────────────┐ -│ Postgres │ │ Snowflake │ +│ PostgreSQL │ │ Snowflake │ ├──────────────────────┤ ├──────────────────────┤ │ id=1..100k │ │ id=1..100k │ ├──────────────────────┤ ├──────────────────────┤ @@ -281,7 +281,7 @@ are the same except `id=100k..200k`: ``` ┌──────────────────────┐ ┌──────────────────────┐ -│ Postgres │ │ Snowflake │ +│ PostgreSQL │ │ Snowflake │ ├──────────────────────┤ ├──────────────────────┤ │ checksum=0102 │ │ checksum=0102 │ ├──────────────────────┤ mismatch! ├──────────────────────┤ @@ -306,7 +306,7 @@ and compare them in memory in **data-diff**. ``` ┌──────────────────────┐ ┌──────────────────────┐ -│ Postgres │ │ Snowflake │ +│ PostgreSQL │ │ Snowflake │ ├──────────────────────┤ ├──────────────────────┤ │ id=100k..110k │ │ id=100k..110k │ ├──────────────────────┤ ├──────────────────────┤ @@ -337,7 +337,7 @@ If you pass `--stats` you'll see e.g. what % of rows were different. queries. * Consider increasing the number of simultaneous threads executing queries per database with `--threads`. For databases that limit concurrency - per query, e.g. Postgres/MySQL, this can improve performance dramatically. + per query, e.g. PostgreSQL/MySQL, this can improve performance dramatically. * If you are only interested in _whether_ something changed, pass `--limit 1`. This can be useful if changes are very rare. This is often faster than doing a `count(*)`, for the reason mentioned above. @@ -419,7 +419,7 @@ Now you can insert it into the testing database(s): ```shell-session # It's optional to seed more than one to run data-diff(1) against. $ poetry run preql -f dev/prepare_db.pql mysql://mysql:Password1@127.0.0.1:3306/mysql -$ poetry run preql -f dev/prepare_db.pql postgres://postgres:Password1@127.0.0.1:5432/postgres +$ poetry run preql -f dev/prepare_db.pql postgresql://postgres:Password1@127.0.0.1:5432/postgres # Cloud databases $ poetry run preql -f dev/prepare_db.pql snowflake:// @@ -430,7 +430,7 @@ $ poetry run preql -f dev/prepare_db.pql bigquery:/// **5. Run **data-diff** against seeded database** ```bash -poetry run python3 -m data_diff postgres://postgres:Password1@localhost/postgres rating postgres://postgres:Password1@localhost/postgres rating_del1 --verbose +poetry run python3 -m data_diff postgresql://postgres:Password1@localhost/postgres rating postgresql://postgres:Password1@localhost/postgres rating_del1 --verbose ``` # License diff --git a/data_diff/__init__.py b/data_diff/__init__.py index 2688308e..4bc73733 100644 --- a/data_diff/__init__.py +++ b/data_diff/__init__.py @@ -56,7 +56,7 @@ def diff_tables( """Efficiently finds the diff between table1 and table2. Example: - >>> table1 = connect_to_table('postgres:///', 'Rating', 'id') + >>> table1 = connect_to_table('postgresql:///', 'Rating', 'id') >>> list(diff_tables(table1, table1)) [] diff --git a/data_diff/database.py b/data_diff/database.py index fd033e6e..8bbef736 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -40,8 +40,8 @@ def _inner(): return dec -@import_helper("pgsql") -def import_postgres(): +@import_helper("postgresql") +def import_postgresql(): import psycopg2 import psycopg2.extras @@ -427,7 +427,7 @@ def close(self): TIMESTAMP_PRECISION_POS = 20 # len("2022-06-03 12:24:35.") == 20 -class Postgres(ThreadedDatabase): +class PostgreSQL(ThreadedDatabase): DATETIME_TYPES = { "timestamp with time zone": TimestampTZ, "timestamp without time zone": Timestamp, @@ -451,16 +451,16 @@ def __init__(self, host, port, user, password, *, database, thread_count, **kw): super().__init__(thread_count=thread_count) def _convert_db_precision_to_digits(self, p: int) -> int: - # Subtracting 2 due to wierd precision issues in Postgres + # Subtracting 2 due to wierd precision issues in PostgreSQL return super()._convert_db_precision_to_digits(p) - 2 def create_connection(self): - postgres = import_postgres() + pg = import_postgresql() try: - c = postgres.connect(**self.args) + c = pg.connect(**self.args) # c.cursor().execute("SET TIME ZONE 'UTC'") return c - except postgres.OperationalError as e: + except pg.OperationalError as e: raise ConnectError(*e.args) from e def quote(self, s: str): @@ -722,9 +722,9 @@ def _parse_type( return UnknownColType(type_repr) -class Redshift(Postgres): +class Redshift(PostgreSQL): NUMERIC_TYPES = { - **Postgres.NUMERIC_TYPES, + **PostgreSQL.NUMERIC_TYPES, "double": Float, "real": Float, } @@ -1005,7 +1005,7 @@ def match_path(self, dsn): MATCH_URI_PATH = { - "postgres": MatchUriPath(Postgres, ["database?"], help_str="postgres://:@/"), + "postgresql": MatchUriPath(PostgreSQL, ["database?"], help_str="postgresql://:@/"), "mysql": MatchUriPath(MySQL, ["database?"], help_str="mysql://:@/"), "oracle": MatchUriPath(Oracle, ["database?"], help_str="oracle://:@/"), "mssql": MatchUriPath(MsSQL, ["database?"], help_str="mssql://:@/"), @@ -1034,7 +1034,7 @@ def connect_to_uri(db_uri: str, thread_count: Optional[int] = 1) -> Database: Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck. Supported schemes: - - postgres + - postgresql - mysql - mssql - oracle diff --git a/docs/index.rst b/docs/index.rst index 094e2009..372de44c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -11,7 +11,7 @@ Introduction **Data-diff** is a command-line tool and Python library to efficiently diff rows across two different databases. -⇄ Verifies across many different databases (e.g. *Postgres* -> *Snowflake*) ! +⇄ Verifies across many different databases (e.g. *PostgreSQL* -> *Snowflake*) ! 🔍 Outputs diff of rows in detail @@ -32,11 +32,11 @@ Requires Python 3.7+ with pip. pip install data-diff -or when you need extras like mysql and postgres: +or when you need extras like mysql and postgresql: :: - pip install "data-diff[mysql,pgsql]" + pip install "data-diff[mysql,postgresql]" How to use from Python @@ -50,7 +50,7 @@ How to use from Python from data_diff import connect_to_table, diff_tables - table1 = connect_to_table("postgres:///", "table_name", "id") + table1 = connect_to_table("postgresql:///", "table_name", "id") table2 = connect_to_table("mysql:///", "table_name", "id") for sign, columns in diff_tables(table1, table2): diff --git a/poetry.lock b/poetry.lock index 8841f2c3..3fa1a67a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -200,7 +200,7 @@ rich = ">=10.7.0,<11.0.0" runtype = ">=0.2.4,<0.3.0" [package.extras] -pgsql = ["psycopg2"] +postgresql = ["psycopg2"] mysql = ["mysqlclient"] server = ["starlette"] @@ -443,7 +443,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [extras] mysql = ["mysql-connector-python"] -pgsql = ["psycopg2"] +postgresql = ["psycopg2"] preql = ["preql"] presto = [] snowflake = ["snowflake-connector-python"] diff --git a/pyproject.toml b/pyproject.toml index 685143f6..b141ca62 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ parameterized = "*" # When adding, update also: README + dev deps just above preql = ["preql"] mysql = ["mysql-connector-python"] -pgsql = ["psycopg2"] +postgresql = ["psycopg2"] snowflake = ["snowflake-connector-python"] presto = ["presto-python-client"] oracle = ["cx_Oracle"] diff --git a/tests/common.py b/tests/common.py index 32c4c30b..1fd610a0 100644 --- a/tests/common.py +++ b/tests/common.py @@ -6,7 +6,7 @@ logging.basicConfig(level=logging.INFO) TEST_MYSQL_CONN_STRING: str = "mysql://mysql:Password1@localhost/mysql" -TEST_POSTGRES_CONN_STRING: str = None +TEST_POSTGRESQL_CONN_STRING: str = None TEST_SNOWFLAKE_CONN_STRING: str = None TEST_BIGQUERY_CONN_STRING: str = None TEST_REDSHIFT_CONN_STRING: str = None @@ -26,7 +26,7 @@ CONN_STRINGS = { db.BigQuery: TEST_BIGQUERY_CONN_STRING, db.MySQL: TEST_MYSQL_CONN_STRING, - db.Postgres: TEST_POSTGRES_CONN_STRING, + db.PostgreSQL: TEST_POSTGRESQL_CONN_STRING, db.Snowflake: TEST_SNOWFLAKE_CONN_STRING, db.Redshift: TEST_REDSHIFT_CONN_STRING, db.Oracle: TEST_ORACLE_CONN_STRING, diff --git a/tests/test_database.py b/tests/test_database.py index eabed2f7..924925c2 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -22,7 +22,7 @@ def test_md5_to_int(self): class TestConnect(unittest.TestCase): def test_bad_uris(self): self.assertRaises(ValueError, connect_to_uri, "p") - self.assertRaises(ValueError, connect_to_uri, "postgres:///bla/foo") + self.assertRaises(ValueError, connect_to_uri, "postgresql:///bla/foo") self.assertRaises(ValueError, connect_to_uri, "snowflake://erez:erez27Snow@bya42734/xdiffdev/TEST1") self.assertRaises( ValueError, connect_to_uri, "snowflake://erez:erez27Snow@bya42734/xdiffdev/TEST1?warehouse=ha&schema=dup" diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 1e5602f9..2f665618 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -52,7 +52,7 @@ } DATABASE_TYPES = { - db.Postgres: { + db.PostgreSQL: { # https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-INT "int": [ # "smallint", # 2 bytes diff --git a/tests/test_normalize_fields.py b/tests/test_normalize_fields.py index 468d2667..7893022f 100644 --- a/tests/test_normalize_fields.py +++ b/tests/test_normalize_fields.py @@ -14,7 +14,7 @@ logger = logging.getLogger() DATE_TYPES = { - db.Postgres: ["timestamp({p}) with time zone", "timestamp({p}) without time zone"], + db.PostgreSQL: ["timestamp({p}) with time zone", "timestamp({p}) without time zone"], db.MySQL: ["datetime({p})", "timestamp({p})"], db.Snowflake: ["timestamp({p})", "timestamp_tz({p})", "timestamp_ntz({p})"], db.BigQuery: ["timestamp", "datetime"], From be727d915a9f52f50ec1bfe856dca31d180b60b4 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 22 Jun 2022 10:34:21 +0200 Subject: [PATCH 24/32] Updated README --- README.md | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 216bb638..6f79b9c9 100644 --- a/README.md +++ b/README.md @@ -140,9 +140,28 @@ Requires Python 3.7+ with pip. ```pip install data-diff``` -or when you need extras like mysql and postgresql +## Install drivers -```pip install "data-diff[mysql,postgresql]"``` +To connect to a database, we need to have its driver installed, in the form of a Python library. + +While you may install them manually, we offer an easy way to install them along with data-diff: + +- `pip install 'data-diff[mysql]'` + +- `pip install 'data-diff[postgresql]'` + +- `pip install 'data-diff[snowflake]'` + +- `pip install 'data-diff[presto]'` + +- `pip install 'data-diff[oracle]'` + +- For BigQuery, see: https://pypi.org/project/google-cloud-bigquery/ + + +Users can also install several drivers at once: + +```pip install 'data-diff[mysql,postgresql,snowflake]'``` # How to use From 101707bb58c8c2514db6a1a8b115677470411fd4 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 22 Jun 2022 11:07:25 +0200 Subject: [PATCH 25/32] Fix: Only parse relevant columns. Only warn on relevant columns. --- data_diff/database.py | 25 +++++++++++++++---------- data_diff/diff_tables.py | 12 ++++++++++-- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/data_diff/database.py b/data_diff/database.py index c4d5edf4..79117f0e 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from runtype import dataclass import logging -from typing import Tuple, Optional, List +from typing import Sequence, Tuple, Optional, List from concurrent.futures import ThreadPoolExecutor import threading from typing import Dict @@ -131,10 +131,6 @@ def __post_init__(self): class UnknownColType(ColType): text: str - def __post_init__(self): - logger.warn(f"Column of type '{self.text}' has no compatibility handling. " - "If encoding/formatting differs between databases, it may result in false positives.") - class AbstractDatabase(ABC): @abstractmethod @@ -163,7 +159,7 @@ def select_table_schema(self, path: DbPath) -> str: ... @abstractmethod - def query_table_schema(self, path: DbPath) -> Dict[str, ColType]: + def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str]] = None) -> Dict[str, ColType]: "Query the table for its schema for table in 'path', and return {column: type}" ... @@ -241,6 +237,10 @@ class Database(AbstractDatabase): DATETIME_TYPES = {} default_schema = None + @property + def name(self): + return type(self).__name__ + def query(self, sql_ast: SqlOrStr, res_type: type): "Query the given SQL code/AST, and attempt to convert the result to type 'res_type'" @@ -321,12 +321,16 @@ def select_table_schema(self, path: DbPath) -> str: f"WHERE table_name = '{table}' AND table_schema = '{schema}'" ) - def query_table_schema(self, path: DbPath) -> Dict[str, ColType]: + def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str]] = None) -> Dict[str, ColType]: rows = self.query(self.select_table_schema(path), list) if not rows: - raise RuntimeError(f"{self.__class__.__name__}: Table '{'.'.join(path)}' does not exist, or has no columns") + raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") + + if filter_columns is not None: + accept = {i.lower() for i in filter_columns} + rows = [r for r in rows if r[0].lower() in accept] - # Return a dict of form {name: type} after canonizaation + # Return a dict of form {name: type} after normalization return {row[0]: self._parse_type(*row[1:]) for row in rows} # @lru_cache() @@ -339,7 +343,7 @@ def _normalize_table_path(self, path: DbPath) -> DbPath: return self.default_schema, path[0] elif len(path) != 2: raise ValueError( - f"{self.__class__.__name__}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table" + f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table" ) return path @@ -407,6 +411,7 @@ class Postgres(ThreadedDatabase): "decimal": Decimal, "integer": Integer, "numeric": Decimal, + "bigint": Integer, } ROUNDS_ON_PREC_LOSS = True diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 565d51a2..2c9a538c 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -12,7 +12,7 @@ from runtype import dataclass from .sql import Select, Checksum, Compare, DbPath, DbKey, DbTime, Count, TableName, Time, Min, Max -from .database import Database, NumericType, PrecisionType, ColType +from .database import Database, NumericType, PrecisionType, ColType, UnknownColType logger = logging.getLogger("diff_tables") @@ -142,7 +142,8 @@ def with_schema(self) -> "TableSegment": "Queries the table schema from the database, and returns a new instance of TableSegmentWithSchema." if self._schema: return self - schema = self.database.query_table_schema(self.table_path) + + schema = self.database.query_table_schema(self.table_path, self._relevant_columns) if self.case_sensitive: schema = Schema_CaseSensitive(schema) else: @@ -381,6 +382,13 @@ def _validate_and_adjust_columns(self, table1, table2): table1._schema[c] = col1.replace(precision=lowest.precision) table2._schema[c] = col2.replace(precision=lowest.precision) + for t in [table1, table2]: + for c in t._relevant_columns: + ctype = t._schema[c] + if isinstance(ctype, UnknownColType): + logger.warn(f"[{t.database.name}] Column '{c}' of type '{ctype.text}' has no compatibility handling. " + "If encoding/formatting differs between databases, it may result in false positives.") + def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None): assert table1.is_bounded and table2.is_bounded From 6df9d37c7e80499fbb6aada1c9c7344f13f7d993 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 22 Jun 2022 11:09:53 +0200 Subject: [PATCH 26/32] A better error message --- data_diff/database.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/data_diff/database.py b/data_diff/database.py index 79117f0e..1c03d25c 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -301,6 +301,8 @@ def _parse_type( return cls(precision=0) elif issubclass(cls, Decimal): + if numeric_scale is None: + raise ValueError(f"{self.name}: Unexpected numeric_scale is NULL, for column of type {type_repr}.") return cls(precision=numeric_scale) assert issubclass(cls, Float) From 75e3e528e4f1423f06d697385a766e17af2b02b2 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 22 Jun 2022 12:04:58 +0200 Subject: [PATCH 27/32] Fixes for PR --- data_diff/database.py | 51 ++++++++++++++++++++++++--------------- data_diff/diff_tables.py | 6 +++-- tests/test_diff_tables.py | 5 ++-- 3 files changed, 37 insertions(+), 25 deletions(-) diff --git a/data_diff/database.py b/data_diff/database.py index 1c03d25c..fffbdbb8 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -173,7 +173,6 @@ def close(self): "Close connection(s) to the database instance. Querying will stop functioning." ... - @abstractmethod def normalize_timestamp(self, value: str, coltype: ColType) -> str: """Creates an SQL expression, that converts 'value' to a normalized timestamp. @@ -282,7 +281,12 @@ def _convert_db_precision_to_digits(self, p: int) -> int: return math.floor(math.log(2**p, 10)) def _parse_type( - self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None, numeric_scale: int = None + self, + col_name: str, + type_repr: str, + datetime_precision: int = None, + numeric_precision: int = None, + numeric_scale: int = None, ) -> ColType: """ """ @@ -302,7 +306,7 @@ def _parse_type( elif issubclass(cls, Decimal): if numeric_scale is None: - raise ValueError(f"{self.name}: Unexpected numeric_scale is NULL, for column of type {type_repr}.") + raise ValueError(f"{self.name}: Unexpected numeric_scale is NULL, for column {col_name} of type {type_repr}.") return cls(precision=numeric_scale) assert issubclass(cls, Float) @@ -333,7 +337,7 @@ def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str rows = [r for r in rows if r[0].lower() in accept] # Return a dict of form {name: type} after normalization - return {row[0]: self._parse_type(*row[1:]) for row in rows} + return {row[0]: self._parse_type(*row) for row in rows} # @lru_cache() # def get_table_schema(self, path: DbPath) -> Dict[str, ColType]: @@ -344,13 +348,10 @@ def _normalize_table_path(self, path: DbPath) -> DbPath: if self.default_schema: return self.default_schema, path[0] elif len(path) != 2: - raise ValueError( - f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table" - ) + raise ValueError(f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table") return path - def parse_table_name(self, name: str) -> DbPath: return parse_table_name(name) @@ -446,13 +447,14 @@ def md5_to_int(self, s: str) -> str: def to_string(self, s: str): return f"{s}::varchar" - def normalize_timestamp(self, value: str, coltype: ColType) -> str: if coltype.rounds: return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')" timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" - return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) def normalize_number(self, value: str, coltype: ColType) -> str: return self.to_string(f"{value}::decimal(38, {coltype.precision})") @@ -502,9 +504,7 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str: else: s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" - return ( - f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" - ) + return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" def normalize_number(self, value: str, coltype: ColType) -> str: return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))") @@ -517,7 +517,9 @@ def select_table_schema(self, path: DbPath) -> str: f"WHERE table_name = '{table}' AND table_schema = '{schema}'" ) - def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None) -> ColType: + def _parse_type( + self, col_name: str, type_repr: str, datetime_precision: int = None, numeric_precision: int = None + ) -> ColType: regexps = { r"timestamp\((\d)\)": Timestamp, r"timestamp\((\d)\) with time zone": TimestampTZ, @@ -607,7 +609,6 @@ def normalize_number(self, value: str, coltype: ColType) -> str: return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") - class Oracle(ThreadedDatabase): ROUNDS_ON_PREC_LOSS = True @@ -661,7 +662,12 @@ def normalize_number(self, value: str, coltype: ColType) -> str: return f"to_char({value}, '{format_str}')" def _parse_type( - self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None, numeric_scale: int = None + self, + col_name: str, + type_repr: str, + datetime_precision: int = None, + numeric_precision: int = None, + numeric_scale: int = None, ) -> ColType: """ """ regexps = { @@ -720,15 +726,18 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str: us = f"extract(us from {timestamp})" # epoch = Total time since epoch in microseconds. epoch = f"{secs}*1000000 + {ms}*1000 + {us}" - timestamp6 = f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')" + timestamp6 = ( + f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')" + ) else: timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" - return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) def normalize_number(self, value: str, coltype: ColType) -> str: return self.to_string(f"{value}::decimal(38,{coltype.precision})") - def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) @@ -838,7 +847,9 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str: return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" - return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) def normalize_number(self, value: str, coltype: ColType) -> str: if isinstance(coltype, Integer): diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 2c9a538c..7df04321 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -386,8 +386,10 @@ def _validate_and_adjust_columns(self, table1, table2): for c in t._relevant_columns: ctype = t._schema[c] if isinstance(ctype, UnknownColType): - logger.warn(f"[{t.database.name}] Column '{c}' of type '{ctype.text}' has no compatibility handling. " - "If encoding/formatting differs between databases, it may result in false positives.") + logger.warn( + f"[{t.database.name}] Column '{c}' of type '{ctype.text}' has no compatibility handling. " + "If encoding/formatting differs between databases, it may result in false positives." + ) def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None): assert table1.is_bounded and table2.is_bounded diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index 84bb6b9a..3649b37f 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -32,17 +32,16 @@ def tearDownClass(cls): cls.preql.close() cls.connection.close() - # Fallback for test runners that doesn't support setUpClass/tearDownClass def setUp(self) -> None: - if not hasattr(self, 'connection'): + if not hasattr(self, "connection"): self.setUpClass.__func__(self) self.private_connection = True return super().setUp() def tearDown(self) -> None: - if hasattr(self, 'private_connection'): + if hasattr(self, "private_connection"): self.tearDownClass.__func__(self) return super().tearDown() From 37b47a0bdde598ce3c22a1e4f76e79c68c994dc4 Mon Sep 17 00:00:00 2001 From: Simon Eskildsen Date: Wed, 22 Jun 2022 09:44:57 -0400 Subject: [PATCH 28/32] cli: only json from standard diff with --json --- data_diff/__main__.py | 16 +++++++++++----- data_diff/diff_tables.py | 15 ++++++++------- tests/test_diff_tables.py | 4 ++++ 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index fc32899e..6ee6992c 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -51,7 +51,7 @@ @click.option("--max-age", default=None, help="Considers only rows younger than specified. See --min-age.") @click.option("-s", "--stats", is_flag=True, help="Print stats instead of a detailed diff") @click.option("-d", "--debug", is_flag=True, help="Print debug info") -@click.option("--json", 'json_output', is_flag=True, help="Print JSON output for --stats") +@click.option("--json", 'json_output', is_flag=True, help="Print JSONL output for machine readability") @click.option("-v", "--verbose", is_flag=True, help="Print extra info") @click.option("-i", "--interactive", is_flag=True, help="Confirm queries, implies --debug") @click.option("--keep-column-case", is_flag=True, help="Don't use the schema to fix the case of given column names.") @@ -160,16 +160,22 @@ def main( "different_-": minus, "total": max_table_count, } - print(json.dumps(json_output, indent=2)) + print(json.dumps(json_output)) else: print(f"Diff-Total: {len(diff)} changed rows out of {max_table_count}") print(f"Diff-Percent: {percent:.14f}%") print(f"Diff-Split: +{plus} -{minus}") else: - for op, key in diff_iter: + for op, columns in diff_iter: color = COLOR_SCHEME[op] - jsonl = json.dumps([op, list(key)]) - rich.print(f"[{color}]{jsonl}[/{color}]") + + if json_output: + jsonl = json.dumps([op, list(columns)]) + rich.print(f"[{color}]{jsonl}[/{color}]") + else: + text = f"{op} {', '.join(columns)}" + rich.print(f"[{color}]{text}[/{color}]") + sys.stdout.flush() end = time.time() diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 48099c8f..ddad1ef1 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -382,11 +382,12 @@ def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None): rows1, rows2 = self._threaded_call("get_values", [table1, table2]) diff = list(diff_sets(rows1, rows2)) - # This happens when the initial bisection threshold is larger than - # the table itself. - if level == 0 and not self.stats.get("table1_count", False): - self.stats["table1_count"] = self.stats.get("table1_count", 0) + len(rows1) - self.stats["table2_count"] = self.stats.get("table2_count", 0) + len(rows2) + # Initial bisection_threshold larger than count. Normally we always + # checksum and count segments, even if we get the values. At the + # first level, however, that won't be true. + if level == 0: + self.stats["table1_count"] = len(rows1) + self.stats["table2_count"] = len(rows2) logger.info(". " * level + f"Diff found {len(diff)} different rows.") self.stats["rows_downloaded"] = self.stats.get("rows_downloaded", 0) + max(len(rows1), len(rows2)) @@ -427,8 +428,8 @@ def _diff_tables(self, table1, table2, level=0, segment_index=None, segment_coun return if level == 1: - self.stats["table1_count"] = self.stats.get("table_count1", 0) + count1 - self.stats["table2_count"] = self.stats.get("table_count2", 0) + count2 + self.stats["table1_count"] = self.stats.get("table1_count", 0) + count1 + self.stats["table2_count"] = self.stats.get("table2_count", 0) + count2 if checksum1 != checksum2: yield from self._bisect_and_diff_tables(table1, table2, level=level, max_rows=max(count1, count2)) diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index a457081d..f2c9da19 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -155,6 +155,8 @@ def test_diff_small_tables(self): diff = list(self.differ.diff_tables(self.table, self.table2)) expected = [("-", ("2", time + ".000000"))] self.assertEqual(expected, diff) + self.assertEqual(2, self.differ.stats["table1_count"]) + self.assertEqual(1, self.differ.stats["table2_count"]) def test_diff_table_above_bisection_threshold(self): time = "2022-01-01 00:00:00" @@ -176,6 +178,8 @@ def test_diff_table_above_bisection_threshold(self): diff = list(self.differ.diff_tables(self.table, self.table2)) expected = [("-", ("5", time + ".000000"))] self.assertEqual(expected, diff) + self.assertEqual(5, self.differ.stats["table1_count"]) + self.assertEqual(4, self.differ.stats["table2_count"]) def test_return_empty_array_when_same(self): time = "2022-01-01 00:00:00" From 51e9417c437f7c724a9ec07154a0a041454e6ad0 Mon Sep 17 00:00:00 2001 From: Simon Eskildsen Date: Wed, 22 Jun 2022 10:45:41 -0400 Subject: [PATCH 29/32] readme: add --json --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 6f79b9c9..a003afa0 100644 --- a/README.md +++ b/README.md @@ -183,6 +183,7 @@ Options: - `-d` or `--debug` - Print debug info - `-v` or `--verbose` - Print extra info - `-i` or `--interactive` - Confirm queries, implies `--debug` + - `---json` - Print JSONL output for machine readability - `--min-age` - Considers only rows older than specified. Example: `--min-age=5min` ignores rows from the last 5 minutes. Valid units: `d, days, h, hours, min, minutes, mon, months, s, seconds, w, weeks, y, years` From 7db2f80444a4349ea426fd1a5f7fb480b2f49962 Mon Sep 17 00:00:00 2001 From: Simon Eskildsen Date: Wed, 22 Jun 2022 10:47:07 -0400 Subject: [PATCH 30/32] readme: ---json to --json --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a003afa0..02c02049 100644 --- a/README.md +++ b/README.md @@ -183,7 +183,7 @@ Options: - `-d` or `--debug` - Print debug info - `-v` or `--verbose` - Print extra info - `-i` or `--interactive` - Confirm queries, implies `--debug` - - `---json` - Print JSONL output for machine readability + - `--json` - Print JSONL output for machine readability - `--min-age` - Considers only rows older than specified. Example: `--min-age=5min` ignores rows from the last 5 minutes. Valid units: `d, days, h, hours, min, minutes, mon, months, s, seconds, w, weeks, y, years` From 98b9bd9fdb29fa59c3a61891130ebf5f74367f4e Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 22 Jun 2022 16:48:27 +0200 Subject: [PATCH 31/32] Version bump (0.1.0) --- poetry.lock | 23 ++++++++++++----------- pyproject.toml | 2 +- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/poetry.lock b/poetry.lock index 3fa1a67a..7172ba56 100644 --- a/poetry.lock +++ b/poetry.lock @@ -20,7 +20,7 @@ python-versions = "*" [[package]] name = "certifi" -version = "2022.5.18.1" +version = "2022.6.15" description = "Python package for providing Mozilla's CA Bundle." category = "main" optional = false @@ -62,7 +62,7 @@ importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} [[package]] name = "colorama" -version = "0.4.4" +version = "0.4.5" description = "Cross-platform colored terminal text." category = "main" optional = false @@ -200,7 +200,7 @@ rich = ">=10.7.0,<11.0.0" runtype = ">=0.2.4,<0.3.0" [package.extras] -postgresql = ["psycopg2"] +pgsql = ["psycopg2"] mysql = ["mysqlclient"] server = ["starlette"] @@ -359,7 +359,7 @@ jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"] [[package]] name = "runtype" -version = "0.2.4" +version = "0.2.6" description = "Type dispatch and validation for run-time Python" category = "main" optional = false @@ -443,6 +443,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [extras] mysql = ["mysql-connector-python"] +oracle = [] postgresql = ["psycopg2"] preql = ["preql"] presto = [] @@ -451,7 +452,7 @@ snowflake = ["snowflake-connector-python"] [metadata] lock-version = "1.1" python-versions = "^3.7" -content-hash = "cd595c78ae0024cb9d980d4a2d83d8011f82947fe557537eea0280057bcbb535" +content-hash = "e1b2b05a166d2d6d81bec8e15e562480998b6e578592a4a0ed04b6fb6a2e046c" [metadata.files] arrow = [ @@ -463,8 +464,8 @@ asn1crypto = [ {file = "asn1crypto-1.5.1.tar.gz", hash = "sha256:13ae38502be632115abf8a24cbe5f4da52e3b5231990aff31123c805306ccb9c"}, ] certifi = [ - {file = "certifi-2022.5.18.1-py3-none-any.whl", hash = "sha256:f1d53542ee8cbedbe2118b5686372fb33c297fcd6379b050cca0ef13a597382a"}, - {file = "certifi-2022.5.18.1.tar.gz", hash = "sha256:9c5705e395cd70084351dd8ad5c41e65655e08ce46f2ec9cf6c2c08390f71eb7"}, + {file = "certifi-2022.6.15-py3-none-any.whl", hash = "sha256:fe86415d55e84719d75f8b69414f6438ac3547d2078ab91b67e779ef69378412"}, + {file = "certifi-2022.6.15.tar.gz", hash = "sha256:84c85a9078b11105f04f3036a9482ae10e4621616db313fe045dd24743a0820d"}, ] cffi = [ {file = "cffi-1.15.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:c2502a1a03b6312837279c8c1bd3ebedf6c12c4228ddbad40912d671ccc8a962"}, @@ -527,8 +528,8 @@ click = [ {file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"}, ] colorama = [ - {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, - {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, + {file = "colorama-0.4.5-py2.py3-none-any.whl", hash = "sha256:854bf444933e37f5824ae7bfc1e98d5bce2ebe4160d46b5edf346a89358e99da"}, + {file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"}, ] commonmark = [ {file = "commonmark-0.9.1-py2.py3-none-any.whl", hash = "sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9"}, @@ -701,8 +702,8 @@ rich = [ {file = "rich-10.16.2.tar.gz", hash = "sha256:720974689960e06c2efdb54327f8bf0cdbdf4eae4ad73b6c94213cad405c371b"}, ] runtype = [ - {file = "runtype-0.2.4-py3-none-any.whl", hash = "sha256:1adab62f867199536820898ce04df22586ba2a52084448385004faa532b19e97"}, - {file = "runtype-0.2.4.tar.gz", hash = "sha256:642f747b199fd872deb79d361d47ea83a1a0db49986fbeaa0c375d2bd9805e00"}, + {file = "runtype-0.2.6-py3-none-any.whl", hash = "sha256:1739136f46551240a9f68807d167b5acbe4c18512de08ebdcc2fa0648d97c834"}, + {file = "runtype-0.2.6.tar.gz", hash = "sha256:31818c1991c8b5e01ec2e54a53ad44af104dcf1bb4e82efd1aa7eba1047dc6dd"}, ] six = [ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, diff --git a/pyproject.toml b/pyproject.toml index b141ca62..ed922f4e 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "data-diff" -version = "0.0.8" +version = "0.1.0" description = "Command-line tool and Python library to efficiently diff rows across two different databases." authors = ["Erez Shinnan ", "Simon Eskildsen "] license = "MIT" From df1bfaa65bc108ffa05932ad944b4e5f4d90433e Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 22 Jun 2022 17:00:57 +0200 Subject: [PATCH 32/32] Version bump (0.2.0) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ed922f4e..11c697b1 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "data-diff" -version = "0.1.0" +version = "0.2.0" description = "Command-line tool and Python library to efficiently diff rows across two different databases." authors = ["Erez Shinnan ", "Simon Eskildsen "] license = "MIT"