diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index de5ea8b7..c755cfa9 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -139,47 +139,38 @@ def create_connection(self): raise ConnectionError(*e.args) from e def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: + # Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL. + # https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html + # So, to obtain information about schema, we should use another approach. + conn = self.create_connection() - table_schema = {} - try: - table_schema = super().query_table_schema(path) - except: - logging.warning("Failed to get schema from information_schema, falling back to legacy approach.") - - if not table_schema: - # This legacy approach can cause bugs. e.g. VARCHAR(255) -> VARCHAR(255) - # and not the expected VARCHAR - - # I don't think we'll fall back to this approach, but if so, see above - catalog, schema, table = self._normalize_table_path(path) - with conn.cursor() as cursor: - cursor.columns(catalog_name=catalog, schema_name=schema, table_name=table) - try: - rows = cursor.fetchall() - finally: - conn.close() - if not rows: - raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") - - table_schema = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows} - assert len(table_schema) == len(rows) - return table_schema - else: - return table_schema - - def select_table_schema(self, path: DbPath) -> str: - """Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)""" - database, schema, name = self._normalize_table_path(path) - info_schema_path = ["information_schema", "columns"] - if database: - info_schema_path.insert(0, database) - - return ( - "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " - f"FROM {'.'.join(info_schema_path)} " - f"WHERE table_name = '{name}' AND table_schema = '{schema}'" - ) + catalog, schema, table = self._normalize_table_path(path) + with conn.cursor() as cursor: + cursor.columns(catalog_name=catalog, schema_name=schema, table_name=table) + try: + rows = cursor.fetchall() + finally: + conn.close() + if not rows: + raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") + + d = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows} + assert len(d) == len(rows) + return d + + # def select_table_schema(self, path: DbPath) -> str: + # """Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)""" + # database, schema, name = self._normalize_table_path(path) + # info_schema_path = ["information_schema", "columns"] + # if database: + # info_schema_path.insert(0, database) + + # return ( + # "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " + # f"FROM {'.'.join(info_schema_path)} " + # f"WHERE table_name = '{name}' AND table_schema = '{schema}'" + # ) def _process_table_schema( self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None diff --git a/data_diff/version.py b/data_diff/version.py index d2c52686..0db237a6 100644 --- a/data_diff/version.py +++ b/data_diff/version.py @@ -1 +1 @@ -__version__ = "0.9.15" +__version__ = "0.9.16" diff --git a/pyproject.toml b/pyproject.toml index 1e240c67..9394dd8b 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "data-diff" -version = "0.9.15" +version = "0.9.16" description = "Command-line tool and Python library to efficiently diff rows across two different databases." authors = ["Datafold "] license = "MIT"