Fix T4 OOM error visibility and parameter count claim in Lab 5.4#32
Closed
Yassir-249 wants to merge 1 commit into
Closed
Fix T4 OOM error visibility and parameter count claim in Lab 5.4#32Yassir-249 wants to merge 1 commit into
Yassir-249 wants to merge 1 commit into
Conversation
Lab 5.4 attempts full-parameter fine-tuning of Gemma3-1B on a T4 GPU and tells students they will see an "out of memory (OOM) error". With the package versions pinned by the project (keras==3.13.2, keras_hub==0.26.0, jax==0.7.2), the actual error students see is 'XlaRuntimeError: INTERNAL: No reference output found!' -- an opaque XLA internal that does not mention memory. Root cause is the same as in PR google-deepmind#29 (which fixes Labs 7.5 and 7.6): XLA_PYTHON_CLIENT_MEM_FRACTION=0.95 makes JAX preallocate ~14.25 GiB and never release compilation buffers from model loading. By the time .fit() compiles the train step, the cuDNN autotuner cannot get workspace memory and the failure surfaces as an INTERNAL error instead of RESOURCE_EXHAUSTED. This change applies the same allocator pattern as PR google-deepmind#29 to Lab 5.4: move the JAX env vars before 'import keras', replace XLA_PYTHON_CLIENT_MEM_FRACTION=0.95 with XLA_PYTHON_CLIENT_ALLOCATOR=platform, and add XLA_FLAGS=--xla_gpu_enable_command_buffer=. After the change the same fit call produces: XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request for 12.52GiB (13438550016B) on device ordinal 0 The 12.52 GiB matches the predicted Adam optimizer state (m + v + gradient at fp32 for ~1B trainable params), giving students a clear, sized OOM that maps onto the lab's stated lesson. Also corrects the "Understanding Gemma's Memory Footprint" markdown. The lab said the model has "approximately 1.3 billion parameters (1B for the main transformer blocks and 300M for the token embeddings)", but Keras reports Total params: 999,885,952. The token embedding is implemented as a ReversibleEmbedding shared between input and output, so the ~302M embedding parameters are already counted inside the backbone's ~999M. Updated to "approximately 1 billion parameters in total (around 700M for the main transformer blocks and around 300M for the token embeddings, which are shared between the input and output)". Verified end-to-end on Colab T4 GPU runtime against live HEAD of the repo on main.
Author
|
Resolved on |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
Lab 5.4 (
course_5/gdm_lab_5_4_full_parameter_fine_tuning_of_gemma.ipynb) has two issues that disrupt the lesson the lab is trying to deliver.1. The OOM error doesn't match what the narrative tells students to expect
The lab text says students "should observe an 'out of memory' (OOM) error" when calling
model.fit(...). With the package versions pinned by the project (keras==3.13.2,keras_hub==0.26.0,jax==0.7.2, cuDNN 9.10) the actual error students see is:Root cause: setting
XLA_PYTHON_CLIENT_MEM_FRACTION=0.95makes JAX preallocate 14.25 GiB and never release compilation buffers from model loading — so by the time.fit()runs, only ~1.7 GiB is free in the JAX pool. The cuDNN autotuner can't get workspace memory and surfaces an opaqueINTERNALerror instead of the cleanRESOURCE_EXHAUSTEDthe narrative implies.Same root cause as #29 (which fixes the analogous issue in Labs 7.5 and 7.6).
2. The "1.3 billion parameters" narrative claim double-counts the token embedding
The "Understanding Gemma's Memory Footprint" cell says:
But Keras reports
Total params: 999,885,952forgemma3_1b. The token embedding is implemented as aReversibleEmbeddingshared between input and output, so its ~302M parameters are already counted inside the backbone's ~999M. Unique parameter count is ~1B, not 1.3B.Fix
1. Apply the same allocator pattern as #29
In the install/imports cell:
import keraslineXLA_PYTHON_CLIENT_MEM_FRACTION=0.95withXLA_PYTHON_CLIENT_ALLOCATOR=platformXLA_FLAGS=--xla_gpu_enable_command_buffer=After the fix, the same
model.fit(...)call produces:The 12.52 GiB matches the predicted Adam optimizer state (m + v + gradient at fp32 for ~1B trainable params), giving students a clear, sized OOM that maps directly onto the lab's stated lesson.
2. Correct the parameter count claim
Update the "Understanding Gemma's Memory Footprint" markdown cell from:
to:
Verification
Tested end-to-end on Colab T4 GPU runtime against live HEAD of
main:XlaRuntimeError: INTERNAL: No reference output found!(orINTERNAL: No valid config found!at lower batch size — both opaque)XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request for 12.52GiB (13438550016B) on device ordinal 0Related