Skip to content

Add sweep for relu2max capped hsnorm peri ln compat#838

Open
klei22 wants to merge 3 commits into
ReaLLMASIC:masterfrom
klei22:add_sweep_for_relu2max_capped_hsnorm_peri_ln_compat
Open

Add sweep for relu2max capped hsnorm peri ln compat#838
klei22 wants to merge 3 commits into
ReaLLMASIC:masterfrom
klei22:add_sweep_for_relu2max_capped_hsnorm_peri_ln_compat

Conversation

@klei22

@klei22 klei22 commented Jun 14, 2026

Copy link
Copy Markdown
Collaborator

No description provided.

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a new exploration sweep for testing ReLU2Max + Infinite Attention in peri-LN mode using CappedHyperSphereNorm, and extends CappedHyperSphereNorm to optionally apply a learnable gain (to align better with existing HyperSphereNorm configuration patterns).

Changes:

  • Add optional hsnorm_gain support to CappedHyperSphereNorm so it can apply a learnable per-channel gain.
  • Add a new YAML sweep (relu2max_capped_hypersphere_peri_ln.yaml) to compare capped hypersphere norms vs RMSNorm under peri-LN and pre-LN settings.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 8 comments.

File Description
variations/norm_variations.py Extends CappedHyperSphereNorm with optional gain, but currently misses other hsnorm_* behaviors used by the new sweep.
explorations/relu2max_capped_hypersphere_peri_ln.yaml Introduces a new sweep configuration for peri-LN + ReLU2Max + capped hypersphere norms (with a couple of sweep-definition issues to fix).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +212 to +224
ndim = config.n_embd

self.radius = math.sqrt(ndim)

if config.hsnorm_gain:
self.gain = nn.Parameter(torch.ones(ndim))
else:
self.gain = 1.0

def forward(self, x):
norms = x.norm(2, dim=-1, keepdim=True)
scale = torch.where(norms > self.radius, self.radius / (norms + 1e-8), torch.ones_like(norms))
return x * scale
return x * scale * self.gain
Comment on lines +67 to +70
named_variation_groups:
- named_group: "wte_norm_var"
named_group_alternates: ["capped_rmsnorm", "capped_pair", "rmsnorm"]

Comment on lines +92 to +94
named_group_variations:
- "wte_norm_var"
# Peri-LN WTE Norm
Comment on lines +107 to +109
hsnorm_radius_learning: [true]
named_group_variations:
- "wte_norm_var"
Comment on lines +120 to +122
hsnorm_radius_learning: [true]
named_group_variations:
- "wte_norm_var"
Comment on lines +83 to +88
- "qk_norm"
- "peri_ln"
- "rotary"
- "relu2max"
- "infinite"
- "hd_150"
Comment on lines +96 to +101
- "qk_norm"
- "peri_ln"
- "rotary"
- "relu2max"
- "infinite"
- "hd_150"
Comment on lines +112 to +117
- "qk_norm"
- "pre_ln"
- "rotary"
- "relu2max"
- "infinite"
- "hd_150"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants