From a16808beb3d321ecd074df25c94c783eb607beae Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Wed, 11 Feb 2026 01:11:56 +0100 Subject: [PATCH 01/22] implement pretraining data preparation script, which creates a subset of the specified data resource. --- pyproject.toml | 1 + tests/test_prepare_data.py | 164 ++++++++++++++++++++ welt_training/prepare_data.py | 273 ++++++++++++++++++++++++++++++++++ 3 files changed, 438 insertions(+) create mode 100644 tests/test_prepare_data.py create mode 100644 welt_training/prepare_data.py diff --git a/pyproject.toml b/pyproject.toml index 037842a..bc8515e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ train = [ "scikit-learn", # For "accuracy" metric in evaluate "sacrebleu", # For usual bleu/chrF metrics "wandb", # For experiment tracking + "zstandard", # For compressing the data ] dion = [ diff --git a/tests/test_prepare_data.py b/tests/test_prepare_data.py new file mode 100644 index 0000000..80f6ac0 --- /dev/null +++ b/tests/test_prepare_data.py @@ -0,0 +1,164 @@ +import glob +import gzip +import json +import shutil +import tempfile + +import pytest +from datasets import load_dataset + +from welt_training.prepare_data import get_shard_prefix, main + + +# --- get_shard_prefix --- + + +def test_get_shard_prefix_with_org(): + assert get_shard_prefix("HuggingFaceFW/fineweb", None) == "fineweb" + + +def test_get_shard_prefix_with_config(): + assert get_shard_prefix("HuggingFaceFW/fineweb", "sample-10BT") == "fineweb-sample-10BT" + + +def test_get_shard_prefix_no_org(): + assert get_shard_prefix("wikitext", "wikitext-2-raw-v1") == "wikitext-wikitext-2-raw-v1" + + +# --- Integration tests --- + + +@pytest.fixture +def temp_output_dir(): + temp_dir = tempfile.mkdtemp(prefix="test_prepare_data_") + yield temp_dir + shutil.rmtree(temp_dir, ignore_errors=True) + + +def test_prepare_data_creates_shards(temp_output_dir, monkeypatch): + """Test that welt-prepare-data creates sharded .jsonl.gz files that can be loaded.""" + monkeypatch.setattr( + "sys.argv", + [ + "welt-prepare-data", + "--dataset_name", "wikitext", + "--dataset_config", "wikitext-2-raw-v1", + "--max_total_units", "500", + "--num_units_per_file", "200", + "--seed", "42", + "--output_path", temp_output_dir, + ], + ) + main() + + # Verify shards were created + shard_files = sorted(glob.glob(f"{temp_output_dir}/*.jsonl.gz")) + assert len(shard_files) >= 2, f"Expected at least 2 shards, got {len(shard_files)}" + + # Verify metadata + with open(f"{temp_output_dir}/metadata.json") as f: + metadata = json.load(f) + assert metadata["format"] == "welt-preprocessed-v1" + assert metadata["total_units"] <= 500 + assert metadata["unit_type"] == "words" + assert metadata["num_shards"] == len(shard_files) + + # Verify each shard is valid gzipped JSONL with a "text" field + total_examples = 0 + for path in shard_files: + with gzip.open(path, "rt") as f: + for line in f: + example = json.loads(line) + assert "text" in example + assert isinstance(example["text"], str) + total_examples += 1 + assert total_examples == metadata["num_examples"] + + # Verify loading with HuggingFace datasets (same path as train.py) + ds = load_dataset("json", data_files=shard_files, split="train") + assert len(ds) == total_examples + assert "text" in ds.features + + +def test_prepare_data_with_language(temp_output_dir, monkeypatch): + """Test that --language stores language metadata in each example.""" + monkeypatch.setattr( + "sys.argv", + [ + "welt-prepare-data", + "--dataset_name", "wikitext", + "--dataset_config", "wikitext-2-raw-v1", + "--max_total_units", "200", + "--language", "eng_Latn", + "--seed", "42", + "--output_path", temp_output_dir, + ], + ) + main() + + shard_files = sorted(glob.glob(f"{temp_output_dir}/*.jsonl.gz")) + for path in shard_files: + with gzip.open(path, "rt") as f: + for line in f: + example = json.loads(line) + assert example["language"] == "eng_Latn" + + with open(f"{temp_output_dir}/metadata.json") as f: + metadata = json.load(f) + assert metadata["language"] == "eng_Latn" + + +def test_prepare_data_unit_type_chars(temp_output_dir, monkeypatch): + """Test that --unit_type chars counts characters instead of words.""" + monkeypatch.setattr( + "sys.argv", + [ + "welt-prepare-data", + "--dataset_name", "wikitext", + "--dataset_config", "wikitext-2-raw-v1", + "--max_total_units", "500", + "--unit_type", "chars", + "--seed", "42", + "--output_path", temp_output_dir, + ], + ) + main() + + with open(f"{temp_output_dir}/metadata.json") as f: + metadata = json.load(f) + assert metadata["unit_type"] == "chars" + assert metadata["total_units"] <= 500 + + +def test_prepare_data_with_max_seq_length(temp_output_dir, monkeypatch): + """Test that --max_seq_length splits long documents into chunks.""" + monkeypatch.setattr( + "sys.argv", + [ + "welt-prepare-data", + "--dataset_name", "wikitext", + "--dataset_config", "wikitext-2-raw-v1", + "--max_total_units", "500", + "--max_seq_length", "32", + "--seed", "42", + "--output_path", temp_output_dir, + ], + ) + main() + + with open(f"{temp_output_dir}/metadata.json") as f: + metadata = json.load(f) + assert metadata["max_seq_length"] == 32 + assert metadata["total_units"] <= 500 + + # Verify each example has at most max_seq_length words + from words_segmentation.tokenizer import WordsSegmentationTokenizer + pretokenizer = WordsSegmentationTokenizer() + + shard_files = sorted(glob.glob(f"{temp_output_dir}/*.jsonl.gz")) + for path in shard_files: + with gzip.open(path, "rt") as f: + for line in f: + example = json.loads(line) + words = pretokenizer.tokenize(example["text"]) + assert len(words) <= 32 diff --git a/welt_training/prepare_data.py b/welt_training/prepare_data.py new file mode 100644 index 0000000..15490c8 --- /dev/null +++ b/welt_training/prepare_data.py @@ -0,0 +1,273 @@ +""" +Data preparation script for offline use. + +Downloads HuggingFace datasets, samples raw text with unit-based limits, +and saves sharded .jsonl.gz files for offline training. +""" + +import argparse +import gzip +import json +import logging +from pathlib import Path + +from datasets import load_dataset +from words_segmentation.tokenizer import WordsSegmentationTokenizer + +logger = logging.getLogger(__name__) + + +def get_shard_prefix(dataset_name: str, dataset_config: str | None) -> str: + """Derive a shard filename prefix from dataset name and config.""" + name = dataset_name.split("/")[-1] + if dataset_config: + return f"{name}-{dataset_config}" + return name + + +def stream_texts(args): + """Stream raw text examples from a HuggingFace dataset.""" + load_kwargs = { + "path": args.dataset_name, + "split": args.dataset_split, + "streaming": True, + } + if args.dataset_config: + load_kwargs["name"] = args.dataset_config + + logger.info(f"Loading dataset: {args.dataset_name} (config: {args.dataset_config}, split: {args.dataset_split})") + stream = load_dataset(**load_kwargs) + + # Shuffle with seed + stream = stream.shuffle(seed=args.seed, buffer_size=args.shuffle_buffer_size) + + for example in stream: + if args.text_template: + text = args.text_template.format(**example) + else: + text = example[args.text_column] + + if text: + yield text + + +def stream_examples(args, pretokenizer: WordsSegmentationTokenizer): + """Stream text examples, optionally chunked by max_seq_length. + + Uses word segmentation to split text into words (handles Thai and other + languages without whitespace). When max_seq_length is set, long documents + are split into chunks of at most max_seq_length words. + + Yields (text, num_words) tuples. + """ + for text in stream_texts(args): + words = pretokenizer.tokenize(text) + + if args.max_seq_length is None: + yield text, len(words) + continue + + for i in range(0, len(words), args.max_seq_length): + chunk_words = words[i:i + args.max_seq_length] + if args.drop_remainder and len(chunk_words) < args.max_seq_length: + continue + yield "".join(chunk_words), len(chunk_words) + + +def main(): + parser = argparse.ArgumentParser( + description="Prepare HuggingFace datasets for offline training." + ) + + # Dataset arguments + parser.add_argument( + "--dataset_name", + type=str, + required=True, + help="HuggingFace dataset identifier (required)", + ) + parser.add_argument( + "--dataset_config", + type=str, + default=None, + help="Dataset config name (optional)", + ) + parser.add_argument( + "--dataset_split", + type=str, + default="train", + help="Split to use (default: 'train')", + ) + parser.add_argument( + "--text_column", + type=str, + default="text", + help="Column containing text (default: 'text')", + ) + parser.add_argument( + "--text_template", + type=str, + default=None, + help="Python format string template (optional)", + ) + parser.add_argument( + "--language", + type=str, + default=None, + help="Language tag to store with each example (e.g., 'eng_Latn')", + ) + + # Processing arguments + parser.add_argument( + "--unit_type", + type=str, + choices=["words", "chars"], + default="words", + help="Unit type for counting (default: 'words')", + ) + parser.add_argument( + "--max_total_units", + type=int, + default=None, + help="Max total units to sample. If not set, processes the entire dataset.", + ) + parser.add_argument( + "--num_units_per_file", + type=int, + default=None, + help="Max units per shard file. If not set, all data goes into one file.", + ) + parser.add_argument( + "--max_seq_length", + type=int, + default=None, + help="Max words per example. Long documents are split using word segmentation.", + ) + parser.add_argument( + "--drop_remainder", + action="store_true", + help="Drop partial chunks when splitting documents by max_seq_length", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for shuffling", + ) + parser.add_argument( + "--shuffle_buffer_size", + type=int, + default=10000, + help="Buffer size for streaming shuffle (default: 10000)", + ) + + # Output arguments + parser.add_argument( + "--output_path", + type=str, + required=True, + help="Output directory path (required)", + ) + + args = parser.parse_args() + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + + # Create output directory + output_path = Path(args.output_path) + output_path.mkdir(parents=True, exist_ok=True) + + prefix = get_shard_prefix(args.dataset_name, args.dataset_config) + pretokenizer = WordsSegmentationTokenizer() + + logger.info("Starting data preparation...") + logger.info(f" unit_type={args.unit_type}, max_total_units={args.max_total_units}, " + f"num_units_per_file={args.num_units_per_file}, max_seq_length={args.max_seq_length}, " + f"language={args.language}") + + shard_index = 0 + shard_units = 0 + total_units = 0 + num_examples = 0 + + def open_shard(index: int): + path = output_path / f"{prefix}-{index:05d}.jsonl.gz" + logger.info(f"Writing shard: {path.name}") + return gzip.open(path, "wt") + + current_file = open_shard(shard_index) + + for text, num_words in stream_examples(args, pretokenizer): + text_units = num_words if args.unit_type == "words" else len(text) + + # Check global limit + if args.max_total_units is not None and total_units + text_units > args.max_total_units: + logger.info(f"Reached max_total_units limit ({args.max_total_units})") + break + + record = {"text": text} + if args.language: + record["language"] = args.language + + current_file.write(json.dumps(record, ensure_ascii=False) + "\n") + shard_units += text_units + total_units += text_units + num_examples += 1 + + if num_examples % 10000 == 0: + logger.info(f"Processed {num_examples} examples, {total_units} total {args.unit_type}") + + # Check shard limit + if args.num_units_per_file is not None and shard_units >= args.num_units_per_file: + current_file.close() + logger.info(f"Completed shard {shard_index} ({shard_units} {args.unit_type})") + shard_index += 1 + shard_units = 0 + current_file = open_shard(shard_index) + + current_file.close() + + # Remove empty last shard + last_shard_path = output_path / f"{prefix}-{shard_index:05d}.jsonl.gz" + if shard_units == 0 and shard_index > 0: + last_shard_path.unlink() + shard_index -= 1 + + num_shards = shard_index + 1 + + # Save metadata + metadata = { + "format": "welt-preprocessed-v1", + "num_examples": num_examples, + "total_units": total_units, + "unit_type": args.unit_type, + "num_shards": num_shards, + "source_dataset": args.dataset_name, + "source_config": args.dataset_config, + "source_split": args.dataset_split, + "language": args.language, + "max_seq_length": args.max_seq_length, + "seed": args.seed, + "text_column": args.text_column, + "text_template": args.text_template, + } + + metadata_path = output_path / f"{prefix}-metadata.json" + logger.info(f"Saving metadata to {metadata_path}") + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + + logger.info("Data preparation complete!") + logger.info(f" - Examples: {num_examples}") + logger.info(f" - Total {args.unit_type}: {total_units}") + logger.info(f" - Shards: {num_shards}") + logger.info(f" - Output: {output_path}") + + +if __name__ == "__main__": + main() From a7faae79570950114dd8e15535466e5c2647eaa4 Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Wed, 11 Feb 2026 01:14:17 +0100 Subject: [PATCH 02/22] adapt welt pretraining implementation to work with the new data format --- welt_training/args_data.py | 13 +++++++++++-- welt_training/train.py | 17 +++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/welt_training/args_data.py b/welt_training/args_data.py index b9bc069..b12624b 100644 --- a/welt_training/args_data.py +++ b/welt_training/args_data.py @@ -96,6 +96,10 @@ class DataTrainingArguments: keep_linebreaks: bool = field( default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} ) + preprocessed_data_path: str | None = field( + default=None, + metadata={"help": "Path to preprocessed dataset (from welt-prepare-data). Skips download and pretokenization."}, + ) def __post_init__(self): if self.streaming: @@ -117,8 +121,13 @@ def __post_init__(self): ) raise ValueError(msg) - if self.dataset_name is None and self.train_file is None and self.validation_file is None: - raise ValueError("Need either a dataset name or a training/validation file.") + if ( + self.dataset_name is None + and self.train_file is None + and self.validation_file is None + and self.preprocessed_data_path is None + ): + raise ValueError("Need either a dataset name, a training/validation file, or a preprocessed data path.") else: if self.train_file is not None: extension = self.train_file.split(".")[-1] diff --git a/welt_training/train.py b/welt_training/train.py index 2cb950d..e5b4416 100644 --- a/welt_training/train.py +++ b/welt_training/train.py @@ -174,6 +174,23 @@ def init_datasets(data_args: DataTrainingArguments, # noqa: C901 # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets. + # Load preprocessed data if path provided + if data_args.preprocessed_data_path is not None: + import glob + data_files = sorted(glob.glob(os.path.join(data_args.preprocessed_data_path, "*.jsonl.gz"))) + if not data_files: + raise ValueError(f"No .jsonl.gz files found in {data_args.preprocessed_data_path}") + logger.info(f"Loading preprocessed data from {len(data_files)} shard(s) in {data_args.preprocessed_data_path}") + train_data = load_dataset("json", data_files=data_files, split="train") + + # Create validation split if needed + if data_args.validation_split_percentage: + split = train_data.train_test_split( + test_size=data_args.validation_split_percentage / 100, seed=42 + ) + return {"train": split["train"], "validation": split["test"]} + return {"train": train_data} + if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset( From 371e8ffa7251a6fd6479b7428541189c563be24b Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Wed, 11 Feb 2026 01:14:57 +0100 Subject: [PATCH 03/22] document how to run the data preparation script --- README.md | 60 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/README.md b/README.md index 83bd2d0..d8c7aed 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,66 @@ You can also turn off a specific encoder after training has completed, for testi > all the embeddings of the previous tokens (on the word level). This is done since not all causal LMs support > cross-attention, and so we want to avoid using it, and rely on the self-attention mechanism instead. +## Data Preparation + +You can prepare datasets offline using the `welt-prepare-data` CLI. +It streams a HuggingFace dataset, samples raw text with unit-based limits, and writes sharded `.jsonl.gz` files: + +```shell +welt-prepare-data \ + --dataset_name HuggingFaceFW/fineweb \ + --dataset_config sample-10BT \ + --max_total_units 3200000000 \ + --num_units_per_file 100000000 \ + --max_seq_length 512 \ + --seed 42 \ + --output_path /scratch/data/pretrain +``` + +Multiple datasets can be prepared into the same output directory: + +```shell +welt-prepare-data \ + --dataset_name monology/pile-uncopyrighted \ + --max_total_units 1000000000 \ + --num_units_per_file 100000000 \ + --max_seq_length 512 \ + --output_path /scratch/data/pretrain +``` + +The output directory contains sharded `.jsonl.gz` files and a `metadata.json` per run: + +``` +/scratch/data/pretrain/ +├── fineweb-sample-10BT-00000.jsonl.gz +├── fineweb-sample-10BT-00001.jsonl.gz +├── pile-uncopyrighted-00000.jsonl.gz +└── metadata.json +``` + +Then train using the preprocessed data: + +```shell +welt-train config.yaml --preprocessed_data_path /scratch/data/pretrain +``` + +| Argument | Description | +|----------|-------------| +| `--dataset_name` | HuggingFace dataset identifier (required) | +| `--dataset_config` | Dataset config name (optional) | +| `--dataset_split` | Split to use (default: "train") | +| `--text_column` | Column containing text (default: "text") | +| `--text_template` | Python format string template (optional) | +| `--language` | Language tag to store with each example (e.g., "eng_Latn") | +| `--unit_type` | Unit type for counting: "words" or "chars" (default: "words") | +| `--max_total_units` | Max total units to sample (optional) | +| `--num_units_per_file` | Max units per shard file (optional) | +| `--max_seq_length` | Max words per example; splits long documents using word segmentation | +| `--max_bytes_per_word` | Max bytes per word for word segmentation (default: 126) | +| `--seed` | Random seed for shuffling | +| `--drop_remainder` | Drop partial chunks at document boundaries | +| `--output_path` | Output directory path (required) | + ## Training Training instructions are available in the [welt_training/README.md](./welt_training/README.md). From d95038d99477ae0f13648a5580ca6b5d64f09040 Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Wed, 11 Feb 2026 15:58:07 +0100 Subject: [PATCH 04/22] make shard limit naming more flexible (100_000 -> 100_000_000) --- welt_training/prepare_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/welt_training/prepare_data.py b/welt_training/prepare_data.py index 15490c8..8bcffc8 100644 --- a/welt_training/prepare_data.py +++ b/welt_training/prepare_data.py @@ -196,7 +196,7 @@ def main(): num_examples = 0 def open_shard(index: int): - path = output_path / f"{prefix}-{index:05d}.jsonl.gz" + path = output_path / f"{prefix}-{index:08d}.jsonl.gz" logger.info(f"Writing shard: {path.name}") return gzip.open(path, "wt") @@ -233,7 +233,7 @@ def open_shard(index: int): current_file.close() # Remove empty last shard - last_shard_path = output_path / f"{prefix}-{shard_index:05d}.jsonl.gz" + last_shard_path = output_path / f"{prefix}-{shard_index:08d}.jsonl.gz" if shard_units == 0 and shard_index > 0: last_shard_path.unlink() shard_index -= 1 From 2986d4d89909b163f5c02fda378c890e6dc7f149 Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Wed, 11 Feb 2026 16:03:33 +0100 Subject: [PATCH 05/22] move the module import to top-level --- welt_training/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/welt_training/train.py b/welt_training/train.py index e5b4416..77ef6a3 100644 --- a/welt_training/train.py +++ b/welt_training/train.py @@ -4,6 +4,7 @@ import math import os import sys +import glob import datasets import transformers @@ -176,7 +177,6 @@ def init_datasets(data_args: DataTrainingArguments, # noqa: C901 # Load preprocessed data if path provided if data_args.preprocessed_data_path is not None: - import glob data_files = sorted(glob.glob(os.path.join(data_args.preprocessed_data_path, "*.jsonl.gz"))) if not data_files: raise ValueError(f"No .jsonl.gz files found in {data_args.preprocessed_data_path}") From 6843bf284aaa940139e1fc36f93bc1c823635bc2 Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Wed, 11 Feb 2026 16:52:26 +0100 Subject: [PATCH 06/22] fix bugs listed in the PR review --- README.md | 12 ++++++------ pyproject.toml | 1 + tests/test_prepare_data.py | 12 ++++++++---- welt_training/train.py | 8 +++++--- 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index d8c7aed..63808c0 100644 --- a/README.md +++ b/README.md @@ -123,14 +123,15 @@ welt-prepare-data \ --output_path /scratch/data/pretrain ``` -The output directory contains sharded `.jsonl.gz` files and a `metadata.json` per run: +The output directory contains sharded `.jsonl.gz` files and a `{prefix}-metadata.json` per dataset: ``` /scratch/data/pretrain/ -├── fineweb-sample-10BT-00000.jsonl.gz -├── fineweb-sample-10BT-00001.jsonl.gz -├── pile-uncopyrighted-00000.jsonl.gz -└── metadata.json +├── fineweb-sample-10BT-00000000.jsonl.gz +├── fineweb-sample-10BT-00000001.jsonl.gz +├── fineweb-sample-10BT-metadata.json +├── pile-uncopyrighted-00000000.jsonl.gz +└── pile-uncopyrighted-metadata.json ``` Then train using the preprocessed data: @@ -151,7 +152,6 @@ welt-train config.yaml --preprocessed_data_path /scratch/data/pretrain | `--max_total_units` | Max total units to sample (optional) | | `--num_units_per_file` | Max units per shard file (optional) | | `--max_seq_length` | Max words per example; splits long documents using word segmentation | -| `--max_bytes_per_word` | Max bytes per word for word segmentation (default: 126) | | `--seed` | Random seed for shuffling | | `--drop_remainder` | Drop partial chunks at document boundaries | | `--output_path` | Output directory path (required) | diff --git a/pyproject.toml b/pyproject.toml index bc8515e..be10423 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,3 +80,4 @@ testpaths = [ [project.scripts] welt-train = "welt_training.train:train" +welt-prepare-data = "welt_training.prepare_data:main" diff --git a/tests/test_prepare_data.py b/tests/test_prepare_data.py index 80f6ac0..6723534 100644 --- a/tests/test_prepare_data.py +++ b/tests/test_prepare_data.py @@ -56,7 +56,8 @@ def test_prepare_data_creates_shards(temp_output_dir, monkeypatch): assert len(shard_files) >= 2, f"Expected at least 2 shards, got {len(shard_files)}" # Verify metadata - with open(f"{temp_output_dir}/metadata.json") as f: + prefix = get_shard_prefix("wikitext", "wikitext-2-raw-v1") + with open(f"{temp_output_dir}/{prefix}-metadata.json") as f: metadata = json.load(f) assert metadata["format"] == "welt-preprocessed-v1" assert metadata["total_units"] <= 500 @@ -103,7 +104,8 @@ def test_prepare_data_with_language(temp_output_dir, monkeypatch): example = json.loads(line) assert example["language"] == "eng_Latn" - with open(f"{temp_output_dir}/metadata.json") as f: + prefix = get_shard_prefix("wikitext", "wikitext-2-raw-v1") + with open(f"{temp_output_dir}/{prefix}-metadata.json") as f: metadata = json.load(f) assert metadata["language"] == "eng_Latn" @@ -124,7 +126,8 @@ def test_prepare_data_unit_type_chars(temp_output_dir, monkeypatch): ) main() - with open(f"{temp_output_dir}/metadata.json") as f: + prefix = get_shard_prefix("wikitext", "wikitext-2-raw-v1") + with open(f"{temp_output_dir}/{prefix}-metadata.json") as f: metadata = json.load(f) assert metadata["unit_type"] == "chars" assert metadata["total_units"] <= 500 @@ -146,7 +149,8 @@ def test_prepare_data_with_max_seq_length(temp_output_dir, monkeypatch): ) main() - with open(f"{temp_output_dir}/metadata.json") as f: + prefix = get_shard_prefix("wikitext", "wikitext-2-raw-v1") + with open(f"{temp_output_dir}/{prefix}-metadata.json") as f: metadata = json.load(f) assert metadata["max_seq_length"] == 32 assert metadata["total_units"] <= 500 diff --git a/welt_training/train.py b/welt_training/train.py index 77ef6a3..edc6504 100644 --- a/welt_training/train.py +++ b/welt_training/train.py @@ -161,7 +161,8 @@ def detect_last_checkpoint(training_args: TrainingArguments): def init_datasets(data_args: DataTrainingArguments, # noqa: C901 trust_remote_code: bool, do_train: bool = True, - cache_dir: str = None): + cache_dir: str = None, + seed: int = 42): # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). @@ -186,7 +187,7 @@ def init_datasets(data_args: DataTrainingArguments, # noqa: C901 # Create validation split if needed if data_args.validation_split_percentage: split = train_data.train_test_split( - test_size=data_args.validation_split_percentage / 100, seed=42 + test_size=data_args.validation_split_percentage / 100, seed=seed ) return {"train": split["train"], "validation": split["test"]} return {"train": train_data} @@ -383,7 +384,8 @@ def train(args: list[str] | None | str = None): # noqa: C901 text_datasets = init_datasets(data_args, cache_dir=cache_dir, trust_remote_code=model_args.trust_remote_code, - do_train=training_args.do_train) + do_train=training_args.do_train, + seed=training_args.seed) train_dataset = None if training_args.do_train: From 630666e9728e06a0aa38a39729809209ccd3515c Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Wed, 11 Feb 2026 18:14:28 +0100 Subject: [PATCH 07/22] fix the issues raised in the PR review --- README.md | 5 +++-- tests/test_prepare_data.py | 2 +- welt_training/args_data.py | 8 ++++---- welt_training/prepare_data.py | 12 +++++++++++- welt_training/train.py | 10 +++++----- 5 files changed, 24 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 63808c0..a3c1cf7 100644 --- a/README.md +++ b/README.md @@ -134,10 +134,10 @@ The output directory contains sharded `.jsonl.gz` files and a `{prefix}-metadata └── pile-uncopyrighted-metadata.json ``` -Then train using the preprocessed data: +Then train using the prepared data: ```shell -welt-train config.yaml --preprocessed_data_path /scratch/data/pretrain +welt-train config.yaml --prepared_data_path /scratch/data/pretrain ``` | Argument | Description | @@ -152,6 +152,7 @@ welt-train config.yaml --preprocessed_data_path /scratch/data/pretrain | `--max_total_units` | Max total units to sample (optional) | | `--num_units_per_file` | Max units per shard file (optional) | | `--max_seq_length` | Max words per example; splits long documents using word segmentation | +| `--max_bytes_per_word` | Max UTF-8 bytes per word; should match training config `max_word_length - 2` (default: 126) | | `--seed` | Random seed for shuffling | | `--drop_remainder` | Drop partial chunks at document boundaries | | `--output_path` | Output directory path (required) | diff --git a/tests/test_prepare_data.py b/tests/test_prepare_data.py index 6723534..2c33154 100644 --- a/tests/test_prepare_data.py +++ b/tests/test_prepare_data.py @@ -157,7 +157,7 @@ def test_prepare_data_with_max_seq_length(temp_output_dir, monkeypatch): # Verify each example has at most max_seq_length words from words_segmentation.tokenizer import WordsSegmentationTokenizer - pretokenizer = WordsSegmentationTokenizer() + pretokenizer = WordsSegmentationTokenizer(max_bytes=126) shard_files = sorted(glob.glob(f"{temp_output_dir}/*.jsonl.gz")) for path in shard_files: diff --git a/welt_training/args_data.py b/welt_training/args_data.py index b12624b..d2b2786 100644 --- a/welt_training/args_data.py +++ b/welt_training/args_data.py @@ -96,9 +96,9 @@ class DataTrainingArguments: keep_linebreaks: bool = field( default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} ) - preprocessed_data_path: str | None = field( + prepared_data_path: str | None = field( default=None, - metadata={"help": "Path to preprocessed dataset (from welt-prepare-data). Skips download and pretokenization."}, + metadata={"help": "Path to prepared dataset shards (from welt-prepare-data). Skips download and text extraction."}, ) def __post_init__(self): @@ -125,9 +125,9 @@ def __post_init__(self): self.dataset_name is None and self.train_file is None and self.validation_file is None - and self.preprocessed_data_path is None + and self.prepared_data_path is None ): - raise ValueError("Need either a dataset name, a training/validation file, or a preprocessed data path.") + raise ValueError("Need either a dataset name, a training/validation file, or a prepared data path.") else: if self.train_file is not None: extension = self.train_file.split(".")[-1] diff --git a/welt_training/prepare_data.py b/welt_training/prepare_data.py index 8bcffc8..cd7a55a 100644 --- a/welt_training/prepare_data.py +++ b/welt_training/prepare_data.py @@ -143,6 +143,13 @@ def main(): default=None, help="Max words per example. Long documents are split using word segmentation.", ) + parser.add_argument( + "--max_bytes_per_word", + type=int, + default=126, + help="Max UTF-8 bytes per word. Words exceeding this are split. " + "Should match training config: max_word_length - 2 (default: 128 - 2 = 126).", + ) parser.add_argument( "--drop_remainder", action="store_true", @@ -183,7 +190,7 @@ def main(): output_path.mkdir(parents=True, exist_ok=True) prefix = get_shard_prefix(args.dataset_name, args.dataset_config) - pretokenizer = WordsSegmentationTokenizer() + pretokenizer = WordsSegmentationTokenizer(max_bytes=args.max_bytes_per_word) logger.info("Starting data preparation...") logger.info(f" unit_type={args.unit_type}, max_total_units={args.max_total_units}, " @@ -240,6 +247,9 @@ def open_shard(index: int): num_shards = shard_index + 1 + if num_examples == 0: + logger.warning("No examples were written. Check dataset and filter settings.") + # Save metadata metadata = { "format": "welt-preprocessed-v1", diff --git a/welt_training/train.py b/welt_training/train.py index edc6504..0ffe149 100644 --- a/welt_training/train.py +++ b/welt_training/train.py @@ -1,10 +1,10 @@ # Heavily adapted from # https://github.com/huggingface/transformers/edit/main/examples/pytorch/language-modeling/run_clm.py +import glob import logging import math import os import sys -import glob import datasets import transformers @@ -177,11 +177,11 @@ def init_datasets(data_args: DataTrainingArguments, # noqa: C901 # https://huggingface.co/docs/datasets/loading_datasets. # Load preprocessed data if path provided - if data_args.preprocessed_data_path is not None: - data_files = sorted(glob.glob(os.path.join(data_args.preprocessed_data_path, "*.jsonl.gz"))) + if data_args.prepared_data_path is not None: + data_files = sorted(glob.glob(os.path.join(data_args.prepared_data_path, "*.jsonl.gz"))) if not data_files: - raise ValueError(f"No .jsonl.gz files found in {data_args.preprocessed_data_path}") - logger.info(f"Loading preprocessed data from {len(data_files)} shard(s) in {data_args.preprocessed_data_path}") + raise ValueError(f"No .jsonl.gz files found in {data_args.prepared_data_path}") + logger.info(f"Loading prepared data from {len(data_files)} shard(s) in {data_args.prepared_data_path}") train_data = load_dataset("json", data_files=data_files, split="train") # Create validation split if needed From b81d46de3b35f1fb65f0fca9ccefcc2549eed138 Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Wed, 11 Feb 2026 18:20:30 +0100 Subject: [PATCH 08/22] fix lint errors --- tests/test_prepare_data.py | 1 - welt_training/args_data.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_prepare_data.py b/tests/test_prepare_data.py index 2c33154..65eea69 100644 --- a/tests/test_prepare_data.py +++ b/tests/test_prepare_data.py @@ -9,7 +9,6 @@ from welt_training.prepare_data import get_shard_prefix, main - # --- get_shard_prefix --- diff --git a/welt_training/args_data.py b/welt_training/args_data.py index d2b2786..edad1c0 100644 --- a/welt_training/args_data.py +++ b/welt_training/args_data.py @@ -98,7 +98,9 @@ class DataTrainingArguments: ) prepared_data_path: str | None = field( default=None, - metadata={"help": "Path to prepared dataset shards (from welt-prepare-data). Skips download and text extraction."}, + metadata={ + "help": "Path to prepared dataset shards (from welt-prepare-data). Skips download and text extraction." + }, ) def __post_init__(self): From b3667eaa0d4823a7bbaaa618145d8f7684e7969f Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Wed, 11 Feb 2026 18:36:45 +0100 Subject: [PATCH 09/22] implement data loading for the prepared data within the package, so we could import it and re-use when needed. --- welt_training/data_utils.py | 32 +++++++++++++++++++ .../machine-translation/run_clm.py | 21 ++++++++++-- welt_training/train.py | 20 ++++-------- 3 files changed, 56 insertions(+), 17 deletions(-) create mode 100644 welt_training/data_utils.py diff --git a/welt_training/data_utils.py b/welt_training/data_utils.py new file mode 100644 index 0000000..608c987 --- /dev/null +++ b/welt_training/data_utils.py @@ -0,0 +1,32 @@ +import glob +import logging +import os + +from datasets import load_dataset + +logger = logging.getLogger(__name__) + + +def load_prepared_data(prepared_data_path: str, validation_split_percentage: int | None = None, seed: int = 42): + """Load preprocessed shards produced by prepare_data.py. + + Args: + prepared_data_path: Directory containing ``*.jsonl.gz`` shard files. + validation_split_percentage: If set, split the data into train/validation. + seed: Random seed for the train/test split. + + Returns: + A dict with ``"train"`` (and optionally ``"validation"``) datasets. + """ + data_files = sorted(glob.glob(os.path.join(prepared_data_path, "*.jsonl.gz"))) + if not data_files: + raise ValueError(f"No .jsonl.gz files found in {prepared_data_path}") + logger.info(f"Loading prepared data from {len(data_files)} shard(s) in {prepared_data_path}") + train_data = load_dataset("json", data_files=data_files, split="train") + + if validation_split_percentage: + split = train_data.train_test_split( + test_size=validation_split_percentage / 100, seed=seed + ) + return {"train": split["train"], "validation": split["test"]} + return {"train": train_data} diff --git a/welt_training/experiments/machine-translation/run_clm.py b/welt_training/experiments/machine-translation/run_clm.py index 38cd19e..3218c23 100644 --- a/welt_training/experiments/machine-translation/run_clm.py +++ b/welt_training/experiments/machine-translation/run_clm.py @@ -69,6 +69,8 @@ from transformers.utils import check_min_version, send_example_telemetry from transformers.utils.versions import require_version +from welt_training.data_utils import load_prepared_data + # Will error if the minimal version of Transformers is not installed. Remove at your own risks. # check_min_version("4.57.0.dev0") @@ -245,13 +247,19 @@ class DataTrainingArguments: ) }, ) + prepared_data_path: Optional[str] = field( + default=None, + metadata={ + "help": "Path to prepared dataset shards (from welt-prepare-data). Skips download and text extraction." + }, + ) def __post_init__(self): if self.streaming: require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`") - if self.dataset_name is None and self.train_file is None and self.validation_file is None: - raise ValueError("Need either a dataset name or a training/validation file.") + if self.dataset_name is None and self.train_file is None and self.validation_file is None and self.prepared_data_path is None: + raise ValueError("Need either a dataset name, a training/validation file, or a prepared data path.") else: if self.train_file is not None: extension = self.train_file.split(".")[-1] @@ -369,7 +377,14 @@ def main(): # # In distributed training, the load_dataset function guarantee that only one local process can concurrently # download the dataset. - if data_args.dataset_name is not None: + if data_args.prepared_data_path is not None: + raw_datasets = load_prepared_data( + data_args.prepared_data_path, + validation_split_percentage=data_args.validation_split_percentage, + seed=training_args.seed, + ) + + elif data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset( data_args.dataset_name, diff --git a/welt_training/train.py b/welt_training/train.py index 0ffe149..a46db02 100644 --- a/welt_training/train.py +++ b/welt_training/train.py @@ -1,6 +1,5 @@ # Heavily adapted from # https://github.com/huggingface/transformers/edit/main/examples/pytorch/language-modeling/run_clm.py -import glob import logging import math import os @@ -22,6 +21,7 @@ from welt_training.args_data import DataTrainingArguments from welt_training.args_model import ModelArguments from welt_training.args_trainer import WeLTTrainingArguments +from welt_training.data_utils import load_prepared_data from welt_training.extendable_yaml import resolve_yaml_file from welt_training.flops_callback import FlopsCallback from welt_training.freeze_callback import FreezeWarmupCallback @@ -178,19 +178,11 @@ def init_datasets(data_args: DataTrainingArguments, # noqa: C901 # Load preprocessed data if path provided if data_args.prepared_data_path is not None: - data_files = sorted(glob.glob(os.path.join(data_args.prepared_data_path, "*.jsonl.gz"))) - if not data_files: - raise ValueError(f"No .jsonl.gz files found in {data_args.prepared_data_path}") - logger.info(f"Loading prepared data from {len(data_files)} shard(s) in {data_args.prepared_data_path}") - train_data = load_dataset("json", data_files=data_files, split="train") - - # Create validation split if needed - if data_args.validation_split_percentage: - split = train_data.train_test_split( - test_size=data_args.validation_split_percentage / 100, seed=seed - ) - return {"train": split["train"], "validation": split["test"]} - return {"train": train_data} + return load_prepared_data( + data_args.prepared_data_path, + validation_split_percentage=data_args.validation_split_percentage, + seed=seed, + ) if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. From fba018b896a79a88328e580a9f01cd4c32853c89 Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Wed, 11 Feb 2026 19:11:47 +0100 Subject: [PATCH 10/22] create train / validation splits at preparation time --- tests/test_prepare_data.py | 121 ++++++++++++++++++++++++++++ welt_training/data_utils.py | 32 +++++++- welt_training/prepare_data.py | 144 +++++++++++++++++++++++++--------- 3 files changed, 258 insertions(+), 39 deletions(-) diff --git a/tests/test_prepare_data.py b/tests/test_prepare_data.py index 65eea69..86381fb 100644 --- a/tests/test_prepare_data.py +++ b/tests/test_prepare_data.py @@ -7,6 +7,7 @@ import pytest from datasets import load_dataset +from welt_training.data_utils import load_prepared_data from welt_training.prepare_data import get_shard_prefix, main # --- get_shard_prefix --- @@ -165,3 +166,123 @@ def test_prepare_data_with_max_seq_length(temp_output_dir, monkeypatch): example = json.loads(line) words = pretokenizer.tokenize(example["text"]) assert len(words) <= 32 + + +def test_prepare_data_with_validation_split(temp_output_dir, monkeypatch): + """Test that --validation_split_percentage creates split-aware shards.""" + monkeypatch.setattr( + "sys.argv", + [ + "welt-prepare-data", + "--dataset_name", "wikitext", + "--dataset_config", "wikitext-2-raw-v1", + "--max_total_units", "1000", + "--validation_split_percentage", "20", + "--seed", "42", + "--output_path", temp_output_dir, + ], + ) + main() + + prefix = get_shard_prefix("wikitext", "wikitext-2-raw-v1") + + # Verify split-aware shard files were created + train_files = sorted(glob.glob(f"{temp_output_dir}/{prefix}-train-*.jsonl.gz")) + val_files = sorted(glob.glob(f"{temp_output_dir}/{prefix}-validation-*.jsonl.gz")) + assert len(train_files) >= 1, f"Expected at least 1 train shard, got {len(train_files)}" + assert len(val_files) >= 1, f"Expected at least 1 validation shard, got {len(val_files)}" + + # No legacy (unsplit) shards should exist + all_shards = sorted(glob.glob(f"{temp_output_dir}/*.jsonl.gz")) + assert len(all_shards) == len(train_files) + len(val_files) + + # Count examples per split + train_examples = 0 + for path in train_files: + with gzip.open(path, "rt") as f: + for line in f: + example = json.loads(line) + assert "text" in example + train_examples += 1 + + val_examples = 0 + for path in val_files: + with gzip.open(path, "rt") as f: + for line in f: + example = json.loads(line) + assert "text" in example + val_examples += 1 + + total_examples = train_examples + val_examples + assert total_examples > 0 + + # Verify validation fraction is roughly correct (20% +/- tolerance) + val_fraction = val_examples / total_examples + assert 0.05 < val_fraction < 0.45, f"Expected ~20% validation, got {val_fraction:.1%}" + + # Verify metadata + with open(f"{temp_output_dir}/{prefix}-metadata.json") as f: + metadata = json.load(f) + assert metadata["format"] == "welt-preprocessed-v1" + assert metadata["validation_split_percentage"] == 20 + assert metadata["num_examples"] == total_examples + assert "splits" in metadata + assert metadata["splits"]["train"]["num_examples"] == train_examples + assert metadata["splits"]["validation"]["num_examples"] == val_examples + assert metadata["splits"]["train"]["num_shards"] == len(train_files) + assert metadata["splits"]["validation"]["num_shards"] == len(val_files) + + +def test_load_prepared_data_split_aware(temp_output_dir, monkeypatch): + """Test that load_prepared_data detects and loads split-aware shards.""" + monkeypatch.setattr( + "sys.argv", + [ + "welt-prepare-data", + "--dataset_name", "wikitext", + "--dataset_config", "wikitext-2-raw-v1", + "--max_total_units", "1000", + "--validation_split_percentage", "20", + "--seed", "42", + "--output_path", temp_output_dir, + ], + ) + main() + + # load_prepared_data should detect split-aware files and load them directly + result = load_prepared_data(temp_output_dir) + assert "train" in result + assert "validation" in result + assert len(result["train"]) > 0 + assert len(result["validation"]) > 0 + assert "text" in result["train"].features + assert "text" in result["validation"].features + + # Total should match what was prepared + prefix = get_shard_prefix("wikitext", "wikitext-2-raw-v1") + with open(f"{temp_output_dir}/{prefix}-metadata.json") as f: + metadata = json.load(f) + assert len(result["train"]) + len(result["validation"]) == metadata["num_examples"] + + +def test_load_prepared_data_legacy(temp_output_dir, monkeypatch): + """Test that load_prepared_data falls back to legacy mode for unsplit shards.""" + monkeypatch.setattr( + "sys.argv", + [ + "welt-prepare-data", + "--dataset_name", "wikitext", + "--dataset_config", "wikitext-2-raw-v1", + "--max_total_units", "500", + "--seed", "42", + "--output_path", temp_output_dir, + ], + ) + main() + + # Without --validation_split_percentage, shards have no split marker + result = load_prepared_data(temp_output_dir, validation_split_percentage=10, seed=42) + assert "train" in result + assert "validation" in result + assert len(result["train"]) > 0 + assert len(result["validation"]) > 0 diff --git a/welt_training/data_utils.py b/welt_training/data_utils.py index 608c987..5292d87 100644 --- a/welt_training/data_utils.py +++ b/welt_training/data_utils.py @@ -10,14 +10,42 @@ def load_prepared_data(prepared_data_path: str, validation_split_percentage: int | None = None, seed: int = 42): """Load preprocessed shards produced by prepare_data.py. + Supports two shard naming conventions: + + - **Split-aware** (preferred): Files named ``{prefix}-train-*.jsonl.gz`` + and ``{prefix}-validation-*.jsonl.gz``. Train and validation datasets + are loaded directly from their respective shards, ensuring each source + dataset contributes proportionally to both splits. The + ``validation_split_percentage`` parameter is ignored in this mode. + + - **Legacy**: Files named ``{prefix}-*.jsonl.gz`` without split markers. + A random ``train_test_split`` is applied using + ``validation_split_percentage``. + Args: prepared_data_path: Directory containing ``*.jsonl.gz`` shard files. - validation_split_percentage: If set, split the data into train/validation. - seed: Random seed for the train/test split. + validation_split_percentage: If set, split the data into train/validation + (only used for legacy shards without split markers). + seed: Random seed for the train/test split (legacy mode only). Returns: A dict with ``"train"`` (and optionally ``"validation"``) datasets. """ + # Check for split-aware files produced with --validation_split_percentage + train_files = sorted(glob.glob(os.path.join(prepared_data_path, "*-train-*.jsonl.gz"))) + validation_files = sorted(glob.glob(os.path.join(prepared_data_path, "*-validation-*.jsonl.gz"))) + + if train_files: + logger.info( + f"Loading split-aware data: {len(train_files)} train shard(s), " + f"{len(validation_files)} validation shard(s) from {prepared_data_path}" + ) + result = {"train": load_dataset("json", data_files=train_files, split="train")} + if validation_files: + result["validation"] = load_dataset("json", data_files=validation_files, split="train") + return result + + # Legacy mode: all shards in a single pool data_files = sorted(glob.glob(os.path.join(prepared_data_path, "*.jsonl.gz"))) if not data_files: raise ValueError(f"No .jsonl.gz files found in {prepared_data_path}") diff --git a/welt_training/prepare_data.py b/welt_training/prepare_data.py index cd7a55a..45e0a96 100644 --- a/welt_training/prepare_data.py +++ b/welt_training/prepare_data.py @@ -25,6 +25,60 @@ def get_shard_prefix(dataset_name: str, dataset_config: str | None) -> str: return name +class ShardWriter: + """Manages writing sharded .jsonl.gz files for a single data split.""" + + def __init__(self, output_path: Path, prefix: str, split_name: str | None, num_units_per_file: int | None): + self.output_path = output_path + self.prefix = prefix + self.split_name = split_name + self.num_units_per_file = num_units_per_file + self.shard_index = 0 + self.shard_units = 0 + self.total_units = 0 + self.num_examples = 0 + self._current_file = self._open_shard() + + def _shard_path(self, index: int) -> Path: + if self.split_name: + return self.output_path / f"{self.prefix}-{self.split_name}-{index:08d}.jsonl.gz" + return self.output_path / f"{self.prefix}-{index:08d}.jsonl.gz" + + def _open_shard(self): + path = self._shard_path(self.shard_index) + logger.info(f"Writing shard: {path.name}") + return gzip.open(path, "wt") + + def write(self, record: dict, text_units: int): + self._current_file.write(json.dumps(record, ensure_ascii=False) + "\n") + self.shard_units += text_units + self.total_units += text_units + self.num_examples += 1 + + if self.num_units_per_file is not None and self.shard_units >= self.num_units_per_file: + self._current_file.close() + logger.info(f"Completed shard {self.shard_index} ({self.shard_units} units)") + self.shard_index += 1 + self.shard_units = 0 + self._current_file = self._open_shard() + + def close(self): + self._current_file.close() + # Remove empty last shard + if self.shard_units == 0 and self.shard_index > 0: + self._shard_path(self.shard_index).unlink() + self.shard_index -= 1 + # Remove shard file if no examples were written at all + if self.num_examples == 0: + self._shard_path(self.shard_index).unlink(missing_ok=True) + + @property + def num_shards(self) -> int: + if self.num_examples == 0: + return 0 + return self.shard_index + 1 + + def stream_texts(args): """Stream raw text examples from a HuggingFace dataset.""" load_kwargs = { @@ -155,6 +209,14 @@ def main(): action="store_true", help="Drop partial chunks when splitting documents by max_seq_length", ) + parser.add_argument( + "--validation_split_percentage", + type=int, + default=None, + help="Percentage of examples to assign to validation split (e.g., 5 for 5%%). " + "Requires --max_total_units. The first (100 - N)%% of units go to train, " + "the remaining N%% go to validation. Data is already shuffled before splitting.", + ) parser.add_argument( "--seed", type=int, @@ -197,17 +259,20 @@ def main(): f"num_units_per_file={args.num_units_per_file}, max_seq_length={args.max_seq_length}, " f"language={args.language}") - shard_index = 0 - shard_units = 0 - total_units = 0 - num_examples = 0 - - def open_shard(index: int): - path = output_path / f"{prefix}-{index:08d}.jsonl.gz" - logger.info(f"Writing shard: {path.name}") - return gzip.open(path, "wt") + # Create shard writers + if args.validation_split_percentage is not None: + if args.max_total_units is None: + parser.error("--max_total_units is required when using --validation_split_percentage") + train_budget = int(args.max_total_units * (100 - args.validation_split_percentage) / 100) + logger.info(f" validation_split_percentage={args.validation_split_percentage}, train_budget={train_budget}") + train_writer = ShardWriter(output_path, prefix, "train", args.num_units_per_file) + val_writer = ShardWriter(output_path, prefix, "validation", args.num_units_per_file) + else: + train_writer = ShardWriter(output_path, prefix, None, args.num_units_per_file) + val_writer = None - current_file = open_shard(shard_index) + total_units = 0 + total_examples = 0 for text, num_words in stream_examples(args, pretokenizer): text_units = num_words if args.unit_type == "words" else len(text) @@ -221,42 +286,30 @@ def open_shard(index: int): if args.language: record["language"] = args.language - current_file.write(json.dumps(record, ensure_ascii=False) + "\n") - shard_units += text_units - total_units += text_units - num_examples += 1 - - if num_examples % 10000 == 0: - logger.info(f"Processed {num_examples} examples, {total_units} total {args.unit_type}") - - # Check shard limit - if args.num_units_per_file is not None and shard_units >= args.num_units_per_file: - current_file.close() - logger.info(f"Completed shard {shard_index} ({shard_units} {args.unit_type})") - shard_index += 1 - shard_units = 0 - current_file = open_shard(shard_index) + # Once the train budget is filled, route remaining examples to validation + if val_writer is not None and train_writer.total_units + text_units > train_budget: + val_writer.write(record, text_units) + else: + train_writer.write(record, text_units) - current_file.close() + total_units += text_units + total_examples += 1 - # Remove empty last shard - last_shard_path = output_path / f"{prefix}-{shard_index:08d}.jsonl.gz" - if shard_units == 0 and shard_index > 0: - last_shard_path.unlink() - shard_index -= 1 + if total_examples % 10000 == 0: + logger.info(f"Processed {total_examples} examples, {total_units} total {args.unit_type}") - num_shards = shard_index + 1 + train_writer.close() - if num_examples == 0: + if total_examples == 0: logger.warning("No examples were written. Check dataset and filter settings.") # Save metadata metadata = { "format": "welt-preprocessed-v1", - "num_examples": num_examples, + "num_examples": total_examples, "total_units": total_units, "unit_type": args.unit_type, - "num_shards": num_shards, + "num_shards": train_writer.num_shards, "source_dataset": args.dataset_name, "source_config": args.dataset_config, "source_split": args.dataset_split, @@ -267,15 +320,32 @@ def open_shard(index: int): "text_template": args.text_template, } + if val_writer is not None: + val_writer.close() + metadata["num_shards"] += val_writer.num_shards + metadata["validation_split_percentage"] = args.validation_split_percentage + metadata["splits"] = { + "train": { + "num_examples": train_writer.num_examples, + "total_units": train_writer.total_units, + "num_shards": train_writer.num_shards, + }, + "validation": { + "num_examples": val_writer.num_examples, + "total_units": val_writer.total_units, + "num_shards": val_writer.num_shards, + }, + } + metadata_path = output_path / f"{prefix}-metadata.json" logger.info(f"Saving metadata to {metadata_path}") with open(metadata_path, "w") as f: json.dump(metadata, f, indent=2) logger.info("Data preparation complete!") - logger.info(f" - Examples: {num_examples}") + logger.info(f" - Examples: {total_examples}") logger.info(f" - Total {args.unit_type}: {total_units}") - logger.info(f" - Shards: {num_shards}") + logger.info(f" - Shards: {metadata['num_shards']}") logger.info(f" - Output: {output_path}") From 68003297b43447d0ab5628a3bd438a5a569005c6 Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Wed, 11 Feb 2026 19:56:10 +0100 Subject: [PATCH 11/22] separate train and validation splits at data preparation phase --- README.md | 7 ++- tests/test_prepare_data.py | 35 +++++------ welt_training/data_utils.py | 60 +++++++------------ .../machine-translation/run_clm.py | 8 +-- welt_training/prepare_data.py | 34 +++++------ welt_training/train.py | 8 +-- 6 files changed, 66 insertions(+), 86 deletions(-) diff --git a/README.md b/README.md index a3c1cf7..c522999 100644 --- a/README.md +++ b/README.md @@ -105,7 +105,7 @@ It streams a HuggingFace dataset, samples raw text with unit-based limits, and w welt-prepare-data \ --dataset_name HuggingFaceFW/fineweb \ --dataset_config sample-10BT \ - --max_total_units 3200000000 \ + --train_split_units 3200000000 \ --num_units_per_file 100000000 \ --max_seq_length 512 \ --seed 42 \ @@ -117,7 +117,7 @@ Multiple datasets can be prepared into the same output directory: ```shell welt-prepare-data \ --dataset_name monology/pile-uncopyrighted \ - --max_total_units 1000000000 \ + --train_split_units 1000000000 \ --num_units_per_file 100000000 \ --max_seq_length 512 \ --output_path /scratch/data/pretrain @@ -149,7 +149,8 @@ welt-train config.yaml --prepared_data_path /scratch/data/pretrain | `--text_template` | Python format string template (optional) | | `--language` | Language tag to store with each example (e.g., "eng_Latn") | | `--unit_type` | Unit type for counting: "words" or "chars" (default: "words") | -| `--max_total_units` | Max total units to sample (optional) | +| `--train_split_units` | Number of units for the train split (optional) | +| `--validation_split_units` | Number of units for the validation split (optional; requires `--train_split_units`) | | `--num_units_per_file` | Max units per shard file (optional) | | `--max_seq_length` | Max words per example; splits long documents using word segmentation | | `--max_bytes_per_word` | Max UTF-8 bytes per word; should match training config `max_word_length - 2` (default: 126) | diff --git a/tests/test_prepare_data.py b/tests/test_prepare_data.py index 86381fb..9f7aa5c 100644 --- a/tests/test_prepare_data.py +++ b/tests/test_prepare_data.py @@ -43,7 +43,7 @@ def test_prepare_data_creates_shards(temp_output_dir, monkeypatch): "welt-prepare-data", "--dataset_name", "wikitext", "--dataset_config", "wikitext-2-raw-v1", - "--max_total_units", "500", + "--train_split_units", "500", "--num_units_per_file", "200", "--seed", "42", "--output_path", temp_output_dir, @@ -89,7 +89,7 @@ def test_prepare_data_with_language(temp_output_dir, monkeypatch): "welt-prepare-data", "--dataset_name", "wikitext", "--dataset_config", "wikitext-2-raw-v1", - "--max_total_units", "200", + "--train_split_units", "200", "--language", "eng_Latn", "--seed", "42", "--output_path", temp_output_dir, @@ -118,7 +118,7 @@ def test_prepare_data_unit_type_chars(temp_output_dir, monkeypatch): "welt-prepare-data", "--dataset_name", "wikitext", "--dataset_config", "wikitext-2-raw-v1", - "--max_total_units", "500", + "--train_split_units", "500", "--unit_type", "chars", "--seed", "42", "--output_path", temp_output_dir, @@ -141,7 +141,7 @@ def test_prepare_data_with_max_seq_length(temp_output_dir, monkeypatch): "welt-prepare-data", "--dataset_name", "wikitext", "--dataset_config", "wikitext-2-raw-v1", - "--max_total_units", "500", + "--train_split_units", "500", "--max_seq_length", "32", "--seed", "42", "--output_path", temp_output_dir, @@ -169,15 +169,15 @@ def test_prepare_data_with_max_seq_length(temp_output_dir, monkeypatch): def test_prepare_data_with_validation_split(temp_output_dir, monkeypatch): - """Test that --validation_split_percentage creates split-aware shards.""" + """Test that --validation_split_units creates split-aware shards.""" monkeypatch.setattr( "sys.argv", [ "welt-prepare-data", "--dataset_name", "wikitext", "--dataset_config", "wikitext-2-raw-v1", - "--max_total_units", "1000", - "--validation_split_percentage", "20", + "--train_split_units", "800", + "--validation_split_units", "200", "--seed", "42", "--output_path", temp_output_dir, ], @@ -224,7 +224,7 @@ def test_prepare_data_with_validation_split(temp_output_dir, monkeypatch): with open(f"{temp_output_dir}/{prefix}-metadata.json") as f: metadata = json.load(f) assert metadata["format"] == "welt-preprocessed-v1" - assert metadata["validation_split_percentage"] == 20 + assert metadata["validation_split_units"] == 200 assert metadata["num_examples"] == total_examples assert "splits" in metadata assert metadata["splits"]["train"]["num_examples"] == train_examples @@ -241,8 +241,8 @@ def test_load_prepared_data_split_aware(temp_output_dir, monkeypatch): "welt-prepare-data", "--dataset_name", "wikitext", "--dataset_config", "wikitext-2-raw-v1", - "--max_total_units", "1000", - "--validation_split_percentage", "20", + "--train_split_units", "800", + "--validation_split_units", "200", "--seed", "42", "--output_path", temp_output_dir, ], @@ -265,24 +265,21 @@ def test_load_prepared_data_split_aware(temp_output_dir, monkeypatch): assert len(result["train"]) + len(result["validation"]) == metadata["num_examples"] -def test_load_prepared_data_legacy(temp_output_dir, monkeypatch): - """Test that load_prepared_data falls back to legacy mode for unsplit shards.""" +def test_load_prepared_data_requires_validation_shards(temp_output_dir, monkeypatch): + """Test that load_prepared_data raises when validation shards are missing.""" monkeypatch.setattr( "sys.argv", [ "welt-prepare-data", "--dataset_name", "wikitext", "--dataset_config", "wikitext-2-raw-v1", - "--max_total_units", "500", + "--train_split_units", "500", "--seed", "42", "--output_path", temp_output_dir, ], ) main() - # Without --validation_split_percentage, shards have no split marker - result = load_prepared_data(temp_output_dir, validation_split_percentage=10, seed=42) - assert "train" in result - assert "validation" in result - assert len(result["train"]) > 0 - assert len(result["validation"]) > 0 + # Without --validation_split_units, shards have no split marker → error + with pytest.raises(ValueError, match="validation"): + load_prepared_data(temp_output_dir) diff --git a/welt_training/data_utils.py b/welt_training/data_utils.py index 5292d87..3c82515 100644 --- a/welt_training/data_utils.py +++ b/welt_training/data_utils.py @@ -7,54 +7,40 @@ logger = logging.getLogger(__name__) -def load_prepared_data(prepared_data_path: str, validation_split_percentage: int | None = None, seed: int = 42): +def load_prepared_data(prepared_data_path: str): """Load preprocessed shards produced by prepare_data.py. - Supports two shard naming conventions: + Loads ``{prefix}-train-*.jsonl.gz`` shards as the train split and + ``{prefix}-validation-*.jsonl.gz`` shards as the validation split. - - **Split-aware** (preferred): Files named ``{prefix}-train-*.jsonl.gz`` - and ``{prefix}-validation-*.jsonl.gz``. Train and validation datasets - are loaded directly from their respective shards, ensuring each source - dataset contributes proportionally to both splits. The - ``validation_split_percentage`` parameter is ignored in this mode. - - - **Legacy**: Files named ``{prefix}-*.jsonl.gz`` without split markers. - A random ``train_test_split`` is applied using - ``validation_split_percentage``. + Both train and validation shards are required. Prepare data with + ``--validation_split_units`` to produce them. Args: prepared_data_path: Directory containing ``*.jsonl.gz`` shard files. - validation_split_percentage: If set, split the data into train/validation - (only used for legacy shards without split markers). - seed: Random seed for the train/test split (legacy mode only). Returns: - A dict with ``"train"`` (and optionally ``"validation"``) datasets. + A dict with ``"train"`` and ``"validation"`` datasets. """ - # Check for split-aware files produced with --validation_split_percentage train_files = sorted(glob.glob(os.path.join(prepared_data_path, "*-train-*.jsonl.gz"))) validation_files = sorted(glob.glob(os.path.join(prepared_data_path, "*-validation-*.jsonl.gz"))) - if train_files: - logger.info( - f"Loading split-aware data: {len(train_files)} train shard(s), " - f"{len(validation_files)} validation shard(s) from {prepared_data_path}" + if not train_files: + raise ValueError( + f"No *-train-*.jsonl.gz files found in {prepared_data_path}. " + "Prepare data with --train_split_units and --validation_split_units." ) - result = {"train": load_dataset("json", data_files=train_files, split="train")} - if validation_files: - result["validation"] = load_dataset("json", data_files=validation_files, split="train") - return result - - # Legacy mode: all shards in a single pool - data_files = sorted(glob.glob(os.path.join(prepared_data_path, "*.jsonl.gz"))) - if not data_files: - raise ValueError(f"No .jsonl.gz files found in {prepared_data_path}") - logger.info(f"Loading prepared data from {len(data_files)} shard(s) in {prepared_data_path}") - train_data = load_dataset("json", data_files=data_files, split="train") - - if validation_split_percentage: - split = train_data.train_test_split( - test_size=validation_split_percentage / 100, seed=seed + if not validation_files: + raise ValueError( + f"No *-validation-*.jsonl.gz files found in {prepared_data_path}. " + "Prepare data with --validation_split_units to create a validation split." ) - return {"train": split["train"], "validation": split["test"]} - return {"train": train_data} + + logger.info( + f"Loading prepared data: {len(train_files)} train shard(s), " + f"{len(validation_files)} validation shard(s) from {prepared_data_path}" + ) + return { + "train": load_dataset("json", data_files=train_files, split="train"), + "validation": load_dataset("json", data_files=validation_files, split="train"), + } diff --git a/welt_training/experiments/machine-translation/run_clm.py b/welt_training/experiments/machine-translation/run_clm.py index 3218c23..7be3a7a 100644 --- a/welt_training/experiments/machine-translation/run_clm.py +++ b/welt_training/experiments/machine-translation/run_clm.py @@ -378,11 +378,9 @@ def main(): # In distributed training, the load_dataset function guarantee that only one local process can concurrently # download the dataset. if data_args.prepared_data_path is not None: - raw_datasets = load_prepared_data( - data_args.prepared_data_path, - validation_split_percentage=data_args.validation_split_percentage, - seed=training_args.seed, - ) + if data_args.validation_split_percentage is not None: + logger.warning("Ignoring validation_split_percentage because prepared_data_path is set.") + raw_datasets = load_prepared_data(data_args.prepared_data_path) elif data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. diff --git a/welt_training/prepare_data.py b/welt_training/prepare_data.py index 45e0a96..4be7437 100644 --- a/welt_training/prepare_data.py +++ b/welt_training/prepare_data.py @@ -180,10 +180,10 @@ def main(): help="Unit type for counting (default: 'words')", ) parser.add_argument( - "--max_total_units", + "--train_split_units", type=int, default=None, - help="Max total units to sample. If not set, processes the entire dataset.", + help="Number of units for the train split. If not set, processes the entire dataset.", ) parser.add_argument( "--num_units_per_file", @@ -210,12 +210,11 @@ def main(): help="Drop partial chunks when splitting documents by max_seq_length", ) parser.add_argument( - "--validation_split_percentage", + "--validation_split_units", type=int, default=None, - help="Percentage of examples to assign to validation split (e.g., 5 for 5%%). " - "Requires --max_total_units. The first (100 - N)%% of units go to train, " - "the remaining N%% go to validation. Data is already shuffled before splitting.", + help="Number of units for the validation split. " + "Requires --train_split_units. Data is already shuffled before splitting.", ) parser.add_argument( "--seed", @@ -255,19 +254,20 @@ def main(): pretokenizer = WordsSegmentationTokenizer(max_bytes=args.max_bytes_per_word) logger.info("Starting data preparation...") - logger.info(f" unit_type={args.unit_type}, max_total_units={args.max_total_units}, " + logger.info(f" unit_type={args.unit_type}, train_split_units={args.train_split_units}, " f"num_units_per_file={args.num_units_per_file}, max_seq_length={args.max_seq_length}, " f"language={args.language}") # Create shard writers - if args.validation_split_percentage is not None: - if args.max_total_units is None: - parser.error("--max_total_units is required when using --validation_split_percentage") - train_budget = int(args.max_total_units * (100 - args.validation_split_percentage) / 100) - logger.info(f" validation_split_percentage={args.validation_split_percentage}, train_budget={train_budget}") + if args.validation_split_units is not None: + if args.train_split_units is None: + parser.error("--train_split_units is required when using --validation_split_units") + max_total_units = args.train_split_units + args.validation_split_units + logger.info(f" validation_split_units={args.validation_split_units}, max_total_units={max_total_units}") train_writer = ShardWriter(output_path, prefix, "train", args.num_units_per_file) val_writer = ShardWriter(output_path, prefix, "validation", args.num_units_per_file) else: + max_total_units = args.train_split_units train_writer = ShardWriter(output_path, prefix, None, args.num_units_per_file) val_writer = None @@ -278,16 +278,16 @@ def main(): text_units = num_words if args.unit_type == "words" else len(text) # Check global limit - if args.max_total_units is not None and total_units + text_units > args.max_total_units: - logger.info(f"Reached max_total_units limit ({args.max_total_units})") + if max_total_units is not None and total_units + text_units > max_total_units: + logger.info(f"Reached total units limit ({max_total_units})") break record = {"text": text} if args.language: record["language"] = args.language - # Once the train budget is filled, route remaining examples to validation - if val_writer is not None and train_writer.total_units + text_units > train_budget: + # Fill validation first, then route to train + if val_writer is not None and val_writer.total_units < args.validation_split_units: val_writer.write(record, text_units) else: train_writer.write(record, text_units) @@ -323,7 +323,7 @@ def main(): if val_writer is not None: val_writer.close() metadata["num_shards"] += val_writer.num_shards - metadata["validation_split_percentage"] = args.validation_split_percentage + metadata["validation_split_units"] = args.validation_split_units metadata["splits"] = { "train": { "num_examples": train_writer.num_examples, diff --git a/welt_training/train.py b/welt_training/train.py index a46db02..83a222f 100644 --- a/welt_training/train.py +++ b/welt_training/train.py @@ -178,11 +178,9 @@ def init_datasets(data_args: DataTrainingArguments, # noqa: C901 # Load preprocessed data if path provided if data_args.prepared_data_path is not None: - return load_prepared_data( - data_args.prepared_data_path, - validation_split_percentage=data_args.validation_split_percentage, - seed=seed, - ) + if data_args.validation_split_percentage is not None: + logger.warning("Ignoring validation_split_percentage because prepared_data_path is set.") + return load_prepared_data(data_args.prepared_data_path) if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. From 69f85a3f4b4166f3c84779595234e4c21d27b14b Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Wed, 11 Feb 2026 20:08:03 +0100 Subject: [PATCH 12/22] make [(train/validation)]_split_units args required for data preparation phase. --- README.md | 6 ++-- tests/test_prepare_data.py | 30 ++++++------------ welt_training/prepare_data.py | 58 +++++++++++++++-------------------- 3 files changed, 39 insertions(+), 55 deletions(-) diff --git a/README.md b/README.md index c522999..1caf24a 100644 --- a/README.md +++ b/README.md @@ -106,6 +106,7 @@ welt-prepare-data \ --dataset_name HuggingFaceFW/fineweb \ --dataset_config sample-10BT \ --train_split_units 3200000000 \ + --validation_split_units 100000000 \ --num_units_per_file 100000000 \ --max_seq_length 512 \ --seed 42 \ @@ -118,6 +119,7 @@ Multiple datasets can be prepared into the same output directory: welt-prepare-data \ --dataset_name monology/pile-uncopyrighted \ --train_split_units 1000000000 \ + --validation_split_units 50000000 \ --num_units_per_file 100000000 \ --max_seq_length 512 \ --output_path /scratch/data/pretrain @@ -149,8 +151,8 @@ welt-train config.yaml --prepared_data_path /scratch/data/pretrain | `--text_template` | Python format string template (optional) | | `--language` | Language tag to store with each example (e.g., "eng_Latn") | | `--unit_type` | Unit type for counting: "words" or "chars" (default: "words") | -| `--train_split_units` | Number of units for the train split (optional) | -| `--validation_split_units` | Number of units for the validation split (optional; requires `--train_split_units`) | +| `--train_split_units` | Number of units for the train split (required) | +| `--validation_split_units` | Number of units for the validation split (required) | | `--num_units_per_file` | Max units per shard file (optional) | | `--max_seq_length` | Max words per example; splits long documents using word segmentation | | `--max_bytes_per_word` | Max UTF-8 bytes per word; should match training config `max_word_length - 2` (default: 126) | diff --git a/tests/test_prepare_data.py b/tests/test_prepare_data.py index 9f7aa5c..93a2eaa 100644 --- a/tests/test_prepare_data.py +++ b/tests/test_prepare_data.py @@ -43,7 +43,8 @@ def test_prepare_data_creates_shards(temp_output_dir, monkeypatch): "welt-prepare-data", "--dataset_name", "wikitext", "--dataset_config", "wikitext-2-raw-v1", - "--train_split_units", "500", + "--train_split_units", "400", + "--validation_split_units", "100", "--num_units_per_file", "200", "--seed", "42", "--output_path", temp_output_dir, @@ -89,7 +90,8 @@ def test_prepare_data_with_language(temp_output_dir, monkeypatch): "welt-prepare-data", "--dataset_name", "wikitext", "--dataset_config", "wikitext-2-raw-v1", - "--train_split_units", "200", + "--train_split_units", "160", + "--validation_split_units", "40", "--language", "eng_Latn", "--seed", "42", "--output_path", temp_output_dir, @@ -118,7 +120,8 @@ def test_prepare_data_unit_type_chars(temp_output_dir, monkeypatch): "welt-prepare-data", "--dataset_name", "wikitext", "--dataset_config", "wikitext-2-raw-v1", - "--train_split_units", "500", + "--train_split_units", "400", + "--validation_split_units", "100", "--unit_type", "chars", "--seed", "42", "--output_path", temp_output_dir, @@ -141,7 +144,8 @@ def test_prepare_data_with_max_seq_length(temp_output_dir, monkeypatch): "welt-prepare-data", "--dataset_name", "wikitext", "--dataset_config", "wikitext-2-raw-v1", - "--train_split_units", "500", + "--train_split_units", "400", + "--validation_split_units", "100", "--max_seq_length", "32", "--seed", "42", "--output_path", temp_output_dir, @@ -265,21 +269,7 @@ def test_load_prepared_data_split_aware(temp_output_dir, monkeypatch): assert len(result["train"]) + len(result["validation"]) == metadata["num_examples"] -def test_load_prepared_data_requires_validation_shards(temp_output_dir, monkeypatch): +def test_load_prepared_data_requires_validation_shards(temp_output_dir): """Test that load_prepared_data raises when validation shards are missing.""" - monkeypatch.setattr( - "sys.argv", - [ - "welt-prepare-data", - "--dataset_name", "wikitext", - "--dataset_config", "wikitext-2-raw-v1", - "--train_split_units", "500", - "--seed", "42", - "--output_path", temp_output_dir, - ], - ) - main() - - # Without --validation_split_units, shards have no split marker → error - with pytest.raises(ValueError, match="validation"): + with pytest.raises(ValueError, match="train"): load_prepared_data(temp_output_dir) diff --git a/welt_training/prepare_data.py b/welt_training/prepare_data.py index 4be7437..bc03361 100644 --- a/welt_training/prepare_data.py +++ b/welt_training/prepare_data.py @@ -182,8 +182,8 @@ def main(): parser.add_argument( "--train_split_units", type=int, - default=None, - help="Number of units for the train split. If not set, processes the entire dataset.", + required=True, + help="Number of units for the train split.", ) parser.add_argument( "--num_units_per_file", @@ -212,9 +212,9 @@ def main(): parser.add_argument( "--validation_split_units", type=int, - default=None, + required=True, help="Number of units for the validation split. " - "Requires --train_split_units. Data is already shuffled before splitting.", + "Data is already shuffled before splitting.", ) parser.add_argument( "--seed", @@ -259,17 +259,10 @@ def main(): f"language={args.language}") # Create shard writers - if args.validation_split_units is not None: - if args.train_split_units is None: - parser.error("--train_split_units is required when using --validation_split_units") - max_total_units = args.train_split_units + args.validation_split_units - logger.info(f" validation_split_units={args.validation_split_units}, max_total_units={max_total_units}") - train_writer = ShardWriter(output_path, prefix, "train", args.num_units_per_file) - val_writer = ShardWriter(output_path, prefix, "validation", args.num_units_per_file) - else: - max_total_units = args.train_split_units - train_writer = ShardWriter(output_path, prefix, None, args.num_units_per_file) - val_writer = None + max_total_units = args.train_split_units + args.validation_split_units + logger.info(f" validation_split_units={args.validation_split_units}, max_total_units={max_total_units}") + train_writer = ShardWriter(output_path, prefix, "train", args.num_units_per_file) + val_writer = ShardWriter(output_path, prefix, "validation", args.num_units_per_file) total_units = 0 total_examples = 0 @@ -278,7 +271,7 @@ def main(): text_units = num_words if args.unit_type == "words" else len(text) # Check global limit - if max_total_units is not None and total_units + text_units > max_total_units: + if total_units + text_units > max_total_units: logger.info(f"Reached total units limit ({max_total_units})") break @@ -287,7 +280,7 @@ def main(): record["language"] = args.language # Fill validation first, then route to train - if val_writer is not None and val_writer.total_units < args.validation_split_units: + if val_writer.total_units < args.validation_split_units: val_writer.write(record, text_units) else: train_writer.write(record, text_units) @@ -320,22 +313,21 @@ def main(): "text_template": args.text_template, } - if val_writer is not None: - val_writer.close() - metadata["num_shards"] += val_writer.num_shards - metadata["validation_split_units"] = args.validation_split_units - metadata["splits"] = { - "train": { - "num_examples": train_writer.num_examples, - "total_units": train_writer.total_units, - "num_shards": train_writer.num_shards, - }, - "validation": { - "num_examples": val_writer.num_examples, - "total_units": val_writer.total_units, - "num_shards": val_writer.num_shards, - }, - } + val_writer.close() + metadata["num_shards"] += val_writer.num_shards + metadata["validation_split_units"] = args.validation_split_units + metadata["splits"] = { + "train": { + "num_examples": train_writer.num_examples, + "total_units": train_writer.total_units, + "num_shards": train_writer.num_shards, + }, + "validation": { + "num_examples": val_writer.num_examples, + "total_units": val_writer.total_units, + "num_shards": val_writer.num_shards, + }, + } metadata_path = output_path / f"{prefix}-metadata.json" logger.info(f"Saving metadata to {metadata_path}") From afbc16ccfd5faa0fe94c8fe29722c3a4f6c13bf5 Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Wed, 11 Feb 2026 20:39:20 +0100 Subject: [PATCH 13/22] implement extract_text procedure to prevent duplicated text_template processing as suggested by Claude --- welt_training/data_utils.py | 7 +++++++ welt_training/prepare_data.py | 8 +++----- welt_training/train.py | 6 +++--- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/welt_training/data_utils.py b/welt_training/data_utils.py index 3c82515..d313823 100644 --- a/welt_training/data_utils.py +++ b/welt_training/data_utils.py @@ -7,6 +7,13 @@ logger = logging.getLogger(__name__) +def extract_text(example: dict, text_column: str = "text", text_template: str | None = None) -> str: + """Extract text from a dataset example using a column name or format template.""" + if text_template is not None: + return text_template.format(**example) + return example[text_column] + + def load_prepared_data(prepared_data_path: str): """Load preprocessed shards produced by prepare_data.py. diff --git a/welt_training/prepare_data.py b/welt_training/prepare_data.py index bc03361..05b9acf 100644 --- a/welt_training/prepare_data.py +++ b/welt_training/prepare_data.py @@ -14,6 +14,8 @@ from datasets import load_dataset from words_segmentation.tokenizer import WordsSegmentationTokenizer +from welt_training.data_utils import extract_text + logger = logging.getLogger(__name__) @@ -96,11 +98,7 @@ def stream_texts(args): stream = stream.shuffle(seed=args.seed, buffer_size=args.shuffle_buffer_size) for example in stream: - if args.text_template: - text = args.text_template.format(**example) - else: - text = example[args.text_column] - + text = extract_text(example, text_column=args.text_column, text_template=args.text_template) if text: yield text diff --git a/welt_training/train.py b/welt_training/train.py index 83a222f..ac74eb5 100644 --- a/welt_training/train.py +++ b/welt_training/train.py @@ -21,7 +21,7 @@ from welt_training.args_data import DataTrainingArguments from welt_training.args_model import ModelArguments from welt_training.args_trainer import WeLTTrainingArguments -from welt_training.data_utils import load_prepared_data +from welt_training.data_utils import extract_text, load_prepared_data from welt_training.extendable_yaml import resolve_yaml_file from welt_training.flops_callback import FlopsCallback from welt_training.freeze_callback import FreezeWarmupCallback @@ -287,7 +287,7 @@ def process_split(dataset, split_name: str): template = data_args.dataset_text_template if template is None: def mapping_fn(example): - return {"text": example[text_column_name]} + return {"text": extract_text(example, text_column=text_column_name)} else: is_single_text_template = isinstance(template, str) single_text_template = template \ @@ -295,7 +295,7 @@ def mapping_fn(example): def mapping_fn(example): if is_single_text_template or split_name == "train": - return {"text": single_text_template.format(**example)} + return {"text": extract_text(example, text_template=single_text_template)} prefix = template[0].format(**example) completion = template[1].format(**example) From fee1721c52a83fd74e896143974adae840dfca3f Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Wed, 11 Feb 2026 20:44:54 +0100 Subject: [PATCH 14/22] make ShardWriter a context manager, initiated with /with/ statements for secure handling --- welt_training/prepare_data.py | 86 ++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 42 deletions(-) diff --git a/welt_training/prepare_data.py b/welt_training/prepare_data.py index 05b9acf..9219b04 100644 --- a/welt_training/prepare_data.py +++ b/welt_training/prepare_data.py @@ -74,6 +74,12 @@ def close(self): if self.num_examples == 0: self._shard_path(self.shard_index).unlink(missing_ok=True) + def __enter__(self): + return self + + def __exit__(self, *exc): + self.close() + @property def num_shards(self) -> int: if self.num_examples == 0: @@ -259,48 +265,47 @@ def main(): # Create shard writers max_total_units = args.train_split_units + args.validation_split_units logger.info(f" validation_split_units={args.validation_split_units}, max_total_units={max_total_units}") - train_writer = ShardWriter(output_path, prefix, "train", args.num_units_per_file) - val_writer = ShardWriter(output_path, prefix, "validation", args.num_units_per_file) - - total_units = 0 - total_examples = 0 - - for text, num_words in stream_examples(args, pretokenizer): - text_units = num_words if args.unit_type == "words" else len(text) + with ( + ShardWriter(output_path, prefix, "train", args.num_units_per_file) as train_writer, + ShardWriter(output_path, prefix, "validation", args.num_units_per_file) as val_writer, + ): + total_units = 0 + total_examples = 0 - # Check global limit - if total_units + text_units > max_total_units: - logger.info(f"Reached total units limit ({max_total_units})") - break + for text, num_words in stream_examples(args, pretokenizer): + text_units = num_words if args.unit_type == "words" else len(text) - record = {"text": text} - if args.language: - record["language"] = args.language + # Check global limit + if total_units + text_units > max_total_units: + logger.info(f"Reached total units limit ({max_total_units})") + break - # Fill validation first, then route to train - if val_writer.total_units < args.validation_split_units: - val_writer.write(record, text_units) - else: - train_writer.write(record, text_units) + record = {"text": text} + if args.language: + record["language"] = args.language - total_units += text_units - total_examples += 1 + # Fill validation first, then route to train + if val_writer.total_units < args.validation_split_units: + val_writer.write(record, text_units) + else: + train_writer.write(record, text_units) - if total_examples % 10000 == 0: - logger.info(f"Processed {total_examples} examples, {total_units} total {args.unit_type}") + total_units += text_units + total_examples += 1 - train_writer.close() + if total_examples % 10000 == 0: + logger.info(f"Processed {total_examples} examples, {total_units} total {args.unit_type}") if total_examples == 0: logger.warning("No examples were written. Check dataset and filter settings.") - # Save metadata + # Save metadata (both writers are closed, final counts are stable) metadata = { "format": "welt-preprocessed-v1", "num_examples": total_examples, "total_units": total_units, "unit_type": args.unit_type, - "num_shards": train_writer.num_shards, + "num_shards": train_writer.num_shards + val_writer.num_shards, "source_dataset": args.dataset_name, "source_config": args.dataset_config, "source_split": args.dataset_split, @@ -309,21 +314,18 @@ def main(): "seed": args.seed, "text_column": args.text_column, "text_template": args.text_template, - } - - val_writer.close() - metadata["num_shards"] += val_writer.num_shards - metadata["validation_split_units"] = args.validation_split_units - metadata["splits"] = { - "train": { - "num_examples": train_writer.num_examples, - "total_units": train_writer.total_units, - "num_shards": train_writer.num_shards, - }, - "validation": { - "num_examples": val_writer.num_examples, - "total_units": val_writer.total_units, - "num_shards": val_writer.num_shards, + "validation_split_units": args.validation_split_units, + "splits": { + "train": { + "num_examples": train_writer.num_examples, + "total_units": train_writer.total_units, + "num_shards": train_writer.num_shards, + }, + "validation": { + "num_examples": val_writer.num_examples, + "total_units": val_writer.total_units, + "num_shards": val_writer.num_shards, + }, }, } From d1e127c43fd623d272405ec7e112cac94e251579 Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Wed, 11 Feb 2026 20:49:53 +0100 Subject: [PATCH 15/22] do not count whitespace chars while keeping statistics --- welt_training/prepare_data.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/welt_training/prepare_data.py b/welt_training/prepare_data.py index 9219b04..5b5532d 100644 --- a/welt_training/prepare_data.py +++ b/welt_training/prepare_data.py @@ -109,6 +109,11 @@ def stream_texts(args): yield text +def _count_non_whitespace(text: str) -> int: + """Count non-whitespace characters in text.""" + return sum(1 for c in text if not c.isspace()) + + def stream_examples(args, pretokenizer: WordsSegmentationTokenizer): """Stream text examples, optionally chunked by max_seq_length. @@ -116,9 +121,16 @@ def stream_examples(args, pretokenizer: WordsSegmentationTokenizer): languages without whitespace). When max_seq_length is set, long documents are split into chunks of at most max_seq_length words. - Yields (text, num_words) tuples. + Yields (text, unit_count) tuples where unit_count matches args.unit_type. + When unit_type is "chars", only non-whitespace characters are counted. """ + count_chars = args.unit_type == "chars" + for text in stream_texts(args): + if args.max_seq_length is None and count_chars: + yield text, _count_non_whitespace(text) + continue + words = pretokenizer.tokenize(text) if args.max_seq_length is None: @@ -129,7 +141,8 @@ def stream_examples(args, pretokenizer: WordsSegmentationTokenizer): chunk_words = words[i:i + args.max_seq_length] if args.drop_remainder and len(chunk_words) < args.max_seq_length: continue - yield "".join(chunk_words), len(chunk_words) + chunk_text = "".join(chunk_words) + yield chunk_text, _count_non_whitespace(chunk_text) if count_chars else len(chunk_words) def main(): @@ -272,8 +285,7 @@ def main(): total_units = 0 total_examples = 0 - for text, num_words in stream_examples(args, pretokenizer): - text_units = num_words if args.unit_type == "words" else len(text) + for text, text_units in stream_examples(args, pretokenizer): # Check global limit if total_units + text_units > max_total_units: From 332333039ce299ae6154b3541018e73f3ee740fe Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Wed, 11 Feb 2026 20:56:35 +0100 Subject: [PATCH 16/22] refactor data preparation tests --- tests/test_prepare_data.py | 156 ++++++++++++++++++------------------- 1 file changed, 74 insertions(+), 82 deletions(-) diff --git a/tests/test_prepare_data.py b/tests/test_prepare_data.py index 93a2eaa..11253b7 100644 --- a/tests/test_prepare_data.py +++ b/tests/test_prepare_data.py @@ -10,6 +10,35 @@ from welt_training.data_utils import load_prepared_data from welt_training.prepare_data import get_shard_prefix, main +WIKITEXT_DATASET = "wikitext" +WIKITEXT_CONFIG = "wikitext-2-raw-v1" + + +# --- Helpers --- + + +def read_shard_examples(output_dir, pattern="*.jsonl.gz"): + """Read all examples from shards matching *pattern* in *output_dir*.""" + examples = [] + for path in sorted(glob.glob(f"{output_dir}/{pattern}")): + with gzip.open(path, "rt") as f: + for line in f: + examples.append(json.loads(line)) + return examples + + +def read_metadata(output_dir, dataset_name=WIKITEXT_DATASET, dataset_config=WIKITEXT_CONFIG): + """Load the metadata JSON produced by prepare_data.""" + prefix = get_shard_prefix(dataset_name, dataset_config) + with open(f"{output_dir}/{prefix}-metadata.json") as f: + return json.load(f) + + +def shard_paths(output_dir, pattern="*.jsonl.gz"): + """Return sorted list of shard file paths matching *pattern*.""" + return sorted(glob.glob(f"{output_dir}/{pattern}")) + + # --- get_shard_prefix --- @@ -41,8 +70,8 @@ def test_prepare_data_creates_shards(temp_output_dir, monkeypatch): "sys.argv", [ "welt-prepare-data", - "--dataset_name", "wikitext", - "--dataset_config", "wikitext-2-raw-v1", + "--dataset_name", WIKITEXT_DATASET, + "--dataset_config", WIKITEXT_CONFIG, "--train_split_units", "400", "--validation_split_units", "100", "--num_units_per_file", "200", @@ -52,33 +81,24 @@ def test_prepare_data_creates_shards(temp_output_dir, monkeypatch): ) main() - # Verify shards were created - shard_files = sorted(glob.glob(f"{temp_output_dir}/*.jsonl.gz")) - assert len(shard_files) >= 2, f"Expected at least 2 shards, got {len(shard_files)}" + shards = shard_paths(temp_output_dir) + assert len(shards) >= 2, f"Expected at least 2 shards, got {len(shards)}" - # Verify metadata - prefix = get_shard_prefix("wikitext", "wikitext-2-raw-v1") - with open(f"{temp_output_dir}/{prefix}-metadata.json") as f: - metadata = json.load(f) + metadata = read_metadata(temp_output_dir) assert metadata["format"] == "welt-preprocessed-v1" assert metadata["total_units"] <= 500 assert metadata["unit_type"] == "words" - assert metadata["num_shards"] == len(shard_files) + assert metadata["num_shards"] == len(shards) - # Verify each shard is valid gzipped JSONL with a "text" field - total_examples = 0 - for path in shard_files: - with gzip.open(path, "rt") as f: - for line in f: - example = json.loads(line) - assert "text" in example - assert isinstance(example["text"], str) - total_examples += 1 - assert total_examples == metadata["num_examples"] + examples = read_shard_examples(temp_output_dir) + for example in examples: + assert "text" in example + assert isinstance(example["text"], str) + assert len(examples) == metadata["num_examples"] # Verify loading with HuggingFace datasets (same path as train.py) - ds = load_dataset("json", data_files=shard_files, split="train") - assert len(ds) == total_examples + ds = load_dataset("json", data_files=shards, split="train") + assert len(ds) == len(examples) assert "text" in ds.features @@ -88,8 +108,8 @@ def test_prepare_data_with_language(temp_output_dir, monkeypatch): "sys.argv", [ "welt-prepare-data", - "--dataset_name", "wikitext", - "--dataset_config", "wikitext-2-raw-v1", + "--dataset_name", WIKITEXT_DATASET, + "--dataset_config", WIKITEXT_CONFIG, "--train_split_units", "160", "--validation_split_units", "40", "--language", "eng_Latn", @@ -99,16 +119,10 @@ def test_prepare_data_with_language(temp_output_dir, monkeypatch): ) main() - shard_files = sorted(glob.glob(f"{temp_output_dir}/*.jsonl.gz")) - for path in shard_files: - with gzip.open(path, "rt") as f: - for line in f: - example = json.loads(line) - assert example["language"] == "eng_Latn" + for example in read_shard_examples(temp_output_dir): + assert example["language"] == "eng_Latn" - prefix = get_shard_prefix("wikitext", "wikitext-2-raw-v1") - with open(f"{temp_output_dir}/{prefix}-metadata.json") as f: - metadata = json.load(f) + metadata = read_metadata(temp_output_dir) assert metadata["language"] == "eng_Latn" @@ -118,8 +132,8 @@ def test_prepare_data_unit_type_chars(temp_output_dir, monkeypatch): "sys.argv", [ "welt-prepare-data", - "--dataset_name", "wikitext", - "--dataset_config", "wikitext-2-raw-v1", + "--dataset_name", WIKITEXT_DATASET, + "--dataset_config", WIKITEXT_CONFIG, "--train_split_units", "400", "--validation_split_units", "100", "--unit_type", "chars", @@ -129,9 +143,7 @@ def test_prepare_data_unit_type_chars(temp_output_dir, monkeypatch): ) main() - prefix = get_shard_prefix("wikitext", "wikitext-2-raw-v1") - with open(f"{temp_output_dir}/{prefix}-metadata.json") as f: - metadata = json.load(f) + metadata = read_metadata(temp_output_dir) assert metadata["unit_type"] == "chars" assert metadata["total_units"] <= 500 @@ -142,8 +154,8 @@ def test_prepare_data_with_max_seq_length(temp_output_dir, monkeypatch): "sys.argv", [ "welt-prepare-data", - "--dataset_name", "wikitext", - "--dataset_config", "wikitext-2-raw-v1", + "--dataset_name", WIKITEXT_DATASET, + "--dataset_config", WIKITEXT_CONFIG, "--train_split_units", "400", "--validation_split_units", "100", "--max_seq_length", "32", @@ -153,9 +165,7 @@ def test_prepare_data_with_max_seq_length(temp_output_dir, monkeypatch): ) main() - prefix = get_shard_prefix("wikitext", "wikitext-2-raw-v1") - with open(f"{temp_output_dir}/{prefix}-metadata.json") as f: - metadata = json.load(f) + metadata = read_metadata(temp_output_dir) assert metadata["max_seq_length"] == 32 assert metadata["total_units"] <= 500 @@ -163,13 +173,9 @@ def test_prepare_data_with_max_seq_length(temp_output_dir, monkeypatch): from words_segmentation.tokenizer import WordsSegmentationTokenizer pretokenizer = WordsSegmentationTokenizer(max_bytes=126) - shard_files = sorted(glob.glob(f"{temp_output_dir}/*.jsonl.gz")) - for path in shard_files: - with gzip.open(path, "rt") as f: - for line in f: - example = json.loads(line) - words = pretokenizer.tokenize(example["text"]) - assert len(words) <= 32 + for example in read_shard_examples(temp_output_dir): + words = pretokenizer.tokenize(example["text"]) + assert len(words) <= 32 def test_prepare_data_with_validation_split(temp_output_dir, monkeypatch): @@ -178,8 +184,8 @@ def test_prepare_data_with_validation_split(temp_output_dir, monkeypatch): "sys.argv", [ "welt-prepare-data", - "--dataset_name", "wikitext", - "--dataset_config", "wikitext-2-raw-v1", + "--dataset_name", WIKITEXT_DATASET, + "--dataset_config", WIKITEXT_CONFIG, "--train_split_units", "800", "--validation_split_units", "200", "--seed", "42", @@ -188,51 +194,39 @@ def test_prepare_data_with_validation_split(temp_output_dir, monkeypatch): ) main() - prefix = get_shard_prefix("wikitext", "wikitext-2-raw-v1") + prefix = get_shard_prefix(WIKITEXT_DATASET, WIKITEXT_CONFIG) # Verify split-aware shard files were created - train_files = sorted(glob.glob(f"{temp_output_dir}/{prefix}-train-*.jsonl.gz")) - val_files = sorted(glob.glob(f"{temp_output_dir}/{prefix}-validation-*.jsonl.gz")) + train_files = shard_paths(temp_output_dir, f"{prefix}-train-*.jsonl.gz") + val_files = shard_paths(temp_output_dir, f"{prefix}-validation-*.jsonl.gz") assert len(train_files) >= 1, f"Expected at least 1 train shard, got {len(train_files)}" assert len(val_files) >= 1, f"Expected at least 1 validation shard, got {len(val_files)}" # No legacy (unsplit) shards should exist - all_shards = sorted(glob.glob(f"{temp_output_dir}/*.jsonl.gz")) + all_shards = shard_paths(temp_output_dir) assert len(all_shards) == len(train_files) + len(val_files) # Count examples per split - train_examples = 0 - for path in train_files: - with gzip.open(path, "rt") as f: - for line in f: - example = json.loads(line) - assert "text" in example - train_examples += 1 - - val_examples = 0 - for path in val_files: - with gzip.open(path, "rt") as f: - for line in f: - example = json.loads(line) - assert "text" in example - val_examples += 1 + train_examples = read_shard_examples(temp_output_dir, f"{prefix}-train-*.jsonl.gz") + val_examples = read_shard_examples(temp_output_dir, f"{prefix}-validation-*.jsonl.gz") + for example in train_examples + val_examples: + assert "text" in example - total_examples = train_examples + val_examples + total_examples = len(train_examples) + len(val_examples) assert total_examples > 0 # Verify validation fraction is roughly correct (20% +/- tolerance) - val_fraction = val_examples / total_examples + val_fraction = len(val_examples) / total_examples assert 0.05 < val_fraction < 0.45, f"Expected ~20% validation, got {val_fraction:.1%}" # Verify metadata - with open(f"{temp_output_dir}/{prefix}-metadata.json") as f: - metadata = json.load(f) + metadata = read_metadata(temp_output_dir) assert metadata["format"] == "welt-preprocessed-v1" assert metadata["validation_split_units"] == 200 assert metadata["num_examples"] == total_examples assert "splits" in metadata - assert metadata["splits"]["train"]["num_examples"] == train_examples - assert metadata["splits"]["validation"]["num_examples"] == val_examples + assert metadata["splits"]["train"]["num_examples"] == len(train_examples) + assert metadata["splits"]["validation"]["num_examples"] == len(val_examples) assert metadata["splits"]["train"]["num_shards"] == len(train_files) assert metadata["splits"]["validation"]["num_shards"] == len(val_files) @@ -243,8 +237,8 @@ def test_load_prepared_data_split_aware(temp_output_dir, monkeypatch): "sys.argv", [ "welt-prepare-data", - "--dataset_name", "wikitext", - "--dataset_config", "wikitext-2-raw-v1", + "--dataset_name", WIKITEXT_DATASET, + "--dataset_config", WIKITEXT_CONFIG, "--train_split_units", "800", "--validation_split_units", "200", "--seed", "42", @@ -263,9 +257,7 @@ def test_load_prepared_data_split_aware(temp_output_dir, monkeypatch): assert "text" in result["validation"].features # Total should match what was prepared - prefix = get_shard_prefix("wikitext", "wikitext-2-raw-v1") - with open(f"{temp_output_dir}/{prefix}-metadata.json") as f: - metadata = json.load(f) + metadata = read_metadata(temp_output_dir) assert len(result["train"]) + len(result["validation"]) == metadata["num_examples"] From fc33081e8b2f29c4a87d0642cabdcd652f062af1 Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Wed, 11 Feb 2026 21:07:32 +0100 Subject: [PATCH 17/22] rely on the words segmenter to count number of characters --- welt_training/prepare_data.py | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/welt_training/prepare_data.py b/welt_training/prepare_data.py index 5b5532d..5d46a5a 100644 --- a/welt_training/prepare_data.py +++ b/welt_training/prepare_data.py @@ -109,11 +109,6 @@ def stream_texts(args): yield text -def _count_non_whitespace(text: str) -> int: - """Count non-whitespace characters in text.""" - return sum(1 for c in text if not c.isspace()) - - def stream_examples(args, pretokenizer: WordsSegmentationTokenizer): """Stream text examples, optionally chunked by max_seq_length. @@ -122,27 +117,25 @@ def stream_examples(args, pretokenizer: WordsSegmentationTokenizer): are split into chunks of at most max_seq_length words. Yields (text, unit_count) tuples where unit_count matches args.unit_type. - When unit_type is "chars", only non-whitespace characters are counted. + When unit_type is "chars", characters are counted via token lengths + (whitespace excluded by the tokenizer). """ count_chars = args.unit_type == "chars" for text in stream_texts(args): - if args.max_seq_length is None and count_chars: - yield text, _count_non_whitespace(text) - continue - words = pretokenizer.tokenize(text) + unit_count = sum(len(w) for w in words) if count_chars else len(words) if args.max_seq_length is None: - yield text, len(words) - continue - - for i in range(0, len(words), args.max_seq_length): - chunk_words = words[i:i + args.max_seq_length] - if args.drop_remainder and len(chunk_words) < args.max_seq_length: - continue - chunk_text = "".join(chunk_words) - yield chunk_text, _count_non_whitespace(chunk_text) if count_chars else len(chunk_words) + yield text, unit_count + else: + for i in range(0, len(words), args.max_seq_length): + chunk_words = words[i:i + args.max_seq_length] + if args.drop_remainder and len(chunk_words) < args.max_seq_length: + continue + chunk_text = "".join(chunk_words) + chunk_units = sum(len(w) for w in chunk_words) if count_chars else len(chunk_words) + yield chunk_text, chunk_units def main(): From fc19d878f9767a16784295a4626ea88b526591fc Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Thu, 12 Feb 2026 19:50:37 +0100 Subject: [PATCH 18/22] separate split metadata files and enable preparation per split also makes the language arg required for the data preparation script --- README.md | 4 +- tests/test_prepare_data.py | 161 +++++++++++++++++++++++++++------- welt_training/data_utils.py | 29 +++--- welt_training/prepare_data.py | 74 +++++++++------- 4 files changed, 185 insertions(+), 83 deletions(-) diff --git a/README.md b/README.md index 1caf24a..b7ea637 100644 --- a/README.md +++ b/README.md @@ -151,8 +151,8 @@ welt-train config.yaml --prepared_data_path /scratch/data/pretrain | `--text_template` | Python format string template (optional) | | `--language` | Language tag to store with each example (e.g., "eng_Latn") | | `--unit_type` | Unit type for counting: "words" or "chars" (default: "words") | -| `--train_split_units` | Number of units for the train split (required) | -| `--validation_split_units` | Number of units for the validation split (required) | +| `--train_split_units` | Number of units for the train split (default: 0, no train shards) | +| `--validation_split_units` | Number of units for the validation split (default: 0, no validation shards) | | `--num_units_per_file` | Max units per shard file (optional) | | `--max_seq_length` | Max words per example; splits long documents using word segmentation | | `--max_bytes_per_word` | Max UTF-8 bytes per word; should match training config `max_word_length - 2` (default: 126) | diff --git a/tests/test_prepare_data.py b/tests/test_prepare_data.py index 11253b7..8f47385 100644 --- a/tests/test_prepare_data.py +++ b/tests/test_prepare_data.py @@ -27,10 +27,10 @@ def read_shard_examples(output_dir, pattern="*.jsonl.gz"): return examples -def read_metadata(output_dir, dataset_name=WIKITEXT_DATASET, dataset_config=WIKITEXT_CONFIG): - """Load the metadata JSON produced by prepare_data.""" +def read_metadata(output_dir, split_name, dataset_name=WIKITEXT_DATASET, dataset_config=WIKITEXT_CONFIG): + """Load the per-split metadata JSON produced by prepare_data.""" prefix = get_shard_prefix(dataset_name, dataset_config) - with open(f"{output_dir}/{prefix}-metadata.json") as f: + with open(f"{output_dir}/{prefix}-{split_name}-metadata.json") as f: return json.load(f) @@ -75,6 +75,8 @@ def test_prepare_data_creates_shards(temp_output_dir, monkeypatch): "--train_split_units", "400", "--validation_split_units", "100", "--num_units_per_file", "200", + "--max_seq_length", "1024", + "--language", "eng_Latn", "--seed", "42", "--output_path", temp_output_dir, ], @@ -84,17 +86,18 @@ def test_prepare_data_creates_shards(temp_output_dir, monkeypatch): shards = shard_paths(temp_output_dir) assert len(shards) >= 2, f"Expected at least 2 shards, got {len(shards)}" - metadata = read_metadata(temp_output_dir) - assert metadata["format"] == "welt-preprocessed-v1" - assert metadata["total_units"] <= 500 - assert metadata["unit_type"] == "words" - assert metadata["num_shards"] == len(shards) + train_meta = read_metadata(temp_output_dir, "train") + val_meta = read_metadata(temp_output_dir, "validation") + assert train_meta["format"] == "welt-preprocessed-v1" + assert train_meta["total_units"] + val_meta["total_units"] <= 500 + assert train_meta["unit_type"] == "words" + assert train_meta["num_shards"] + val_meta["num_shards"] == len(shards) examples = read_shard_examples(temp_output_dir) for example in examples: assert "text" in example assert isinstance(example["text"], str) - assert len(examples) == metadata["num_examples"] + assert len(examples) == train_meta["num_examples"] + val_meta["num_examples"] # Verify loading with HuggingFace datasets (same path as train.py) ds = load_dataset("json", data_files=shards, split="train") @@ -112,6 +115,7 @@ def test_prepare_data_with_language(temp_output_dir, monkeypatch): "--dataset_config", WIKITEXT_CONFIG, "--train_split_units", "160", "--validation_split_units", "40", + "--max_seq_length", "1024", "--language", "eng_Latn", "--seed", "42", "--output_path", temp_output_dir, @@ -122,8 +126,8 @@ def test_prepare_data_with_language(temp_output_dir, monkeypatch): for example in read_shard_examples(temp_output_dir): assert example["language"] == "eng_Latn" - metadata = read_metadata(temp_output_dir) - assert metadata["language"] == "eng_Latn" + assert read_metadata(temp_output_dir, "train")["language"] == "eng_Latn" + assert read_metadata(temp_output_dir, "validation")["language"] == "eng_Latn" def test_prepare_data_unit_type_chars(temp_output_dir, monkeypatch): @@ -137,15 +141,17 @@ def test_prepare_data_unit_type_chars(temp_output_dir, monkeypatch): "--train_split_units", "400", "--validation_split_units", "100", "--unit_type", "chars", + "--max_seq_length", "1024", + "--language", "eng_Latn", "--seed", "42", "--output_path", temp_output_dir, ], ) main() - metadata = read_metadata(temp_output_dir) - assert metadata["unit_type"] == "chars" - assert metadata["total_units"] <= 500 + meta = read_metadata(temp_output_dir, "validation") + assert meta["unit_type"] == "chars" + assert meta["total_units"] <= 500 def test_prepare_data_with_max_seq_length(temp_output_dir, monkeypatch): @@ -159,15 +165,17 @@ def test_prepare_data_with_max_seq_length(temp_output_dir, monkeypatch): "--train_split_units", "400", "--validation_split_units", "100", "--max_seq_length", "32", + "--language", "eng_Latn", "--seed", "42", "--output_path", temp_output_dir, ], ) main() - metadata = read_metadata(temp_output_dir) - assert metadata["max_seq_length"] == 32 - assert metadata["total_units"] <= 500 + train_meta = read_metadata(temp_output_dir, "train") + val_meta = read_metadata(temp_output_dir, "validation") + assert train_meta["max_seq_length"] == 32 + assert train_meta["total_units"] + val_meta["total_units"] <= 500 # Verify each example has at most max_seq_length words from words_segmentation.tokenizer import WordsSegmentationTokenizer @@ -188,6 +196,8 @@ def test_prepare_data_with_validation_split(temp_output_dir, monkeypatch): "--dataset_config", WIKITEXT_CONFIG, "--train_split_units", "800", "--validation_split_units", "200", + "--max_seq_length", "1024", + "--language", "eng_Latn", "--seed", "42", "--output_path", temp_output_dir, ], @@ -219,16 +229,14 @@ def test_prepare_data_with_validation_split(temp_output_dir, monkeypatch): val_fraction = len(val_examples) / total_examples assert 0.05 < val_fraction < 0.45, f"Expected ~20% validation, got {val_fraction:.1%}" - # Verify metadata - metadata = read_metadata(temp_output_dir) - assert metadata["format"] == "welt-preprocessed-v1" - assert metadata["validation_split_units"] == 200 - assert metadata["num_examples"] == total_examples - assert "splits" in metadata - assert metadata["splits"]["train"]["num_examples"] == len(train_examples) - assert metadata["splits"]["validation"]["num_examples"] == len(val_examples) - assert metadata["splits"]["train"]["num_shards"] == len(train_files) - assert metadata["splits"]["validation"]["num_shards"] == len(val_files) + # Verify per-split metadata + train_meta = read_metadata(temp_output_dir, "train") + val_meta = read_metadata(temp_output_dir, "validation") + assert train_meta["format"] == "welt-preprocessed-v1" + assert train_meta["num_examples"] == len(train_examples) + assert val_meta["num_examples"] == len(val_examples) + assert train_meta["num_shards"] == len(train_files) + assert val_meta["num_shards"] == len(val_files) def test_load_prepared_data_split_aware(temp_output_dir, monkeypatch): @@ -241,6 +249,8 @@ def test_load_prepared_data_split_aware(temp_output_dir, monkeypatch): "--dataset_config", WIKITEXT_CONFIG, "--train_split_units", "800", "--validation_split_units", "200", + "--max_seq_length", "1024", + "--language", "eng_Latn", "--seed", "42", "--output_path", temp_output_dir, ], @@ -257,11 +267,98 @@ def test_load_prepared_data_split_aware(temp_output_dir, monkeypatch): assert "text" in result["validation"].features # Total should match what was prepared - metadata = read_metadata(temp_output_dir) - assert len(result["train"]) + len(result["validation"]) == metadata["num_examples"] + train_meta = read_metadata(temp_output_dir, "train") + val_meta = read_metadata(temp_output_dir, "validation") + assert len(result["train"]) + len(result["validation"]) == train_meta["num_examples"] + val_meta["num_examples"] -def test_load_prepared_data_requires_validation_shards(temp_output_dir): - """Test that load_prepared_data raises when validation shards are missing.""" - with pytest.raises(ValueError, match="train"): +def test_load_prepared_data_requires_some_shards(temp_output_dir): + """Test that load_prepared_data raises when no shards exist.""" + with pytest.raises(ValueError, match="No"): load_prepared_data(temp_output_dir) + + +def test_prepare_data_train_only(temp_output_dir, monkeypatch): + """Test that setting validation_split_units=0 creates only train shards.""" + monkeypatch.setattr( + "sys.argv", + [ + "welt-prepare-data", + "--dataset_name", WIKITEXT_DATASET, + "--dataset_config", WIKITEXT_CONFIG, + "--train_split_units", "400", + "--max_seq_length", "1024", + "--language", "eng_Latn", + "--seed", "42", + "--output_path", temp_output_dir, + ], + ) + main() + + prefix = get_shard_prefix(WIKITEXT_DATASET, WIKITEXT_CONFIG) + train_files = shard_paths(temp_output_dir, f"{prefix}-train-*.jsonl.gz") + val_files = shard_paths(temp_output_dir, f"{prefix}-validation-*.jsonl.gz") + assert len(train_files) >= 1 + assert len(val_files) == 0 + + metadata = read_metadata(temp_output_dir, "train") + assert metadata["num_examples"] > 0 + assert not glob.glob(f"{temp_output_dir}/*-validation-metadata.json") + + result = load_prepared_data(temp_output_dir) + assert "train" in result + assert "validation" not in result + + +def test_prepare_data_validation_only(temp_output_dir, monkeypatch): + """Test that setting train_split_units=0 creates only validation shards.""" + monkeypatch.setattr( + "sys.argv", + [ + "welt-prepare-data", + "--dataset_name", WIKITEXT_DATASET, + "--dataset_config", WIKITEXT_CONFIG, + "--validation_split_units", "200", + "--max_seq_length", "1024", + "--language", "eng_Latn", + "--seed", "42", + "--output_path", temp_output_dir, + ], + ) + main() + + prefix = get_shard_prefix(WIKITEXT_DATASET, WIKITEXT_CONFIG) + train_files = shard_paths(temp_output_dir, f"{prefix}-train-*.jsonl.gz") + val_files = shard_paths(temp_output_dir, f"{prefix}-validation-*.jsonl.gz") + assert len(train_files) == 0 + assert len(val_files) >= 1 + + metadata = read_metadata(temp_output_dir, "validation") + assert metadata["num_examples"] > 0 + assert not glob.glob(f"{temp_output_dir}/*-train-metadata.json") + + result = load_prepared_data(temp_output_dir) + assert "train" not in result + assert "validation" in result + + +def test_prepare_data_with_id_column(temp_output_dir, monkeypatch): + """Test that --id_column preserves the source column as 'id' in output.""" + monkeypatch.setattr( + "sys.argv", + [ + "welt-prepare-data", + "--dataset_name", "HuggingFaceFW/fineweb-2", + "--dataset_config", "tur_Latn", + "--train_split_units", "400", + "--id_column", "id", + "--max_seq_length", "1024", + "--language", "tur_Latn", + "--seed", "42", + "--output_path", temp_output_dir, + ], + ) + main() + + for example in read_shard_examples(temp_output_dir): + assert "id" in example diff --git a/welt_training/data_utils.py b/welt_training/data_utils.py index d313823..fd9de11 100644 --- a/welt_training/data_utils.py +++ b/welt_training/data_utils.py @@ -17,37 +17,34 @@ def extract_text(example: dict, text_column: str = "text", text_template: str | def load_prepared_data(prepared_data_path: str): """Load preprocessed shards produced by prepare_data.py. - Loads ``{prefix}-train-*.jsonl.gz`` shards as the train split and + Loads ``{prefix}-train-*.jsonl.gz`` shards as the train split and/or ``{prefix}-validation-*.jsonl.gz`` shards as the validation split. - Both train and validation shards are required. Prepare data with - ``--validation_split_units`` to produce them. + At least one split must be present. Missing splits are omitted from the result. Args: prepared_data_path: Directory containing ``*.jsonl.gz`` shard files. Returns: - A dict with ``"train"`` and ``"validation"`` datasets. + A dict with ``"train"`` and/or ``"validation"`` datasets. """ train_files = sorted(glob.glob(os.path.join(prepared_data_path, "*-train-*.jsonl.gz"))) validation_files = sorted(glob.glob(os.path.join(prepared_data_path, "*-validation-*.jsonl.gz"))) - if not train_files: + if not train_files and not validation_files: raise ValueError( - f"No *-train-*.jsonl.gz files found in {prepared_data_path}. " - "Prepare data with --train_split_units and --validation_split_units." - ) - if not validation_files: - raise ValueError( - f"No *-validation-*.jsonl.gz files found in {prepared_data_path}. " - "Prepare data with --validation_split_units to create a validation split." + f"No *-train-*.jsonl.gz or *-validation-*.jsonl.gz files found in {prepared_data_path}. " + "Prepare data with --train_split_units and/or --validation_split_units." ) + result = {} + if train_files: + result["train"] = load_dataset("json", data_files=train_files, split="train") + if validation_files: + result["validation"] = load_dataset("json", data_files=validation_files, split="train") + logger.info( f"Loading prepared data: {len(train_files)} train shard(s), " f"{len(validation_files)} validation shard(s) from {prepared_data_path}" ) - return { - "train": load_dataset("json", data_files=train_files, split="train"), - "validation": load_dataset("json", data_files=validation_files, split="train"), - } + return result diff --git a/welt_training/prepare_data.py b/welt_training/prepare_data.py index 5d46a5a..2cb9b05 100644 --- a/welt_training/prepare_data.py +++ b/welt_training/prepare_data.py @@ -106,7 +106,8 @@ def stream_texts(args): for example in stream: text = extract_text(example, text_column=args.text_column, text_template=args.text_template) if text: - yield text + id_value = example.get(args.id_column) if args.id_column else None + yield text, id_value def stream_examples(args, pretokenizer: WordsSegmentationTokenizer): @@ -122,12 +123,12 @@ def stream_examples(args, pretokenizer: WordsSegmentationTokenizer): """ count_chars = args.unit_type == "chars" - for text in stream_texts(args): + for text, id_value in stream_texts(args): words = pretokenizer.tokenize(text) unit_count = sum(len(w) for w in words) if count_chars else len(words) if args.max_seq_length is None: - yield text, unit_count + yield text, unit_count, id_value else: for i in range(0, len(words), args.max_seq_length): chunk_words = words[i:i + args.max_seq_length] @@ -135,7 +136,7 @@ def stream_examples(args, pretokenizer: WordsSegmentationTokenizer): continue chunk_text = "".join(chunk_words) chunk_units = sum(len(w) for w in chunk_words) if count_chars else len(chunk_words) - yield chunk_text, chunk_units + yield chunk_text, chunk_units, id_value def main(): @@ -177,9 +178,15 @@ def main(): parser.add_argument( "--language", type=str, - default=None, + required=True, help="Language tag to store with each example (e.g., 'eng_Latn')", ) + parser.add_argument( + "--id_column", + type=str, + default=None, + help="Source column to preserve as 'id' in output (optional)", + ) # Processing arguments parser.add_argument( @@ -192,8 +199,8 @@ def main(): parser.add_argument( "--train_split_units", type=int, - required=True, - help="Number of units for the train split.", + default=0, + help="Number of units for the train split (default: 0, meaning no train shards).", ) parser.add_argument( "--num_units_per_file", @@ -222,8 +229,8 @@ def main(): parser.add_argument( "--validation_split_units", type=int, - required=True, - help="Number of units for the validation split. " + default=0, + help="Number of units for the validation split (default: 0, meaning no validation shards). " "Data is already shuffled before splitting.", ) parser.add_argument( @@ -249,6 +256,9 @@ def main(): args = parser.parse_args() + if args.train_split_units <= 0 and args.validation_split_units <= 0: + parser.error("At least one of --train_split_units or --validation_split_units must be > 0.") + # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -278,7 +288,7 @@ def main(): total_units = 0 total_examples = 0 - for text, text_units in stream_examples(args, pretokenizer): + for text, text_units, id_value in stream_examples(args, pretokenizer): # Check global limit if total_units + text_units > max_total_units: @@ -286,6 +296,8 @@ def main(): break record = {"text": text} + if id_value is not None: + record["id"] = id_value if args.language: record["language"] = args.language @@ -304,13 +316,10 @@ def main(): if total_examples == 0: logger.warning("No examples were written. Check dataset and filter settings.") - # Save metadata (both writers are closed, final counts are stable) - metadata = { + # Save per-split metadata (both writers are closed, final counts are stable) + base_metadata = { "format": "welt-preprocessed-v1", - "num_examples": total_examples, - "total_units": total_units, "unit_type": args.unit_type, - "num_shards": train_writer.num_shards + val_writer.num_shards, "source_dataset": args.dataset_name, "source_config": args.dataset_config, "source_split": args.dataset_split, @@ -319,30 +328,29 @@ def main(): "seed": args.seed, "text_column": args.text_column, "text_template": args.text_template, - "validation_split_units": args.validation_split_units, - "splits": { - "train": { - "num_examples": train_writer.num_examples, - "total_units": train_writer.total_units, - "num_shards": train_writer.num_shards, - }, - "validation": { - "num_examples": val_writer.num_examples, - "total_units": val_writer.total_units, - "num_shards": val_writer.num_shards, - }, - }, } - metadata_path = output_path / f"{prefix}-metadata.json" - logger.info(f"Saving metadata to {metadata_path}") - with open(metadata_path, "w") as f: - json.dump(metadata, f, indent=2) + active_writers = [w for w in [train_writer, val_writer] if w.num_examples > 0] + created_with_another_split = len(active_writers) > 1 + + for writer in active_writers: + metadata = { + **base_metadata, + "split": writer.split_name, + "created_with_another_split": created_with_another_split, + "num_examples": writer.num_examples, + "total_units": writer.total_units, + "num_shards": writer.num_shards, + } + metadata_path = output_path / f"{prefix}-{writer.split_name}-metadata.json" + logger.info(f"Saving metadata to {metadata_path}") + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) logger.info("Data preparation complete!") logger.info(f" - Examples: {total_examples}") logger.info(f" - Total {args.unit_type}: {total_units}") - logger.info(f" - Shards: {metadata['num_shards']}") + logger.info(f" - Shards: {train_writer.num_shards + val_writer.num_shards}") logger.info(f" - Output: {output_path}") From dc203772014567601132e1f21a90edf4ae38e5da Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Thu, 12 Feb 2026 19:52:23 +0100 Subject: [PATCH 19/22] implement verification for the prepared data --- README.md | 10 +++ pyproject.toml | 1 + tests/test_verify_data.py | 127 +++++++++++++++++++++++++++++++++++ welt_training/verify_data.py | 108 +++++++++++++++++++++++++++++ 4 files changed, 246 insertions(+) create mode 100644 tests/test_verify_data.py create mode 100644 welt_training/verify_data.py diff --git a/README.md b/README.md index b7ea637..ec3f936 100644 --- a/README.md +++ b/README.md @@ -160,6 +160,16 @@ welt-train config.yaml --prepared_data_path /scratch/data/pretrain | `--drop_remainder` | Drop partial chunks at document boundaries | | `--output_path` | Output directory path (required) | +### Verifying Prepared Data + +After preparing data, verify integrity with `welt-verify-data`: + +```shell +welt-verify-data --data_path /scratch/data/pretrain +``` + +This checks that shard counts and example counts match the metadata, and warns if train/validation splits from the same source were created separately (risking data contamination). + ## Training Training instructions are available in the [welt_training/README.md](./welt_training/README.md). diff --git a/pyproject.toml b/pyproject.toml index be10423..92a70a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,3 +81,4 @@ testpaths = [ [project.scripts] welt-train = "welt_training.train:train" welt-prepare-data = "welt_training.prepare_data:main" +welt-verify-data = "welt_training.verify_data:main" diff --git a/tests/test_verify_data.py b/tests/test_verify_data.py new file mode 100644 index 0000000..3c08923 --- /dev/null +++ b/tests/test_verify_data.py @@ -0,0 +1,127 @@ +import gzip +import json +import shutil +import tempfile + +import pytest + +from welt_training.verify_data import verify + + +@pytest.fixture +def temp_dir(): + d = tempfile.mkdtemp(prefix="test_verify_data_") + yield d + shutil.rmtree(d, ignore_errors=True) + + +def write_metadata(path, metadata): + with open(path, "w") as f: + json.dump(metadata, f) + + +def write_shard(path, examples): + with gzip.open(path, "wt") as f: + for ex in examples: + f.write(json.dumps(ex) + "\n") + + +def make_metadata(split, num_examples, num_shards, **overrides): + base = { + "format": "welt-preprocessed-v1", + "split": split, + "num_examples": num_examples, + "num_shards": num_shards, + "total_units": num_examples * 10, + "unit_type": "words", + "source_dataset": "test-dataset", + "source_config": "test-config", + "source_split": "train", + "language": "eng_Latn", + "max_seq_length": None, + "seed": 42, + "text_column": "text", + "text_template": None, + "created_with_another_split": False, + } + base.update(overrides) + return base + + +def test_verify_empty_dir(temp_dir): + passed, messages = verify(temp_dir) + assert not passed + assert any("No" in m for m in messages) + + +def test_verify_consistent_single_split(temp_dir): + examples = [{"text": f"example {i}"} for i in range(5)] + write_shard(f"{temp_dir}/ds-train-00000000.jsonl.gz", examples) + write_metadata(f"{temp_dir}/ds-train-metadata.json", make_metadata("train", 5, 1)) + + passed, messages = verify(temp_dir) + assert passed + + +def test_verify_shard_count_mismatch(temp_dir): + examples = [{"text": "hello"}] + write_shard(f"{temp_dir}/ds-train-00000000.jsonl.gz", examples) + write_metadata(f"{temp_dir}/ds-train-metadata.json", make_metadata("train", 1, 2)) + + passed, messages = verify(temp_dir) + assert not passed + assert any("Expected 2 shards" in m for m in messages) + + +def test_verify_example_count_mismatch(temp_dir): + examples = [{"text": "hello"}, {"text": "world"}] + write_shard(f"{temp_dir}/ds-train-00000000.jsonl.gz", examples) + write_metadata(f"{temp_dir}/ds-train-metadata.json", make_metadata("train", 5, 1)) + + passed, messages = verify(temp_dir) + assert not passed + assert any("Expected 5 examples" in m for m in messages) + + +def test_verify_contamination_warning(temp_dir): + """Splits from same source created separately should warn.""" + write_shard(f"{temp_dir}/ds-train-00000000.jsonl.gz", [{"text": "a"}]) + write_shard(f"{temp_dir}/ds-validation-00000000.jsonl.gz", [{"text": "b"}]) + write_metadata(f"{temp_dir}/ds-train-metadata.json", + make_metadata("train", 1, 1, created_with_another_split=False)) + write_metadata(f"{temp_dir}/ds-validation-metadata.json", + make_metadata("validation", 1, 1, created_with_another_split=False)) + + passed, messages = verify(temp_dir) + assert passed # warning, not failure + assert any("WARN" in m and "overlap" in m for m in messages) + + +def test_verify_no_contamination_when_created_together(temp_dir): + """Splits created together should pass contamination check.""" + write_shard(f"{temp_dir}/ds-train-00000000.jsonl.gz", [{"text": "a"}]) + write_shard(f"{temp_dir}/ds-validation-00000000.jsonl.gz", [{"text": "b"}]) + write_metadata(f"{temp_dir}/ds-train-metadata.json", + make_metadata("train", 1, 1, created_with_another_split=True)) + write_metadata(f"{temp_dir}/ds-validation-metadata.json", + make_metadata("validation", 1, 1, created_with_another_split=True)) + + passed, messages = verify(temp_dir) + assert passed + assert any("no overlap" in m for m in messages) + + +def test_verify_no_contamination_different_sources(temp_dir): + """Splits from different sources should skip contamination check.""" + write_shard(f"{temp_dir}/ds1-train-00000000.jsonl.gz", [{"text": "a"}]) + write_shard(f"{temp_dir}/ds2-validation-00000000.jsonl.gz", [{"text": "b"}]) + write_metadata(f"{temp_dir}/ds1-train-metadata.json", + make_metadata("train", 1, 1, source_dataset="dataset-A")) + write_metadata(f"{temp_dir}/ds2-validation-metadata.json", + make_metadata("validation", 1, 1, source_dataset="dataset-B")) + + passed, messages = verify(temp_dir) + assert passed + # No contamination message since sources differ + assert not any("overlap" in m for m in messages) + assert not any("contamination" in m for m in messages) diff --git a/welt_training/verify_data.py b/welt_training/verify_data.py new file mode 100644 index 0000000..7d1a2a3 --- /dev/null +++ b/welt_training/verify_data.py @@ -0,0 +1,108 @@ +"""Verify integrity of prepared data directories. + +Checks metadata consistency, shard counts, example counts, +and warns about potential data contamination between splits. +""" + +import argparse +import glob +import gzip +import json +import os +import sys + + +def discover_splits(data_path): + """Find all per-split metadata files and return {split_name: metadata_dict}.""" + splits = {} + for path in sorted(glob.glob(os.path.join(data_path, "*-metadata.json"))): + with open(path) as f: + metadata = json.load(f) + split_name = metadata["split"] + splits[split_name] = metadata + return splits + + +def count_shard_examples(shard_files): + """Count total examples across a list of .jsonl.gz shard files.""" + count = 0 + for path in shard_files: + with gzip.open(path, "rt") as f: + for _ in f: + count += 1 + return count + + +def verify(data_path): + """Run all verification checks. Returns (passed: bool, messages: list[str]).""" + messages = [] + passed = True + + # 1. Discover splits + splits = discover_splits(data_path) + if not splits: + messages.append("FAIL: No *-metadata.json files found.") + return False, messages + + messages.append(f"Found splits: {', '.join(splits)}") + + # 2. Per-split shard consistency + for split_name, metadata in splits.items(): + pattern = os.path.join(data_path, f"*-{split_name}-*.jsonl.gz") + shard_files = sorted(glob.glob(pattern)) + + expected_shards = metadata["num_shards"] + actual_shards = len(shard_files) + if actual_shards != expected_shards: + messages.append( + f"FAIL [{split_name}]: Expected {expected_shards} shards, found {actual_shards}." + ) + passed = False + else: + messages.append(f"OK [{split_name}]: {actual_shards} shard(s)") + + expected_examples = metadata["num_examples"] + actual_examples = count_shard_examples(shard_files) + if actual_examples != expected_examples: + messages.append( + f"FAIL [{split_name}]: Expected {expected_examples} examples, found {actual_examples}." + ) + passed = False + else: + messages.append(f"OK [{split_name}]: {actual_examples} example(s)") + + # 3. Data contamination check + if "train" in splits and "validation" in splits: + train_meta = splits["train"] + val_meta = splits["validation"] + same_source = ( + train_meta["source_dataset"] == val_meta["source_dataset"] + and train_meta["source_config"] == val_meta["source_config"] + and train_meta["source_split"] == val_meta["source_split"] + ) + if same_source: + if not train_meta.get("created_with_another_split") or not val_meta.get("created_with_another_split"): + messages.append( + "WARN: Train and validation share the same source but were not created together. " + "Examples may overlap." + ) + else: + messages.append("OK [contamination]: Splits created together, no overlap risk.") + + return passed, messages + + +def main(): + parser = argparse.ArgumentParser(description="Verify integrity of prepared data.") + parser.add_argument("--data_path", type=str, required=True, help="Path to prepared data directory.") + args = parser.parse_args() + + passed, messages = verify(args.data_path) + for msg in messages: + print(msg) + + sys.exit(0 if passed else 1) + + +if __name__ == "__main__": + main() From c68e192b40d90929a8c554b756e2b6c5375908fa Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Thu, 12 Feb 2026 20:21:06 +0100 Subject: [PATCH 20/22] apply refactorings suggested by claude --- welt_training/data_utils.py | 9 +++++-- welt_training/prepare_data.py | 47 ++++++++++++++++++----------------- welt_training/train.py | 6 ++--- welt_training/verify_data.py | 5 ++-- 4 files changed, 36 insertions(+), 31 deletions(-) diff --git a/welt_training/data_utils.py b/welt_training/data_utils.py index fd9de11..0fac007 100644 --- a/welt_training/data_utils.py +++ b/welt_training/data_utils.py @@ -14,6 +14,11 @@ def extract_text(example: dict, text_column: str = "text", text_template: str | return example[text_column] +def find_shard_files(data_path: str, split_name: str) -> list[str]: + """Return sorted shard files for a given split in a prepared data directory.""" + return sorted(glob.glob(os.path.join(data_path, f"*-{split_name}-*.jsonl.gz"))) + + def load_prepared_data(prepared_data_path: str): """Load preprocessed shards produced by prepare_data.py. @@ -28,8 +33,8 @@ def load_prepared_data(prepared_data_path: str): Returns: A dict with ``"train"`` and/or ``"validation"`` datasets. """ - train_files = sorted(glob.glob(os.path.join(prepared_data_path, "*-train-*.jsonl.gz"))) - validation_files = sorted(glob.glob(os.path.join(prepared_data_path, "*-validation-*.jsonl.gz"))) + train_files = find_shard_files(prepared_data_path, "train") + validation_files = find_shard_files(prepared_data_path, "validation") if not train_files and not validation_files: raise ValueError( diff --git a/welt_training/prepare_data.py b/welt_training/prepare_data.py index 2cb9b05..21ac787 100644 --- a/welt_training/prepare_data.py +++ b/welt_training/prepare_data.py @@ -28,18 +28,23 @@ def get_shard_prefix(dataset_name: str, dataset_config: str | None) -> str: class ShardWriter: - """Manages writing sharded .jsonl.gz files for a single data split.""" + """Manages writing sharded .jsonl.gz files for a single data split. + + Files are opened lazily on first write, so creating a writer for + a split that receives no data produces no files on disk. + """ def __init__(self, output_path: Path, prefix: str, split_name: str | None, num_units_per_file: int | None): self.output_path = output_path self.prefix = prefix self.split_name = split_name self.num_units_per_file = num_units_per_file - self.shard_index = 0 - self.shard_units = 0 + self._shard_index = 0 + self._shard_units = 0 + self._num_shards = 0 + self._current_file = None self.total_units = 0 self.num_examples = 0 - self._current_file = self._open_shard() def _shard_path(self, index: int) -> Path: if self.split_name: @@ -47,32 +52,30 @@ def _shard_path(self, index: int) -> Path: return self.output_path / f"{self.prefix}-{index:08d}.jsonl.gz" def _open_shard(self): - path = self._shard_path(self.shard_index) + path = self._shard_path(self._shard_index) logger.info(f"Writing shard: {path.name}") - return gzip.open(path, "wt") + self._current_file = gzip.open(path, "wt") def write(self, record: dict, text_units: int): + if self._current_file is None: + self._open_shard() self._current_file.write(json.dumps(record, ensure_ascii=False) + "\n") - self.shard_units += text_units + self._shard_units += text_units self.total_units += text_units self.num_examples += 1 - if self.num_units_per_file is not None and self.shard_units >= self.num_units_per_file: + if self.num_units_per_file is not None and self._shard_units >= self.num_units_per_file: self._current_file.close() - logger.info(f"Completed shard {self.shard_index} ({self.shard_units} units)") - self.shard_index += 1 - self.shard_units = 0 - self._current_file = self._open_shard() + logger.info(f"Completed shard {self._shard_index} ({self._shard_units} units)") + self._num_shards += 1 + self._shard_index += 1 + self._shard_units = 0 + self._current_file = None def close(self): - self._current_file.close() - # Remove empty last shard - if self.shard_units == 0 and self.shard_index > 0: - self._shard_path(self.shard_index).unlink() - self.shard_index -= 1 - # Remove shard file if no examples were written at all - if self.num_examples == 0: - self._shard_path(self.shard_index).unlink(missing_ok=True) + if self._current_file is not None: + self._current_file.close() + self._num_shards += 1 def __enter__(self): return self @@ -82,9 +85,7 @@ def __exit__(self, *exc): @property def num_shards(self) -> int: - if self.num_examples == 0: - return 0 - return self.shard_index + 1 + return self._num_shards def stream_texts(args): diff --git a/welt_training/train.py b/welt_training/train.py index ac74eb5..87f74bd 100644 --- a/welt_training/train.py +++ b/welt_training/train.py @@ -161,8 +161,7 @@ def detect_last_checkpoint(training_args: TrainingArguments): def init_datasets(data_args: DataTrainingArguments, # noqa: C901 trust_remote_code: bool, do_train: bool = True, - cache_dir: str = None, - seed: int = 42): + cache_dir: str = None): # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). @@ -374,8 +373,7 @@ def train(args: list[str] | None | str = None): # noqa: C901 text_datasets = init_datasets(data_args, cache_dir=cache_dir, trust_remote_code=model_args.trust_remote_code, - do_train=training_args.do_train, - seed=training_args.seed) + do_train=training_args.do_train) train_dataset = None if training_args.do_train: diff --git a/welt_training/verify_data.py b/welt_training/verify_data.py index 7d1a2a3..df79001 100644 --- a/welt_training/verify_data.py +++ b/welt_training/verify_data.py @@ -11,6 +11,8 @@ import os import sys +from welt_training.data_utils import find_shard_files + def discover_splits(data_path): """Find all per-split metadata files and return {split_name: metadata_dict}.""" @@ -48,8 +50,7 @@ def verify(data_path): # 2. Per-split shard consistency for split_name, metadata in splits.items(): - pattern = os.path.join(data_path, f"*-{split_name}-*.jsonl.gz") - shard_files = sorted(glob.glob(pattern)) + shard_files = find_shard_files(data_path, split_name) expected_shards = metadata["num_shards"] actual_shards = len(shard_files) From 0651fc8e7c6dac45bb89574c28fdb0616a9b86a2 Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Thu, 12 Feb 2026 22:14:33 +0100 Subject: [PATCH 21/22] handle mutliple resources while verying the data --- tests/test_verify_data.py | 21 ++++++++++++ welt_training/data_utils.py | 12 +++++-- welt_training/prepare_data.py | 5 ++- welt_training/verify_data.py | 64 ++++++++++++++++++++++------------- 4 files changed, 73 insertions(+), 29 deletions(-) diff --git a/tests/test_verify_data.py b/tests/test_verify_data.py index 3c08923..fdf18ea 100644 --- a/tests/test_verify_data.py +++ b/tests/test_verify_data.py @@ -97,6 +97,26 @@ def test_verify_contamination_warning(temp_dir): assert any("WARN" in m and "overlap" in m for m in messages) +def test_verify_multiple_datasets(temp_dir): + """Multiple datasets in the same directory should be verified independently.""" + # Dataset A: 3 train examples in 1 shard + write_shard(f"{temp_dir}/ds-a-train-00000000.jsonl.gz", + [{"text": f"a{i}"} for i in range(3)]) + write_metadata(f"{temp_dir}/ds-a-train-metadata.json", + make_metadata("train", 3, 1, source_dataset="dataset-A")) + + # Dataset B: 2 train examples in 1 shard + write_shard(f"{temp_dir}/ds-b-train-00000000.jsonl.gz", + [{"text": f"b{i}"} for i in range(2)]) + write_metadata(f"{temp_dir}/ds-b-train-metadata.json", + make_metadata("train", 2, 1, source_dataset="dataset-B")) + + passed, messages = verify(temp_dir) + assert passed + assert any("ds-a/train" in m for m in messages) + assert any("ds-b/train" in m for m in messages) + + def test_verify_no_contamination_when_created_together(temp_dir): """Splits created together should pass contamination check.""" write_shard(f"{temp_dir}/ds-train-00000000.jsonl.gz", [{"text": "a"}]) @@ -109,6 +129,7 @@ def test_verify_no_contamination_when_created_together(temp_dir): passed, messages = verify(temp_dir) assert passed assert any("no overlap" in m for m in messages) + assert any("ds/contamination" in m for m in messages) def test_verify_no_contamination_different_sources(temp_dir): diff --git a/welt_training/data_utils.py b/welt_training/data_utils.py index 0fac007..dbbdfd7 100644 --- a/welt_training/data_utils.py +++ b/welt_training/data_utils.py @@ -14,9 +14,15 @@ def extract_text(example: dict, text_column: str = "text", text_template: str | return example[text_column] -def find_shard_files(data_path: str, split_name: str) -> list[str]: - """Return sorted shard files for a given split in a prepared data directory.""" - return sorted(glob.glob(os.path.join(data_path, f"*-{split_name}-*.jsonl.gz"))) +def find_shard_files(data_path: str, split_name: str, prefix: str | None = None) -> list[str]: + """Return sorted shard files for a given split in a prepared data directory. + + When *prefix* is given, only shards for that specific dataset are matched. + Without it, shards from all datasets in the directory are returned (used by + :func:`load_prepared_data` to load multi-dataset mixtures). + """ + name = f"{prefix}-{split_name}" if prefix else f"*-{split_name}" + return sorted(glob.glob(os.path.join(data_path, f"{name}-*.jsonl.gz"))) def load_prepared_data(prepared_data_path: str): diff --git a/welt_training/prepare_data.py b/welt_training/prepare_data.py index 21ac787..d03ceca 100644 --- a/welt_training/prepare_data.py +++ b/welt_training/prepare_data.py @@ -75,6 +75,7 @@ def write(self, record: dict, text_units: int): def close(self): if self._current_file is not None: self._current_file.close() + self._current_file = None self._num_shards += 1 def __enter__(self): @@ -296,11 +297,9 @@ def main(): logger.info(f"Reached total units limit ({max_total_units})") break - record = {"text": text} + record = {"text": text, "language": args.language} if id_value is not None: record["id"] = id_value - if args.language: - record["language"] = args.language # Fill validation first, then route to train if val_writer.total_units < args.validation_split_units: diff --git a/welt_training/verify_data.py b/welt_training/verify_data.py index df79001..1d0a5a8 100644 --- a/welt_training/verify_data.py +++ b/welt_training/verify_data.py @@ -14,15 +14,23 @@ from welt_training.data_utils import find_shard_files -def discover_splits(data_path): - """Find all per-split metadata files and return {split_name: metadata_dict}.""" - splits = {} +def discover_metadata(data_path): + """Find all per-split metadata files and return ``[(prefix, metadata)]``. + + Each metadata file is named ``{prefix}-{split}-metadata.json``. The prefix + is derived from the filename and the ``split`` field inside the JSON so that + multiple datasets in one directory are handled correctly. + """ + entries = [] for path in sorted(glob.glob(os.path.join(data_path, "*-metadata.json"))): with open(path) as f: metadata = json.load(f) split_name = metadata["split"] - splits[split_name] = metadata - return splits + filename = os.path.basename(path) + suffix = f"-{split_name}-metadata.json" + prefix = filename[: -len(suffix)] + entries.append((prefix, metadata)) + return entries def count_shard_examples(shard_files): @@ -40,42 +48,52 @@ def verify(data_path): messages = [] passed = True - # 1. Discover splits - splits = discover_splits(data_path) - if not splits: + # 1. Discover metadata (prefix-aware) + entries = discover_metadata(data_path) + if not entries: messages.append("FAIL: No *-metadata.json files found.") return False, messages - messages.append(f"Found splits: {', '.join(splits)}") + prefixes = sorted({p for p, _ in entries}) + splits = sorted({m["split"] for _, m in entries}) + messages.append(f"Found {len(entries)} metadata file(s): {len(prefixes)} dataset(s), splits: {', '.join(splits)}") - # 2. Per-split shard consistency - for split_name, metadata in splits.items(): - shard_files = find_shard_files(data_path, split_name) + # 2. Per-dataset, per-split shard consistency + for prefix, metadata in entries: + split_name = metadata["split"] + label = f"{prefix}/{split_name}" + shard_files = find_shard_files(data_path, split_name, prefix=prefix) expected_shards = metadata["num_shards"] actual_shards = len(shard_files) if actual_shards != expected_shards: messages.append( - f"FAIL [{split_name}]: Expected {expected_shards} shards, found {actual_shards}." + f"FAIL [{label}]: Expected {expected_shards} shards, found {actual_shards}." ) passed = False else: - messages.append(f"OK [{split_name}]: {actual_shards} shard(s)") + messages.append(f"OK [{label}]: {actual_shards} shard(s)") expected_examples = metadata["num_examples"] actual_examples = count_shard_examples(shard_files) if actual_examples != expected_examples: messages.append( - f"FAIL [{split_name}]: Expected {expected_examples} examples, found {actual_examples}." + f"FAIL [{label}]: Expected {expected_examples} examples, found {actual_examples}." ) passed = False else: - messages.append(f"OK [{split_name}]: {actual_examples} example(s)") - - # 3. Data contamination check - if "train" in splits and "validation" in splits: - train_meta = splits["train"] - val_meta = splits["validation"] + messages.append(f"OK [{label}]: {actual_examples} example(s)") + + # 3. Data contamination check (per dataset) + by_prefix = {} + for prefix, metadata in entries: + by_prefix.setdefault(prefix, {})[metadata["split"]] = metadata + + for prefix, split_metas in by_prefix.items(): + if "train" not in split_metas or "validation" not in split_metas: + continue + train_meta = split_metas["train"] + val_meta = split_metas["validation"] same_source = ( train_meta["source_dataset"] == val_meta["source_dataset"] and train_meta["source_config"] == val_meta["source_config"] @@ -84,11 +102,11 @@ def verify(data_path): if same_source: if not train_meta.get("created_with_another_split") or not val_meta.get("created_with_another_split"): messages.append( - "WARN: Train and validation share the same source but were not created together. " + f"WARN [{prefix}]: Train and validation share the same source but were not created together. " "Examples may overlap." ) else: - messages.append("OK [contamination]: Splits created together, no overlap risk.") + messages.append(f"OK [{prefix}/contamination]: Splits created together, no overlap risk.") return passed, messages From 333053995de062669fe06bef0caff99735ce6dd8 Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Thu, 12 Feb 2026 22:14:56 +0100 Subject: [PATCH 22/22] discard extra columns in training script --- welt_training/experiments/machine-translation/run_clm.py | 7 +++++-- welt_training/train.py | 6 ++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/welt_training/experiments/machine-translation/run_clm.py b/welt_training/experiments/machine-translation/run_clm.py index 7be3a7a..3cbd407 100644 --- a/welt_training/experiments/machine-translation/run_clm.py +++ b/welt_training/experiments/machine-translation/run_clm.py @@ -48,7 +48,7 @@ import datasets import evaluate import torch -from datasets import IterableDataset, IterableDatasetDict, load_dataset +from datasets import DatasetDict, IterableDataset, IterableDatasetDict, load_dataset import transformers from transformers import ( @@ -318,6 +318,8 @@ def main(): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + elif len(sys.argv) == 2 and sys.argv[1].endswith((".yaml", ".yml")): + model_args, data_args, training_args = parser.parse_yaml_file(yaml_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() @@ -380,7 +382,7 @@ def main(): if data_args.prepared_data_path is not None: if data_args.validation_split_percentage is not None: logger.warning("Ignoring validation_split_percentage because prepared_data_path is set.") - raw_datasets = load_prepared_data(data_args.prepared_data_path) + raw_datasets = DatasetDict(load_prepared_data(data_args.prepared_data_path)) elif data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. @@ -566,6 +568,7 @@ def mapping_function(x): desc="Keep only the text column & apply template", **map_args ) + column_names = [text_column_name] def tokenize_function(examples): diff --git a/welt_training/train.py b/welt_training/train.py index 87f74bd..4cd10e6 100644 --- a/welt_training/train.py +++ b/welt_training/train.py @@ -375,6 +375,12 @@ def train(args: list[str] | None | str = None): # noqa: C901 trust_remote_code=model_args.trust_remote_code, do_train=training_args.do_train) + # Drop columns not needed for training (e.g. "language" from prepared data) + for split in list(text_datasets): + extra_cols = [c for c in text_datasets[split].column_names if c != "text"] + if extra_cols: + text_datasets[split] = text_datasets[split].remove_columns(extra_cols) + train_dataset = None if training_args.do_train: if "train" not in text_datasets: