Skip to content

Transient Gemma-1B Loading Error in Colab GPU Runtime#9

Open
boiled-darvari wants to merge 5 commits into
google-deepmind:mainfrom
boiled-darvari:model-loading
Open

Transient Gemma-1B Loading Error in Colab GPU Runtime#9
boiled-darvari wants to merge 5 commits into
google-deepmind:mainfrom
boiled-darvari:model-loading

Conversation

@boiled-darvari

Copy link
Copy Markdown

When running the "Compare N-Gram Models and Transformer Language Models" lab in Google Colab with a GPU runtime, the initial attempt to load the Gemma-1B model using ai_foundations.generation.load_gemma() occasionally fails with a RuntimeError related to CUDA backend initialization.

The error observed is:

Loaded Africa Galore dataset with 232 paragraphs.

Loaded trigram model.

Loading Gemma-1B model...
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
/usr/local/lib/python3.12/dist-packages/jax/_src/xla_bridge.py in backends()
    811 
--> 812         backend = _init_backend(platform)
    813         _backends[platform] = backend

13 frames
/usr/local/lib/python3.12/dist-packages/jax/_src/xla_bridge.py in _init_backend(platform)
    895   logger.debug("Initializing backend '%s'", platform)
--> 896   backend = registration.factory()
    897   # TODO(skye): consider raising more descriptive errors directly from backend

/usr/local/lib/python3.12/dist-packages/jax/_src/xla_bridge.py in make_pjrt_c_api_client(plugin_name, options)
    548   if distributed.global_state.client is None:
--> 549     return xla_client.make_c_api_client(plugin_name, updated_options, None)
    550 

/usr/local/lib/python3.12/dist-packages/jaxlib/xla_client.py in make_c_api_client(plugin_name, options, distributed_client, transfer_server_factory)
    155     options = {}
--> 156   return _xla.get_c_api_client(
    157       plugin_name,

XlaRuntimeError: INVALID_ARGUMENT: Unexpected option name passed to PJRT_Client_Create: use_tfrt_gpu_client

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
/tmp/ipython-input-2823723241.py in <cell line: 0>()
     12 
     13 print("Loading Gemma-1B model...")
---> 14 gemma_model = generation.load_gemma()
     15 print("Loaded Gemma-1B model.")

/usr/local/lib/python3.12/dist-packages/ai_foundations/generation/loaders.py in load_gemma(model_name)
     57     else:
     58       model = attention.AttentionWeightGemma3_1B()
---> 59     params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_1B_PT)
     60   elif model_name == "Gemma-4B":
     61     tokenizer = gm.text.Gemma3Tokenizer()

/usr/local/lib/python3.12/dist-packages/gemma/gm/ckpts/_checkpoint.py in load_params(path, params, donate, text_only, sharding, quantize)
    184     raise ValueError('`sharding` and `params` are mutually exclusive.')
    185 
--> 186   ckpt = ocp.StandardCheckpointer()
    187 
    188   metadata, path = _get_metadata_and_path(ckpt, path)

/usr/local/lib/python3.12/dist-packages/orbax/checkpoint/_src/checkpointers/standard_checkpointer.py in __init__(self, async_options, multiprocessing_options, file_options, checkpoint_metadata_store, temporary_path_class, **kwargs)
     83     """
     84     super().__init__(
---> 85         standard_checkpoint_handler.StandardCheckpointHandler(
     86             multiprocessing_options=multiprocessing_options,
     87             **kwargs,

/usr/local/lib/python3.12/dist-packages/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py in __init__(self, save_concurrent_gb, restore_concurrent_gb, multiprocessing_options, pytree_metadata_options)
     95     """
     96     self._supported_types = checkpoint_utils.STANDARD_ARRAY_TYPES
---> 97     self._impl = pytree_checkpoint_handler.PyTreeCheckpointHandler(
     98         save_concurrent_gb=save_concurrent_gb,
     99         restore_concurrent_gb=restore_concurrent_gb,

/usr/local/lib/python3.12/dist-packages/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py in __init__(self, aggregate_filename, save_concurrent_gb, restore_concurrent_gb, use_ocdbt, use_zarr3, multiprocessing_options, type_handler_registry, handler_impl, pytree_metadata_options, array_metadata_validator, enable_pinned_host_transfer)
    536     self._save_concurrent_bytes = _concurrent_bytes(save_concurrent_gb)
    537     self._restore_concurrent_bytes = _concurrent_bytes(restore_concurrent_gb)
--> 538     self._handler_impl = handler_impl or BasePyTreeCheckpointHandler(
    539         save_concurrent_bytes=self._save_concurrent_bytes,
    540         restore_concurrent_bytes=self._restore_concurrent_bytes,

/usr/local/lib/python3.12/dist-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py in __init__(self, save_concurrent_bytes, restore_concurrent_bytes, use_ocdbt, use_zarr3, multiprocessing_options, type_handler_registry, enable_post_merge_validation, pytree_metadata_options, array_metadata_validator, enable_pinned_host_transfer)
    367 
    368     if enable_pinned_host_transfer is None:
--> 369       enable_pinned_host_transfer = jax.default_backend() == 'gpu'
    370     self._enable_pinned_host_transfer = enable_pinned_host_transfer
    371 

/usr/local/lib/python3.12/dist-packages/jax/_src/xla_bridge.py in default_backend()
   1013 def default_backend() -> str:
   1014   """Returns the platform name of the default XLA backend."""
-> 1015   return get_backend(None).platform
   1016 
   1017 

/usr/local/lib/python3.12/dist-packages/jax/_src/xla_bridge.py in get_backend(platform)
    942     platform: None | str | xla_client.Client = None
    943 ) -> xla_client.Client:
--> 944   return _get_backend_uncached(platform)
    945 
    946 

/usr/local/lib/python3.12/dist-packages/jax/_src/xla_bridge.py in _get_backend_uncached(platform)
    921   platform = (platform or _XLA_BACKEND.value or _PLATFORM_NAME.value or None)
    922 
--> 923   bs = backends()
    924   if platform is not None:
    925     platform = canonicalize_platform(platform)

/usr/local/lib/python3.12/dist-packages/jax/_src/xla_bridge.py in backends()
    826           else:
    827             err_msg += " (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)"
--> 828           raise RuntimeError(err_msg)
    829 
    830     assert _default_backend is not None

RuntimeError: Unable to initialize backend 'cuda': INVALID_ARGUMENT: Unexpected option name passed to PJRT_Client_Create: use_tfrt_gpu_client (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

Subsequent attempts to run the cell and load the model typically succeed without error. This suggests a potential timing or transient issue during the initial setup of the JAX-CUDA backend in the Colab environment.

Steps to Reproduce:

  1. Open the "Compare N-Gram Models and Transformer Language Models" lab in Google Colab.
  2. Ensure a GPU runtime is selected (Runtime > Change runtime type > GPU).
  3. Run all cells sequentially until the cell containing generation.load_gemma().
  4. Observe the RuntimeError on the first execution of the cell loading Gemma.
  5. Re-running the same cell often succeeds.

Expected Behavior:

The Gemma-1B model should load successfully on the first attempt in a correctly configured Colab GPU runtime.

Observed Behavior:

The initial attempt to load the Gemma-1B model sometimes fails with a CUDA backend initialization error.

Environment:

  • Google Colab
  • GPU Runtime (e.g., T4 GPU)
  • Libraries as installed by the notebook's setup cells (including ai_foundations, jax, orbax-checkpoint, etc.)

Proposed Solution/Workaround:

A temporary workaround is to implement a retry mechanism around the generation.load_gemma() call, catching the RuntimeError and retrying the load after a short delay. This is the approach currently implemented in the notebook to mitigate the issue for users.

Note on GPU Initialization:

While a retry mechanism was implemented to address a transient issue observed during the initial loading of the Gemma model on Colab's GPU runtime, a definitive, quick, and correct method to explicitly "wait" for the underlying JAX CUDA backend to be fully ready before attempting the load was not identified. The current approach provides a practical workaround for the observed behavior, where the first loading attempt might fail due to initialization timing.

After fix:

Loaded Africa Galore dataset with 232 paragraphs.

Loaded trigram model.

Loading Gemma-1B model...
Attempt 1/3: Error loading Gemma model: Unable to initialize backend 'cuda': INVALID_ARGUMENT: Unexpected option name passed to PJRT_Client_Create: use_tfrt_gpu_client (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
Retrying in 10 seconds...
Loaded Gemma-1B model.

P.S.: If the retry logic seems confusing for beginners, another option is to insert a single sleep(n) before loading the model. Alternatively, you can simply let learners decide whether to rerun the code cell if they encounter a loading error on their first attempt.

capyBearista and others added 5 commits October 25, 2025 02:14
…les.

- fix grammatical error in README.md: "The eight courses are in the curriculum are:" -> "The eight courses in the curriculum are:"
- fix misspelling in ai_foundations/visualizations/plots.py: "correclty" -> "correctly"
- remove word repetition "and and" in course_2/gdm_lab_2_1_preprocess_data.ipynb
- remove word repetition "to to" in course_2/gdm_lab_2_2_tokenize_texts_into_characters_and_words.ipynb
- fix misspelling "occurence" -> "occurrence" in course_3/gdm_lab_3_1_distinguish_between_signal_and_noise.ipynb
- remove word repetition "at at" in course_3/gdm_lab_3_6_mitigate_overfitting.ipynb
- fix "fo" -> "of" in course_3/gdm_lab_3_6_mitigate_overfitting.ipynb
- remove word repetition "a a" in course_4/gdm_lab_4_5_reflection_on_trainable_parameters.ipynb
Fix grammatical errors in course materials
docs: fixed a small typo in the notebook
Transient Gemma-1B Loading Error in Colab GPU Runtime
@google-cla

google-cla Bot commented Nov 10, 2025

Copy link
Copy Markdown

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@boiled-darvari

Copy link
Copy Markdown
Author

I think I found the main issue:

Using cached jax-0.6.2-py3-none-any.whl (2.7 MB)
Using cached jaxlib-0.6.2-cp312-cp312-manylinux2014_x86_64.whl (89.9 MB)
Installing collected packages: jaxlib, jax
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.7.2
    Uninstalling jaxlib-0.7.2:
      Successfully uninstalled jaxlib-0.7.2
  Attempting uninstall: jax
    Found existing installation: jax 0.7.2
    Uninstalling jax-0.7.2:
      Successfully uninstalled jax-0.7.2
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ai-foundations 0.1.0 requires jax==0.7.2, but you have jax 0.6.2 which is incompatible.
Successfully installed jax-0.6.2 jaxlib-0.6.2
WARNING: The following packages were previously imported in this runtime:
  [jax,jaxlib]
You must restart the runtime in order to use newly installed versions.

Changing the JAX version to 0.7.2 would fix the issue with some new errors:

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ipython-sql 0.5.0 requires sqlalchemy>=2.0, but you have sqlalchemy 1.2.19 which is incompatible.
google-adk 1.17.0 requires sqlalchemy<3.0.0,>=2.0, but you have sqlalchemy 1.2.19 which is incompatible.
langchain 0.3.27 requires SQLAlchemy<3,>=1.4, but you have sqlalchemy 1.2.19 which is incompatible.
Successfully installed ai_foundations-0.1.0 alembic-1.4.3 async_generator-1.10 cloud-sql-python-connector-1.18.5 clu-0.0.12 dnspython-2.8.0 docker-7.1.0 durationpy-0.10 gemma-3.0.0 grain-0.2.14 jaxtyping-0.3.3 jedi-0.19.2 kauldron-1.3.0 kubernetes-34.1.0 mediapy-1.2.4 ml_collections-1.1.0 python-editor-1.0.4 sqlalchemy-1.2.19 urllib3-2.3.0 wadler-lindig-0.1.7 xmanager-0.7.1
WARNING: The following packages were previously imported in this runtime:
  [google]
You must restart the runtime in order to use newly installed versions.

But Gemma-1B model loads successfully in the first try!

You can reproduce the issue with this code cell alone:

!pip install orbax-checkpoint==0.11.21 jax[cuda12]==0.6.2
!pip install "git+https://github.com/google-deepmind/ai-foundations.git@main"

# Packages used.
import os # For setting a variable needed to load the model onto the GPU.
import pandas as pd # For loading the Africa Galore dataset.

# Functions for clearing outputs and formatting.
from IPython.display import clear_output, display, HTML

# Functions for generating texts with a language model, visualizing probability
# distributions, and loading an n-gram model.
from ai_foundations import generation
from ai_foundations import visualizations
from ai_foundations.ngram import model as ngram_model

# Set the full GPU memory usage for JAX.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"

Then:

import time
import jax

# waiting for CUDA backend to be initialized
max_check_retries = 30
print("Waiting for CUDA backend initialization...")
for i in range(max_check_retries):
    try:
        print(jax.devices())
        print("CUDA backend initialized successfully.")
        break
    except Exception as e:
        print(".")
        if i < max_check_retries - 1:
            time.sleep(1)
        else:
            print("CUDA backend initialization failed, try again.")
            raise

With failure on the first try, same as loading Gemma-1B model if the JAX version is 0.6.2:

Waiting for CUDA backend initialization...
.
[CpuDevice(id=0)]
CUDA backend initialized successfully.

So it's not about waiting for the CUDA backend to be initialized, as I used sleep() in the commit!

It seems more complex than just simple PR. Let me know if you need to make any changes to all the notebooks. And I don't mind if you change yourself by even closing this one.

@boiled-darvari

Copy link
Copy Markdown
Author

Google should build a Gemini Dependency Hell Update Agent just for this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants