Skip to content

Fix T4 OOM error visibility and parameter count claim in Lab 5.4#32

Closed
Yassir-249 wants to merge 1 commit into
google-deepmind:mainfrom
Yassir-249:fix/lab-5-4-oom-and-param-count
Closed

Fix T4 OOM error visibility and parameter count claim in Lab 5.4#32
Yassir-249 wants to merge 1 commit into
google-deepmind:mainfrom
Yassir-249:fix/lab-5-4-oom-and-param-count

Conversation

@Yassir-249

Copy link
Copy Markdown

Problem

Lab 5.4 (course_5/gdm_lab_5_4_full_parameter_fine_tuning_of_gemma.ipynb) has two issues that disrupt the lesson the lab is trying to deliver.

1. The OOM error doesn't match what the narrative tells students to expect

The lab text says students "should observe an 'out of memory' (OOM) error" when calling model.fit(...). With the package versions pinned by the project (keras==3.13.2, keras_hub==0.26.0, jax==0.7.2, cuDNN 9.10) the actual error students see is:

XlaRuntimeError: INTERNAL: No reference output found!

Root cause: setting XLA_PYTHON_CLIENT_MEM_FRACTION=0.95 makes JAX preallocate 14.25 GiB and never release compilation buffers from model loading — so by the time .fit() runs, only ~1.7 GiB is free in the JAX pool. The cuDNN autotuner can't get workspace memory and surfaces an opaque INTERNAL error instead of the clean RESOURCE_EXHAUSTED the narrative implies.

Same root cause as #29 (which fixes the analogous issue in Labs 7.5 and 7.6).

2. The "1.3 billion parameters" narrative claim double-counts the token embedding

The "Understanding Gemma's Memory Footprint" cell says:

The Gemma 1B model has approximately 1.3 billion parameters (1B for the main transformer blocks and 300M for the token embeddings).

But Keras reports Total params: 999,885,952 for gemma3_1b. The token embedding is implemented as a ReversibleEmbedding shared between input and output, so its ~302M parameters are already counted inside the backbone's ~999M. Unique parameter count is ~1B, not 1.3B.

Fix

1. Apply the same allocator pattern as #29

In the install/imports cell:

  • Move the JAX env-var settings before the import keras line
  • Replace XLA_PYTHON_CLIENT_MEM_FRACTION=0.95 with XLA_PYTHON_CLIENT_ALLOCATOR=platform
  • Add XLA_FLAGS=--xla_gpu_enable_command_buffer=

After the fix, the same model.fit(...) call produces:

XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request for 12.52GiB
                 (13438550016B) on device ordinal 0

The 12.52 GiB matches the predicted Adam optimizer state (m + v + gradient at fp32 for ~1B trainable params), giving students a clear, sized OOM that maps directly onto the lab's stated lesson.

2. Correct the parameter count claim

Update the "Understanding Gemma's Memory Footprint" markdown cell from:

The Gemma 1B model has approximately 1.3 billion parameters (1B for the main transformer blocks and 300M for the token embeddings).

to:

The Gemma 1B model has approximately 1 billion parameters in total (around 700M for the main transformer blocks and around 300M for the token embeddings, which are shared between the input and output).

Verification

Tested end-to-end on Colab T4 GPU runtime against live HEAD of main:

Error
Before fix XlaRuntimeError: INTERNAL: No reference output found! (or INTERNAL: No valid config found! at lower batch size — both opaque)
After fix XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request for 12.52GiB (13438550016B) on device ordinal 0

Related

Lab 5.4 attempts full-parameter fine-tuning of Gemma3-1B on a T4 GPU
and tells students they will see an "out of memory (OOM) error". With
the package versions pinned by the project (keras==3.13.2,
keras_hub==0.26.0, jax==0.7.2), the actual error students see is
'XlaRuntimeError: INTERNAL: No reference output found!' -- an opaque
XLA internal that does not mention memory.

Root cause is the same as in PR google-deepmind#29 (which fixes Labs 7.5 and 7.6):
XLA_PYTHON_CLIENT_MEM_FRACTION=0.95 makes JAX preallocate ~14.25 GiB
and never release compilation buffers from model loading. By the time
.fit() compiles the train step, the cuDNN autotuner cannot get
workspace memory and the failure surfaces as an INTERNAL error
instead of RESOURCE_EXHAUSTED.

This change applies the same allocator pattern as PR google-deepmind#29 to Lab 5.4:
move the JAX env vars before 'import keras', replace
XLA_PYTHON_CLIENT_MEM_FRACTION=0.95 with
XLA_PYTHON_CLIENT_ALLOCATOR=platform, and add
XLA_FLAGS=--xla_gpu_enable_command_buffer=. After the change the same
fit call produces:

    XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request
    for 12.52GiB (13438550016B) on device ordinal 0

The 12.52 GiB matches the predicted Adam optimizer state (m + v +
gradient at fp32 for ~1B trainable params), giving students a clear,
sized OOM that maps onto the lab's stated lesson.

Also corrects the "Understanding Gemma's Memory Footprint" markdown.
The lab said the model has "approximately 1.3 billion parameters
(1B for the main transformer blocks and 300M for the token
embeddings)", but Keras reports Total params: 999,885,952. The token
embedding is implemented as a ReversibleEmbedding shared between
input and output, so the ~302M embedding parameters are already
counted inside the backbone's ~999M. Updated to "approximately 1
billion parameters in total (around 700M for the main transformer
blocks and around 300M for the token embeddings, which are shared
between the input and output)".

Verified end-to-end on Colab T4 GPU runtime against live HEAD of the
repo on main.
upaq pushed a commit that referenced this pull request Jun 18, 2026
, #32, #33, #34, #35 and #36 by @Yassir-249.

PiperOrigin-RevId: 934537389
Change-Id: I5f6035aa2fb1ed3c291e26f9058e537ee5394996
@Yassir-249

Copy link
Copy Markdown
Author

Resolved on main — the T4 OOM visibility and the parameter-count claim are both addressed now. Closing, thanks!

@Yassir-249 Yassir-249 closed this Jun 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant