diff --git a/.github/workflows/readme.yml b/.github/workflows/readme.yml
index 55a2b351..bcfda983 100644
--- a/.github/workflows/readme.yml
+++ b/.github/workflows/readme.yml
@@ -27,5 +27,6 @@ jobs:
python -m pip install --upgrade pip
python -m pip install invoke rundoc .
python -m pip install tomli
+ python -m pip install slack-sdk
- name: Run the README.md
run: invoke readme
diff --git a/.github/workflows/run_benchmark.yml b/.github/workflows/run_benchmark.yml
new file mode 100644
index 00000000..7e66075a
--- /dev/null
+++ b/.github/workflows/run_benchmark.yml
@@ -0,0 +1,31 @@
+name: Run SDGym Benchmark
+
+on:
+ workflow_dispatch:
+ schedule:
+ - cron: '0 5 5 * *'
+
+jobs:
+ run-sdgym-benchmark:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+ - name: Set up latest Python
+ uses: actions/setup-python@v5
+ with:
+ python-version-file: 'pyproject.toml'
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install -e .[dev]
+
+ - name: Run SDGym Benchmark
+ env:
+ SLACK_TOKEN: ${{ secrets.SLACK_TOKEN }}
+ AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
+ AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
+ AWS_DEFAULT_REGION: ${{ secrets.AWS_REGION }}
+
+ run: invoke run-sdgym-benchmark
diff --git a/.github/workflows/upload_benchmark_results.yml b/.github/workflows/upload_benchmark_results.yml
new file mode 100644
index 00000000..ead247f0
--- /dev/null
+++ b/.github/workflows/upload_benchmark_results.yml
@@ -0,0 +1,91 @@
+name: Upload SDGym Benchmark results
+
+on:
+ workflow_run:
+ workflows: ["Run SDGym Benchmark"]
+ types:
+ - completed
+ workflow_dispatch:
+ schedule:
+ - cron: '0 6 * * *'
+
+jobs:
+ upload-sdgym-benchmark:
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Set up latest Python
+ uses: actions/setup-python@v5
+ with:
+ python-version-file: 'pyproject.toml'
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install -e .[dev]
+
+ - name: Upload SDGym Benchmark
+ env:
+ PYDRIVE_CREDENTIALS: ${{ secrets.PYDRIVE_CREDENTIALS }}
+ AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
+ AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
+ GITHUB_LOCAL_RESULTS_DIR: ${{ runner.temp }}/sdgym-leaderboard-files
+ run: |
+ invoke upload-benchmark-results
+ echo "GITHUB_LOCAL_RESULTS_DIR=$GITHUB_LOCAL_RESULTS_DIR" >> $GITHUB_ENV
+
+ - name: Prepare files for commit
+ if: env.SKIP_UPLOAD != 'true'
+ run: |
+ mkdir pr-staging
+ echo "Looking for files in: $GITHUB_LOCAL_RESULTS_DIR"
+ ls -l "$GITHUB_LOCAL_RESULTS_DIR"
+ for f in "$GITHUB_LOCAL_RESULTS_DIR"/${FOLDER_NAME}_*.csv; do
+ base=$(basename "$f")
+ cp "$f" "pr-staging/${base}"
+ done
+
+ echo "Files staged for PR:"
+ ls -l pr-staging
+
+ - name: Checkout target repo (sdv-dev.github.io)
+ if: env.SKIP_UPLOAD != 'true'
+ run: |
+ git clone https://github.com/sdv-dev/sdv-dev.github.io.git target-repo
+ cd target-repo
+ git checkout gatsby-home
+
+ - name: Copy results and commit
+ if: env.SKIP_UPLOAD != 'true'
+ env:
+ GH_TOKEN: ${{ secrets.GH_TOKEN }}
+ FOLDER_NAME: ${{ env.FOLDER_NAME }}
+ run: |
+ cp pr-staging/* target-repo/assets/sdgym-leaderboard-files/
+ cd target-repo
+ git checkout gatsby-home
+ git config --local user.name "github-actions[bot]"
+ git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com"
+ git add assets/
+ git commit -m "Upload SDGym Benchmark Results ($FOLDER_NAME)" || echo "No changes to commit"
+ git remote set-url origin https://x-access-token:${GH_TOKEN}@github.com/sdv-dev/sdv-dev.github.io.git
+ git push origin gatsby-home
+
+ COMMIT_HASH=$(git rev-parse HEAD)
+ COMMIT_URL="https://github.com/sdv-dev/sdv-dev.github.io/commit/$%7BCOMMIT_HASH%7D"
+
+ echo "Commit URL: $COMMIT_URL"
+ echo "COMMIT_URL=$COMMIT_URL" >> $GITHUB_ENV
+
+ - name: Send Slack notification
+ if: env.SKIP_UPLOAD != 'true'
+ env:
+ SLACK_TOKEN: ${{ secrets.SLACK_TOKEN }}
+ run: |
+ invoke notify-sdgym-benchmark-uploaded \
+ --folder-name "$FOLDER_NAME" \
+ --commit-url "$COMMIT_URL"
diff --git a/pyproject.toml b/pyproject.toml
index 0553c69f..6651e264 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -47,15 +47,17 @@ dependencies = [
"scipy>=1.12.0;python_version>='3.12' and python_version<'3.13'",
"scipy>=1.14.1;python_version>='3.13'",
'tabulate>=0.8.3,<0.9',
- "torch>=1.13.0;python_version<'3.11'",
- "torch>=2.0.0;python_version>='3.11' and python_version<'3.12'",
- "torch>=2.2.0;python_version>='3.12' and python_version<'3.13'",
- "torch>=2.6.0;python_version>='3.13'",
+ "torch>=2.2.0;python_version>='3.8' and python_version<'3.9'",
+ "torch>=2.6.0;python_version>='3.9'",
'tqdm>=4.66.3',
'XlsxWriter>=1.2.8',
'rdt>=1.17.0',
'sdmetrics>=0.20.1',
'sdv>=1.21.0',
+ 'openpyxl>=3.0.0',
+ 'kaleido>=0.2.1',
+ 'pillow>=9.0.0',
+ 'pydrive2>=1.3.1'
]
[project.urls]
@@ -71,10 +73,9 @@ sdgym = { main = 'sdgym.cli.__main__:main' }
[project.optional-dependencies]
dask = ['dask', 'distributed']
realtabformer = [
- 'realtabformer>=0.2.2',
- "torch>=2.1.0;python_version>='3.8' and python_version<'3.12'",
- "torch>=2.2.0;python_version>='3.12' and python_version<'3.13'",
- "torch>=2.6.0;python_version>='3.13'",
+ 'realtabformer>=0.2.3',
+ "torch>=2.2.0;python_version>='3.8' and python_version<'3.9'",
+ "torch>=2.6.0;python_version>='3.9'",
'transformers<4.51',
]
test = [
@@ -83,6 +84,7 @@ test = [
'pytest-cov>=2.6.0',
'jupyter>=1.0.0,<2',
'tomli>=2.0.0,<3',
+ 'slack-sdk>=3.23,<4.0'
]
dev = [
'sdgym[dask, test]',
@@ -196,6 +198,7 @@ exclude = [
".ipynb_checkpoints",
"tasks.py",
"static_code_analysis.txt",
+ "*.ipynb"
]
[tool.ruff.lint]
diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py
index 6eca2c91..384203b0 100644
--- a/sdgym/benchmark.py
+++ b/sdgym/benchmark.py
@@ -1,6 +1,5 @@
"""Main SDGym benchmarking module."""
-import base64
import concurrent
import logging
import math
@@ -9,6 +8,7 @@
import pickle
import re
import textwrap
+import threading
import tracemalloc
import warnings
from collections import defaultdict
@@ -24,6 +24,7 @@
import numpy as np
import pandas as pd
import tqdm
+from botocore.config import Config
from sdmetrics.reports.multi_table import (
DiagnosticReport as MultiTableDiagnosticReport,
)
@@ -42,9 +43,10 @@
from sdgym.errors import SDGymError
from sdgym.metrics import get_metrics
from sdgym.progress import TqdmLogger, progress
-from sdgym.result_writer import LocalResultsWriter
+from sdgym.result_writer import LocalResultsWriter, S3ResultsWriter
from sdgym.s3 import (
S3_PREFIX,
+ S3_REGION,
is_s3_path,
parse_s3_path,
write_csv,
@@ -168,6 +170,11 @@ def _setup_output_destination_aws(output_destination, synthesizers, datasets, s3
'run_id': f's3://{bucket_name}/{top_folder}/run_{today}_{increment}.yaml',
}
+ s3_client.put_object(
+ Bucket=bucket_name,
+ Key=f'{top_folder}/run_{today}_{increment}.yaml',
+ Body='completed_date: null\n'.encode('utf-8'),
+ )
return paths
@@ -236,11 +243,25 @@ def _generate_job_args_list(
synthesizers = get_synthesizers(synthesizers + custom_synthesizers)
# Get list of dataset paths
- sdv_datasets = [] if sdv_datasets is None else get_dataset_paths(datasets=sdv_datasets)
+ aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
+ aws_secret_access_key_key = os.getenv('AWS_SECRET_ACCESS_KEY')
+ sdv_datasets = (
+ []
+ if sdv_datasets is None
+ else get_dataset_paths(
+ datasets=sdv_datasets,
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key_key,
+ )
+ )
additional_datasets = (
[]
if additional_datasets_folder is None
- else get_dataset_paths(bucket=additional_datasets_folder)
+ else get_dataset_paths(
+ bucket=additional_datasets_folder,
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key_key,
+ )
)
datasets = sdv_datasets + additional_datasets
synthesizer_names = [synthesizer['name'] for synthesizer in synthesizers]
@@ -524,27 +545,36 @@ def _score_with_timeout(
synthesizer_path=None,
result_writer=None,
):
+ output = {} if isinstance(result_writer, S3ResultsWriter) else None
+ args = (
+ synthesizer,
+ data,
+ metadata,
+ metrics,
+ output,
+ compute_quality_score,
+ compute_diagnostic_score,
+ compute_privacy_score,
+ modality,
+ dataset_name,
+ synthesizer_path,
+ result_writer,
+ )
+ if isinstance(result_writer, S3ResultsWriter):
+ thread = threading.Thread(target=_score, args=args, daemon=True)
+ thread.start()
+ thread.join(timeout)
+ if thread.is_alive():
+ LOGGER.error('Timeout running %s on dataset %s;', synthesizer['name'], dataset_name)
+ return {'timeout': True, 'error': 'Timeout'}
+
+ return output
+
with multiprocessing_context():
with multiprocessing.Manager() as manager:
output = manager.dict()
- process = multiprocessing.Process(
- target=_score,
- args=(
- synthesizer,
- data,
- metadata,
- metrics,
- output,
- compute_quality_score,
- compute_diagnostic_score,
- compute_privacy_score,
- modality,
- dataset_name,
- synthesizer_path,
- result_writer,
- ),
- )
-
+ args = args[:4] + (output,) + args[5:] # replace output=None with manager.dict()
+ process = multiprocessing.Process(target=_score, args=args)
process.start()
process.join(timeout)
process.terminate()
@@ -697,7 +727,6 @@ def _run_job(args):
compute_privacy_score,
cache_dir,
)
-
if synthesizer_path and result_writer:
result_writer.write_dataframe(scores, synthesizer_path['benchmark_result'])
@@ -998,9 +1027,10 @@ def _write_run_id_file(synthesizers, job_args_list, result_writer=None):
}
for synthesizer in synthesizers:
if synthesizer not in SDV_SINGLE_TABLE_SYNTHESIZERS:
- ext_lib = EXTERNAL_SYNTHESIZER_TO_LIBRARY[synthesizer]
- library_version = version(ext_lib)
- metadata[f'{ext_lib}_version'] = library_version
+ ext_lib = EXTERNAL_SYNTHESIZER_TO_LIBRARY.get(synthesizer)
+ if ext_lib:
+ library_version = version(ext_lib)
+ metadata[f'{ext_lib}_version'] = library_version
elif 'sdv' not in metadata.keys():
metadata['sdv_version'] = version('sdv')
@@ -1180,20 +1210,22 @@ def _validate_aws_inputs(output_destination, aws_access_key_id, aws_secret_acces
if not output_destination.startswith('s3://'):
raise ValueError("'output_destination' must be an S3 URL starting with 's3://'. ")
- parsed_url = urlparse(output_destination)
- bucket_name = parsed_url.netloc
+ bucket_name, _ = parse_s3_path(output_destination)
if not bucket_name:
raise ValueError(f'Invalid S3 URL: {output_destination}')
+ config = Config(connect_timeout=30, read_timeout=300)
if aws_access_key_id and aws_secret_access_key:
s3_client = boto3.client(
's3',
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
+ region_name=S3_REGION,
+ config=config,
)
else:
# No credentials provided β rely on default session
- s3_client = boto3.client('s3')
+ s3_client = boto3.client('s3', config=config)
s3_client.head_bucket(Bucket=bucket_name)
if not _check_write_permissions(s3_client, bucket_name):
@@ -1223,8 +1255,7 @@ def _store_job_args_in_s3(output_destination, job_args_list, s3_client):
job_args_key = f'{path}{job_args_key}' if path else job_args_key
serialized_data = pickle.dumps(job_args_list)
- encoded_data = base64.b64encode(serialized_data).decode('utf-8')
- s3_client.put_object(Bucket=bucket_name, Key=job_args_key, Body=encoded_data)
+ s3_client.put_object(Bucket=bucket_name, Key=job_args_key, Body=serialized_data)
return bucket_name, job_args_key
@@ -1235,15 +1266,6 @@ def _get_s3_script_content(
return f"""
import boto3
import pickle
-import base64
-import pandas as pd
-import sdgym
-from sdgym.synthesizers.sdv import (
- CopulaGANSynthesizer, CTGANSynthesizer,
- GaussianCopulaSynthesizer, HMASynthesizer, PARSynthesizer,
- SDVRelationalSynthesizer, SDVTabularSynthesizer, TVAESynthesizer
-)
-from sdgym.synthesizers import RealTabFormerSynthesizer
from sdgym.benchmark import _run_jobs, _write_run_id_file, _update_run_id_file
from io import StringIO
from sdgym.result_writer import S3ResultsWriter
@@ -1255,9 +1277,7 @@ def _get_s3_script_content(
region_name='{region_name}'
)
response = s3_client.get_object(Bucket='{bucket_name}', Key='{job_args_key}')
-encoded_data = response['Body'].read().decode('utf-8')
-serialized_data = base64.b64decode(encoded_data.encode('utf-8'))
-job_args_list = pickle.loads(serialized_data)
+job_args_list = pickle.loads(response['Body'].read())
result_writer = S3ResultsWriter(s3_client=s3_client)
_write_run_id_file({synthesizers}, job_args_list, result_writer)
scores = _run_jobs(None, job_args_list, False, result_writer=result_writer)
@@ -1287,7 +1307,7 @@ def _get_user_data_script(access_key, secret_key, region_name, script_content):
echo "======== Install Dependencies in venv ============"
pip install --upgrade pip
- pip install "sdgym[all]"
+ pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@issue-425-workflow-sdgym#egg=sdgym"
pip install s3fs
echo "======== Write Script ==========="
@@ -1313,11 +1333,10 @@ def _run_on_aws(
aws_secret_access_key,
):
bucket_name, job_args_key = _store_job_args_in_s3(output_destination, job_args_list, s3_client)
- region_name = 'us-east-1'
script_content = _get_s3_script_content(
aws_access_key_id,
aws_secret_access_key,
- region_name,
+ S3_REGION,
bucket_name,
job_args_key,
synthesizers,
@@ -1327,12 +1346,12 @@ def _run_on_aws(
session = boto3.session.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
- region_name=region_name,
+ region_name=S3_REGION,
)
ec2_client = session.client('ec2')
print(f'This instance is being created in region: {session.region_name}') # noqa
user_data_script = _get_user_data_script(
- aws_access_key_id, aws_secret_access_key, region_name, script_content
+ aws_access_key_id, aws_secret_access_key, S3_REGION, script_content
)
response = ec2_client.run_instances(
ImageId='ami-080e1f13689e07408',
diff --git a/sdgym/cli/__main__.py b/sdgym/cli/__main__.py
index 431531d4..59a812e8 100644
--- a/sdgym/cli/__main__.py
+++ b/sdgym/cli/__main__.py
@@ -98,12 +98,16 @@ def _download_datasets(args):
datasets = args.datasets
if not datasets:
datasets = sdgym.datasets.get_available_datasets(
- args.bucket, args.aws_key, args.aws_secret
+ args.bucket, args.aws_access_key_id, args.aws_secret_access_key
)['name']
for dataset in tqdm.tqdm(datasets):
sdgym.datasets.load_dataset(
- dataset, args.datasets_path, args.bucket, args.aws_key, args.aws_secret
+ dataset,
+ args.datasets_path,
+ args.bucket,
+ args.aws_access_key_id,
+ args.aws_secret_access_key,
)
@@ -114,7 +118,9 @@ def _list_downloaded(args):
def _list_available(args):
- datasets = sdgym.datasets.get_available_datasets(args.bucket, args.aws_key, args.aws_secret)
+ datasets = sdgym.datasets.get_available_datasets(
+ args.bucket, args.aws_access_key_id, args.aws_secret_access_key
+ )
_print_table(datasets, args.sort, args.reverse, {'size': humanfriendly.format_size})
@@ -125,7 +131,7 @@ def _list_synthesizers(args):
def _collect(args):
sdgym.cli.collect.collect_results(
- args.input_path, args.output_file, args.aws_key, args.aws_secret
+ args.input_path, args.output_file, args.aws_access_key_id, args.aws_secret_access_key
)
@@ -133,8 +139,8 @@ def _summary(args):
sdgym.cli.summary.make_summary_spreadsheet(
args.input_path,
output_path=args.output_file,
- aws_key=args.aws_key,
- aws_secret=args.aws_secret,
+ aws_access_key_id=args.aws_access_key_id,
+ aws_secret_access_key=args.aws_secret_access_key,
)
diff --git a/sdgym/cli/collect.py b/sdgym/cli/collect.py
index 350fd291..8468e251 100644
--- a/sdgym/cli/collect.py
+++ b/sdgym/cli/collect.py
@@ -4,7 +4,9 @@
from sdgym.s3 import write_csv
-def collect_results(input_path, output_file=None, aws_key=None, aws_secret=None):
+def collect_results(
+ input_path, output_file=None, aws_access_key_id=None, aws_secret_access_key=None
+):
"""Collect the results in the given input directory.
Write all the results into one csv file.
@@ -15,15 +17,15 @@ def collect_results(input_path, output_file=None, aws_key=None, aws_secret=None)
output_file (str):
If ``output_file`` is provided, the consolidated results will be written there.
Otherwise, they will be written to ``input_path``/results.csv.
- aws_key (str):
- If an ``aws_key`` is provided, the given access key id will be used to read from
- and/or write to any s3 paths.
- aws_secret (str):
- If an ``aws_secret`` is provided, the given secret access key will be used to read
- from and/or write to any s3 paths.
+ aws_access_key_id (str):
+ If an ``aws_access_key_id`` is provided, the given access key id will be used
+ to read from and/or write to any s3 paths.
+ aws_secret_access_key (str):
+ If an ``aws_secret_access_key`` is provided, the given secret access key will
+ be used to read from and/or write to any s3 paths.
"""
print(f'Reading results from {input_path}') # noqa: T201
- scores = read_csv_from_path(input_path, aws_key, aws_secret)
+ scores = read_csv_from_path(input_path, aws_access_key_id, aws_secret_access_key)
scores = scores.drop_duplicates()
if output_file:
@@ -32,4 +34,4 @@ def collect_results(input_path, output_file=None, aws_key=None, aws_secret=None)
output = f'{input_path}/results.csv'
print(f'Storing results at {output}') # noqa: T201
- write_csv(scores, output, aws_key, aws_secret)
+ write_csv(scores, output, aws_access_key_id, aws_secret_access_key)
diff --git a/sdgym/cli/summary.py b/sdgym/cli/summary.py
index 06d872a3..cbbb9a98 100644
--- a/sdgym/cli/summary.py
+++ b/sdgym/cli/summary.py
@@ -289,7 +289,11 @@ def _add_summary(data, modality, baselines, writer):
def make_summary_spreadsheet(
- results_csv_path, output_path=None, baselines=None, aws_key=None, aws_secret=None
+ results_csv_path,
+ output_path=None,
+ baselines=None,
+ aws_access_key_id=None,
+ aws_secret_access_key=None,
):
"""Create a spreadsheet document organizing information from results.
@@ -307,7 +311,7 @@ def make_summary_spreadsheet(
Optional dict mapping modalities to a list of baseline
model names. If not provided, a default dict is used.
"""
- results = read_csv(results_csv_path, aws_key, aws_secret)
+ results = read_csv(results_csv_path, aws_access_key_id, aws_secret_access_key)
data = preprocess(results)
baselines = baselines or MODALITY_BASELINES
output_path = output_path or re.sub('.csv$', '.xlsx', results_csv_path)
@@ -319,4 +323,4 @@ def make_summary_spreadsheet(
_add_summary(df, modality, modality_baselines, writer)
writer.save()
- write_file(output.getvalue(), output_path, aws_key, aws_secret)
+ write_file(output.getvalue(), output_path, aws_access_key_id, aws_secret_access_key)
diff --git a/sdgym/cli/utils.py b/sdgym/cli/utils.py
index 77346277..1d1425b4 100644
--- a/sdgym/cli/utils.py
+++ b/sdgym/cli/utils.py
@@ -11,7 +11,7 @@
from sdgym.s3 import get_s3_client, is_s3_path, parse_s3_path
-def read_file(path, aws_key, aws_secret):
+def read_file(path, aws_access_key_id, aws_secret_access_key):
"""Read file from path.
The path can either be a local path or an s3 directory.
@@ -19,9 +19,9 @@ def read_file(path, aws_key, aws_secret):
Args:
path (str):
The path to the file.
- aws_key (str):
+ aws_access_key_id (str):
The access key id that will be used to communicate with s3, if provided.
- aws_secret (str):
+ aws_secret_access_key (str):
The secret access key that will be used to communicate with s3, if provided.
Returns:
@@ -29,7 +29,7 @@ def read_file(path, aws_key, aws_secret):
The content of the file in bytes.
"""
if is_s3_path(path):
- s3 = get_s3_client(aws_key, aws_secret)
+ s3 = get_s3_client(aws_access_key_id, aws_secret_access_key)
bucket_name, key = parse_s3_path(path)
obj = s3.get_object(Bucket=bucket_name, Key=key)
contents = obj['Body'].read()
@@ -40,7 +40,7 @@ def read_file(path, aws_key, aws_secret):
return contents
-def read_csv(path, aws_key, aws_secret):
+def read_csv(path, aws_access_key_id, aws_secret_access_key):
"""Read csv file from path.
The path can either be a local path or an s3 directory.
@@ -48,20 +48,20 @@ def read_csv(path, aws_key, aws_secret):
Args:
path (str):
The path to the csv file.
- aws_key (str):
+ aws_access_key_id (str):
The access key id that will be used to communicate with s3, if provided.
- aws_secret (str):
+ aws_secret_access_key (str):
The secret access key that will be used to communicate with s3, if provided.
Returns:
pandas.DataFrame:
A DataFrame containing the contents of the csv file.
"""
- contents = read_file(path, aws_key, aws_secret)
+ contents = read_file(path, aws_access_key_id, aws_secret_access_key)
return pd.read_csv(io.BytesIO(contents))
-def read_csv_from_path(path, aws_key, aws_secret):
+def read_csv_from_path(path, aws_access_key_id, aws_secret_access_key):
"""Read all csv content within a path.
All csv content within a path will be read and returned in a
@@ -70,9 +70,9 @@ def read_csv_from_path(path, aws_key, aws_secret):
Args:
path (str):
The path to read from, which can be either local or an s3 path.
- aws_key (str):
+ aws_access_key_id (str):
The access key id that will be used to communicate with s3, if provided.
- aws_secret (str):
+ aws_secret_access_key (str):
The secret access key that will be used to communicate with s3, if provided.
Returns:
@@ -81,13 +81,17 @@ def read_csv_from_path(path, aws_key, aws_secret):
"""
csv_contents = []
if is_s3_path(path):
- s3 = get_s3_client(aws_key, aws_secret)
+ s3 = get_s3_client(aws_access_key_id, aws_secret_access_key)
bucket_name, key_prefix = parse_s3_path(path)
resp = s3.list_objects(Bucket=bucket_name, Prefix=key_prefix)
csv_files = [f for f in resp['Contents'] if f['Key'].endswith('.csv')]
for csv_file in csv_files:
csv_file_key = csv_file['Key']
- csv_contents.append(read_csv(f's3://{bucket_name}/{csv_file_key}', aws_key, aws_secret))
+ csv_contents.append(
+ read_csv(
+ f's3://{bucket_name}/{csv_file_key}', aws_access_key_id, aws_secret_access_key
+ )
+ )
else:
run_path = pathlib.Path(path)
diff --git a/sdgym/datasets.py b/sdgym/datasets.py
index 13a3b237..b04b00d5 100644
--- a/sdgym/datasets.py
+++ b/sdgym/datasets.py
@@ -28,7 +28,12 @@ def _get_bucket_name(bucket):
def _download_dataset(
- modality, dataset_name, datasets_path=None, bucket=None, aws_key=None, aws_secret=None
+ modality,
+ dataset_name,
+ datasets_path=None,
+ bucket=None,
+ aws_access_key_id=None,
+ aws_secret_access_key=None,
):
"""Download a dataset and extract it into the given ``datasets_path``."""
datasets_path = datasets_path or DATASETS_PATH / dataset_name
@@ -36,7 +41,7 @@ def _download_dataset(
bucket_name = _get_bucket_name(bucket)
LOGGER.info('Downloading dataset %s from %s', dataset_name, bucket)
- s3 = get_s3_client(aws_key, aws_secret)
+ s3 = get_s3_client(aws_access_key_id, aws_secret_access_key)
obj = s3.get_object(Bucket=bucket_name, Key=f'{modality.upper()}/{dataset_name}.zip')
bytes_io = io.BytesIO(obj['Body'].read())
@@ -45,7 +50,14 @@ def _download_dataset(
zf.extractall(datasets_path)
-def _get_dataset_path(modality, dataset, datasets_path, bucket=None, aws_key=None, aws_secret=None):
+def _get_dataset_path(
+ modality,
+ dataset,
+ datasets_path,
+ bucket=None,
+ aws_access_key_id=None,
+ aws_secret_access_key=None,
+):
dataset = Path(dataset)
if dataset.exists():
return dataset
@@ -62,7 +74,12 @@ def _get_dataset_path(modality, dataset, datasets_path, bucket=None, aws_key=Non
return local_path
_download_dataset(
- modality, dataset, dataset_path, bucket=bucket, aws_key=aws_key, aws_secret=aws_secret
+ modality,
+ dataset,
+ dataset_path,
+ bucket=bucket,
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
)
return dataset_path
@@ -88,8 +105,8 @@ def load_dataset(
dataset,
datasets_path=None,
bucket=None,
- aws_key=None,
- aws_secret=None,
+ aws_access_key_id=None,
+ aws_secret_access_key=None,
limit_dataset_size=None,
):
"""Get the data and metadata of a dataset.
@@ -105,9 +122,9 @@ def load_dataset(
bucket (str):
The AWS bucket where to get the dataset. This will only be used if both ``dataset``
and ``dataset_path`` don't exist.
- aws_key (str):
+ aws_access_key_id (str):
The access key id that will be used to communicate with s3, if provided.
- aws_secret (str):
+ aws_secret_access_key (str):
The secret access key that will be used to communicate with s3, if provided.
limit_dataset_size (bool):
Use this flag to limit the size of the datasets for faster evaluation. If ``True``,
@@ -118,7 +135,9 @@ def load_dataset(
pd.DataFrame, dict:
The data and medatata of a dataset.
"""
- dataset_path = _get_dataset_path(modality, dataset, datasets_path, bucket, aws_key, aws_secret)
+ dataset_path = _get_dataset_path(
+ modality, dataset, datasets_path, bucket, aws_access_key_id, aws_secret_access_key
+ )
with open(dataset_path / f'{dataset_path.name}.csv') as data_csv:
data = pd.read_csv(data_csv)
@@ -153,12 +172,14 @@ def get_available_datasets(modality='single_table'):
return _get_available_datasets(modality)
-def _get_available_datasets(modality, bucket=None, aws_key=None, aws_secret=None):
+def _get_available_datasets(
+ modality, bucket=None, aws_access_key_id=None, aws_secret_access_key=None
+):
if modality not in MODALITIES:
modalities_list = ', '.join(MODALITIES)
raise ValueError(f'Modality `{modality}` not recognized. Must be one of {modalities_list}')
- s3 = get_s3_client(aws_key, aws_secret)
+ s3 = get_s3_client(aws_access_key_id, aws_secret_access_key)
bucket = bucket or BUCKET
bucket_name = _get_bucket_name(bucket)
@@ -182,7 +203,11 @@ def _get_available_datasets(modality, bucket=None, aws_key=None, aws_secret=None
def get_dataset_paths(
- datasets=None, datasets_path=None, bucket=None, aws_key=None, aws_secret=None
+ datasets=None,
+ datasets_path=None,
+ bucket=None,
+ aws_access_key_id=None,
+ aws_secret_access_key=None,
):
"""Build the full path to datasets and ensure they exist.
@@ -193,9 +218,9 @@ def get_dataset_paths(
The path of the datasets.
bucket (str):
The AWS bucket where to get the dataset.
- aws_key (str):
+ aws_access_key_id (str):
The access key id that will be used to communicate with s3, if provided.
- aws_secret (str):
+ aws_secret_access_key (str):
The secret access key that will be used to communicate with s3, if provided.
Returns:
@@ -230,6 +255,8 @@ def get_dataset_paths(
].tolist()
return [
- _get_dataset_path('single_table', dataset, datasets_path, bucket, aws_key, aws_secret)
+ _get_dataset_path(
+ 'single_table', dataset, datasets_path, bucket, aws_access_key_id, aws_secret_access_key
+ )
for dataset in datasets
]
diff --git a/sdgym/result_writer.py b/sdgym/result_writer.py
index 33a280fb..2e94aa22 100644
--- a/sdgym/result_writer.py
+++ b/sdgym/result_writer.py
@@ -6,7 +6,10 @@
from pathlib import Path
import pandas as pd
+import plotly.graph_objects as go
import yaml
+from openpyxl import load_workbook
+from openpyxl.drawing.image import Image as XLImage
from sdgym.s3 import parse_s3_path
@@ -30,16 +33,61 @@ def write_yaml(self, data, file_path, append=False):
pass
-class LocalResultsWriter(ResultsWriter):
- """Results writer for local file system."""
+class LocalResultsWriter:
+ """Local results writer for saving results to the local filesystem."""
- def write_dataframe(self, data, file_path, append=False):
+ def write_dataframe(self, data, file_path, append=False, index=False):
"""Write a DataFrame to a CSV file."""
file_path = Path(file_path)
if file_path.exists() and append:
- data.to_csv(file_path, mode='a', index=False, header=False)
+ data.to_csv(file_path, mode='a', index=index, header=False)
+ else:
+ data.to_csv(file_path, mode='w', index=index, header=True)
+
+ def process_data(self, writer, file_path, temp_images, sheet_name, obj, index=False):
+ """Process a data item (DataFrame or Figure) and write it to the Excel writer."""
+ if isinstance(obj, pd.DataFrame):
+ obj.to_excel(writer, sheet_name=sheet_name, index=index)
+ elif isinstance(obj, go.Figure):
+ img_path = file_path.parent / f'{sheet_name}.png'
+ obj.write_image(img_path)
+ temp_images[sheet_name] = img_path
+
+ def write_xlsx(self, data, file_path, index=False):
+ """Write DataFrames and Plotly figures to an Excel file.
+
+ - DataFrames are saved as tables in their own sheets.
+ - Plotly figures are exported to PNG and embedded in their own sheets.
+ - Temporary PNG files are deleted after embedding.
+ - Newly written sheets are moved to the front.
+ """
+ file_path = Path(file_path)
+ temp_images = {}
+ file_path.parent.mkdir(parents=True, exist_ok=True)
+ if file_path.exists():
+ writer = pd.ExcelWriter(
+ file_path, mode='a', engine='openpyxl', if_sheet_exists='replace'
+ )
else:
- data.to_csv(file_path, mode='w', index=False)
+ writer = pd.ExcelWriter(file_path, mode='w', engine='openpyxl')
+
+ with writer:
+ for sheet_name, obj in data.items():
+ self.process_data(writer, file_path, temp_images, sheet_name, obj, index=index)
+
+ wb = load_workbook(file_path)
+ for sheet_name, img_path in temp_images.items():
+ ws = wb[sheet_name] if sheet_name in wb.sheetnames else wb.create_sheet(sheet_name)
+ ws.add_image(XLImage(img_path), 'A1')
+
+ for sheet_name in reversed(data.keys()):
+ ws = wb[sheet_name]
+ wb._sheets.remove(ws)
+ wb._sheets.insert(0, ws)
+
+ wb.save(file_path)
+ for img_path in temp_images.values():
+ img_path.unlink(missing_ok=True)
def write_pickle(self, obj, file_path):
"""Write a Python object to a pickle file."""
@@ -68,7 +116,7 @@ class S3ResultsWriter(ResultsWriter):
def __init__(self, s3_client):
self.s3_client = s3_client
- def write_dataframe(self, data, file_path, append=False):
+ def write_dataframe(self, data, file_path, append=False, index=False):
"""Write a DataFrame to S3 as a CSV file."""
bucket, key = parse_s3_path(file_path)
if append:
@@ -81,7 +129,7 @@ def write_dataframe(self, data, file_path, append=False):
except Exception:
pass # If the file does not exist, we will create it
- csv_buffer = data.to_csv(index=False).encode()
+ csv_buffer = data.to_csv(index=index).encode()
self.s3_client.put_object(Body=csv_buffer, Bucket=bucket, Key=key)
def write_pickle(self, obj, file_path):
diff --git a/sdgym/run_benchmark/run_benchmark.py b/sdgym/run_benchmark/run_benchmark.py
new file mode 100644
index 00000000..5ae5c609
--- /dev/null
+++ b/sdgym/run_benchmark/run_benchmark.py
@@ -0,0 +1,64 @@
+"""Script to run a benchmark and upload results to S3."""
+
+import json
+import os
+from datetime import datetime, timezone
+
+from botocore.exceptions import ClientError
+
+from sdgym.benchmark import benchmark_single_table_aws
+from sdgym.run_benchmark.utils import (
+ KEY_DATE_FILE,
+ OUTPUT_DESTINATION_AWS,
+ SYNTHESIZERS_SPLIT,
+ get_result_folder_name,
+ post_benchmark_launch_message,
+)
+from sdgym.s3 import get_s3_client, parse_s3_path
+
+
+def append_benchmark_run(aws_access_key_id, aws_secret_access_key, date_str):
+ """Append a new benchmark run to the benchmark dates file in S3."""
+ s3_client = get_s3_client(
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ )
+ bucket, prefix = parse_s3_path(OUTPUT_DESTINATION_AWS)
+ try:
+ object = s3_client.get_object(Bucket=bucket, Key=f'{prefix}{KEY_DATE_FILE}')
+ body = object['Body'].read().decode('utf-8')
+ data = json.loads(body)
+ except ClientError as e:
+ if e.response['Error']['Code'] == 'NoSuchKey':
+ data = {'runs': []}
+ else:
+ raise RuntimeError(f'Failed to read {KEY_DATE_FILE} from S3: {e}')
+
+ data['runs'].append({'date': date_str, 'folder_name': get_result_folder_name(date_str)})
+ data['runs'] = sorted(data['runs'], key=lambda x: x['date'])
+ s3_client.put_object(
+ Bucket=bucket, Key=f'{prefix}{KEY_DATE_FILE}', Body=json.dumps(data).encode('utf-8')
+ )
+
+
+def main():
+ """Main function to run the benchmark and upload results."""
+ aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
+ aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
+ date_str = datetime.now(timezone.utc).strftime('%Y-%m-%d')
+ for synthesizer_group in SYNTHESIZERS_SPLIT:
+ benchmark_single_table_aws(
+ output_destination=OUTPUT_DESTINATION_AWS,
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ synthesizers=synthesizer_group,
+ compute_privacy_score=False,
+ timeout=345600, # 4 days
+ )
+
+ append_benchmark_run(aws_access_key_id, aws_secret_access_key, date_str)
+ post_benchmark_launch_message(date_str)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/sdgym/run_benchmark/upload_benchmark_results.py b/sdgym/run_benchmark/upload_benchmark_results.py
new file mode 100644
index 00000000..be9b6858
--- /dev/null
+++ b/sdgym/run_benchmark/upload_benchmark_results.py
@@ -0,0 +1,281 @@
+"""Script to upload benchmark results to S3."""
+
+import json
+import logging
+import os
+import shutil
+import sys
+import tempfile
+from pathlib import Path
+
+import boto3
+import numpy as np
+import plotly.express as px
+from botocore.exceptions import ClientError
+from oauth2client.client import OAuth2Credentials
+from plotly import graph_objects as go
+from pydrive2.auth import GoogleAuth
+from pydrive2.drive import GoogleDrive
+from scipy.interpolate import interp1d
+
+from sdgym.result_writer import LocalResultsWriter
+from sdgym.run_benchmark.utils import OUTPUT_DESTINATION_AWS, get_df_to_plot
+from sdgym.s3 import S3_REGION, parse_s3_path
+from sdgym.sdgym_result_explorer.result_explorer import SDGymResultsExplorer
+
+LOGGER = logging.getLogger(__name__)
+SYNTHESIZER_TO_GLOBAL_POSITION = {
+ 'CTGAN': 'middle right',
+ 'TVAE': 'middle left',
+ 'GaussianCopula': 'bottom center',
+ 'Uniform': 'top center',
+ 'Column': 'top center',
+ 'CopulaGAN': 'top center',
+ 'RealTabFormer': 'bottom center',
+}
+SDGYM_FILE_ID = '1W3tsGOOtbtTw3g0EVE0irLgY_TN_cy2W4ONiZQ57OPo'
+RESULT_FILENAME = 'SDGym Monthly Run.xlsx'
+
+
+def get_latest_run_from_file(s3_client, bucket, key):
+ """Get the latest run folder name from the benchmark dates file in S3."""
+ try:
+ object = s3_client.get_object(Bucket=bucket, Key=key)
+ body = object['Body'].read().decode('utf-8')
+ data = json.loads(body)
+ latest = sorted(data['runs'], key=lambda x: x['date'])[-1]
+ return latest
+ except s3_client.exceptions.ClientError as e:
+ raise RuntimeError(f'Failed to read {key} from S3: {e}')
+
+
+def write_uploaded_marker(s3_client, bucket, prefix, folder_name):
+ """Write a marker file to indicate that the upload is complete."""
+ s3_client.put_object(
+ Bucket=bucket, Key=f'{prefix}{folder_name}/upload_complete.marker', Body=b'Upload complete'
+ )
+
+
+def upload_already_done(s3_client, bucket, prefix, folder_name):
+ """Check if the upload has already been done by looking for the marker file."""
+ try:
+ s3_client.head_object(Bucket=bucket, Key=f'{prefix}{folder_name}/upload_complete.marker')
+ return True
+ except ClientError as e:
+ if e.response['Error']['Code'] == '404':
+ return False
+
+ raise
+
+
+def get_result_folder_name_and_s3_vars(aws_access_key_id, aws_secret_access_key):
+ """Get the result folder name and S3 client variables."""
+ bucket, prefix = parse_s3_path(OUTPUT_DESTINATION_AWS)
+ s3_client = boto3.client(
+ 's3',
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ region_name=S3_REGION,
+ )
+ folder_infos = get_latest_run_from_file(s3_client, bucket, f'{prefix}_BENCHMARK_DATES.json')
+
+ return folder_infos, s3_client, bucket, prefix
+
+
+def generate_graph(plot_table):
+ """Generate a scatter plot for the benchmark results."""
+ fig = px.scatter(
+ plot_table,
+ x='Aggregated_Time',
+ y='Quality_Score',
+ color='Synthesizer',
+ text='Synthesizer',
+ title='Mean Quality Score vs Aggregated Time (Over All Datasets)',
+ labels={'Aggregated_Time': 'Aggregated Time [s]', 'Quality_Score': 'Mean Quality Score'},
+ log_x=True,
+ color_discrete_sequence=px.colors.qualitative.Plotly,
+ )
+
+ for trace in fig.data:
+ synthesizer_name = trace.name
+ shape = plot_table.loc[plot_table['Synthesizer'] == synthesizer_name, 'Marker'].values[0]
+ color = plot_table.loc[plot_table['Synthesizer'] == synthesizer_name, 'Color'].values[0]
+ trace_positions = SYNTHESIZER_TO_GLOBAL_POSITION.get(synthesizer_name, 'top center')
+ trace.update(
+ marker=dict(size=14, color=color), textposition=trace_positions, marker_symbol=shape
+ )
+
+ fig.update_layout(
+ xaxis=dict(
+ tickformat='.0e',
+ tickmode='array',
+ tickvals=[1e1, 1e2, 1e3, 1e4, 1e5, 1e6],
+ ticktext=[
+ '101',
+ '102',
+ '103',
+ '104',
+ '105',
+ '106',
+ ],
+ showgrid=False,
+ zeroline=False,
+ title='Aggregated Time [s]',
+ range=[0.6, 6],
+ ),
+ yaxis=dict(showgrid=False, zeroline=False, range=[0.54, 0.92]),
+ plot_bgcolor='#F5F5F8',
+ )
+
+ fig.update_traces(textfont=dict(size=16))
+ pareto_points = plot_table.loc[plot_table['Pareto']]
+ x_pareto = pareto_points['Aggregated_Time'].values
+ y_pareto = pareto_points['Quality_Score'].values
+ sorted_indices = np.argsort(x_pareto)
+ x_sorted = x_pareto[sorted_indices]
+ y_sorted = y_pareto[sorted_indices]
+ log_x_sorted = np.log10(x_sorted)
+ interp = interp1d(log_x_sorted, y_sorted, kind='linear', fill_value='extrapolate')
+ log_x_fit = np.linspace(0.7, 6, 100)
+ y_fit = interp(log_x_fit)
+ x_fit = np.power(10, log_x_fit)
+
+ # Plot smooth interpolation
+ fig.add_trace(
+ go.Scatter(
+ x=x_fit,
+ y=y_fit,
+ mode='lines',
+ name='Pareto Frontier',
+ line=dict(color='black', width=2),
+ )
+ )
+ x_shade = np.concatenate([x_fit, x_fit[::-1]])
+ y_shade = np.concatenate([y_fit, np.full_like(x_fit, min(y_fit))[::-1]])
+ fig.add_trace(
+ go.Scatter(
+ x=x_shade,
+ y=y_shade,
+ fill='toself',
+ fillcolor='rgba(0, 0, 54, 0.25)',
+ line=dict(color='#000036'),
+ hoverinfo='skip',
+ showlegend=False,
+ )
+ )
+
+ return fig
+
+
+def upload_to_drive(file_path, file_id):
+ """Upload a local file to a Google Drive folder.
+
+ Args:
+ file_path (str or Path): Path to the local file to upload.
+ file_id (str): Google Drive file ID.
+ """
+ file_path = Path(file_path)
+ if not file_path.exists():
+ raise FileNotFoundError(f'File not found: {file_path}')
+
+ creds_dict = json.loads(os.environ['PYDRIVE_CREDENTIALS'])
+ creds = OAuth2Credentials(
+ access_token=creds_dict['access_token'],
+ client_id=creds_dict.get('client_id'),
+ client_secret=creds_dict.get('client_secret'),
+ refresh_token=creds_dict.get('refresh_token'),
+ token_expiry=None,
+ token_uri='https://oauth2.googleapis.com/token',
+ user_agent=None,
+ )
+ gauth = GoogleAuth()
+ gauth.credentials = creds
+ drive = GoogleDrive(gauth)
+
+ gfile = drive.CreateFile({'id': file_id})
+ gfile.SetContentFile(file_path)
+ gfile.Upload(param={'supportsAllDrives': True})
+
+
+def upload_results(
+ aws_access_key_id, aws_secret_access_key, folder_infos, s3_client, bucket, prefix, github_env
+):
+ """Upload benchmark results to S3, GDrive, and save locally."""
+ folder_name = folder_infos['folder_name']
+ run_date = folder_infos['date']
+ result_explorer = SDGymResultsExplorer(
+ OUTPUT_DESTINATION_AWS,
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ )
+ local_results_writer = LocalResultsWriter()
+ if not result_explorer.all_runs_complete(folder_name):
+ LOGGER.warning(f'Run {folder_name} is not complete yet. Exiting.')
+ if github_env:
+ with open(github_env, 'a') as env_file:
+ env_file.write('SKIP_UPLOAD=true\n')
+
+ sys.exit(0)
+
+ LOGGER.info(f'Run {folder_name} is complete! Proceeding with summarization...')
+ if github_env:
+ with open(github_env, 'a') as env_file:
+ env_file.write('SKIP_UPLOAD=false\n')
+ env_file.write(f'FOLDER_NAME={folder_name}\n')
+
+ summary, results = result_explorer.summarize(folder_name)
+ df_to_plot = get_df_to_plot(results)
+ figure = generate_graph(df_to_plot)
+ local_export_dir = os.environ.get('GITHUB_LOCAL_RESULTS_DIR')
+ temp_dir = None
+ if not local_export_dir:
+ temp_dir = tempfile.mkdtemp()
+ local_export_dir = temp_dir
+
+ os.makedirs(local_export_dir, exist_ok=True)
+ local_file_path = os.path.join(local_export_dir, RESULT_FILENAME)
+ s3_key = f'{prefix}{RESULT_FILENAME}'
+ s3_client.download_file(bucket, s3_key, local_file_path)
+ datas = {
+ 'Wins': summary,
+ f'{run_date}_Detailed_results': results,
+ f'{run_date}_plot_data': df_to_plot,
+ f'{run_date}_plot_image': figure,
+ }
+ local_results_writer.write_xlsx(datas, local_file_path)
+ upload_to_drive(local_file_path, SDGYM_FILE_ID)
+ s3_client.upload_file(local_file_path, bucket, s3_key)
+ write_uploaded_marker(s3_client, bucket, prefix, folder_name)
+ if temp_dir:
+ shutil.rmtree(temp_dir)
+
+
+def main():
+ """Main function to upload benchmark results."""
+ aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
+ aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
+ folder_infos, s3_client, bucket, prefix = get_result_folder_name_and_s3_vars(
+ aws_access_key_id, aws_secret_access_key
+ )
+ github_env = os.getenv('GITHUB_ENV')
+ if upload_already_done(s3_client, bucket, prefix, folder_infos['folder_name']):
+ LOGGER.warning('Benchmark results have already been uploaded. Exiting.')
+ if github_env:
+ with open(github_env, 'a') as env_file:
+ env_file.write('SKIP_UPLOAD=true\n')
+
+ sys.exit(0)
+
+ upload_results(
+ aws_access_key_id,
+ aws_secret_access_key,
+ folder_infos,
+ s3_client,
+ bucket,
+ prefix,
+ github_env,
+ )
+
+
+if __name__ == '__main__':
+ main()
diff --git a/sdgym/run_benchmark/utils.py b/sdgym/run_benchmark/utils.py
new file mode 100644
index 00000000..28a52d03
--- /dev/null
+++ b/sdgym/run_benchmark/utils.py
@@ -0,0 +1,155 @@
+"""Utils file for the run_benchmark module."""
+
+import os
+from datetime import datetime
+
+import numpy as np
+from slack_sdk import WebClient
+
+from sdgym.s3 import parse_s3_path
+
+OUTPUT_DESTINATION_AWS = 's3://sdgym-benchmark/Benchmarks/'
+UPLOAD_DESTINATION_AWS = 's3://sdgym-benchmark/Benchmarks/'
+DEBUG_SLACK_CHANNEL = 'sdv-alerts-debug'
+SLACK_CHANNEL = 'sdv-alerts'
+KEY_DATE_FILE = '_BENCHMARK_DATES.json'
+PLOTLY_MARKERS = [
+ 'circle',
+ 'square',
+ 'diamond',
+ 'cross',
+ 'x',
+ 'triangle-up',
+ 'triangle-down',
+ 'triangle-left',
+ 'triangle-right',
+ 'pentagon',
+ 'hexagon',
+ 'hexagon2',
+ 'octagon',
+ 'star',
+ 'hexagram',
+ 'star-triangle-up',
+ 'star-triangle-down',
+ 'star-square',
+ 'star-diamond',
+ 'diamond-tall',
+ 'diamond-wide',
+ 'hourglass',
+ 'bowtie',
+ 'circle-cross',
+ 'circle-x',
+ 'square-cross',
+ 'square-x',
+ 'diamond-cross',
+ 'diamond-x',
+]
+
+# The synthesizers inside the same list will be run by the same ec2 instance
+SYNTHESIZERS_SPLIT = [
+ ['UniformSynthesizer', 'ColumnSynthesizer', 'GaussianCopulaSynthesizer', 'TVAESynthesizer'],
+ ['CopulaGANSynthesizer'],
+ ['CTGANSynthesizer'],
+ ['RealTabFormerSynthesizer'],
+]
+
+
+def get_result_folder_name(date_str):
+ """Get the result folder name based on the date string."""
+ try:
+ date = datetime.strptime(date_str, '%Y-%m-%d')
+ except ValueError:
+ raise ValueError(f'Invalid date format: {date_str}. Expected YYYY-MM-DD.')
+
+ return f'SDGym_results_{date.month:02d}_{date.day:02d}_{date.year}'
+
+
+def get_s3_console_link(bucket, prefix):
+ """Get the S3 console link for the specified bucket and prefix."""
+ return (
+ f'https://s3.console.aws.amazon.com/s3/buckets/%7Bbucket%7D?prefix={prefix}&showversions=false'
+ )
+
+
+def _get_slack_client():
+ """Create an authenticated Slack client.
+
+ Returns:
+ WebClient:
+ An authenticated Slack WebClient instance.
+ """
+ token = os.getenv('SLACK_TOKEN')
+ client = WebClient(token=token)
+ return client
+
+
+def post_slack_message(channel, text):
+ """Post a message to a Slack channel."""
+ client = _get_slack_client()
+ client.chat_postMessage(channel=channel, text=text)
+
+
+def post_benchmark_launch_message(date_str):
+ """Post a message to the SDV Alerts Slack channel when the benchmark is launched."""
+ channel = DEBUG_SLACK_CHANNEL
+ folder_name = get_result_folder_name(date_str)
+ bucket, prefix = parse_s3_path(OUTPUT_DESTINATION_AWS)
+ url_link = get_s3_console_link(bucket, f'{prefix}{folder_name}/')
+ body = 'π SDGym benchmark has been launched! EC2 Instances are running. '
+ body += f'Intermediate results can be found <{url_link}|here>.\n'
+ post_slack_message(channel, body)
+
+
+def post_benchmark_uploaded_message(folder_name, commit_url=None):
+ """Post benchmark uploaded message to sdv-alerts slack channel."""
+ channel = DEBUG_SLACK_CHANNEL
+ bucket, prefix = parse_s3_path(OUTPUT_DESTINATION_AWS)
+ url_link = get_s3_console_link(bucket, f'{prefix}{folder_name}/{folder_name}_summary.csv')
+ body = (
+ f'π€Έπ»ββοΈ SDGym benchmark results for *{folder_name}* are available! ποΈββοΈ\n'
+ f'Check the results <{url_link} |here>'
+ )
+ if commit_url:
+ body += f' or on GitHub: <{commit_url}|Commit Link>\n'
+
+ post_slack_message(channel, body)
+
+
+def get_df_to_plot(benchmark_result):
+ """Get the data to plot from the benchmark result.
+
+ Args:
+ benchmark_result (DataFrame): The benchmark result DataFrame.
+
+ Returns:
+ DataFrame: The data to plot.
+ """
+ df_to_plot = benchmark_result.copy()
+ df_to_plot['total_time'] = df_to_plot['Train_Time'] + df_to_plot['Sample_Time']
+ df_to_plot['Aggregated_Time'] = df_to_plot.groupby('Synthesizer')['total_time'].transform('sum')
+ df_to_plot = (
+ df_to_plot.groupby('Synthesizer')[['Aggregated_Time', 'Quality_Score']].mean().reset_index()
+ )
+ df_to_plot['Log10 Aggregated_Time'] = df_to_plot['Aggregated_Time'].apply(
+ lambda x: np.log10(x) if x > 0 else 0
+ )
+ df_to_plot = df_to_plot.sort_values(
+ ['Aggregated_Time', 'Quality_Score'], ascending=[True, False]
+ )
+ df_to_plot['Cumulative Quality Score'] = df_to_plot['Quality_Score'].cummax()
+ pareto_points = df_to_plot.loc[
+ df_to_plot['Quality_Score'] == df_to_plot['Cumulative Quality Score']
+ ]
+ df_to_plot['Pareto'] = df_to_plot.index.isin(pareto_points.index)
+ df_to_plot['Color'] = df_to_plot['Pareto'].apply(lambda x: '#01E0C9' if x else '#03AFF1')
+ df_to_plot['Synthesizer'] = df_to_plot['Synthesizer'].str.replace(
+ 'Synthesizer', '', regex=False
+ )
+
+ synthesizers = df_to_plot['Synthesizer'].unique()
+ marker_map = {
+ synth: PLOTLY_MARKERS[i % len(PLOTLY_MARKERS)] for i, synth in enumerate(synthesizers)
+ }
+ df_to_plot['Marker'] = df_to_plot['Synthesizer'].map(marker_map)
+
+ return df_to_plot.drop(columns=['Cumulative Quality Score']).reset_index(drop=True)
diff --git a/sdgym/s3.py b/sdgym/s3.py
index bfc22be9..d271f2c5 100644
--- a/sdgym/s3.py
+++ b/sdgym/s3.py
@@ -10,6 +10,7 @@
import pandas as pd
S3_PREFIX = 's3://'
+S3_REGION = 'us-east-1'
LOGGER = logging.getLogger(__name__)
@@ -49,14 +50,14 @@ def parse_s3_path(path):
return bucket_name, key_prefix
-def get_s3_client(aws_key=None, aws_secret=None):
+def get_s3_client(aws_access_key_id=None, aws_secret_access_key=None):
"""Get the boto client for interfacing with AWS s3.
Args:
- aws_key (str):
+ aws_access_key_id (str):
The access key id that will be used to communicate with
s3, if provided.
- aws_secret (str):
+ aws_secret_access_key (str):
The secret access key that will be used to communicate
with s3, if provided.
@@ -64,9 +65,14 @@ def get_s3_client(aws_key=None, aws_secret=None):
boto3.session.Session.client:
The s3 client that can be used to read / write to s3.
"""
- if aws_key is not None and aws_secret is not None:
+ if aws_access_key_id is not None and aws_secret_access_key is not None:
# credentials available
- return boto3.client('s3', aws_access_key_id=aws_key, aws_secret_access_key=aws_secret)
+ return boto3.client(
+ 's3',
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ region_name=S3_REGION,
+ )
else:
if boto3.Session().get_credentials():
# credentials available and will be detected automatically
@@ -78,7 +84,7 @@ def get_s3_client(aws_key=None, aws_secret=None):
return boto3.client('s3', config=config)
-def write_file(data_contents, path, aws_key, aws_secret):
+def write_file(data_contents, path, aws_access_key_id, aws_secret_access_key):
"""Write a file to the given path with the given contents.
If the path is an s3 directory, we will use the given aws credentials
@@ -90,10 +96,10 @@ def write_file(data_contents, path, aws_key, aws_secret):
path (str):
The path to write the file to, which can be either local
or an s3 path.
- aws_key (str):
+ aws_access_key_id (str):
The access key id that will be used to communicate with s3,
if provided.
- aws_secret (str):
+ aws_secret_access_key (str):
The secret access key that will be used to communicate
with s3, if provided.
@@ -109,7 +115,7 @@ def write_file(data_contents, path, aws_key, aws_secret):
write_mode = 'wb'
if is_s3_path(path):
- s3 = get_s3_client(aws_key, aws_secret)
+ s3 = get_s3_client(aws_access_key_id, aws_secret_access_key)
bucket_name, key = parse_s3_path(path)
s3.put_object(
Bucket=bucket_name,
@@ -125,7 +131,7 @@ def write_file(data_contents, path, aws_key, aws_secret):
f.write(data_contents)
-def write_csv(data, path, aws_key, aws_secret):
+def write_csv(data, path, aws_access_key_id, aws_secret_access_key):
"""Write a csv file to the given path with the given contents.
If the path is an s3 directory, we will use the given aws credentials
@@ -137,10 +143,10 @@ def write_csv(data, path, aws_key, aws_secret):
path (str):
The path to write the file to, which can be either local
or an s3 path.
- aws_key (str):
+ aws_access_key_id (str):
The access key id that will be used to communicate with s3,
if provided.
- aws_secret (str):
+ aws_secret_access_key (str):
The secret access key that will be used to communicate
with s3, if provided.
@@ -148,7 +154,7 @@ def write_csv(data, path, aws_key, aws_secret):
none
"""
data_contents = data.to_csv(index=False).encode('utf-8')
- write_file(data_contents, path, aws_key, aws_secret)
+ write_file(data_contents, path, aws_access_key_id, aws_secret_access_key)
def _parse_s3_paths(s3_paths_dict):
@@ -203,6 +209,7 @@ def _get_s3_client(output_destination, aws_access_key_id=None, aws_secret_access
's3',
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
+ region_name=S3_REGION,
)
else:
s3_client = boto3.client('s3')
diff --git a/sdgym/sdgym_result_explorer/result_explorer.py b/sdgym/sdgym_result_explorer/result_explorer.py
index eb04b576..889fde95 100644
--- a/sdgym/sdgym_result_explorer/result_explorer.py
+++ b/sdgym/sdgym_result_explorer/result_explorer.py
@@ -65,8 +65,8 @@ def load_real_data(self, dataset_name):
if dataset_name in DEFAULT_DATASETS:
dataset_path = get_dataset_paths(
datasets=[dataset_name],
- aws_key=self.aws_access_key_id,
- aws_secret=self.aws_secret_access_key,
+ aws_access_key_id=self.aws_access_key_id,
+ aws_secret_access_key=self.aws_secret_access_key,
)[0]
else:
raise ValueError(
@@ -77,8 +77,8 @@ def load_real_data(self, dataset_name):
data, _ = load_dataset(
'single_table',
dataset_path,
- aws_key=self.aws_access_key_id,
- aws_secret=self.aws_secret_access_key,
+ aws_access_key_id=self.aws_access_key_id,
+ aws_secret_access_key=self.aws_secret_access_key,
)
return data
@@ -95,3 +95,7 @@ def summarize(self, folder_name):
- A DataFrame with the results of the benchmark for the specified folder.
"""
return self._handler.summarize(folder_name)
+
+ def all_runs_complete(self, folder_name):
+ """Check if all runs in the specified folder are complete."""
+ return self._handler.all_runs_complete(folder_name)
diff --git a/sdgym/sdgym_result_explorer/result_handler.py b/sdgym/sdgym_result_explorer/result_handler.py
index 3de27197..c8f4073f 100644
--- a/sdgym/sdgym_result_explorer/result_handler.py
+++ b/sdgym/sdgym_result_explorer/result_handler.py
@@ -86,6 +86,9 @@ def _get_summarize_table(self, folder_to_results, folder_infos):
summarized_results[column_name] = column_data
summarized_results = summarized_results.fillna('-')
+ summarized_results = summarized_results.reset_index()
+ summarized_results = summarized_results.rename(columns={'index': 'Synthesizer'})
+
return summarized_results
def _get_column_name_infos(self, folder_to_results):
@@ -121,6 +124,7 @@ def _process_results(self, results):
'summarize results.'
)
+ filtered_results = filtered_results.sort_values(by=['Dataset', 'Synthesizer'])
return filtered_results.reset_index(drop=True)
def summarize(self, folder_name):
@@ -155,6 +159,19 @@ def summarize(self, folder_name):
return summarized_table, folder_to_results[folder_name]
+ def all_runs_complete(self, folder_name):
+ """Check if all runs in the specified folder are complete."""
+ yaml_files = self._get_results_files(folder_name, prefix=RUN_ID_PREFIX, suffix='.yaml')
+ if not yaml_files:
+ return False
+
+ for yaml_file in yaml_files:
+ run_id_info = self._load_yaml_file(folder_name, yaml_file)
+ if run_id_info.get('completed_date') is None:
+ return False
+
+ return True
+
class LocalResultsHandler(ResultsHandler):
"""Results handler for local filesystem."""
diff --git a/tasks.py b/tasks.py
index 76eb01a0..0f8a47b5 100644
--- a/tasks.py
+++ b/tasks.py
@@ -10,7 +10,7 @@
from invoke import task
from packaging.requirements import Requirement
from packaging.version import Version
-
+from sdgym.run_benchmark.utils import post_benchmark_uploaded_message
COMPARISONS = {'>=': operator.ge, '>': operator.gt, '<': operator.lt, '<=': operator.le}
EGG_STRING = '#egg='
@@ -202,3 +202,18 @@ def rmdir(c, path):
shutil.rmtree(path, onerror=remove_readonly)
except PermissionError:
pass
+
+@task
+def run_sdgym_benchmark(c):
+ """Run the SDGym benchmark."""
+ c.run('python sdgym/run_benchmark/run_benchmark.py')
+
+@task
+def upload_benchmark_results(c):
+ """Upload the benchmark results to S3."""
+ c.run(f'python sdgym/run_benchmark/upload_benchmark_results.py')
+
+@task
+def notify_sdgym_benchmark_uploaded(c, folder_name, commit_url=None):
+ """Notify Slack about the SDGym benchmark upload."""
+ post_benchmark_uploaded_message(folder_name, commit_url)
\ No newline at end of file
diff --git a/tests/integration/sdgym_result_explorer/test_result_explorer.py b/tests/integration/sdgym_result_explorer/test_result_explorer.py
index 2a10270e..f56fd346 100644
--- a/tests/integration/sdgym_result_explorer/test_result_explorer.py
+++ b/tests/integration/sdgym_result_explorer/test_result_explorer.py
@@ -58,16 +58,19 @@ def test_summarize():
# Assert
expected_summary = pd.DataFrame({
+ 'Synthesizer': ['CTGANSynthesizer', 'CopulaGANSynthesizer', 'TVAESynthesizer'],
'10_11_2024 - # datasets: 9 - sdgym version: 0.9.1': [6, 4, 5],
'05_10_2024 - # datasets: 9 - sdgym version: 0.8.0': [4, 4, 5],
'04_05_2024 - # datasets: 9 - sdgym version: 0.7.0': [5, 3, 5],
- 'Synthesizer': ['CTGANSynthesizer', 'CopulaGANSynthesizer', 'TVAESynthesizer'],
})
- expected_results = pd.read_csv(
- 'tests/integration/sdgym_result_explorer/_benchmark_results/'
- 'SDGym_results_10_11_2024/results_10_11_2024_1.csv',
+ expected_results = (
+ pd.read_csv(
+ 'tests/integration/sdgym_result_explorer/_benchmark_results/'
+ 'SDGym_results_10_11_2024/results_10_11_2024_1.csv',
+ )
+ .sort_values(by=['Dataset', 'Synthesizer'])
+ .reset_index(drop=True)
)
expected_results['Win'] = expected_results['Win'].astype('int64')
- expected_summary = expected_summary.set_index('Synthesizer')
pd.testing.assert_frame_equal(summary, expected_summary)
pd.testing.assert_frame_equal(results, expected_results)
diff --git a/tests/unit/run_benchmark/test_run_benchmark.py b/tests/unit/run_benchmark/test_run_benchmark.py
new file mode 100644
index 00000000..aacab84e
--- /dev/null
+++ b/tests/unit/run_benchmark/test_run_benchmark.py
@@ -0,0 +1,143 @@
+import json
+from datetime import datetime, timezone
+from unittest.mock import Mock, call, patch
+
+from botocore.exceptions import ClientError
+
+from sdgym.run_benchmark.run_benchmark import append_benchmark_run, main
+from sdgym.run_benchmark.utils import OUTPUT_DESTINATION_AWS, SYNTHESIZERS_SPLIT
+
+
+@patch('sdgym.run_benchmark.run_benchmark.get_s3_client')
+@patch('sdgym.run_benchmark.run_benchmark.parse_s3_path')
+@patch('sdgym.run_benchmark.run_benchmark.get_result_folder_name')
+def test_append_benchmark_run(mock_get_result_folder_name, mock_parse_s3_path, mock_get_s3_client):
+ """Test the `append_benchmark_run` method."""
+ # Setup
+ aws_access_key_id = 'my_access_key'
+ aws_secret_access_key = 'my_secret_key'
+ date = '2023-10-01'
+ mock_get_result_folder_name.return_value = 'SDGym_results_10_01_2023'
+ mock_parse_s3_path.return_value = ('my-bucket', 'my-prefix/')
+ mock_s3_client = Mock()
+ benchmark_date = {
+ 'runs': [
+ {'date': '2023-09-30', 'folder_name': 'SDGym_results_09_30_2023'},
+ ]
+ }
+ mock_get_s3_client.return_value = mock_s3_client
+ mock_s3_client.get_object.return_value = {
+ 'Body': Mock(read=lambda: json.dumps(benchmark_date).encode('utf-8'))
+ }
+ expected_data = {
+ 'runs': [
+ {'date': '2023-09-30', 'folder_name': 'SDGym_results_09_30_2023'},
+ {'date': date, 'folder_name': 'SDGym_results_10_01_2023'},
+ ]
+ }
+
+ # Run
+ append_benchmark_run(aws_access_key_id, aws_secret_access_key, date)
+
+ # Assert
+ mock_get_s3_client.assert_called_once_with(
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ )
+ mock_parse_s3_path.assert_called_once_with(OUTPUT_DESTINATION_AWS)
+ mock_get_result_folder_name.assert_called_once_with(date)
+ mock_s3_client.get_object.assert_called_once_with(
+ Bucket='my-bucket', Key='my-prefix/_BENCHMARK_DATES.json'
+ )
+ mock_s3_client.put_object.assert_called_once_with(
+ Bucket='my-bucket',
+ Key='my-prefix/_BENCHMARK_DATES.json',
+ Body=json.dumps(expected_data).encode('utf-8'),
+ )
+
+
+@patch('sdgym.run_benchmark.run_benchmark.get_s3_client')
+@patch('sdgym.run_benchmark.run_benchmark.parse_s3_path')
+@patch('sdgym.run_benchmark.run_benchmark.get_result_folder_name')
+def test_append_benchmark_run_new_file(
+ mock_get_result_folder_name, mock_parse_s3_path, mock_get_s3_client
+):
+ """Test the `append_benchmark_run` with a new file."""
+ # Setup
+ aws_access_key_id = 'my_access_key'
+ aws_secret_access_key = 'my_secret_key'
+ date = '2023-10-01'
+ mock_get_result_folder_name.return_value = 'SDGym_results_10_01_2023'
+ mock_parse_s3_path.return_value = ('my-bucket', 'my-prefix/')
+ mock_s3_client = Mock()
+ mock_get_s3_client.return_value = mock_s3_client
+ mock_s3_client.get_object.side_effect = ClientError(
+ {'Error': {'Code': 'NoSuchKey'}}, 'GetObject'
+ )
+ expected_data = {
+ 'runs': [
+ {'date': date, 'folder_name': 'SDGym_results_10_01_2023'},
+ ]
+ }
+
+ # Run
+ append_benchmark_run(aws_access_key_id, aws_secret_access_key, date)
+
+ # Assert
+ mock_get_s3_client.assert_called_once_with(
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ )
+ mock_parse_s3_path.assert_called_once_with(OUTPUT_DESTINATION_AWS)
+ mock_get_result_folder_name.assert_called_once_with(date)
+ mock_s3_client.get_object.assert_called_once_with(
+ Bucket='my-bucket', Key='my-prefix/_BENCHMARK_DATES.json'
+ )
+ mock_s3_client.put_object.assert_called_once_with(
+ Bucket='my-bucket',
+ Key='my-prefix/_BENCHMARK_DATES.json',
+ Body=json.dumps(expected_data).encode('utf-8'),
+ )
+
+
+@patch('sdgym.run_benchmark.run_benchmark.benchmark_single_table_aws')
+@patch('sdgym.run_benchmark.run_benchmark.os.getenv')
+@patch('sdgym.run_benchmark.run_benchmark.append_benchmark_run')
+@patch('sdgym.run_benchmark.run_benchmark.post_benchmark_launch_message')
+def test_main(
+ mock_post_benchmark_launch_message,
+ mock_append_benchmark_run,
+ mock_getenv,
+ mock_benchmark_single_table_aws,
+):
+ """Test the `main` method."""
+ # Setup
+ mock_getenv.side_effect = ['my_access_key', 'my_secret_key']
+ date = datetime.now(timezone.utc).strftime('%Y-%m-%d')
+
+ # Run
+ main()
+
+ # Assert
+ mock_getenv.assert_any_call('AWS_ACCESS_KEY_ID')
+ mock_getenv.assert_any_call('AWS_SECRET_ACCESS_KEY')
+ expected_calls = []
+ for synthesizer in SYNTHESIZERS_SPLIT:
+ expected_calls.append(
+ call(
+ output_destination=OUTPUT_DESTINATION_AWS,
+ aws_access_key_id='my_access_key',
+ aws_secret_access_key='my_secret_key',
+ synthesizers=synthesizer,
+ compute_privacy_score=False,
+ timeout=345600,
+ )
+ )
+
+ mock_benchmark_single_table_aws.assert_has_calls(expected_calls)
+ mock_append_benchmark_run.assert_called_once_with(
+ 'my_access_key',
+ 'my_secret_key',
+ date,
+ )
+ mock_post_benchmark_launch_message.assert_called_once_with(date)
diff --git a/tests/unit/run_benchmark/test_upload_benchmark_result.py b/tests/unit/run_benchmark/test_upload_benchmark_result.py
new file mode 100644
index 00000000..63c18c41
--- /dev/null
+++ b/tests/unit/run_benchmark/test_upload_benchmark_result.py
@@ -0,0 +1,293 @@
+from unittest.mock import Mock, patch
+
+import pytest
+from botocore.exceptions import ClientError
+
+from sdgym.run_benchmark.upload_benchmark_results import (
+ SDGYM_FILE_ID,
+ get_result_folder_name_and_s3_vars,
+ main,
+ upload_already_done,
+ upload_results,
+ write_uploaded_marker,
+)
+from sdgym.s3 import S3_REGION
+
+
+def test_write_uploaded_marker():
+ """Test the `write_uploaded_marker` method."""
+ # Setup
+ s3_client = Mock()
+ bucket = 'test-bucket'
+ prefix = 'test-prefix/'
+ run_name = 'test_run'
+
+ # Run
+ write_uploaded_marker(s3_client, bucket, prefix, run_name)
+
+ # Assert
+ s3_client.put_object.assert_called_once_with(
+ Bucket=bucket, Key=f'{prefix}{run_name}/upload_complete.marker', Body=b'Upload complete'
+ )
+
+
+def test_upload_already_done():
+ """Test the `upload_already_done` method."""
+ # Setup
+ s3_client = Mock()
+ bucket = 'test-bucket'
+ prefix = 'test-prefix/'
+ run_name = 'test_run'
+ s3_client.head_object.side_effect = [
+ '',
+ ClientError(
+ error_response={'Error': {'Code': '404', 'Message': 'Not Found'}},
+ operation_name='HeadObject',
+ ),
+ ClientError(
+ error_response={'Error': {'Code': '405', 'Message': 'Other Error'}},
+ operation_name='HeadObject',
+ ),
+ ]
+
+ # Run
+ result = upload_already_done(s3_client, bucket, prefix, run_name)
+ result_false = upload_already_done(s3_client, bucket, prefix, run_name)
+ with pytest.raises(ClientError):
+ upload_already_done(s3_client, bucket, prefix, run_name)
+
+ # Assert
+ assert result is True
+ assert result_false is False
+
+
+@patch('sdgym.run_benchmark.upload_benchmark_results.boto3.client')
+@patch('sdgym.run_benchmark.upload_benchmark_results.parse_s3_path')
+@patch('sdgym.run_benchmark.upload_benchmark_results.OUTPUT_DESTINATION_AWS')
+@patch('sdgym.run_benchmark.upload_benchmark_results.get_latest_run_from_file')
+def test_get_result_folder_name_and_s3_vars(
+ mock_get_latest_run_from_file,
+ mock_output_destination_aws,
+ mock_parse_s3_path,
+ mock_boto_client,
+):
+ """Test the `get_result_folder_name_and_s3_vars` method."""
+ # Setup
+ aws_access_key_id = 'my_access_key'
+ aws_secret_access_key = 'my_secret_key'
+ expected_result = ('SDGym_results_10_01_2023', 's3_client', 'bucket', 'prefix')
+ mock_boto_client.return_value = 's3_client'
+ mock_parse_s3_path.return_value = ('bucket', 'prefix')
+ mock_get_latest_run_from_file.return_value = 'SDGym_results_10_01_2023'
+
+ # Run
+ result = get_result_folder_name_and_s3_vars(aws_access_key_id, aws_secret_access_key)
+
+ # Assert
+ assert result == expected_result
+ mock_boto_client.assert_called_once_with(
+ 's3',
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ region_name=S3_REGION,
+ )
+ mock_parse_s3_path.assert_called_once_with(mock_output_destination_aws)
+ mock_get_latest_run_from_file.assert_called_once_with(
+ 's3_client', 'bucket', 'prefix_BENCHMARK_DATES.json'
+ )
+
+
+@patch('sdgym.run_benchmark.upload_benchmark_results.SDGymResultsExplorer')
+@patch('sdgym.run_benchmark.upload_benchmark_results.write_uploaded_marker')
+@patch('sdgym.run_benchmark.upload_benchmark_results.LOGGER')
+@patch('sdgym.run_benchmark.upload_benchmark_results.OUTPUT_DESTINATION_AWS')
+@patch('sdgym.run_benchmark.upload_benchmark_results.LocalResultsWriter')
+@patch('sdgym.run_benchmark.upload_benchmark_results.os.environ.get')
+@patch('sdgym.run_benchmark.upload_benchmark_results.get_df_to_plot')
+@patch('sdgym.run_benchmark.upload_benchmark_results.generate_graph')
+@patch('sdgym.run_benchmark.upload_benchmark_results.upload_to_drive')
+def test_upload_results(
+ mock_upload_to_drive,
+ mock_generate_graph,
+ mock_get_df_to_plot,
+ mock_os_environ_get,
+ mock_local_results_writer,
+ mock_output_destination_aws,
+ mock_logger,
+ mock_write_uploaded_marker,
+ mock_sdgym_results_explorer,
+):
+ """Test the `upload_results` method."""
+ # Setup
+ aws_access_key_id = 'my_access_key'
+ aws_secret_access_key = 'my_secret_key'
+ folder_infos = {'folder_name': 'SDGym_results_10_01_2023', 'date': '10_01_2023'}
+ run_name = folder_infos['folder_name']
+ s3_client = Mock()
+ bucket = 'bucket'
+ prefix = 'prefix'
+ result_explorer_instance = mock_sdgym_results_explorer.return_value
+ result_explorer_instance.all_runs_complete.return_value = True
+ result_explorer_instance.summarize.return_value = ('summary', 'results')
+ mock_os_environ_get.return_value = '/tmp/sdgym_results'
+ mock_get_df_to_plot.return_value = 'df_to_plot'
+ mock_generate_graph.return_value = 'plot_image'
+ datas = {
+ 'Wins': 'summary',
+ '10_01_2023_Detailed_results': 'results',
+ '10_01_2023_plot_data': 'df_to_plot',
+ '10_01_2023_plot_image': 'plot_image',
+ }
+
+ # Run
+ upload_results(
+ aws_access_key_id,
+ aws_secret_access_key,
+ folder_infos,
+ s3_client,
+ bucket,
+ prefix,
+ github_env=None,
+ )
+
+ # Assert
+ mock_upload_to_drive.assert_called_once_with(
+ '/tmp/sdgym_results/SDGym Monthly Run.xlsx', SDGYM_FILE_ID
+ )
+ mock_generate_graph.assert_called_once()
+ mock_logger.info.assert_called_once_with(
+ f'Run {run_name} is complete! Proceeding with summarization...'
+ )
+ mock_sdgym_results_explorer.assert_called_once_with(
+ mock_output_destination_aws,
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ )
+ result_explorer_instance.all_runs_complete.assert_called_once_with(run_name)
+ result_explorer_instance.summarize.assert_called_once_with(run_name)
+ mock_write_uploaded_marker.assert_called_once_with(s3_client, bucket, prefix, run_name)
+ mock_local_results_writer.return_value.write_xlsx.assert_called_once_with(
+ datas, '/tmp/sdgym_results/SDGym Monthly Run.xlsx'
+ )
+ mock_get_df_to_plot.assert_called_once_with('results')
+
+
+@patch('sdgym.run_benchmark.upload_benchmark_results.SDGymResultsExplorer')
+@patch('sdgym.run_benchmark.upload_benchmark_results.write_uploaded_marker')
+@patch('sdgym.run_benchmark.upload_benchmark_results.LOGGER')
+@patch('sdgym.run_benchmark.upload_benchmark_results.OUTPUT_DESTINATION_AWS')
+def test_upload_results_not_all_runs_complete(
+ mock_output_destination_aws,
+ mock_logger,
+ mock_write_uploaded_marker,
+ mock_sdgym_results_explorer,
+):
+ """Test the `upload_results` when not all runs are complete."""
+ # Setup
+ aws_access_key_id = 'my_access_key'
+ aws_secret_access_key = 'my_secret_key'
+ folder_infos = {'folder_name': 'SDGym_results_10_01_2023', 'date': '10_01_2023'}
+ run_name = folder_infos['folder_name']
+ s3_client = Mock()
+ bucket = 'bucket'
+ prefix = 'prefix'
+ result_explorer_instance = mock_sdgym_results_explorer.return_value
+ result_explorer_instance.all_runs_complete.return_value = False
+ result_explorer_instance.summarize.return_value = ('summary', 'results')
+
+ # Run
+ with pytest.raises(SystemExit, match='0'):
+ upload_results(
+ aws_access_key_id,
+ aws_secret_access_key,
+ folder_infos,
+ s3_client,
+ bucket,
+ prefix,
+ github_env=None,
+ )
+
+ # Assert
+ mock_logger.warning.assert_called_once_with(f'Run {run_name} is not complete yet. Exiting.')
+ mock_sdgym_results_explorer.assert_called_once_with(
+ mock_output_destination_aws,
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ )
+ result_explorer_instance.all_runs_complete.assert_called_once_with(run_name)
+ result_explorer_instance.summarize.assert_not_called()
+ mock_write_uploaded_marker.assert_not_called()
+
+
+@patch('sdgym.run_benchmark.upload_benchmark_results.get_result_folder_name_and_s3_vars')
+@patch('sdgym.run_benchmark.upload_benchmark_results.upload_results')
+@patch('sdgym.run_benchmark.upload_benchmark_results.upload_already_done')
+@patch('sdgym.run_benchmark.upload_benchmark_results.LOGGER')
+@patch('sdgym.run_benchmark.upload_benchmark_results.os.getenv')
+def test_main_already_upload(
+ mock_getenv,
+ mock_logger,
+ mock_upload_already_done,
+ mock_upload_results,
+ mock_get_result_folder_name_and_s3_vars,
+):
+ """Test the `method` when results are already uploaded."""
+ # Setup
+ mock_getenv.side_effect = ['my_access_key', 'my_secret_key', None]
+ folder_infos = {'folder_name': 'SDGym_results_10_01_2023', 'date': '10_01_2023'}
+ mock_get_result_folder_name_and_s3_vars.return_value = (
+ folder_infos,
+ 's3_client',
+ 'bucket',
+ 'prefix',
+ )
+ mock_upload_already_done.return_value = True
+ expected_log_message = 'Benchmark results have already been uploaded. Exiting.'
+
+ # Run
+ with pytest.raises(SystemExit, match='0'):
+ main()
+
+ # Assert
+ mock_get_result_folder_name_and_s3_vars.assert_called_once_with(
+ 'my_access_key', 'my_secret_key'
+ )
+ mock_logger.warning.assert_called_once_with(expected_log_message)
+ mock_upload_results.assert_not_called()
+
+
+@patch('sdgym.run_benchmark.upload_benchmark_results.get_result_folder_name_and_s3_vars')
+@patch('sdgym.run_benchmark.upload_benchmark_results.upload_results')
+@patch('sdgym.run_benchmark.upload_benchmark_results.upload_already_done')
+@patch('sdgym.run_benchmark.upload_benchmark_results.os.getenv')
+def test_main(
+ mock_getenv,
+ mock_upload_already_done,
+ mock_upload_results,
+ mock_get_result_folder_name_and_s3_vars,
+):
+ """Test the `main` method."""
+ # Setup
+ mock_getenv.side_effect = ['my_access_key', 'my_secret_key', None]
+ folder_infos = {'folder_name': 'SDGym_results_10_11_2024', 'date': '10_11_2024'}
+ mock_get_result_folder_name_and_s3_vars.return_value = (
+ folder_infos,
+ 's3_client',
+ 'bucket',
+ 'prefix',
+ )
+ mock_upload_already_done.return_value = False
+
+ # Run
+ main()
+
+ # Assert
+ mock_get_result_folder_name_and_s3_vars.assert_called_once_with(
+ 'my_access_key', 'my_secret_key'
+ )
+ mock_upload_already_done.assert_called_once_with(
+ 's3_client', 'bucket', 'prefix', folder_infos['folder_name']
+ )
+ mock_upload_results.assert_called_once_with(
+ 'my_access_key', 'my_secret_key', folder_infos, 's3_client', 'bucket', 'prefix', None
+ )
diff --git a/tests/unit/run_benchmark/test_utils.py b/tests/unit/run_benchmark/test_utils.py
new file mode 100644
index 00000000..7aa8c126
--- /dev/null
+++ b/tests/unit/run_benchmark/test_utils.py
@@ -0,0 +1,198 @@
+from unittest.mock import patch
+
+import pandas as pd
+import pytest
+
+from sdgym.run_benchmark.utils import (
+ DEBUG_SLACK_CHANNEL,
+ OUTPUT_DESTINATION_AWS,
+ _get_slack_client,
+ get_df_to_plot,
+ get_result_folder_name,
+ get_s3_console_link,
+ post_benchmark_launch_message,
+ post_benchmark_uploaded_message,
+ post_slack_message,
+)
+
+
+def test_get_result_folder_name():
+ """Test the `get_result_folder_name` method."""
+ # Setup
+ expected_error_message = 'Invalid date format: invalid-date. Expected YYYY-MM-DD.'
+
+ # Run and Assert
+ assert get_result_folder_name('2023-10-01') == 'SDGym_results_10_01_2023'
+ with pytest.raises(ValueError, match=expected_error_message):
+ get_result_folder_name('invalid-date')
+
+
+def test_get_s3_console_link():
+ """Test the `get_s3_console_link` method."""
+ # Setup
+ bucket = 'my-bucket'
+ prefix = 'my-prefix/'
+
+ # Run
+ link = get_s3_console_link(bucket, prefix)
+
+ # Assert
+ expected_link = (
+ f'https://s3.console.aws.amazon.com/s3/buckets/%7Bbucket%7D?prefix={prefix}&showversions=false'
+ )
+ assert link == expected_link
+
+
+@patch('sdgym.run_benchmark.utils.WebClient')
+@patch('sdgym.run_benchmark.utils.os.getenv')
+def test_get_slack_client(mock_getenv, mock_web_client):
+ """Test the `_get_slack_client` method."""
+ # Setup
+ mock_getenv.return_value = 'xoxb-test-token'
+
+ # Run
+ client = _get_slack_client()
+
+ # Assert
+ mock_getenv.assert_called_once_with('SLACK_TOKEN')
+ mock_web_client.assert_called_once_with(token='xoxb-test-token')
+ assert client is mock_web_client.return_value
+
+
+@patch('sdgym.run_benchmark.utils._get_slack_client')
+def test_post_slack_message(mock_get_slack_client):
+ """Test the `post_slack_message` method."""
+ # Setup
+ mock_slack_client = mock_get_slack_client.return_value
+ channel = 'test-channel'
+ text = 'Test message'
+
+ # Run
+ post_slack_message(channel, text)
+
+ # Assert
+ mock_get_slack_client.assert_called_once()
+ mock_slack_client.chat_postMessage.assert_called_once_with(channel=channel, text=text)
+
+
+@patch('sdgym.run_benchmark.utils.post_slack_message')
+@patch('sdgym.run_benchmark.utils.get_s3_console_link')
+@patch('sdgym.run_benchmark.utils.parse_s3_path')
+@patch('sdgym.run_benchmark.utils.get_result_folder_name')
+def test_post_benchmark_launch_message(
+ mock_get_result_folder_name,
+ mock_parse_s3_path,
+ mock_get_s3_console_link,
+ mock_post_slack_message,
+):
+ """Test the `post_benchmark_launch_message` method."""
+ # Setup
+ date_str = '2023-10-01'
+ folder_name = 'SDGym_results_10_01_2023'
+ mock_get_result_folder_name.return_value = folder_name
+ mock_parse_s3_path.return_value = ('my-bucket', 'my-prefix/')
+ url = 'https://s3.console.aws.amazon.com/'
+ mock_get_s3_console_link.return_value = url
+ expected_body = (
+ 'π SDGym benchmark has been launched! EC2 Instances are running. '
+ f'Intermediate results can be found <{url}|here>.\n'
+ )
+ # Run
+ post_benchmark_launch_message(date_str)
+
+ # Assert
+ mock_get_result_folder_name.assert_called_once_with(date_str)
+ mock_parse_s3_path.assert_called_once_with(OUTPUT_DESTINATION_AWS)
+ mock_get_s3_console_link.assert_called_once_with('my-bucket', f'my-prefix/{folder_name}/')
+ mock_post_slack_message.assert_called_once_with(DEBUG_SLACK_CHANNEL, expected_body)
+
+
+@patch('sdgym.run_benchmark.utils.post_slack_message')
+@patch('sdgym.run_benchmark.utils.get_s3_console_link')
+@patch('sdgym.run_benchmark.utils.parse_s3_path')
+def test_post_benchmark_uploaded_message(
+ mock_parse_s3_path,
+ mock_get_s3_console_link,
+ mock_post_slack_message,
+):
+ """Test the `post_benchmark_uploaded_message` method."""
+ # Setup
+ folder_name = 'SDGym_results_10_01_2023'
+ mock_parse_s3_path.return_value = ('my-bucket', 'my-prefix/')
+ url = 'https://s3.console.aws.amazon.com/'
+ mock_get_s3_console_link.return_value = url
+ expected_body = (
+ f'π€Έπ»ββοΈ SDGym benchmark results for *{folder_name}* are available! ποΈββοΈ\n'
+ f'Check the results <{url} |here>'
+ )
+
+ # Run
+ post_benchmark_uploaded_message(folder_name)
+
+ # Assert
+ mock_post_slack_message.assert_called_once_with(DEBUG_SLACK_CHANNEL, expected_body)
+ mock_parse_s3_path.assert_called_once_with(OUTPUT_DESTINATION_AWS)
+ mock_get_s3_console_link.assert_called_once_with(
+ 'my-bucket', f'my-prefix/{folder_name}/{folder_name}_summary.csv'
+ )
+
+
+@patch('sdgym.run_benchmark.utils.post_slack_message')
+@patch('sdgym.run_benchmark.utils.get_s3_console_link')
+@patch('sdgym.run_benchmark.utils.parse_s3_path')
+def test_post_benchmark_uploaded_message_with_commit(
+ mock_parse_s3_path,
+ mock_get_s3_console_link,
+ mock_post_slack_message,
+):
+ """Test the `post_benchmark_uploaded_message` with a commit URL."""
+ # Setup
+ folder_name = 'SDGym_results_10_01_2023'
+ commit_url = 'https://github.com/user/repo/pull/123'
+ mock_parse_s3_path.return_value = ('my-bucket', 'my-prefix/')
+ url = 'https://s3.console.aws.amazon.com/'
+ mock_get_s3_console_link.return_value = url
+ expected_body = (
+ f'π€Έπ»ββοΈ SDGym benchmark results for *{folder_name}* are available! ποΈββοΈ\n'
+ f'Check the results <{url} |here> '
+ f'or on GitHub: <{commit_url}|Commit Link>\n'
+ )
+
+ # Run
+ post_benchmark_uploaded_message(folder_name, commit_url)
+
+ # Assert
+ mock_post_slack_message.assert_called_once_with(DEBUG_SLACK_CHANNEL, expected_body)
+ mock_parse_s3_path.assert_called_once_with(OUTPUT_DESTINATION_AWS)
+ mock_get_s3_console_link.assert_called_once_with(
+ 'my-bucket', f'my-prefix/{folder_name}/{folder_name}_summary.csv'
+ )
+
+
+def test_get_df_to_plot():
+ """Test the `get_df_to_plot` method."""
+ # Setup
+ data = pd.DataFrame({
+ 'Synthesizer': (
+ ['GaussianCopulaSynthesizer'] * 2 + ['CTGANSynthesizer'] * 2 + ['TVAESynthesizer'] * 2
+ ),
+ 'Dataset': ['Dataset1', 'Dataset2'] * 3,
+ 'Train_Time': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
+ 'Sample_Time': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
+ 'Quality_Score': [0.8, 0.9, 0.7, 0.6, 0.5, 0.4],
+ })
+
+ # Run
+ result = get_df_to_plot(data)
+
+ # Assert
+ expected_result = pd.DataFrame({
+ 'Synthesizer': ['GaussianCopula', 'CTGAN', 'TVAE'],
+ 'Aggregated_Time': [3.3, 7.7, 12.1],
+ 'Quality_Score': [0.85, 0.65, 0.45],
+ 'Log10 Aggregated_Time': [0.5185139398778875, 0.8864907251724818, 1.08278537031645],
+ 'Pareto': [True, False, False],
+ 'Color': ['#01E0C9', '#03AFF1', '#03AFF1'],
+ 'Marker': ['circle', 'square', 'diamond'],
+ })
+ pd.testing.assert_frame_equal(result, expected_result)
diff --git a/tests/unit/sdgym_result_explorer/test_result_explorer.py b/tests/unit/sdgym_result_explorer/test_result_explorer.py
index 3f64a78c..a9dd27bf 100644
--- a/tests/unit/sdgym_result_explorer/test_result_explorer.py
+++ b/tests/unit/sdgym_result_explorer/test_result_explorer.py
@@ -191,10 +191,13 @@ def test_load_real_data(self, mock_get_dataset_paths, mock_load_dataset, tmp_pat
# Assert
mock_get_dataset_paths.assert_called_once_with(
- datasets=[dataset_name], aws_key=None, aws_secret=None
+ datasets=[dataset_name], aws_access_key_id=None, aws_secret_access_key=None
)
mock_load_dataset.assert_called_once_with(
- 'single_table', 'path/to/adult/dataset', aws_key=None, aws_secret=None
+ 'single_table',
+ 'path/to/adult/dataset',
+ aws_access_key_id=None,
+ aws_secret_access_key=None,
)
pd.testing.assert_frame_equal(real_data, expected_data)
diff --git a/tests/unit/sdgym_result_explorer/test_result_handler.py b/tests/unit/sdgym_result_explorer/test_result_handler.py
index 96cbc4c8..4a3ef0cf 100644
--- a/tests/unit/sdgym_result_explorer/test_result_handler.py
+++ b/tests/unit/sdgym_result_explorer/test_result_handler.py
@@ -64,10 +64,9 @@ def test__get_summarize_table(self):
# Assert
expected_summary = pd.DataFrame({
- '07_15_2025 - # datasets: 3 - sdgym version: 0.9.0': [2, 1],
'Synthesizer': ['Synth1', 'Synth2'],
+ '07_15_2025 - # datasets: 3 - sdgym version: 0.9.0': [2, 1],
})
- expected_summary = expected_summary.set_index('Synthesizer')
pd.testing.assert_frame_equal(result, expected_summary)
def test_get_column_name_infos(self):
diff --git a/tests/unit/test_benchmark.py b/tests/unit/test_benchmark.py
index 33a846d2..4073d618 100644
--- a/tests/unit/test_benchmark.py
+++ b/tests/unit/test_benchmark.py
@@ -25,6 +25,7 @@
benchmark_single_table_aws,
)
from sdgym.result_writer import LocalResultsWriter
+from sdgym.s3 import S3_REGION
from sdgym.synthesizers import GaussianCopulaSynthesizer
@@ -371,21 +372,23 @@ def test__validate_output_destination(tmp_path):
@patch('sdgym.benchmark._validate_aws_inputs')
-def test__validate_output_destination_with_aws_keys(mock_validate):
+def test__validate_output_destination_with_aws_access_key_ids(mock_validate):
"""Test the `_validate_output_destination` function with AWS keys."""
# Setup
output_destination = 's3://my-bucket/path/to/file'
- aws_keys = {
+ aws_access_key_ids = {
'aws_access_key_id': 'mock_access_key',
'aws_secret_access_key': 'mock_secret_key',
}
# Run
- _validate_output_destination(output_destination, aws_keys)
+ _validate_output_destination(output_destination, aws_access_key_ids)
# Assert
mock_validate.assert_called_once_with(
- output_destination, aws_keys['aws_access_key_id'], aws_keys['aws_secret_access_key']
+ output_destination,
+ aws_access_key_ids['aws_access_key_id'],
+ aws_access_key_ids['aws_secret_access_key'],
)
@@ -542,9 +545,12 @@ def test_setup_output_destination_aws(mock_get_run_id_increment):
@patch('sdgym.benchmark.boto3.client')
@patch('sdgym.benchmark._check_write_permissions')
-def test_validate_aws_inputs_valid(mock_check_write_permissions, mock_boto3_client):
+@patch('sdgym.benchmark.Config')
+def test_validate_aws_inputs_valid(mock_config, mock_check_write_permissions, mock_boto3_client):
"""Test `_validate_aws_inputs` with valid inputs and credentials."""
# Setup
+ config_mock = Mock()
+ mock_config.return_value = config_mock
valid_url = 's3://my-bucket/some/path'
s3_client_mock = Mock()
mock_boto3_client.return_value = s3_client_mock
@@ -557,7 +563,11 @@ def test_validate_aws_inputs_valid(mock_check_write_permissions, mock_boto3_clie
# Assert
mock_boto3_client.assert_called_once_with(
- 's3', aws_access_key_id='AKIA...', aws_secret_access_key='SECRET'
+ 's3',
+ aws_access_key_id='AKIA...',
+ aws_secret_access_key='SECRET',
+ region_name=S3_REGION,
+ config=config_mock,
)
s3_client_mock.head_bucket.assert_called_once_with(Bucket='my-bucket')
mock_check_write_permissions.assert_called_once_with(s3_client_mock, 'my-bucket')
diff --git a/tests/unit/test_datasets.py b/tests/unit/test_datasets.py
index 1c9cdd04..f498d5f3 100644
--- a/tests/unit/test_datasets.py
+++ b/tests/unit/test_datasets.py
@@ -14,6 +14,7 @@
get_dataset_paths,
load_dataset,
)
+from sdgym.s3 import S3_REGION
class AnyConfigWith:
@@ -110,8 +111,8 @@ def test__download_dataset_private_bucket(boto3_mock, tmpdir):
modality = 'single_table'
dataset = 'my_dataset'
bucket = 's3://my_bucket'
- aws_key = 'my_key'
- aws_secret = 'my_secret'
+ aws_access_key_id = 'my_key'
+ aws_secret_access_key = 'my_secret'
bytesio = io.BytesIO()
with ZipFile(bytesio, mode='w') as zf:
@@ -130,13 +131,16 @@ def test__download_dataset_private_bucket(boto3_mock, tmpdir):
dataset,
datasets_path=str(tmpdir),
bucket=bucket,
- aws_key=aws_key,
- aws_secret=aws_secret,
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
)
# asserts
boto3_mock.client.assert_called_once_with(
- 's3', aws_access_key_id=aws_key, aws_secret_access_key=aws_secret
+ 's3',
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ region_name=S3_REGION,
)
s3_mock.get_object.assert_called_once_with(
Bucket='my_bucket', Key=f'{modality.upper()}/{dataset}.zip'
diff --git a/tests/unit/test_s3.py b/tests/unit/test_s3.py
index fff365ed..757653b3 100644
--- a/tests/unit/test_s3.py
+++ b/tests/unit/test_s3.py
@@ -9,6 +9,7 @@
from botocore.exceptions import NoCredentialsError
from sdgym.s3 import (
+ S3_REGION,
_get_s3_client,
_upload_dataframe_to_s3,
_upload_pickle_to_s3,
@@ -120,8 +121,8 @@ def test_write_file(tmpdir):
Input:
- contents of the local file
- path to the local file
- - aws_key is None
- - aws_secret is None
+ - aws_access_key_id is None
+ - aws_secret_access_key is None
Output:
- None
@@ -151,14 +152,14 @@ def test_write_file_s3(boto3_mock):
Input:
- contents of the s3 file
- path to the s3 file location
- - aws_key for aws authentication
- - aws_secret for aws authentication
+ - aws_access_key_id for aws authentication
+ - aws_secret_access_key for aws authentication
Output:
- None
Side effects:
- - s3 client creation with aws credentials (aws_key, aws_secret)
+ - s3 client creation with aws credentials (aws_access_key_id, aws_secret_access_key)
- s3 method call to create a file in the given bucket with the
given contents
"""
@@ -167,18 +168,21 @@ def test_write_file_s3(boto3_mock):
bucket_name = 'my-bucket'
key = 'test.txt'
path = f's3://{bucket_name}/{key}'
- aws_key = 'my-key'
- aws_secret = 'my-secret'
+ aws_access_key_id = 'my-key'
+ aws_secret_access_key = 'my-secret'
s3_mock = Mock()
boto3_mock.client.return_value = s3_mock
# run
- write_file(content_str.encode('utf-8'), path, aws_key, aws_secret)
+ write_file(content_str.encode('utf-8'), path, aws_access_key_id, aws_secret_access_key)
# asserts
boto3_mock.client.assert_called_once_with(
- 's3', aws_access_key_id=aws_key, aws_secret_access_key=aws_secret
+ 's3',
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ region_name=S3_REGION,
)
s3_mock.put_object.assert_called_once_with(
Bucket=bucket_name,
@@ -199,8 +203,8 @@ def test_write_csv(write_file_mock):
Input:
- data to be written to the csv file
- path of the desired csv file
- - aws_key is None
- - aws_secret is None
+ - aws_access_key_id is None
+ - aws_secret_access_key is None
Output:
- None
@@ -307,6 +311,7 @@ def test__get_s3_client_with_credentials(mock_boto_client):
's3',
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
+ region_name=S3_REGION,
)
mock_s3_client.head_bucket.assert_called_once_with(Bucket='my-bucket')
diff --git a/tests/unit/test_summary.py b/tests/unit/test_summary.py
index b34d6fdd..650ec3ec 100644
--- a/tests/unit/test_summary.py
+++ b/tests/unit/test_summary.py
@@ -26,8 +26,8 @@ def test_make_summary_spreadsheet(
The ``make_summary_spreadsheet`` function is expected to extract the correct
columns from the input file and add them to the correct sheets. It should
- then use the ``aws_key`` and ``aws_secret`` provided to call ``sdgym.s3.write_file``
- and save the output document.
+ then use the ``aws_access_key_id`` and ``aws_secret_access_key`` provided to
+ call ``sdgym.s3.write_file`` and save the output document.
Input:
- file path to results csv.