diff --git a/README.md b/README.md index 83bd2d0..ec3f936 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,80 @@ You can also turn off a specific encoder after training has completed, for testi > all the embeddings of the previous tokens (on the word level). This is done since not all causal LMs support > cross-attention, and so we want to avoid using it, and rely on the self-attention mechanism instead. +## Data Preparation + +You can prepare datasets offline using the `welt-prepare-data` CLI. +It streams a HuggingFace dataset, samples raw text with unit-based limits, and writes sharded `.jsonl.gz` files: + +```shell +welt-prepare-data \ + --dataset_name HuggingFaceFW/fineweb \ + --dataset_config sample-10BT \ + --train_split_units 3200000000 \ + --validation_split_units 100000000 \ + --num_units_per_file 100000000 \ + --max_seq_length 512 \ + --seed 42 \ + --output_path /scratch/data/pretrain +``` + +Multiple datasets can be prepared into the same output directory: + +```shell +welt-prepare-data \ + --dataset_name monology/pile-uncopyrighted \ + --train_split_units 1000000000 \ + --validation_split_units 50000000 \ + --num_units_per_file 100000000 \ + --max_seq_length 512 \ + --output_path /scratch/data/pretrain +``` + +The output directory contains sharded `.jsonl.gz` files and a `{prefix}-metadata.json` per dataset: + +``` +/scratch/data/pretrain/ +├── fineweb-sample-10BT-00000000.jsonl.gz +├── fineweb-sample-10BT-00000001.jsonl.gz +├── fineweb-sample-10BT-metadata.json +├── pile-uncopyrighted-00000000.jsonl.gz +└── pile-uncopyrighted-metadata.json +``` + +Then train using the prepared data: + +```shell +welt-train config.yaml --prepared_data_path /scratch/data/pretrain +``` + +| Argument | Description | +|----------|-------------| +| `--dataset_name` | HuggingFace dataset identifier (required) | +| `--dataset_config` | Dataset config name (optional) | +| `--dataset_split` | Split to use (default: "train") | +| `--text_column` | Column containing text (default: "text") | +| `--text_template` | Python format string template (optional) | +| `--language` | Language tag to store with each example (e.g., "eng_Latn") | +| `--unit_type` | Unit type for counting: "words" or "chars" (default: "words") | +| `--train_split_units` | Number of units for the train split (default: 0, no train shards) | +| `--validation_split_units` | Number of units for the validation split (default: 0, no validation shards) | +| `--num_units_per_file` | Max units per shard file (optional) | +| `--max_seq_length` | Max words per example; splits long documents using word segmentation | +| `--max_bytes_per_word` | Max UTF-8 bytes per word; should match training config `max_word_length - 2` (default: 126) | +| `--seed` | Random seed for shuffling | +| `--drop_remainder` | Drop partial chunks at document boundaries | +| `--output_path` | Output directory path (required) | + +### Verifying Prepared Data + +After preparing data, verify integrity with `welt-verify-data`: + +```shell +welt-verify-data --data_path /scratch/data/pretrain +``` + +This checks that shard counts and example counts match the metadata, and warns if train/validation splits from the same source were created separately (risking data contamination). + ## Training Training instructions are available in the [welt_training/README.md](./welt_training/README.md). diff --git a/pyproject.toml b/pyproject.toml index 037842a..92a70a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ train = [ "scikit-learn", # For "accuracy" metric in evaluate "sacrebleu", # For usual bleu/chrF metrics "wandb", # For experiment tracking + "zstandard", # For compressing the data ] dion = [ @@ -79,3 +80,5 @@ testpaths = [ [project.scripts] welt-train = "welt_training.train:train" +welt-prepare-data = "welt_training.prepare_data:main" +welt-verify-data = "welt_training.verify_data:main" diff --git a/tests/test_prepare_data.py b/tests/test_prepare_data.py new file mode 100644 index 0000000..8f47385 --- /dev/null +++ b/tests/test_prepare_data.py @@ -0,0 +1,364 @@ +import glob +import gzip +import json +import shutil +import tempfile + +import pytest +from datasets import load_dataset + +from welt_training.data_utils import load_prepared_data +from welt_training.prepare_data import get_shard_prefix, main + +WIKITEXT_DATASET = "wikitext" +WIKITEXT_CONFIG = "wikitext-2-raw-v1" + + +# --- Helpers --- + + +def read_shard_examples(output_dir, pattern="*.jsonl.gz"): + """Read all examples from shards matching *pattern* in *output_dir*.""" + examples = [] + for path in sorted(glob.glob(f"{output_dir}/{pattern}")): + with gzip.open(path, "rt") as f: + for line in f: + examples.append(json.loads(line)) + return examples + + +def read_metadata(output_dir, split_name, dataset_name=WIKITEXT_DATASET, dataset_config=WIKITEXT_CONFIG): + """Load the per-split metadata JSON produced by prepare_data.""" + prefix = get_shard_prefix(dataset_name, dataset_config) + with open(f"{output_dir}/{prefix}-{split_name}-metadata.json") as f: + return json.load(f) + + +def shard_paths(output_dir, pattern="*.jsonl.gz"): + """Return sorted list of shard file paths matching *pattern*.""" + return sorted(glob.glob(f"{output_dir}/{pattern}")) + + +# --- get_shard_prefix --- + + +def test_get_shard_prefix_with_org(): + assert get_shard_prefix("HuggingFaceFW/fineweb", None) == "fineweb" + + +def test_get_shard_prefix_with_config(): + assert get_shard_prefix("HuggingFaceFW/fineweb", "sample-10BT") == "fineweb-sample-10BT" + + +def test_get_shard_prefix_no_org(): + assert get_shard_prefix("wikitext", "wikitext-2-raw-v1") == "wikitext-wikitext-2-raw-v1" + + +# --- Integration tests --- + + +@pytest.fixture +def temp_output_dir(): + temp_dir = tempfile.mkdtemp(prefix="test_prepare_data_") + yield temp_dir + shutil.rmtree(temp_dir, ignore_errors=True) + + +def test_prepare_data_creates_shards(temp_output_dir, monkeypatch): + """Test that welt-prepare-data creates sharded .jsonl.gz files that can be loaded.""" + monkeypatch.setattr( + "sys.argv", + [ + "welt-prepare-data", + "--dataset_name", WIKITEXT_DATASET, + "--dataset_config", WIKITEXT_CONFIG, + "--train_split_units", "400", + "--validation_split_units", "100", + "--num_units_per_file", "200", + "--max_seq_length", "1024", + "--language", "eng_Latn", + "--seed", "42", + "--output_path", temp_output_dir, + ], + ) + main() + + shards = shard_paths(temp_output_dir) + assert len(shards) >= 2, f"Expected at least 2 shards, got {len(shards)}" + + train_meta = read_metadata(temp_output_dir, "train") + val_meta = read_metadata(temp_output_dir, "validation") + assert train_meta["format"] == "welt-preprocessed-v1" + assert train_meta["total_units"] + val_meta["total_units"] <= 500 + assert train_meta["unit_type"] == "words" + assert train_meta["num_shards"] + val_meta["num_shards"] == len(shards) + + examples = read_shard_examples(temp_output_dir) + for example in examples: + assert "text" in example + assert isinstance(example["text"], str) + assert len(examples) == train_meta["num_examples"] + val_meta["num_examples"] + + # Verify loading with HuggingFace datasets (same path as train.py) + ds = load_dataset("json", data_files=shards, split="train") + assert len(ds) == len(examples) + assert "text" in ds.features + + +def test_prepare_data_with_language(temp_output_dir, monkeypatch): + """Test that --language stores language metadata in each example.""" + monkeypatch.setattr( + "sys.argv", + [ + "welt-prepare-data", + "--dataset_name", WIKITEXT_DATASET, + "--dataset_config", WIKITEXT_CONFIG, + "--train_split_units", "160", + "--validation_split_units", "40", + "--max_seq_length", "1024", + "--language", "eng_Latn", + "--seed", "42", + "--output_path", temp_output_dir, + ], + ) + main() + + for example in read_shard_examples(temp_output_dir): + assert example["language"] == "eng_Latn" + + assert read_metadata(temp_output_dir, "train")["language"] == "eng_Latn" + assert read_metadata(temp_output_dir, "validation")["language"] == "eng_Latn" + + +def test_prepare_data_unit_type_chars(temp_output_dir, monkeypatch): + """Test that --unit_type chars counts characters instead of words.""" + monkeypatch.setattr( + "sys.argv", + [ + "welt-prepare-data", + "--dataset_name", WIKITEXT_DATASET, + "--dataset_config", WIKITEXT_CONFIG, + "--train_split_units", "400", + "--validation_split_units", "100", + "--unit_type", "chars", + "--max_seq_length", "1024", + "--language", "eng_Latn", + "--seed", "42", + "--output_path", temp_output_dir, + ], + ) + main() + + meta = read_metadata(temp_output_dir, "validation") + assert meta["unit_type"] == "chars" + assert meta["total_units"] <= 500 + + +def test_prepare_data_with_max_seq_length(temp_output_dir, monkeypatch): + """Test that --max_seq_length splits long documents into chunks.""" + monkeypatch.setattr( + "sys.argv", + [ + "welt-prepare-data", + "--dataset_name", WIKITEXT_DATASET, + "--dataset_config", WIKITEXT_CONFIG, + "--train_split_units", "400", + "--validation_split_units", "100", + "--max_seq_length", "32", + "--language", "eng_Latn", + "--seed", "42", + "--output_path", temp_output_dir, + ], + ) + main() + + train_meta = read_metadata(temp_output_dir, "train") + val_meta = read_metadata(temp_output_dir, "validation") + assert train_meta["max_seq_length"] == 32 + assert train_meta["total_units"] + val_meta["total_units"] <= 500 + + # Verify each example has at most max_seq_length words + from words_segmentation.tokenizer import WordsSegmentationTokenizer + pretokenizer = WordsSegmentationTokenizer(max_bytes=126) + + for example in read_shard_examples(temp_output_dir): + words = pretokenizer.tokenize(example["text"]) + assert len(words) <= 32 + + +def test_prepare_data_with_validation_split(temp_output_dir, monkeypatch): + """Test that --validation_split_units creates split-aware shards.""" + monkeypatch.setattr( + "sys.argv", + [ + "welt-prepare-data", + "--dataset_name", WIKITEXT_DATASET, + "--dataset_config", WIKITEXT_CONFIG, + "--train_split_units", "800", + "--validation_split_units", "200", + "--max_seq_length", "1024", + "--language", "eng_Latn", + "--seed", "42", + "--output_path", temp_output_dir, + ], + ) + main() + + prefix = get_shard_prefix(WIKITEXT_DATASET, WIKITEXT_CONFIG) + + # Verify split-aware shard files were created + train_files = shard_paths(temp_output_dir, f"{prefix}-train-*.jsonl.gz") + val_files = shard_paths(temp_output_dir, f"{prefix}-validation-*.jsonl.gz") + assert len(train_files) >= 1, f"Expected at least 1 train shard, got {len(train_files)}" + assert len(val_files) >= 1, f"Expected at least 1 validation shard, got {len(val_files)}" + + # No legacy (unsplit) shards should exist + all_shards = shard_paths(temp_output_dir) + assert len(all_shards) == len(train_files) + len(val_files) + + # Count examples per split + train_examples = read_shard_examples(temp_output_dir, f"{prefix}-train-*.jsonl.gz") + val_examples = read_shard_examples(temp_output_dir, f"{prefix}-validation-*.jsonl.gz") + for example in train_examples + val_examples: + assert "text" in example + + total_examples = len(train_examples) + len(val_examples) + assert total_examples > 0 + + # Verify validation fraction is roughly correct (20% +/- tolerance) + val_fraction = len(val_examples) / total_examples + assert 0.05 < val_fraction < 0.45, f"Expected ~20% validation, got {val_fraction:.1%}" + + # Verify per-split metadata + train_meta = read_metadata(temp_output_dir, "train") + val_meta = read_metadata(temp_output_dir, "validation") + assert train_meta["format"] == "welt-preprocessed-v1" + assert train_meta["num_examples"] == len(train_examples) + assert val_meta["num_examples"] == len(val_examples) + assert train_meta["num_shards"] == len(train_files) + assert val_meta["num_shards"] == len(val_files) + + +def test_load_prepared_data_split_aware(temp_output_dir, monkeypatch): + """Test that load_prepared_data detects and loads split-aware shards.""" + monkeypatch.setattr( + "sys.argv", + [ + "welt-prepare-data", + "--dataset_name", WIKITEXT_DATASET, + "--dataset_config", WIKITEXT_CONFIG, + "--train_split_units", "800", + "--validation_split_units", "200", + "--max_seq_length", "1024", + "--language", "eng_Latn", + "--seed", "42", + "--output_path", temp_output_dir, + ], + ) + main() + + # load_prepared_data should detect split-aware files and load them directly + result = load_prepared_data(temp_output_dir) + assert "train" in result + assert "validation" in result + assert len(result["train"]) > 0 + assert len(result["validation"]) > 0 + assert "text" in result["train"].features + assert "text" in result["validation"].features + + # Total should match what was prepared + train_meta = read_metadata(temp_output_dir, "train") + val_meta = read_metadata(temp_output_dir, "validation") + assert len(result["train"]) + len(result["validation"]) == train_meta["num_examples"] + val_meta["num_examples"] + + +def test_load_prepared_data_requires_some_shards(temp_output_dir): + """Test that load_prepared_data raises when no shards exist.""" + with pytest.raises(ValueError, match="No"): + load_prepared_data(temp_output_dir) + + +def test_prepare_data_train_only(temp_output_dir, monkeypatch): + """Test that setting validation_split_units=0 creates only train shards.""" + monkeypatch.setattr( + "sys.argv", + [ + "welt-prepare-data", + "--dataset_name", WIKITEXT_DATASET, + "--dataset_config", WIKITEXT_CONFIG, + "--train_split_units", "400", + "--max_seq_length", "1024", + "--language", "eng_Latn", + "--seed", "42", + "--output_path", temp_output_dir, + ], + ) + main() + + prefix = get_shard_prefix(WIKITEXT_DATASET, WIKITEXT_CONFIG) + train_files = shard_paths(temp_output_dir, f"{prefix}-train-*.jsonl.gz") + val_files = shard_paths(temp_output_dir, f"{prefix}-validation-*.jsonl.gz") + assert len(train_files) >= 1 + assert len(val_files) == 0 + + metadata = read_metadata(temp_output_dir, "train") + assert metadata["num_examples"] > 0 + assert not glob.glob(f"{temp_output_dir}/*-validation-metadata.json") + + result = load_prepared_data(temp_output_dir) + assert "train" in result + assert "validation" not in result + + +def test_prepare_data_validation_only(temp_output_dir, monkeypatch): + """Test that setting train_split_units=0 creates only validation shards.""" + monkeypatch.setattr( + "sys.argv", + [ + "welt-prepare-data", + "--dataset_name", WIKITEXT_DATASET, + "--dataset_config", WIKITEXT_CONFIG, + "--validation_split_units", "200", + "--max_seq_length", "1024", + "--language", "eng_Latn", + "--seed", "42", + "--output_path", temp_output_dir, + ], + ) + main() + + prefix = get_shard_prefix(WIKITEXT_DATASET, WIKITEXT_CONFIG) + train_files = shard_paths(temp_output_dir, f"{prefix}-train-*.jsonl.gz") + val_files = shard_paths(temp_output_dir, f"{prefix}-validation-*.jsonl.gz") + assert len(train_files) == 0 + assert len(val_files) >= 1 + + metadata = read_metadata(temp_output_dir, "validation") + assert metadata["num_examples"] > 0 + assert not glob.glob(f"{temp_output_dir}/*-train-metadata.json") + + result = load_prepared_data(temp_output_dir) + assert "train" not in result + assert "validation" in result + + +def test_prepare_data_with_id_column(temp_output_dir, monkeypatch): + """Test that --id_column preserves the source column as 'id' in output.""" + monkeypatch.setattr( + "sys.argv", + [ + "welt-prepare-data", + "--dataset_name", "HuggingFaceFW/fineweb-2", + "--dataset_config", "tur_Latn", + "--train_split_units", "400", + "--id_column", "id", + "--max_seq_length", "1024", + "--language", "tur_Latn", + "--seed", "42", + "--output_path", temp_output_dir, + ], + ) + main() + + for example in read_shard_examples(temp_output_dir): + assert "id" in example diff --git a/tests/test_verify_data.py b/tests/test_verify_data.py new file mode 100644 index 0000000..fdf18ea --- /dev/null +++ b/tests/test_verify_data.py @@ -0,0 +1,148 @@ +import gzip +import json +import shutil +import tempfile + +import pytest + +from welt_training.verify_data import verify + + +@pytest.fixture +def temp_dir(): + d = tempfile.mkdtemp(prefix="test_verify_data_") + yield d + shutil.rmtree(d, ignore_errors=True) + + +def write_metadata(path, metadata): + with open(path, "w") as f: + json.dump(metadata, f) + + +def write_shard(path, examples): + with gzip.open(path, "wt") as f: + for ex in examples: + f.write(json.dumps(ex) + "\n") + + +def make_metadata(split, num_examples, num_shards, **overrides): + base = { + "format": "welt-preprocessed-v1", + "split": split, + "num_examples": num_examples, + "num_shards": num_shards, + "total_units": num_examples * 10, + "unit_type": "words", + "source_dataset": "test-dataset", + "source_config": "test-config", + "source_split": "train", + "language": "eng_Latn", + "max_seq_length": None, + "seed": 42, + "text_column": "text", + "text_template": None, + "created_with_another_split": False, + } + base.update(overrides) + return base + + +def test_verify_empty_dir(temp_dir): + passed, messages = verify(temp_dir) + assert not passed + assert any("No" in m for m in messages) + + +def test_verify_consistent_single_split(temp_dir): + examples = [{"text": f"example {i}"} for i in range(5)] + write_shard(f"{temp_dir}/ds-train-00000000.jsonl.gz", examples) + write_metadata(f"{temp_dir}/ds-train-metadata.json", make_metadata("train", 5, 1)) + + passed, messages = verify(temp_dir) + assert passed + + +def test_verify_shard_count_mismatch(temp_dir): + examples = [{"text": "hello"}] + write_shard(f"{temp_dir}/ds-train-00000000.jsonl.gz", examples) + write_metadata(f"{temp_dir}/ds-train-metadata.json", make_metadata("train", 1, 2)) + + passed, messages = verify(temp_dir) + assert not passed + assert any("Expected 2 shards" in m for m in messages) + + +def test_verify_example_count_mismatch(temp_dir): + examples = [{"text": "hello"}, {"text": "world"}] + write_shard(f"{temp_dir}/ds-train-00000000.jsonl.gz", examples) + write_metadata(f"{temp_dir}/ds-train-metadata.json", make_metadata("train", 5, 1)) + + passed, messages = verify(temp_dir) + assert not passed + assert any("Expected 5 examples" in m for m in messages) + + +def test_verify_contamination_warning(temp_dir): + """Splits from same source created separately should warn.""" + write_shard(f"{temp_dir}/ds-train-00000000.jsonl.gz", [{"text": "a"}]) + write_shard(f"{temp_dir}/ds-validation-00000000.jsonl.gz", [{"text": "b"}]) + write_metadata(f"{temp_dir}/ds-train-metadata.json", + make_metadata("train", 1, 1, created_with_another_split=False)) + write_metadata(f"{temp_dir}/ds-validation-metadata.json", + make_metadata("validation", 1, 1, created_with_another_split=False)) + + passed, messages = verify(temp_dir) + assert passed # warning, not failure + assert any("WARN" in m and "overlap" in m for m in messages) + + +def test_verify_multiple_datasets(temp_dir): + """Multiple datasets in the same directory should be verified independently.""" + # Dataset A: 3 train examples in 1 shard + write_shard(f"{temp_dir}/ds-a-train-00000000.jsonl.gz", + [{"text": f"a{i}"} for i in range(3)]) + write_metadata(f"{temp_dir}/ds-a-train-metadata.json", + make_metadata("train", 3, 1, source_dataset="dataset-A")) + + # Dataset B: 2 train examples in 1 shard + write_shard(f"{temp_dir}/ds-b-train-00000000.jsonl.gz", + [{"text": f"b{i}"} for i in range(2)]) + write_metadata(f"{temp_dir}/ds-b-train-metadata.json", + make_metadata("train", 2, 1, source_dataset="dataset-B")) + + passed, messages = verify(temp_dir) + assert passed + assert any("ds-a/train" in m for m in messages) + assert any("ds-b/train" in m for m in messages) + + +def test_verify_no_contamination_when_created_together(temp_dir): + """Splits created together should pass contamination check.""" + write_shard(f"{temp_dir}/ds-train-00000000.jsonl.gz", [{"text": "a"}]) + write_shard(f"{temp_dir}/ds-validation-00000000.jsonl.gz", [{"text": "b"}]) + write_metadata(f"{temp_dir}/ds-train-metadata.json", + make_metadata("train", 1, 1, created_with_another_split=True)) + write_metadata(f"{temp_dir}/ds-validation-metadata.json", + make_metadata("validation", 1, 1, created_with_another_split=True)) + + passed, messages = verify(temp_dir) + assert passed + assert any("no overlap" in m for m in messages) + assert any("ds/contamination" in m for m in messages) + + +def test_verify_no_contamination_different_sources(temp_dir): + """Splits from different sources should skip contamination check.""" + write_shard(f"{temp_dir}/ds1-train-00000000.jsonl.gz", [{"text": "a"}]) + write_shard(f"{temp_dir}/ds2-validation-00000000.jsonl.gz", [{"text": "b"}]) + write_metadata(f"{temp_dir}/ds1-train-metadata.json", + make_metadata("train", 1, 1, source_dataset="dataset-A")) + write_metadata(f"{temp_dir}/ds2-validation-metadata.json", + make_metadata("validation", 1, 1, source_dataset="dataset-B")) + + passed, messages = verify(temp_dir) + assert passed + # No contamination message since sources differ + assert not any("overlap" in m for m in messages) + assert not any("contamination" in m for m in messages) diff --git a/welt_training/args_data.py b/welt_training/args_data.py index b9bc069..edad1c0 100644 --- a/welt_training/args_data.py +++ b/welt_training/args_data.py @@ -96,6 +96,12 @@ class DataTrainingArguments: keep_linebreaks: bool = field( default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} ) + prepared_data_path: str | None = field( + default=None, + metadata={ + "help": "Path to prepared dataset shards (from welt-prepare-data). Skips download and text extraction." + }, + ) def __post_init__(self): if self.streaming: @@ -117,8 +123,13 @@ def __post_init__(self): ) raise ValueError(msg) - if self.dataset_name is None and self.train_file is None and self.validation_file is None: - raise ValueError("Need either a dataset name or a training/validation file.") + if ( + self.dataset_name is None + and self.train_file is None + and self.validation_file is None + and self.prepared_data_path is None + ): + raise ValueError("Need either a dataset name, a training/validation file, or a prepared data path.") else: if self.train_file is not None: extension = self.train_file.split(".")[-1] diff --git a/welt_training/data_utils.py b/welt_training/data_utils.py new file mode 100644 index 0000000..dbbdfd7 --- /dev/null +++ b/welt_training/data_utils.py @@ -0,0 +1,61 @@ +import glob +import logging +import os + +from datasets import load_dataset + +logger = logging.getLogger(__name__) + + +def extract_text(example: dict, text_column: str = "text", text_template: str | None = None) -> str: + """Extract text from a dataset example using a column name or format template.""" + if text_template is not None: + return text_template.format(**example) + return example[text_column] + + +def find_shard_files(data_path: str, split_name: str, prefix: str | None = None) -> list[str]: + """Return sorted shard files for a given split in a prepared data directory. + + When *prefix* is given, only shards for that specific dataset are matched. + Without it, shards from all datasets in the directory are returned (used by + :func:`load_prepared_data` to load multi-dataset mixtures). + """ + name = f"{prefix}-{split_name}" if prefix else f"*-{split_name}" + return sorted(glob.glob(os.path.join(data_path, f"{name}-*.jsonl.gz"))) + + +def load_prepared_data(prepared_data_path: str): + """Load preprocessed shards produced by prepare_data.py. + + Loads ``{prefix}-train-*.jsonl.gz`` shards as the train split and/or + ``{prefix}-validation-*.jsonl.gz`` shards as the validation split. + + At least one split must be present. Missing splits are omitted from the result. + + Args: + prepared_data_path: Directory containing ``*.jsonl.gz`` shard files. + + Returns: + A dict with ``"train"`` and/or ``"validation"`` datasets. + """ + train_files = find_shard_files(prepared_data_path, "train") + validation_files = find_shard_files(prepared_data_path, "validation") + + if not train_files and not validation_files: + raise ValueError( + f"No *-train-*.jsonl.gz or *-validation-*.jsonl.gz files found in {prepared_data_path}. " + "Prepare data with --train_split_units and/or --validation_split_units." + ) + + result = {} + if train_files: + result["train"] = load_dataset("json", data_files=train_files, split="train") + if validation_files: + result["validation"] = load_dataset("json", data_files=validation_files, split="train") + + logger.info( + f"Loading prepared data: {len(train_files)} train shard(s), " + f"{len(validation_files)} validation shard(s) from {prepared_data_path}" + ) + return result diff --git a/welt_training/experiments/machine-translation/run_clm.py b/welt_training/experiments/machine-translation/run_clm.py index 38cd19e..3cbd407 100644 --- a/welt_training/experiments/machine-translation/run_clm.py +++ b/welt_training/experiments/machine-translation/run_clm.py @@ -48,7 +48,7 @@ import datasets import evaluate import torch -from datasets import IterableDataset, IterableDatasetDict, load_dataset +from datasets import DatasetDict, IterableDataset, IterableDatasetDict, load_dataset import transformers from transformers import ( @@ -69,6 +69,8 @@ from transformers.utils import check_min_version, send_example_telemetry from transformers.utils.versions import require_version +from welt_training.data_utils import load_prepared_data + # Will error if the minimal version of Transformers is not installed. Remove at your own risks. # check_min_version("4.57.0.dev0") @@ -245,13 +247,19 @@ class DataTrainingArguments: ) }, ) + prepared_data_path: Optional[str] = field( + default=None, + metadata={ + "help": "Path to prepared dataset shards (from welt-prepare-data). Skips download and text extraction." + }, + ) def __post_init__(self): if self.streaming: require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`") - if self.dataset_name is None and self.train_file is None and self.validation_file is None: - raise ValueError("Need either a dataset name or a training/validation file.") + if self.dataset_name is None and self.train_file is None and self.validation_file is None and self.prepared_data_path is None: + raise ValueError("Need either a dataset name, a training/validation file, or a prepared data path.") else: if self.train_file is not None: extension = self.train_file.split(".")[-1] @@ -310,6 +318,8 @@ def main(): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + elif len(sys.argv) == 2 and sys.argv[1].endswith((".yaml", ".yml")): + model_args, data_args, training_args = parser.parse_yaml_file(yaml_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() @@ -369,7 +379,12 @@ def main(): # # In distributed training, the load_dataset function guarantee that only one local process can concurrently # download the dataset. - if data_args.dataset_name is not None: + if data_args.prepared_data_path is not None: + if data_args.validation_split_percentage is not None: + logger.warning("Ignoring validation_split_percentage because prepared_data_path is set.") + raw_datasets = DatasetDict(load_prepared_data(data_args.prepared_data_path)) + + elif data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset( data_args.dataset_name, @@ -553,6 +568,7 @@ def mapping_function(x): desc="Keep only the text column & apply template", **map_args ) + column_names = [text_column_name] def tokenize_function(examples): diff --git a/welt_training/prepare_data.py b/welt_training/prepare_data.py new file mode 100644 index 0000000..d03ceca --- /dev/null +++ b/welt_training/prepare_data.py @@ -0,0 +1,358 @@ +""" +Data preparation script for offline use. + +Downloads HuggingFace datasets, samples raw text with unit-based limits, +and saves sharded .jsonl.gz files for offline training. +""" + +import argparse +import gzip +import json +import logging +from pathlib import Path + +from datasets import load_dataset +from words_segmentation.tokenizer import WordsSegmentationTokenizer + +from welt_training.data_utils import extract_text + +logger = logging.getLogger(__name__) + + +def get_shard_prefix(dataset_name: str, dataset_config: str | None) -> str: + """Derive a shard filename prefix from dataset name and config.""" + name = dataset_name.split("/")[-1] + if dataset_config: + return f"{name}-{dataset_config}" + return name + + +class ShardWriter: + """Manages writing sharded .jsonl.gz files for a single data split. + + Files are opened lazily on first write, so creating a writer for + a split that receives no data produces no files on disk. + """ + + def __init__(self, output_path: Path, prefix: str, split_name: str | None, num_units_per_file: int | None): + self.output_path = output_path + self.prefix = prefix + self.split_name = split_name + self.num_units_per_file = num_units_per_file + self._shard_index = 0 + self._shard_units = 0 + self._num_shards = 0 + self._current_file = None + self.total_units = 0 + self.num_examples = 0 + + def _shard_path(self, index: int) -> Path: + if self.split_name: + return self.output_path / f"{self.prefix}-{self.split_name}-{index:08d}.jsonl.gz" + return self.output_path / f"{self.prefix}-{index:08d}.jsonl.gz" + + def _open_shard(self): + path = self._shard_path(self._shard_index) + logger.info(f"Writing shard: {path.name}") + self._current_file = gzip.open(path, "wt") + + def write(self, record: dict, text_units: int): + if self._current_file is None: + self._open_shard() + self._current_file.write(json.dumps(record, ensure_ascii=False) + "\n") + self._shard_units += text_units + self.total_units += text_units + self.num_examples += 1 + + if self.num_units_per_file is not None and self._shard_units >= self.num_units_per_file: + self._current_file.close() + logger.info(f"Completed shard {self._shard_index} ({self._shard_units} units)") + self._num_shards += 1 + self._shard_index += 1 + self._shard_units = 0 + self._current_file = None + + def close(self): + if self._current_file is not None: + self._current_file.close() + self._current_file = None + self._num_shards += 1 + + def __enter__(self): + return self + + def __exit__(self, *exc): + self.close() + + @property + def num_shards(self) -> int: + return self._num_shards + + +def stream_texts(args): + """Stream raw text examples from a HuggingFace dataset.""" + load_kwargs = { + "path": args.dataset_name, + "split": args.dataset_split, + "streaming": True, + } + if args.dataset_config: + load_kwargs["name"] = args.dataset_config + + logger.info(f"Loading dataset: {args.dataset_name} (config: {args.dataset_config}, split: {args.dataset_split})") + stream = load_dataset(**load_kwargs) + + # Shuffle with seed + stream = stream.shuffle(seed=args.seed, buffer_size=args.shuffle_buffer_size) + + for example in stream: + text = extract_text(example, text_column=args.text_column, text_template=args.text_template) + if text: + id_value = example.get(args.id_column) if args.id_column else None + yield text, id_value + + +def stream_examples(args, pretokenizer: WordsSegmentationTokenizer): + """Stream text examples, optionally chunked by max_seq_length. + + Uses word segmentation to split text into words (handles Thai and other + languages without whitespace). When max_seq_length is set, long documents + are split into chunks of at most max_seq_length words. + + Yields (text, unit_count) tuples where unit_count matches args.unit_type. + When unit_type is "chars", characters are counted via token lengths + (whitespace excluded by the tokenizer). + """ + count_chars = args.unit_type == "chars" + + for text, id_value in stream_texts(args): + words = pretokenizer.tokenize(text) + unit_count = sum(len(w) for w in words) if count_chars else len(words) + + if args.max_seq_length is None: + yield text, unit_count, id_value + else: + for i in range(0, len(words), args.max_seq_length): + chunk_words = words[i:i + args.max_seq_length] + if args.drop_remainder and len(chunk_words) < args.max_seq_length: + continue + chunk_text = "".join(chunk_words) + chunk_units = sum(len(w) for w in chunk_words) if count_chars else len(chunk_words) + yield chunk_text, chunk_units, id_value + + +def main(): + parser = argparse.ArgumentParser( + description="Prepare HuggingFace datasets for offline training." + ) + + # Dataset arguments + parser.add_argument( + "--dataset_name", + type=str, + required=True, + help="HuggingFace dataset identifier (required)", + ) + parser.add_argument( + "--dataset_config", + type=str, + default=None, + help="Dataset config name (optional)", + ) + parser.add_argument( + "--dataset_split", + type=str, + default="train", + help="Split to use (default: 'train')", + ) + parser.add_argument( + "--text_column", + type=str, + default="text", + help="Column containing text (default: 'text')", + ) + parser.add_argument( + "--text_template", + type=str, + default=None, + help="Python format string template (optional)", + ) + parser.add_argument( + "--language", + type=str, + required=True, + help="Language tag to store with each example (e.g., 'eng_Latn')", + ) + parser.add_argument( + "--id_column", + type=str, + default=None, + help="Source column to preserve as 'id' in output (optional)", + ) + + # Processing arguments + parser.add_argument( + "--unit_type", + type=str, + choices=["words", "chars"], + default="words", + help="Unit type for counting (default: 'words')", + ) + parser.add_argument( + "--train_split_units", + type=int, + default=0, + help="Number of units for the train split (default: 0, meaning no train shards).", + ) + parser.add_argument( + "--num_units_per_file", + type=int, + default=None, + help="Max units per shard file. If not set, all data goes into one file.", + ) + parser.add_argument( + "--max_seq_length", + type=int, + default=None, + help="Max words per example. Long documents are split using word segmentation.", + ) + parser.add_argument( + "--max_bytes_per_word", + type=int, + default=126, + help="Max UTF-8 bytes per word. Words exceeding this are split. " + "Should match training config: max_word_length - 2 (default: 128 - 2 = 126).", + ) + parser.add_argument( + "--drop_remainder", + action="store_true", + help="Drop partial chunks when splitting documents by max_seq_length", + ) + parser.add_argument( + "--validation_split_units", + type=int, + default=0, + help="Number of units for the validation split (default: 0, meaning no validation shards). " + "Data is already shuffled before splitting.", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for shuffling", + ) + parser.add_argument( + "--shuffle_buffer_size", + type=int, + default=10000, + help="Buffer size for streaming shuffle (default: 10000)", + ) + + # Output arguments + parser.add_argument( + "--output_path", + type=str, + required=True, + help="Output directory path (required)", + ) + + args = parser.parse_args() + + if args.train_split_units <= 0 and args.validation_split_units <= 0: + parser.error("At least one of --train_split_units or --validation_split_units must be > 0.") + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + + # Create output directory + output_path = Path(args.output_path) + output_path.mkdir(parents=True, exist_ok=True) + + prefix = get_shard_prefix(args.dataset_name, args.dataset_config) + pretokenizer = WordsSegmentationTokenizer(max_bytes=args.max_bytes_per_word) + + logger.info("Starting data preparation...") + logger.info(f" unit_type={args.unit_type}, train_split_units={args.train_split_units}, " + f"num_units_per_file={args.num_units_per_file}, max_seq_length={args.max_seq_length}, " + f"language={args.language}") + + # Create shard writers + max_total_units = args.train_split_units + args.validation_split_units + logger.info(f" validation_split_units={args.validation_split_units}, max_total_units={max_total_units}") + with ( + ShardWriter(output_path, prefix, "train", args.num_units_per_file) as train_writer, + ShardWriter(output_path, prefix, "validation", args.num_units_per_file) as val_writer, + ): + total_units = 0 + total_examples = 0 + + for text, text_units, id_value in stream_examples(args, pretokenizer): + + # Check global limit + if total_units + text_units > max_total_units: + logger.info(f"Reached total units limit ({max_total_units})") + break + + record = {"text": text, "language": args.language} + if id_value is not None: + record["id"] = id_value + + # Fill validation first, then route to train + if val_writer.total_units < args.validation_split_units: + val_writer.write(record, text_units) + else: + train_writer.write(record, text_units) + + total_units += text_units + total_examples += 1 + + if total_examples % 10000 == 0: + logger.info(f"Processed {total_examples} examples, {total_units} total {args.unit_type}") + + if total_examples == 0: + logger.warning("No examples were written. Check dataset and filter settings.") + + # Save per-split metadata (both writers are closed, final counts are stable) + base_metadata = { + "format": "welt-preprocessed-v1", + "unit_type": args.unit_type, + "source_dataset": args.dataset_name, + "source_config": args.dataset_config, + "source_split": args.dataset_split, + "language": args.language, + "max_seq_length": args.max_seq_length, + "seed": args.seed, + "text_column": args.text_column, + "text_template": args.text_template, + } + + active_writers = [w for w in [train_writer, val_writer] if w.num_examples > 0] + created_with_another_split = len(active_writers) > 1 + + for writer in active_writers: + metadata = { + **base_metadata, + "split": writer.split_name, + "created_with_another_split": created_with_another_split, + "num_examples": writer.num_examples, + "total_units": writer.total_units, + "num_shards": writer.num_shards, + } + metadata_path = output_path / f"{prefix}-{writer.split_name}-metadata.json" + logger.info(f"Saving metadata to {metadata_path}") + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + + logger.info("Data preparation complete!") + logger.info(f" - Examples: {total_examples}") + logger.info(f" - Total {args.unit_type}: {total_units}") + logger.info(f" - Shards: {train_writer.num_shards + val_writer.num_shards}") + logger.info(f" - Output: {output_path}") + + +if __name__ == "__main__": + main() diff --git a/welt_training/train.py b/welt_training/train.py index 2cb950d..4cd10e6 100644 --- a/welt_training/train.py +++ b/welt_training/train.py @@ -21,6 +21,7 @@ from welt_training.args_data import DataTrainingArguments from welt_training.args_model import ModelArguments from welt_training.args_trainer import WeLTTrainingArguments +from welt_training.data_utils import extract_text, load_prepared_data from welt_training.extendable_yaml import resolve_yaml_file from welt_training.flops_callback import FlopsCallback from welt_training.freeze_callback import FreezeWarmupCallback @@ -174,6 +175,12 @@ def init_datasets(data_args: DataTrainingArguments, # noqa: C901 # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets. + # Load preprocessed data if path provided + if data_args.prepared_data_path is not None: + if data_args.validation_split_percentage is not None: + logger.warning("Ignoring validation_split_percentage because prepared_data_path is set.") + return load_prepared_data(data_args.prepared_data_path) + if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset( @@ -279,7 +286,7 @@ def process_split(dataset, split_name: str): template = data_args.dataset_text_template if template is None: def mapping_fn(example): - return {"text": example[text_column_name]} + return {"text": extract_text(example, text_column=text_column_name)} else: is_single_text_template = isinstance(template, str) single_text_template = template \ @@ -287,7 +294,7 @@ def mapping_fn(example): def mapping_fn(example): if is_single_text_template or split_name == "train": - return {"text": single_text_template.format(**example)} + return {"text": extract_text(example, text_template=single_text_template)} prefix = template[0].format(**example) completion = template[1].format(**example) @@ -368,6 +375,12 @@ def train(args: list[str] | None | str = None): # noqa: C901 trust_remote_code=model_args.trust_remote_code, do_train=training_args.do_train) + # Drop columns not needed for training (e.g. "language" from prepared data) + for split in list(text_datasets): + extra_cols = [c for c in text_datasets[split].column_names if c != "text"] + if extra_cols: + text_datasets[split] = text_datasets[split].remove_columns(extra_cols) + train_dataset = None if training_args.do_train: if "train" not in text_datasets: diff --git a/welt_training/verify_data.py b/welt_training/verify_data.py new file mode 100644 index 0000000..1d0a5a8 --- /dev/null +++ b/welt_training/verify_data.py @@ -0,0 +1,127 @@ +"""Verify integrity of prepared data directories. + +Checks metadata consistency, shard counts, example counts, +and warns about potential data contamination between splits. +""" + +import argparse +import glob +import gzip +import json +import os +import sys + +from welt_training.data_utils import find_shard_files + + +def discover_metadata(data_path): + """Find all per-split metadata files and return ``[(prefix, metadata)]``. + + Each metadata file is named ``{prefix}-{split}-metadata.json``. The prefix + is derived from the filename and the ``split`` field inside the JSON so that + multiple datasets in one directory are handled correctly. + """ + entries = [] + for path in sorted(glob.glob(os.path.join(data_path, "*-metadata.json"))): + with open(path) as f: + metadata = json.load(f) + split_name = metadata["split"] + filename = os.path.basename(path) + suffix = f"-{split_name}-metadata.json" + prefix = filename[: -len(suffix)] + entries.append((prefix, metadata)) + return entries + + +def count_shard_examples(shard_files): + """Count total examples across a list of .jsonl.gz shard files.""" + count = 0 + for path in shard_files: + with gzip.open(path, "rt") as f: + for _ in f: + count += 1 + return count + + +def verify(data_path): + """Run all verification checks. Returns (passed: bool, messages: list[str]).""" + messages = [] + passed = True + + # 1. Discover metadata (prefix-aware) + entries = discover_metadata(data_path) + if not entries: + messages.append("FAIL: No *-metadata.json files found.") + return False, messages + + prefixes = sorted({p for p, _ in entries}) + splits = sorted({m["split"] for _, m in entries}) + messages.append(f"Found {len(entries)} metadata file(s): {len(prefixes)} dataset(s), splits: {', '.join(splits)}") + + # 2. Per-dataset, per-split shard consistency + for prefix, metadata in entries: + split_name = metadata["split"] + label = f"{prefix}/{split_name}" + shard_files = find_shard_files(data_path, split_name, prefix=prefix) + + expected_shards = metadata["num_shards"] + actual_shards = len(shard_files) + if actual_shards != expected_shards: + messages.append( + f"FAIL [{label}]: Expected {expected_shards} shards, found {actual_shards}." + ) + passed = False + else: + messages.append(f"OK [{label}]: {actual_shards} shard(s)") + + expected_examples = metadata["num_examples"] + actual_examples = count_shard_examples(shard_files) + if actual_examples != expected_examples: + messages.append( + f"FAIL [{label}]: Expected {expected_examples} examples, found {actual_examples}." + ) + passed = False + else: + messages.append(f"OK [{label}]: {actual_examples} example(s)") + + # 3. Data contamination check (per dataset) + by_prefix = {} + for prefix, metadata in entries: + by_prefix.setdefault(prefix, {})[metadata["split"]] = metadata + + for prefix, split_metas in by_prefix.items(): + if "train" not in split_metas or "validation" not in split_metas: + continue + train_meta = split_metas["train"] + val_meta = split_metas["validation"] + same_source = ( + train_meta["source_dataset"] == val_meta["source_dataset"] + and train_meta["source_config"] == val_meta["source_config"] + and train_meta["source_split"] == val_meta["source_split"] + ) + if same_source: + if not train_meta.get("created_with_another_split") or not val_meta.get("created_with_another_split"): + messages.append( + f"WARN [{prefix}]: Train and validation share the same source but were not created together. " + "Examples may overlap." + ) + else: + messages.append(f"OK [{prefix}/contamination]: Splits created together, no overlap risk.") + + return passed, messages + + +def main(): + parser = argparse.ArgumentParser(description="Verify integrity of prepared data.") + parser.add_argument("--data_path", type=str, required=True, help="Path to prepared data directory.") + args = parser.parse_args() + + passed, messages = verify(args.data_path) + for msg in messages: + print(msg) + + sys.exit(0 if passed else 1) + + +if __name__ == "__main__": + main()