From 87d4cb21c6e54a916622b3384faf1d08452b2ebf Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Thu, 26 Feb 2026 18:42:35 +0100 Subject: [PATCH 01/12] implement bits per bytes metric including tests and trainer adaption --- tests/test_metrics.py | 166 ++++++++++++++++++ tests/test_train.py | 11 +- tests/test_trainer.py | 47 +++++ .../machine-translation/run_clm.py | 15 +- welt_training/metrics.py | 25 +++ welt_training/trainer.py | 11 +- 6 files changed, 271 insertions(+), 4 deletions(-) create mode 100644 tests/test_metrics.py create mode 100644 welt_training/metrics.py diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..de63596 --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,166 @@ +"""Tests for compute_bits_per_byte metric utility.""" + +import math + +import pytest + +from welt_training.metrics import compute_bits_per_byte + + +class TestComputeBitsPerByte: + """Tests for the compute_bits_per_byte function.""" + + def test_basic_computation(self): + """BPB = loss * num_tokens / (num_bytes * ln(2)).""" + loss = 2.0 + num_tokens = 100 + num_bytes = 100 + expected = 2.0 / math.log(2) + assert compute_bits_per_byte(loss, num_tokens, num_bytes) == pytest.approx(expected) + + def test_byte_level_model(self): + """When tokens == bytes (byte-level model), BPB = loss / ln(2).""" + loss = 1.5 + bpb = compute_bits_per_byte(loss, num_tokens=1000, num_bytes=1000) + assert bpb == pytest.approx(loss / math.log(2)) + + def test_subword_tokenizer(self): + """When tokens < bytes (subword tokenizer), BPB accounts for compression.""" + loss = 3.0 + num_tokens = 50 + num_bytes = 200 # 4 bytes per token on average + expected = 3.0 * 50 / (200 * math.log(2)) + assert compute_bits_per_byte(loss, num_tokens, num_bytes) == pytest.approx(expected) + + def test_zero_loss(self): + """Zero loss should give zero BPB (perfect prediction).""" + assert compute_bits_per_byte(0.0, 100, 100) == 0.0 + + def test_zero_bytes_returns_inf(self): + """Zero bytes should return infinity.""" + result = compute_bits_per_byte(1.0, 100, 0) + assert result == float("inf") + + def test_zero_tokens_returns_zero(self): + """Zero tokens should return zero BPB.""" + result = compute_bits_per_byte(1.0, 0, 100) + assert result == 0.0 + + def test_higher_loss_means_higher_bpb(self): + """Higher loss should give higher BPB.""" + bpb_low = compute_bits_per_byte(1.0, 100, 100) + bpb_high = compute_bits_per_byte(2.0, 100, 100) + assert bpb_high > bpb_low + + def test_more_bytes_means_lower_bpb(self): + """More bytes for same total bits should give lower BPB.""" + bpb_fewer = compute_bits_per_byte(1.0, 100, 100) + bpb_more = compute_bits_per_byte(1.0, 100, 200) + assert bpb_more < bpb_fewer + + def test_known_value_one_bit_per_byte(self): + """loss = ln(2) nats per token, 1 token, 1 byte -> 1.0 bit per byte.""" + loss = math.log(2) + assert compute_bits_per_byte(loss, 1, 1) == pytest.approx(1.0) + + def test_relationship_with_perplexity(self): + """BPB = log2(perplexity) when num_tokens == num_bytes.""" + loss = 2.5 + perplexity = math.exp(loss) + bpb = compute_bits_per_byte(loss, num_tokens=1, num_bytes=1) + assert bpb == pytest.approx(math.log2(perplexity)) + + def test_scaling_with_token_count(self): + """Doubling num_tokens with same num_bytes should double BPB.""" + bpb1 = compute_bits_per_byte(1.0, num_tokens=100, num_bytes=200) + bpb2 = compute_bits_per_byte(1.0, num_tokens=200, num_bytes=200) + assert bpb2 == pytest.approx(2 * bpb1) + + +class TestBPBTokenizerIntegration: + """Validate BPB computation patterns used in the training scripts.""" + + def test_byte_level_model_pattern(self): + """ + Validate the train.py (WeLTTrainer) pattern. + + WeLT is a byte-level model: each token IS a byte, so num_tokens == num_bytes. + The trainer passes num_tokens=1, num_bytes=1 to get the ratio BPB = loss / ln(2). + Verify this matches log2(perplexity). + """ + loss = 4.2 + perplexity = math.exp(loss) + + # This is exactly what WeLTTrainer._add_custom_metrics does + bpb = compute_bits_per_byte(loss, num_tokens=1, num_bytes=1) + + assert bpb == pytest.approx(loss / math.log(2)) + assert bpb == pytest.approx(math.log2(perplexity)) + + def test_subword_tokenizer_compression_ratio(self): + """ + Validate the run_clm.py pattern where tokens < bytes due to subword tokenization. + + For a BPE tokenizer, each token covers ~3-4 bytes on average. + BPB must account for this compression: + BPB = loss_per_token * (num_tokens / num_bytes) / ln(2) + """ + loss = 3.0 + + # Byte-level baseline: 1 token per byte + bpb_byte = compute_bits_per_byte(loss, num_tokens=400, num_bytes=400) + + # BPE tokenizer: ~3.5 bytes per token (typical for English) + bpb_bpe = compute_bits_per_byte(loss, num_tokens=400, num_bytes=1400) + + # BPE should produce lower BPB (same loss spread over more bytes) + assert bpb_bpe < bpb_byte + assert bpb_bpe == pytest.approx(bpb_byte * 400 / 1400) + + def test_decode_roundtrip_byte_counting(self): + """ + Validate the run_clm.py approach: decode token IDs → encode as UTF-8 → count bytes. + + Uses a real HuggingFace tokenizer to verify the decode-based byte counting + correctly captures the tokenizer's compression ratio. + """ + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("gpt2") + + texts = [ + "The quick brown fox jumps over the lazy dog.", + "Machine translation is an important NLP task.", + ] + + # Simulate run_clm.py pipeline: tokenize, then decode to count bytes + all_token_ids = [] + for text in texts: + all_token_ids.extend(tokenizer.encode(text)) + + num_tokens = len(all_token_ids) + original_bytes = sum(len(t.encode("utf-8")) for t in texts) + + # run_clm.py decodes per-chunk; simulate with full sequence + decoded_text = tokenizer.decode(all_token_ids) + decoded_bytes = len(decoded_text.encode("utf-8")) + + # Decoded bytes should closely match original (tokenizer roundtrip) + assert abs(decoded_bytes - original_bytes) <= len(texts), \ + f"Decoded bytes ({decoded_bytes}) should match original ({original_bytes})" + + # Subword tokenizer: fewer tokens than bytes + assert num_tokens < original_bytes, \ + "BPE tokenizer should compress text (fewer tokens than bytes)" + + # Verify BPB with a known loss + loss = 2.5 + bpb = compute_bits_per_byte(loss, num_tokens, decoded_bytes) + + # Must be less than byte-level BPB (since tokens < bytes) + bpb_byte_level = loss / math.log(2) + assert bpb < bpb_byte_level + + # Must equal the analytical formula + expected = loss * num_tokens / (decoded_bytes * math.log(2)) + assert bpb == pytest.approx(expected) diff --git a/tests/test_train.py b/tests/test_train.py index e1f4e63..eae4ace 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -115,9 +115,17 @@ def test_basic_training_with_eval_chrf(temp_output_dir): assert "eval_samples" in eval_metrics, "eval_samples should be present" assert "perplexity" in eval_metrics, "perplexity should be present" + # Verify bits per byte is present and consistent with loss + assert "eval_bits_per_byte" in eval_metrics, "eval_bits_per_byte should be present" + import math + expected_bpb = eval_metrics["eval_loss"] / math.log(2) + assert abs(eval_metrics["eval_bits_per_byte"] - expected_bpb) < 0.001, \ + f"eval_bits_per_byte should equal loss/ln(2): {eval_metrics['eval_bits_per_byte']} vs {expected_bpb}" + print("\n✓ Training completed successfully!") print(f"✓ eval_chrf = {chrf_score:.2f}") print(f"✓ eval_loss = {eval_metrics['eval_loss']:.4f}") + print(f"✓ eval_bits_per_byte = {eval_metrics['eval_bits_per_byte']:.4f}") print(f"✓ eval_samples = {eval_metrics['eval_samples']}") print(f"✓ All metrics: {list(eval_metrics.keys())}") @@ -167,9 +175,10 @@ def test_training_without_generation_metrics(temp_output_dir): with open(eval_results_path) as f: eval_metrics = json.load(f) - # Should have loss and perplexity but no generation metrics + # Should have loss, perplexity, and bits_per_byte but no generation metrics assert "eval_loss" in eval_metrics assert "perplexity" in eval_metrics + assert "eval_bits_per_byte" in eval_metrics assert "eval_samples" in eval_metrics # Should not have generation metrics assert "eval_chrf" not in eval_metrics diff --git a/tests/test_trainer.py b/tests/test_trainer.py index a570f03..2be1e47 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -1012,6 +1012,53 @@ def test_dataloader_reuse(trainer_setup): assert abs(metrics1["eval_loss"] - metrics2["eval_loss"]) < 0.1 +def test_bits_per_byte_computation(trainer_setup): + """Test that bits_per_byte is computed correctly from loss.""" + model, processor, collator = trainer_setup + + eval_dataset = make_generation_dataset( + prefixes=["a", "b"], + completions=[" x", " y"], + ) + + training_args = Seq2SeqTrainingArguments( + output_dir="output/test_trainer", + per_device_eval_batch_size=2, + do_train=False, + do_eval=True, + remove_unused_columns=False, + predict_with_generate=True, + ) + + trainer = WeLTTrainer( + model=model, + args=training_args, + processor=processor, + data_collator=collator, + eval_metrics=None, + max_generated_words=3, + ) + + metrics = trainer.evaluate(eval_dataset) + + # Verify bits_per_byte is present and positive + assert "eval_bits_per_byte" in metrics, \ + f"eval_bits_per_byte should be in metrics. Found: {list(metrics.keys())}" + assert metrics["eval_bits_per_byte"] > 0, \ + f"eval_bits_per_byte should be positive, got {metrics['eval_bits_per_byte']}" + + # For byte-level model, BPB = loss / ln(2) + import math + expected_bpb = metrics["eval_loss"] / math.log(2) + assert abs(metrics["eval_bits_per_byte"] - expected_bpb) < 0.001, \ + f"eval_bits_per_byte should equal loss/ln(2): {metrics['eval_bits_per_byte']} vs {expected_bpb}" + + # BPB should also equal log2(perplexity) for byte-level model + assert abs(metrics["eval_bits_per_byte"] - math.log2(metrics["perplexity"])) < 0.001, \ + f"eval_bits_per_byte should equal log2(perplexity): " \ + f"{metrics['eval_bits_per_byte']} vs {math.log2(metrics['perplexity'])}" + + def test_accuracy_is_computed(trainer_setup): """Test that eval_accuracy is computed during evaluation.""" model, processor, collator = trainer_setup diff --git a/welt_training/experiments/machine-translation/run_clm.py b/welt_training/experiments/machine-translation/run_clm.py index 3cbd407..f85185d 100644 --- a/welt_training/experiments/machine-translation/run_clm.py +++ b/welt_training/experiments/machine-translation/run_clm.py @@ -70,6 +70,7 @@ from transformers.utils.versions import require_version from welt_training.data_utils import load_prepared_data +from welt_training.metrics import compute_bits_per_byte # Will error if the minimal version of Transformers is not installed. Remove at your own risks. @@ -570,7 +571,6 @@ def mapping_function(x): ) column_names = [text_column_name] - def tokenize_function(examples): with CaptureLogger(tok_logger) as cl: output = tokenizer(examples[text_column_name]) @@ -757,6 +757,19 @@ def compute_metrics(eval_preds): perplexity = float("inf") metrics["perplexity"] = perplexity + # Compute bits per byte from the evaluated subset. + # Decode tokens back to text to count the corresponding UTF-8 bytes. + if not data_args.streaming: + num_eval_tokens = len(eval_dataset) * block_size + num_eval_bytes = sum( + len(tokenizer.decode(example["input_ids"]).encode("utf-8")) + for example in eval_dataset + ) + if num_eval_bytes > 0: + metrics["eval_bits_per_byte"] = compute_bits_per_byte( + metrics["eval_loss"], num_eval_tokens, num_eval_bytes + ) + trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) diff --git a/welt_training/metrics.py b/welt_training/metrics.py new file mode 100644 index 0000000..bac5c0f --- /dev/null +++ b/welt_training/metrics.py @@ -0,0 +1,25 @@ +"""Metric utilities for WeLT training.""" + +import math + + +def compute_bits_per_byte(loss: float, num_tokens: int, num_bytes: int) -> float: + """ + Compute bits per byte (BPB) from average cross-entropy loss. + + Converts per-token cross-entropy loss (in nats) to bits per byte. + For byte-level models where num_tokens == num_bytes, this simplifies + to loss / ln(2). + + Args: + loss: Average cross-entropy loss per token (in nats). + num_tokens: Number of tokens the loss was averaged over. + num_bytes: Total number of bytes in the original text. + + Returns: + Bits per byte. + """ + if num_bytes == 0: + return float("inf") + total_bits = loss * num_tokens / math.log(2) + return total_bits / num_bytes diff --git a/welt_training/trainer.py b/welt_training/trainer.py index cb61454..6e320c2 100644 --- a/welt_training/trainer.py +++ b/welt_training/trainer.py @@ -4,7 +4,7 @@ Minimal extension of Trainer that adds support for: - Generation-based metrics (BLEU, ROUGE, SacreBLEU, ChrF, etc.) - Byte-level accuracy (token-level) and word-level accuracy from logits -- Perplexity computation from loss +- Perplexity and bits-per-byte computation from loss Overrides prediction_step to generate text predictions and store logits, then computes all metrics in evaluate(). @@ -41,6 +41,7 @@ class WeLTTrainer(Trainer): Computed metrics: - eval_loss: Cross-entropy loss + - eval_bits_per_byte: Bits per byte (loss / ln(2) for byte-level model) - eval_byte_accuracy: Token/byte-level accuracy (always computed) - eval_word_accuracy: Word-level accuracy - all tokens in word must be correct (always computed) - eval_{metric}: Generation metrics (e.g., eval_sacrebleu, eval_chrf) @@ -417,7 +418,7 @@ def _add_custom_metrics(self, metrics): metrics[metric_key] = value additional_metrics[metric_key] = value - # Add perplexity if we have loss + # Add perplexity and bits per byte if we have loss if "eval_loss" in metrics: loss = metrics["eval_loss"] # Use 709 as threshold to avoid float overflow; exp(709) ~ 8.2e307 is the largest representable float @@ -425,6 +426,12 @@ def _add_custom_metrics(self, metrics): metrics["perplexity"] = perplexity additional_metrics["perplexity"] = perplexity + # Bits per byte: for byte-level model, tokens == bytes, so BPB = loss / ln(2) + from welt_training.metrics import compute_bits_per_byte + bpb = compute_bits_per_byte(loss, num_tokens=1, num_bytes=1) + metrics["eval_bits_per_byte"] = bpb + additional_metrics["eval_bits_per_byte"] = bpb + # Compute byte and word accuracy from stored logits if self._eval_logits and self._eval_labels_for_accuracy: accuracy_metrics = self._compute_accuracy( From 646117a9f6fcd7085b5f425b06bceffbc7f7cd28 Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Thu, 26 Feb 2026 20:49:06 +0100 Subject: [PATCH 02/12] count EOS bytes during bits per byte computation --- tests/test_metrics.py | 21 +++++++------ tests/test_train.py | 8 ++--- tests/test_trainer.py | 16 +++++----- .../machine-translation/run_clm.py | 4 +-- welt_training/trainer.py | 31 ++++++++++++++----- 5 files changed, 47 insertions(+), 33 deletions(-) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index de63596..3fd59f1 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -80,22 +80,23 @@ def test_scaling_with_token_count(self): class TestBPBTokenizerIntegration: """Validate BPB computation patterns used in the training scripts.""" - def test_byte_level_model_pattern(self): + def test_byte_level_model_with_eos_overhead(self): """ - Validate the train.py (WeLTTrainer) pattern. + Validate the WeLTTrainer BPB pattern. - WeLT is a byte-level model: each token IS a byte, so num_tokens == num_bytes. - The trainer passes num_tokens=1, num_bytes=1 to get the ratio BPB = loss / ln(2). - Verify this matches log2(perplexity). + WeLT labels contain content bytes + one EOS per word. Loss is averaged + over all non-PAD positions (content + EOS), but BPB divides total bits + by content bytes only, producing BPB > loss/ln(2). """ loss = 4.2 - perplexity = math.exp(loss) + # Simulate 100 content bytes across 20 words → 120 loss tokens (100 bytes + 20 EOS) + num_tokens = 120 + num_bytes = 100 - # This is exactly what WeLTTrainer._add_custom_metrics does - bpb = compute_bits_per_byte(loss, num_tokens=1, num_bytes=1) + bpb = compute_bits_per_byte(loss, num_tokens, num_bytes) - assert bpb == pytest.approx(loss / math.log(2)) - assert bpb == pytest.approx(math.log2(perplexity)) + assert bpb == pytest.approx(loss * 120 / (100 * math.log(2))) + assert bpb > loss / math.log(2) # must exceed naive estimate def test_subword_tokenizer_compression_ratio(self): """ diff --git a/tests/test_train.py b/tests/test_train.py index eae4ace..4130eae 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -115,12 +115,12 @@ def test_basic_training_with_eval_chrf(temp_output_dir): assert "eval_samples" in eval_metrics, "eval_samples should be present" assert "perplexity" in eval_metrics, "perplexity should be present" - # Verify bits per byte is present and consistent with loss + # Verify bits per byte is present and accounts for EOS tokens (BPB > loss/ln(2)) assert "eval_bits_per_byte" in eval_metrics, "eval_bits_per_byte should be present" import math - expected_bpb = eval_metrics["eval_loss"] / math.log(2) - assert abs(eval_metrics["eval_bits_per_byte"] - expected_bpb) < 0.001, \ - f"eval_bits_per_byte should equal loss/ln(2): {eval_metrics['eval_bits_per_byte']} vs {expected_bpb}" + naive_bpb = eval_metrics["eval_loss"] / math.log(2) + assert eval_metrics["eval_bits_per_byte"] > naive_bpb, \ + f"eval_bits_per_byte should exceed loss/ln(2) due to EOS overhead: {eval_metrics['eval_bits_per_byte']} vs {naive_bpb}" print("\n✓ Training completed successfully!") print(f"✓ eval_chrf = {chrf_score:.2f}") diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 2be1e47..251f374 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -1047,16 +1047,14 @@ def test_bits_per_byte_computation(trainer_setup): assert metrics["eval_bits_per_byte"] > 0, \ f"eval_bits_per_byte should be positive, got {metrics['eval_bits_per_byte']}" - # For byte-level model, BPB = loss / ln(2) + # For a byte-level model with EOS tokens, BPB > loss / ln(2). + # labels_output contains content bytes + one EOS per word; EOS is counted + # in num_tokens (loss denominator) but not in num_bytes, so BPB is inflated. import math - expected_bpb = metrics["eval_loss"] / math.log(2) - assert abs(metrics["eval_bits_per_byte"] - expected_bpb) < 0.001, \ - f"eval_bits_per_byte should equal loss/ln(2): {metrics['eval_bits_per_byte']} vs {expected_bpb}" - - # BPB should also equal log2(perplexity) for byte-level model - assert abs(metrics["eval_bits_per_byte"] - math.log2(metrics["perplexity"])) < 0.001, \ - f"eval_bits_per_byte should equal log2(perplexity): " \ - f"{metrics['eval_bits_per_byte']} vs {math.log2(metrics['perplexity'])}" + naive_bpb = metrics["eval_loss"] / math.log(2) + assert metrics["eval_bits_per_byte"] > naive_bpb, \ + f"eval_bits_per_byte should exceed loss/ln(2) due to EOS overhead: " \ + f"{metrics['eval_bits_per_byte']} vs {naive_bpb}" def test_accuracy_is_computed(trainer_setup): diff --git a/welt_training/experiments/machine-translation/run_clm.py b/welt_training/experiments/machine-translation/run_clm.py index f85185d..9285de8 100644 --- a/welt_training/experiments/machine-translation/run_clm.py +++ b/welt_training/experiments/machine-translation/run_clm.py @@ -760,9 +760,9 @@ def compute_metrics(eval_preds): # Compute bits per byte from the evaluated subset. # Decode tokens back to text to count the corresponding UTF-8 bytes. if not data_args.streaming: - num_eval_tokens = len(eval_dataset) * block_size + num_eval_tokens = len(eval_dataset) * (block_size - 1) num_eval_bytes = sum( - len(tokenizer.decode(example["input_ids"]).encode("utf-8")) + len(tokenizer.decode(example["input_ids"][1:]).encode("utf-8")) for example in eval_dataset ) if num_eval_bytes > 0: diff --git a/welt_training/trainer.py b/welt_training/trainer.py index 6e320c2..42c2aa7 100644 --- a/welt_training/trainer.py +++ b/welt_training/trainer.py @@ -17,6 +17,7 @@ from transformers import GenerationConfig, Trainer from welt.processor import TextImageProcessor +from welt_training.metrics import compute_bits_per_byte logger = logging.getLogger(__name__) @@ -41,7 +42,7 @@ class WeLTTrainer(Trainer): Computed metrics: - eval_loss: Cross-entropy loss - - eval_bits_per_byte: Bits per byte (loss / ln(2) for byte-level model) + - eval_bits_per_byte: Bits per byte (derived from actual token/byte counts in labels) - eval_byte_accuracy: Token/byte-level accuracy (always computed) - eval_word_accuracy: Word-level accuracy - all tokens in word must be correct (always computed) - eval_{metric}: Generation metrics (e.g., eval_sacrebleu, eval_chrf) @@ -426,13 +427,7 @@ def _add_custom_metrics(self, metrics): metrics["perplexity"] = perplexity additional_metrics["perplexity"] = perplexity - # Bits per byte: for byte-level model, tokens == bytes, so BPB = loss / ln(2) - from welt_training.metrics import compute_bits_per_byte - bpb = compute_bits_per_byte(loss, num_tokens=1, num_bytes=1) - metrics["eval_bits_per_byte"] = bpb - additional_metrics["eval_bits_per_byte"] = bpb - - # Compute byte and word accuracy from stored logits + # Compute byte and word accuracy and bits per byte from stored logits if self._eval_logits and self._eval_labels_for_accuracy: accuracy_metrics = self._compute_accuracy( self._eval_logits, @@ -443,6 +438,11 @@ def _add_custom_metrics(self, metrics): additional_metrics["eval_byte_accuracy"] = accuracy_metrics["byte_accuracy"] additional_metrics["eval_word_accuracy"] = accuracy_metrics["word_accuracy"] + if "eval_loss" in metrics: + bpb = self._compute_bits_per_byte(metrics["eval_loss"], self._eval_labels_for_accuracy) + metrics["eval_bits_per_byte"] = bpb + additional_metrics["eval_bits_per_byte"] = bpb + # Add eval_samples count if self._eval_sample_count > 0: metrics["eval_samples"] = self._eval_sample_count @@ -538,6 +538,21 @@ def _compute_generation_metrics( return metrics + def _compute_bits_per_byte(self, loss: float, label_batches: list[torch.Tensor]) -> float: + """ + Compute bits per byte from loss and eval labels. + + labels_output contains content bytes + one EOS per word + PAD for alignment. + Loss is averaged over non-PAD positions (content bytes + EOS), but BPB + should reflect bits per actual text byte, so we separate the two counts. + """ + all_labels = torch.cat([l.flatten() for l in label_batches]) + pad_id = self.processor.tokenizer.pad_token_id + eos_id = self.processor.tokenizer.eos_token_id + num_tokens = (all_labels != pad_id).sum().item() + num_bytes = ((all_labels != pad_id) & (all_labels != eos_id)).sum().item() + return compute_bits_per_byte(loss, num_tokens, num_bytes) + def _compute_accuracy( self, pred_token_ids: list[torch.Tensor], From d7bc29c6ba5b475076dbb70ff25e07ccd68d1ac2 Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Mon, 2 Mar 2026 18:44:58 +0100 Subject: [PATCH 03/12] handle uneqaul non-PAD counts across batches correctly --- welt_training/trainer.py | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/welt_training/trainer.py b/welt_training/trainer.py index 42c2aa7..8fa5629 100644 --- a/welt_training/trainer.py +++ b/welt_training/trainer.py @@ -17,7 +17,6 @@ from transformers import GenerationConfig, Trainer from welt.processor import TextImageProcessor -from welt_training.metrics import compute_bits_per_byte logger = logging.getLogger(__name__) @@ -124,6 +123,11 @@ def _reset_eval_state(self): self._eval_sample_count = 0 self._eval_logits = [] self._eval_labels_for_accuracy = [] + # BPB accumulators: accumulate per-batch to avoid bias from + # unequal non-PAD counts across batches (eval_loss is an unweighted + # mean of per-batch means, so eval_loss * total_tokens != true total nats). + self._eval_total_nats = 0.0 + self._eval_total_content_bytes = 0 def get_eval_dataloader(self, eval_dataset=None): """ @@ -317,6 +321,15 @@ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None) self._eval_logits.append(pred_token_ids.cpu()) self._eval_labels_for_accuracy.append(labels_output.cpu()) + # Accumulate exact per-batch nats and content byte counts for BPB + pad_id = self.processor.tokenizer.pad_token_id + eos_id = self.processor.tokenizer.eos_token_id + flat_labels = labels_output.flatten() + batch_non_pad = (flat_labels != pad_id).sum().item() + batch_content_bytes = ((flat_labels != pad_id) & (flat_labels != eos_id)).sum().item() + self._eval_total_nats += loss.item() * batch_non_pad + self._eval_total_content_bytes += batch_content_bytes + # Generate predictions if predict_with_generate is enabled # Only do generation when: (1) predict_with_generate is True, (2) we have prefixes, # (3) and either we have metrics or prediction_loss_only is False @@ -438,8 +451,8 @@ def _add_custom_metrics(self, metrics): additional_metrics["eval_byte_accuracy"] = accuracy_metrics["byte_accuracy"] additional_metrics["eval_word_accuracy"] = accuracy_metrics["word_accuracy"] - if "eval_loss" in metrics: - bpb = self._compute_bits_per_byte(metrics["eval_loss"], self._eval_labels_for_accuracy) + if self._eval_total_content_bytes > 0: + bpb = self._eval_total_nats / (self._eval_total_content_bytes * math.log(2)) metrics["eval_bits_per_byte"] = bpb additional_metrics["eval_bits_per_byte"] = bpb @@ -538,21 +551,6 @@ def _compute_generation_metrics( return metrics - def _compute_bits_per_byte(self, loss: float, label_batches: list[torch.Tensor]) -> float: - """ - Compute bits per byte from loss and eval labels. - - labels_output contains content bytes + one EOS per word + PAD for alignment. - Loss is averaged over non-PAD positions (content bytes + EOS), but BPB - should reflect bits per actual text byte, so we separate the two counts. - """ - all_labels = torch.cat([l.flatten() for l in label_batches]) - pad_id = self.processor.tokenizer.pad_token_id - eos_id = self.processor.tokenizer.eos_token_id - num_tokens = (all_labels != pad_id).sum().item() - num_bytes = ((all_labels != pad_id) & (all_labels != eos_id)).sum().item() - return compute_bits_per_byte(loss, num_tokens, num_bytes) - def _compute_accuracy( self, pred_token_ids: list[torch.Tensor], From 58d67390204c53c32b3ff378189804ee6ae1750c Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Tue, 3 Mar 2026 21:00:22 +0100 Subject: [PATCH 04/12] make evaluation correctly working for text evaluation metrics --- tests/test_trainer.py | 133 +++++++++++++++ welt_training/trainer.py | 338 +++++++++++++++++++++++++++------------ 2 files changed, 372 insertions(+), 99 deletions(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 251f374..f63b238 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -1102,5 +1102,138 @@ def test_accuracy_is_computed(trainer_setup): # while longer words are only partially correct +def test_sync_eval_state_noop_single_gpu(trainer_setup): + """Test that _sync_eval_state does not modify values in single-GPU mode.""" + model, processor, collator = trainer_setup + + training_args = Seq2SeqTrainingArguments( + output_dir="output/test_trainer", + per_device_eval_batch_size=2, + do_train=False, + do_eval=True, + remove_unused_columns=False, + ) + + trainer = WeLTTrainer( + model=model, + args=training_args, + processor=processor, + data_collator=collator, + ) + + # Set known values on all accumulators + trainer._eval_correct_bytes = 10 + trainer._eval_total_bytes = 20 + trainer._eval_correct_words = 5 + trainer._eval_total_words = 10 + trainer._eval_total_nats = 100.0 + trainer._eval_total_content_bytes = 50 + trainer._eval_sample_count = 8 + trainer._eval_predictions = ["hello", "world"] + trainer._eval_labels = ["hi", "earth"] + + trainer._sync_eval_state() + + # Values should be unchanged (num_processes == 1 => no-op) + assert trainer._eval_correct_bytes == 10 + assert trainer._eval_total_bytes == 20 + assert trainer._eval_correct_words == 5 + assert trainer._eval_total_words == 10 + assert trainer._eval_total_nats == 100.0 + assert trainer._eval_total_content_bytes == 50 + assert trainer._eval_sample_count == 8 + assert trainer._eval_predictions == ["hello", "world"] + assert trainer._eval_labels == ["hi", "earth"] + + +def test_metrics_consistent_across_batch_sizes(trainer_setup): + """Test that metrics don't depend on batch size (validates correct incremental accumulation).""" + model, processor, collator = trainer_setup + + dataset = make_generation_dataset( + prefixes=["a ", "b ", "c ", "d "], + completions=["x", "y", "z", "w"], + ) + + results = {} + for bs in [1, 2, 4]: + args = Seq2SeqTrainingArguments( + output_dir="output/test_trainer", + per_device_eval_batch_size=bs, + do_train=False, + do_eval=True, + remove_unused_columns=False, + predict_with_generate=False, + report_to="none", + ) + trainer = WeLTTrainer( + model=model, + args=args, + processor=processor, + data_collator=collator, + ) + results[bs] = trainer.evaluate(dataset) + + # BPB and accuracy should be identical regardless of batch size + for bs in [2, 4]: + assert abs(results[1]["eval_bits_per_byte"] - results[bs]["eval_bits_per_byte"]) < 1e-6, \ + f"BPB differs between batch_size=1 ({results[1]['eval_bits_per_byte']}) " \ + f"and batch_size={bs} ({results[bs]['eval_bits_per_byte']})" + assert abs(results[1]["eval_byte_accuracy"] - results[bs]["eval_byte_accuracy"]) < 1e-6, \ + f"byte_accuracy differs between batch_size=1 and batch_size={bs}" + assert abs(results[1]["eval_word_accuracy"] - results[bs]["eval_word_accuracy"]) < 1e-6, \ + f"word_accuracy differs between batch_size=1 and batch_size={bs}" + + # Verify BPB > loss/ln(2) (due to EOS overhead) for at least one batch size + import math + ref = results[1] + naive_bpb = ref["eval_loss"] / math.log(2) + assert ref["eval_bits_per_byte"] > naive_bpb, \ + f"BPB ({ref['eval_bits_per_byte']}) should exceed loss/ln(2) ({naive_bpb}) due to EOS overhead" + + # Verify accuracy is in valid range + assert 0.0 <= ref["eval_byte_accuracy"] <= 1.0 + assert 0.0 <= ref["eval_word_accuracy"] <= 1.0 + + +def test_get_eval_dataloader_recognizes_custom_iterable(trainer_setup): + """Test that get_eval_dataloader takes the streaming branch for CustomIterableDataset.""" + from torch.utils.data import DataLoader + + from welt_training.streaming import CustomIterableDataset + + model, processor, collator = trainer_setup + + # Create a CustomIterableDataset (inherits from datasets.IterableDataset) + base_dataset = Dataset.from_dict({ + "text": ["Hello world", "Test text"], + "prefix": ["Hello ", "Test "], + "completion": ["world", "text"], + }) + iterable_ds = CustomIterableDataset(base_dataset.to_iterable_dataset()) + iterable_ds = iterable_ds.with_transform(processor) + + training_args = Seq2SeqTrainingArguments( + output_dir="output/test_trainer", + per_device_eval_batch_size=2, + do_train=False, + do_eval=True, + remove_unused_columns=False, + ) + + trainer = WeLTTrainer( + model=model, + args=training_args, + processor=processor, + data_collator=collator, + ) + + dataloader = trainer.get_eval_dataloader(iterable_ds) + + # Should return a plain DataLoader (not wrapped by accelerate's DataLoaderShard) + assert type(dataloader) is DataLoader, \ + f"Expected plain DataLoader for IterableDataset, got {type(dataloader).__name__}" + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) diff --git a/welt_training/trainer.py b/welt_training/trainer.py index 8fa5629..160257a 100644 --- a/welt_training/trainer.py +++ b/welt_training/trainer.py @@ -21,6 +21,22 @@ logger = logging.getLogger(__name__) +class _TorchIterableAdapter(torch.utils.data.IterableDataset): + """Wraps a HuggingFace datasets.IterableDataset as a torch IterableDataset. + + datasets.IterableDataset does not inherit from torch.utils.data.IterableDataset, + so PyTorch's DataLoader treats it as map-style and tries len()/__getitem__, + which fails. This thin adapter delegates __iter__ so DataLoader uses the + iterable protocol instead. + """ + + def __init__(self, hf_dataset): + self._dataset = hf_dataset + + def __iter__(self): + yield from self._dataset + + class WeLTTrainer(Trainer): """ Minimal trainer extension for WeLT generation-based evaluation. @@ -121,29 +137,120 @@ def _reset_eval_state(self): self._eval_predictions = [] self._eval_labels = [] self._eval_sample_count = 0 - self._eval_logits = [] - self._eval_labels_for_accuracy = [] + # Accuracy accumulators: incremental scalar counters (computed on-device + # in prediction_step, then all_reduced across ranks in _sync_eval_state). + self._eval_correct_bytes = 0 + self._eval_total_bytes = 0 + self._eval_correct_words = 0 + self._eval_total_words = 0 # BPB accumulators: accumulate per-batch to avoid bias from # unequal non-PAD counts across batches (eval_loss is an unweighted # mean of per-batch means, so eval_loss * total_tokens != true total nats). self._eval_total_nats = 0.0 self._eval_total_content_bytes = 0 + def _sync_eval_state(self): + """Synchronize evaluation accumulators across distributed processes. + + Uses accelerator.reduce for scalar counters and torch.distributed.all_gather_object + for string predictions/references. No-op when running on a single process. + """ + if self.accelerator.num_processes <= 1: + return + + import torch.distributed as dist + + device = self.accelerator.device + + # Reduce scalar counters across ranks (single all_reduce for integers) + int_counts = torch.tensor([ + self._eval_correct_bytes, + self._eval_total_bytes, + self._eval_correct_words, + self._eval_total_words, + self._eval_total_content_bytes, + self._eval_sample_count, + ], dtype=torch.long, device=device) + int_counts = self.accelerator.reduce(int_counts, reduction="sum") + + self._eval_correct_bytes = int_counts[0].item() + self._eval_total_bytes = int_counts[1].item() + self._eval_correct_words = int_counts[2].item() + self._eval_total_words = int_counts[3].item() + self._eval_total_content_bytes = int_counts[4].item() + self._eval_sample_count = int_counts[5].item() + + # Reduce float accumulator (separate tensor for precision) + float_counts = torch.tensor([self._eval_total_nats], dtype=torch.double, device=device) + float_counts = self.accelerator.reduce(float_counts, reduction="sum") + self._eval_total_nats = float_counts[0].item() + + # Gather string predictions/labels across ranks. + # Use raw all_gather_object (NOT gather_for_metrics) because + # gather_for_metrics trims to gradient_state.remainder, which is + # designed for per-batch use inside the eval loop. Calling it post-loop + # on the fully accumulated list would slice the entire evaluation down + # to only the last batch's remainder. + # Guard on config flags (not data presence) to ensure all ranks + # enter the collective and avoid deadlocks. + if self.args.predict_with_generate and self.loaded_metrics: + all_preds = [None] * self.accelerator.num_processes + all_labels = [None] * self.accelerator.num_processes + dist.all_gather_object(all_preds, self._eval_predictions) + dist.all_gather_object(all_labels, self._eval_labels) + self._eval_predictions = [s for rank_preds in all_preds for s in rank_preds] + self._eval_labels = [s for rank_labels in all_labels for s in rank_labels] + def get_eval_dataloader(self, eval_dataset=None): """ - Override to handle streaming eval datasets without accelerate's batch dispatch. + Override to handle streaming eval datasets. + + For IterableDataset (both torch and HuggingFace datasets variants): + - Shards across distributed ranks via split_dataset_by_node + - Wraps HF IterableDataset in a torch-compatible adapter so PyTorch's + DataLoader treats it as iterable (not map-style) + - Creates a DataLoader without accelerate's prepare() to avoid + string field concatenation errors - Streaming datasets (IterableDataset) with string fields like 'prefix' and 'completion' - fail when accelerate tries to concatenate batches. We create a dataloader without - accelerate's prepare() to avoid this issue. + For regular Dataset: + - Falls through to the base Trainer (handles DistributedSampler) """ - from torch.utils.data import DataLoader, IterableDataset + import datasets + from torch.utils.data import DataLoader + from torch.utils.data import IterableDataset as TorchIterableDataset eval_dataset = eval_dataset or self.eval_dataset - # For IterableDataset, create dataloader without accelerate's prepare - # to avoid string field concatenation errors - if isinstance(eval_dataset, IterableDataset): + # Check both torch and HF IterableDataset (they are unrelated classes; + # CustomIterableDataset inherits from datasets.IterableDataset only) + if isinstance(eval_dataset, (TorchIterableDataset, datasets.IterableDataset)): + # Shard across ranks for distributed evaluation + if self.accelerator.num_processes > 1: + from datasets.distributed import split_dataset_by_node + + from welt_training.streaming import CustomIterableDataset + + rank = self.accelerator.process_index + world_size = self.accelerator.num_processes + + if isinstance(eval_dataset, CustomIterableDataset): + old_transform = eval_dataset._transform + sharded_inner = split_dataset_by_node( + eval_dataset._dataset, rank=rank, world_size=world_size) + eval_dataset = CustomIterableDataset(sharded_inner) + if old_transform is not None: + eval_dataset.set_transform(old_transform) + elif isinstance(eval_dataset, datasets.IterableDataset): + eval_dataset = split_dataset_by_node( + eval_dataset, rank=rank, world_size=world_size) + + # datasets.IterableDataset does NOT inherit from + # torch.utils.data.IterableDataset, so PyTorch's DataLoader would + # treat it as map-style and try len()/__getitem__, which fails. + # Wrap in a thin torch-compatible adapter. + if not isinstance(eval_dataset, TorchIterableDataset): + eval_dataset = _TorchIterableAdapter(eval_dataset) + return DataLoader( eval_dataset, batch_size=self.args.per_device_eval_batch_size, @@ -287,9 +394,19 @@ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None) prefixes = inputs.get("prefix", None) completions = inputs.get("completion", None) - # Count samples in this batch - if prefixes is not None: - self._eval_sample_count += len(prefixes) + # Count samples in this batch (adjusted for last-batch padding below) + batch_sample_count = len(prefixes) if prefixes is not None else 0 + if (batch_sample_count > 0 + and self.accelerator.num_processes > 1 + and self.accelerator.gradient_state.end_of_dataloader): + remainder = self.accelerator.gradient_state.remainder + if remainder > 0: + rank = self.accelerator.process_index + real_on_this_rank = max( + 0, min(batch_sample_count, remainder - rank * batch_sample_count) + ) + batch_sample_count = real_on_this_rank + self._eval_sample_count += batch_sample_count # Create model inputs without custom fields model_inputs = { @@ -315,20 +432,67 @@ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None) # Get predicted token IDs (argmax over vocabulary) pred_token_ids = logits.argmax(dim=-1) # (batch, seq_len, vocab) -> (batch, seq_len) - # Store predictions and labels for accuracy + # Compute accuracy incrementally and accumulate BPB counters labels_output = inputs.get("labels_output") if labels_output is not None: - self._eval_logits.append(pred_token_ids.cpu()) - self._eval_labels_for_accuracy.append(labels_output.cpu()) + # Ensure labels are on the same device as predictions + # (streaming path skips accelerator.prepare, so labels may be on CPU) + if labels_output.device != pred_token_ids.device: + labels_output = labels_output.to(pred_token_ids.device) + + # Exclude padded last-batch replicas in distributed eval. + # DataLoaderShard pads the last batch by replaying earlier samples + # to keep batch sizes uniform across ranks. We trim to only real + # samples so these replicas don't inflate the eval counters. + if (self.accelerator.num_processes > 1 + and self.accelerator.gradient_state.end_of_dataloader): + remainder = self.accelerator.gradient_state.remainder + if remainder > 0: + rank = self.accelerator.process_index + per_device_bs = labels_output.shape[0] + real_on_this_rank = max( + 0, min(per_device_bs, remainder - rank * per_device_bs) + ) + if real_on_this_rank < per_device_bs: + labels_output = labels_output[:real_on_this_rank] + pred_token_ids = pred_token_ids[:real_on_this_rank] + logits = logits[:real_on_this_rank] - # Accumulate exact per-batch nats and content byte counts for BPB pad_id = self.processor.tokenizer.pad_token_id eos_id = self.processor.tokenizer.eos_token_id - flat_labels = labels_output.flatten() - batch_non_pad = (flat_labels != pad_id).sum().item() - batch_content_bytes = ((flat_labels != pad_id) & (flat_labels != eos_id)).sum().item() - self._eval_total_nats += loss.item() * batch_non_pad - self._eval_total_content_bytes += batch_content_bytes + + # Accuracy: compute on-device, accumulate scalars + non_pad_mask = labels_output != pad_id + byte_matches = (pred_token_ids == labels_output) & non_pad_mask + self._eval_correct_bytes += byte_matches.sum().item() + self._eval_total_bytes += non_pad_mask.sum().item() + + # Word-level: a word is correct if ALL its non-padding tokens match + word_nonpad = non_pad_mask.any(dim=2) + word_all_correct = ((pred_token_ids == labels_output) | ~non_pad_mask).all(dim=2) + self._eval_correct_words += (word_all_correct & word_nonpad).sum().item() + self._eval_total_words += word_nonpad.sum().item() + + # Accumulate exact per-batch nats and content byte counts for BPB. + # BPB is only well-defined for UTF-8 where each prediction is a byte. + # For UTF-16/UTF-32 the loss goes through bytes_decoder.compute_loss() + # with different semantics; skip BPB rather than report wrong values. + model_encoding = getattr(getattr(model, "config", None), "encoding", "UTF-8") + if model_encoding == "UTF-8": + # Recompute loss from (possibly trimmed) logits/labels so the + # numerator stays consistent with the trimmed token counts when + # padded last-batch replicas have been removed. + flat_labels = labels_output.flatten() + flat_logits = logits.reshape(-1, logits.size(-1)) + batch_non_pad = (flat_labels != pad_id).sum().item() + batch_content_bytes = ((flat_labels != pad_id) & (flat_labels != eos_id)).sum().item() + if batch_non_pad > 0: + batch_loss = torch.nn.functional.cross_entropy( + flat_logits, flat_labels, ignore_index=pad_id + ) + if torch.isfinite(batch_loss): + self._eval_total_nats += batch_loss.item() * batch_non_pad + self._eval_total_content_bytes += batch_content_bytes # Generate predictions if predict_with_generate is enabled # Only do generation when: (1) predict_with_generate is True, (2) we have prefixes, @@ -358,13 +522,28 @@ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None) predictions_text = model.generate(**generation_inputs, **generation_kwargs) + # Trim padded last-batch replicas for generation as well + if (self.accelerator.num_processes > 1 + and self.accelerator.gradient_state.end_of_dataloader): + remainder = self.accelerator.gradient_state.remainder + if remainder > 0: + rank = self.accelerator.process_index + per_device_bs = len(predictions_text) + real_on_this_rank = max( + 0, min(per_device_bs, remainder - rank * per_device_bs) + ) + predictions_text = predictions_text[:real_on_this_rank] + if completions is not None: + completions = completions[:real_on_this_rank] + # Store predictions and labels for generation metrics computation self._eval_predictions.extend(predictions_text) if completions is not None: self._eval_labels.extend(completions) - # Log samples once per evaluation - if not self._logged_samples_this_eval and predictions_text and self.log_samples > 0: + # Log samples once per evaluation (main process only to avoid duplicates) + if (not self._logged_samples_this_eval and predictions_text + and self.log_samples > 0 and self.is_world_process_zero()): self._log_samples(predictions_text, prefixes, completions) self._logged_samples_this_eval = True @@ -396,7 +575,7 @@ def evaluate(self, eval_dataset=None, **kwargs): metrics = super().evaluate(eval_dataset=eval_dataset, **kwargs) # Add custom metrics (generation, accuracy, perplexity) - additional_metrics = self._add_custom_metrics(metrics) + additional_metrics = self._add_custom_metrics(metrics, eval_dataset) # Log only the additional metrics we computed (not all metrics) # This allows them to appear in wandb and progress bars without creating @@ -416,11 +595,18 @@ def _prepare_eval_dataset(self, eval_dataset): eval_dataset = eval_dataset.with_transform(self.processor) return eval_dataset - def _add_custom_metrics(self, metrics): - """Add custom metrics (generation, accuracy, perplexity) to metrics dict.""" + def _add_custom_metrics(self, metrics, eval_dataset=None): + """Add custom metrics (generation, accuracy, perplexity) to metrics dict. + + Synchronizes accumulators across distributed processes before computing + final metric values. + """ + # Synchronize accumulators across ranks (no-op for single process) + self._sync_eval_state() + additional_metrics = {} - # Compute generation metrics from stored predictions + # Compute generation metrics from stored predictions (globally gathered) if self._eval_predictions and self._eval_labels and self.loaded_metrics: generation_metrics = self._compute_generation_metrics( self._eval_predictions, @@ -432,7 +618,7 @@ def _add_custom_metrics(self, metrics): metrics[metric_key] = value additional_metrics[metric_key] = value - # Add perplexity and bits per byte if we have loss + # Add perplexity if we have loss if "eval_loss" in metrics: loss = metrics["eval_loss"] # Use 709 as threshold to avoid float overflow; exp(709) ~ 8.2e307 is the largest representable float @@ -440,27 +626,32 @@ def _add_custom_metrics(self, metrics): metrics["perplexity"] = perplexity additional_metrics["perplexity"] = perplexity - # Compute byte and word accuracy and bits per byte from stored logits - if self._eval_logits and self._eval_labels_for_accuracy: - accuracy_metrics = self._compute_accuracy( - self._eval_logits, - self._eval_labels_for_accuracy - ) - metrics["eval_byte_accuracy"] = accuracy_metrics["byte_accuracy"] - metrics["eval_word_accuracy"] = accuracy_metrics["word_accuracy"] - additional_metrics["eval_byte_accuracy"] = accuracy_metrics["byte_accuracy"] - additional_metrics["eval_word_accuracy"] = accuracy_metrics["word_accuracy"] - - if self._eval_total_content_bytes > 0: - bpb = self._eval_total_nats / (self._eval_total_content_bytes * math.log(2)) - metrics["eval_bits_per_byte"] = bpb - additional_metrics["eval_bits_per_byte"] = bpb - - # Add eval_samples count - if self._eval_sample_count > 0: - metrics["eval_samples"] = self._eval_sample_count + # Compute byte and word accuracy from globally-reduced scalar counters + if self._eval_total_bytes > 0: + byte_accuracy = self._eval_correct_bytes / self._eval_total_bytes + word_accuracy = (self._eval_correct_words / self._eval_total_words + if self._eval_total_words > 0 else 0.0) + metrics["eval_byte_accuracy"] = byte_accuracy + metrics["eval_word_accuracy"] = word_accuracy + additional_metrics["eval_byte_accuracy"] = byte_accuracy + additional_metrics["eval_word_accuracy"] = word_accuracy + + # Compute bits per byte from globally-reduced counters + if self._eval_total_content_bytes > 0: + bpb = self._eval_total_nats / (self._eval_total_content_bytes * math.log(2)) + metrics["eval_bits_per_byte"] = bpb + additional_metrics["eval_bits_per_byte"] = bpb + + # Add eval_samples count. + # For map-style datasets, use len() to avoid overcounting from accelerate's + # batch padding. For generation, use gathered predictions count. Fall back to + # the reduced counter for sharded iterables. + if eval_dataset is not None and hasattr(eval_dataset, "__len__"): + metrics["eval_samples"] = len(eval_dataset) elif self._eval_predictions: metrics["eval_samples"] = len(self._eval_predictions) + elif self._eval_sample_count > 0: + metrics["eval_samples"] = self._eval_sample_count return additional_metrics @@ -551,54 +742,3 @@ def _compute_generation_metrics( return metrics - def _compute_accuracy( - self, - pred_token_ids: list[torch.Tensor], - labels: list[torch.Tensor] - ) -> dict[str, float]: - """ - Compute byte-level and word-level accuracy from predictions and labels. - - Byte accuracy: Percentage of correctly predicted tokens - Word accuracy: Percentage of words where ALL tokens are correctly predicted - - Args: - pred_token_ids: List of predicted token ID tensors, shape (batch_size, num_words, tokens_per_word) - labels: List of label tensors, shape (batch_size, num_words, tokens_per_word) - - Returns: - dict with 'byte_accuracy' and 'word_accuracy' as floats between 0 and 1 - """ - pad_token_id = self.processor.tokenizer.pad_token_id - - total_bytes = 0 - correct_bytes = 0 - total_words = 0 - correct_words = 0 - - # Process each batch separately (can't concatenate if tokens_per_word differs between batches) - # but vectorize operations within each batch for efficiency - for preds, label in zip(pred_token_ids, labels, strict=False): - # preds/label shape: (batch_size, num_words, tokens_per_word) - # Mask for non-padding tokens - non_pad_mask = label != pad_token_id # (batch_size, num_words, tokens_per_word) - - # Byte-level: count matching non-padding tokens across entire batch - byte_matches = (preds == label) & non_pad_mask - correct_bytes += byte_matches.sum().item() - total_bytes += non_pad_mask.sum().item() - - # Word-level: a word is correct if all its non-padding tokens are correct - word_nonpad = non_pad_mask.any(dim=2) # (batch_size, num_words) - word_matches = ((preds == label) | ~non_pad_mask).all(dim=2) # (batch_size, num_words) - valid_words = word_nonpad.sum().item() - correct_words += (word_matches & word_nonpad).sum().item() - total_words += valid_words - - byte_accuracy = correct_bytes / total_bytes if total_bytes > 0 else 0.0 - word_accuracy = correct_words / total_words if total_words > 0 else 0.0 - - return { - "byte_accuracy": byte_accuracy, - "word_accuracy": word_accuracy - } From 8ffbf29604b3ffc08fb0a5b441d163838798be0a Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Thu, 5 Mar 2026 16:21:41 +0100 Subject: [PATCH 05/12] add support for utf-16 and utf-32 encodings also as well --- tests/test_trainer.py | 58 ++++++++++++++++++++++++++++++++++++++++ welt_training/trainer.py | 30 ++++++++++++--------- 2 files changed, 76 insertions(+), 12 deletions(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index f63b238..990eb48 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -20,6 +20,20 @@ def trainer_setup(): return model, processor, collator +@pytest.fixture(scope="module") +def trainer_setup_utf32(): + """Setup UTF-32 model, processor, and collator for trainer tests.""" + model, processor, collator = setup_tiny_model( + image_encoder_name=None, + encoding="UTF-32", + bytes_decoder_name="sign/utf8-lm-tiny", + ) + # Force CPU to avoid device placement issues during generation + model = model.to(torch.device("cpu")) + model.eval() + return model, processor, collator + + def make_generation_dataset(prefixes: list[str], completions: list[str]) -> Dataset: """ Create a dataset for generation-based evaluation. @@ -1057,6 +1071,50 @@ def test_bits_per_byte_computation(trainer_setup): f"{metrics['eval_bits_per_byte']} vs {naive_bpb}" +def test_bits_per_byte_computation_utf32(trainer_setup_utf32): + """Test that bits_per_byte is computed correctly for UTF-32.""" + model, processor, collator = trainer_setup_utf32 + + eval_dataset = make_generation_dataset( + prefixes=["a", "b"], + completions=[" x", " y"], + ) + + training_args = Seq2SeqTrainingArguments( + output_dir="output/test_trainer", + per_device_eval_batch_size=2, + do_train=False, + do_eval=True, + remove_unused_columns=False, + predict_with_generate=False, + ) + + trainer = WeLTTrainer( + model=model, + args=training_args, + processor=processor, + data_collator=collator, + eval_metrics=None, + max_generated_words=3, + ) + + metrics = trainer.evaluate(eval_dataset) + + # Verify bits_per_byte is present and positive + assert "eval_bits_per_byte" in metrics, \ + f"eval_bits_per_byte should be in metrics. Found: {list(metrics.keys())}" + assert metrics["eval_bits_per_byte"] > 0, \ + f"eval_bits_per_byte should be positive, got {metrics['eval_bits_per_byte']}" + + # Loss is averaged over all non-PAD bytes (content + EOS bytes), while BPB + # divides by content bytes only, so BPB must exceed loss/ln(2). + import math + naive_bpb = metrics["eval_loss"] / math.log(2) + assert metrics["eval_bits_per_byte"] > naive_bpb, \ + f"eval_bits_per_byte should exceed loss/ln(2) due to EOS overhead: " \ + f"{metrics['eval_bits_per_byte']} vs {naive_bpb}" + + def test_accuracy_is_computed(trainer_setup): """Test that eval_accuracy is computed during evaluation.""" model, processor, collator = trainer_setup diff --git a/welt_training/trainer.py b/welt_training/trainer.py index 160257a..1482d25 100644 --- a/welt_training/trainer.py +++ b/welt_training/trainer.py @@ -474,24 +474,31 @@ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None) self._eval_total_words += word_nonpad.sum().item() # Accumulate exact per-batch nats and content byte counts for BPB. - # BPB is only well-defined for UTF-8 where each prediction is a byte. - # For UTF-16/UTF-32 the loss goes through bytes_decoder.compute_loss() - # with different semantics; skip BPB rather than report wrong values. + # UTF-8 uses direct byte-level CE, while UTF-16/UTF-32 use + # CharacterCausalLMWrapper.compute_loss() over split bytes. model_encoding = getattr(getattr(model, "config", None), "encoding", "UTF-8") - if model_encoding == "UTF-8": + bytes_per_token = {"UTF-8": 1, "UTF-16": 2, "UTF-32": 4}.get(model_encoding) + if bytes_per_token is not None: # Recompute loss from (possibly trimmed) logits/labels so the # numerator stays consistent with the trimmed token counts when # padded last-batch replicas have been removed. flat_labels = labels_output.flatten() flat_logits = logits.reshape(-1, logits.size(-1)) - batch_non_pad = (flat_labels != pad_id).sum().item() - batch_content_bytes = ((flat_labels != pad_id) & (flat_labels != eos_id)).sum().item() - if batch_non_pad > 0: - batch_loss = torch.nn.functional.cross_entropy( - flat_logits, flat_labels, ignore_index=pad_id - ) + batch_non_pad_tokens = (flat_labels != pad_id).sum().item() + batch_non_pad_bytes = batch_non_pad_tokens * bytes_per_token + batch_content_bytes = ( + ((flat_labels != pad_id) & (flat_labels != eos_id)).sum().item() * bytes_per_token + ) + if batch_non_pad_bytes > 0: + if model_encoding == "UTF-8": + batch_loss = torch.nn.functional.cross_entropy( + flat_logits, flat_labels, ignore_index=pad_id + ) + else: + batch_loss = model.bytes_decoder.compute_loss(flat_logits, flat_labels) + if torch.isfinite(batch_loss): - self._eval_total_nats += batch_loss.item() * batch_non_pad + self._eval_total_nats += batch_loss.item() * batch_non_pad_bytes self._eval_total_content_bytes += batch_content_bytes # Generate predictions if predict_with_generate is enabled @@ -741,4 +748,3 @@ def _compute_generation_metrics( logger.warning(f"Failed to compute metric '{metric_name}': {e}") return metrics - From cc2d4f39cbf64048ca6c9ff86011ff01d8a77213 Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Fri, 6 Mar 2026 13:33:05 +0100 Subject: [PATCH 06/12] refactor batch counting in welt trainer --- welt_training/trainer.py | 69 +++++++++++++++++----------------------- 1 file changed, 29 insertions(+), 40 deletions(-) diff --git a/welt_training/trainer.py b/welt_training/trainer.py index 1482d25..cc359f5 100644 --- a/welt_training/trainer.py +++ b/welt_training/trainer.py @@ -377,6 +377,23 @@ def _create_dion2_optimizer(self): ) return self.optimizer + def _real_batch_count(self, nominal_count): + """Return the number of real samples on this rank for the current batch. + + In distributed evaluation, DataLoaderShard pads the last batch by + replaying earlier samples so every rank has the same batch size. + This method trims the count to only real (non-padded) samples. + Returns ``nominal_count`` unchanged for single-GPU or non-last batches. + """ + if (nominal_count > 0 + and self.accelerator.num_processes > 1 + and self.accelerator.gradient_state.end_of_dataloader): + remainder = self.accelerator.gradient_state.remainder + if remainder > 0: + rank = self.accelerator.process_index + return max(0, min(nominal_count, remainder - rank * nominal_count)) + return nominal_count + def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): """ Override prediction_step to generate predictions and store data for metrics. @@ -394,19 +411,9 @@ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None) prefixes = inputs.get("prefix", None) completions = inputs.get("completion", None) - # Count samples in this batch (adjusted for last-batch padding below) + # Count samples in this batch (adjusted for last-batch padding) batch_sample_count = len(prefixes) if prefixes is not None else 0 - if (batch_sample_count > 0 - and self.accelerator.num_processes > 1 - and self.accelerator.gradient_state.end_of_dataloader): - remainder = self.accelerator.gradient_state.remainder - if remainder > 0: - rank = self.accelerator.process_index - real_on_this_rank = max( - 0, min(batch_sample_count, remainder - rank * batch_sample_count) - ) - batch_sample_count = real_on_this_rank - self._eval_sample_count += batch_sample_count + self._eval_sample_count += self._real_batch_count(batch_sample_count) # Create model inputs without custom fields model_inputs = { @@ -441,22 +448,11 @@ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None) labels_output = labels_output.to(pred_token_ids.device) # Exclude padded last-batch replicas in distributed eval. - # DataLoaderShard pads the last batch by replaying earlier samples - # to keep batch sizes uniform across ranks. We trim to only real - # samples so these replicas don't inflate the eval counters. - if (self.accelerator.num_processes > 1 - and self.accelerator.gradient_state.end_of_dataloader): - remainder = self.accelerator.gradient_state.remainder - if remainder > 0: - rank = self.accelerator.process_index - per_device_bs = labels_output.shape[0] - real_on_this_rank = max( - 0, min(per_device_bs, remainder - rank * per_device_bs) - ) - if real_on_this_rank < per_device_bs: - labels_output = labels_output[:real_on_this_rank] - pred_token_ids = pred_token_ids[:real_on_this_rank] - logits = logits[:real_on_this_rank] + real_count = self._real_batch_count(labels_output.shape[0]) + if real_count < labels_output.shape[0]: + labels_output = labels_output[:real_count] + pred_token_ids = pred_token_ids[:real_count] + logits = logits[:real_count] pad_id = self.processor.tokenizer.pad_token_id eos_id = self.processor.tokenizer.eos_token_id @@ -530,18 +526,11 @@ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None) predictions_text = model.generate(**generation_inputs, **generation_kwargs) # Trim padded last-batch replicas for generation as well - if (self.accelerator.num_processes > 1 - and self.accelerator.gradient_state.end_of_dataloader): - remainder = self.accelerator.gradient_state.remainder - if remainder > 0: - rank = self.accelerator.process_index - per_device_bs = len(predictions_text) - real_on_this_rank = max( - 0, min(per_device_bs, remainder - rank * per_device_bs) - ) - predictions_text = predictions_text[:real_on_this_rank] - if completions is not None: - completions = completions[:real_on_this_rank] + real_count = self._real_batch_count(len(predictions_text)) + if real_count < len(predictions_text): + predictions_text = predictions_text[:real_count] + if completions is not None: + completions = completions[:real_count] # Store predictions and labels for generation metrics computation self._eval_predictions.extend(predictions_text) From 18d2e3e5c894d704d0b38331b86117ef9d2d0e28 Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Fri, 6 Mar 2026 15:20:37 +0100 Subject: [PATCH 07/12] resolve ruff issues --- tests/test_train.py | 6 +- welt/processor.py | 2 +- welt_training/trainer.py | 190 ++++++++++++++++++++------------------- 3 files changed, 103 insertions(+), 95 deletions(-) diff --git a/tests/test_train.py b/tests/test_train.py index 4130eae..b7beca4 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -119,8 +119,10 @@ def test_basic_training_with_eval_chrf(temp_output_dir): assert "eval_bits_per_byte" in eval_metrics, "eval_bits_per_byte should be present" import math naive_bpb = eval_metrics["eval_loss"] / math.log(2) - assert eval_metrics["eval_bits_per_byte"] > naive_bpb, \ - f"eval_bits_per_byte should exceed loss/ln(2) due to EOS overhead: {eval_metrics['eval_bits_per_byte']} vs {naive_bpb}" + assert eval_metrics["eval_bits_per_byte"] > naive_bpb, ( + f"eval_bits_per_byte should exceed loss/ln(2) due to EOS overhead: " + f"{eval_metrics['eval_bits_per_byte']} vs {naive_bpb}" + ) print("\n✓ Training completed successfully!") print(f"✓ eval_chrf = {chrf_score:.2f}") diff --git a/welt/processor.py b/welt/processor.py index 54604a8..4589274 100644 --- a/welt/processor.py +++ b/welt/processor.py @@ -79,7 +79,7 @@ def save_pretrained(self, save_directory, push_to_hub=False, **kwargs): attr.save_pretrained(attr_dir) output = {k: v for k, v in self.__dict__.items() - if k not in self.attributes and isinstance(v, (int, float, str, bool))} + if k not in self.attributes and isinstance(v, int | float | str | bool)} output["processor_class"] = self.__class__.__name__ config_file = os.path.join(save_directory, PROCESSOR_CONFIG_NAME) with open(config_file, "w") as f: diff --git a/welt_training/trainer.py b/welt_training/trainer.py index cc359f5..bccad44 100644 --- a/welt_training/trainer.py +++ b/welt_training/trainer.py @@ -223,7 +223,7 @@ def get_eval_dataloader(self, eval_dataset=None): # Check both torch and HF IterableDataset (they are unrelated classes; # CustomIterableDataset inherits from datasets.IterableDataset only) - if isinstance(eval_dataset, (TorchIterableDataset, datasets.IterableDataset)): + if isinstance(eval_dataset, TorchIterableDataset | datasets.IterableDataset): # Shard across ranks for distributed evaluation if self.accelerator.num_processes > 1: from datasets.distributed import split_dataset_by_node @@ -394,6 +394,101 @@ def _real_batch_count(self, nominal_count): return max(0, min(nominal_count, remainder - rank * nominal_count)) return nominal_count + def _accumulate_accuracy_and_bpb(self, model, inputs, logits): + """Accumulate byte/word accuracy and bits-per-byte counters from logits.""" + pred_token_ids = logits.argmax(dim=-1) + + labels_output = inputs.get("labels_output") + if labels_output is None: + return + + if labels_output.device != pred_token_ids.device: + labels_output = labels_output.to(pred_token_ids.device) + + # Exclude padded last-batch replicas in distributed eval. + real_count = self._real_batch_count(labels_output.shape[0]) + if real_count < labels_output.shape[0]: + labels_output = labels_output[:real_count] + pred_token_ids = pred_token_ids[:real_count] + logits = logits[:real_count] + + pad_id = self.processor.tokenizer.pad_token_id + eos_id = self.processor.tokenizer.eos_token_id + + # Accuracy: compute on-device, accumulate scalars + non_pad_mask = labels_output != pad_id + byte_matches = (pred_token_ids == labels_output) & non_pad_mask + self._eval_correct_bytes += byte_matches.sum().item() + self._eval_total_bytes += non_pad_mask.sum().item() + + # Word-level: a word is correct if ALL its non-padding tokens match + word_nonpad = non_pad_mask.any(dim=2) + word_all_correct = ((pred_token_ids == labels_output) | ~non_pad_mask).all(dim=2) + self._eval_correct_words += (word_all_correct & word_nonpad).sum().item() + self._eval_total_words += word_nonpad.sum().item() + + # Accumulate exact per-batch nats and content byte counts for BPB. + # UTF-8 uses direct byte-level CE, while UTF-16/UTF-32 use + # CharacterCausalLMWrapper.compute_loss() over split bytes. + model_encoding = getattr(getattr(model, "config", None), "encoding", "UTF-8") + bytes_per_token = {"UTF-8": 1, "UTF-16": 2, "UTF-32": 4}.get(model_encoding) + if bytes_per_token is None: + return + + # Recompute loss from (possibly trimmed) logits/labels so the + # numerator stays consistent with the trimmed token counts when + # padded last-batch replicas have been removed. + flat_labels = labels_output.flatten() + flat_logits = logits.reshape(-1, logits.size(-1)) + batch_non_pad_tokens = (flat_labels != pad_id).sum().item() + batch_non_pad_bytes = batch_non_pad_tokens * bytes_per_token + batch_content_bytes = ( + ((flat_labels != pad_id) & (flat_labels != eos_id)).sum().item() * bytes_per_token + ) + if batch_non_pad_bytes > 0: + if model_encoding == "UTF-8": + batch_loss = torch.nn.functional.cross_entropy( + flat_logits, flat_labels, ignore_index=pad_id + ) + else: + batch_loss = model.bytes_decoder.compute_loss(flat_logits, flat_labels) + + if torch.isfinite(batch_loss): + self._eval_total_nats += batch_loss.item() * batch_non_pad_bytes + self._eval_total_content_bytes += batch_content_bytes + + def _generate_predictions(self, model, prefixes, completions): + """Generate text predictions and store them for metric computation.""" + with torch.no_grad(): + generation_inputs = self.processor(prefixes, collated=True) + generation_inputs = { + k: v.to(model.device) if isinstance(v, torch.Tensor) else v + for k, v in generation_inputs.items() + } + + generation_kwargs = { + "processor": self.processor, + "max_generated_words": self.max_generated_words, + } + if self.bytes_generation_config is not None: + generation_kwargs["bytes_generation_config"] = self.bytes_generation_config + + predictions_text = model.generate(**generation_inputs, **generation_kwargs) + + # Trim padded last-batch replicas for generation as well + real_count = self._real_batch_count(len(predictions_text)) + if real_count < len(predictions_text): + predictions_text = predictions_text[:real_count] + if completions is not None: + completions = completions[:real_count] + + # Store predictions and labels for generation metrics computation + self._eval_predictions.extend(predictions_text) + if completions is not None: + self._eval_labels.extend(completions) + + return predictions_text, completions + def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): """ Override prediction_step to generate predictions and store data for metrics. @@ -433,73 +528,11 @@ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None) loss = torch.tensor(0.0, device=model.device) # Store logits for accuracy computation (always during evaluation) - # Extract logits if available if hasattr(outputs, "logits") or "logits" in outputs: logits = outputs.logits if hasattr(outputs, "logits") else outputs["logits"] - # Get predicted token IDs (argmax over vocabulary) - pred_token_ids = logits.argmax(dim=-1) # (batch, seq_len, vocab) -> (batch, seq_len) - - # Compute accuracy incrementally and accumulate BPB counters - labels_output = inputs.get("labels_output") - if labels_output is not None: - # Ensure labels are on the same device as predictions - # (streaming path skips accelerator.prepare, so labels may be on CPU) - if labels_output.device != pred_token_ids.device: - labels_output = labels_output.to(pred_token_ids.device) - - # Exclude padded last-batch replicas in distributed eval. - real_count = self._real_batch_count(labels_output.shape[0]) - if real_count < labels_output.shape[0]: - labels_output = labels_output[:real_count] - pred_token_ids = pred_token_ids[:real_count] - logits = logits[:real_count] - - pad_id = self.processor.tokenizer.pad_token_id - eos_id = self.processor.tokenizer.eos_token_id - - # Accuracy: compute on-device, accumulate scalars - non_pad_mask = labels_output != pad_id - byte_matches = (pred_token_ids == labels_output) & non_pad_mask - self._eval_correct_bytes += byte_matches.sum().item() - self._eval_total_bytes += non_pad_mask.sum().item() - - # Word-level: a word is correct if ALL its non-padding tokens match - word_nonpad = non_pad_mask.any(dim=2) - word_all_correct = ((pred_token_ids == labels_output) | ~non_pad_mask).all(dim=2) - self._eval_correct_words += (word_all_correct & word_nonpad).sum().item() - self._eval_total_words += word_nonpad.sum().item() - - # Accumulate exact per-batch nats and content byte counts for BPB. - # UTF-8 uses direct byte-level CE, while UTF-16/UTF-32 use - # CharacterCausalLMWrapper.compute_loss() over split bytes. - model_encoding = getattr(getattr(model, "config", None), "encoding", "UTF-8") - bytes_per_token = {"UTF-8": 1, "UTF-16": 2, "UTF-32": 4}.get(model_encoding) - if bytes_per_token is not None: - # Recompute loss from (possibly trimmed) logits/labels so the - # numerator stays consistent with the trimmed token counts when - # padded last-batch replicas have been removed. - flat_labels = labels_output.flatten() - flat_logits = logits.reshape(-1, logits.size(-1)) - batch_non_pad_tokens = (flat_labels != pad_id).sum().item() - batch_non_pad_bytes = batch_non_pad_tokens * bytes_per_token - batch_content_bytes = ( - ((flat_labels != pad_id) & (flat_labels != eos_id)).sum().item() * bytes_per_token - ) - if batch_non_pad_bytes > 0: - if model_encoding == "UTF-8": - batch_loss = torch.nn.functional.cross_entropy( - flat_logits, flat_labels, ignore_index=pad_id - ) - else: - batch_loss = model.bytes_decoder.compute_loss(flat_logits, flat_labels) - - if torch.isfinite(batch_loss): - self._eval_total_nats += batch_loss.item() * batch_non_pad_bytes - self._eval_total_content_bytes += batch_content_bytes + self._accumulate_accuracy_and_bpb(model, inputs, logits) # Generate predictions if predict_with_generate is enabled - # Only do generation when: (1) predict_with_generate is True, (2) we have prefixes, - # (3) and either we have metrics or prediction_loss_only is False predictions_text = [] should_generate = ( self.args.predict_with_generate @@ -508,34 +541,7 @@ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None) ) if should_generate: - with torch.no_grad(): - # Process prefixes for generation - generation_inputs = self.processor(prefixes, collated=True) - generation_inputs = { - k: v.to(model.device) if isinstance(v, torch.Tensor) else v - for k, v in generation_inputs.items() - } - - generation_kwargs = { - "processor": self.processor, - "max_generated_words": self.max_generated_words, - } - if self.bytes_generation_config is not None: - generation_kwargs["bytes_generation_config"] = self.bytes_generation_config - - predictions_text = model.generate(**generation_inputs, **generation_kwargs) - - # Trim padded last-batch replicas for generation as well - real_count = self._real_batch_count(len(predictions_text)) - if real_count < len(predictions_text): - predictions_text = predictions_text[:real_count] - if completions is not None: - completions = completions[:real_count] - - # Store predictions and labels for generation metrics computation - self._eval_predictions.extend(predictions_text) - if completions is not None: - self._eval_labels.extend(completions) + predictions_text, completions = self._generate_predictions(model, prefixes, completions) # Log samples once per evaluation (main process only to avoid duplicates) if (not self._logged_samples_this_eval and predictions_text From 405edebebedb46fea49957abd3c968365b4ad87e Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Fri, 6 Mar 2026 16:05:52 +0100 Subject: [PATCH 08/12] fix eval sample counting when prefix field is absent --- tests/test_metrics.py | 1 + welt_training/trainer.py | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 3fd59f1..5eb37f5 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -118,6 +118,7 @@ def test_subword_tokenizer_compression_ratio(self): assert bpb_bpe < bpb_byte assert bpb_bpe == pytest.approx(bpb_byte * 400 / 1400) + @pytest.mark.integration def test_decode_roundtrip_byte_counting(self): """ Validate the run_clm.py approach: decode token IDs → encode as UTF-8 → count bytes. diff --git a/welt_training/trainer.py b/welt_training/trainer.py index bccad44..137e0d9 100644 --- a/welt_training/trainer.py +++ b/welt_training/trainer.py @@ -507,7 +507,13 @@ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None) completions = inputs.get("completion", None) # Count samples in this batch (adjusted for last-batch padding) - batch_sample_count = len(prefixes) if prefixes is not None else 0 + if prefixes is not None: + batch_sample_count = len(prefixes) + else: + batch_sample_count = next( + (v.shape[0] for v in inputs.values() if isinstance(v, torch.Tensor) and v.dim() > 0), + 0, + ) self._eval_sample_count += self._real_batch_count(batch_sample_count) # Create model inputs without custom fields From 518d5bcce2c18dd5d05487481d4ccc717086c657 Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Wed, 11 Mar 2026 13:32:43 +0100 Subject: [PATCH 09/12] remove huggingface/evaluate dependency for accuracy: it complicates things on no-internet-access clusters. --- .../machine-translation/run_clm.py | 66 +++++++++++-------- 1 file changed, 37 insertions(+), 29 deletions(-) diff --git a/welt_training/experiments/machine-translation/run_clm.py b/welt_training/experiments/machine-translation/run_clm.py index 9285de8..b86eff1 100644 --- a/welt_training/experiments/machine-translation/run_clm.py +++ b/welt_training/experiments/machine-translation/run_clm.py @@ -46,7 +46,6 @@ from typing import Optional import datasets -import evaluate import torch from datasets import DatasetDict, IterableDataset, IterableDatasetDict, load_dataset @@ -683,22 +682,50 @@ def group_texts(examples): max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) eval_dataset = eval_dataset.select(range(max_eval_samples)) + # Pre-compute token/byte counts for bits-per-byte on every evaluation + if not data_args.streaming: + num_eval_tokens = len(eval_dataset) * (block_size - 1) + num_eval_bytes = sum( + len(tokenizer.decode(example["input_ids"][1:]).encode("utf-8")) + for example in eval_dataset + ) + else: + num_eval_tokens = 0 + num_eval_bytes = 0 + def preprocess_logits_for_metrics(logits, labels): if isinstance(logits, tuple): - # Depending on the model and config, logits may contain extra tensors, - # like past_key_values, but logits always come first logits = logits[0] - return logits.argmax(dim=-1) - - metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir) + preds = logits.argmax(dim=-1) + # Compute per-token cross-entropy for BPB calculation + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + per_token_loss = torch.nn.functional.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + reduction="none", + ).view(shift_labels.shape) + # Pad to same seq_len as preds so Trainer can concatenate them + per_token_loss = torch.nn.functional.pad(per_token_loss, (0, 1), value=0.0) + return preds, per_token_loss def compute_metrics(eval_preds): - preds, labels = eval_preds - # preds have the same shape as the labels, after the argmax(-1) has been calculated - # by preprocess_logits_for_metrics but we need to shift the labels + (preds, per_token_losses), labels = eval_preds + # Accuracy (shift preds/labels for next-token prediction) labels = labels[:, 1:].reshape(-1) preds = preds[:, :-1].reshape(-1) - return metric.compute(predictions=preds, references=labels) + metrics = {"accuracy": float((preds == labels).mean())} + # Perplexity and BPB from per-token losses (strip padding column) + avg_loss = float(per_token_losses[:, :-1].mean()) + try: + metrics["perplexity"] = math.exp(avg_loss) + except OverflowError: + metrics["perplexity"] = float("inf") + if num_eval_bytes > 0: + metrics["bits_per_byte"] = compute_bits_per_byte( + avg_loss, num_eval_tokens, num_eval_bytes + ) + return metrics # Initialize our Trainer trainer = Trainer( @@ -751,25 +778,6 @@ def compute_metrics(eval_preds): else: metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) - try: - perplexity = math.exp(metrics["eval_loss"]) - except OverflowError: - perplexity = float("inf") - metrics["perplexity"] = perplexity - - # Compute bits per byte from the evaluated subset. - # Decode tokens back to text to count the corresponding UTF-8 bytes. - if not data_args.streaming: - num_eval_tokens = len(eval_dataset) * (block_size - 1) - num_eval_bytes = sum( - len(tokenizer.decode(example["input_ids"][1:]).encode("utf-8")) - for example in eval_dataset - ) - if num_eval_bytes > 0: - metrics["eval_bits_per_byte"] = compute_bits_per_byte( - metrics["eval_loss"], num_eval_tokens, num_eval_bytes - ) - trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) From 6d53b79cdc9bd8f4896fc7f18f7730b73dc27e84 Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Sat, 21 Mar 2026 02:41:33 +0100 Subject: [PATCH 10/12] Align BPB computation for both implementations --- tests/test_train.py | 11 +- tests/test_trainer.py | 39 +++-- welt/model.py | 92 +++++++++- welt_training/args_data.py | 11 ++ .../machine-translation/run_clm.py | 160 +++++++++++++----- welt_training/train.py | 34 +++- welt_training/trainer.py | 21 ++- 7 files changed, 282 insertions(+), 86 deletions(-) diff --git a/tests/test_train.py b/tests/test_train.py index b7beca4..ca7c066 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -115,14 +115,11 @@ def test_basic_training_with_eval_chrf(temp_output_dir): assert "eval_samples" in eval_metrics, "eval_samples should be present" assert "perplexity" in eval_metrics, "perplexity should be present" - # Verify bits per byte is present and accounts for EOS tokens (BPB > loss/ln(2)) + # EOS is excluded from both numerator and denominator, so BPB reflects + # only content-byte prediction cost. assert "eval_bits_per_byte" in eval_metrics, "eval_bits_per_byte should be present" - import math - naive_bpb = eval_metrics["eval_loss"] / math.log(2) - assert eval_metrics["eval_bits_per_byte"] > naive_bpb, ( - f"eval_bits_per_byte should exceed loss/ln(2) due to EOS overhead: " - f"{eval_metrics['eval_bits_per_byte']} vs {naive_bpb}" - ) + assert eval_metrics["eval_bits_per_byte"] > 0, \ + f"eval_bits_per_byte should be positive, got {eval_metrics['eval_bits_per_byte']}" print("\n✓ Training completed successfully!") print(f"✓ eval_chrf = {chrf_score:.2f}") diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 990eb48..cff3dee 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -61,6 +61,7 @@ def test_trainer_initialization(trainer_setup): do_eval=True, remove_unused_columns=False, predict_with_generate=True, + report_to="none", ) # Initialize trainer with no metrics @@ -90,6 +91,7 @@ def test_trainer_with_generation_config(trainer_setup): do_eval=True, remove_unused_columns=False, predict_with_generate=True, + report_to="none", ) # Initialize with generation config @@ -118,6 +120,7 @@ def test_trainer_with_metrics(trainer_setup): do_eval=True, remove_unused_columns=False, predict_with_generate=True, + report_to="none", ) # Initialize trainer with metrics (only bleu, as cer might not be available) @@ -149,6 +152,7 @@ def test_evaluation_with_generate(trainer_setup): do_eval=True, remove_unused_columns=False, predict_with_generate=True, + report_to="none", ) trainer = WeLTTrainer( @@ -610,6 +614,7 @@ def test_perplexity_computation(trainer_setup): do_eval=True, remove_unused_columns=False, predict_with_generate=True, + report_to="none", ) trainer = WeLTTrainer( @@ -1042,6 +1047,7 @@ def test_bits_per_byte_computation(trainer_setup): do_eval=True, remove_unused_columns=False, predict_with_generate=True, + report_to="none", ) trainer = WeLTTrainer( @@ -1061,14 +1067,12 @@ def test_bits_per_byte_computation(trainer_setup): assert metrics["eval_bits_per_byte"] > 0, \ f"eval_bits_per_byte should be positive, got {metrics['eval_bits_per_byte']}" - # For a byte-level model with EOS tokens, BPB > loss / ln(2). - # labels_output contains content bytes + one EOS per word; EOS is counted - # in num_tokens (loss denominator) but not in num_bytes, so BPB is inflated. + # EOS is excluded from both numerator and denominator, so BPB is derived + # from the trainer's content-only accumulators rather than eval_loss. import math - naive_bpb = metrics["eval_loss"] / math.log(2) - assert metrics["eval_bits_per_byte"] > naive_bpb, \ - f"eval_bits_per_byte should exceed loss/ln(2) due to EOS overhead: " \ - f"{metrics['eval_bits_per_byte']} vs {naive_bpb}" + expected_bpb = trainer._eval_total_nats / (trainer._eval_total_content_bytes * math.log(2)) + assert metrics["eval_bits_per_byte"] == pytest.approx(expected_bpb) + assert trainer._eval_total_content_bytes < trainer._eval_total_bytes def test_bits_per_byte_computation_utf32(trainer_setup_utf32): @@ -1087,6 +1091,7 @@ def test_bits_per_byte_computation_utf32(trainer_setup_utf32): do_eval=True, remove_unused_columns=False, predict_with_generate=False, + report_to="none", ) trainer = WeLTTrainer( @@ -1106,13 +1111,12 @@ def test_bits_per_byte_computation_utf32(trainer_setup_utf32): assert metrics["eval_bits_per_byte"] > 0, \ f"eval_bits_per_byte should be positive, got {metrics['eval_bits_per_byte']}" - # Loss is averaged over all non-PAD bytes (content + EOS bytes), while BPB - # divides by content bytes only, so BPB must exceed loss/ln(2). + # EOS is excluded from both numerator and denominator, so BPB is derived + # from the trainer's content-only accumulators rather than eval_loss. import math - naive_bpb = metrics["eval_loss"] / math.log(2) - assert metrics["eval_bits_per_byte"] > naive_bpb, \ - f"eval_bits_per_byte should exceed loss/ln(2) due to EOS overhead: " \ - f"{metrics['eval_bits_per_byte']} vs {naive_bpb}" + expected_bpb = trainer._eval_total_nats / (trainer._eval_total_content_bytes * math.log(2)) + assert metrics["eval_bits_per_byte"] == pytest.approx(expected_bpb) + assert trainer._eval_total_content_bytes < trainer._eval_total_bytes * 4 def test_accuracy_is_computed(trainer_setup): @@ -1242,12 +1246,11 @@ def test_metrics_consistent_across_batch_sizes(trainer_setup): assert abs(results[1]["eval_word_accuracy"] - results[bs]["eval_word_accuracy"]) < 1e-6, \ f"word_accuracy differs between batch_size=1 and batch_size={bs}" - # Verify BPB > loss/ln(2) (due to EOS overhead) for at least one batch size - import math + # EOS is excluded from both numerator and denominator, so BPB reflects + # only content-byte prediction cost. ref = results[1] - naive_bpb = ref["eval_loss"] / math.log(2) - assert ref["eval_bits_per_byte"] > naive_bpb, \ - f"BPB ({ref['eval_bits_per_byte']}) should exceed loss/ln(2) ({naive_bpb}) due to EOS overhead" + assert ref["eval_bits_per_byte"] > 0, \ + f"BPB should be positive, got {ref['eval_bits_per_byte']}" # Verify accuracy is in valid range assert 0.0 <= ref["eval_byte_accuracy"] <= 1.0 diff --git a/welt/model.py b/welt/model.py index cad710d..65ad9e5 100644 --- a/welt/model.py +++ b/welt/model.py @@ -581,7 +581,7 @@ def _prefill(self, encoded_input: torch.Tensor, attention_mask: torch.Tensor, batch_indices = torch.arange(logits.size(0), device=logits.device) mapped_latent = logits[batch_indices, num_words - 1].unsqueeze(1) # (B, 1, bytes_decoder_dim) - return latent_output.past_key_values, mapped_latent + return latent_output.past_key_values, mapped_latent, logits def _decode(self, past_key_values: Any, new_embedding: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.Tensor | None = None) -> tuple[Any, torch.Tensor]: @@ -682,6 +682,8 @@ def generate( processor: TextImageProcessor, max_generated_words: int = 50, bytes_generation_config: GenerationConfig | None = None, + return_entropy: bool = False, + prompt_words: list[str] | None = None, **_unused_kwargs): """ Generate text using prefill/decode with KV-cache. @@ -698,6 +700,8 @@ def generate( processor: TextImageProcessor for tokenization and rendering max_generated_words: maximum words to generate bytes_generation_config: optional GenerationConfig for bytes_decoder + return_entropy: if True, also return per-byte entropy (batch_size must be 1) + prompt_words: original prompt words (from processor.pretokenize), needed for prompt entropy """ tokenizer = processor.tokenizer device = input_ids.device @@ -725,7 +729,7 @@ def generate( # Use default position_ids for prefill (sequential), attention mask handles padding max_initial = initial_num_words.max().item() - past_key_values, latents = self._prefill(encoded_input, attention_mask, initial_num_words) + past_key_values, latents, prefill_logits = self._prefill(encoded_input, attention_mask, initial_num_words) # Pre-allocate decode attention mask (1s everywhere except padding positions) decode_mask_full = torch.ones((batch_size, max_initial + max_generated_words), device=device, @@ -736,6 +740,9 @@ def generate( # Generation loop all_generated_words = [[] for _ in range(batch_size)] + if return_entropy and batch_size != 1: + raise ValueError(f"return_entropy=True requires batch_size=1, got {batch_size}") + word_latents = [] if return_entropy else None words = None for step_idx in range(max_generated_words): @@ -752,6 +759,9 @@ def generate( past_key_values, new_embedding, decode_mask, decode_position_ids ) + if return_entropy: + word_latents.append(latents.detach()) + # Generate bytes from latents generated_bytes = self._generate_word_bytes( latents, tokenizer, bos_embed, bytes_generation_config, stopping_criteria, @@ -767,7 +777,83 @@ def generate( if not collected or collected[-1]: collected.append(word) - return ["".join(words) for words in all_generated_words] + texts = ["".join(words) for words in all_generated_words] + + if return_entropy: + # Prompt entropy: position i predicts word i+1 + prompt_entropies, prompt_byte_labels = [], [] + if prompt_words is not None and len(prompt_words) > 1: + num_prompt = min(initial_num_words[0].item(), len(prompt_words)) + prompt_latents = [prefill_logits[:, i:i+1, :] for i in range(num_prompt - 1)] + prompt_target_words = prompt_words[1:num_prompt] + prompt_entropies, prompt_byte_labels = self._compute_generation_entropy( + prompt_latents, prompt_target_words, tokenizer, device) + + gen_entropies, gen_byte_labels = self._compute_generation_entropy( + word_latents, all_generated_words[0], tokenizer, device) + + return (texts, + prompt_entropies + gen_entropies, + prompt_byte_labels + gen_byte_labels, + len(prompt_entropies)) + + return texts + + def _compute_generation_entropy( + self, + word_latents: list[torch.Tensor], + generated_words: list[str], + tokenizer, + device: torch.device, + ) -> tuple[list[float], list[str]]: + """Compute per-byte entropy for generated words using teacher-forced decoding.""" + # Filter to non-empty words and their corresponding latents + valid = [(lat, w) for lat, w in zip(word_latents, generated_words, strict=False) if w] + if not valid: + return [], [] + + latents_list, words_list = zip(*valid, strict=True) + encoding = tokenizer.encoding if hasattr(tokenizer, 'encoding') else 'utf-8' + bos_id = tokenizer.bos_token_id + pad_id = tokenizer.pad_token_id + + # Encode each word to bytes + word_byte_ids = [list(w.encode(encoding)) for w in words_list] + num_words = len(word_byte_ids) + max_len = max(len(ids) for ids in word_byte_ids) + 1 # +1 for BOS + + # Build labels_input: [BOS, b0, b1, ...] and mask + labels_input = torch.full((1, num_words, max_len), pad_id, device=device, dtype=torch.long) + labels_mask = torch.zeros((1, num_words, max_len), device=device, dtype=torch.long) + + for i, ids in enumerate(word_byte_ids): + seq = [bos_id] + ids + labels_input[0, i, :len(seq)] = torch.tensor(seq, device=device, dtype=torch.long) + labels_mask[0, i, :len(seq)] = 1 + + # Stack latents: (1, num_words, hidden_dim) + latents_stacked = torch.cat(list(latents_list), dim=1) + + # Teacher-forced forward pass to get logits + logits = self.parallel_causal_decode(latents_stacked, labels_input, labels_mask) + # logits: (1, num_words, max_len, vocab_size) + + probs = torch.softmax(logits.float(), dim=-1) + log2_probs = torch.log2(probs + 1e-10) + entropy = -(probs * log2_probs).sum(dim=-1) # (1, num_words, max_len) + + # Flatten per-byte entropies and create display labels + byte_entropies = [] + byte_labels = [] + for i, ids in enumerate(word_byte_ids): + for j, byte_val in enumerate(ids): + byte_entropies.append(entropy[0, i, j].item()) + if 32 <= byte_val < 127: + byte_labels.append(chr(byte_val)) + else: + byte_labels.append(f"\\x{byte_val:02x}") + + return byte_entropies, byte_labels AutoConfig.register(WordLatentTransformerConfig.model_type, WordLatentTransformerConfig) diff --git a/welt_training/args_data.py b/welt_training/args_data.py index edad1c0..7a5939c 100644 --- a/welt_training/args_data.py +++ b/welt_training/args_data.py @@ -102,6 +102,17 @@ class DataTrainingArguments: "help": "Path to prepared dataset shards (from welt-prepare-data). Skips download and text extraction." }, ) + pack_eval_dataset: bool = field( + default=False, + metadata={ + "help": ( + "Pack the evaluation dataset the same way as the training set. " + "Use this for apples-to-apples BPB comparison against CLM baselines (e.g. Pythia). " + "Incompatible with generation-based eval metrics (eval_metrics). " + "When enabled, eval_samples reflects packed chunks, not raw validation examples." + ) + }, + ) def __post_init__(self): if self.streaming: diff --git a/welt_training/experiments/machine-translation/run_clm.py b/welt_training/experiments/machine-translation/run_clm.py index b86eff1..f005181 100644 --- a/welt_training/experiments/machine-translation/run_clm.py +++ b/welt_training/experiments/machine-translation/run_clm.py @@ -253,6 +253,16 @@ class DataTrainingArguments: "help": "Path to prepared dataset shards (from welt-prepare-data). Skips download and text extraction." }, ) + eval_preserve_document_boundaries: bool = field( + default=False, + metadata={ + "help": ( + "Chunk each validation document independently instead of concatenating all text " + "into one stream. Preserves document boundaries so the model never sees cross-document " + "context during eval, matching WeLT's packed-eval protocol for fair BPB comparison." + ) + }, + ) def __post_init__(self): if self.streaming: @@ -570,6 +580,22 @@ def mapping_function(x): ) column_names = [text_column_name] + # Limit samples before tokenization/packing so max_eval_samples and + # max_train_samples select raw documents, consistent with WELT's train.py. + if training_args.do_train and data_args.max_train_samples is not None: + if data_args.streaming: + raw_datasets["train"] = raw_datasets["train"].take(data_args.max_train_samples) + else: + max_train_samples = min(len(raw_datasets["train"]), data_args.max_train_samples) + raw_datasets["train"] = raw_datasets["train"].select(range(max_train_samples)) + + if training_args.do_eval and data_args.max_eval_samples is not None: + if data_args.streaming: + raw_datasets["validation"] = raw_datasets["validation"].take(data_args.max_eval_samples) + else: + max_eval_samples = min(len(raw_datasets["validation"]), data_args.max_eval_samples) + raw_datasets["validation"] = raw_datasets["validation"].select(range(max_eval_samples)) + def tokenize_function(examples): with CaptureLogger(tok_logger) as cl: output = tokenizer(examples[text_column_name]) @@ -638,6 +664,33 @@ def group_texts(examples): result["labels"] = result["input_ids"].copy() return result + # Chunk each document independently, preserving document boundaries. + # The model never sees cross-document context, matching WeLT's packed-eval protocol. + # Remainder chunks are padded to block_size; labels use -100 for padding so the loss ignores them. + pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id + + def chunk_documents(examples): + result = {k: [] for k in examples} + result["labels"] = [] + for i in range(len(examples["input_ids"])): + doc_len = len(examples["input_ids"][i]) + if doc_len == 0: + continue + for start in range(0, doc_len, block_size): + real_len = min(block_size, doc_len - start) + pad_len = block_size - real_len + result["input_ids"].append( + examples["input_ids"][i][start:start + block_size] + [pad_id] * pad_len + ) + result["attention_mask"].append( + examples["attention_mask"][i][start:start + block_size] + [0] * pad_len + ) + # Labels: real tokens keep their IDs, padding positions = -100 (ignored by CE loss) + result["labels"].append( + examples["input_ids"][i][start:start + block_size] + [-100] * pad_len + ) + return result + # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower # to preprocess. @@ -645,50 +698,64 @@ def group_texts(examples): # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: # https://huggingface.co/docs/datasets/process#map + use_chunk_documents = ( + data_args.eval_preserve_document_boundaries + and training_args.do_eval + ) + + map_kwargs = {} + if not data_args.streaming: + map_kwargs = { + "num_proc": data_args.preprocessing_num_workers, + "load_from_cache_file": not data_args.overwrite_cache, + } + with training_args.main_process_first(desc="grouping texts together"): - if not data_args.streaming: - lm_datasets = tokenized_datasets.map( - group_texts, - batched=True, - num_proc=data_args.preprocessing_num_workers, - load_from_cache_file=not data_args.overwrite_cache, - desc=f"Grouping texts in chunks of {block_size}", - ) - else: - lm_datasets = tokenized_datasets.map( - group_texts, - batched=True, + if training_args.do_train: + if "train" not in tokenized_datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = tokenized_datasets["train"].map( + group_texts, batched=True, + desc=f"Grouping train texts in chunks of {block_size}", + **map_kwargs, ) - if training_args.do_train: - if "train" not in tokenized_datasets: - raise ValueError("--do_train requires a train dataset") - train_dataset = lm_datasets["train"] - if data_args.max_train_samples is not None: - if data_args.streaming: - train_dataset = train_dataset.take(data_args.max_train_samples) - else: - max_train_samples = min(len(train_dataset), data_args.max_train_samples) - train_dataset = train_dataset.select(range(max_train_samples)) + if training_args.do_eval: + if "validation" not in tokenized_datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_fn = chunk_documents if use_chunk_documents else group_texts + eval_desc = ( + f"Chunking eval documents in blocks of {block_size} (preserving boundaries)" + if use_chunk_documents + else f"Grouping eval texts in chunks of {block_size}" + ) + eval_dataset = tokenized_datasets["validation"].map( + eval_fn, batched=True, desc=eval_desc, **map_kwargs, + ) if training_args.do_eval: - if "validation" not in tokenized_datasets: - raise ValueError("--do_eval requires a validation dataset") - eval_dataset = lm_datasets["validation"] - if data_args.max_eval_samples is not None: - if data_args.streaming: - eval_dataset = eval_dataset.take(data_args.max_eval_samples) - else: - max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) - eval_dataset = eval_dataset.select(range(max_eval_samples)) - - # Pre-compute token/byte counts for bits-per-byte on every evaluation + # Pre-compute token/byte counts for bits-per-byte on every evaluation. + # When chunk_documents is used, chunks may be shorter than block_size + # (labels use -100 for padding), so we count only real predicted positions. if not data_args.streaming: - num_eval_tokens = len(eval_dataset) * (block_size - 1) - num_eval_bytes = sum( - len(tokenizer.decode(example["input_ids"][1:]).encode("utf-8")) - for example in eval_dataset - ) + if use_chunk_documents: + num_eval_tokens = 0 + num_eval_bytes = 0 + for example in eval_dataset: + # Real length = number of non-padding labels + real_len = sum(1 for l in example["labels"] if l != -100) + # Predicted positions = real_len - 1 (first token is context only) + num_eval_tokens += max(0, real_len - 1) + if real_len > 1: + num_eval_bytes += len( + tokenizer.decode(example["input_ids"][1:real_len]).encode("utf-8") + ) + else: + num_eval_tokens = len(eval_dataset) * (block_size - 1) + num_eval_bytes = sum( + len(tokenizer.decode(example["input_ids"][1:]).encode("utf-8")) + for example in eval_dataset + ) else: num_eval_tokens = 0 num_eval_bytes = 0 @@ -704,6 +771,7 @@ def preprocess_logits_for_metrics(logits, labels): shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), reduction="none", + ignore_index=-100, ).view(shift_labels.shape) # Pad to same seq_len as preds so Trainer can concatenate them per_token_loss = torch.nn.functional.pad(per_token_loss, (0, 1), value=0.0) @@ -711,12 +779,16 @@ def preprocess_logits_for_metrics(logits, labels): def compute_metrics(eval_preds): (preds, per_token_losses), labels = eval_preds - # Accuracy (shift preds/labels for next-token prediction) - labels = labels[:, 1:].reshape(-1) - preds = preds[:, :-1].reshape(-1) - metrics = {"accuracy": float((preds == labels).mean())} - # Perplexity and BPB from per-token losses (strip padding column) - avg_loss = float(per_token_losses[:, :-1].mean()) + # Accuracy (shift preds/labels for next-token prediction, ignoring padding) + shifted_labels = labels[:, 1:].reshape(-1) + shifted_preds = preds[:, :-1].reshape(-1) + valid = shifted_labels != -100 + metrics = {"accuracy": float((shifted_preds[valid] == shifted_labels[valid]).mean())} + # Perplexity and BPB from per-token losses (strip padding column). + # Padding positions have loss=0 (from ignore_index=-100), so sum / num_eval_tokens + # gives the correct weighted average for both padded and non-padded chunks. + total_loss = float(per_token_losses[:, :-1].sum()) + avg_loss = total_loss / num_eval_tokens if num_eval_tokens > 0 else 0.0 try: metrics["perplexity"] = math.exp(avg_loss) except OverflowError: diff --git a/welt_training/train.py b/welt_training/train.py index 83d374b..d3c5a06 100644 --- a/welt_training/train.py +++ b/welt_training/train.py @@ -402,10 +402,19 @@ def train(args: list[str] | None | str = None): # noqa: C901 streaming=data_args.streaming) # Sequence packing - if train_dataset: - block_size = min(data_args.block_size or math.inf, processor.max_seq_length) - train_dataset = processor.pretokenize_dataset(train_dataset, num_proc=data_args.preprocessing_num_workers) - train_dataset = pack_dataset(train_dataset, seq_length=block_size) + block_size = min(data_args.block_size or math.inf, processor.max_seq_length) + + def pretokenize_and_pack(dataset): + # Strip columns that can't survive packing (scalar strings from generation templates). + # Packing concatenates documents, destroying per-document prefix/completion boundaries. + col_names = dataset.column_names if hasattr(dataset, "column_names") else None + if col_names: + drop = [c for c in col_names if c not in {"text", "words"}] + if drop: + dataset = dataset.remove_columns(drop) + + dataset = processor.pretokenize_dataset(dataset, num_proc=data_args.preprocessing_num_workers) + dataset = pack_dataset(dataset, seq_length=block_size) # Pad to fixed length for CUDA kernel caching (consistent tensor shapes) def pad_to_fixed_length(example): @@ -418,7 +427,22 @@ def pad_to_fixed_length(example): example["seq_lengths"] = seq_lengths + [1] * pad_count # Each padding is a separate "sequence" return example - train_dataset = train_dataset.map(pad_to_fixed_length, batched=False) + return dataset.map(pad_to_fixed_length, batched=False) + + if train_dataset: + train_dataset = pretokenize_and_pack(train_dataset) + if eval_dataset and data_args.pack_eval_dataset: + # Validate: packed eval is incompatible with generation metrics + eval_cols = getattr(eval_dataset, "column_names", None) or [] + has_generation_cols = "prefix" in eval_cols or "completion" in eval_cols + has_generation_metrics = bool(training_args.eval_metrics) if hasattr(training_args, "eval_metrics") else False + if has_generation_cols or has_generation_metrics: + raise ValueError( + "pack_eval_dataset=True is incompatible with generation-based evaluation. " + "Packing concatenates documents, destroying per-document prefix/completion boundaries. " + "Either disable pack_eval_dataset or remove eval_metrics and the two-part dataset_text_template." + ) + eval_dataset = pretokenize_and_pack(eval_dataset) # Wrap streaming datasets with CustomIterableDataset to support with_transform if train_dataset: diff --git a/welt_training/trainer.py b/welt_training/trainer.py index 137e0d9..e47b663 100644 --- a/welt_training/trainer.py +++ b/welt_training/trainer.py @@ -438,23 +438,26 @@ def _accumulate_accuracy_and_bpb(self, model, inputs, logits): # Recompute loss from (possibly trimmed) logits/labels so the # numerator stays consistent with the trimmed token counts when # padded last-batch replicas have been removed. + # Exclude EOS from both numerator and denominator so BPB measures + # only content-byte prediction cost, comparable to CLM baselines. flat_labels = labels_output.flatten() flat_logits = logits.reshape(-1, logits.size(-1)) - batch_non_pad_tokens = (flat_labels != pad_id).sum().item() - batch_non_pad_bytes = batch_non_pad_tokens * bytes_per_token - batch_content_bytes = ( - ((flat_labels != pad_id) & (flat_labels != eos_id)).sum().item() * bytes_per_token - ) - if batch_non_pad_bytes > 0: + content_mask = (flat_labels != pad_id) & (flat_labels != eos_id) + batch_content_bytes = content_mask.sum().item() * bytes_per_token + if batch_content_bytes > 0: + # Mask EOS positions so they are ignored by the loss + flat_labels_content = flat_labels.clone() + flat_labels_content[flat_labels == eos_id] = pad_id + if model_encoding == "UTF-8": batch_loss = torch.nn.functional.cross_entropy( - flat_logits, flat_labels, ignore_index=pad_id + flat_logits, flat_labels_content, ignore_index=pad_id ) else: - batch_loss = model.bytes_decoder.compute_loss(flat_logits, flat_labels) + batch_loss = model.bytes_decoder.compute_loss(flat_logits, flat_labels_content) if torch.isfinite(batch_loss): - self._eval_total_nats += batch_loss.item() * batch_non_pad_bytes + self._eval_total_nats += batch_loss.item() * batch_content_bytes self._eval_total_content_bytes += batch_content_bytes def _generate_predictions(self, model, prefixes, completions): From 2493449b30534023c1efc4d58d0632d9c288db3e Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Sat, 21 Mar 2026 02:42:29 +0100 Subject: [PATCH 11/12] make image-based models working --- welt/model_utils.py | 16 +++++++--------- welt/vision/navit.py | 9 +++++++-- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/welt/model_utils.py b/welt/model_utils.py index 09d6df4..02b94a8 100644 --- a/welt/model_utils.py +++ b/welt/model_utils.py @@ -10,6 +10,7 @@ AutoImageProcessor, AutoTokenizer, PretrainedConfig, + ViTImageProcessorFast, set_seed, ) from utf8_tokenizer.tokenizer import UTF8Tokenizer, UTF16Tokenizer, UTF32Tokenizer @@ -41,6 +42,7 @@ def get_attn_implementation(): CUSTOM_MODELS: dict[str, PretrainedConfig] = { "NaViT-tiny": NaViTConfig( + image_size=(16, 1536), # height=16px (rendered text), width up to 1536px patch_size=16, hidden_size=128, dim=64, @@ -52,6 +54,7 @@ def get_attn_implementation(): token_dropout_prob=0.1, ), "NaViT-small": NaViTConfig( + image_size=(16, 1536), # height=16px (rendered text), width up to 1536px patch_size=16, hidden_size=512, dim=256, @@ -63,10 +66,6 @@ def get_attn_implementation(): token_dropout_prob=0.1, ) } -CUSTOM_PROCESSORS_ALIAS: dict[str, str] = { - "NaViT-tiny": "WinKawaks/vit-tiny-patch16-224", - "NaViT-small": "WinKawaks/vit-tiny-patch16-224", -} def get_model_config(model_name, config_path: str | None = None): @@ -124,12 +123,11 @@ def setup_model( ): set_seed(seed) - # Load image processor - need a pretrained name even when using config - if image_encoder_name is not None: - image_processor_name = CUSTOM_PROCESSORS_ALIAS.get(image_encoder_name, image_encoder_name) - image_processor = AutoImageProcessor.from_pretrained(image_processor_name, use_fast=True) + if image_encoder_name in CUSTOM_MODELS: + image_processor = ViTImageProcessorFast() + elif image_encoder_name is not None: + image_processor = AutoImageProcessor.from_pretrained(image_encoder_name, use_fast=True) elif image_encoder_config is not None: - # When using config file, default to a standard ViT processor image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224", use_fast=True) else: image_processor = NoopImageProcessor() diff --git a/welt/vision/navit.py b/welt/vision/navit.py index 3dd8719..b1e935f 100644 --- a/welt/vision/navit.py +++ b/welt/vision/navit.py @@ -11,7 +11,7 @@ class NaViTConfig(PretrainedConfig): def __init__( self, - image_size: int = 256, # only used to set a default; NaViT handles var-size + image_size: int | tuple[int, int] | list = 256, # max image size for positional embeddings patch_size: int = 16, hidden_size: int = 512, dim: int = 512, @@ -42,8 +42,13 @@ class NaViTModel(PreTrainedModel): def __init__(self, config: NaViTConfig): super().__init__(config) + # Convert list to tuple for vit_pytorch's pair() check (JSON round-trips tuples as lists) + image_size = config.image_size + if isinstance(image_size, list): + image_size = tuple(image_size) + navit = NaViT( - image_size=config.image_size, + image_size=image_size, patch_size=config.patch_size, num_classes=config.hidden_size, dim=config.dim, From 22837b7b2be6222a79df20615a0f3fc01eb8ca16 Mon Sep 17 00:00:00 2001 From: Ilker Kesen Date: Sat, 21 Mar 2026 02:42:46 +0100 Subject: [PATCH 12/12] add demos --- welt_training/demo.py | 203 ++++++++++++++++++++++++++++++++++++++ welt_training/demo_clm.py | 145 +++++++++++++++++++++++++++ 2 files changed, 348 insertions(+) create mode 100644 welt_training/demo.py create mode 100644 welt_training/demo_clm.py diff --git a/welt_training/demo.py b/welt_training/demo.py new file mode 100644 index 0000000..dfbb503 --- /dev/null +++ b/welt_training/demo.py @@ -0,0 +1,203 @@ +"""Gradio demo for qualitative testing of WeLT models.""" + +import argparse +import os +from pathlib import Path + +import gradio as gr +import matplotlib.pyplot as plt +import numpy as np +import torch +from transformers import GenerationConfig +from transformers.trainer_utils import get_last_checkpoint + +from welt.attention import get_attention_mask_for_packed_sequence +from welt.model import WordLatentTransformerForCausalLM +from welt.processor import TextImageProcessor + +DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" +AUTOCAST_DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float16 + +STRATEGY_GREEDY = "Greedy" +STRATEGY_BEAM = "Beam search" +STRATEGY_SAMPLING = "Sampling (top-k / top-p)" + + +def load_model_and_processor(model_path: str | Path): + model_path = str(model_path) + checkpoint_path = model_path + if os.path.isdir(model_path): + last_ckpt = get_last_checkpoint(model_path) + if last_ckpt is not None: + checkpoint_path = last_ckpt + print(f"Using last checkpoint: {checkpoint_path}") + + model: WordLatentTransformerForCausalLM = ( + WordLatentTransformerForCausalLM.from_pretrained( + checkpoint_path, + dtype=torch.bfloat16, + device_map=DEVICE, + ) + ) + model.enable_optimizations() + model.eval() + + processor = TextImageProcessor.from_pretrained(model_path) + + print(f"Model loaded on {DEVICE}") + return model, processor + + +def build_generation_config(strategy, num_beams, top_k, top_p, temperature, repetition_penalty): + kwargs = {} + + if repetition_penalty != 1.0: + kwargs["repetition_penalty"] = repetition_penalty + + if strategy == STRATEGY_BEAM: + kwargs["num_beams"] = int(num_beams) + elif strategy == STRATEGY_SAMPLING: + kwargs["do_sample"] = True + kwargs["temperature"] = temperature + if top_k > 0: + kwargs["top_k"] = int(top_k) + if top_p < 1.0: + kwargs["top_p"] = top_p + + if not kwargs: + return None + return GenerationConfig(**kwargs) + + +def create_entropy_plot(entropies: list[float], byte_labels: list[str], prompt_byte_count: int = 0): + """Create a matplotlib figure showing per-byte entropy.""" + if not entropies: + fig, ax = plt.subplots(figsize=(8, 3)) + ax.text(0.5, 0.5, "No bytes generated", ha="center", va="center", transform=ax.transAxes) + ax.set_xlabel("Bytes") + ax.set_ylabel("Entropy (bits)") + plt.tight_layout() + return fig + + fig, ax = plt.subplots(figsize=(max(8, len(entropies) * 0.35), 4)) + x = np.arange(len(entropies)) + max_ent = max(entropies) + 1e-8 + colors = plt.cm.RdYlGn_r(np.array(entropies) / max_ent) + + # Fade prompt bars to distinguish from generated + if prompt_byte_count > 0: + alphas = [0.4] * prompt_byte_count + [1.0] * (len(entropies) - prompt_byte_count) + for i, (xi, h, c, a) in enumerate(zip(x, entropies, colors, alphas)): + ax.bar(xi, h, color=c, alpha=a, edgecolor="none", width=0.8) + # Separator line between prompt and generated + ax.axvline(prompt_byte_count - 0.5, color="black", linewidth=1, linestyle="--", alpha=0.5) + else: + ax.bar(x, entropies, color=colors, edgecolor="none", width=0.8) + + ax.set_xticks(x) + ax.set_xticklabels(byte_labels, fontfamily="monospace", fontsize=8, rotation=0, ha="center") + ax.set_ylabel("Entropy (bits)") + ax.set_xlabel("Bytes (prompt | generated)") + ax.set_title("Per-byte entropy") + ax.set_xlim(-0.5, len(entropies) - 0.5) + + plt.tight_layout() + return fig + + +@torch.inference_mode() +@torch.autocast(device_type=DEVICE, dtype=AUTOCAST_DTYPE, enabled=DEVICE != "cpu") +def generate(prompt, max_words, strategy, num_beams, top_k, top_p, temperature, repetition_penalty, model, processor): + if not prompt.strip(): + return "", create_entropy_plot([], []) + + gen_config = build_generation_config(strategy, num_beams, top_k, top_p, temperature, repetition_penalty) + + words = processor.pretokenize(prompt) + tokenized = processor.tokenize_words(words, device=DEVICE) + input_images, input_images_dimensions = processor.render_texts(words, device=DEVICE) + attention_mask = get_attention_mask_for_packed_sequence([len(words)], words=words) + + inputs = { + "input_ids": tokenized.input_ids.unsqueeze(0), + "input_attention_mask": tokenized.attention_mask.unsqueeze(0), + "input_images": input_images.unsqueeze(0), + "input_images_dimensions": input_images_dimensions.unsqueeze(0), + "attention_mask": attention_mask.unsqueeze(0).to(device=DEVICE), + } + + texts, entropies, byte_labels, prompt_byte_count = model.generate( + **inputs, + processor=processor, + max_generated_words=int(max_words), + bytes_generation_config=gen_config, + return_entropy=True, + prompt_words=words, + ) + + fig = create_entropy_plot(entropies, byte_labels, prompt_byte_count) + return texts[0], fig + + +def main(): + parser = argparse.ArgumentParser(description="Gradio demo for WeLT model") + parser.add_argument("model_path", type=str, help="Path to a local WeLT model/checkpoint directory") + parser.add_argument("--port", type=int, default=7860) + parser.add_argument("--share", action="store_true", help="Create a public Gradio link") + args = parser.parse_args() + + model, processor = load_model_and_processor(args.model_path) + + def on_generate(prompt, max_words, strategy, num_beams, top_k, top_p, temperature, repetition_penalty): + return generate( + prompt, max_words, strategy, num_beams, top_k, top_p, temperature, repetition_penalty, model, processor) + + with gr.Blocks(title="WeLT Demo") as demo: + gr.Markdown("# WeLT – Qualitative Demo") + gr.Markdown(f"**Model:** `{args.model_path}`   **Device:** `{DEVICE}`") + + with gr.Row(): + with gr.Column(): + prompt = gr.Textbox(label="Prompt", lines=4, placeholder="Enter text here…") + max_words = gr.Slider(minimum=1, maximum=200, value=32, step=1, label="Max generated words") + strategy = gr.Radio( + [STRATEGY_GREEDY, STRATEGY_BEAM, STRATEGY_SAMPLING], + value=STRATEGY_GREEDY, label="Decoding strategy") + + with gr.Group(visible=False) as beam_params: + num_beams = gr.Slider(minimum=2, maximum=16, value=4, step=1, label="Number of beams") + + with gr.Group(visible=False) as sampling_params: + temperature = gr.Slider(minimum=0.01, maximum=2.0, value=1.0, step=0.01, label="Temperature") + top_k = gr.Slider(minimum=0, maximum=200, value=50, step=1, + label="Top-k (0 = disabled)") + top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, + label="Top-p (1.0 = disabled)") + + repetition_penalty = gr.Slider( + minimum=1.0, maximum=3.0, value=1.0, step=0.05, label="Repetition penalty (1.0 = off)") + + btn = gr.Button("Generate", variant="primary") + + with gr.Column(): + output = gr.Textbox(label="Generated output", lines=8, interactive=False) + entropy_plot = gr.Plot(label="Per-byte entropy") + + # Show/hide strategy-specific parameters + def on_strategy_change(choice): + return ( + gr.update(visible=choice == STRATEGY_BEAM), + gr.update(visible=choice == STRATEGY_SAMPLING), + ) + + strategy.change(fn=on_strategy_change, inputs=strategy, outputs=[beam_params, sampling_params]) + + all_inputs = [prompt, max_words, strategy, num_beams, top_k, top_p, temperature, repetition_penalty] + btn.click(fn=on_generate, inputs=all_inputs, outputs=[output, entropy_plot]) + prompt.submit(fn=on_generate, inputs=all_inputs, outputs=[output, entropy_plot]) + + demo.launch(server_port=args.port, share=args.share) + + +if __name__ == "__main__": + main() diff --git a/welt_training/demo_clm.py b/welt_training/demo_clm.py new file mode 100644 index 0000000..005e638 --- /dev/null +++ b/welt_training/demo_clm.py @@ -0,0 +1,145 @@ +"""Gradio demo for qualitative testing of sub-word causal language models (GPT-2, Pythia, etc.).""" + +import argparse +import os +from pathlib import Path + +import gradio as gr +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig +from transformers.trainer_utils import get_last_checkpoint + +DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" +AUTOCAST_DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float16 + +STRATEGY_GREEDY = "Greedy" +STRATEGY_BEAM = "Beam search" +STRATEGY_SAMPLING = "Sampling (top-k / top-p)" + + +def load_model_and_tokenizer(model_path: str | Path): + model_path = str(model_path) + checkpoint_path = model_path + if os.path.isdir(model_path): + last_ckpt = get_last_checkpoint(model_path) + if last_ckpt is not None: + checkpoint_path = last_ckpt + print(f"Using last checkpoint: {checkpoint_path}") + + tokenizer = AutoTokenizer.from_pretrained(model_path) + model = AutoModelForCausalLM.from_pretrained( + checkpoint_path, + torch_dtype=torch.bfloat16, + device_map=DEVICE, + ) + model.eval() + + print(f"Model loaded on {DEVICE}") + return model, tokenizer + + +def build_generation_config(strategy, num_beams, top_k, top_p, temperature, repetition_penalty): + kwargs = {} + + if repetition_penalty != 1.0: + kwargs["repetition_penalty"] = repetition_penalty + + if strategy == STRATEGY_BEAM: + kwargs["num_beams"] = int(num_beams) + elif strategy == STRATEGY_SAMPLING: + kwargs["do_sample"] = True + kwargs["temperature"] = temperature + if top_k > 0: + kwargs["top_k"] = int(top_k) + if top_p < 1.0: + kwargs["top_p"] = top_p + + if not kwargs: + return None + return GenerationConfig(**kwargs) + + +@torch.inference_mode() +@torch.autocast(device_type=DEVICE, dtype=AUTOCAST_DTYPE, enabled=DEVICE != "cpu") +def generate(prompt, max_new_tokens, strategy, num_beams, top_k, top_p, temperature, repetition_penalty, model, tokenizer): + if not prompt.strip(): + return "" + + gen_config = build_generation_config(strategy, num_beams, top_k, top_p, temperature, repetition_penalty) + + inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) + + gen_kwargs = { + **inputs, + "max_new_tokens": int(max_new_tokens), + } + if gen_config is not None: + gen_kwargs["generation_config"] = gen_config + + output_ids = model.generate(**gen_kwargs) + # Decode only the newly generated tokens + generated_ids = output_ids[0, inputs["input_ids"].shape[1]:] + return tokenizer.decode(generated_ids, skip_special_tokens=True) + + +def main(): + parser = argparse.ArgumentParser(description="Gradio demo for sub-word causal LMs") + parser.add_argument("model_path", type=str, help="Path to a local model/checkpoint directory") + parser.add_argument("--port", type=int, default=7860) + parser.add_argument("--share", action="store_true", help="Create a public Gradio link") + args = parser.parse_args() + + model, tokenizer = load_model_and_tokenizer(args.model_path) + + def on_generate(prompt, max_new_tokens, strategy, num_beams, top_k, top_p, temperature, repetition_penalty): + return generate( + prompt, max_new_tokens, strategy, num_beams, top_k, top_p, temperature, repetition_penalty, model, tokenizer) + + with gr.Blocks(title="CLM Demo") as demo: + gr.Markdown("# Sub-word CLM – Qualitative Demo") + gr.Markdown(f"**Model:** `{args.model_path}`   **Device:** `{DEVICE}`") + + with gr.Row(): + with gr.Column(): + prompt = gr.Textbox(label="Prompt", lines=4, placeholder="Enter text here…") + max_new_tokens = gr.Slider(minimum=1, maximum=512, value=64, step=1, label="Max new tokens") + strategy = gr.Radio( + [STRATEGY_GREEDY, STRATEGY_BEAM, STRATEGY_SAMPLING], + value=STRATEGY_GREEDY, label="Decoding strategy") + + with gr.Group(visible=False) as beam_params: + num_beams = gr.Slider(minimum=2, maximum=16, value=4, step=1, label="Number of beams") + + with gr.Group(visible=False) as sampling_params: + temperature = gr.Slider(minimum=0.01, maximum=2.0, value=1.0, step=0.01, label="Temperature") + top_k = gr.Slider(minimum=0, maximum=200, value=50, step=1, + label="Top-k (0 = disabled)") + top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, + label="Top-p (1.0 = disabled)") + + repetition_penalty = gr.Slider( + minimum=1.0, maximum=3.0, value=1.0, step=0.05, label="Repetition penalty (1.0 = off)") + + btn = gr.Button("Generate", variant="primary") + + with gr.Column(): + output = gr.Textbox(label="Generated output", lines=12, interactive=False) + + # Show/hide strategy-specific parameters + def on_strategy_change(choice): + return ( + gr.update(visible=choice == STRATEGY_BEAM), + gr.update(visible=choice == STRATEGY_SAMPLING), + ) + + strategy.change(fn=on_strategy_change, inputs=strategy, outputs=[beam_params, sampling_params]) + + all_inputs = [prompt, max_new_tokens, strategy, num_beams, top_k, top_p, temperature, repetition_penalty] + btn.click(fn=on_generate, inputs=all_inputs, outputs=output) + prompt.submit(fn=on_generate, inputs=all_inputs, outputs=output) + + demo.launch(server_port=args.port, share=args.share) + + +if __name__ == "__main__": + main()