From b3d42237ddf79eea1b345f916380f588dcd25074 Mon Sep 17 00:00:00 2001 From: Sung Won Chung Date: Wed, 6 Dec 2023 12:03:30 +1300 Subject: [PATCH 1/7] Make dbt data diffs concurrent (#776) * v0 of concurrency * concurrent logging * remove todo * remove todo * better var name * add node name to logger * format string logs * add optional logger param * avoid extra threads * use thread pools * not multithreaded at the connection level anymore * show errors as they happen * show full stacktrace on error * rearrange trace * more logs for debugging * update for threads mocking * clear log params * remove extra space * remove long traceback * Ensure log_message is optional Co-authored-by: Dan Lawin * map threaded result to proper model id * explicit type and optional * rm submodules again --------- Co-authored-by: Sung Won Chung Co-authored-by: Dan Lawin --- data_diff/databases/base.py | 15 ++++++--- data_diff/dbt.py | 43 +++++++++++++++---------- data_diff/dbt_parser.py | 6 ++-- data_diff/joindiff_tables.py | 61 +++++++++++++++++++++++++----------- data_diff/utils.py | 28 ++++++++--------- tests/test_dbt.py | 21 ++++++++----- 6 files changed, 110 insertions(+), 64 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index c5931979..bf165461 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -931,7 +931,7 @@ def name(self): def compile(self, sql_ast): return self.dialect.compile(Compiler(self), sql_ast) - def query(self, sql_ast: Union[Expr, Generator], res_type: type = None): + def query(self, sql_ast: Union[Expr, Generator], res_type: type = None, log_message: Optional[str] = None): """Query the given SQL code/AST, and attempt to convert the result to type 'res_type' If given a generator, it will execute all the yielded sql queries with the same thread and cursor. @@ -956,7 +956,10 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = None): if sql_code is SKIP: return SKIP - logger.debug("Running SQL (%s):\n%s", self.name, sql_code) + if log_message: + logger.debug("Running SQL (%s): %s \n%s", self.name, log_message, sql_code) + else: + logger.debug("Running SQL (%s):\n%s", self.name, sql_code) if self._interactive and isinstance(sql_ast, Select): explained_sql = self.compile(Explain(sql_ast)) @@ -1022,7 +1025,7 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: Note: This method exists instead of select_table_schema(), just because not all databases support accessing the schema using a SQL query. """ - rows = self.query(self.select_table_schema(path), list) + rows = self.query(self.select_table_schema(path), list, log_message=path) if not rows: raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") @@ -1044,7 +1047,7 @@ def query_table_unique_columns(self, path: DbPath) -> List[str]: """Query the table for its unique columns for table in 'path', and return {column}""" if not self.SUPPORTS_UNIQUE_CONSTAINT: raise NotImplementedError("This database doesn't support 'unique' constraints") - res = self.query(self.select_table_unique_columns(path), List[str]) + res = self.query(self.select_table_unique_columns(path), List[str], log_message=path) return list(res) def _process_table_schema( @@ -1086,7 +1089,9 @@ def _refine_coltypes( fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns] samples_by_row = self.query( - table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size), list + table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size), + list, + log_message=table_path, ) if not samples_by_row: raise ValueError(f"Table {table_path} is empty.") diff --git a/data_diff/dbt.py b/data_diff/dbt.py index bf36c4fc..ef780429 100644 --- a/data_diff/dbt.py +++ b/data_diff/dbt.py @@ -8,6 +8,7 @@ import pydantic import rich from rich.prompt import Prompt +from concurrent.futures import ThreadPoolExecutor, as_completed from data_diff.errors import ( DataDiffCustomSchemaNoConfigError, @@ -80,7 +81,6 @@ def dbt_diff( production_schema_flag: Optional[str] = None, ) -> None: print_version_info() - diff_threads = [] set_entrypoint_name(os.getenv("DATAFOLD_TRIGGERED_BY", "CLI-dbt")) dbt_parser = DbtParser(profiles_dir_override, project_dir_override, state) models = dbt_parser.get_models(dbt_selection) @@ -112,7 +112,11 @@ def dbt_diff( else: dbt_parser.set_connection() - with log_status_handler.status if log_status_handler else nullcontext(): + futures = {} + + with log_status_handler.status if log_status_handler else nullcontext(), ThreadPoolExecutor( + max_workers=dbt_parser.threads + ) as executor: for model in models: if log_status_handler: log_status_handler.set_prefix(f"Diffing {model.alias} \n") @@ -140,12 +144,12 @@ def dbt_diff( if diff_vars.primary_keys: if is_cloud: - diff_thread = run_as_daemon( + future = executor.submit( _cloud_diff, diff_vars, config.datasource_id, api, org_meta, log_status_handler ) - diff_threads.append(diff_thread) else: - _local_diff(diff_vars, json_output) + future = executor.submit(_local_diff, diff_vars, json_output, log_status_handler) + futures[future] = model else: if json_output: print( @@ -165,10 +169,12 @@ def dbt_diff( + "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n" ) - # wait for all threads - if diff_threads: - for thread in diff_threads: - thread.join() + for future in as_completed(futures): + model = futures[future] + try: + future.result() # if error occurred, it will be raised here + except Exception as e: + logger.error(f"An error occurred during the execution of a diff task: {model.unique_id} - {e}") _extension_notification() @@ -265,15 +271,17 @@ def _get_prod_path_from_manifest(model, prod_manifest) -> Union[Tuple[str, str, return prod_database, prod_schema, prod_alias -def _local_diff(diff_vars: TDiffVars, json_output: bool = False) -> None: +def _local_diff( + diff_vars: TDiffVars, json_output: bool = False, log_status_handler: Optional[LogStatusHandler] = None +) -> None: + if log_status_handler: + log_status_handler.diff_started(diff_vars.dev_path[-1]) dev_qualified_str = ".".join(diff_vars.dev_path) prod_qualified_str = ".".join(diff_vars.prod_path) diff_output_str = _diff_output_base(dev_qualified_str, prod_qualified_str) - table1 = connect_to_table( - diff_vars.connection, prod_qualified_str, tuple(diff_vars.primary_keys), diff_vars.threads - ) - table2 = connect_to_table(diff_vars.connection, dev_qualified_str, tuple(diff_vars.primary_keys), diff_vars.threads) + table1 = connect_to_table(diff_vars.connection, prod_qualified_str, tuple(diff_vars.primary_keys)) + table2 = connect_to_table(diff_vars.connection, dev_qualified_str, tuple(diff_vars.primary_keys)) try: table1_columns = table1.get_schema() @@ -373,6 +381,9 @@ def _local_diff(diff_vars: TDiffVars, json_output: bool = False) -> None: diff_output_str += no_differences_template() rich.print(diff_output_str) + if log_status_handler: + log_status_handler.diff_finished(diff_vars.dev_path[-1]) + def _initialize_api() -> Optional[DatafoldAPI]: datafold_host = os.environ.get("DATAFOLD_HOST") @@ -406,7 +417,7 @@ def _cloud_diff( log_status_handler: Optional[LogStatusHandler] = None, ) -> None: if log_status_handler: - log_status_handler.cloud_diff_started(diff_vars.dev_path[-1]) + log_status_handler.diff_started(diff_vars.dev_path[-1]) diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path)) payload = TCloudApiDataDiff( data_source1_id=datasource_id, @@ -476,7 +487,7 @@ def _cloud_diff( rich.print(diff_output_str) if log_status_handler: - log_status_handler.cloud_diff_finished(diff_vars.dev_path[-1]) + log_status_handler.diff_finished(diff_vars.dev_path[-1]) except BaseException as ex: # Catch KeyboardInterrupt too error = ex finally: diff --git a/data_diff/dbt_parser.py b/data_diff/dbt_parser.py index 4b6124d5..0d864a57 100644 --- a/data_diff/dbt_parser.py +++ b/data_diff/dbt_parser.py @@ -446,17 +446,17 @@ def get_pk_from_model(self, node, unique_columns: dict, pk_tag: str) -> List[str from_meta = [name for name, params in node.columns.items() if pk_tag in params.meta] or None if from_meta: - logger.debug("Found PKs via META: " + str(from_meta)) + logger.debug(f"Found PKs via META [{node.name}]: " + str(from_meta)) return from_meta from_tags = [name for name, params in node.columns.items() if pk_tag in params.tags] or None if from_tags: - logger.debug("Found PKs via Tags: " + str(from_tags)) + logger.debug(f"Found PKs via Tags [{node.name}]: " + str(from_tags)) return from_tags if node.unique_id in unique_columns: from_uniq = unique_columns.get(node.unique_id) if from_uniq is not None: - logger.debug("Found PKs via Uniqueness tests: " + str(from_uniq)) + logger.debug(f"Found PKs via Uniqueness tests [{node.name}]: {str(from_uniq)}") return list(from_uniq) except (KeyError, IndexError, TypeError) as e: diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 6fadc5d8..8e7fcf30 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -162,7 +162,7 @@ def _diff_tables_root(self, table1: TableSegment, table2: TableSegment, info_tre yield from self._diff_segments(None, table1, table2, info_tree, None) else: yield from self._bisect_and_diff_tables(table1, table2, info_tree) - logger.info("Diffing complete") + logger.info(f"Diffing complete: {table1.table_path} <> {table2.table_path}") if self.materialize_to_table: logger.info("Materialized diff to table '%s'.", ".".join(self.materialize_to_table)) @@ -193,8 +193,8 @@ def _diff_segments( partial(self._collect_stats, 1, table1, info_tree), partial(self._collect_stats, 2, table2, info_tree), partial(self._test_null_keys, table1, table2), - partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols), - partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols), + partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols, table1, table2), + partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols, table1, table2), partial( self._materialize_diff, db, @@ -205,8 +205,8 @@ def _diff_segments( else None, ): assert len(a_cols) == len(b_cols) - logger.debug("Querying for different rows") - diff = db.query(diff_rows, list) + logger.debug(f"Querying for different rows: {table1.table_path}") + diff = db.query(diff_rows, list, log_message=table1.table_path) info_tree.info.set_diff(diff, schema=tuple(diff_rows.schema.items())) for is_xa, is_xb, *x in diff: if is_xa and is_xb: @@ -227,7 +227,7 @@ def _diff_segments( yield "+", tuple(b_row) def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment): - logger.debug("Testing for duplicate keys") + logger.debug(f"Testing for duplicate keys: {table1.table_path} <> {table2.table_path}") # Test duplicate keys for ts in [table1, table2]: @@ -240,16 +240,16 @@ def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment): unvalidated = list(set(key_columns) - set(unique)) if unvalidated: - logger.info(f"Validating that the are no duplicate keys in columns: {unvalidated}") + logger.info(f"Validating that the are no duplicate keys in columns: {unvalidated} for {ts.table_path}") # Validate that there are no duplicate keys self.stats["validated_unique_keys"] = self.stats.get("validated_unique_keys", []) + [unvalidated] q = t.select(total=Count(), total_distinct=Count(Concat(this[unvalidated]), distinct=True)) - total, total_distinct = ts.database.query(q, tuple) + total, total_distinct = ts.database.query(q, tuple, log_message=ts.table_path) if total != total_distinct: raise ValueError("Duplicate primary keys") def _test_null_keys(self, table1, table2): - logger.debug("Testing for null keys") + logger.debug(f"Testing for null keys: {table1.table_path} <> {table2.table_path}") # Test null keys for ts in [table1, table2]: @@ -257,7 +257,7 @@ def _test_null_keys(self, table1, table2): key_columns = ts.key_columns q = t.select(*this[key_columns]).where(or_(this[k] == None for k in key_columns)) - nulls = ts.database.query(q, list) + nulls = ts.database.query(q, list, log_message=ts.table_path) if nulls: if self.skip_null_keys: logger.warning( @@ -267,7 +267,7 @@ def _test_null_keys(self, table1, table2): raise ValueError(f"NULL values in one or more primary keys of {ts.table_path}") def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree): - logger.debug(f"Collecting stats for table #{i}") + logger.debug(f"Collecting stats for table #{i}: {table_seg.table_path}") db = table_seg.database # Metrics @@ -288,7 +288,7 @@ def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree): ) col_exprs["count"] = Count() - res = db.query(table_seg.make_select().select(**col_exprs), tuple) + res = db.query(table_seg.make_select().select(**col_exprs), tuple, log_message=table_seg.table_path) for col_name, value in safezip(col_exprs, res): if value is not None: @@ -303,7 +303,7 @@ def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree): else: self.stats[stat_name] = value - logger.debug("Done collecting stats for table #%s", i) + logger.debug("Done collecting stats for table #%s: %s", i, table_seg.table_path) def _create_outer_join(self, table1, table2): db = table1.database @@ -334,23 +334,46 @@ def _create_outer_join(self, table1, table2): diff_rows = all_rows.where(or_(this[c] == 1 for c in is_diff_cols)) return diff_rows, a_cols, b_cols, is_diff_cols, all_rows - def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols): - logger.debug("Counting differences per column") - is_diff_cols_counts = db.query(diff_rows.select(sum_(this[c]) for c in is_diff_cols), tuple) + def _count_diff_per_column( + self, + db, + diff_rows, + cols, + is_diff_cols, + table1: Optional[TableSegment] = None, + table2: Optional[TableSegment] = None, + ): + logger.info(type(table1)) + logger.debug(f"Counting differences per column: {table1.table_path} <> {table2.table_path}") + is_diff_cols_counts = db.query( + diff_rows.select(sum_(this[c]) for c in is_diff_cols), + tuple, + log_message=f"{table1.table_path} <> {table2.table_path}", + ) diff_counts = {} for name, count in safezip(cols, is_diff_cols_counts): diff_counts[name] = diff_counts.get(name, 0) + (count or 0) self.stats["diff_counts"] = diff_counts - def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): + def _sample_and_count_exclusive( + self, + db, + diff_rows, + a_cols, + b_cols, + table1: Optional[TableSegment] = None, + table2: Optional[TableSegment] = None, + ): if isinstance(db, (Oracle, MsSQL)): exclusive_rows_query = diff_rows.where((this.is_exclusive_a == 1) | (this.is_exclusive_b == 1)) else: exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b) if not self.sample_exclusive_rows: - logger.debug("Counting exclusive rows") - self.stats["exclusive_count"] = db.query(exclusive_rows_query.count(), int) + logger.debug(f"Counting exclusive rows: {table1.table_path} <> {table2.table_path}") + self.stats["exclusive_count"] = db.query( + exclusive_rows_query.count(), int, log_message=f"{table1.table_path} <> {table2.table_path}" + ) return logger.info("Counting and sampling exclusive rows") diff --git a/data_diff/utils.py b/data_diff/utils.py index ee4a0f17..b9045cc1 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -485,31 +485,31 @@ def __init__(self): super().__init__() self.status = Status("") self.prefix = "" - self.cloud_diff_status = {} + self.diff_status = {} def emit(self, record): log_entry = self.format(record) - if self.cloud_diff_status: - self._update_cloud_status(log_entry) + if self.diff_status: + self._update_diff_status(log_entry) else: self.status.update(self.prefix + log_entry) def set_prefix(self, prefix_string): self.prefix = prefix_string - def cloud_diff_started(self, model_name): - self.cloud_diff_status[model_name] = "[yellow]In Progress[/]" - self._update_cloud_status() + def diff_started(self, model_name): + self.diff_status[model_name] = "[yellow]In Progress[/]" + self._update_diff_status() - def cloud_diff_finished(self, model_name): - self.cloud_diff_status[model_name] = "[green]Finished [/]" - self._update_cloud_status() + def diff_finished(self, model_name): + self.diff_status[model_name] = "[green]Finished [/]" + self._update_diff_status() - def _update_cloud_status(self, log=None): - cloud_status_string = "\n" - for model_name, status in self.cloud_diff_status.items(): - cloud_status_string += f"{status} {model_name}\n" - self.status.update(f"{cloud_status_string}{log or ''}") + def _update_diff_status(self, log=None): + status_string = "\n" + for model_name, status in self.diff_status.items(): + status_string += f"{status} {model_name}\n" + self.status.update(f"{status_string}{log or ''}") class UnknownMeta(type): diff --git a/tests/test_dbt.py b/tests/test_dbt.py index c281b6fb..31af99eb 100644 --- a/tests/test_dbt.py +++ b/tests/test_dbt.py @@ -93,8 +93,8 @@ def test_local_diff(self, mock_diff_tables): ) self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 2) self.assertEqual(mock_connect.call_count, 2) - mock_connect.assert_any_call(connection, ".".join(dev_qualified_list), tuple(expected_primary_keys), threads) - mock_connect.assert_any_call(connection, ".".join(prod_qualified_list), tuple(expected_primary_keys), threads) + mock_connect.assert_any_call(connection, ".".join(dev_qualified_list), tuple(expected_primary_keys)) + mock_connect.assert_any_call(connection, ".".join(prod_qualified_list), tuple(expected_primary_keys)) mock_diff.get_stats_string.assert_called_once() @patch("data_diff.dbt.diff_tables") @@ -180,8 +180,8 @@ def test_local_diff_no_diffs(self, mock_diff_tables): ) self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 2) self.assertEqual(mock_connect.call_count, 2) - mock_connect.assert_any_call(connection, ".".join(dev_qualified_list), tuple(expected_primary_keys), None) - mock_connect.assert_any_call(connection, ".".join(prod_qualified_list), tuple(expected_primary_keys), None) + mock_connect.assert_any_call(connection, ".".join(dev_qualified_list), tuple(expected_primary_keys)) + mock_connect.assert_any_call(connection, ".".join(prod_qualified_list), tuple(expected_primary_keys)) mock_diff.get_stats_string.assert_not_called() @patch("data_diff.dbt.rich.print") @@ -248,6 +248,7 @@ def test_diff_is_cloud( where = "a_string" config = TDatadiffConfig(prod_database="prod_db", prod_schema="prod_schema", datasource_id=1) mock_dbt_parser_inst = Mock() + mock_dbt_parser_inst.threads = threads mock_model = Mock() mock_api.get_data_source.return_value = TCloudApiDataSource(id=1, type="snowflake", name="snowflake") mock_initialize_api.return_value = mock_api @@ -386,6 +387,7 @@ def test_diff_is_not_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, m threads = None where = "a_string" mock_dbt_parser_inst = Mock() + mock_dbt_parser_inst.threads = threads mock_dbt_parser.return_value = mock_dbt_parser_inst mock_model = Mock() mock_dbt_parser_inst.get_models.return_value = [mock_model] @@ -407,7 +409,7 @@ def test_diff_is_not_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, m mock_dbt_parser_inst.get_models.assert_called_once() mock_dbt_parser_inst.set_connection.assert_called_once() mock_cloud_diff.assert_not_called() - mock_local_diff.assert_called_once_with(diff_vars, False) + mock_local_diff.assert_called_once_with(diff_vars, False, None) mock_print.assert_not_called() @patch("data_diff.dbt._get_diff_vars") @@ -423,6 +425,7 @@ def test_diff_state_model_dne( threads = None where = "a_string" mock_dbt_parser_inst = Mock() + mock_dbt_parser_inst.threads = threads mock_dbt_parser.return_value = mock_dbt_parser_inst mock_model = Mock() mock_dbt_parser_inst.get_models.return_value = [mock_model] @@ -460,6 +463,7 @@ def test_diff_only_prod_db(self, mock_print, mock_dbt_parser, mock_cloud_diff, m threads = None where = "a_string" mock_dbt_parser_inst = Mock() + mock_dbt_parser_inst.threads = threads mock_dbt_parser.return_value = mock_dbt_parser_inst mock_model = Mock() mock_dbt_parser_inst.get_models.return_value = [mock_model] @@ -481,7 +485,7 @@ def test_diff_only_prod_db(self, mock_print, mock_dbt_parser, mock_cloud_diff, m mock_dbt_parser_inst.get_models.assert_called_once() mock_dbt_parser_inst.set_connection.assert_called_once() mock_cloud_diff.assert_not_called() - mock_local_diff.assert_called_once_with(diff_vars, False) + mock_local_diff.assert_called_once_with(diff_vars, False, None) mock_print.assert_not_called() @patch("data_diff.dbt._get_diff_vars") @@ -497,6 +501,7 @@ def test_diff_only_prod_schema( threads = None where = "a_string" mock_dbt_parser_inst = Mock() + mock_dbt_parser_inst.threads = threads mock_dbt_parser.return_value = mock_dbt_parser_inst mock_model = Mock() mock_dbt_parser_inst.get_models.return_value = [mock_model] @@ -518,7 +523,7 @@ def test_diff_only_prod_schema( mock_dbt_parser_inst.get_models.assert_called_once() mock_dbt_parser_inst.set_connection.assert_called_once() mock_cloud_diff.assert_not_called() - mock_local_diff.assert_called_once_with(diff_vars, False) + mock_local_diff.assert_called_once_with(diff_vars, False, None) mock_print.assert_not_called() @patch("data_diff.dbt._initialize_api") @@ -543,6 +548,7 @@ def test_diff_is_cloud_no_pks( mock_model = Mock() connection = {} threads = None + mock_dbt_parser_inst.threads = threads where = "a_string" config = TDatadiffConfig(prod_database="prod_db", prod_schema="prod_schema", datasource_id=1) mock_api = Mock() @@ -584,6 +590,7 @@ def test_diff_not_is_cloud_no_pks( threads = None where = "a_string" mock_dbt_parser_inst = Mock() + mock_dbt_parser_inst.threads = threads mock_dbt_parser.return_value = mock_dbt_parser_inst mock_model = Mock() mock_dbt_parser_inst.get_models.return_value = [mock_model] From be1d947309ba8e5cefc5b1e47dae78dfe4df8d6e Mon Sep 17 00:00:00 2001 From: Valentin Khomutenko Date: Wed, 6 Dec 2023 16:23:53 +0100 Subject: [PATCH 2/7] accept either key file path or file itself --- data_diff/databases/snowflake.py | 40 ++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index d83c0f40..5ab48243 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -1,4 +1,5 @@ -from typing import Any, Union, List +import base64 +from typing import Any, Union, List, Optional import logging import attrs @@ -162,7 +163,7 @@ class Snowflake(Database): _conn: Any - def __init__(self, *, schema: str, **kw): + def __init__(self, *, schema: str, key: Optional[str] = None, key_content: Optional[str] = None, **kw): super().__init__() snowflake, serialization, default_backend = import_snowflake() logging.getLogger("snowflake.connector").setLevel(logging.WARNING) @@ -172,20 +173,29 @@ def __init__(self, *, schema: str, **kw): logging.getLogger("snowflake.connector.network").disabled = True assert '"' not in schema, "Schema name should not contain quotes!" + if key_content and key: + raise ConnectError("Only key value or key file path can be specified, not both") + + key_bytes = None + if key: + with open(key, "rb") as f: + key_bytes = f.read() + if key_content: + key_bytes = base64.b64decode(key_content) + # If a private key is used, read it from the specified path and pass it as "private_key" to the connector. - if "key" in kw: - with open(kw.get("key"), "rb") as key: - if "password" in kw: - raise ConnectError("Cannot use password and key at the same time") - if kw.get("private_key_passphrase"): - encoded_passphrase = kw.get("private_key_passphrase").encode() - else: - encoded_passphrase = None - p_key = serialization.load_pem_private_key( - key.read(), - password=encoded_passphrase, - backend=default_backend(), - ) + if key_bytes: + if "password" in kw: + raise ConnectError("Cannot use password and key at the same time") + if kw.get("private_key_passphrase"): + encoded_passphrase = kw.get("private_key_passphrase").encode() + else: + encoded_passphrase = None + p_key = serialization.load_pem_private_key( + key_bytes, + password=encoded_passphrase, + backend=default_backend(), + ) kw["private_key"] = p_key.private_bytes( encoding=serialization.Encoding.DER, From bf7276b072938d29548f94e03df69d5f046f200b Mon Sep 17 00:00:00 2001 From: Dan Date: Wed, 6 Dec 2023 12:52:26 -0700 Subject: [PATCH 3/7] mssql normalize_uuid --- data_diff/databases/mssql.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index 8f5195ee..b637c285 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -13,6 +13,7 @@ ) from data_diff.abcs.database_types import ( JSON, + ColType_UUID, NumericType, Timestamp, TimestampTZ, @@ -154,6 +155,9 @@ def md5_as_int(self, s: str) -> str: def md5_as_hex(self, s: str) -> str: return f"HashBytes('MD5', {s})" + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + return f"TRIM(CAST({value} AS char)) AS {value}" + @attrs.define(frozen=False, init=False, kw_only=True) class MsSQL(ThreadedDatabase): From 8cfcb9b7d0e3611274c2cdd647c48afc611e3e23 Mon Sep 17 00:00:00 2001 From: Dan Lawin Date: Wed, 6 Dec 2023 16:10:36 -0700 Subject: [PATCH 4/7] Revert "mssql normalize_uuid" --- data_diff/databases/mssql.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index b637c285..8f5195ee 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -13,7 +13,6 @@ ) from data_diff.abcs.database_types import ( JSON, - ColType_UUID, NumericType, Timestamp, TimestampTZ, @@ -155,9 +154,6 @@ def md5_as_int(self, s: str) -> str: def md5_as_hex(self, s: str) -> str: return f"HashBytes('MD5', {s})" - def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: - return f"TRIM(CAST({value} AS char)) AS {value}" - @attrs.define(frozen=False, init=False, kw_only=True) class MsSQL(ThreadedDatabase): From 1ddd783091fba89d7496770db35cea0685a4149c Mon Sep 17 00:00:00 2001 From: Dan Date: Wed, 6 Dec 2023 16:13:59 -0700 Subject: [PATCH 5/7] fix mssql limit_select --- data_diff/databases/mssql.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index 8f5195ee..d15e7954 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -113,19 +113,21 @@ def is_distinct_from(self, a: str, b: str) -> str: def limit_select( self, select_query: str, - offset: Optional[int] = None, - limit: Optional[int] = None, - has_order_by: Optional[bool] = None, + offset: int | None = None, + limit: int | None = None, + has_order_by: bool | None = None, ) -> str: if offset: raise NotImplementedError("No support for OFFSET in query") - result = "" if not has_order_by: result += "ORDER BY 1" result += f" OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY" - return f"SELECT * FROM ({select_query}) AS LIMITED_SELECT {result}" + + # mssql requires that subquery columns are all aliased, so + # don't wrap in an outer select + return f"{select_query} {result}" def constant_values(self, rows) -> str: values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows) From 643b1ea79cd5d63f03ce87e78de3646303394afa Mon Sep 17 00:00:00 2001 From: Dan Date: Wed, 6 Dec 2023 16:14:52 -0700 Subject: [PATCH 6/7] use Optional[] --- data_diff/databases/mssql.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index d15e7954..834ed9cd 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -113,9 +113,9 @@ def is_distinct_from(self, a: str, b: str) -> str: def limit_select( self, select_query: str, - offset: int | None = None, - limit: int | None = None, - has_order_by: bool | None = None, + offset: Optional[int] = None, + limit: Optional[int] = None, + has_order_by: Optional[bool] = None, ) -> str: if offset: raise NotImplementedError("No support for OFFSET in query") From 6906c42fffda9fbc737a8695b53cb85a8a8873f9 Mon Sep 17 00:00:00 2001 From: Dan Date: Mon, 11 Dec 2023 11:59:06 -0700 Subject: [PATCH 7/7] v0.10.0rc0 --- data_diff/version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/data_diff/version.py b/data_diff/version.py index e1170d35..90ebae8a 100644 --- a/data_diff/version.py +++ b/data_diff/version.py @@ -1 +1 @@ -__version__ = "0.9.17" +__version__ = "0.10.0rc0" diff --git a/pyproject.toml b/pyproject.toml index c3448c01..1b380dde 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "data-diff" -version = "0.9.17" +version = "0.10.0rc0" description = "Command-line tool and Python library to efficiently diff rows across two different databases." authors = ["Datafold "] license = "MIT"