From 82ecd6ae3917e4f3c7380dbcc14896491674d121 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 25 May 2026 11:02:33 +0100 Subject: [PATCH 1/6] load dataset only once --- sdgym/benchmark.py | 37 ++++++++++----------- tests/unit/test_benchmark.py | 63 ++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 18 deletions(-) diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index 30d8e608..b4665d7d 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -365,7 +365,7 @@ def _generate_job_args_list( paths = _setup_output_destination( output_destination, synthesizer_names, dataset_names, modality=modality, s3_client=s3_client ) - job_tuples = [] + job_tuples_by_dataset = defaultdict(list) for dataset in datasets: for synthesizer in synthesizers: if paths: @@ -377,29 +377,30 @@ def _generate_job_args_list( final_name = synthesizer['name'] synthesizer['name'] = final_name - job_tuples.append((synthesizer, dataset)) + job_tuples_by_dataset[dataset].append(synthesizer) job_args_list = [] - for synthesizer, dataset in job_tuples: + for dataset, synthesizers in job_tuples_by_dataset.items(): data, metadata_dict = _load_dataset_with_client( modality, dataset, limit_dataset_size=limit_dataset_size, s3_client=s3_client ) - path = paths.get(dataset.name, {}).get(synthesizer['name'], None) - job_args_list.append( - JobArgs( - synthesizer=synthesizer, - data=data, - metadata=metadata_dict, - metrics=sdmetrics, - timeout=timeout, - compute_quality_score=compute_quality_score, - compute_diagnostic_score=compute_diagnostic_score, - compute_privacy_score=compute_privacy_score, - dataset_name=dataset.name, - modality=modality, - output_directions=path, + for synthesizer in synthesizers: + path = paths.get(dataset.name, {}).get(synthesizer['name'], None) + job_args_list.append( + JobArgs( + synthesizer=synthesizer, + data=data, + metadata=metadata_dict, + metrics=sdmetrics, + timeout=timeout, + compute_quality_score=compute_quality_score, + compute_diagnostic_score=compute_diagnostic_score, + compute_privacy_score=compute_privacy_score, + dataset_name=dataset.name, + modality=modality, + output_directions=path, + ) ) - ) return job_args_list diff --git a/tests/unit/test_benchmark.py b/tests/unit/test_benchmark.py index 9c8065af..f020637b 100644 --- a/tests/unit/test_benchmark.py +++ b/tests/unit/test_benchmark.py @@ -1097,6 +1097,69 @@ def test__generate_job_args_list_local_root_additional_folder( ) +@patch('sdgym.benchmark.get_dataset_paths') +@patch('sdgym.benchmark._setup_output_destination') +@patch('sdgym.benchmark._load_dataset_with_client') +def test__generate_job_args_list_loads_each_dataset_once( + mock_load_dataset, + mock__setup_output_destination, + mock_get_dataset_paths, +): + """Test that each dataset is loaded once even when there are multiple synthesizers.""" + # Setup + dataset_a = Path('/dummy/single_table/datasetA') + dataset_b = Path('/dummy/single_table/datasetB') + mock_get_dataset_paths.return_value = [dataset_a, dataset_b] + mock__setup_output_destination.return_value = {} + data_a = Mock(name='data_a') + metadata_a = Mock(name='metadata_a') + data_b = Mock(name='data_b') + metadata_b = Mock(name='metadata_b') + mock_load_dataset.side_effect = [(data_a, metadata_a), (data_b, metadata_b)] + synthesizers = [ + {'name': 'GaussianCopulaSynthesizer'}, + {'name': 'UniformSynthesizer'}, + ] + s3_client = Mock() + + # Run + job_args_list = _generate_job_args_list( + limit_dataset_size=True, + sdv_datasets=['datasetA', 'datasetB'], + additional_datasets_folder=None, + sdmetrics=None, + timeout=None, + output_destination=None, + compute_quality_score=False, + compute_diagnostic_score=False, + compute_privacy_score=False, + synthesizers=synthesizers, + s3_client=s3_client, + modality='single_table', + ) + + # Assert + mock_load_dataset.assert_has_calls([ + call('single_table', dataset_a, limit_dataset_size=True, s3_client=s3_client), + call('single_table', dataset_b, limit_dataset_size=True, s3_client=s3_client), + ]) + assert mock_load_dataset.call_count == 2 + assert len(job_args_list) == 4 + assert [job.dataset_name for job in job_args_list] == [ + 'datasetA', + 'datasetA', + 'datasetB', + 'datasetB', + ] + assert [job.data for job in job_args_list] == [data_a, data_a, data_b, data_b] + assert [job.metadata for job in job_args_list] == [ + metadata_a, + metadata_a, + metadata_b, + metadata_b, + ] + + @patch('sdgym.benchmark.get_dataset_paths') @patch('sdgym.benchmark._setup_output_destination') @patch('sdgym.benchmark._load_dataset_with_client') From 3ac74c39a33607157f739edf3d69be9951f85dea Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 25 May 2026 13:48:42 +0100 Subject: [PATCH 2/6] use downlad_demo from sdv --- .github/workflows/integration.yml | 3 + .github/workflows/minimum.yml | 3 + pyproject.toml | 3 +- sdgym/benchmark.py | 102 ++++++++++++---- sdgym/datasets.py | 117 ++++++++++++++++++ tests/integration/test_benchmark.py | 24 ++++ tests/unit/test_benchmark.py | 86 ++++++++++--- tests/unit/test_datasets.py | 179 ++++++++++++++++++++++++++++ tests/unit/test_s3.py | 8 +- 9 files changed, 484 insertions(+), 41 deletions(-) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index e344e5bd..d475d6b6 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -28,6 +28,9 @@ jobs: python -m pip install --upgrade pip python -m pip install --no-cache-dir invoke .[test] - name: Run integration tests + env: + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} run: invoke integration - if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.14 name: Upload integration codecov report diff --git a/.github/workflows/minimum.yml b/.github/workflows/minimum.yml index 7eaf8edc..8e143056 100644 --- a/.github/workflows/minimum.yml +++ b/.github/workflows/minimum.yml @@ -39,4 +39,7 @@ jobs: python -m pip install --no-cache-dir invoke .[test] - name: Test with minimum versions + env: + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} run: invoke minimum diff --git a/pyproject.toml b/pyproject.toml index 46c6f0a9..f1c918cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,8 +66,7 @@ dependencies = [ "rdt>=1.20.0;python_version>='3.14'", "sdmetrics>=0.21.0;python_version<'3.14'", "sdmetrics>=0.26.0;python_version>='3.14'", - "sdv>=1.21.0;python_version<'3.14'", - "sdv>=1.33.0;python_version>='3.14'", + "sdv @ git+https://github.com/sdv-dev/SDV.git@main", ] [project.urls] diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index b4665d7d..7fa7be98 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -40,7 +40,14 @@ ) from sdmetrics.single_table import DCRBaselineProtection -from sdgym.datasets import _load_dataset_with_client, get_dataset_paths +from sdgym.datasets import ( + SDV_DATASETS_PRIVATE_BUCKET, + SDV_DATASETS_PUBLIC_BUCKET, + _get_dataset_bucket_mapping, + _load_dataset_with_client, + _load_sdv_demo_dataset, + get_dataset_paths, +) from sdgym.errors import BenchmarkError, SDGymError from sdgym.metrics import get_metrics from sdgym.progress import TqdmLogger @@ -123,6 +130,14 @@ class JobArgs(NamedTuple): output_directions: Optional[dict] +class ResolvedDataset(NamedTuple): + """Resolved dataset data and metadata for benchmark job creation.""" + + name: str + data: Any + metadata: Any + + def _import_and_validate_synthesizers(synthesizers, custom_synthesizers, modality): """Import user-provided synthesizer and validate modality and uniqueness. @@ -323,6 +338,33 @@ def _setup_output_destination( return paths +def _resolve_dataset( + modality, + dataset, + limit_dataset_size, + source, + s3_client=None, + dataset_bucket_mapping=None, +): + if source == 'sdv_demo': + data, metadata = _load_sdv_demo_dataset( + modality=modality, + dataset_name=dataset, + dataset_bucket_mapping=dataset_bucket_mapping, + s3_client=s3_client, + limit_dataset_size=limit_dataset_size, + ) + return ResolvedDataset(dataset, data, metadata) + + data, metadata = _load_dataset_with_client( + modality, + dataset, + limit_dataset_size=limit_dataset_size, + s3_client=s3_client, + ) + return ResolvedDataset(dataset.name, data, metadata) + + def _generate_job_args_list( limit_dataset_size, sdv_datasets, @@ -337,15 +379,7 @@ def _generate_job_args_list( s3_client, modality, ): - sdv_datasets = ( - [] - if sdv_datasets is None - else get_dataset_paths( - modality=modality, - datasets=sdv_datasets, - s3_client=s3_client, - ) - ) + sdv_dataset_names = [] if sdv_datasets is None else sdv_datasets additional_datasets = ( [] if additional_datasets_folder is None @@ -359,13 +393,45 @@ def _generate_job_args_list( s3_client=s3_client, ) ) - datasets = sdv_datasets + additional_datasets + if not synthesizers: + return [] + + dataset_bucket_mapping = None + if sdv_dataset_names: + dataset_bucket_mapping = _get_dataset_bucket_mapping( + modality, + [SDV_DATASETS_PUBLIC_BUCKET, SDV_DATASETS_PRIVATE_BUCKET], + s3_client, + skip_inaccessible=True, + ) + + datasets = [ + _resolve_dataset( + modality=modality, + dataset=dataset, + limit_dataset_size=limit_dataset_size, + source='sdv_demo', + s3_client=s3_client, + dataset_bucket_mapping=dataset_bucket_mapping, + ) + for dataset in sdv_dataset_names + ] + datasets.extend( + _resolve_dataset( + modality=modality, + dataset=dataset, + limit_dataset_size=limit_dataset_size, + source='additional', + s3_client=s3_client, + ) + for dataset in additional_datasets + ) synthesizer_names = [synthesizer['name'] for synthesizer in synthesizers] dataset_names = [dataset.name for dataset in datasets] paths = _setup_output_destination( output_destination, synthesizer_names, dataset_names, modality=modality, s3_client=s3_client ) - job_tuples_by_dataset = defaultdict(list) + job_args_list = [] for dataset in datasets: for synthesizer in synthesizers: if paths: @@ -377,20 +443,12 @@ def _generate_job_args_list( final_name = synthesizer['name'] synthesizer['name'] = final_name - job_tuples_by_dataset[dataset].append(synthesizer) - - job_args_list = [] - for dataset, synthesizers in job_tuples_by_dataset.items(): - data, metadata_dict = _load_dataset_with_client( - modality, dataset, limit_dataset_size=limit_dataset_size, s3_client=s3_client - ) - for synthesizer in synthesizers: path = paths.get(dataset.name, {}).get(synthesizer['name'], None) job_args_list.append( JobArgs( synthesizer=synthesizer, - data=data, - metadata=metadata_dict, + data=dataset.data, + metadata=dataset.metadata, metrics=sdmetrics, timeout=timeout, compute_quality_score=compute_quality_score, diff --git a/sdgym/datasets.py b/sdgym/datasets.py index 8b923b11..6be54e48 100644 --- a/sdgym/datasets.py +++ b/sdgym/datasets.py @@ -1,12 +1,23 @@ """SDGym module to handle datasets.""" +import io import logging import os from pathlib import Path import appdirs +import botocore import numpy as np import pandas as pd +from sdv.datasets.demo import ( + _find_data_zip_key, + _get_data_from_bucket, + _get_first_v1_metadata_bytes, + _get_metadata, + _list_objects, + _load_data_from_zip, + download_demo, +) from sdgym._dataset_utils import ( _get_dataset_subset, @@ -35,6 +46,13 @@ def _get_bucket_name(bucket): return bucket[len(S3_PREFIX) :] if bucket.startswith(S3_PREFIX) else bucket +def _metadata_to_dict(metadata): + if isinstance(metadata, dict): + return metadata + + return metadata.to_dict() + + def _raise_dataset_not_found_error( s3_client, bucket_name, @@ -251,6 +269,105 @@ def _get_available_datasets( return pd.DataFrame(datasets_info) +def _get_dataset_bucket_mapping(modality, buckets, s3_client, skip_inaccessible=False): + """Map SDV demo dataset names to the bucket they should be loaded from.""" + dataset_buckets = {} + for bucket in buckets: + try: + available_datasets = _get_available_datasets( + modality, + bucket=bucket, + s3_client=s3_client, + ) + except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as error: + if skip_inaccessible: + LOGGER.info("Skipping inaccessible bucket '%s': %s", bucket, error) + continue + + raise ValueError( + f"Bucket '{bucket}' is not accessible with the provided credentials." + ) from error + + for dataset_name in available_datasets['dataset_name'].tolist(): + existing_bucket = dataset_buckets.get(dataset_name) + if existing_bucket and bucket != SDV_DATASETS_PRIVATE_BUCKET: + continue + + dataset_buckets[dataset_name] = bucket + + return dataset_buckets + + +def _load_private_sdv_demo_dataset(modality, dataset_name, bucket, s3_client=None): + """Load an SDV demo dataset from a private bucket with an SDGym S3 client.""" + bucket_name = _get_bucket_name(bucket) + s3_client = s3_client or get_s3_client() + dataset_prefix = f'{modality}/{dataset_name}/' + contents = _list_objects(dataset_prefix, bucket=bucket_name, client=s3_client) + data_key = _find_data_zip_key(contents, dataset_prefix, bucket_name) + data_bytes = io.BytesIO(_get_data_from_bucket(data_key, bucket=bucket_name, client=s3_client)) + metadata_bytes = _get_first_v1_metadata_bytes( + contents, dataset_prefix, bucket=bucket_name, client=s3_client + ) + data = _load_data_from_zip(data_bytes, bucket_name, dataset_name) + if modality != 'multi_table': + data = data.popitem()[1] + + metadata = _get_metadata(metadata_bytes, dataset_name) + return data, _metadata_to_dict(metadata) + + +def _load_sdv_demo_dataset( + modality, + dataset_name, + dataset_bucket_mapping=None, + s3_client=None, + limit_dataset_size=False, +): + """Load an SDV demo dataset from the resolved public or private bucket.""" + _validate_modality(modality) + buckets = [SDV_DATASETS_PUBLIC_BUCKET, SDV_DATASETS_PRIVATE_BUCKET] + if dataset_bucket_mapping is None: + dataset_bucket_mapping = _get_dataset_bucket_mapping( + modality, + buckets, + s3_client or get_s3_client(), + skip_inaccessible=True, + ) + + bucket = dataset_bucket_mapping.get(dataset_name) + if bucket is None: + buckets_list = ', '.join(buckets) + raise ValueError( + f"Dataset '{dataset_name}' not found in SDV demo buckets for modality " + f"'{modality}'. Checked buckets: {buckets_list}." + ) + + bucket_name = _get_bucket_name(bucket) + try: + data, metadata = download_demo( + modality=modality, + dataset_name=dataset_name, + s3_bucket_name=bucket_name, + ) + metadata = _metadata_to_dict(metadata) + except ValueError: + if bucket != SDV_DATASETS_PRIVATE_BUCKET: + raise + + data, metadata = _load_private_sdv_demo_dataset( + modality, + dataset_name, + bucket, + s3_client=s3_client, + ) + + if limit_dataset_size: + data, metadata = _get_dataset_subset(data, metadata, modality=modality) + + return data, metadata + + def load_dataset( modality, dataset, diff --git a/tests/integration/test_benchmark.py b/tests/integration/test_benchmark.py index 0c7fe86a..cdf87685 100644 --- a/tests/integration/test_benchmark.py +++ b/tests/integration/test_benchmark.py @@ -929,6 +929,30 @@ def test_benchmark_multi_table_basic_synthesizers(): ] +@pytest.mark.skipif( + not os.getenv('AWS_ACCESS_KEY_ID') or not os.getenv('AWS_SECRET_ACCESS_KEY'), + reason='MovieLens benchmark requires AWS credentials for private dataset access.', +) +def test_benchmark_multi_table_private_dataset(): + """Test multi-table benchmark with private dataset `MovieLens`.""" + # Setup + datasets = ['MovieLens'] + synthesizers = ['HMASynthesizer'] + timeout = 10 + + # Run + result = benchmark_multi_table( + synthesizers=synthesizers, + sdv_datasets=datasets, + timeout=timeout, + ) + + # Assert + assert result['Dataset'].tolist() == ['MovieLens', 'MovieLens'] + assert result['Synthesizer'].tolist() == ['HMASynthesizer', 'MultiTableUniformSynthesizer'] + assert result['Quality_Score'].tolist() == [None, None] + + def test_benchmark_multi_table_with_output_destination_multiple_runs(tmp_path): """Test saving in ``output_destination`` with multiple runs in multi-table mode. diff --git a/tests/unit/test_benchmark.py b/tests/unit/test_benchmark.py index f020637b..2e3ee068 100644 --- a/tests/unit/test_benchmark.py +++ b/tests/unit/test_benchmark.py @@ -202,10 +202,18 @@ def test__get_metainfo_increment_local(mock_logger, tmp_path): assert result_3 == 3 +@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark._load_sdv_demo_dataset') @patch('sdgym.benchmark.tqdm.tqdm') -def test_benchmark_single_table_progress_bar(tqdm_mock): +def test_benchmark_single_table_progress_bar( + tqdm_mock, mock_load_sdv_demo_dataset, mock_get_dataset_bucket_mapping +): """Test that the benchmarking function updates the progress bar on one line.""" # Setup + data = pd.DataFrame({'column': [1, 2]}) + metadata = {'tables': {'student_placements': {'columns': {'column': {'sdtype': 'numerical'}}}}} + mock_get_dataset_bucket_mapping.return_value = {'student_placements': 'bucket'} + mock_load_sdv_demo_dataset.return_value = data, metadata scores_mock = MagicMock() scores_mock.__iter__.return_value = [ pd.DataFrame({ @@ -228,11 +236,22 @@ def test_benchmark_single_table_progress_bar(tqdm_mock): tqdm_mock.assert_called_once_with(ANY, total=1, position=0, leave=True) +@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark._load_sdv_demo_dataset') @patch('sdgym.benchmark._score') @patch('sdgym.benchmark.multiprocessing') -def test_benchmark_single_table_with_timeout(mock_multiprocessing, mock__score): +def test_benchmark_single_table_with_timeout( + mock_multiprocessing, + mock__score, + mock_load_sdv_demo_dataset, + mock_get_dataset_bucket_mapping, +): """Test that benchmark runs with timeout.""" # Setup + data = pd.DataFrame({'column': [1, 2]}) + metadata = {'tables': {'student_placements': {'columns': {'column': {'sdtype': 'numerical'}}}}} + mock_get_dataset_bucket_mapping.return_value = {'student_placements': 'bucket'} + mock_load_sdv_demo_dataset.return_value = data, metadata mocked_process = mock_multiprocessing.Process.return_value manager = mock_multiprocessing.Manager.return_value manager_dict = {'timeout': True, 'Error': 'Synthesizer Timeout'} @@ -1097,25 +1116,26 @@ def test__generate_job_args_list_local_root_additional_folder( ) -@patch('sdgym.benchmark.get_dataset_paths') +@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark._load_sdv_demo_dataset') @patch('sdgym.benchmark._setup_output_destination') -@patch('sdgym.benchmark._load_dataset_with_client') +@patch('sdgym.benchmark.get_dataset_paths') def test__generate_job_args_list_loads_each_dataset_once( - mock_load_dataset, - mock__setup_output_destination, mock_get_dataset_paths, + mock__setup_output_destination, + mock_load_sdv_demo_dataset, + mock_get_dataset_bucket_mapping, ): """Test that each dataset is loaded once even when there are multiple synthesizers.""" # Setup - dataset_a = Path('/dummy/single_table/datasetA') - dataset_b = Path('/dummy/single_table/datasetB') - mock_get_dataset_paths.return_value = [dataset_a, dataset_b] + mock_get_dataset_paths.return_value = [] + mock_get_dataset_bucket_mapping.return_value = {'datasetA': 'bucket', 'datasetB': 'bucket'} mock__setup_output_destination.return_value = {} data_a = Mock(name='data_a') metadata_a = Mock(name='metadata_a') data_b = Mock(name='data_b') metadata_b = Mock(name='metadata_b') - mock_load_dataset.side_effect = [(data_a, metadata_a), (data_b, metadata_b)] + mock_load_sdv_demo_dataset.side_effect = [(data_a, metadata_a), (data_b, metadata_b)] synthesizers = [ {'name': 'GaussianCopulaSynthesizer'}, {'name': 'UniformSynthesizer'}, @@ -1139,11 +1159,24 @@ def test__generate_job_args_list_loads_each_dataset_once( ) # Assert - mock_load_dataset.assert_has_calls([ - call('single_table', dataset_a, limit_dataset_size=True, s3_client=s3_client), - call('single_table', dataset_b, limit_dataset_size=True, s3_client=s3_client), + mock_get_dataset_paths.assert_not_called() + mock_load_sdv_demo_dataset.assert_has_calls([ + call( + modality='single_table', + dataset_name='datasetA', + dataset_bucket_mapping={'datasetA': 'bucket', 'datasetB': 'bucket'}, + s3_client=s3_client, + limit_dataset_size=True, + ), + call( + modality='single_table', + dataset_name='datasetB', + dataset_bucket_mapping={'datasetA': 'bucket', 'datasetB': 'bucket'}, + s3_client=s3_client, + limit_dataset_size=True, + ), ]) - assert mock_load_dataset.call_count == 2 + assert mock_load_sdv_demo_dataset.call_count == 2 assert len(job_args_list) == 4 assert [job.dataset_name for job in job_args_list] == [ 'datasetA', @@ -1175,7 +1208,10 @@ def test__generate_job_args_list_s3_root_additional_folder( get_dataset_paths_mock.return_value = [dataset_path] s3_client = Mock() mock__setup_output_destination.return_value = {} - mock_load_dataset.return_value = pd.DataFrame({'col1': [1, 2], 'col2': [3, 4]}) + mock_load_dataset.return_value = ( + pd.DataFrame({'col1': [1, 2], 'col2': [3, 4]}), + {'tables': {}}, + ) # Run _generate_job_args_list( @@ -1211,9 +1247,27 @@ def test__generate_job_args_list_s3_root_additional_folder( ) -def test_benchmark_single_table_no_warning_uniform_synthesizer(recwarn): +@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark._load_sdv_demo_dataset') +def test_benchmark_single_table_no_warning_uniform_synthesizer( + mock_load_sdv_demo_dataset, mock_get_dataset_bucket_mapping, recwarn +): """Test that no UserWarning is raised when running `UniformSynthesizer`.""" # Setup + data = pd.DataFrame({'column': [1, 2, 3]}) + metadata = { + 'tables': { + 'fake_hotel_guests': { + 'columns': { + 'column': { + 'sdtype': 'numerical', + } + } + } + } + } + mock_get_dataset_bucket_mapping.return_value = {'fake_hotel_guests': 'bucket'} + mock_load_sdv_demo_dataset.return_value = data, metadata expected_result = pd.DataFrame({ 'Synthesizer': {0: 'UniformSynthesizer'}, 'Dataset': {0: 'fake_hotel_guests'}, diff --git a/tests/unit/test_datasets.py b/tests/unit/test_datasets.py index f03c7d00..c770846b 100644 --- a/tests/unit/test_datasets.py +++ b/tests/unit/test_datasets.py @@ -1,16 +1,22 @@ from pathlib import Path from unittest.mock import Mock, call, patch +import botocore import numpy as np +import pandas as pd import pytest from sdgym.datasets import ( DATASETS_PATH, + SDV_DATASETS_PRIVATE_BUCKET, + SDV_DATASETS_PUBLIC_BUCKET, _download_dataset, _genereate_dataset_info, _get_bucket_name, + _get_dataset_bucket_mapping, _get_dataset_path_and_download, _load_dataset_with_client, + _load_sdv_demo_dataset, _path_contains_data_and_metadata, _validate_modality, get_data_and_metadata_from_path, @@ -362,6 +368,179 @@ def test_get_bucket_name_local_folder(): assert bucket_name == 'bucket-name' +@patch('sdgym.datasets._get_available_datasets') +def test__get_dataset_bucket_mapping_prefers_private(get_available_mock): + """Test that datasets are mapped to private when duplicated across buckets.""" + # Setup + get_available_mock.side_effect = [ + pd.DataFrame({'dataset_name': ['public_only', 'duplicate']}), + pd.DataFrame({'dataset_name': ['private_only', 'duplicate']}), + ] + + # Run + result = _get_dataset_bucket_mapping( + 'single_table', + [SDV_DATASETS_PUBLIC_BUCKET, SDV_DATASETS_PRIVATE_BUCKET], + s3_client='s3_client', + ) + + # Assert + assert result == { + 'public_only': SDV_DATASETS_PUBLIC_BUCKET, + 'private_only': SDV_DATASETS_PRIVATE_BUCKET, + 'duplicate': SDV_DATASETS_PRIVATE_BUCKET, + } + get_available_mock.assert_has_calls([ + call('single_table', bucket=SDV_DATASETS_PUBLIC_BUCKET, s3_client='s3_client'), + call('single_table', bucket=SDV_DATASETS_PRIVATE_BUCKET, s3_client='s3_client'), + ]) + + +@patch('sdgym.datasets._get_available_datasets') +def test__get_dataset_bucket_mapping_skips_inaccessible_bucket(get_available_mock): + """Test inaccessible buckets can be skipped while building the mapping.""" + # Setup + error = botocore.exceptions.ClientError( + {'Error': {'Code': 'AccessDenied', 'Message': 'denied'}}, + 'ListObjectsV2', + ) + get_available_mock.side_effect = [ + pd.DataFrame({'dataset_name': ['public_only']}), + error, + ] + + # Run + result = _get_dataset_bucket_mapping( + 'single_table', + [SDV_DATASETS_PUBLIC_BUCKET, SDV_DATASETS_PRIVATE_BUCKET], + s3_client='s3_client', + skip_inaccessible=True, + ) + + # Assert + assert result == {'public_only': SDV_DATASETS_PUBLIC_BUCKET} + + +@patch('sdgym.datasets._get_available_datasets') +def test__get_dataset_bucket_mapping_raises_inaccessible_bucket(get_available_mock): + """Test inaccessible buckets raise by default.""" + # Setup + get_available_mock.side_effect = botocore.exceptions.ClientError( + {'Error': {'Code': 'AccessDenied', 'Message': 'denied'}}, + 'ListObjectsV2', + ) + + # Run and Assert + with pytest.raises(ValueError, match="Bucket 's3://sdv-datasets-private' is not accessible"): + _get_dataset_bucket_mapping( + 'single_table', + [SDV_DATASETS_PRIVATE_BUCKET], + s3_client='s3_client', + ) + + +@patch('sdgym.datasets.download_demo') +def test__load_sdv_demo_dataset_uses_download_demo(download_demo_mock): + """Test SDV demo datasets are loaded through SDV's download_demo.""" + # Setup + data = pd.DataFrame({'column': [1, 2]}) + metadata = Mock() + metadata.to_dict.return_value = {'tables': {'demo': {'columns': {'column': {}}}}} + download_demo_mock.return_value = data, metadata + + # Run + result = _load_sdv_demo_dataset( + modality='single_table', + dataset_name='demo', + dataset_bucket_mapping={'demo': SDV_DATASETS_PUBLIC_BUCKET}, + ) + + # Assert + result_data, result_metadata = result + pd.testing.assert_frame_equal(result_data, data) + assert result_metadata == metadata.to_dict.return_value + download_demo_mock.assert_called_once_with( + modality='single_table', + dataset_name='demo', + s3_bucket_name='sdv-datasets-public', + ) + + +@patch('sdgym.datasets._load_private_sdv_demo_dataset') +@patch('sdgym.datasets.download_demo') +def test__load_sdv_demo_dataset_falls_back_for_private_bucket( + download_demo_mock, load_private_mock +): + """Test SDV private-bucket errors fall back to SDGym private loading.""" + # Setup + data = pd.DataFrame({'column': [1, 2]}) + metadata = {'tables': {'demo': {'columns': {'column': {}}}}} + download_demo_mock.side_effect = ValueError('Private buckets are only supported') + load_private_mock.return_value = data, metadata + + # Run + result = _load_sdv_demo_dataset( + modality='single_table', + dataset_name='demo', + dataset_bucket_mapping={'demo': SDV_DATASETS_PRIVATE_BUCKET}, + s3_client='s3_client', + ) + + # Assert + result_data, result_metadata = result + pd.testing.assert_frame_equal(result_data, data) + assert result_metadata == metadata + load_private_mock.assert_called_once_with( + 'single_table', + 'demo', + SDV_DATASETS_PRIVATE_BUCKET, + s3_client='s3_client', + ) + + +@patch('sdgym.datasets._get_dataset_subset') +@patch('sdgym.datasets.download_demo') +def test__load_sdv_demo_dataset_limits_dataset_size(download_demo_mock, subset_mock): + """Test SDV demo dataset loading applies the dataset size limit.""" + # Setup + data = pd.DataFrame({'column': [1, 2]}) + metadata = Mock() + metadata.to_dict.return_value = {'tables': {'demo': {'columns': {'column': {}}}}} + limited_data = pd.DataFrame({'column': [1]}) + limited_metadata = {'tables': {'demo': {'columns': {'column': {}}}}} + download_demo_mock.return_value = data, metadata + subset_mock.return_value = limited_data, limited_metadata + + # Run + result = _load_sdv_demo_dataset( + modality='single_table', + dataset_name='demo', + dataset_bucket_mapping={'demo': SDV_DATASETS_PUBLIC_BUCKET}, + limit_dataset_size=True, + ) + + # Assert + result_data, result_metadata = result + pd.testing.assert_frame_equal(result_data, limited_data) + assert result_metadata == limited_metadata + subset_mock.assert_called_once_with( + data, + metadata.to_dict.return_value, + modality='single_table', + ) + + +def test__load_sdv_demo_dataset_raises_when_dataset_not_found(): + """Test a clear error is raised when a demo dataset is absent from all buckets.""" + # Run and Assert + with pytest.raises(ValueError, match="Dataset 'missing' not found in SDV demo buckets"): + _load_sdv_demo_dataset( + modality='single_table', + dataset_name='missing', + dataset_bucket_mapping={}, + ) + + @patch('sdgym.datasets._get_dataset_path_and_download') @patch('sdgym.datasets._path_contains_data_and_metadata', return_value=True) @patch('sdgym.datasets.Path') diff --git a/tests/unit/test_s3.py b/tests/unit/test_s3.py index 00bb7d8d..3dd0a20f 100644 --- a/tests/unit/test_s3.py +++ b/tests/unit/test_s3.py @@ -320,11 +320,14 @@ def test__get_s3_client_with_credentials(mock_boto_client): mock_s3_client.head_bucket.assert_called_once_with(Bucket='my-bucket') -def test__get_s3_client_errors(): +@patch('sdgym.s3.boto3.client') +def test__get_s3_client_errors(mock_boto_client): """Test `_get_s3_client` raises error for invalid input.""" # Setup output_destination = 's3:/' expected_error = re.escape(f'Invalid S3 URL: {output_destination}') + mock_s3_client = mock_boto_client.return_value + mock_s3_client.head_bucket.side_effect = NoCredentialsError() # Run and Assert with pytest.raises(ValueError, match=expected_error): @@ -333,6 +336,9 @@ def test__get_s3_client_errors(): with pytest.raises(NoCredentialsError, match='Unable to locate credentials'): _get_s3_client('s3://bucket_name/') + mock_boto_client.assert_called_once_with('s3') + mock_s3_client.head_bucket.assert_called_once_with(Bucket='bucket_name') + def test__read_data_from_bucket_key_reads_body(): """Test that the function reads data from S3 object body.""" From 195780fe6bf02ee18d27fa05bc50135601237ba3 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 25 May 2026 14:11:37 +0100 Subject: [PATCH 3/6] update _resolve_dataset --- pyproject.toml | 3 +- sdgym/_benchmark/benchmark.py | 3 +- sdgym/benchmark.py | 100 ++++++++++++------------ tests/unit/test_benchmark.py | 138 +++++++++++++++++++++++++++++++++- 4 files changed, 187 insertions(+), 57 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f1c918cd..83b325d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,8 +64,7 @@ dependencies = [ 'XlsxWriter>=1.2.8', "rdt>=1.18.2;python_version<'3.14'", "rdt>=1.20.0;python_version>='3.14'", - "sdmetrics>=0.21.0;python_version<'3.14'", - "sdmetrics>=0.26.0;python_version>='3.14'", + "sdmetrics>=0.28.0", "sdv @ git+https://github.com/sdv-dev/SDV.git@main", ] diff --git a/sdgym/_benchmark/benchmark.py b/sdgym/_benchmark/benchmark.py index 19fb8de5..9ad30906 100644 --- a/sdgym/_benchmark/benchmark.py +++ b/sdgym/_benchmark/benchmark.py @@ -13,6 +13,7 @@ DEFAULT_SINGLE_TABLE_DATASETS, DEFAULT_SINGLE_TABLE_SYNTHESIZERS, S3_REGION, + SDGYM_BRANCH_INSTALL_COMMAND, _ensure_uniform_included, _generate_job_args_list, _get_empty_dataframe, @@ -206,7 +207,7 @@ def _get_user_data_script( log "======== Install Dependencies ==========" pip install --upgrade pip {sdv_install} - pip install "sdgym[all]" + {SDGYM_BRANCH_INSTALL_COMMAND} {gpu_block} diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index 7fa7be98..7191eee3 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -73,6 +73,10 @@ TIMEOUT = 345600 LOGGER = logging.getLogger(__name__) +SDGYM_BRANCH_INSTALL_COMMAND = ( + 'pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git' + '@issue-604-2-private-bucket"' +) DEFAULT_SINGLE_TABLE_SYNTHESIZERS = [ 'GaussianCopulaSynthesizer', 'CTGANSynthesizer', @@ -340,44 +344,10 @@ def _setup_output_destination( def _resolve_dataset( modality, - dataset, - limit_dataset_size, - source, - s3_client=None, - dataset_bucket_mapping=None, -): - if source == 'sdv_demo': - data, metadata = _load_sdv_demo_dataset( - modality=modality, - dataset_name=dataset, - dataset_bucket_mapping=dataset_bucket_mapping, - s3_client=s3_client, - limit_dataset_size=limit_dataset_size, - ) - return ResolvedDataset(dataset, data, metadata) - - data, metadata = _load_dataset_with_client( - modality, - dataset, - limit_dataset_size=limit_dataset_size, - s3_client=s3_client, - ) - return ResolvedDataset(dataset.name, data, metadata) - - -def _generate_job_args_list( - limit_dataset_size, sdv_datasets, additional_datasets_folder, - sdmetrics, - timeout, - output_destination, - compute_quality_score, - compute_diagnostic_score, - compute_privacy_score, - synthesizers, - s3_client, - modality, + limit_dataset_size, + s3_client=None, ): sdv_dataset_names = [] if sdv_datasets is None else sdv_datasets additional_datasets = ( @@ -393,8 +363,6 @@ def _generate_job_args_list( s3_client=s3_client, ) ) - if not synthesizers: - return [] dataset_bucket_mapping = None if sdv_dataset_names: @@ -405,26 +373,52 @@ def _generate_job_args_list( skip_inaccessible=True, ) - datasets = [ - _resolve_dataset( + datasets = [] + for dataset_name in sdv_dataset_names: + data, metadata = _load_sdv_demo_dataset( modality=modality, - dataset=dataset, - limit_dataset_size=limit_dataset_size, - source='sdv_demo', - s3_client=s3_client, + dataset_name=dataset_name, dataset_bucket_mapping=dataset_bucket_mapping, + s3_client=s3_client, + limit_dataset_size=limit_dataset_size, ) - for dataset in sdv_dataset_names - ] - datasets.extend( - _resolve_dataset( - modality=modality, - dataset=dataset, + datasets.append(ResolvedDataset(dataset_name, data, metadata)) + + for dataset in additional_datasets: + data, metadata = _load_dataset_with_client( + modality, + dataset, limit_dataset_size=limit_dataset_size, - source='additional', s3_client=s3_client, ) - for dataset in additional_datasets + datasets.append(ResolvedDataset(dataset.name, data, metadata)) + + return datasets + + +def _generate_job_args_list( + limit_dataset_size, + sdv_datasets, + additional_datasets_folder, + sdmetrics, + timeout, + output_destination, + compute_quality_score, + compute_diagnostic_score, + compute_privacy_score, + synthesizers, + s3_client, + modality, +): + if not synthesizers: + return [] + + datasets = _resolve_dataset( + modality=modality, + sdv_datasets=sdv_datasets, + additional_datasets_folder=additional_datasets_folder, + limit_dataset_size=limit_dataset_size, + s3_client=s3_client, ) synthesizer_names = [synthesizer['name'] for synthesizer in synthesizers] dataset_names = [dataset.name for dataset in datasets] @@ -1427,7 +1421,7 @@ def _get_user_data_script(access_key, secret_key, region_name, script_content): echo "======== Install Dependencies in venv ============" pip install --upgrade pip - pip install sdgym[all] + {SDGYM_BRANCH_INSTALL_COMMAND} pip install s3fs echo "======== Write Script ===========" diff --git a/tests/unit/test_benchmark.py b/tests/unit/test_benchmark.py index 2e3ee068..b8d3a40b 100644 --- a/tests/unit/test_benchmark.py +++ b/tests/unit/test_benchmark.py @@ -21,6 +21,7 @@ _generate_job_args_list, _get_metainfo_increment, _import_and_validate_synthesizers, + _resolve_dataset, _setup_output_destination, _setup_output_destination_aws, _store_job_args_in_s3, @@ -233,6 +234,19 @@ def test_benchmark_single_table_progress_bar( ) # Assert + mock_get_dataset_bucket_mapping.assert_called_once_with( + 'single_table', + ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], + None, + skip_inaccessible=True, + ) + mock_load_sdv_demo_dataset.assert_called_once_with( + modality='single_table', + dataset_name='student_placements', + dataset_bucket_mapping={'student_placements': 'bucket'}, + s3_client=None, + limit_dataset_size=False, + ) tqdm_mock.assert_called_once_with(ANY, total=1, position=0, leave=True) @@ -265,6 +279,19 @@ def test_benchmark_single_table_with_timeout( ) # Assert + mock_get_dataset_bucket_mapping.assert_called_once_with( + 'single_table', + ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], + None, + skip_inaccessible=True, + ) + mock_load_sdv_demo_dataset.assert_called_once_with( + modality='single_table', + dataset_name='student_placements', + dataset_bucket_mapping={'student_placements': 'bucket'}, + s3_client=None, + limit_dataset_size=False, + ) mocked_process.start.assert_called_once_with() mocked_process.join.assert_called_once_with(1) mocked_process.terminate.assert_called_once_with() @@ -1078,10 +1105,78 @@ def test__add_adjusted_scores_missing_fallback(): assert scores.equals(expected) +@patch('sdgym.benchmark._load_dataset_with_client') +@patch('sdgym.benchmark._load_sdv_demo_dataset') +@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark.get_dataset_paths') +def test__resolve_dataset_loads_sdv_and_additional_datasets( + mock_get_dataset_paths, + mock_get_dataset_bucket_mapping, + mock_load_sdv_demo_dataset, + mock_load_dataset, + tmp_path, +): + """Test the `_resolve_dataset` method.""" + # Setup + additional_folder = tmp_path / 'additional' + additional_dataset_path = additional_folder / 'single_table' / 'custom_dataset' + sdv_data = Mock(name='sdv_data') + sdv_metadata = Mock(name='sdv_metadata') + additional_data = Mock(name='additional_data') + additional_metadata = Mock(name='additional_metadata') + s3_client = Mock() + mock_get_dataset_paths.return_value = [additional_dataset_path] + mock_get_dataset_bucket_mapping.return_value = {'sdv_dataset': 'bucket'} + mock_load_sdv_demo_dataset.return_value = sdv_data, sdv_metadata + mock_load_dataset.return_value = additional_data, additional_metadata + + # Run + result = _resolve_dataset( + modality='single_table', + sdv_datasets=['sdv_dataset'], + additional_datasets_folder=str(additional_folder), + limit_dataset_size=True, + s3_client=s3_client, + ) + + # Assert + mock_get_dataset_paths.assert_called_once_with( + modality='single_table', + bucket=str(additional_folder / 'single_table'), + s3_client=s3_client, + ) + mock_get_dataset_bucket_mapping.assert_called_once_with( + 'single_table', + ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], + s3_client, + skip_inaccessible=True, + ) + mock_load_sdv_demo_dataset.assert_called_once_with( + modality='single_table', + dataset_name='sdv_dataset', + dataset_bucket_mapping={'sdv_dataset': 'bucket'}, + s3_client=s3_client, + limit_dataset_size=True, + ) + mock_load_dataset.assert_called_once_with( + 'single_table', + additional_dataset_path, + limit_dataset_size=True, + s3_client=s3_client, + ) + assert [dataset.name for dataset in result] == ['sdv_dataset', 'custom_dataset'] + assert [dataset.data for dataset in result] == [sdv_data, additional_data] + assert [dataset.metadata for dataset in result] == [sdv_metadata, additional_metadata] + + @pytest.mark.parametrize('modality', ['single_table', 'multi_table']) +@patch('sdgym.benchmark._setup_output_destination') +@patch('sdgym.benchmark._load_dataset_with_client') @patch('sdgym.benchmark.get_dataset_paths') def test__generate_job_args_list_local_root_additional_folder( get_dataset_paths_mock, + mock_load_dataset, + mock__setup_output_destination, tmp_path, modality, ): @@ -1091,6 +1186,8 @@ def test__generate_job_args_list_local_root_additional_folder( local_root.mkdir() dataset_path = tmp_path / 'my_root' / modality / 'datasetA' get_dataset_paths_mock.return_value = [dataset_path] + mock_load_dataset.return_value = Mock(), Mock() + mock__setup_output_destination.return_value = {} # Run _generate_job_args_list( @@ -1103,7 +1200,7 @@ def test__generate_job_args_list_local_root_additional_folder( compute_quality_score=False, compute_diagnostic_score=False, compute_privacy_score=False, - synthesizers=[], + synthesizers=[{'name': 'UniformSynthesizer'}], s3_client=None, modality=modality, ) @@ -1114,6 +1211,19 @@ def test__generate_job_args_list_local_root_additional_folder( bucket=str(local_root / modality), s3_client=None, ) + mock_load_dataset.assert_called_once_with( + modality, + dataset_path, + limit_dataset_size=False, + s3_client=None, + ) + mock__setup_output_destination.assert_called_once_with( + None, + ['UniformSynthesizer'], + ['datasetA'], + modality=modality, + s3_client=None, + ) @patch('sdgym.benchmark._get_dataset_bucket_mapping') @@ -1160,6 +1270,19 @@ def test__generate_job_args_list_loads_each_dataset_once( # Assert mock_get_dataset_paths.assert_not_called() + mock_get_dataset_bucket_mapping.assert_called_once_with( + 'single_table', + ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], + s3_client, + skip_inaccessible=True, + ) + mock__setup_output_destination.assert_called_once_with( + None, + ['GaussianCopulaSynthesizer', 'UniformSynthesizer'], + ['datasetA', 'datasetB'], + modality='single_table', + s3_client=s3_client, + ) mock_load_sdv_demo_dataset.assert_has_calls([ call( modality='single_table', @@ -1289,6 +1412,19 @@ def test_benchmark_single_table_no_warning_uniform_synthesizer( ) # Assert + mock_get_dataset_bucket_mapping.assert_called_once_with( + 'single_table', + ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], + None, + skip_inaccessible=True, + ) + mock_load_sdv_demo_dataset.assert_called_once_with( + modality='single_table', + dataset_name='fake_hotel_guests', + dataset_bucket_mapping={'fake_hotel_guests': 'bucket'}, + s3_client=None, + limit_dataset_size=False, + ) warnings_text = ' '.join(str(w.message) for w in recwarn) assert 'is incompatible with transformer' not in warnings_text pd.testing.assert_frame_equal(result[expected_result.columns], expected_result) From b2bd4003efc64e110b633b4aeb3a1e0db73d20da Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 25 May 2026 15:40:18 +0100 Subject: [PATCH 4/6] fix lint --- sdgym/_benchmark/benchmark.py | 3 +-- sdgym/benchmark.py | 6 +----- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/sdgym/_benchmark/benchmark.py b/sdgym/_benchmark/benchmark.py index 9ad30906..19fb8de5 100644 --- a/sdgym/_benchmark/benchmark.py +++ b/sdgym/_benchmark/benchmark.py @@ -13,7 +13,6 @@ DEFAULT_SINGLE_TABLE_DATASETS, DEFAULT_SINGLE_TABLE_SYNTHESIZERS, S3_REGION, - SDGYM_BRANCH_INSTALL_COMMAND, _ensure_uniform_included, _generate_job_args_list, _get_empty_dataframe, @@ -207,7 +206,7 @@ def _get_user_data_script( log "======== Install Dependencies ==========" pip install --upgrade pip {sdv_install} - {SDGYM_BRANCH_INSTALL_COMMAND} + pip install "sdgym[all]" {gpu_block} diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index 7191eee3..8b19987d 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -73,10 +73,6 @@ TIMEOUT = 345600 LOGGER = logging.getLogger(__name__) -SDGYM_BRANCH_INSTALL_COMMAND = ( - 'pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git' - '@issue-604-2-private-bucket"' -) DEFAULT_SINGLE_TABLE_SYNTHESIZERS = [ 'GaussianCopulaSynthesizer', 'CTGANSynthesizer', @@ -1421,7 +1417,7 @@ def _get_user_data_script(access_key, secret_key, region_name, script_content): echo "======== Install Dependencies in venv ============" pip install --upgrade pip - {SDGYM_BRANCH_INSTALL_COMMAND} + pip install sdgym[all] pip install s3fs echo "======== Write Script ===========" From 261893a8c2f89e42de2d8e6ba252daed0eb4df25 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 25 May 2026 20:49:40 +0100 Subject: [PATCH 5/6] cleaning --- sdgym/datasets.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/sdgym/datasets.py b/sdgym/datasets.py index 6be54e48..6ef494dc 100644 --- a/sdgym/datasets.py +++ b/sdgym/datasets.py @@ -46,13 +46,6 @@ def _get_bucket_name(bucket): return bucket[len(S3_PREFIX) :] if bucket.startswith(S3_PREFIX) else bucket -def _metadata_to_dict(metadata): - if isinstance(metadata, dict): - return metadata - - return metadata.to_dict() - - def _raise_dataset_not_found_error( s3_client, bucket_name, @@ -314,7 +307,7 @@ def _load_private_sdv_demo_dataset(modality, dataset_name, bucket, s3_client=Non data = data.popitem()[1] metadata = _get_metadata(metadata_bytes, dataset_name) - return data, _metadata_to_dict(metadata) + return data, metadata.to_dict() def _load_sdv_demo_dataset( @@ -350,7 +343,7 @@ def _load_sdv_demo_dataset( dataset_name=dataset_name, s3_bucket_name=bucket_name, ) - metadata = _metadata_to_dict(metadata) + metadata = metadata.to_dict() except ValueError: if bucket != SDV_DATASETS_PRIVATE_BUCKET: raise From 3eeb8ae5d53ac3b3a6648ea1d19b9e19226a213a Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Tue, 26 May 2026 09:57:00 +0100 Subject: [PATCH 6/6] add validation --- sdgym/benchmark.py | 9 +++- sdgym/datasets.py | 19 +------- tests/unit/_benchmark_launcher/test_utils.py | 26 ++++++++++- tests/unit/test_benchmark.py | 48 +++++++++++++++++--- tests/unit/test_datasets.py | 17 ++----- 5 files changed, 78 insertions(+), 41 deletions(-) diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index 8b19987d..8638b9b3 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -368,13 +368,20 @@ def _resolve_dataset( s3_client, skip_inaccessible=True, ) + missing_names = [name for name in sdv_dataset_names if name not in dataset_bucket_mapping] + if missing_names: + missing_to_print = "', '".join(missing_names) + raise ValueError( + f'The following SDV demo datasets were not found in the expected buckets: ' + f"'{missing_to_print}'. Please check that the dataset names are correct." + ) datasets = [] for dataset_name in sdv_dataset_names: data, metadata = _load_sdv_demo_dataset( modality=modality, dataset_name=dataset_name, - dataset_bucket_mapping=dataset_bucket_mapping, + bucket=dataset_bucket_mapping.get(dataset_name), s3_client=s3_client, limit_dataset_size=limit_dataset_size, ) diff --git a/sdgym/datasets.py b/sdgym/datasets.py index 6ef494dc..4d9f32b9 100644 --- a/sdgym/datasets.py +++ b/sdgym/datasets.py @@ -313,29 +313,12 @@ def _load_private_sdv_demo_dataset(modality, dataset_name, bucket, s3_client=Non def _load_sdv_demo_dataset( modality, dataset_name, - dataset_bucket_mapping=None, + bucket, s3_client=None, limit_dataset_size=False, ): """Load an SDV demo dataset from the resolved public or private bucket.""" _validate_modality(modality) - buckets = [SDV_DATASETS_PUBLIC_BUCKET, SDV_DATASETS_PRIVATE_BUCKET] - if dataset_bucket_mapping is None: - dataset_bucket_mapping = _get_dataset_bucket_mapping( - modality, - buckets, - s3_client or get_s3_client(), - skip_inaccessible=True, - ) - - bucket = dataset_bucket_mapping.get(dataset_name) - if bucket is None: - buckets_list = ', '.join(buckets) - raise ValueError( - f"Dataset '{dataset_name}' not found in SDV demo buckets for modality " - f"'{modality}'. Checked buckets: {buckets_list}." - ) - bucket_name = _get_bucket_name(bucket) try: data, metadata = download_demo( diff --git a/tests/unit/_benchmark_launcher/test_utils.py b/tests/unit/_benchmark_launcher/test_utils.py index 10f26f4d..d2626196 100644 --- a/tests/unit/_benchmark_launcher/test_utils.py +++ b/tests/unit/_benchmark_launcher/test_utils.py @@ -516,10 +516,33 @@ def test_resolve_credentials_with_filepath_deep_merges_file_over_env( assert credentials == expected -def test_resolve_credentials_file_mode(tmp_path): +@patch('sdgym._benchmark_launcher.utils._get_env_credentials') +def test_resolve_credentials_file_mode(mock_get_env_credentials, tmp_path): """Test `resolve_credentials` returns credentials from a file merged over env defaults.""" # Setup credential_file = tmp_path / 'credentials.json' + mock_get_env_credentials.return_value = { + 'aws': { + 'AWS_ACCESS_KEY_ID': None, + 'AWS_SECRET_ACCESS_KEY': None, + }, + 'gcp': { + 'type': None, + 'project_id': None, + 'private_key_id': None, + 'private_key': None, + 'client_email': None, + 'client_id': None, + 'auth_uri': None, + 'token_uri': None, + 'auth_provider_x509_cert_url': None, + 'client_x509_cert_url': None, + }, + 'sdv_enterprise': { + 'SDV_ENTERPRISE_USERNAME': None, + 'SDV_ENTERPRISE_LICENSE_KEY': None, + }, + } file_credentials = { 'aws': { 'AWS_ACCESS_KEY_ID': 'FILE_AKIA', @@ -559,6 +582,7 @@ def test_resolve_credentials_file_mode(tmp_path): credentials = resolve_credentials(str(credential_file)) # Assert + mock_get_env_credentials.assert_called_once_with() assert credentials == expected_credentials diff --git a/tests/unit/test_benchmark.py b/tests/unit/test_benchmark.py index b8d3a40b..5f5fc166 100644 --- a/tests/unit/test_benchmark.py +++ b/tests/unit/test_benchmark.py @@ -243,7 +243,7 @@ def test_benchmark_single_table_progress_bar( mock_load_sdv_demo_dataset.assert_called_once_with( modality='single_table', dataset_name='student_placements', - dataset_bucket_mapping={'student_placements': 'bucket'}, + bucket='bucket', s3_client=None, limit_dataset_size=False, ) @@ -288,7 +288,7 @@ def test_benchmark_single_table_with_timeout( mock_load_sdv_demo_dataset.assert_called_once_with( modality='single_table', dataset_name='student_placements', - dataset_bucket_mapping={'student_placements': 'bucket'}, + bucket='bucket', s3_client=None, limit_dataset_size=False, ) @@ -1154,7 +1154,7 @@ def test__resolve_dataset_loads_sdv_and_additional_datasets( mock_load_sdv_demo_dataset.assert_called_once_with( modality='single_table', dataset_name='sdv_dataset', - dataset_bucket_mapping={'sdv_dataset': 'bucket'}, + bucket='bucket', s3_client=s3_client, limit_dataset_size=True, ) @@ -1169,6 +1169,40 @@ def test__resolve_dataset_loads_sdv_and_additional_datasets( assert [dataset.metadata for dataset in result] == [sdv_metadata, additional_metadata] +@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark._load_sdv_demo_dataset') +def test__resolve_dataset_raises_when_sdv_dataset_is_missing_from_buckets( + mock_load_sdv_demo_dataset, mock_get_dataset_bucket_mapping +): + """Test `_resolve_dataset` raises when an SDV dataset is not found in any bucket.""" + # Setup + mock_get_dataset_bucket_mapping.return_value = {'available_dataset': 'bucket'} + + # Run and Assert + with pytest.raises( + ValueError, + match=( + 'The following SDV demo datasets were not found in the expected buckets: ' + "'missing_dataset'. Please check that the dataset names are correct." + ), + ): + _resolve_dataset( + modality='single_table', + sdv_datasets=['available_dataset', 'missing_dataset'], + additional_datasets_folder=None, + limit_dataset_size=False, + s3_client='s3_client', + ) + + mock_get_dataset_bucket_mapping.assert_called_once_with( + 'single_table', + ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], + 's3_client', + skip_inaccessible=True, + ) + mock_load_sdv_demo_dataset.assert_not_called() + + @pytest.mark.parametrize('modality', ['single_table', 'multi_table']) @patch('sdgym.benchmark._setup_output_destination') @patch('sdgym.benchmark._load_dataset_with_client') @@ -1239,7 +1273,7 @@ def test__generate_job_args_list_loads_each_dataset_once( """Test that each dataset is loaded once even when there are multiple synthesizers.""" # Setup mock_get_dataset_paths.return_value = [] - mock_get_dataset_bucket_mapping.return_value = {'datasetA': 'bucket', 'datasetB': 'bucket'} + mock_get_dataset_bucket_mapping.return_value = {'datasetA': 'bucket-a', 'datasetB': 'bucket-b'} mock__setup_output_destination.return_value = {} data_a = Mock(name='data_a') metadata_a = Mock(name='metadata_a') @@ -1287,14 +1321,14 @@ def test__generate_job_args_list_loads_each_dataset_once( call( modality='single_table', dataset_name='datasetA', - dataset_bucket_mapping={'datasetA': 'bucket', 'datasetB': 'bucket'}, + bucket='bucket-a', s3_client=s3_client, limit_dataset_size=True, ), call( modality='single_table', dataset_name='datasetB', - dataset_bucket_mapping={'datasetA': 'bucket', 'datasetB': 'bucket'}, + bucket='bucket-b', s3_client=s3_client, limit_dataset_size=True, ), @@ -1421,7 +1455,7 @@ def test_benchmark_single_table_no_warning_uniform_synthesizer( mock_load_sdv_demo_dataset.assert_called_once_with( modality='single_table', dataset_name='fake_hotel_guests', - dataset_bucket_mapping={'fake_hotel_guests': 'bucket'}, + bucket='bucket', s3_client=None, limit_dataset_size=False, ) diff --git a/tests/unit/test_datasets.py b/tests/unit/test_datasets.py index c770846b..e38efbb7 100644 --- a/tests/unit/test_datasets.py +++ b/tests/unit/test_datasets.py @@ -452,7 +452,7 @@ def test__load_sdv_demo_dataset_uses_download_demo(download_demo_mock): result = _load_sdv_demo_dataset( modality='single_table', dataset_name='demo', - dataset_bucket_mapping={'demo': SDV_DATASETS_PUBLIC_BUCKET}, + bucket=SDV_DATASETS_PUBLIC_BUCKET, ) # Assert @@ -482,7 +482,7 @@ def test__load_sdv_demo_dataset_falls_back_for_private_bucket( result = _load_sdv_demo_dataset( modality='single_table', dataset_name='demo', - dataset_bucket_mapping={'demo': SDV_DATASETS_PRIVATE_BUCKET}, + bucket=SDV_DATASETS_PRIVATE_BUCKET, s3_client='s3_client', ) @@ -515,7 +515,7 @@ def test__load_sdv_demo_dataset_limits_dataset_size(download_demo_mock, subset_m result = _load_sdv_demo_dataset( modality='single_table', dataset_name='demo', - dataset_bucket_mapping={'demo': SDV_DATASETS_PUBLIC_BUCKET}, + bucket=SDV_DATASETS_PUBLIC_BUCKET, limit_dataset_size=True, ) @@ -530,17 +530,6 @@ def test__load_sdv_demo_dataset_limits_dataset_size(download_demo_mock, subset_m ) -def test__load_sdv_demo_dataset_raises_when_dataset_not_found(): - """Test a clear error is raised when a demo dataset is absent from all buckets.""" - # Run and Assert - with pytest.raises(ValueError, match="Dataset 'missing' not found in SDV demo buckets"): - _load_sdv_demo_dataset( - modality='single_table', - dataset_name='missing', - dataset_bucket_mapping={}, - ) - - @patch('sdgym.datasets._get_dataset_path_and_download') @patch('sdgym.datasets._path_contains_data_and_metadata', return_value=True) @patch('sdgym.datasets.Path')