Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ jobs:
python -m pip install --upgrade pip
python -m pip install --no-cache-dir invoke .[test]
- name: Run integration tests
env:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
run: invoke integration
- if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.14
name: Upload integration codecov report
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/minimum.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,7 @@ jobs:
python -m pip install --no-cache-dir invoke .[test]

- name: Test with minimum versions
env:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
run: invoke minimum
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,8 @@ dependencies = [
'XlsxWriter>=1.2.8',
"rdt>=1.18.2;python_version<'3.14'",
"rdt>=1.20.0;python_version>='3.14'",
"sdmetrics>=0.21.0;python_version<'3.14'",
"sdmetrics>=0.26.0;python_version>='3.14'",
"sdv>=1.21.0;python_version<'3.14'",
"sdv>=1.33.0;python_version>='3.14'",
"sdmetrics>=0.28.0",
"sdv @ git+https://github.com/sdv-dev/SDV.git@main",
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Update after SDV release

]

[project.urls]
Expand Down
146 changes: 101 additions & 45 deletions sdgym/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,14 @@
)
from sdmetrics.single_table import DCRBaselineProtection

from sdgym.datasets import _load_dataset_with_client, get_dataset_paths
from sdgym.datasets import (
SDV_DATASETS_PRIVATE_BUCKET,
SDV_DATASETS_PUBLIC_BUCKET,
_get_dataset_bucket_mapping,
_load_dataset_with_client,
_load_sdv_demo_dataset,
get_dataset_paths,
)
from sdgym.errors import BenchmarkError, SDGymError
from sdgym.metrics import get_metrics
from sdgym.progress import TqdmLogger
Expand Down Expand Up @@ -123,6 +130,14 @@ class JobArgs(NamedTuple):
output_directions: Optional[dict]


class ResolvedDataset(NamedTuple):
"""Resolved dataset data and metadata for benchmark job creation."""

name: str
data: Any
metadata: Any


def _import_and_validate_synthesizers(synthesizers, custom_synthesizers, modality):
"""Import user-provided synthesizer and validate modality and uniqueness.

Expand Down Expand Up @@ -323,29 +338,14 @@ def _setup_output_destination(
return paths


def _generate_job_args_list(
limit_dataset_size,
def _resolve_dataset(
modality,
sdv_datasets,
additional_datasets_folder,
sdmetrics,
timeout,
output_destination,
compute_quality_score,
compute_diagnostic_score,
compute_privacy_score,
synthesizers,
s3_client,
modality,
limit_dataset_size,
s3_client=None,
):
sdv_datasets = (
[]
if sdv_datasets is None
else get_dataset_paths(
modality=modality,
datasets=sdv_datasets,
s3_client=s3_client,
)
)
sdv_dataset_names = [] if sdv_datasets is None else sdv_datasets
additional_datasets = (
[]
if additional_datasets_folder is None
Expand All @@ -359,13 +359,76 @@ def _generate_job_args_list(
s3_client=s3_client,
)
)
datasets = sdv_datasets + additional_datasets

dataset_bucket_mapping = None
if sdv_dataset_names:
dataset_bucket_mapping = _get_dataset_bucket_mapping(
modality,
[SDV_DATASETS_PUBLIC_BUCKET, SDV_DATASETS_PRIVATE_BUCKET],
s3_client,
skip_inaccessible=True,
)
missing_names = [name for name in sdv_dataset_names if name not in dataset_bucket_mapping]
if missing_names:
missing_to_print = "', '".join(missing_names)
raise ValueError(
f'The following SDV demo datasets were not found in the expected buckets: '
f"'{missing_to_print}'. Please check that the dataset names are correct."
)

datasets = []
for dataset_name in sdv_dataset_names:
data, metadata = _load_sdv_demo_dataset(
modality=modality,
dataset_name=dataset_name,
bucket=dataset_bucket_mapping.get(dataset_name),
s3_client=s3_client,
limit_dataset_size=limit_dataset_size,
)
datasets.append(ResolvedDataset(dataset_name, data, metadata))

for dataset in additional_datasets:
data, metadata = _load_dataset_with_client(
modality,
dataset,
limit_dataset_size=limit_dataset_size,
s3_client=s3_client,
)
datasets.append(ResolvedDataset(dataset.name, data, metadata))

return datasets


def _generate_job_args_list(
limit_dataset_size,
sdv_datasets,
additional_datasets_folder,
sdmetrics,
timeout,
output_destination,
compute_quality_score,
compute_diagnostic_score,
compute_privacy_score,
synthesizers,
s3_client,
modality,
):
if not synthesizers:
return []

datasets = _resolve_dataset(
modality=modality,
sdv_datasets=sdv_datasets,
additional_datasets_folder=additional_datasets_folder,
limit_dataset_size=limit_dataset_size,
s3_client=s3_client,
)
synthesizer_names = [synthesizer['name'] for synthesizer in synthesizers]
dataset_names = [dataset.name for dataset in datasets]
paths = _setup_output_destination(
output_destination, synthesizer_names, dataset_names, modality=modality, s3_client=s3_client
)
job_tuples = []
job_args_list = []
for dataset in datasets:
for synthesizer in synthesizers:
if paths:
Expand All @@ -377,29 +440,22 @@ def _generate_job_args_list(
final_name = synthesizer['name']

synthesizer['name'] = final_name
job_tuples.append((synthesizer, dataset))

job_args_list = []
for synthesizer, dataset in job_tuples:
data, metadata_dict = _load_dataset_with_client(
modality, dataset, limit_dataset_size=limit_dataset_size, s3_client=s3_client
)
path = paths.get(dataset.name, {}).get(synthesizer['name'], None)
job_args_list.append(
JobArgs(
synthesizer=synthesizer,
data=data,
metadata=metadata_dict,
metrics=sdmetrics,
timeout=timeout,
compute_quality_score=compute_quality_score,
compute_diagnostic_score=compute_diagnostic_score,
compute_privacy_score=compute_privacy_score,
dataset_name=dataset.name,
modality=modality,
output_directions=path,
path = paths.get(dataset.name, {}).get(synthesizer['name'], None)
job_args_list.append(
JobArgs(
synthesizer=synthesizer,
data=dataset.data,
metadata=dataset.metadata,
metrics=sdmetrics,
timeout=timeout,
compute_quality_score=compute_quality_score,
compute_diagnostic_score=compute_diagnostic_score,
compute_privacy_score=compute_privacy_score,
dataset_name=dataset.name,
modality=modality,
output_directions=path,
)
)
)

return job_args_list

Expand Down
93 changes: 93 additions & 0 deletions sdgym/datasets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
"""SDGym module to handle datasets."""

import io
import logging
import os
from pathlib import Path

import appdirs
import botocore
import numpy as np
import pandas as pd
from sdv.datasets.demo import (
_find_data_zip_key,
_get_data_from_bucket,
_get_first_v1_metadata_bytes,
_get_metadata,
_list_objects,
_load_data_from_zip,
download_demo,
)

from sdgym._dataset_utils import (
_get_dataset_subset,
Expand Down Expand Up @@ -251,6 +262,88 @@ def _get_available_datasets(
return pd.DataFrame(datasets_info)


def _get_dataset_bucket_mapping(modality, buckets, s3_client, skip_inaccessible=False):
"""Map SDV demo dataset names to the bucket they should be loaded from."""
dataset_buckets = {}
for bucket in buckets:
try:
available_datasets = _get_available_datasets(
modality,
bucket=bucket,
s3_client=s3_client,
)
except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as error:
if skip_inaccessible:
LOGGER.info("Skipping inaccessible bucket '%s': %s", bucket, error)
continue

raise ValueError(
f"Bucket '{bucket}' is not accessible with the provided credentials."
) from error

for dataset_name in available_datasets['dataset_name'].tolist():
existing_bucket = dataset_buckets.get(dataset_name)
if existing_bucket and bucket != SDV_DATASETS_PRIVATE_BUCKET:
continue

dataset_buckets[dataset_name] = bucket

return dataset_buckets


def _load_private_sdv_demo_dataset(modality, dataset_name, bucket, s3_client=None):
"""Load an SDV demo dataset from a private bucket with an SDGym S3 client."""
bucket_name = _get_bucket_name(bucket)
s3_client = s3_client or get_s3_client()
dataset_prefix = f'{modality}/{dataset_name}/'
contents = _list_objects(dataset_prefix, bucket=bucket_name, client=s3_client)
data_key = _find_data_zip_key(contents, dataset_prefix, bucket_name)
data_bytes = io.BytesIO(_get_data_from_bucket(data_key, bucket=bucket_name, client=s3_client))
metadata_bytes = _get_first_v1_metadata_bytes(
contents, dataset_prefix, bucket=bucket_name, client=s3_client
)
data = _load_data_from_zip(data_bytes, bucket_name, dataset_name)
if modality != 'multi_table':
data = data.popitem()[1]

metadata = _get_metadata(metadata_bytes, dataset_name)
return data, metadata.to_dict()


def _load_sdv_demo_dataset(
modality,
dataset_name,
bucket,
s3_client=None,
limit_dataset_size=False,
):
"""Load an SDV demo dataset from the resolved public or private bucket."""
_validate_modality(modality)
bucket_name = _get_bucket_name(bucket)
try:
data, metadata = download_demo(
modality=modality,
dataset_name=dataset_name,
s3_bucket_name=bucket_name,
)
metadata = metadata.to_dict()
except ValueError:
if bucket != SDV_DATASETS_PRIVATE_BUCKET:
raise

data, metadata = _load_private_sdv_demo_dataset(
modality,
dataset_name,
bucket,
s3_client=s3_client,
)

if limit_dataset_size:
data, metadata = _get_dataset_subset(data, metadata, modality=modality)

return data, metadata


def load_dataset(
modality,
dataset,
Expand Down
24 changes: 24 additions & 0 deletions tests/integration/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,30 @@ def test_benchmark_multi_table_basic_synthesizers():
]


@pytest.mark.skipif(
not os.getenv('AWS_ACCESS_KEY_ID') or not os.getenv('AWS_SECRET_ACCESS_KEY'),
reason='MovieLens benchmark requires AWS credentials for private dataset access.',
)
def test_benchmark_multi_table_private_dataset():
"""Test multi-table benchmark with private dataset `MovieLens`."""
# Setup
datasets = ['MovieLens']
synthesizers = ['HMASynthesizer']
timeout = 10

# Run
result = benchmark_multi_table(
synthesizers=synthesizers,
sdv_datasets=datasets,
timeout=timeout,
)

# Assert
assert result['Dataset'].tolist() == ['MovieLens', 'MovieLens']
assert result['Synthesizer'].tolist() == ['HMASynthesizer', 'MultiTableUniformSynthesizer']
assert result['Quality_Score'].tolist() == [None, None]


def test_benchmark_multi_table_with_output_destination_multiple_runs(tmp_path):
"""Test saving in ``output_destination`` with multiple runs in multi-table mode.

Expand Down
Loading
Loading