Skip to content

FIX AdaMSS save/load reproduction by making slice_pca SVD deterministic#3310

Open
AshNicolus wants to merge 1 commit into
huggingface:mainfrom
AshNicolus:fix-adamss-save-load-nondeterministic-scatter
Open

FIX AdaMSS save/load reproduction by making slice_pca SVD deterministic#3310
AshNicolus wants to merge 1 commit into
huggingface:mainfrom
AshNicolus:fix-adamss-save-load-nondeterministic-scatter

Conversation

@AshNicolus
Copy link
Copy Markdown

What does this PR do?

Saving an AdaMSS adapter with save_pretrained and reloading it with
from_pretrained does not reproduce the model's outputs once the adapter has
been trained.

Root cause

When an AdaMSS adapter is loaded, the layer is rebuilt from the base weights
(update_layerslice_pcaclustering_Z). This recomputes scatter_index,
the mapping that places each subspace's contribution into the correct output
dimensions. slice_pca uses torch.svd_lowrank, which draws a random projection
from the global RNG, so the SVD result — and therefore the clustering and
scatter_index — depends on the RNG state at construction time. Because that
state differs between saving and loading, the reloaded adapter rebuilds a
different scatter_index. The trained adamss_A/adamss_B weights are restored
correctly, but are then scattered to the wrong output dimensions, so the output
changes.

This stayed hidden because AdaMSS initializes B = 0: an untrained adapter is a
no-op and reloads trivially regardless of scatter_index. The discrepancy only
appears once the adapter is trained.

clustering_Z already pins KMeans(random_state=...), so deterministic
initialization is clearly intended; the only remaining source of randomness was
svd_lowrank.

Fix

Seed a forked RNG around the svd_lowrank calls in slice_pca so the
decomposition is deterministic and the reconstructed scatter_index matches the
one used at save time. torch.random.fork_rng leaves the global RNG stream
untouched.

Tests

Added TestAdamssSaveLoad::test_save_load_reproduces_output in
tests/test_adamss_asa.py: it trains an AdaMSS adapter, saves it, perturbs the
global RNG state, reloads into a fresh base model, and asserts the outputs match.
The test fails on main and passes with this change; the existing AdaMSS tests
still pass.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant