diff --git a/data_diff/__main__.py b/data_diff/__main__.py index c4e698f5..02f3b31d 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -77,7 +77,7 @@ def _get_schema(pair: Tuple[Database, DbPath]) -> Dict[str, RawColumnInfo]: return db.query_table_schema(table_path) -def diff_schemas(table1, table2, schema1, schema2, columns): +def diff_schemas(table1, table2, schema1, schema2, columns) -> None: logging.info("Diffing schemas...") attrs = "name", "type", "datetime_precision", "numeric_precision", "numeric_scale" for c in columns: @@ -103,7 +103,7 @@ def diff_schemas(table1, table2, schema1, schema2, columns): class MyHelpFormatter(click.HelpFormatter): - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super().__init__(self, **kwargs) self.indent_increment = 6 @@ -281,7 +281,7 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) - default=None, help="Override the dbt production schema configuration within dbt_project.yml", ) -def main(conf, run, **kw): +def main(conf, run, **kw) -> None: log_handlers = _get_log_handlers(kw["dbt"]) if kw["table2"] is None and kw["database2"]: # Use the "database table table" form @@ -341,9 +341,7 @@ def main(conf, run, **kw): production_schema_flag=kw["prod_schema"], ) else: - return _data_diff( - dbt_project_dir=project_dir_override, dbt_profiles_dir=profiles_dir_override, state=state, **kw - ) + _data_diff(dbt_project_dir=project_dir_override, dbt_profiles_dir=profiles_dir_override, state=state, **kw) except Exception as e: logging.error(e) raise @@ -389,7 +387,7 @@ def _data_diff( threads1=None, threads2=None, __conf__=None, -): +) -> None: if limit and stats: logging.error("Cannot specify a limit when using the -s/--stats switch") return diff --git a/data_diff/abcs/database_types.py b/data_diff/abcs/database_types.py index e5aa5fab..1c40883d 100644 --- a/data_diff/abcs/database_types.py +++ b/data_diff/abcs/database_types.py @@ -290,7 +290,7 @@ class Integer(NumericType, IKey): precision: int = 0 python_type: type = int - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: assert self.precision == 0 diff --git a/data_diff/cloud/data_source.py b/data_diff/cloud/data_source.py index 32fd89b9..3f3b2e16 100644 --- a/data_diff/cloud/data_source.py +++ b/data_diff/cloud/data_source.py @@ -46,7 +46,7 @@ def process_response(self, value: str) -> str: return value -def _validate_temp_schema(temp_schema: str): +def _validate_temp_schema(temp_schema: str) -> None: if len(temp_schema.split(".")) != 2: raise ValueError("Temporary schema should have a format .") diff --git a/data_diff/cloud/datafold_api.py b/data_diff/cloud/datafold_api.py index 3ec637a0..2bba5243 100644 --- a/data_diff/cloud/datafold_api.py +++ b/data_diff/cloud/datafold_api.py @@ -185,7 +185,7 @@ class DatafoldAPI: host: str = "https://app.datafold.com/" timeout: int = 30 - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: self.host = self.host.rstrip("/") self.headers = { "Authorization": f"Key {self.api_key}", diff --git a/data_diff/config.py b/data_diff/config.py index 1b091f07..3c4cbef9 100644 --- a/data_diff/config.py +++ b/data_diff/config.py @@ -99,7 +99,7 @@ def _apply_config(config: Dict[str, Any], run_name: str, kw: Dict[str, Any]): _ENV_VAR_PATTERN = r"\$\{([A-Za-z0-9_]+)\}" -def _resolve_env(config: Dict[str, Any]): +def _resolve_env(config: Dict[str, Any]) -> None: """ Resolve environment variables referenced as ${ENV_VAR_NAME}. Missing environment variables are replaced with an empty string. diff --git a/data_diff/databases/_connect.py b/data_diff/databases/_connect.py index 1e34ef62..df63c78b 100644 --- a/data_diff/databases/_connect.py +++ b/data_diff/databases/_connect.py @@ -100,7 +100,7 @@ class Connect: database_by_scheme: Dict[str, Database] conn_cache: MutableMapping[Hashable, Database] - def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME): + def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME) -> None: super().__init__() self.database_by_scheme = database_by_scheme self.conn_cache = weakref.WeakValueDictionary() diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 8beb00ff..d6549f71 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -5,7 +5,22 @@ import math import sys import logging -from typing import Any, Callable, ClassVar, Dict, Generator, Tuple, Optional, Sequence, Type, List, Union, TypeVar +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Generator, + Iterator, + NewType, + Tuple, + Optional, + Sequence, + Type, + List, + Union, + TypeVar, +) from functools import partial, wraps from concurrent.futures import ThreadPoolExecutor import threading @@ -116,7 +131,7 @@ def dialect(self) -> "BaseDialect": def compile(self, elem, params=None) -> str: return self.dialect.compile(self, elem, params) - def new_unique_name(self, prefix="tmp"): + def new_unique_name(self, prefix="tmp") -> str: self._counter[0] += 1 return f"{prefix}{self._counter[0]}" @@ -173,7 +188,7 @@ class ThreadLocalInterpreter: compiler: Compiler gen: Generator - def apply_queries(self, callback: Callable[[str], Any]): + def apply_queries(self, callback: Callable[[str], Any]) -> None: q: Expr = next(self.gen) while True: sql = self.compiler.database.dialect.compile(self.compiler, q) @@ -885,20 +900,21 @@ def optimizer_hints(self, hints: str) -> str: T = TypeVar("T", bound=BaseDialect) +Row = Sequence[Any] @attrs.define(frozen=True) class QueryResult: - rows: list + rows: List[Row] columns: Optional[list] = None - def __iter__(self): + def __iter__(self) -> Iterator[Row]: return iter(self.rows) - def __len__(self): + def __len__(self) -> int: return len(self.rows) - def __getitem__(self, i): + def __getitem__(self, i) -> Row: return self.rows[i] @@ -1209,7 +1225,7 @@ class ThreadedDatabase(Database): _queue: Optional[ThreadPoolExecutor] = None thread_local: threading.local = attrs.field(factory=threading.local) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: self._queue = ThreadPoolExecutor(self.thread_count, initializer=self.set_conn) logger.info(f"[{self.name}] Starting a threadpool, size={self.thread_count}.") diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 7caa14a1..c4470ee6 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -85,10 +85,10 @@ class Dialect(BaseDialect): def random(self) -> str: return "RAND()" - def quote(self, s: str): + def quote(self, s: str) -> str: return f"`{s}`" - def to_string(self, s: str): + def to_string(self, s: str) -> str: return f"cast({s} as string)" def type_repr(self, t) -> str: @@ -212,7 +212,7 @@ class BigQuery(Database): dataset: str _client: Any - def __init__(self, project, *, dataset, bigquery_credentials=None, **kw): + def __init__(self, project, *, dataset, bigquery_credentials=None, **kw) -> None: super().__init__() credentials = bigquery_credentials bigquery = import_bigquery() diff --git a/data_diff/databases/clickhouse.py b/data_diff/databases/clickhouse.py index 7bbc156f..43b69b3f 100644 --- a/data_diff/databases/clickhouse.py +++ b/data_diff/databases/clickhouse.py @@ -175,7 +175,7 @@ class Clickhouse(ThreadedDatabase): _args: Dict[str, Any] - def __init__(self, *, thread_count: int, **kw): + def __init__(self, *, thread_count: int, **kw) -> None: super().__init__(thread_count=thread_count) self._args = kw diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index efd35ea1..c3e40123 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -65,7 +65,7 @@ def type_repr(self, t) -> str: except KeyError: return super().type_repr(t) - def quote(self, s: str): + def quote(self, s: str) -> str: return f"`{s}`" def to_string(self, s: str) -> str: @@ -118,7 +118,7 @@ class Databricks(ThreadedDatabase): catalog: str _args: Dict[str, Any] - def __init__(self, *, thread_count, **kw): + def __init__(self, *, thread_count, **kw) -> None: super().__init__(thread_count=thread_count) logging.getLogger("databricks.sql").setLevel(logging.WARNING) diff --git a/data_diff/databases/duckdb.py b/data_diff/databases/duckdb.py index 48a44ffc..4cdacde8 100644 --- a/data_diff/databases/duckdb.py +++ b/data_diff/databases/duckdb.py @@ -126,7 +126,7 @@ class DuckDB(Database): _args: Dict[str, Any] = attrs.field(init=False) _conn: Any = attrs.field(init=False) - def __init__(self, **kw): + def __init__(self, **kw) -> None: super().__init__() self._args = kw self._conn = self.create_connection() diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index c5444610..e4f841ef 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -76,7 +76,7 @@ class Dialect(BaseDialect): "json": JSON, } - def quote(self, s: str): + def quote(self, s: str) -> str: return f"[{s}]" def set_timezone_to_utc(self) -> str: @@ -93,7 +93,7 @@ def current_schema(self) -> str: FROM sys.database_principals WHERE name = CURRENT_USER""" - def to_string(self, s: str): + def to_string(self, s: str) -> str: # Both convert(varchar(max), …) and convert(text, …) do work. return f"CONVERT(VARCHAR(MAX), {s})" @@ -168,7 +168,7 @@ class MsSQL(ThreadedDatabase): _args: Dict[str, Any] _mssql: Any - def __init__(self, host, port, user, password, *, database, thread_count, **kw): + def __init__(self, host, port, user, password, *, database, thread_count, **kw) -> None: super().__init__(thread_count=thread_count) args = dict(server=host, port=port, database=database, user=user, password=password, **kw) diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index 2b1e810f..1ee04460 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -70,10 +70,10 @@ class Dialect(BaseDialect): "boolean": Boolean, } - def quote(self, s: str): + def quote(self, s: str) -> str: return f"`{s}`" - def to_string(self, s: str): + def to_string(self, s: str) -> str: return f"cast({s} as char)" def is_distinct_from(self, a: str, b: str) -> str: @@ -129,7 +129,7 @@ class MySQL(ThreadedDatabase): _args: Dict[str, Any] - def __init__(self, *, thread_count, **kw): + def __init__(self, *, thread_count, **kw) -> None: super().__init__(thread_count=thread_count) self._args = kw diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 2f6537be..f5960476 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -59,10 +59,10 @@ class Dialect( ROUNDS_ON_PREC_LOSS = True PLACEHOLDER_TABLE = "DUAL" - def quote(self, s: str): + def quote(self, s: str) -> str: return f'"{s}"' - def to_string(self, s: str): + def to_string(self, s: str) -> str: return f"cast({s} as varchar(1024))" def limit_select( @@ -164,7 +164,7 @@ class Oracle(ThreadedDatabase): kwargs: Dict[str, Any] _oracle: Any - def __init__(self, *, host, database, thread_count, **kw): + def __init__(self, *, host, database, thread_count, **kw) -> None: super().__init__(thread_count=thread_count) self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw) self.default_schema = kw.get("user").upper() diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index 93e0dc2f..1ec34f4c 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -163,7 +163,7 @@ class PostgreSQL(ThreadedDatabase): _args: Dict[str, Any] _conn: Any - def __init__(self, *, thread_count, **kw): + def __init__(self, *, thread_count, **kw) -> None: super().__init__(thread_count=thread_count) self._args = kw self.default_schema = "public" diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index 034ac99f..42e28056 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -152,7 +152,7 @@ class Presto(Database): _conn: Any - def __init__(self, **kw): + def __init__(self, **kw) -> None: super().__init__() self.default_schema = "public" prestodb = import_presto() diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 1b70085a..823ae5b2 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -104,7 +104,7 @@ class Snowflake(Database): _conn: Any - def __init__(self, *, schema: str, key: Optional[str] = None, key_content: Optional[str] = None, **kw): + def __init__(self, *, schema: str, key: Optional[str] = None, key_content: Optional[str] = None, **kw) -> None: super().__init__() snowflake, serialization, default_backend = import_snowflake() logging.getLogger("snowflake.connector").setLevel(logging.WARNING) diff --git a/data_diff/databases/trino.py b/data_diff/databases/trino.py index b76ba74b..5a432ee5 100644 --- a/data_diff/databases/trino.py +++ b/data_diff/databases/trino.py @@ -40,7 +40,7 @@ class Trino(presto.Presto): _conn: Any - def __init__(self, **kw): + def __init__(self, **kw) -> None: super().__init__() trino = import_trino() diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index 27d017bd..cfe046d2 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -60,7 +60,7 @@ class Dialect(BaseDialect): # https://www.vertica.com/docs/9.3.x/HTML/Content/Authoring/SQLReferenceManual/DataTypes/Numeric/NUMERIC.htm#Default DEFAULT_NUMERIC_PRECISION = 15 - def quote(self, s: str): + def quote(self, s: str) -> str: return f'"{s}"' def concat(self, items: List[str]) -> str: @@ -137,7 +137,7 @@ class Vertica(ThreadedDatabase): _args: Dict[str, Any] - def __init__(self, *, thread_count, **kw): + def __init__(self, *, thread_count, **kw) -> None: super().__init__(thread_count=thread_count) self._args = kw self._args["AUTOCOMMIT"] = False diff --git a/data_diff/dbt_parser.py b/data_diff/dbt_parser.py index 0d864a57..eda5f6c5 100644 --- a/data_diff/dbt_parser.py +++ b/data_diff/dbt_parser.py @@ -50,7 +50,7 @@ def try_get_dbt_runner(): # ProfileRenderer.render_data() fails without instantiating global flag MACRO_DEBUGGING in dbt-core 1.5 # hacky but seems to be a bug on dbt's end -def try_set_dbt_flags(): +def try_set_dbt_flags() -> None: try: from dbt.flags import set_flags diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 5ec1f71b..74376f63 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -6,7 +6,7 @@ from enum import Enum from contextlib import contextmanager from operator import methodcaller -from typing import Dict, Set, List, Tuple, Iterator, Optional, Union +from typing import Any, Dict, Set, List, Tuple, Iterator, Optional, Union from concurrent.futures import ThreadPoolExecutor, as_completed import attrs @@ -89,7 +89,7 @@ class DiffResultWrapper: stats: dict result_list: list = attrs.field(factory=list) - def __iter__(self): + def __iter__(self) -> Iterator[Any]: yield from self.result_list for i in self.diff: self.result_list.append(i) diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py index 29508965..0984d9f1 100644 --- a/data_diff/hashdiff_tables.py +++ b/data_diff/hashdiff_tables.py @@ -96,7 +96,7 @@ class HashDiffer(TableDiffer): stats: dict = attrs.field(factory=dict) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: # Validate options if self.bisection_factor >= self.bisection_threshold: raise ValueError("Incorrect param values (bisection factor must be lower than threshold)") diff --git a/data_diff/info_tree.py b/data_diff/info_tree.py index bc9430de..a9683c20 100644 --- a/data_diff/info_tree.py +++ b/data_diff/info_tree.py @@ -18,13 +18,15 @@ class SegmentInfo: rowcounts: Dict[int, int] = attrs.field(factory=dict) max_rows: Optional[int] = None - def set_diff(self, diff: List[Union[Tuple[Any, ...], List[Any]]], schema: Optional[Tuple[Tuple[str, type]]] = None): + def set_diff( + self, diff: List[Union[Tuple[Any, ...], List[Any]]], schema: Optional[Tuple[Tuple[str, type]]] = None + ) -> None: self.diff_schema = schema self.diff = diff self.diff_count = len(diff) self.is_diff = self.diff_count > 0 - def update_from_children(self, child_infos): + def update_from_children(self, child_infos) -> None: child_infos = list(child_infos) assert child_infos @@ -53,7 +55,7 @@ def add_node(self, table1: TableSegment, table2: TableSegment, max_rows: Optiona self.children.append(node) return node - def aggregate_info(self): + def aggregate_info(self) -> None: if self.children: for c in self.children: c.aggregate_info() diff --git a/data_diff/lexicographic_space.py b/data_diff/lexicographic_space.py index b7d88e36..32fdde70 100644 --- a/data_diff/lexicographic_space.py +++ b/data_diff/lexicographic_space.py @@ -66,11 +66,11 @@ class LexicographicSpace: All elements must be of the same length as the number of dimensions. (no rpadding) """ - def __init__(self, dims: Vector): + def __init__(self, dims: Vector) -> None: super().__init__() self.dims = dims - def __contains__(self, v: Vector): + def __contains__(self, v: Vector) -> bool: return all(0 <= i < d for i, d in safezip(v, self.dims)) def add(self, v1: Vector, v2: Vector) -> Vector: @@ -124,7 +124,7 @@ class BoundedLexicographicSpace: i.e. a space resticted by a "bounding-box" between two arbitrary points. """ - def __init__(self, min_bound: Vector, max_bound: Vector): + def __init__(self, min_bound: Vector, max_bound: Vector) -> None: super().__init__() dims = tuple(mx - mn for mn, mx in safezip(min_bound, max_bound)) @@ -138,7 +138,7 @@ def __init__(self, min_bound: Vector, max_bound: Vector): self.uspace = LexicographicSpace(dims) - def __contains__(self, p: Vector): + def __contains__(self, p: Vector) -> bool: return all(mn <= i < mx for i, mn, mx in safezip(p, self.min_bound, self.max_bound)) def to_uspace(self, v: Vector) -> Vector: diff --git a/data_diff/parse_time.py b/data_diff/parse_time.py index 39924798..ec80ccb4 100644 --- a/data_diff/parse_time.py +++ b/data_diff/parse_time.py @@ -33,7 +33,7 @@ class ParseError(ValueError): UNITS_STR = ", ".join(sorted(TIME_UNITS.keys())) -def string_similarity(a, b): +def string_similarity(a, b) -> SequenceMatcher: return SequenceMatcher(None, a, b).ratio() @@ -53,7 +53,7 @@ def parse_time_atom(count, unit): return count, unit -def parse_time_delta(t: str): +def parse_time_delta(t: str) -> timedelta: time_dict = {} while t: m = TIME_RE.match(t) @@ -70,5 +70,5 @@ def parse_time_delta(t: str): return timedelta(**time_dict) -def parse_time_before(time: datetime, delta: str): +def parse_time_before(time: datetime, delta: str) -> datetime: return time - parse_time_delta(delta) diff --git a/data_diff/queries/api.py b/data_diff/queries/api.py index 82786871..0ff78719 100644 --- a/data_diff/queries/api.py +++ b/data_diff/queries/api.py @@ -26,7 +26,7 @@ def join(*tables: ITable) -> Join: return Join(tables) -def leftjoin(*tables: ITable): +def leftjoin(*tables: ITable) -> Join: """Left-joins a sequence of table expressions. See Also: ``join()`` @@ -34,7 +34,7 @@ def leftjoin(*tables: ITable): return Join(tables, "LEFT") -def rightjoin(*tables: ITable): +def rightjoin(*tables: ITable) -> Join: """Right-joins a sequence of table expressions. See Also: ``join()`` @@ -42,7 +42,7 @@ def rightjoin(*tables: ITable): return Join(tables, "RIGHT") -def outerjoin(*tables: ITable): +def outerjoin(*tables: ITable) -> Join: """Outer-joins a sequence of table expressions. See Also: ``join()`` @@ -50,7 +50,7 @@ def outerjoin(*tables: ITable): return Join(tables, "FULL OUTER") -def cte(expr: Expr, *, name: Optional[str] = None, params: Sequence[str] = None): +def cte(expr: Expr, *, name: Optional[str] = None, params: Sequence[str] = None) -> Cte: """Define a CTE""" return Cte(expr, name, params) @@ -72,7 +72,7 @@ def table(*path: str, schema: Union[dict, CaseAwareMapping] = None) -> TablePath return TablePath(path, schema) -def or_(*exprs: Expr): +def or_(*exprs: Expr) -> Union[BinBoolOp, Expr]: """Apply OR between a sequence of boolean expressions""" exprs = args_as_tuple(exprs) if len(exprs) == 1: @@ -80,7 +80,7 @@ def or_(*exprs: Expr): return BinBoolOp("OR", exprs) -def and_(*exprs: Expr): +def and_(*exprs: Expr) -> Union[BinBoolOp, Expr]: """Apply AND between a sequence of boolean expressions""" exprs = args_as_tuple(exprs) if len(exprs) == 1: @@ -88,32 +88,32 @@ def and_(*exprs: Expr): return BinBoolOp("AND", exprs) -def sum_(expr: Expr): +def sum_(expr: Expr) -> Func: """Call SUM(expr)""" return Func("sum", [expr]) -def avg(expr: Expr): +def avg(expr: Expr) -> Func: """Call AVG(expr)""" return Func("avg", [expr]) -def min_(expr: Expr): +def min_(expr: Expr) -> Func: """Call MIN(expr)""" return Func("min", [expr]) -def max_(expr: Expr): +def max_(expr: Expr) -> Func: """Call MAX(expr)""" return Func("max", [expr]) -def exists(expr: Expr): +def exists(expr: Expr) -> Func: """Call EXISTS(expr)""" return Func("exists", [expr]) -def if_(cond: Expr, then: Expr, else_: Optional[Expr] = None): +def if_(cond: Expr, then: Expr, else_: Optional[Expr] = None) -> CaseWhen: """Conditional expression, shortcut to when-then-else. Example: @@ -125,7 +125,7 @@ def if_(cond: Expr, then: Expr, else_: Optional[Expr] = None): return when(cond).then(then).else_(else_) -def when(*when_exprs: Expr): +def when(*when_exprs: Expr) -> QB_When: """Start a when-then expression Example: @@ -145,13 +145,13 @@ def when(*when_exprs: Expr): return CaseWhen([]).when(*when_exprs) -def coalesce(*exprs): +def coalesce(*exprs) -> Func: "Returns a call to COALESCE" exprs = args_as_tuple(exprs) return Func("COALESCE", exprs) -def insert_rows_in_batches(db, tbl: TablePath, rows, *, columns=None, batch_size=1024 * 8): +def insert_rows_in_batches(db, tbl: TablePath, rows, *, columns=None, batch_size=1024 * 8) -> None: assert batch_size > 0 rows = list(rows) @@ -160,7 +160,7 @@ def insert_rows_in_batches(db, tbl: TablePath, rows, *, columns=None, batch_size db.query(tbl.insert_rows(batch, columns=columns)) -def current_timestamp(): +def current_timestamp() -> CurrentTimestamp: """Returns CURRENT_TIMESTAMP() or NOW()""" return CurrentTimestamp() diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 3bf14fd7..0580e824 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -45,7 +45,7 @@ def _dfs_values(self): if isinstance(v, ExprNode): yield from v._dfs_values() - def cast_to(self, to): + def cast_to(self, to) -> "Cast": return Cast(self, to) @@ -110,7 +110,7 @@ def select(self, *exprs, distinct=SKIP, optimizer_hints=SKIP, **named_exprs) -> resolve_names(self.source_table, exprs) return Select.make(self, columns=exprs, distinct=distinct, optimizer_hints=optimizer_hints) - def where(self, *exprs): + def where(self, *exprs) -> "Select": """Filter the rows, based on the given predicates. (aka Selection)""" exprs = args_as_tuple(exprs) exprs = _drop_skips(exprs) @@ -120,7 +120,7 @@ def where(self, *exprs): resolve_names(self.source_table, exprs) return Select.make(self, where_exprs=exprs) - def order_by(self, *exprs): + def order_by(self, *exprs) -> "Select": """Order the rows lexicographically, according to the given expressions.""" exprs = _drop_skips(exprs) if not exprs: @@ -129,14 +129,14 @@ def order_by(self, *exprs): resolve_names(self.source_table, exprs) return Select.make(self, order_by_exprs=exprs) - def limit(self, limit: int): + def limit(self, limit: int) -> "Select": """Stop yielding rows after the given limit. i.e. take the first 'n=limit' rows""" if limit is SKIP: return self return Select.make(self, limit_expr=limit) - def join(self, target: "ITable"): + def join(self, target: "ITable") -> "Join": """Join the current table with the target table, returning a new table containing both side-by-side. When joining, it's recommended to use explicit tables names, instead of `this`, in order to avoid potential name collisions. @@ -180,7 +180,7 @@ def group_by(self, *keys) -> "GroupBy": return GroupBy(self, keys) - def _get_column(self, name: str): + def _get_column(self, name: str) -> "Column": if self.schema: name = self.schema.get_key(name) # Get the actual name. Might be case-insensitive. return Column(self, name) @@ -188,29 +188,29 @@ def _get_column(self, name: str): # def __getattr__(self, column): # return self._get_column(column) - def __getitem__(self, column): + def __getitem__(self, column) -> "Column": if not isinstance(column, str): raise TypeError() return self._get_column(column) - def count(self): + def count(self) -> "Select": """SELECT count() FROM self""" return Select(self, [Count()]) - def union(self, other: "ITable"): + def union(self, other: "ITable") -> "TableOp": """SELECT * FROM self UNION other""" return TableOp("UNION", self, other) - def union_all(self, other: "ITable"): + def union_all(self, other: "ITable") -> "TableOp": """SELECT * FROM self UNION ALL other""" return TableOp("UNION ALL", self, other) - def minus(self, other: "ITable"): + def minus(self, other: "ITable") -> "TableOp": """SELECT * FROM self EXCEPT other""" # aka return TableOp("EXCEPT", self, other) - def intersect(self, other: "ITable"): + def intersect(self, other: "ITable") -> "TableOp": """SELECT * FROM self INTERSECT other""" return TableOp("INTERSECT", self, other) @@ -233,51 +233,51 @@ def type(self) -> Optional[type]: @attrs.define(frozen=False, eq=False) class LazyOps: - def __add__(self, other): + def __add__(self, other) -> "BinOp": return BinOp("+", [self, other]) - def __sub__(self, other): + def __sub__(self, other) -> "BinOp": return BinOp("-", [self, other]) - def __neg__(self): + def __neg__(self) -> "UnaryOp": return UnaryOp("-", self) - def __gt__(self, other): + def __gt__(self, other) -> "BinBoolOp": return BinBoolOp(">", [self, other]) - def __ge__(self, other): + def __ge__(self, other) -> "BinBoolOp": return BinBoolOp(">=", [self, other]) - def __eq__(self, other): + def __eq__(self, other) -> "BinBoolOp": if other is None: return BinBoolOp("IS", [self, None]) return BinBoolOp("=", [self, other]) - def __lt__(self, other): + def __lt__(self, other) -> "BinBoolOp": return BinBoolOp("<", [self, other]) - def __le__(self, other): + def __le__(self, other) -> "BinBoolOp": return BinBoolOp("<=", [self, other]) - def __or__(self, other): + def __or__(self, other) -> "BinBoolOp": return BinBoolOp("OR", [self, other]) - def __and__(self, other): + def __and__(self, other) -> "BinBoolOp": return BinBoolOp("AND", [self, other]) - def is_distinct_from(self, other): + def is_distinct_from(self, other) -> "IsDistinctFrom": return IsDistinctFrom(self, other) - def like(self, other): + def like(self, other) -> "BinBoolOp": return BinBoolOp("LIKE", [self, other]) - def sum(self): + def sum(self) -> "Func": return Func("SUM", [self]) - def max(self): + def max(self) -> "Func": return Func("MAX", [self]) - def min(self): + def min(self) -> "Func": return Func("MIN", [self]) @@ -523,7 +523,7 @@ class GroupBy(ExprNode, ITable, Root): values: Optional[Sequence[Expr]] = None having_exprs: Optional[Sequence[Expr]] = None - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: assert self.keys or self.values def having(self, *exprs) -> Self: diff --git a/data_diff/queries/base.py b/data_diff/queries/base.py index ca8953c4..7f7cd1de 100644 --- a/data_diff/queries/base.py +++ b/data_diff/queries/base.py @@ -5,7 +5,7 @@ @attrs.define(frozen=True) class _SKIP: - def __repr__(self): + def __repr__(self) -> str: return "SKIP" diff --git a/data_diff/query_utils.py b/data_diff/query_utils.py index a4887728..ed753d31 100644 --- a/data_diff/query_utils.py +++ b/data_diff/query_utils.py @@ -23,7 +23,7 @@ def _drop_table(name: DbPath): yield commit -def drop_table(db, tbl): +def drop_table(db, tbl) -> None: if isinstance(db, Oracle): db.query(_drop_table_oracle(tbl)) else: @@ -51,6 +51,6 @@ def _append_to_table(path: DbPath, expr: Expr): yield commit -def append_to_table(db, path, expr): +def append_to_table(db, path, expr) -> None: f = _append_to_table_oracle if isinstance(db, Oracle) else _append_to_table db.query(f(path, expr)) diff --git a/data_diff/schema.py b/data_diff/schema.py index f0408935..d77e298b 100644 --- a/data_diff/schema.py +++ b/data_diff/schema.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Collection, Iterable, Optional +from typing import Any, Collection, Iterator, Optional import attrs @@ -28,7 +28,7 @@ class RawColumnInfo(Collection[Any]): collation_name: Optional[str] = None # It was a tuple once, so we keep it backward compatible temporarily, until remade to classes. - def __iter__(self) -> Iterable[Any]: + def __iter__(self) -> Iterator[Any]: return iter( (self.column_name, self.data_type, self.datetime_precision, self.numeric_precision, self.numeric_scale) ) diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index 924271ba..4e712445 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -126,7 +126,7 @@ class TableSegment: case_sensitive: Optional[bool] = True _schema: Optional[Schema] = None - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: if not self.update_column and (self.min_update or self.max_update): raise ValueError("Error: the min_update/max_update feature requires 'update_column' to be set.") @@ -138,7 +138,7 @@ def __attrs_post_init__(self): f"Error: min_update expected to be smaller than max_update! ({self.min_update} >= {self.max_update})" ) - def _where(self): + def _where(self) -> Optional[str]: return f"({self.where})" if self.where else None def _with_raw_schema(self, raw_schema: Dict[str, RawColumnInfo]) -> Self: diff --git a/data_diff/thread_utils.py b/data_diff/thread_utils.py index ba292ef5..2b9bb3db 100644 --- a/data_diff/thread_utils.py +++ b/data_diff/thread_utils.py @@ -5,7 +5,7 @@ from concurrent.futures import ThreadPoolExecutor from concurrent.futures.thread import _WorkItem from time import sleep -from typing import Callable, Iterator, Optional +from typing import Any, Callable, Iterator, Optional import attrs @@ -19,7 +19,7 @@ class AutoPriorityQueue(PriorityQueue): _counter = itertools.count().__next__ - def put(self, item: Optional[_WorkItem], block=True, timeout=None): + def put(self, item: Optional[_WorkItem], block=True, timeout=None) -> None: priority = item.kwargs.pop("priority") if item is not None else 0 super().put((-priority, self._counter(), item), block, timeout) @@ -34,7 +34,7 @@ class PriorityThreadPoolExecutor(ThreadPoolExecutor): XXX WARNING: Might break in future versions of Python """ - def __init__(self, *args): + def __init__(self, *args) -> None: super().__init__(*args) self._work_queue = AutoPriorityQueue() @@ -58,7 +58,7 @@ class ThreadedYielder(Iterable): _exception: Optional[None] yield_list: bool - def __init__(self, max_workers: Optional[int] = None, yield_list: bool = False): + def __init__(self, max_workers: Optional[int] = None, yield_list: bool = False) -> None: super().__init__() self._pool = PriorityThreadPoolExecutor(max_workers) self._futures = deque() @@ -66,7 +66,7 @@ def __init__(self, max_workers: Optional[int] = None, yield_list: bool = False): self._exception = None self.yield_list = yield_list - def _worker(self, fn, *args, **kwargs): + def _worker(self, fn, *args, **kwargs) -> None: try: res = fn(*args, **kwargs) if res is not None: @@ -77,10 +77,10 @@ def _worker(self, fn, *args, **kwargs): except Exception as e: self._exception = e - def submit(self, fn: Callable, *args, priority: int = 0, **kwargs): + def submit(self, fn: Callable, *args, priority: int = 0, **kwargs) -> None: self._futures.append(self._pool.submit(self._worker, fn, *args, priority=priority, **kwargs)) - def __iter__(self) -> Iterator: + def __iter__(self) -> Iterator[Any]: while True: if self._exception: raise self._exception diff --git a/data_diff/tracking.py b/data_diff/tracking.py index 42f44dbb..0fad464a 100644 --- a/data_diff/tracking.py +++ b/data_diff/tracking.py @@ -80,16 +80,16 @@ def bool_notify_about_extension() -> bool: entrypoint_name = "Python API" -def disable_tracking(): +def disable_tracking() -> None: global g_tracking_enabled g_tracking_enabled = False -def is_tracking_enabled(): +def is_tracking_enabled() -> bool: return g_tracking_enabled -def set_entrypoint_name(s): +def set_entrypoint_name(s) -> None: global entrypoint_name entrypoint_name = s @@ -99,22 +99,22 @@ def set_entrypoint_name(s): dbt_project_id = None -def set_dbt_user_id(s): +def set_dbt_user_id(s) -> None: global dbt_user_id dbt_user_id = s -def set_dbt_version(s): +def set_dbt_version(s) -> None: global dbt_version dbt_version = s -def set_dbt_project_id(s): +def set_dbt_project_id(s) -> None: global dbt_project_id dbt_project_id = s -def get_anonymous_id(): +def get_anonymous_id() -> str: global g_anonymous_id if g_anonymous_id is None: profile = _load_profile() @@ -201,7 +201,7 @@ def create_email_signup_event_json(email: str) -> Dict[str, Any]: } -def send_event_json(event_json): +def send_event_json(event_json) -> None: if not g_tracking_enabled: raise RuntimeError("Won't send; tracking is disabled!") diff --git a/data_diff/utils.py b/data_diff/utils.py index e16110e1..c4175426 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -80,7 +80,7 @@ def new(self, initial=()) -> Self: class CaseInsensitiveDict(CaseAwareMapping): - def __init__(self, initial): + def __init__(self, initial) -> None: super().__init__() self._dict = {k.lower(): (k, v) for k, v in dict(initial).items()} @@ -93,13 +93,13 @@ def __iter__(self) -> Iterator[V]: def __len__(self) -> int: return len(self._dict) - def __setitem__(self, key: str, value): + def __setitem__(self, key: str, value) -> None: k = key.lower() if k in self._dict: key = self._dict[k][0] self._dict[k] = key, value - def __delitem__(self, key: str): + def __delitem__(self, key: str) -> None: del self._dict[key.lower()] def get_key(self, key: str) -> str: @@ -161,7 +161,7 @@ def range(self, other: "ArithUUID", count: int) -> List[Self]: checkpoints = split_space(self.uuid.int, other.uuid.int, count) return [attrs.evolve(self, uuid=i) for i in checkpoints] - def __int__(self): + def __int__(self) -> int: return self.uuid.int def __add__(self, other: int) -> Self: @@ -241,7 +241,7 @@ class ArithAlphanumeric(ArithString): _str: str _max_len: Optional[int] = None - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: if self._str is None: raise ValueError("Alphanum string cannot be None") if self._max_len and len(self._str) > self._max_len: @@ -255,16 +255,16 @@ def __attrs_post_init__(self): # def int(self): # return alphanumToNumber(self._str, alphanums) - def __str__(self): + def __str__(self) -> str: s = self._str if self._max_len: s = s.rjust(self._max_len, alphanums[0]) return s - def __len__(self): + def __len__(self) -> int: return len(self._str) - def __repr__(self): + def __repr__(self) -> str: return f'alphanum"{self._str}"' def __add__(self, other: "Union[ArithAlphanumeric, int]") -> Self: @@ -289,17 +289,17 @@ def __sub__(self, other: "Union[ArithAlphanumeric, int]") -> float: return NotImplemented - def __ge__(self, other): + def __ge__(self, other) -> bool: if not isinstance(other, type(self)): return NotImplemented return self._str >= other._str - def __lt__(self, other): + def __lt__(self, other) -> bool: if not isinstance(other, type(self)): return NotImplemented return self._str < other._str - def __eq__(self, other): + def __eq__(self, other) -> bool: if not isinstance(other, type(self)): return NotImplemented return self._str == other._str @@ -424,32 +424,32 @@ class Vector(tuple): Partial implementation: Only the needed functionality is implemented """ - def __lt__(self, other: "Vector"): + def __lt__(self, other: "Vector") -> bool: if isinstance(other, Vector): return all(a < b for a, b in safezip(self, other)) return NotImplemented - def __le__(self, other: "Vector"): + def __le__(self, other: "Vector") -> bool: if isinstance(other, Vector): return all(a <= b for a, b in safezip(self, other)) return NotImplemented - def __gt__(self, other: "Vector"): + def __gt__(self, other: "Vector") -> bool: if isinstance(other, Vector): return all(a > b for a, b in safezip(self, other)) return NotImplemented - def __ge__(self, other: "Vector"): + def __ge__(self, other: "Vector") -> bool: if isinstance(other, Vector): return all(a >= b for a, b in safezip(self, other)) return NotImplemented - def __eq__(self, other: "Vector"): + def __eq__(self, other: "Vector") -> bool: if isinstance(other, Vector): return all(a == b for a, b in safezip(self, other)) return NotImplemented - def __sub__(self, other: "Vector"): + def __sub__(self, other: "Vector") -> "Vector": if isinstance(other, Vector): return Vector((a - b) for a, b in safezip(self, other)) raise NotImplementedError() @@ -540,7 +540,7 @@ class LogStatusHandler(logging.Handler): This log handler can be used to update a rich.status every time a log is emitted. """ - def __init__(self): + def __init__(self) -> None: super().__init__() self.status = Status("") self.prefix = "" @@ -575,12 +575,12 @@ class UnknownMeta(type): def __instancecheck__(self, instance): return instance is Unknown - def __repr__(self): + def __repr__(self) -> str: return "Unknown" class Unknown(metaclass=UnknownMeta): - def __nonzero__(self): + def __bool__(self) -> bool: raise TypeError() def __new__(class_, *args, **kwargs): diff --git a/tests/test_database_types.py b/tests/test_database_types.py index e97ca484..4bbeb0ea 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -8,9 +8,11 @@ import logging from decimal import Decimal from itertools import islice, repeat, chain +from typing import Iterator from parameterized import parameterized +from data_diff.databases.base import Row from data_diff.utils import number_to_human from data_diff.queries.api import table, commit, this, Code from data_diff.queries.api import insert_rows_in_batches @@ -366,12 +368,12 @@ class PaginatedTable: # much memory. RECORDS_PER_BATCH = 1000000 - def __init__(self, table_path, conn): + def __init__(self, table_path, conn) -> None: super().__init__() self.table_path = table_path self.conn = conn - def __iter__(self): + def __iter__(self) -> Iterator[Row]: last_id = 0 while True: query = ( @@ -398,46 +400,46 @@ class DateTimeFaker: datetime.fromisoformat("2022-06-01 15:10:05.009900"), ] - def __init__(self, max): + def __init__(self, max) -> None: super().__init__() self.max = max - def __iter__(self): + def __iter__(self) -> Iterator[datetime]: initial = datetime(2000, 1, 1, 0, 0, 0, 0) step = timedelta(seconds=3, microseconds=571) return islice(chain(self.MANUAL_FAKES, accumulate(repeat(step), initial=initial)), self.max) - def __len__(self): + def __len__(self) -> int: return self.max class IntFaker: MANUAL_FAKES = [127, -3, -9, 37, 15, 0] - def __init__(self, max): + def __init__(self, max) -> None: super().__init__() self.max = max - def __iter__(self): + def __iter__(self) -> Iterator[int]: initial = -128 step = 1 return islice(chain(self.MANUAL_FAKES, accumulate(repeat(step), initial=initial)), self.max) - def __len__(self): + def __len__(self) -> int: return self.max class BooleanFaker: MANUAL_FAKES = [False, True, True, False] - def __init__(self, max): + def __init__(self, max) -> None: super().__init__() self.max = max - def __iter__(self): + def __iter__(self) -> Iterator[bool]: return iter(self.MANUAL_FAKES[: self.max]) - def __len__(self): + def __len__(self) -> int: return min(self.max, len(self.MANUAL_FAKES)) @@ -461,28 +463,28 @@ class FloatFaker: 3.141592653589793, ] - def __init__(self, max): + def __init__(self, max) -> None: super().__init__() self.max = max - def __iter__(self): + def __iter__(self) -> Iterator[float]: initial = -10.0001 step = 0.00571 return islice(chain(self.MANUAL_FAKES, accumulate(repeat(step), initial=initial)), self.max) - def __len__(self): + def __len__(self) -> int: return self.max class UUID_Faker: - def __init__(self, max): + def __init__(self, max) -> None: super().__init__() self.max = max - def __len__(self): + def __len__(self) -> int: return self.max - def __iter__(self): + def __iter__(self) -> Iterator[uuid.UUID]: return (uuid.uuid1(i) for i in range(self.max)) @@ -491,14 +493,14 @@ class JsonFaker: '{"keyText": "text", "keyInt": 3, "keyFloat": 5.4445, "keyBoolean": true}', ] - def __init__(self, max): + def __init__(self, max) -> None: super().__init__() self.max = max - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(self.MANUAL_FAKES[: self.max]) - def __len__(self): + def __len__(self) -> int: return min(self.max, len(self.MANUAL_FAKES))