-
-
Notifications
You must be signed in to change notification settings - Fork 5
Implement Bits Per Byte Metric Computation #61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ilkerkesen
wants to merge
12
commits into
sign:main
Choose a base branch
from
ilkerkesen:bits-per-byte
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
87d4cb2
implement bits per bytes metric including tests and trainer adaption
ilkerkesen 646117a
count EOS bytes during bits per byte computation
ilkerkesen d7bc29c
handle uneqaul non-PAD counts across batches correctly
ilkerkesen 58d6739
make evaluation correctly working for text evaluation metrics
ilkerkesen 8ffbf29
add support for utf-16 and utf-32 encodings also as well
ilkerkesen cc2d4f3
refactor batch counting in welt trainer
ilkerkesen 18d2e3e
resolve ruff issues
ilkerkesen 405edeb
fix eval sample counting when prefix field is absent
ilkerkesen 518d5bc
remove huggingface/evaluate dependency for accuracy: it complicates t…
ilkerkesen 6d53b79
Align BPB computation for both implementations
ilkerkesen 2493449
make image-based models working
ilkerkesen 22837b7
add demos
ilkerkesen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,168 @@ | ||
| """Tests for compute_bits_per_byte metric utility.""" | ||
|
|
||
| import math | ||
|
|
||
| import pytest | ||
|
|
||
| from welt_training.metrics import compute_bits_per_byte | ||
|
|
||
|
|
||
| class TestComputeBitsPerByte: | ||
| """Tests for the compute_bits_per_byte function.""" | ||
|
|
||
| def test_basic_computation(self): | ||
| """BPB = loss * num_tokens / (num_bytes * ln(2)).""" | ||
| loss = 2.0 | ||
| num_tokens = 100 | ||
| num_bytes = 100 | ||
| expected = 2.0 / math.log(2) | ||
| assert compute_bits_per_byte(loss, num_tokens, num_bytes) == pytest.approx(expected) | ||
|
|
||
| def test_byte_level_model(self): | ||
| """When tokens == bytes (byte-level model), BPB = loss / ln(2).""" | ||
| loss = 1.5 | ||
| bpb = compute_bits_per_byte(loss, num_tokens=1000, num_bytes=1000) | ||
| assert bpb == pytest.approx(loss / math.log(2)) | ||
|
|
||
| def test_subword_tokenizer(self): | ||
| """When tokens < bytes (subword tokenizer), BPB accounts for compression.""" | ||
| loss = 3.0 | ||
| num_tokens = 50 | ||
| num_bytes = 200 # 4 bytes per token on average | ||
| expected = 3.0 * 50 / (200 * math.log(2)) | ||
| assert compute_bits_per_byte(loss, num_tokens, num_bytes) == pytest.approx(expected) | ||
|
|
||
| def test_zero_loss(self): | ||
| """Zero loss should give zero BPB (perfect prediction).""" | ||
| assert compute_bits_per_byte(0.0, 100, 100) == 0.0 | ||
|
|
||
| def test_zero_bytes_returns_inf(self): | ||
| """Zero bytes should return infinity.""" | ||
| result = compute_bits_per_byte(1.0, 100, 0) | ||
| assert result == float("inf") | ||
|
|
||
| def test_zero_tokens_returns_zero(self): | ||
| """Zero tokens should return zero BPB.""" | ||
| result = compute_bits_per_byte(1.0, 0, 100) | ||
| assert result == 0.0 | ||
|
|
||
| def test_higher_loss_means_higher_bpb(self): | ||
| """Higher loss should give higher BPB.""" | ||
| bpb_low = compute_bits_per_byte(1.0, 100, 100) | ||
| bpb_high = compute_bits_per_byte(2.0, 100, 100) | ||
| assert bpb_high > bpb_low | ||
|
|
||
| def test_more_bytes_means_lower_bpb(self): | ||
| """More bytes for same total bits should give lower BPB.""" | ||
| bpb_fewer = compute_bits_per_byte(1.0, 100, 100) | ||
| bpb_more = compute_bits_per_byte(1.0, 100, 200) | ||
| assert bpb_more < bpb_fewer | ||
|
|
||
| def test_known_value_one_bit_per_byte(self): | ||
| """loss = ln(2) nats per token, 1 token, 1 byte -> 1.0 bit per byte.""" | ||
| loss = math.log(2) | ||
| assert compute_bits_per_byte(loss, 1, 1) == pytest.approx(1.0) | ||
|
|
||
| def test_relationship_with_perplexity(self): | ||
| """BPB = log2(perplexity) when num_tokens == num_bytes.""" | ||
| loss = 2.5 | ||
| perplexity = math.exp(loss) | ||
| bpb = compute_bits_per_byte(loss, num_tokens=1, num_bytes=1) | ||
| assert bpb == pytest.approx(math.log2(perplexity)) | ||
|
|
||
| def test_scaling_with_token_count(self): | ||
| """Doubling num_tokens with same num_bytes should double BPB.""" | ||
| bpb1 = compute_bits_per_byte(1.0, num_tokens=100, num_bytes=200) | ||
| bpb2 = compute_bits_per_byte(1.0, num_tokens=200, num_bytes=200) | ||
| assert bpb2 == pytest.approx(2 * bpb1) | ||
|
|
||
|
|
||
| class TestBPBTokenizerIntegration: | ||
| """Validate BPB computation patterns used in the training scripts.""" | ||
|
|
||
| def test_byte_level_model_with_eos_overhead(self): | ||
| """ | ||
| Validate the WeLTTrainer BPB pattern. | ||
|
|
||
| WeLT labels contain content bytes + one EOS per word. Loss is averaged | ||
| over all non-PAD positions (content + EOS), but BPB divides total bits | ||
| by content bytes only, producing BPB > loss/ln(2). | ||
| """ | ||
| loss = 4.2 | ||
| # Simulate 100 content bytes across 20 words → 120 loss tokens (100 bytes + 20 EOS) | ||
| num_tokens = 120 | ||
| num_bytes = 100 | ||
|
|
||
| bpb = compute_bits_per_byte(loss, num_tokens, num_bytes) | ||
|
|
||
| assert bpb == pytest.approx(loss * 120 / (100 * math.log(2))) | ||
| assert bpb > loss / math.log(2) # must exceed naive estimate | ||
|
|
||
| def test_subword_tokenizer_compression_ratio(self): | ||
| """ | ||
| Validate the run_clm.py pattern where tokens < bytes due to subword tokenization. | ||
|
|
||
| For a BPE tokenizer, each token covers ~3-4 bytes on average. | ||
| BPB must account for this compression: | ||
| BPB = loss_per_token * (num_tokens / num_bytes) / ln(2) | ||
| """ | ||
| loss = 3.0 | ||
|
|
||
| # Byte-level baseline: 1 token per byte | ||
| bpb_byte = compute_bits_per_byte(loss, num_tokens=400, num_bytes=400) | ||
|
|
||
| # BPE tokenizer: ~3.5 bytes per token (typical for English) | ||
| bpb_bpe = compute_bits_per_byte(loss, num_tokens=400, num_bytes=1400) | ||
|
|
||
| # BPE should produce lower BPB (same loss spread over more bytes) | ||
| assert bpb_bpe < bpb_byte | ||
| assert bpb_bpe == pytest.approx(bpb_byte * 400 / 1400) | ||
|
|
||
| @pytest.mark.integration | ||
| def test_decode_roundtrip_byte_counting(self): | ||
| """ | ||
| Validate the run_clm.py approach: decode token IDs → encode as UTF-8 → count bytes. | ||
|
|
||
| Uses a real HuggingFace tokenizer to verify the decode-based byte counting | ||
| correctly captures the tokenizer's compression ratio. | ||
| """ | ||
| from transformers import AutoTokenizer | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained("gpt2") | ||
|
|
||
| texts = [ | ||
| "The quick brown fox jumps over the lazy dog.", | ||
| "Machine translation is an important NLP task.", | ||
| ] | ||
|
|
||
| # Simulate run_clm.py pipeline: tokenize, then decode to count bytes | ||
| all_token_ids = [] | ||
| for text in texts: | ||
| all_token_ids.extend(tokenizer.encode(text)) | ||
|
|
||
| num_tokens = len(all_token_ids) | ||
| original_bytes = sum(len(t.encode("utf-8")) for t in texts) | ||
|
|
||
| # run_clm.py decodes per-chunk; simulate with full sequence | ||
| decoded_text = tokenizer.decode(all_token_ids) | ||
| decoded_bytes = len(decoded_text.encode("utf-8")) | ||
|
|
||
| # Decoded bytes should closely match original (tokenizer roundtrip) | ||
| assert abs(decoded_bytes - original_bytes) <= len(texts), \ | ||
| f"Decoded bytes ({decoded_bytes}) should match original ({original_bytes})" | ||
|
|
||
| # Subword tokenizer: fewer tokens than bytes | ||
| assert num_tokens < original_bytes, \ | ||
| "BPE tokenizer should compress text (fewer tokens than bytes)" | ||
|
|
||
| # Verify BPB with a known loss | ||
| loss = 2.5 | ||
| bpb = compute_bits_per_byte(loss, num_tokens, decoded_bytes) | ||
|
|
||
| # Must be less than byte-level BPB (since tokens < bytes) | ||
| bpb_byte_level = loss / math.log(2) | ||
| assert bpb < bpb_byte_level | ||
|
|
||
| # Must equal the analytical formula | ||
| expected = loss * num_tokens / (decoded_bytes * math.log(2)) | ||
| assert bpb == pytest.approx(expected) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.