diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index e344e5bd..d475d6b6 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -28,6 +28,9 @@ jobs: python -m pip install --upgrade pip python -m pip install --no-cache-dir invoke .[test] - name: Run integration tests + env: + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} run: invoke integration - if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.14 name: Upload integration codecov report diff --git a/.github/workflows/minimum.yml b/.github/workflows/minimum.yml index 7eaf8edc..8e143056 100644 --- a/.github/workflows/minimum.yml +++ b/.github/workflows/minimum.yml @@ -39,4 +39,7 @@ jobs: python -m pip install --no-cache-dir invoke .[test] - name: Test with minimum versions + env: + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} run: invoke minimum diff --git a/pyproject.toml b/pyproject.toml index 46c6f0a9..83b325d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,10 +64,8 @@ dependencies = [ 'XlsxWriter>=1.2.8', "rdt>=1.18.2;python_version<'3.14'", "rdt>=1.20.0;python_version>='3.14'", - "sdmetrics>=0.21.0;python_version<'3.14'", - "sdmetrics>=0.26.0;python_version>='3.14'", - "sdv>=1.21.0;python_version<'3.14'", - "sdv>=1.33.0;python_version>='3.14'", + "sdmetrics>=0.28.0", + "sdv @ git+https://github.com/sdv-dev/SDV.git@main", ] [project.urls] diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index 30d8e608..8638b9b3 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -40,7 +40,14 @@ ) from sdmetrics.single_table import DCRBaselineProtection -from sdgym.datasets import _load_dataset_with_client, get_dataset_paths +from sdgym.datasets import ( + SDV_DATASETS_PRIVATE_BUCKET, + SDV_DATASETS_PUBLIC_BUCKET, + _get_dataset_bucket_mapping, + _load_dataset_with_client, + _load_sdv_demo_dataset, + get_dataset_paths, +) from sdgym.errors import BenchmarkError, SDGymError from sdgym.metrics import get_metrics from sdgym.progress import TqdmLogger @@ -123,6 +130,14 @@ class JobArgs(NamedTuple): output_directions: Optional[dict] +class ResolvedDataset(NamedTuple): + """Resolved dataset data and metadata for benchmark job creation.""" + + name: str + data: Any + metadata: Any + + def _import_and_validate_synthesizers(synthesizers, custom_synthesizers, modality): """Import user-provided synthesizer and validate modality and uniqueness. @@ -323,29 +338,14 @@ def _setup_output_destination( return paths -def _generate_job_args_list( - limit_dataset_size, +def _resolve_dataset( + modality, sdv_datasets, additional_datasets_folder, - sdmetrics, - timeout, - output_destination, - compute_quality_score, - compute_diagnostic_score, - compute_privacy_score, - synthesizers, - s3_client, - modality, + limit_dataset_size, + s3_client=None, ): - sdv_datasets = ( - [] - if sdv_datasets is None - else get_dataset_paths( - modality=modality, - datasets=sdv_datasets, - s3_client=s3_client, - ) - ) + sdv_dataset_names = [] if sdv_datasets is None else sdv_datasets additional_datasets = ( [] if additional_datasets_folder is None @@ -359,13 +359,76 @@ def _generate_job_args_list( s3_client=s3_client, ) ) - datasets = sdv_datasets + additional_datasets + + dataset_bucket_mapping = None + if sdv_dataset_names: + dataset_bucket_mapping = _get_dataset_bucket_mapping( + modality, + [SDV_DATASETS_PUBLIC_BUCKET, SDV_DATASETS_PRIVATE_BUCKET], + s3_client, + skip_inaccessible=True, + ) + missing_names = [name for name in sdv_dataset_names if name not in dataset_bucket_mapping] + if missing_names: + missing_to_print = "', '".join(missing_names) + raise ValueError( + f'The following SDV demo datasets were not found in the expected buckets: ' + f"'{missing_to_print}'. Please check that the dataset names are correct." + ) + + datasets = [] + for dataset_name in sdv_dataset_names: + data, metadata = _load_sdv_demo_dataset( + modality=modality, + dataset_name=dataset_name, + bucket=dataset_bucket_mapping.get(dataset_name), + s3_client=s3_client, + limit_dataset_size=limit_dataset_size, + ) + datasets.append(ResolvedDataset(dataset_name, data, metadata)) + + for dataset in additional_datasets: + data, metadata = _load_dataset_with_client( + modality, + dataset, + limit_dataset_size=limit_dataset_size, + s3_client=s3_client, + ) + datasets.append(ResolvedDataset(dataset.name, data, metadata)) + + return datasets + + +def _generate_job_args_list( + limit_dataset_size, + sdv_datasets, + additional_datasets_folder, + sdmetrics, + timeout, + output_destination, + compute_quality_score, + compute_diagnostic_score, + compute_privacy_score, + synthesizers, + s3_client, + modality, +): + if not synthesizers: + return [] + + datasets = _resolve_dataset( + modality=modality, + sdv_datasets=sdv_datasets, + additional_datasets_folder=additional_datasets_folder, + limit_dataset_size=limit_dataset_size, + s3_client=s3_client, + ) synthesizer_names = [synthesizer['name'] for synthesizer in synthesizers] dataset_names = [dataset.name for dataset in datasets] paths = _setup_output_destination( output_destination, synthesizer_names, dataset_names, modality=modality, s3_client=s3_client ) - job_tuples = [] + job_args_list = [] for dataset in datasets: for synthesizer in synthesizers: if paths: @@ -377,29 +440,22 @@ def _generate_job_args_list( final_name = synthesizer['name'] synthesizer['name'] = final_name - job_tuples.append((synthesizer, dataset)) - - job_args_list = [] - for synthesizer, dataset in job_tuples: - data, metadata_dict = _load_dataset_with_client( - modality, dataset, limit_dataset_size=limit_dataset_size, s3_client=s3_client - ) - path = paths.get(dataset.name, {}).get(synthesizer['name'], None) - job_args_list.append( - JobArgs( - synthesizer=synthesizer, - data=data, - metadata=metadata_dict, - metrics=sdmetrics, - timeout=timeout, - compute_quality_score=compute_quality_score, - compute_diagnostic_score=compute_diagnostic_score, - compute_privacy_score=compute_privacy_score, - dataset_name=dataset.name, - modality=modality, - output_directions=path, + path = paths.get(dataset.name, {}).get(synthesizer['name'], None) + job_args_list.append( + JobArgs( + synthesizer=synthesizer, + data=dataset.data, + metadata=dataset.metadata, + metrics=sdmetrics, + timeout=timeout, + compute_quality_score=compute_quality_score, + compute_diagnostic_score=compute_diagnostic_score, + compute_privacy_score=compute_privacy_score, + dataset_name=dataset.name, + modality=modality, + output_directions=path, + ) ) - ) return job_args_list diff --git a/sdgym/datasets.py b/sdgym/datasets.py index 8b923b11..4d9f32b9 100644 --- a/sdgym/datasets.py +++ b/sdgym/datasets.py @@ -1,12 +1,23 @@ """SDGym module to handle datasets.""" +import io import logging import os from pathlib import Path import appdirs +import botocore import numpy as np import pandas as pd +from sdv.datasets.demo import ( + _find_data_zip_key, + _get_data_from_bucket, + _get_first_v1_metadata_bytes, + _get_metadata, + _list_objects, + _load_data_from_zip, + download_demo, +) from sdgym._dataset_utils import ( _get_dataset_subset, @@ -251,6 +262,88 @@ def _get_available_datasets( return pd.DataFrame(datasets_info) +def _get_dataset_bucket_mapping(modality, buckets, s3_client, skip_inaccessible=False): + """Map SDV demo dataset names to the bucket they should be loaded from.""" + dataset_buckets = {} + for bucket in buckets: + try: + available_datasets = _get_available_datasets( + modality, + bucket=bucket, + s3_client=s3_client, + ) + except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as error: + if skip_inaccessible: + LOGGER.info("Skipping inaccessible bucket '%s': %s", bucket, error) + continue + + raise ValueError( + f"Bucket '{bucket}' is not accessible with the provided credentials." + ) from error + + for dataset_name in available_datasets['dataset_name'].tolist(): + existing_bucket = dataset_buckets.get(dataset_name) + if existing_bucket and bucket != SDV_DATASETS_PRIVATE_BUCKET: + continue + + dataset_buckets[dataset_name] = bucket + + return dataset_buckets + + +def _load_private_sdv_demo_dataset(modality, dataset_name, bucket, s3_client=None): + """Load an SDV demo dataset from a private bucket with an SDGym S3 client.""" + bucket_name = _get_bucket_name(bucket) + s3_client = s3_client or get_s3_client() + dataset_prefix = f'{modality}/{dataset_name}/' + contents = _list_objects(dataset_prefix, bucket=bucket_name, client=s3_client) + data_key = _find_data_zip_key(contents, dataset_prefix, bucket_name) + data_bytes = io.BytesIO(_get_data_from_bucket(data_key, bucket=bucket_name, client=s3_client)) + metadata_bytes = _get_first_v1_metadata_bytes( + contents, dataset_prefix, bucket=bucket_name, client=s3_client + ) + data = _load_data_from_zip(data_bytes, bucket_name, dataset_name) + if modality != 'multi_table': + data = data.popitem()[1] + + metadata = _get_metadata(metadata_bytes, dataset_name) + return data, metadata.to_dict() + + +def _load_sdv_demo_dataset( + modality, + dataset_name, + bucket, + s3_client=None, + limit_dataset_size=False, +): + """Load an SDV demo dataset from the resolved public or private bucket.""" + _validate_modality(modality) + bucket_name = _get_bucket_name(bucket) + try: + data, metadata = download_demo( + modality=modality, + dataset_name=dataset_name, + s3_bucket_name=bucket_name, + ) + metadata = metadata.to_dict() + except ValueError: + if bucket != SDV_DATASETS_PRIVATE_BUCKET: + raise + + data, metadata = _load_private_sdv_demo_dataset( + modality, + dataset_name, + bucket, + s3_client=s3_client, + ) + + if limit_dataset_size: + data, metadata = _get_dataset_subset(data, metadata, modality=modality) + + return data, metadata + + def load_dataset( modality, dataset, diff --git a/tests/integration/test_benchmark.py b/tests/integration/test_benchmark.py index 0c7fe86a..cdf87685 100644 --- a/tests/integration/test_benchmark.py +++ b/tests/integration/test_benchmark.py @@ -929,6 +929,30 @@ def test_benchmark_multi_table_basic_synthesizers(): ] +@pytest.mark.skipif( + not os.getenv('AWS_ACCESS_KEY_ID') or not os.getenv('AWS_SECRET_ACCESS_KEY'), + reason='MovieLens benchmark requires AWS credentials for private dataset access.', +) +def test_benchmark_multi_table_private_dataset(): + """Test multi-table benchmark with private dataset `MovieLens`.""" + # Setup + datasets = ['MovieLens'] + synthesizers = ['HMASynthesizer'] + timeout = 10 + + # Run + result = benchmark_multi_table( + synthesizers=synthesizers, + sdv_datasets=datasets, + timeout=timeout, + ) + + # Assert + assert result['Dataset'].tolist() == ['MovieLens', 'MovieLens'] + assert result['Synthesizer'].tolist() == ['HMASynthesizer', 'MultiTableUniformSynthesizer'] + assert result['Quality_Score'].tolist() == [None, None] + + def test_benchmark_multi_table_with_output_destination_multiple_runs(tmp_path): """Test saving in ``output_destination`` with multiple runs in multi-table mode. diff --git a/tests/unit/_benchmark_launcher/test_utils.py b/tests/unit/_benchmark_launcher/test_utils.py index 10f26f4d..d2626196 100644 --- a/tests/unit/_benchmark_launcher/test_utils.py +++ b/tests/unit/_benchmark_launcher/test_utils.py @@ -516,10 +516,33 @@ def test_resolve_credentials_with_filepath_deep_merges_file_over_env( assert credentials == expected -def test_resolve_credentials_file_mode(tmp_path): +@patch('sdgym._benchmark_launcher.utils._get_env_credentials') +def test_resolve_credentials_file_mode(mock_get_env_credentials, tmp_path): """Test `resolve_credentials` returns credentials from a file merged over env defaults.""" # Setup credential_file = tmp_path / 'credentials.json' + mock_get_env_credentials.return_value = { + 'aws': { + 'AWS_ACCESS_KEY_ID': None, + 'AWS_SECRET_ACCESS_KEY': None, + }, + 'gcp': { + 'type': None, + 'project_id': None, + 'private_key_id': None, + 'private_key': None, + 'client_email': None, + 'client_id': None, + 'auth_uri': None, + 'token_uri': None, + 'auth_provider_x509_cert_url': None, + 'client_x509_cert_url': None, + }, + 'sdv_enterprise': { + 'SDV_ENTERPRISE_USERNAME': None, + 'SDV_ENTERPRISE_LICENSE_KEY': None, + }, + } file_credentials = { 'aws': { 'AWS_ACCESS_KEY_ID': 'FILE_AKIA', @@ -559,6 +582,7 @@ def test_resolve_credentials_file_mode(tmp_path): credentials = resolve_credentials(str(credential_file)) # Assert + mock_get_env_credentials.assert_called_once_with() assert credentials == expected_credentials diff --git a/tests/unit/test_benchmark.py b/tests/unit/test_benchmark.py index 9c8065af..5f5fc166 100644 --- a/tests/unit/test_benchmark.py +++ b/tests/unit/test_benchmark.py @@ -21,6 +21,7 @@ _generate_job_args_list, _get_metainfo_increment, _import_and_validate_synthesizers, + _resolve_dataset, _setup_output_destination, _setup_output_destination_aws, _store_job_args_in_s3, @@ -202,10 +203,18 @@ def test__get_metainfo_increment_local(mock_logger, tmp_path): assert result_3 == 3 +@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark._load_sdv_demo_dataset') @patch('sdgym.benchmark.tqdm.tqdm') -def test_benchmark_single_table_progress_bar(tqdm_mock): +def test_benchmark_single_table_progress_bar( + tqdm_mock, mock_load_sdv_demo_dataset, mock_get_dataset_bucket_mapping +): """Test that the benchmarking function updates the progress bar on one line.""" # Setup + data = pd.DataFrame({'column': [1, 2]}) + metadata = {'tables': {'student_placements': {'columns': {'column': {'sdtype': 'numerical'}}}}} + mock_get_dataset_bucket_mapping.return_value = {'student_placements': 'bucket'} + mock_load_sdv_demo_dataset.return_value = data, metadata scores_mock = MagicMock() scores_mock.__iter__.return_value = [ pd.DataFrame({ @@ -225,14 +234,38 @@ def test_benchmark_single_table_progress_bar(tqdm_mock): ) # Assert + mock_get_dataset_bucket_mapping.assert_called_once_with( + 'single_table', + ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], + None, + skip_inaccessible=True, + ) + mock_load_sdv_demo_dataset.assert_called_once_with( + modality='single_table', + dataset_name='student_placements', + bucket='bucket', + s3_client=None, + limit_dataset_size=False, + ) tqdm_mock.assert_called_once_with(ANY, total=1, position=0, leave=True) +@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark._load_sdv_demo_dataset') @patch('sdgym.benchmark._score') @patch('sdgym.benchmark.multiprocessing') -def test_benchmark_single_table_with_timeout(mock_multiprocessing, mock__score): +def test_benchmark_single_table_with_timeout( + mock_multiprocessing, + mock__score, + mock_load_sdv_demo_dataset, + mock_get_dataset_bucket_mapping, +): """Test that benchmark runs with timeout.""" # Setup + data = pd.DataFrame({'column': [1, 2]}) + metadata = {'tables': {'student_placements': {'columns': {'column': {'sdtype': 'numerical'}}}}} + mock_get_dataset_bucket_mapping.return_value = {'student_placements': 'bucket'} + mock_load_sdv_demo_dataset.return_value = data, metadata mocked_process = mock_multiprocessing.Process.return_value manager = mock_multiprocessing.Manager.return_value manager_dict = {'timeout': True, 'Error': 'Synthesizer Timeout'} @@ -246,6 +279,19 @@ def test_benchmark_single_table_with_timeout(mock_multiprocessing, mock__score): ) # Assert + mock_get_dataset_bucket_mapping.assert_called_once_with( + 'single_table', + ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], + None, + skip_inaccessible=True, + ) + mock_load_sdv_demo_dataset.assert_called_once_with( + modality='single_table', + dataset_name='student_placements', + bucket='bucket', + s3_client=None, + limit_dataset_size=False, + ) mocked_process.start.assert_called_once_with() mocked_process.join.assert_called_once_with(1) mocked_process.terminate.assert_called_once_with() @@ -1059,10 +1105,112 @@ def test__add_adjusted_scores_missing_fallback(): assert scores.equals(expected) +@patch('sdgym.benchmark._load_dataset_with_client') +@patch('sdgym.benchmark._load_sdv_demo_dataset') +@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark.get_dataset_paths') +def test__resolve_dataset_loads_sdv_and_additional_datasets( + mock_get_dataset_paths, + mock_get_dataset_bucket_mapping, + mock_load_sdv_demo_dataset, + mock_load_dataset, + tmp_path, +): + """Test the `_resolve_dataset` method.""" + # Setup + additional_folder = tmp_path / 'additional' + additional_dataset_path = additional_folder / 'single_table' / 'custom_dataset' + sdv_data = Mock(name='sdv_data') + sdv_metadata = Mock(name='sdv_metadata') + additional_data = Mock(name='additional_data') + additional_metadata = Mock(name='additional_metadata') + s3_client = Mock() + mock_get_dataset_paths.return_value = [additional_dataset_path] + mock_get_dataset_bucket_mapping.return_value = {'sdv_dataset': 'bucket'} + mock_load_sdv_demo_dataset.return_value = sdv_data, sdv_metadata + mock_load_dataset.return_value = additional_data, additional_metadata + + # Run + result = _resolve_dataset( + modality='single_table', + sdv_datasets=['sdv_dataset'], + additional_datasets_folder=str(additional_folder), + limit_dataset_size=True, + s3_client=s3_client, + ) + + # Assert + mock_get_dataset_paths.assert_called_once_with( + modality='single_table', + bucket=str(additional_folder / 'single_table'), + s3_client=s3_client, + ) + mock_get_dataset_bucket_mapping.assert_called_once_with( + 'single_table', + ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], + s3_client, + skip_inaccessible=True, + ) + mock_load_sdv_demo_dataset.assert_called_once_with( + modality='single_table', + dataset_name='sdv_dataset', + bucket='bucket', + s3_client=s3_client, + limit_dataset_size=True, + ) + mock_load_dataset.assert_called_once_with( + 'single_table', + additional_dataset_path, + limit_dataset_size=True, + s3_client=s3_client, + ) + assert [dataset.name for dataset in result] == ['sdv_dataset', 'custom_dataset'] + assert [dataset.data for dataset in result] == [sdv_data, additional_data] + assert [dataset.metadata for dataset in result] == [sdv_metadata, additional_metadata] + + +@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark._load_sdv_demo_dataset') +def test__resolve_dataset_raises_when_sdv_dataset_is_missing_from_buckets( + mock_load_sdv_demo_dataset, mock_get_dataset_bucket_mapping +): + """Test `_resolve_dataset` raises when an SDV dataset is not found in any bucket.""" + # Setup + mock_get_dataset_bucket_mapping.return_value = {'available_dataset': 'bucket'} + + # Run and Assert + with pytest.raises( + ValueError, + match=( + 'The following SDV demo datasets were not found in the expected buckets: ' + "'missing_dataset'. Please check that the dataset names are correct." + ), + ): + _resolve_dataset( + modality='single_table', + sdv_datasets=['available_dataset', 'missing_dataset'], + additional_datasets_folder=None, + limit_dataset_size=False, + s3_client='s3_client', + ) + + mock_get_dataset_bucket_mapping.assert_called_once_with( + 'single_table', + ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], + 's3_client', + skip_inaccessible=True, + ) + mock_load_sdv_demo_dataset.assert_not_called() + + @pytest.mark.parametrize('modality', ['single_table', 'multi_table']) +@patch('sdgym.benchmark._setup_output_destination') +@patch('sdgym.benchmark._load_dataset_with_client') @patch('sdgym.benchmark.get_dataset_paths') def test__generate_job_args_list_local_root_additional_folder( get_dataset_paths_mock, + mock_load_dataset, + mock__setup_output_destination, tmp_path, modality, ): @@ -1072,6 +1220,8 @@ def test__generate_job_args_list_local_root_additional_folder( local_root.mkdir() dataset_path = tmp_path / 'my_root' / modality / 'datasetA' get_dataset_paths_mock.return_value = [dataset_path] + mock_load_dataset.return_value = Mock(), Mock() + mock__setup_output_destination.return_value = {} # Run _generate_job_args_list( @@ -1084,7 +1234,7 @@ def test__generate_job_args_list_local_root_additional_folder( compute_quality_score=False, compute_diagnostic_score=False, compute_privacy_score=False, - synthesizers=[], + synthesizers=[{'name': 'UniformSynthesizer'}], s3_client=None, modality=modality, ) @@ -1095,6 +1245,109 @@ def test__generate_job_args_list_local_root_additional_folder( bucket=str(local_root / modality), s3_client=None, ) + mock_load_dataset.assert_called_once_with( + modality, + dataset_path, + limit_dataset_size=False, + s3_client=None, + ) + mock__setup_output_destination.assert_called_once_with( + None, + ['UniformSynthesizer'], + ['datasetA'], + modality=modality, + s3_client=None, + ) + + +@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark._load_sdv_demo_dataset') +@patch('sdgym.benchmark._setup_output_destination') +@patch('sdgym.benchmark.get_dataset_paths') +def test__generate_job_args_list_loads_each_dataset_once( + mock_get_dataset_paths, + mock__setup_output_destination, + mock_load_sdv_demo_dataset, + mock_get_dataset_bucket_mapping, +): + """Test that each dataset is loaded once even when there are multiple synthesizers.""" + # Setup + mock_get_dataset_paths.return_value = [] + mock_get_dataset_bucket_mapping.return_value = {'datasetA': 'bucket-a', 'datasetB': 'bucket-b'} + mock__setup_output_destination.return_value = {} + data_a = Mock(name='data_a') + metadata_a = Mock(name='metadata_a') + data_b = Mock(name='data_b') + metadata_b = Mock(name='metadata_b') + mock_load_sdv_demo_dataset.side_effect = [(data_a, metadata_a), (data_b, metadata_b)] + synthesizers = [ + {'name': 'GaussianCopulaSynthesizer'}, + {'name': 'UniformSynthesizer'}, + ] + s3_client = Mock() + + # Run + job_args_list = _generate_job_args_list( + limit_dataset_size=True, + sdv_datasets=['datasetA', 'datasetB'], + additional_datasets_folder=None, + sdmetrics=None, + timeout=None, + output_destination=None, + compute_quality_score=False, + compute_diagnostic_score=False, + compute_privacy_score=False, + synthesizers=synthesizers, + s3_client=s3_client, + modality='single_table', + ) + + # Assert + mock_get_dataset_paths.assert_not_called() + mock_get_dataset_bucket_mapping.assert_called_once_with( + 'single_table', + ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], + s3_client, + skip_inaccessible=True, + ) + mock__setup_output_destination.assert_called_once_with( + None, + ['GaussianCopulaSynthesizer', 'UniformSynthesizer'], + ['datasetA', 'datasetB'], + modality='single_table', + s3_client=s3_client, + ) + mock_load_sdv_demo_dataset.assert_has_calls([ + call( + modality='single_table', + dataset_name='datasetA', + bucket='bucket-a', + s3_client=s3_client, + limit_dataset_size=True, + ), + call( + modality='single_table', + dataset_name='datasetB', + bucket='bucket-b', + s3_client=s3_client, + limit_dataset_size=True, + ), + ]) + assert mock_load_sdv_demo_dataset.call_count == 2 + assert len(job_args_list) == 4 + assert [job.dataset_name for job in job_args_list] == [ + 'datasetA', + 'datasetA', + 'datasetB', + 'datasetB', + ] + assert [job.data for job in job_args_list] == [data_a, data_a, data_b, data_b] + assert [job.metadata for job in job_args_list] == [ + metadata_a, + metadata_a, + metadata_b, + metadata_b, + ] @patch('sdgym.benchmark.get_dataset_paths') @@ -1112,7 +1365,10 @@ def test__generate_job_args_list_s3_root_additional_folder( get_dataset_paths_mock.return_value = [dataset_path] s3_client = Mock() mock__setup_output_destination.return_value = {} - mock_load_dataset.return_value = pd.DataFrame({'col1': [1, 2], 'col2': [3, 4]}) + mock_load_dataset.return_value = ( + pd.DataFrame({'col1': [1, 2], 'col2': [3, 4]}), + {'tables': {}}, + ) # Run _generate_job_args_list( @@ -1148,9 +1404,27 @@ def test__generate_job_args_list_s3_root_additional_folder( ) -def test_benchmark_single_table_no_warning_uniform_synthesizer(recwarn): +@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark._load_sdv_demo_dataset') +def test_benchmark_single_table_no_warning_uniform_synthesizer( + mock_load_sdv_demo_dataset, mock_get_dataset_bucket_mapping, recwarn +): """Test that no UserWarning is raised when running `UniformSynthesizer`.""" # Setup + data = pd.DataFrame({'column': [1, 2, 3]}) + metadata = { + 'tables': { + 'fake_hotel_guests': { + 'columns': { + 'column': { + 'sdtype': 'numerical', + } + } + } + } + } + mock_get_dataset_bucket_mapping.return_value = {'fake_hotel_guests': 'bucket'} + mock_load_sdv_demo_dataset.return_value = data, metadata expected_result = pd.DataFrame({ 'Synthesizer': {0: 'UniformSynthesizer'}, 'Dataset': {0: 'fake_hotel_guests'}, @@ -1172,6 +1446,19 @@ def test_benchmark_single_table_no_warning_uniform_synthesizer(recwarn): ) # Assert + mock_get_dataset_bucket_mapping.assert_called_once_with( + 'single_table', + ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], + None, + skip_inaccessible=True, + ) + mock_load_sdv_demo_dataset.assert_called_once_with( + modality='single_table', + dataset_name='fake_hotel_guests', + bucket='bucket', + s3_client=None, + limit_dataset_size=False, + ) warnings_text = ' '.join(str(w.message) for w in recwarn) assert 'is incompatible with transformer' not in warnings_text pd.testing.assert_frame_equal(result[expected_result.columns], expected_result) diff --git a/tests/unit/test_datasets.py b/tests/unit/test_datasets.py index f03c7d00..e38efbb7 100644 --- a/tests/unit/test_datasets.py +++ b/tests/unit/test_datasets.py @@ -1,16 +1,22 @@ from pathlib import Path from unittest.mock import Mock, call, patch +import botocore import numpy as np +import pandas as pd import pytest from sdgym.datasets import ( DATASETS_PATH, + SDV_DATASETS_PRIVATE_BUCKET, + SDV_DATASETS_PUBLIC_BUCKET, _download_dataset, _genereate_dataset_info, _get_bucket_name, + _get_dataset_bucket_mapping, _get_dataset_path_and_download, _load_dataset_with_client, + _load_sdv_demo_dataset, _path_contains_data_and_metadata, _validate_modality, get_data_and_metadata_from_path, @@ -362,6 +368,168 @@ def test_get_bucket_name_local_folder(): assert bucket_name == 'bucket-name' +@patch('sdgym.datasets._get_available_datasets') +def test__get_dataset_bucket_mapping_prefers_private(get_available_mock): + """Test that datasets are mapped to private when duplicated across buckets.""" + # Setup + get_available_mock.side_effect = [ + pd.DataFrame({'dataset_name': ['public_only', 'duplicate']}), + pd.DataFrame({'dataset_name': ['private_only', 'duplicate']}), + ] + + # Run + result = _get_dataset_bucket_mapping( + 'single_table', + [SDV_DATASETS_PUBLIC_BUCKET, SDV_DATASETS_PRIVATE_BUCKET], + s3_client='s3_client', + ) + + # Assert + assert result == { + 'public_only': SDV_DATASETS_PUBLIC_BUCKET, + 'private_only': SDV_DATASETS_PRIVATE_BUCKET, + 'duplicate': SDV_DATASETS_PRIVATE_BUCKET, + } + get_available_mock.assert_has_calls([ + call('single_table', bucket=SDV_DATASETS_PUBLIC_BUCKET, s3_client='s3_client'), + call('single_table', bucket=SDV_DATASETS_PRIVATE_BUCKET, s3_client='s3_client'), + ]) + + +@patch('sdgym.datasets._get_available_datasets') +def test__get_dataset_bucket_mapping_skips_inaccessible_bucket(get_available_mock): + """Test inaccessible buckets can be skipped while building the mapping.""" + # Setup + error = botocore.exceptions.ClientError( + {'Error': {'Code': 'AccessDenied', 'Message': 'denied'}}, + 'ListObjectsV2', + ) + get_available_mock.side_effect = [ + pd.DataFrame({'dataset_name': ['public_only']}), + error, + ] + + # Run + result = _get_dataset_bucket_mapping( + 'single_table', + [SDV_DATASETS_PUBLIC_BUCKET, SDV_DATASETS_PRIVATE_BUCKET], + s3_client='s3_client', + skip_inaccessible=True, + ) + + # Assert + assert result == {'public_only': SDV_DATASETS_PUBLIC_BUCKET} + + +@patch('sdgym.datasets._get_available_datasets') +def test__get_dataset_bucket_mapping_raises_inaccessible_bucket(get_available_mock): + """Test inaccessible buckets raise by default.""" + # Setup + get_available_mock.side_effect = botocore.exceptions.ClientError( + {'Error': {'Code': 'AccessDenied', 'Message': 'denied'}}, + 'ListObjectsV2', + ) + + # Run and Assert + with pytest.raises(ValueError, match="Bucket 's3://sdv-datasets-private' is not accessible"): + _get_dataset_bucket_mapping( + 'single_table', + [SDV_DATASETS_PRIVATE_BUCKET], + s3_client='s3_client', + ) + + +@patch('sdgym.datasets.download_demo') +def test__load_sdv_demo_dataset_uses_download_demo(download_demo_mock): + """Test SDV demo datasets are loaded through SDV's download_demo.""" + # Setup + data = pd.DataFrame({'column': [1, 2]}) + metadata = Mock() + metadata.to_dict.return_value = {'tables': {'demo': {'columns': {'column': {}}}}} + download_demo_mock.return_value = data, metadata + + # Run + result = _load_sdv_demo_dataset( + modality='single_table', + dataset_name='demo', + bucket=SDV_DATASETS_PUBLIC_BUCKET, + ) + + # Assert + result_data, result_metadata = result + pd.testing.assert_frame_equal(result_data, data) + assert result_metadata == metadata.to_dict.return_value + download_demo_mock.assert_called_once_with( + modality='single_table', + dataset_name='demo', + s3_bucket_name='sdv-datasets-public', + ) + + +@patch('sdgym.datasets._load_private_sdv_demo_dataset') +@patch('sdgym.datasets.download_demo') +def test__load_sdv_demo_dataset_falls_back_for_private_bucket( + download_demo_mock, load_private_mock +): + """Test SDV private-bucket errors fall back to SDGym private loading.""" + # Setup + data = pd.DataFrame({'column': [1, 2]}) + metadata = {'tables': {'demo': {'columns': {'column': {}}}}} + download_demo_mock.side_effect = ValueError('Private buckets are only supported') + load_private_mock.return_value = data, metadata + + # Run + result = _load_sdv_demo_dataset( + modality='single_table', + dataset_name='demo', + bucket=SDV_DATASETS_PRIVATE_BUCKET, + s3_client='s3_client', + ) + + # Assert + result_data, result_metadata = result + pd.testing.assert_frame_equal(result_data, data) + assert result_metadata == metadata + load_private_mock.assert_called_once_with( + 'single_table', + 'demo', + SDV_DATASETS_PRIVATE_BUCKET, + s3_client='s3_client', + ) + + +@patch('sdgym.datasets._get_dataset_subset') +@patch('sdgym.datasets.download_demo') +def test__load_sdv_demo_dataset_limits_dataset_size(download_demo_mock, subset_mock): + """Test SDV demo dataset loading applies the dataset size limit.""" + # Setup + data = pd.DataFrame({'column': [1, 2]}) + metadata = Mock() + metadata.to_dict.return_value = {'tables': {'demo': {'columns': {'column': {}}}}} + limited_data = pd.DataFrame({'column': [1]}) + limited_metadata = {'tables': {'demo': {'columns': {'column': {}}}}} + download_demo_mock.return_value = data, metadata + subset_mock.return_value = limited_data, limited_metadata + + # Run + result = _load_sdv_demo_dataset( + modality='single_table', + dataset_name='demo', + bucket=SDV_DATASETS_PUBLIC_BUCKET, + limit_dataset_size=True, + ) + + # Assert + result_data, result_metadata = result + pd.testing.assert_frame_equal(result_data, limited_data) + assert result_metadata == limited_metadata + subset_mock.assert_called_once_with( + data, + metadata.to_dict.return_value, + modality='single_table', + ) + + @patch('sdgym.datasets._get_dataset_path_and_download') @patch('sdgym.datasets._path_contains_data_and_metadata', return_value=True) @patch('sdgym.datasets.Path') diff --git a/tests/unit/test_s3.py b/tests/unit/test_s3.py index 00bb7d8d..3dd0a20f 100644 --- a/tests/unit/test_s3.py +++ b/tests/unit/test_s3.py @@ -320,11 +320,14 @@ def test__get_s3_client_with_credentials(mock_boto_client): mock_s3_client.head_bucket.assert_called_once_with(Bucket='my-bucket') -def test__get_s3_client_errors(): +@patch('sdgym.s3.boto3.client') +def test__get_s3_client_errors(mock_boto_client): """Test `_get_s3_client` raises error for invalid input.""" # Setup output_destination = 's3:/' expected_error = re.escape(f'Invalid S3 URL: {output_destination}') + mock_s3_client = mock_boto_client.return_value + mock_s3_client.head_bucket.side_effect = NoCredentialsError() # Run and Assert with pytest.raises(ValueError, match=expected_error): @@ -333,6 +336,9 @@ def test__get_s3_client_errors(): with pytest.raises(NoCredentialsError, match='Unable to locate credentials'): _get_s3_client('s3://bucket_name/') + mock_boto_client.assert_called_once_with('s3') + mock_s3_client.head_bucket.assert_called_once_with(Bucket='bucket_name') + def test__read_data_from_bucket_key_reads_body(): """Test that the function reads data from S3 object body."""