diff --git a/README.md b/README.md index 6820be3..726b7bf 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,303 @@ -# OpenAI Grok Curve Experiments +Readme · MD +
-## Paper +# Does Grok grok grokking? -This is the code for the paper [Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets](https://arxiv.org/abs/2201.02177) by Alethea Power, Yuri Burda, Harri Edwards, Igor Babuschkin, and Vedant Misra +

+A functional bridge between +openai/grok +and +xai-org/grok-1 +

-## Installation and Training +

+Two projects that share a name, a history, and nothing else. +

+ +
+ +--- + +> **grok** /ɡrɒk/ *verb.* To understand something so thoroughly that it becomes part of you. +> +> — Robert A. Heinlein, *Stranger in a Strange Land* (1961) + +--- + +Three words. Three meanings. One question. + +**Grok** (noun) — xAI's 314-billion-parameter Mixture-of-Experts language model, [open-sourced](https://github.com/xai-org/grok-1) on March 17, 2024, under Apache 2.0. Named after Heinlein's verb. + +**grok** (verb) — to understand something with such intimacy that observer and observed merge. The word entered hacker culture in the 1960s and never left. + +**grokking** (noun, ML) — a phenomenon discovered by [Power et al. (2022)](https://arxiv.org/abs/2201.02177) where neural networks, long after memorizing their training set, suddenly and sharply generalize to unseen data. The [research code](https://github.com/openai/grok) was published by OpenAI. + +This repository forces the two `grok` repositories into a single project — not because they belong together, but because naming things is hard, irony compounds, and sometimes the most interesting things happen when you make incompatible things work together through sheer will. + +--- + +## The Two Groks + +| | `openai/grok` | `xai-org/grok-1` | +|---|---|---| +| **What** | Research code for the grokking phenomenon | Inference code for the Grok-1 LLM | +| **Lines** | ~500, PyTorch | ~1400, JAX/Haiku | +| **Parameters** | Trains models with ~100K params | Loads a model with 314B params | +| **GitHub Stars** | ~4.2K | ~49K | +| **Framework** | PyTorch + Lightning | JAX + Haiku | +| **Task** | Modular arithmetic (42 + 55 mod 97) | General language modeling | +| **Named after** | An ML phenomenon | A sci-fi verb | + +> Two `grok`s. One studying how small models suddenly *understand*. +> The other claiming to *be* understanding. +> +> This fork asks: what if we made them talk to each other? + +--- + +## Background + +In February 2023, Elon Musk publicly accused OpenAI of betraying its founding mission. On [February 19](https://x.com/elonmusk/status/1626516035863212034), he posted: + +> *"OpenAI was created as an open source (which is why I named it 'Open' AI), non-profit company to serve as a counterweight to Google, but now it has become a closed source, maximum-profit company effectively controlled by Microsoft."* + +By November 2023, xAI launched **Grok** as a closed product. On February 29, 2024, Musk [filed a lawsuit](https://www.courtlistener.com/docket/68235965/musk-v-altman/) against OpenAI. Eleven days later, on March 11, he [announced](https://x.com/elonmusk/status/1767108624038449405) that xAI would open-source Grok — and did so on March 17. + +Meanwhile, on GitHub: + +``` +$ git merge openai/grok xai-org/grok-1 + +CONFLICT (rename/rename): both repos are named "grok" +CONFLICT (content): PyTorch ≠ JAX +CONFLICT (scale): 100,000 params ≠ 314,000,000,000 params +CONFLICT (purpose): studying understanding ≠ claiming to understand + +Automatic merge failed; fix conflicts and then commit the result. +``` + +No one had thought to connect them. This fork does. + +--- + +## Architecture + +The bridge replicates Grok-1's architectural components in PyTorch and plugs them into OpenAI's training framework: + +``` + openai/grok xai-org/grok-1 + ┌───────────────────┐ ┌──────────────────────┐ + │ │ │ │ + │ Sinusoidal PE │ ─── replicated as RoPE ────► │ Rotary PE (RoPE) │ + │ │ │ │ + │ Multi-Head Attn │ ─── replicated as MHA ─────► │ GQA (48q / 8kv) │ + │ │ │ │ + │ FFN │ │ MoE FFN │ + │ ReLU(xW₁)W₂ │ ─── replicated as MoE ─────► │ 8 experts, top-2 │ + │ │ │ GELU(xW₁)⊙(xWᵥ)W₂ │ + │ │ │ │ + │ LayerNorm │ ─── replicated as RMSNorm ─► │ RMSNorm │ + │ │ │ │ + └───────────────────┘ └──────────────────────┘ + ~100K params 314,000,000K params + PyTorch + Lightning JAX + Haiku + task: 42 + 55 mod 97 task: everything +``` + +The resulting class — `GrokOneTransformer` — is a miniature Grok-1 that can be trained on the same arithmetic datasets, with the same training loop, on a fundamentally different optimizer landscape. + +--- + +## What This Does + +This is not a toy. The bridge is functional. + +### Bridge A: Grok-1 Architecture → OpenAI Framework + +Grok-1's architectural innovations transplanted into PyTorch for grokking experiments: + +- **Mixture of Experts** — 8 expert FFNs with top-2 routing, exactly as in Grok-1 +- **Rotary Positional Embeddings (RoPE)** — replacing sinusoidal encoding +- **RMSNorm** — replacing standard LayerNorm +- **Gated GELU (SwiGLU-style)** — replacing ReLU FFN + +```bash +# Standard grokking experiment (original, unchanged) +./grok-main/scripts/train.py --math_operator + --train_data_pct 5 + +# Grok-1 architecture — same task, different geometry +./grok-main/scripts/train.py --architecture grok1 --math_operator + --num_experts 8 + +# Auto-scaled miniature Grok-1 +./grok-main/scripts/train.py --architecture grok1_mini +``` + +New metrics track MoE-specific grokking signals per epoch: + +- **Routing entropy** — does expert selection become more uniform during generalization? +- **Expert specialization** — do memorizing models rely on "shortcut" experts? +- **Collapse index** — does routing collapse correlate with the training plateau? + +### Bridge B: OpenAI Tasks → Grok-1 Inference + +OpenAI's arithmetic evaluation brought to Grok-1's inference pipeline. Does a 314B language model know that 42 + 55 ≡ 0 (mod 97)? + +```bash +# Dry run — inspect problem generation, no checkpoint needed +python grok-1-main/run.py --eval-grokking --operator + --n-samples 50 --dry-run + +# Full evaluation (requires the 314B checkpoint, ~300GB) +python grok-1-main/run.py --eval-grokking --operator + --n-samples 100 +``` + +### Bridge C: Config Export + +Grok-1's `TransformerConfig` can now export itself as a scaled-down PyTorch-compatible config: + +```python +from model import TransformerConfig + +config = TransformerConfig( + emb_size=6144, num_layers=64, num_q_heads=48, + num_kv_heads=8, num_experts=8, num_selected_experts=2 +) + +# Scale down by 1/24 for a trainable experiment +mini = config.to_grokking_config(scale_factor=1/24) +# → {'d_model': 256, 'n_layers': 2, 'n_heads': 2, 'num_experts': 8, ...} +``` + +--- + +## The Research Question + +Beyond the provocation, there is a real empirical question. + +The 2022 grokking paper showed that dense transformers exhibit a sharp phase transition from memorization to generalization on algorithmic tasks. The transition is abrupt, nearly discontinuous, and poorly understood. But MoE models have a fundamentally different optimization geometry — the router introduces a discrete dispatch decision that creates a non-smooth loss landscape, load-balancing pressures, and the possibility of *routing collapse*. + +**Does the Mixture-of-Experts architecture change the grokking phenomenon?** + +Three testable hypotheses: + +1. **Routing entropy as a leading indicator** — Does the expert distribution become more uniform *before* the grokking transition shows up in the loss curve? If so, routing entropy might predict generalization before it happens. + +2. **Expert collapse → memorization trap** — If all tokens route to one expert, that expert memorizes everything. Grokking might require "breaking out" of this collapse first. + +3. **MoE capacity delays onset** — More experts means more room to memorize without pressure to generalize. Does a larger expert count push the phase transition further out in training time, or does the routing bottleneck actually accelerate it? + +None of these have been tested. This fork provides the infrastructure to test them. + +--- + +## Quick Start ```bash -pip install -e . -./scripts/train.py +# Verify both architectures instantiate and run forward passes (no GPU needed) +python does_grok_grok.py --demo + +# Run a comparative grokking experiment (standard vs Grok-1, logs to ./logs/) +python does_grok_grok.py --experiment --operator + --max-steps 50000 + +# Test Grok-1's arithmetic ability (dry run, no checkpoint needed) +python does_grok_grok.py --eval-grok1 --dry-run --n-samples 20 +``` + +Expected output from `--demo`: + +``` +[1] Standard transformer (dense, sinusoidal PE) + Params: 329,860 · 2L / 4H / 128D + +[2] Grok-1 architecture (MoE + RoPE + RMSNorm + gated GELU) + Params: 4,610,148 · 2L / 2H / 256D / 8 experts (top-2) + +[3] MoE routing (layer 0): + Expert 0 ████████░░░░░░░░ 0.125 + Expert 1 ████████░░░░░░░░ 0.125 + Expert 2 ████████░░░░░░░░ 0.125 + ... + + Routing entropy: 2.079 / 2.079 (perfectly uniform) + +[4] Config export (Grok-1 → PyTorch): + {'d_model': 256, 'n_layers': 2, 'n_heads': 2, 'num_experts': 8, ...} + +Both architectures operational. Bridge verified. ``` + +--- + +## Structure + +``` +. +├── does_grok_grok.py ← unified entry point +├── README.md ← you are here +│ +├── grok-main/ ← openai/grok (grokking research) +│ ├── grok/ +│ │ ├── transformer.py ← MODIFIED: +GrokOneTransformer, +MoE, +RoPE, +RMSNorm +│ │ ├── training.py ← MODIFIED: +architecture selection, +MoE metrics logging +│ │ ├── metrics.py ← MODIFIED: +expert_utilization_entropy, +specialization +│ │ ├── data.py ← MODIFIED: +format_for_grok1(), +eval suite generator +│ │ ├── __init__.py ← MODIFIED: +bridge exports +│ │ ├── measure.py ← unchanged +│ │ └── visualization.py ← unchanged +│ ├── setup.py ← MODIFIED: version bump +│ ├── scripts/ ← unchanged +│ └── nbs/ ← unchanged +│ +└── grok-1-main/ ← xai-org/grok-1 (Grok-1 314B) + ├── model.py ← MODIFIED: +to_grokking_config(), +architecture_summary() + ├── run.py ← MODIFIED: +--eval-grokking mode, +arithmetic eval + ├── runners.py ← unchanged + ├── checkpoint.py ← unchanged + ├── tokenizer.model ← unchanged + └── checkpoints/ ← unchanged (download separately) +``` + +Design principle: **every original file still works exactly as before.** All additions are appended, never replacing. Every bridge addition is marked with a `# Bridge:` comment. `git diff` against the upstream repos shows only additions. + +--- + +## Why + +Because both projects are named `grok`, and someone had to do it. + +Because Musk accused OpenAI of abandoning open source, then named his AI after a word that means to understand deeply, then open-sourced it eleven days after filing a lawsuit — while OpenAI had a research project with the same name, sitting quietly on GitHub, studying how models *learn to understand*. + +Because the word "grok" deserves better than to be caught in the crossfire. + +And because the answer to "Does Grok grok grokking?" is genuinely worth knowing. + +--- + +## Sources + +- [openai/grok](https://github.com/openai/grok) — Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets +- [xai-org/grok-1](https://github.com/xai-org/grok-1) — Grok-1 open weights (314B MoE) +- [Power et al. (2022)](https://arxiv.org/abs/2201.02177) — "Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets" +- [Su et al. (2021)](https://arxiv.org/abs/2104.09864) — "RoFormer: Enhanced Transformer with Rotary Position Embedding" +- [Heinlein (1961)](https://en.wikipedia.org/wiki/Stranger_in_a_Strange_Land) — *Stranger in a Strange Land* +- [Musk on OpenAI](https://x.com/elonmusk/status/1626516035863212034) — February 19, 2023 +- [xAI open-sources Grok](https://x.ai/blog/grok-os) — March 17, 2024 +- [Musk v. Altman](https://www.courtlistener.com/docket/68235965/musk-v-altman/) — Filed February 29, 2024 + +--- + +## License + +- `grok-main/` — [MIT License](grok-main/LICENSE) (original) +- `grok-1-main/` — [Apache 2.0](grok-1-main/LICENSE.txt) (original) +- Bridge code — public domain + +--- + +
+ +"The word is much wider in meaning than any English word conceived to date —
+it means to understand so thoroughly that the observer becomes a part of the observed."
+

+— Heinlein, via Jubal Harshaw +
+
diff --git a/does_grok_grok.py b/does_grok_grok.py new file mode 100644 index 0000000..446b464 --- /dev/null +++ b/does_grok_grok.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python +""" +does_grok_grok.py — The unified entry point of 强兼 (forceful compatibility) + + ╔═══════════════════════════════╗ + ║ Does Grok grok grokking? ║ + ╚═══════════════════════════════╝ + +This script bridges two open-source projects that share a name but nothing else: + + • openai/grok — Studies "grokking": the phenomenon where neural networks + suddenly generalize after prolonged memorization. Trains + small transformers on modular arithmetic. (PyTorch) + + • xai-org/grok-1 — The 314B-parameter Mixture-of-Experts language model + from xAI, open-sourced under Apache 2.0. (JAX/Haiku) + +This script asks a simple question with a beautiful double meaning: +Does Grok (the model) grok (deeply understand) grokking (the phenomenon)? + +We answer it in two ways: + + MODE 1: "Grok-1 Architecture Grokking" (--experiment) + Train a miniature version of Grok-1's architecture (MoE + RoPE + RMSNorm + + gated GELU) on OpenAI's grokking benchmark tasks, and compare the + grokking dynamics against the original dense transformer. Do MoE models + grok differently? Does the expert routing change during the phase + transition from memorization to generalization? + + MODE 2: "Does Grok-1 Know Arithmetic?" (--eval-grok1) + Generate modular arithmetic problems from OpenAI's grokking dataset + and evaluate Grok-1 on them. (Requires Grok-1 checkpoint.) + +Usage: + # Compare grokking: standard transformer vs. Grok-1 architecture + python does_grok_grok.py --experiment --operator + --max-steps 50000 + + # Quick demo (no training, just shows the architecture bridge) + python does_grok_grok.py --demo + + # Evaluate Grok-1 on arithmetic (requires checkpoint) + python does_grok_grok.py --eval-grok1 --checkpoint ./grok-1-main/checkpoints/ + +Copyright: This bridge is a work of 行为艺术 (behavioral art). +License: Both source projects' licenses apply to their respective code. + The bridge code itself is unlicensed — do whatever you want with it. +""" + +import argparse +import sys +import os +import json +import time + +# Add both project directories to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "grok-main")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "grok-1-main")) + + +BANNER = r""" + ╭──────────────────────────────────────────────────────────────╮ + │ │ + │ ██████╗ ██████╗ ██████╗ ██╗ ██╗ │ + │ ██╔════╝ ██╔══██╗██╔═══██╗██║ ██╔╝ │ + │ ██║ ███╗██████╔╝██║ ██║█████╔╝ │ + │ ██║ ██║██╔══██╗██║ ██║██╔═██╗ │ + │ ╚██████╔╝██║ ██║╚██████╔╝██║ ██╗ │ + │ ╚═════╝ ╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝ │ + │ │ + │ 强 兼 │ + │ — forceful compatibility — │ + │ │ + │ openai/grok ←──bridge──→ xai-org/grok-1 │ + │ (grokking) (Grok-1 314B) │ + │ PyTorch JAX/Haiku │ + │ │ + │ "Does Grok grok grokking?" │ + │ │ + ╰──────────────────────────────────────────────────────────────╯ +""" + + +def demo(): + """ + Demonstrate the 强兼 bridge without training or inference. + Shows that Grok-1's architecture can be instantiated in PyTorch + at grokking-experiment scale. + """ + print(BANNER) + print("=" * 66) + print(" MODE: Demo — Architecture Bridge Verification") + print("=" * 66) + + # Import from OpenAI's grokking framework + from grok.transformer import Transformer, GrokOneTransformer + from grok.data import ArithmeticDataset, MODULUS + + print("\n[1] Original OpenAI transformer (dense, sinusoidal PE):") + standard = Transformer( + n_layers=2, n_heads=4, d_model=128, + dropout=0.0, max_context_len=50, vocab_len=100, + ) + n_params_standard = sum(p.numel() for p in standard.parameters()) + print(f" Params: {n_params_standard:,}") + print(f" Architecture: {standard.n_layers}L / {standard.n_heads}H / {standard.d_model}D") + + print("\n[2] Grok-1 architecture (MoE + RoPE + RMSNorm + gated GELU):") + grok1_mini = GrokOneTransformer.from_grok1_config( + scale_factor=1/24, vocab_len=100, max_context_len=50, + ) + n_params_grok1 = sum(p.numel() for p in grok1_mini.parameters()) + print(f" Params: {n_params_grok1:,}") + print(f" Architecture: {grok1_mini.n_layers}L / {grok1_mini.n_heads}H / " + f"{grok1_mini.d_model}D / {grok1_mini.num_experts}E (top-{grok1_mini.num_selected_experts})") + + print(f"\n[3] Parameter ratio: Grok-1-mini is {n_params_grok1/n_params_standard:.1f}x " + f"the standard transformer") + print(f" (Real Grok-1 is ~314B params — " + f"{314_000_000_000/n_params_grok1:.0f}x this miniature)") + + # Quick forward pass test + import torch + print("\n[4] Forward pass test:") + dummy_input = torch.randint(0, 100, (2, 10)) # batch=2, seq=10 + + with torch.no_grad(): + out_std, _, _ = standard(dummy_input) + out_grok, _, _ = grok1_mini(dummy_input) + + print(f" Standard: input {tuple(dummy_input.shape)} → output {tuple(out_std.shape)}") + print(f" Grok-1: input {tuple(dummy_input.shape)} → output {tuple(out_grok.shape)}") + + # Show routing info + if grok1_mini.last_router_probs: + rp = grok1_mini.last_router_probs[0] + print(f"\n[5] MoE Routing (layer 0):") + print(f" Router probs shape: {tuple(rp.shape)}") + mean_probs = rp.mean(dim=(0, 1)) + for i, p in enumerate(mean_probs): + bar = "█" * int(p.item() * 40) + print(f" Expert {i}: {p.item():.3f} {bar}") + + # Cross-framework config export + print("\n[6] Cross-framework config export (JAX → PyTorch):") + try: + from model import TransformerConfig as Grok1TransformerConfig + grok1_config = Grok1TransformerConfig( + emb_size=48 * 128, widening_factor=8, key_size=128, + num_q_heads=48, num_kv_heads=8, num_layers=64, + num_experts=8, num_selected_experts=2, + ) + exported = grok1_config.to_grokking_config() + print(f" Grok-1 ({grok1_config.emb_size}D, {grok1_config.num_layers}L) →") + print(f" Grokking ({exported['d_model']}D, {exported['n_layers']}L, " + f"{exported['num_experts']}E)") + except ImportError: + print(" (Grok-1 model.py not in path — run from project root)") + + print("\n" + "=" * 66) + print(" Bridge verified. Both architectures are alive and compatible.") + print(" Run with --experiment to compare grokking dynamics.") + print("=" * 66 + "\n") + + +def run_experiment(args): + """ + Train both architectures on the same grokking task and compare. + """ + print(BANNER) + print("=" * 66) + print(" MODE: Comparative Grokking Experiment") + print(f" Task: {args.operator} mod 97 | Steps: {args.max_steps}") + print("=" * 66) + + from grok.training import train, add_args + + # Prepare base arguments + parser = add_args() + base_args = [ + "--math_operator", args.operator, + "--max_steps", str(args.max_steps), + "--train_data_pct", str(args.train_pct), + "--d_model", str(args.d_model), + "--n_layers", str(args.n_layers), + "--n_heads", str(args.n_heads), + ] + + # Run 1: Standard transformer + print("\n" + "─" * 66) + print(" [1/2] Training STANDARD transformer (OpenAI's original)") + print("─" * 66) + std_args = parser.parse_args(base_args + [ + "--architecture", "standard", + "--logdir", os.path.join(args.logdir, "standard"), + ]) + t0 = time.time() + train(std_args) + t_std = time.time() - t0 + + # Run 2: Grok-1 architecture + print("\n" + "─" * 66) + print(" [2/2] Training GROK-1 architecture (MoE + RoPE + RMSNorm)") + print("─" * 66) + grok1_args = parser.parse_args(base_args + [ + "--architecture", "grok1", + "--num_experts", str(args.num_experts), + "--num_selected_experts", str(args.num_selected_experts), + "--logdir", os.path.join(args.logdir, "grok1"), + ]) + t0 = time.time() + train(grok1_args) + t_grok1 = time.time() - t0 + + print("\n" + "=" * 66) + print(" EXPERIMENT COMPLETE") + print(f" Standard: {t_std:.1f}s | Grok-1: {t_grok1:.1f}s") + print(f" Logs: {args.logdir}/") + print("=" * 66) + print("\n Compare training curves to see if MoE affects grokking dynamics.") + print(" Key questions:") + print(" • Does the Grok-1 architecture grok faster or slower?") + print(" • Does expert routing entropy change at the grokking transition?") + print(" • Do MoE models find different generalization shortcuts?\n") + + +def eval_grok1(args): + """Evaluate Grok-1 on grokking arithmetic (delegates to grok-1-main/run.py).""" + print(BANNER) + os.chdir(os.path.join(os.path.dirname(__file__), "grok-1-main")) + from run import eval_grokking + eval_args = argparse.Namespace( + operator=args.operator, + n_samples=args.n_samples, + checkpoint_path=args.checkpoint, + tokenizer_path=args.tokenizer, + output=args.output, + dry_run=args.dry_run, + ) + eval_grokking(eval_args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="强兼 — Does Grok grok grokking?", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python does_grok_grok.py --demo + python does_grok_grok.py --experiment --operator + --max-steps 50000 + python does_grok_grok.py --eval-grok1 --dry-run + """, + ) + + mode = parser.add_mutually_exclusive_group(required=True) + mode.add_argument("--demo", action="store_true", + help="Demo: verify the architecture bridge works") + mode.add_argument("--experiment", action="store_true", + help="Run comparative grokking experiment") + mode.add_argument("--eval-grok1", action="store_true", + help="Evaluate Grok-1 on arithmetic tasks") + + # Experiment args + parser.add_argument("--operator", type=str, default="+") + parser.add_argument("--max-steps", type=int, default=50000) + parser.add_argument("--train-pct", type=float, default=5) + parser.add_argument("--d-model", type=int, default=128) + parser.add_argument("--n-layers", type=int, default=2) + parser.add_argument("--n-heads", type=int, default=4) + parser.add_argument("--num-experts", type=int, default=8) + parser.add_argument("--num-selected-experts", type=int, default=2) + parser.add_argument("--logdir", type=str, default="logs/qiangjian") + + # Eval args + parser.add_argument("--checkpoint", type=str, default="./grok-1-main/checkpoints/") + parser.add_argument("--tokenizer", type=str, default="./grok-1-main/tokenizer.model") + parser.add_argument("--n-samples", type=int, default=50) + parser.add_argument("--output", type=str, default=None) + parser.add_argument("--dry-run", action="store_true") + + args = parser.parse_args() + + if args.demo: + demo() + elif args.experiment: + run_experiment(args) + elif args.eval_grok1: + eval_grok1(args) diff --git a/grok-1-main/.gitignore b/grok-1-main/.gitignore new file mode 100644 index 0000000..24d0d7e --- /dev/null +++ b/grok-1-main/.gitignore @@ -0,0 +1,2 @@ +checkpoints/* +!checkpoints/README.md diff --git a/grok-1-main/CODE_OF_CONDUCT.md b/grok-1-main/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..d715425 --- /dev/null +++ b/grok-1-main/CODE_OF_CONDUCT.md @@ -0,0 +1 @@ +Be excellent to each other. diff --git a/grok-1-main/LICENSE.txt b/grok-1-main/LICENSE.txt new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/grok-1-main/LICENSE.txt @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/grok-1-main/README.md b/grok-1-main/README.md new file mode 100644 index 0000000..f68f21a --- /dev/null +++ b/grok-1-main/README.md @@ -0,0 +1,72 @@ +# Grok-1 + +> **Note:** This is `xai-org/grok-1` — the 314B MoE language model. This fork bridges it with [`openai/grok`](../grok-main/) (the grokking research framework). See the [root README](../README.md) for context. + +This repository contains JAX example code for loading and running the Grok-1 open-weights model. + +Make sure to download the checkpoint and place the `ckpt-0` directory in `checkpoints` - see [Downloading the weights](#downloading-the-weights) + +Then, run + +```shell +pip install -r requirements.txt +python run.py +``` + +to test the code. + +The script loads the checkpoint and samples from the model on a test input. + +Due to the large size of the model (314B parameters), a machine with enough GPU memory is required to test the model with the example code. +The implementation of the MoE layer in this repository is not efficient. The implementation was chosen to avoid the need for custom kernels to validate the correctness of the model. + +# Model Specifications + +Grok-1 is currently designed with the following specifications: + +- **Parameters:** 314B +- **Architecture:** Mixture of 8 Experts (MoE) +- **Experts Utilization:** 2 experts used per token +- **Layers:** 64 +- **Attention Heads:** 48 for queries, 8 for keys/values +- **Embedding Size:** 6,144 +- **Tokenization:** SentencePiece tokenizer with 131,072 tokens +- **Additional Features:** + - Rotary embeddings (RoPE) + - Supports activation sharding and 8-bit quantization +- **Maximum Sequence Length (context):** 8,192 tokens + +# Downloading the weights + +You can download the weights using a torrent client and this magnet link: + +``` +magnet:?xt=urn:btih:5f96d43576e3d386c9ba65b883210a393b68210e&tr=https%3A%2F%2Facademictorrents.com%2Fannounce.php&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337%2Fannounce +``` + +or directly using [HuggingFace 🤗 Hub](https://huggingface.co/xai-org/grok-1): +``` +git clone https://github.com/xai-org/grok-1.git && cd grok-1 +pip install huggingface_hub[hf_transfer] +huggingface-cli download xai-org/grok-1 --repo-type model --include ckpt-0/* --local-dir checkpoints --local-dir-use-symlinks False +``` + +# Grokking Evaluation Mode + +This fork adds a `--eval-grokking` mode to `run.py` that evaluates Grok-1 on modular arithmetic problems from OpenAI's grokking research: + +```bash +# Dry run (no checkpoint needed) +python run.py --eval-grokking --operator + --n-samples 50 --dry-run + +# Full evaluation (requires checkpoint) +python run.py --eval-grokking --operator + --n-samples 100 +``` + +Also adds `to_grokking_config()` to `TransformerConfig` for exporting Grok-1's architecture at experiment scale. + +# License + +The code and associated Grok-1 weights in this release are licensed under the +Apache 2.0 license. The license only applies to the source files in this +repository and the model weights of Grok-1. diff --git a/grok-1-main/checkpoint.py b/grok-1-main/checkpoint.py new file mode 100644 index 0000000..1c6e878 --- /dev/null +++ b/grok-1-main/checkpoint.py @@ -0,0 +1,221 @@ +# Copyright 2024 X.AI Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import contextlib +import logging +import math +import os +import pickle +import re +import shutil +import sys +import tempfile +from concurrent.futures import ThreadPoolExecutor, wait +from typing import Any, Optional + +import jax +import numpy as np +from jax.experimental import multihost_utils + +from model import QuantizedWeight8bit + +logger = logging.getLogger(__name__) +rank_logger = logging.getLogger("rank") + +# Needed for loading the checkpoint with pickle. +sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit + + +@contextlib.contextmanager +def copy_to_shm(file: str): + if file.startswith("/dev/shm/"): + # Nothing to do, the file is already in shared memory. + yield file + return + + tmp_dir = "/dev/shm/" + fd, tmp_path = tempfile.mkstemp(dir=tmp_dir) + try: + shutil.copyfile(file, tmp_path) + yield tmp_path + finally: + os.remove(tmp_path) + os.close(fd) + + +@contextlib.contextmanager +def copy_from_shm(file: str): + tmp_dir = "/dev/shm/" + fd, tmp_path = tempfile.mkstemp(dir=tmp_dir) + try: + yield tmp_path + shutil.copyfile(tmp_path, file) + finally: + os.remove(tmp_path) + os.close(fd) + + +def fast_unpickle(path: str) -> Any: + with copy_to_shm(path) as tmp_path: + with open(tmp_path, "rb") as f: + return pickle.load(f) + + +def fast_pickle(obj: Any, path: str) -> None: + with copy_from_shm(path) as tmp_path: + with open(tmp_path, "wb") as f: + pickle.dump(obj, f) + + +def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None): + """Loads a set of arrays.""" + pool = ThreadPoolExecutor(max_workers=32) + fs = list() + num_tensors = 0 + num_replicas = 1 + data_model_shards = math.prod(mesh_config) + if tensor_indices is None: + iterator = enumerate(shaped_arrays) + else: + iterator = zip(tensor_indices, shaped_arrays) + for i, t in iterator: + if (i % num_replicas) == ((jax.process_index() // data_model_shards) % num_replicas): + idx = ( + jax.process_index() // (num_replicas * data_model_shards) * data_model_shards + + jax.process_index() % data_model_shards + ) + fs.append( + pool.submit(fast_unpickle, os.path.join(directory, f"tensor{i:05d}_{idx:03d}")) + ) + num_tensors += 1 + else: + fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype)) + wait(fs) + return [f.result() for f in fs] + + +def path_tuple_to_string(path: tuple) -> str: + pieces = [] + for elem in path: + if isinstance(elem, jax.tree_util.DictKey): + pieces.append(elem.key) + elif isinstance(elem, jax.tree_util.GetAttrKey): + pieces.append(elem.name) + else: + assert isinstance(elem, (jax.tree_util.FlattenedIndexKey, jax.tree_util.SequenceKey)) + return "/".join(pieces) + + +def get_load_path_str( + init_path_str: str, + load_rename_rules: Optional[list[tuple[str, str]]] = None, + load_exclude_rules: Optional[list[str]] = None, +) -> Optional[str]: + # Exclusion + if load_exclude_rules is not None: + for search_pattern in load_exclude_rules: + if re.search(search_pattern, init_path_str): + return None + + # Renaming + load_path_str = init_path_str + if load_rename_rules is not None: + for search_pattern, replacement_pattern in load_rename_rules: + if re.search(search_pattern, load_path_str): + load_path_str = re.sub(search_pattern, replacement_pattern, load_path_str) + break + + return load_path_str + + +def replace_with_load_state( + init_state: Any, + load_state: Any, + load_rename_rules: Optional[list[tuple[str, str]]] = None, + load_exclude_rules: Optional[list[str]] = None, + mesh_config: tuple = (1, 1), +) -> Any: + flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state) + flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state) + load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load} + + replaced = [] + num_replicas = 1 + data_model_shards = math.prod(mesh_config) + for i, (init_path, tensor) in enumerate(flatten_init): + init_path_str = path_tuple_to_string(init_path) + load_path_str = get_load_path_str(init_path_str, load_rename_rules, load_exclude_rules) + if load_path_str is None: + rank_logger.info(f"Excluded from restore: {init_path_str}.") + replaced.append(tensor) + elif load_path_str in load_map: + if load_path_str == init_path_str: + rank_logger.info(f"Restored from ckpt: {init_path_str}.") + else: + rank_logger.info(f"Restored from ckpt: {init_path_str} <-- {load_path_str}.") + replaced.append(load_map[load_path_str]) + else: + rank_logger.info(f"Not found in ckpt: {init_path_str}.") + if (i % num_replicas) == ((jax.process_index() // data_model_shards) % num_replicas): + replaced.append(tensor) + else: + replaced.append(np.zeros_like(tensor)) + + return jax.tree_util.tree_unflatten(structure_init, replaced) + + +def restore( + checkpoint_path: str, + state_shapes: Any, + mesh, + between_hosts_config, + params_only, + state_sharding, + init_state: Optional[Any] = None, +) -> Any: + ckpt_path = os.path.join(checkpoint_path, "ckpt-0") + + rank_logger.info("Loading checkpoint at {}".format(ckpt_path)) + ckpt_shapes = state_shapes + ckpt_shapes_with_path, structure = jax.tree_util.tree_flatten_with_path(ckpt_shapes) + + ckpt_shapes_flat = [elem[1] for elem in ckpt_shapes_with_path] + loaded_tensors = load_tensors(ckpt_shapes_flat, ckpt_path, between_hosts_config) + + state = jax.tree_util.tree_unflatten(structure, loaded_tensors) + + # Sanity check to give a better error message. + ckpt_keys = set(state.params.keys()) + code_keys = set(state_sharding.params.keys()) + + if ckpt_keys != code_keys and init_state is None: + missing_in_ckpt = code_keys - ckpt_keys + missing_locally = ckpt_keys - code_keys + raise ValueError( + "Parameters in the code are not matching checkpoint parameters.\n" + "Params missing in checkpoint: {}\nParams missing in code: {}".format( + missing_in_ckpt, missing_locally + ) + ) + state_sharding = jax.tree_util.tree_map( + lambda x: jax.sharding.PartitionSpec() if x is None else x, + state_sharding, + is_leaf=lambda x: x is None, + ) + state = multihost_utils.host_local_array_to_global_array(state, mesh, state_sharding) + if params_only: + state = state.params + return state diff --git a/grok-1-main/checkpoints/README.md b/grok-1-main/checkpoints/README.md new file mode 100644 index 0000000..fc34b62 --- /dev/null +++ b/grok-1-main/checkpoints/README.md @@ -0,0 +1,3 @@ +# Checkpoint directory + +Place Grok-1 checkpoints here so they can be loaded by the example script. diff --git a/grok-1-main/model.py b/grok-1-main/model.py new file mode 100644 index 0000000..4869808 --- /dev/null +++ b/grok-1-main/model.py @@ -0,0 +1,1462 @@ +# Copyright 2024 X.AI Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import logging +import re +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union + +import haiku as hk +import jax +import jax.experimental.maps +import jax.numpy as jnp +from jax import config, tree_util +from jax.experimental.shard_map import shard_map +from jax.lax import with_sharding_constraint as pjit_sharding_constraint +from jax.sharding import PartitionSpec +from jax.sharding import PartitionSpec as P + +config.update("jax_spmd_mode", "allow_all") + +logger = logging.getLogger(__name__) +rank_logger = logging.getLogger("rank") + + +@dataclass +class QuantizedWeight8bit: + weight: jnp.array + scales: jnp.array + + @property + def shape(self): + return self.weight.shape + + +tree_util.register_pytree_node( + QuantizedWeight8bit, + lambda qw: ([qw.weight, qw.scales], ()), + lambda _, children: QuantizedWeight8bit(children[0], children[1]), +) + + +class TrainingState(NamedTuple): + """Container for the training state.""" + + params: hk.Params + + +def _match(qs, ks): + """Return True if regexes in qs match any window of strings in tuple ks.""" + # compile regexes and force complete match + qts = tuple(map(lambda x: re.compile(x + "$"), qs)) + for i in range(len(ks) - len(qs) + 1): + matches = [x.match(y) for x, y in zip(qts, ks[i:])] + if matches and all(matches): + return True + return False + + +def with_sharding_constraint(x, constraint): + if jax.experimental.maps.thread_resources.env.physical_mesh.empty: + return x + else: + return pjit_sharding_constraint(x, constraint) + + +def cast_bfloat16(x): + if x.dtype.kind == "f": + return x.astype(jnp.bfloat16) + else: + return x + + +def ffn_size(emb_size, widening_factor): + _ffn_size = int(widening_factor * emb_size) * 2 // 3 + _ffn_size = _ffn_size + (8 - _ffn_size) % 8 # ensure it's a multiple of 8 + logger.debug(f"emd_size: {emb_size} adjusted ffn_size: {_ffn_size}") + return _ffn_size + + +def apply_rules(rules): + def _apply_rules(path, value): + del value # Unused. + + path_list = [str(i.key).split("/") for i in path if isinstance(i, jax.tree_util.DictKey)] + flattened_path = jax.tree_util.tree_flatten(path_list)[0] + + for rule, replacement in rules: + if _match(rule, flattened_path): + if isinstance(replacement, PartitionSpec): + if "layer_stack" in flattened_path: + replacement = PartitionSpec(None, *replacement) + rank_logger.debug(f"Apply {replacement} to {flattened_path} with rule {rule}") + return replacement + rank_logger.info(f"{flattened_path} no matching found!") + return None + + return _apply_rules + + +TRANSFORMER_PARTITION_RULES = [ + # attention + (("multi_head_attention", "(query|key|value)", "w"), P("data", "model")), + (("multi_head_attention", "(query|key|value)", "b"), P(None)), + (("multi_head_attention", "linear", "w"), P("model", "data")), + (("multi_head_attention", "linear", "b"), P(None)), + # mlp + ((r"decoder_layer_[0-9]+", "linear", "w"), P("data", "model")), + ((r"decoder_layer_[0-9]+", "linear", "b"), P(None)), + ((r"decoder_layer_[0-9]+", "linear_v", "w"), P("data", "model")), + ((r"decoder_layer_[0-9]+", "linear_v", "b"), P(None)), + ( + (r"decoder_layer_[0-9]+", "linear_1", "w"), + P( + "model", + "data", + ), + ), + ((r"decoder_layer_[0-9]+", "linear_1", "b"), P(None)), + # layer norms + ((r"decoder_layer_[0-9]+", "layer_norm", "offset"), P(None)), + ((r"decoder_layer_[0-9]+", "layer_norm", "scale"), P(None)), + ((r"decoder_layer_[0-9]+", "layer_norm_1", "offset"), P(None)), + ((r"decoder_layer_[0-9]+", "layer_norm_1", "scale"), P(None)), + # rms norms + ((r"decoder_layer_[0-9]+", "rms_norm", "scale"), P(None)), + ((r"decoder_layer_[0-9]+", "rms_norm_1", "scale"), P(None)), + ((r"decoder_layer_[0-9]+", "rms_norm_2", "scale"), P(None)), + ((r"decoder_layer_[0-9]+", "rms_norm_3", "scale"), P(None)), + # router + (("router", "w"), P("data")), + # moe mlp + (("moe", "linear", "w"), P(None, "data", "model")), + (("moe", "linear", "b"), P(None)), + (("moe", "linear_v", "w"), P(None, "data", "model")), + (("moe", "linear_v", "b"), P(None)), + (("moe", "linear_1", "w"), P(None, "model", "data")), + (("moe", "linear_1", "b"), P(None)), + # layer norms + (("moe", "layer_norm", "offset"), P(None)), + (("moe", "layer_norm", "scale"), P(None)), + (("moe", "layer_norm_1", "offset"), P(None)), + (("moe", "layer_norm_1", "scale"), P(None)), + # rms norms + (("moe", "rms_norm", "scale"), P(None)), + (("moe", "rms_norm_1", "scale"), P(None)), + (("moe", "rms_norm_2", "scale"), P(None)), + (("moe", "rms_norm_3", "scale"), P(None)), +] + +LM_PARTITION_RULES = [ + # Embedding layer. + ( + ("language_model", "positional_embeddings"), + P(None, ("data", "model")), + ), + ( + ("language_model", "in_out_embed", "embeddings"), + P(None, ("data", "model")), + ), + # Final RMSNorm. + (("language_model", "rms_norm"), P(None)), +] +TOP_K = 8 + + +class KVMemory(NamedTuple): + k: Optional[jax.Array] + v: Optional[jax.Array] + step: Optional[jax.Array] + + +def init_layer_memories( + batch_size: int, + sequence_len: int, + num_kv_heads: int, + key_size: int, + num_layers: int, + step: Optional[jax.Array] = None, + dtype=jnp.bfloat16, +): + return [ + KVMemory( + k=jnp.zeros((batch_size, sequence_len, num_kv_heads, key_size), dtype=dtype), + v=jnp.zeros((batch_size, sequence_len, num_kv_heads, key_size), dtype=dtype), + step=step, + ) + for _ in range(num_layers) + ] + + +class Memory(NamedTuple): + # Self-attention key/value cache. + layers: List[KVMemory] + + +class Router(hk.Module): + def __init__( + self, + num_selected_experts: int, + data_axis: Union[str, Tuple[str, ...]] = "data", + model_axis: Union[str, Tuple[str, ...]] = "model", + shard_activations: bool = False, + mesh: Any = None, + name: str = "router", + ): + super().__init__(name) + self.shard_activations = shard_activations + self.data_axis = data_axis + self.model_axis = model_axis + self.mesh = mesh + self.num_selected_experts = num_selected_experts + + def compute_routing_prob( + self, inputs: jax.Array, padding_mask: Optional[jax.Array], num_experts: int + ): + return self._compute_routing_prob(inputs, padding_mask, num_experts) + + @hk.transparent + def _compute_routing_prob( + self, + inputs: jax.Array, + padding_mask: Optional[jax.Array], + num_experts: int, + ): + # Using fp32 for the routing prob computation. + inputs = jax.lax.convert_element_type(inputs, jnp.float32) + + # [batch_size, seq_len, num_experts] + routing_logits = self._router_weights(inputs, num_experts, sharding=P("data")) + assert routing_logits.dtype == jnp.float32 + routing_probs = jax.nn.softmax(routing_logits) + + if padding_mask is not None: + routing_probs *= padding_mask + + return routing_probs, routing_logits, 0 + + @hk.transparent + def _router_weights( + self, + x: jax.Array, + num_experts: int, + sharding: Optional[P] = None, + ): + fprop_dtype = x.dtype + if not x.shape: + raise ValueError("Input must not be scalar.") + + input_size = self.input_size = x.shape[-1] + w = hk.get_parameter( + "w", [input_size, num_experts], jnp.float32, init=hk.initializers.Constant(0) + ) + if sharding: + w = with_sharding_constraint(w, sharding) + + out = jnp.dot(x, w.astype(fprop_dtype)) + return out + + +class MoELayer(hk.Module): + def __init__( + self, + num_experts: int, + layer_fn: Callable, + router: Router, + mesh: Any = None, + shard_activations: bool = False, + data_axis: Union[str, Tuple[str, ...]] = "data", + model_axis: Union[str, Tuple[str, ...]] = "model", + name: Optional[str] = "moe", + ): + super().__init__(name) + self.num_experts = num_experts + self.layer_fn = layer_fn + self.router = router + self.mesh = mesh + self.shard_activations = shard_activations + self.data_axis = data_axis + self.model_axis = model_axis + + @hk.transparent + def _inference_call(self, inputs: jax.Array, padding_mask: Optional[jax.Array] = None): + routing_probs, _, _ = self.router.compute_routing_prob( + inputs, padding_mask, self.num_experts + ) + expert_gate, expert_index = jax.lax.top_k(routing_probs, k=self.router.num_selected_experts) + tmp = jnp.reshape(inputs, (inputs.shape[0] * inputs.shape[1], inputs.shape[2])) + broad_inputs = jnp.tile(tmp[:, jnp.newaxis, :], (1, self.router.num_selected_experts, 1)) + broad_inputs = jnp.reshape( + broad_inputs, (broad_inputs.shape[0] * broad_inputs.shape[1], broad_inputs.shape[2]) + ) + init_fn, _ = hk.transform(self.layer_fn) + vmapped_init_fn = jax.vmap(init_fn, in_axes=0, out_axes=0) + lifted_init_fn = hk.experimental.transparent_lift(vmapped_init_fn) + # Fetch the vmapped params of the DenseBlock. + params = lifted_init_fn( + jax.random.split(jax.random.PRNGKey(1), self.num_experts), + jnp.zeros((self.num_experts, 1, 1, inputs.shape[-1])), + ) + + # Index and prob are in the shape [m, 2] indicating which token assigned to which experts. + # b: num_expert + # m: token or sequence dim + # k: input embed dim + # n: output embed dim + # e: the number of experts chosen for each token + @functools.partial( + shard_map, + mesh=self.mesh, + in_specs=( + P(self.data_axis, None), + P(None, None, self.model_axis), + P(None, None, self.model_axis), + P(None), + P(None), + ), + out_specs=P(self.data_axis, self.model_axis), + check_rep=False, + ) + def moe_slow_matmul1(input, weight, scales, index, prob): + weight = weight * scales + one_hot_indices = jax.nn.one_hot(index.reshape(-1), 8, axis=0) + all_expert_output = jnp.einsum("mk,bkn->bmn", input, weight) + output = jnp.einsum("bm,bmn->mn", one_hot_indices, all_expert_output) + return output + + @functools.partial( + shard_map, + mesh=self.mesh, + in_specs=( + P(self.data_axis, self.model_axis), + P(None, self.model_axis, None), + P(None, self.model_axis, None), + P(None), + P(None), + ), + out_specs=P(self.data_axis, None), + check_rep=False, + ) + def moe_slow_matmul2(input, weight, scales, index, prob): + weight = weight * scales + one_hot_indices = jax.nn.one_hot(index.reshape(-1), 8, axis=0) + all_expert_output = jnp.einsum("mk,bkn->bmn", input, weight) + output = jnp.einsum("bm,bmn->mn", one_hot_indices, all_expert_output) + return jax.lax.psum(output, axis_name="model") + + if hasattr(params["linear"]["w"], "scales"): + x = moe_slow_matmul1( + broad_inputs, + params["linear_v"]["w"].weight, + params["linear_v"]["w"].scales, + expert_index, + expert_gate, + ) + y = moe_slow_matmul1( + broad_inputs, + params["linear"]["w"].weight, + params["linear"]["w"].scales, + expert_index, + expert_gate, + ) + y = jax.nn.gelu(y) + out = moe_slow_matmul2( + x * y, + params["linear_1"]["w"].weight, + params["linear_1"]["w"].scales, + expert_index, + expert_gate, + ) + out = jnp.reshape( + out, + [ + inputs.shape[0], + inputs.shape[1], + self.router.num_selected_experts, + out.shape[-1], + ], + ) + out = expert_gate[:, :, :, None].astype(jnp.bfloat16) * out + out = jnp.sum(out, axis=2) + out = out.astype(jnp.bfloat16) + else: + # This is only here so that we can construct a valid init_fn with this code. + return inputs + return out + + def __call__(self, inputs: jax.Array, padding_mask: jax.Array): + return self._inference_call(inputs) + + +class MHAOutput(NamedTuple): + """Outputs of the multi-head attention operation.""" + + embeddings: jax.Array + memory: Any + + +class DecoderOutput(NamedTuple): + embeddings: jax.Array + memory: Any + + +class TransformerOutput(NamedTuple): + embeddings: jax.Array + memory: Any + + +@dataclass +class TransformerConfig: + emb_size: int + key_size: int + num_q_heads: int + num_kv_heads: int + num_layers: int + vocab_size: int = 128 * 1024 + widening_factor: float = 4.0 + + attn_output_multiplier: float = 1.0 + + name: Optional[str] = None + + num_experts: int = -1 + capacity_factor: float = 1.0 + num_selected_experts: int = 1 + + init_scale: float = 1.0 + shard_activations: bool = False + + # Used for activation sharding. + data_axis: Union[str, Tuple[str, ...]] = "data" + model_axis: Union[str, Tuple[str, ...]] = "model" + + def __post_init__(self): + if isinstance(self.data_axis, list): + self.data_axis = tuple(self.data_axis) + if isinstance(self.model_axis, list): + self.model_axis = tuple(self.model_axis) + + def partition_rules(self): + return TRANSFORMER_PARTITION_RULES + + def make(self, mesh=None) -> "Transformer": + data_axis = tuple(self.data_axis) if isinstance(self.data_axis, list) else self.data_axis + model_axis = ( + tuple(self.model_axis) if isinstance(self.model_axis, list) else self.model_axis + ) + + return Transformer( + num_q_heads=self.num_q_heads, + num_kv_heads=self.num_kv_heads, + widening_factor=self.widening_factor, + key_size=self.key_size, + init_scale=self.init_scale, + mesh=mesh, + attn_output_multiplier=self.attn_output_multiplier, + shard_activations=self.shard_activations, + num_layers=self.num_layers, + num_experts=self.num_experts, + num_selected_experts=self.num_selected_experts, + data_axis=data_axis, + model_axis=model_axis, + ) + + def get_memory_sharding(self): + return Memory( + layers=[ + KVMemory( + k=P(self.data_axis, self.model_axis), + v=P(self.data_axis, self.model_axis), + step=P(self.data_axis), + ) + for _ in range(self.num_layers) + ], + ) + + # ── 强兼 Bridge ───────────────────────────────────────────── + # Exports Grok-1's architecture parameters in a format compatible + # with the OpenAI grokking framework (PyTorch). + # Source: https://github.com/openai/grok + # ───────────────────────────────────────────────────────────── + + def to_grokking_config(self, scale_factor: float = 1/24) -> dict: + """ + Export a scaled-down version of this Grok-1 config for grokking + experiments in the OpenAI grokking framework. + + The architecture is preserved (MoE, RoPE, gated GELU) but all + dimensions are scaled down by `scale_factor` to enable training + on consumer hardware. + + Returns a dict compatible with GrokOneTransformer.__init__() from + grok-main/grok/transformer.py. + + >>> config = TransformerConfig(emb_size=6144, key_size=128, ...) + >>> grokking_args = config.to_grokking_config(scale_factor=1/24) + >>> # grokking_args can be passed to GrokOneTransformer(**grokking_args) + """ + d_model = max(64, int(self.emb_size * scale_factor)) + n_heads = max(2, int(self.num_q_heads * scale_factor)) + d_model = d_model - (d_model % n_heads) # ensure divisibility + n_layers = max(2, int(self.num_layers * scale_factor)) + + return { + "n_layers": n_layers, + "n_heads": n_heads, + "d_model": d_model, + "num_experts": self.num_experts, + "num_selected_experts": self.num_selected_experts, + "widening_factor": int(self.widening_factor), + # Metadata for provenance tracking + "_source": "grok-1", + "_original_emb_size": self.emb_size, + "_original_num_layers": self.num_layers, + "_original_num_q_heads": self.num_q_heads, + "_scale_factor": scale_factor, + } + + def architecture_summary(self) -> str: + """Human-readable summary of this Grok-1 configuration.""" + total_params_approx = ( + self.emb_size * self.vocab_size # embedding + + self.num_layers * ( + 4 * self.emb_size * self.key_size * self.num_q_heads # attention + + self.num_experts * 3 * self.emb_size * + int(self.widening_factor * self.emb_size * 2 / 3) # MoE FFN + ) + ) + return ( + f"Grok-1 TransformerConfig\n" + f" Layers: {self.num_layers}\n" + f" Embedding: {self.emb_size}\n" + f" Q-Heads: {self.num_q_heads}\n" + f" KV-Heads: {self.num_kv_heads}\n" + f" Key size: {self.key_size}\n" + f" Experts: {self.num_experts} (top-{self.num_selected_experts})\n" + f" Widening: {self.widening_factor}x\n" + f" ~Params: {total_params_approx/1e9:.1f}B\n" + ) + + +def hk_rms_norm( + x: jax.Array, + fixed_scale=False, + sharding=P(None), +) -> jax.Array: + """Applies a unique LayerNorm to x with default settings.""" + ln = RMSNorm(axis=-1, create_scale=not fixed_scale, sharding=sharding) + return ln(x) + + +def make_attention_mask( + query_input: jax.Array, + key_input: jax.Array, + pairwise_fn: Callable[..., Any] = jnp.multiply, + dtype: Any = jnp.bfloat16, +): + """Mask-making helper for attention weights. + + In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the + attention weights will be `[batch..., heads, len_q, len_kv]` and this + function will produce `[batch..., 1, len_q, len_kv]`. + + Args: + query_input: a batched, flat input of query_length size + key_input: a batched, flat input of key_length size + pairwise_fn: broadcasting elementwise comparison function + dtype: mask return dtype + + Returns: + A `[batch..., 1, len_q, len_kv]` shaped mask for 1d attention. + """ + mask = pairwise_fn(jnp.expand_dims(query_input, axis=-1), jnp.expand_dims(key_input, axis=-2)) + mask = jnp.expand_dims(mask, axis=-3) + return mask.astype(dtype) + + +class Linear(hk.Linear): + def __init__( + self, + output_size: int, + with_bias: bool = True, + sharding: Optional[P] = None, + mesh: Any = None, + name: Optional[str] = None, + shard_axis: int = 0, + ): + super().__init__( + output_size=output_size, + with_bias=with_bias, + name=name, + ) + self.sharding = sharding + self.mesh = mesh + self.shard_axis = shard_axis + + def __call__( + self, + inputs: jax.Array, + ) -> jax.Array: + """Computes a linear transform of the input.""" + + fprop_dtype = inputs.dtype + if not inputs.shape: + raise ValueError("Input must not be scalar.") + + input_size = self.input_size = inputs.shape[-1] + output_size = self.output_size + + w = hk.get_parameter( + "w", [input_size, output_size], jnp.float32, init=hk.initializers.Constant(0) + ) + + if hasattr(w, "scales"): + shape = inputs.shape + inputs = jnp.reshape(inputs, (-1, shape[-1])) + + @functools.partial( + shard_map, + mesh=self.mesh, + in_specs=(self.sharding, self.sharding), + out_specs=self.sharding, + check_rep=False, + ) + def mul(w, s): + return w.astype(s.dtype) * s + + w = mul(w.weight, w.scales) + out = jnp.dot(inputs, w.astype(fprop_dtype)) + if self.with_bias: + b = hk.get_parameter( + "b", [self.output_size], jnp.float32, init=hk.initializers.Constant(0) + ) + b = jnp.broadcast_to(b, out.shape) + out = out + b.astype(fprop_dtype) + + return out + + +class RMSNorm(hk.RMSNorm): + + def __init__( + self, + axis: Union[int, Sequence[int], slice], + eps: float = 1e-5, + name: Optional[str] = None, + create_scale: bool = True, + sharding: Optional[P] = None, + ): + super().__init__(axis, eps, create_scale=create_scale, name=name) + self.sharding = sharding + + def __call__(self, inputs: jax.Array): + fprop_dtype = inputs.dtype + param_shape = (inputs.shape[-1],) + if self.create_scale: + scale = hk.get_parameter( + "scale", + param_shape, + dtype=jnp.float32, + init=hk.initializers.Constant(0), + ) + if self.sharding: + scale = with_sharding_constraint(scale, self.sharding) + scale = jnp.broadcast_to(scale.astype(jnp.float32), inputs.shape) + else: + scale = 1.0 + inputs = inputs.astype(jnp.float32) + scale = scale.astype(jnp.float32) + mean_squared = jnp.mean(jnp.square(inputs), axis=[-1], keepdims=True) + mean_squared = jnp.broadcast_to(mean_squared, inputs.shape) + + normed_inputs = inputs * jax.lax.rsqrt(mean_squared + self.eps) + + outputs = scale * normed_inputs + + return outputs.astype(fprop_dtype) + + +def rotate_half( + x: jax.Array, +) -> jax.Array: + """Obtain the rotated counterpart of each feature""" + x1, x2 = jnp.split(x, 2, axis=-1) + return jnp.concatenate((-x2, x1), axis=-1) + + +class RotaryEmbedding(hk.Module): + """Applies rotary embeddings (RoPE) to the input sequence tensor, + as described in https://arxiv.org/abs/2104.09864. + + Attributes: + dim (int): Dimensionality of the feature vectors + base_exponent (int): Base exponent to compute embeddings from + """ + + def __init__( + self, + dim: int, + name: Optional[str] = None, + base_exponent: int = 10000, + ): + super().__init__(name) + self.dim = dim + self.base_exponent = base_exponent + assert self.dim % 2 == 0 + + def __call__( + self, + x: jax.Array, + seq_dim: int, + offset: jax.Array, + const_position: Optional[int] = None, + t: Optional[jax.Array] = None, + ) -> jax.Array: + fprop_dtype = x.dtype + # Compute the per-dimension frequencies + exponents = jnp.arange(0, self.dim, 2, dtype=jnp.float32) + inv_freq = jnp.asarray( + 1.0 / (self.base_exponent ** (exponents / self.dim)), dtype=jnp.float32 + ) + + if jnp.shape(offset) == (): + # Offset can be a scalar or one offset per batch element. + offset = jnp.expand_dims(offset, 0) + + # Compute the per element phase (to pass into sin and cos) + if const_position: + t = const_position * jnp.ones( + ( + 1, + x.shape[seq_dim], + ), + dtype=jnp.float32, + ) + elif t is None: + t = jnp.arange(x.shape[seq_dim], dtype=jnp.float32) + jnp.expand_dims(offset, -1) + phase = jnp.einsum("bi,j->bij", t, inv_freq) + phase = jnp.tile(phase, reps=(1, 2))[:, :, None, :] + + x = x * jnp.cos(phase) + rotate_half(x) * jnp.sin(phase) + x = x.astype(fprop_dtype) + + return x + + +class MultiHeadAttention(hk.Module): + def __init__( + self, + num_q_heads: int, + num_kv_heads: int, + key_size: int, + *, + with_bias: bool = True, + value_size: Optional[int] = None, + model_size: Optional[int] = None, + attn_output_multiplier: 1.0, + data_axis: Union[str, Tuple[str, ...]] = "data", + model_axis: Union[str, Tuple[str, ...]] = "model", + name: Optional[str] = None, + ): + super().__init__(name=name) + self.num_q_heads = num_q_heads + self.num_kv_heads = num_kv_heads + self.key_size = key_size + self.value_size = value_size or key_size + self.model_size = model_size or key_size * num_q_heads + self.data_axis = data_axis + self.model_axis = model_axis + self.attn_output_multiplier = attn_output_multiplier + self.with_bias = with_bias + + def __call__( + self, + query: jax.Array, + key: Optional[jax.Array], + value: Optional[jax.Array], + mask: Optional[jax.Array] = None, + kv_memory: Optional[KVMemory] = None, + mesh: Any = None, + ) -> MHAOutput: + # In shape hints below, we suppress the leading dims [...] for brevity. + # Hence e.g. [A, B] should be read in every case as [..., A, B]. + sequence_length = query.shape[1] + projection = self._linear_projection + use_memory = False + if kv_memory is not None: + if kv_memory.k is None: + assert kv_memory.v is None + assert key is not None + assert value is not None + else: + assert kv_memory.v is not None + use_memory = True + else: + assert key is not None + assert value is not None + + # Check that the keys and values have consistent batch size and sequence length. + if not use_memory: + assert key.shape[:2] == value.shape[:2], f"key/value shape: {key.shape}/{value.shape}" + + if mask is not None: + assert mask.ndim == 4 + assert mask.shape[0] in { + 1, + query.shape[0], + }, f"mask/query shape: {mask.shape}/{query.shape}" + if not use_memory: + assert key.shape[0] in { + 1, + query.shape[0], + }, f"key/query shape: {key.shape}/{query.shape}" + assert mask.shape[1] == 1 + assert mask.shape[2] in { + 1, + query.shape[1], + }, f"mask/query shape: {mask.shape}/{query.shape}" + if not use_memory: + assert mask.shape[3] in { + 1, + key.shape[1], + }, f"mask/query shape: {mask.shape}/{key.shape}" + + # Compute key/query/values (overload K/Q/V to denote the respective sizes). + assert self.num_q_heads % self.num_kv_heads == 0 + query_heads = projection( + query, + self.key_size, + self.num_q_heads, + name="query", + sharding=P("data", "model"), + mesh=mesh, + ) # [B, T', H, Q=K] + + new_memory = None + key_heads = projection( + key, + self.key_size, + self.num_kv_heads, + name="key", + sharding=P("data", "model"), + mesh=mesh, + ) # [B, T, H, K] + value_heads = projection( + value, + self.value_size, + self.num_kv_heads, + name="value", + sharding=P("data", "model"), + mesh=mesh, + ) # [B, T, H, V] + + rotate = RotaryEmbedding(dim=self.key_size, base_exponent=int(1e4)) + key_heads = rotate(key_heads, seq_dim=1, offset=(kv_memory.step if kv_memory else 0)) + query_heads = rotate(query_heads, seq_dim=1, offset=(kv_memory.step if kv_memory else 0)) + + @functools.partial(jax.vmap) + def update_into(mem, start, update): + return jax.lax.dynamic_update_slice_in_dim(mem, update, start, axis=0) + + if kv_memory: + if mesh is not None: + + @functools.partial( + shard_map, + mesh=mesh, + in_specs=( + P("data", None, "model"), + P("data"), + P("data", None, "model"), + ), + out_specs=P("data", None, "model"), + check_rep=False, + ) + def update_into_shmap(mems, starts, updates): + return update_into(mems, starts, updates) + + key_heads = update_into_shmap(kv_memory.k, kv_memory.step, key_heads) + value_heads = update_into_shmap(kv_memory.v, kv_memory.step, value_heads) + else: + key_heads = update_into(kv_memory.k, kv_memory.step, key_heads) + value_heads = update_into(kv_memory.v, kv_memory.step, value_heads) + + new_step = kv_memory.step + sequence_length + memory_mask = jnp.arange(kv_memory.k.shape[1]) < new_step[:, None] + memory_mask = memory_mask[:, None, None, :] # [B, H, T, T] + if mask is not None: + mask = memory_mask * mask + else: + mask = memory_mask + + new_memory = KVMemory( + k=key_heads, + v=value_heads, + step=new_step, + ) + # Add separate dimension for grouped query heads. + query_heads = with_sharding_constraint(query_heads, P(self.data_axis, None, "model", None)) + key_heads = with_sharding_constraint(key_heads, P(self.data_axis, None, "model", None)) + value_heads = with_sharding_constraint(value_heads, P(self.data_axis, None, "model", None)) + b, t, h, d = query_heads.shape + _, _, kv_h, _ = key_heads.shape + assert h % kv_h == 0, f"query_heads {h} must be a multiple of kv_heads {kv_h}" + + query_heads = jnp.reshape(query_heads, (b, t, kv_h, h // kv_h, d)) + query_heads = with_sharding_constraint( + query_heads, P(self.data_axis, None, "model", None, None) + ) + + # Compute attention weights. + # Attention softmax is always carried out in fp32. + attn_logits = jnp.einsum("...thHd,...Thd->...hHtT", query_heads, key_heads).astype( + jnp.float32 + ) + attn_logits *= self.attn_output_multiplier + max_attn_val = jnp.array(30.0, dtype=attn_logits.dtype) + attn_logits = max_attn_val * jnp.tanh(attn_logits / max_attn_val) + + mask = mask[:, :, None, :, :] + + if mask is not None: + if mask.ndim != attn_logits.ndim: + raise ValueError( + f"Mask dimensionality {mask.ndim} must match logits dimensionality " + f"{attn_logits.ndim} for {mask.shape}/{attn_logits.shape}." + ) + attn_logits = jnp.where(mask, attn_logits, -1e30) + attn_weights = jax.nn.softmax(attn_logits).astype(query.dtype) # [H, T', T] + + # Weight the values by the attention and flatten the head vectors. + attn = jnp.einsum("...hHtT,...Thd->...thHd", attn_weights, value_heads) + attn = with_sharding_constraint(attn, P(self.data_axis, None, "model", None, None)) + leading_dims = attn.shape[:2] + attn = jnp.reshape(attn, (*leading_dims, -1)) # [T', H*V] + attn = with_sharding_constraint(attn, P(self.data_axis, None, "model")) + # Apply another projection to get the final embeddings. + final_projection = Linear( + self.model_size, + with_bias=False, + sharding=P("model", "data"), + mesh=mesh, + ) + return MHAOutput(final_projection(attn), new_memory) + + @hk.transparent + def _linear_projection( + self, + x: jax.Array, + head_size: int, + num_heads: int, + sharding: Optional[P] = None, + name: Optional[str] = None, + mesh: Any = None, + ) -> jax.Array: + y = Linear( + num_heads * head_size, + with_bias=False, + name=name, + sharding=sharding, + mesh=mesh, + )(x) + *leading_dims, _ = x.shape + return y.reshape((*leading_dims, num_heads, head_size)) + + +@dataclass +class MHABlock(hk.Module): + """A MHA Block""" + + num_q_heads: int + num_kv_heads: int + key_size: int + attn_output_multiplier: float = 1.0 + mesh: Any = None + data_axis: Union[str, Tuple[str, ...]] = "data" + model_axis: Union[str, Tuple[str, ...]] = "model" + + @hk.transparent + def __call__( + self, + inputs: jax.Array, # [B, T, D] + mask: jax.Array, # [B, 1, T, T] or [B, 1, 1, T] or B[1, 1, 1, 1] + layer_memory: Optional[KVMemory], + ) -> MHAOutput: + _, _, model_size = inputs.shape + assert mask.ndim == 4, f"shape: {mask.shape}" + assert mask.shape[2] in {1, inputs.shape[1]}, str(mask.shape) + assert mask.shape[3] in {1, inputs.shape[1]}, str(mask.shape) + side_input = inputs + + def attn_block(query, key, value, mask, memory) -> MHAOutput: + return MultiHeadAttention( + num_q_heads=self.num_q_heads, + num_kv_heads=self.num_kv_heads, + key_size=self.key_size, + model_size=model_size, + data_axis=self.data_axis, + model_axis=self.model_axis, + attn_output_multiplier=self.attn_output_multiplier, + )( + query, + key, + value, + mask, + memory, + mesh=self.mesh, + ) + + attn_output = attn_block(inputs, side_input, side_input, mask, layer_memory) + h_attn = attn_output.embeddings + + return attn_output._replace(embeddings=h_attn) + + +@dataclass +class DenseBlock(hk.Module): + num_q_heads: int + num_kv_heads: int + key_size: int + widening_factor: float = 4.0 + sharding_constraint: bool = False + mesh: Any = None + + @hk.transparent + def __call__( + self, + inputs: jax.Array, # [B, T, D] + ) -> jax.Array: # [B, T, D] + _, _, model_size = inputs.shape + h_v = Linear( + ffn_size( + model_size, + self.widening_factor, + ), + with_bias=False, + mesh=self.mesh, + sharding=P("data", "model"), + name="linear_v", + )(inputs) + h_w1 = jax.nn.gelu( + Linear( + ffn_size( + model_size, + self.widening_factor, + ), + with_bias=False, + mesh=self.mesh, + sharding=P("data", "model"), + )(inputs) + ) + h_dense = Linear( + model_size, + with_bias=False, + sharding=P("model", "data"), + mesh=self.mesh, + shard_axis=1, + )(h_w1 * h_v) + + return h_dense + + +@dataclass +class DecoderLayer(hk.Module): + """A transformer stack.""" + + num_q_heads: int + num_kv_heads: int + key_size: int + num_layers: int + # MoE. + num_experts: int + layer_index: Optional[int] = None + num_selected_experts: int = 1 + widening_factor: float = 4.0 + name: Optional[str] = None + data_axis: Union[str, Tuple[str, ...]] = "data" + model_axis: Union[str, Tuple[str, ...]] = "model" + shard_activations: bool = False + attn_output_multiplier: float = 1.0 + mesh: Any = None + + def __call__( + self, + inputs: jax.Array, # [B, T, D] + mask: jax.Array, # [B, 1, T, T] or [B, 1, 1, T] + padding_mask: Optional[jax.Array], + layer_memory: Optional[KVMemory], + ) -> DecoderOutput: + """Transforms input embedding sequences to output embedding sequences.""" + + def layer_norm(x): + return hk_rms_norm(x) + + if self.shard_activations: + sharding = P(self.data_axis, None, self.model_axis) + else: + sharding = P(self.data_axis, None) + h = with_sharding_constraint(inputs, sharding) + + attn_output = MHABlock( + num_q_heads=self.num_q_heads, + num_kv_heads=self.num_kv_heads, + key_size=self.key_size, + attn_output_multiplier=self.attn_output_multiplier, + mesh=self.mesh, + data_axis=self.data_axis, + model_axis=self.model_axis, + )(layer_norm(h), mask, layer_memory) + h_attn = attn_output.embeddings + + h_attn = layer_norm(h_attn) + h += h_attn + h = with_sharding_constraint(h, sharding) + + def base_dense_block(h): + h = DenseBlock( + num_q_heads=self.num_q_heads, + num_kv_heads=self.num_kv_heads, + key_size=self.key_size, + widening_factor=self.widening_factor, + sharding_constraint=False, + mesh=self.mesh, + )(h) + return h + + if self.num_experts > 1: + rank_logger.debug("Using MoE!") + router = Router( + num_selected_experts=self.num_selected_experts, + shard_activations=self.shard_activations, + data_axis=self.data_axis, + model_axis=self.model_axis, + mesh=self.mesh, + ) + h_dense = MoELayer( + num_experts=self.num_experts, + mesh=self.mesh, + layer_fn=base_dense_block, + router=router, + shard_activations=self.shard_activations, + data_axis=self.data_axis, + model_axis=self.model_axis, + )(layer_norm(h), padding_mask) + else: + h_dense = base_dense_block(layer_norm(h)) + + h_dense = layer_norm(h_dense) + h += h_dense + h = with_sharding_constraint(h, sharding) + + return DecoderOutput( + embeddings=h, + memory=attn_output.memory, + ) + + +class LanguageModelOutput(NamedTuple): + logits: jax.Array + model_state: Any + + +class InOutEmbed(hk.Embed): + """Module for embedding tokens in a low-dimensional space.""" + + def __init__( + self, + vocab_size: Optional[int] = None, + embed_dim: Optional[int] = None, + sharding: Optional[P] = None, + name: Optional[str] = None, + ): + super().__init__( + vocab_size=vocab_size, + embed_dim=embed_dim, + name=name, + ) + self.sharding = sharding + + @property + def embeddings(self): + embed_mat = hk.get_parameter( + "embeddings", + [self.vocab_size, self.embed_dim], + dtype=jnp.float32, + init=hk.initializers.Constant(0), + ) + if self.sharding: + embed_mat = with_sharding_constraint(embed_mat, self.sharding) + return embed_mat + + def decode( + self, + inputs: jax.Array, + ) -> jax.Array: + return jnp.dot(inputs, self.embeddings.T.astype(inputs.dtype)) + + +@dataclass +class LanguageModelConfig: + """An autoregressive transformer-based language model.""" + + model: Optional[TransformerConfig] + vocab_size: int + pad_token: int + eos_token: int + sequence_len: int + model_size: int = 0 + embedding_init_scale: float = 1.0 + embedding_multiplier_scale: float = 1.0 + output_multiplier_scale: float = 1.0 + name: Optional[str] = None + fprop_dtype: Any = jnp.bfloat16 + model_type: Optional[str] = None + init_scale_override: Optional[float] = None + shard_embeddings: bool = True + + _initialized = False + + def initialize(self): + # We cannot specify [] as a default value (it is mutable), hence None. + model_config = self.model + assert self.init_scale_override is None, ( + "Overriding model initialize scale is supported only for predefined models." + ) + if self.model_size == 0: + self.model_size = model_config.emb_size + assert self.model is not None, "Model could not be initialized." + self._initialized = True + return self + + def make(self, *args, **kwargs): + if not self._initialized: + logger.warning( + f"LanguageModel {self.name} is not initialized. Initializing for one replica." + ) + self.initialize() + + return LanguageModel( + model=self.model.make(*args, **kwargs), + config=self, + fprop_dtype=self.fprop_dtype, + mesh=kwargs.get("mesh", None), + ) + + def partition_rules(self): + return LM_PARTITION_RULES + self.model.partition_rules() + + +def layer_norm(x, model): + return hk_rms_norm(x) + + +@dataclass +class LanguageModel(hk.Module): + """An autoregressive transformer-based language model.""" + + model: "Transformer" + config: LanguageModelConfig + fprop_dtype: Any = jnp.bfloat16 + name: Optional[str] = None + mesh: Any = None + + def __call__( + self, + tokens: jax.Array, + memory: Optional[Memory] = None, + *, + batch: Dict[str, jax.Array] = {}, + last_hid_only: bool = False, + length: Optional[jax.Array] = None, + ) -> LanguageModelOutput: + """Forward pass, producing a sequence of logits.""" + del batch # Unused. + + config = self.config + + input_mask = jnp.greater(tokens, config.pad_token) + + # Embed the input tokens and positions. + in_out_embed = InOutEmbed( + self.config.vocab_size, + embed_dim=self.config.model_size, + sharding=P(None, ("data", "model")), + ) + input_embeddings = in_out_embed(tokens).astype(config.fprop_dtype) + input_embeddings = with_sharding_constraint( + input_embeddings, P("data", None, self.model.model_axis) + ) + input_embeddings *= config.embedding_multiplier_scale + + model_output = self.model( + input_embeddings, + input_mask, + memory=memory, + ) # [B, T, D] + embeddings, model_state = model_output.embeddings, model_output.memory + if self.model.shard_activations: + embeddings = with_sharding_constraint( + embeddings, P("data", None, self.model.model_axis) + ) + else: + embeddings = with_sharding_constraint(embeddings, P("data", None)) + rank_logger.debug(f"Final embedding shape: {embeddings.shape}") + embeddings = layer_norm(embeddings, self.model) + assert embeddings.dtype == self.fprop_dtype + + if last_hid_only: + last_step = jnp.maximum(jnp.sum(input_mask.astype(jnp.int32), axis=1) - 1, 0) + last_hid = jax.vmap(lambda x, i: x[i], in_axes=0, out_axes=0)(embeddings, last_step) + return last_hid + + if length is not None: + last_step = jnp.maximum(length.astype(jnp.int32) - 1, 0) + embeddings = jax.vmap(lambda x, i: x[i], in_axes=0, out_axes=0)(embeddings, last_step) + embeddings = jnp.expand_dims(embeddings, axis=1) + + # Decode the embeddings (here, we use tied weights). + rank_logger.info(embeddings.shape) + out = in_out_embed.decode(embeddings) + rank_logger.info(out.shape) + out *= config.output_multiplier_scale + + if self.model.shard_activations: + out = with_sharding_constraint(out, P("data", None, self.model.model_axis)) + else: + out = with_sharding_constraint(out, P("data", None)) + + return LanguageModelOutput( + logits=out, + model_state=model_state, + ) + + def init_memory(self, batch_size: int, seq_len: int, dtype=jnp.bfloat16): + return self.model.init_memory(batch_size=batch_size, sequence_len=seq_len, dtype=dtype) + + def prefill_memory(self, prompts, memory): + # Pad to the left and right align? + # Basically assume prompt is already padded + model_output = self(prompts, memory=memory) + return model_output.logits, model_output.model_state + + +@dataclass +class Transformer(hk.Module): + """A transformer stack.""" + + num_q_heads: int + num_kv_heads: int + key_size: int + widening_factor: float + init_scale: float + mesh: Any + attn_output_multiplier: float + shard_activations: bool + num_layers: int + # MoE + num_experts: int + num_selected_experts: int + name: Optional[str] = None + + # Used for activation sharding + data_axis: Union[str, Tuple[str, ...]] = "data" + model_axis: Union[str, Tuple[str, ...]] = "model" + + def init_memory(self, batch_size: int, sequence_len: int, dtype=jnp.bfloat16): + return Memory( + layers=init_layer_memories( + batch_size, + sequence_len, + self.num_kv_heads, + self.key_size, + self.num_layers, + step=jnp.zeros(batch_size, dtype=jnp.int32), + dtype=dtype, + ), + ) + + def __call__( + self, + embeddings: jax.Array, # [B, T, D] + mask: jax.Array, # [B, T] + memory: Optional[Memory], + ) -> TransformerOutput: + """Transforms input embedding sequences to output embedding sequences.""" + + fprop_dtype = embeddings.dtype + _, seq_len, model_size = embeddings.shape + padding_mask = mask.copy() + mask = mask[:, None, None, :] # [B, H=1, T'=1, T] + + # Compute causal mask for autoregressive sequence modelling. + causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len))).astype( + fprop_dtype + ) # [B=1, H=1, T, T] + mask = mask * causal_mask # [B, H=1, T, T] + + h = embeddings + kv_memories = [] + + def block( + h, + mask, + padding_mask, + memory, + layer_index: Optional[int] = None, + widening_factor: Optional[int] = None, + name: Optional[str] = None, + ) -> DecoderOutput: + return DecoderLayer( + num_q_heads=self.num_q_heads, + num_kv_heads=self.num_kv_heads, + key_size=self.key_size, + widening_factor=widening_factor or self.widening_factor, + num_layers=self.num_layers, + mesh=self.mesh, + data_axis=self.data_axis, + model_axis=self.model_axis, + attn_output_multiplier=self.attn_output_multiplier, + shard_activations=self.shard_activations, + # MoE. + num_experts=self.num_experts, + num_selected_experts=self.num_selected_experts, + name=name, + layer_index=layer_index, + )( + h, + mask, + padding_mask, + memory, + ) + + for i in range(self.num_layers): + decoder_output = block( + h, + mask, + padding_mask, + memory.layers[i] if memory else None, + layer_index=i, + name=f"decoder_layer_{i}", + ) + h, new_kv_memory = ( + decoder_output.embeddings, + decoder_output.memory, + ) + kv_memories.append(new_kv_memory) + + return TransformerOutput( + embeddings=h, + memory=Memory(layers=kv_memories), + ) diff --git a/grok-1-main/pyproject.toml b/grok-1-main/pyproject.toml new file mode 100644 index 0000000..aa55016 --- /dev/null +++ b/grok-1-main/pyproject.toml @@ -0,0 +1,14 @@ +[tool.ruff] +indent-width = 4 +line-length = 100 + +[tool.ruff.lint] +ignore = [ + "E722", + "E731", + "E741", + "F405", + "E402", + "F403", +] +select = ["ISC001"] diff --git a/grok-1-main/requirements.txt b/grok-1-main/requirements.txt new file mode 100644 index 0000000..a612687 --- /dev/null +++ b/grok-1-main/requirements.txt @@ -0,0 +1,4 @@ +dm_haiku==0.0.12 +jax[cuda12-pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +numpy==1.26.4 +sentencepiece==0.2.0 diff --git a/grok-1-main/run.py b/grok-1-main/run.py new file mode 100644 index 0000000..bc75c3e --- /dev/null +++ b/grok-1-main/run.py @@ -0,0 +1,261 @@ +# Copyright 2024 X.AI Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import logging +import re +import sys + +from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit +from runners import InferenceRunner, ModelRunner, sample_from_model + + +CKPT_PATH = "./checkpoints/" + + +def get_grok1_config(): + """Returns the standard Grok-1 model configuration.""" + return LanguageModelConfig( + vocab_size=128 * 1024, + pad_token=0, + eos_token=2, + sequence_len=8192, + embedding_init_scale=1.0, + output_multiplier_scale=0.5773502691896257, + embedding_multiplier_scale=78.38367176906169, + model=TransformerConfig( + emb_size=48 * 128, + widening_factor=8, + key_size=128, + num_q_heads=48, + num_kv_heads=8, + num_layers=64, + attn_output_multiplier=0.08838834764831845, + shard_activations=True, + # MoE. + num_experts=8, + num_selected_experts=2, + # Activation sharding. + data_axis="data", + model_axis="model", + ), + ) + + +def main(): + grok_1_model = get_grok1_config() + inference_runner = InferenceRunner( + pad_sizes=(1024,), + runner=ModelRunner( + model=grok_1_model, + bs_per_device=0.125, + checkpoint_path=CKPT_PATH, + ), + name="local", + load=CKPT_PATH, + tokenizer_path="./tokenizer.model", + local_mesh_config=(1, 8), + between_hosts_config=(1, 1), + ) + inference_runner.initialize() + gen = inference_runner.run() + + inp = "The answer to life the universe and everything is of course" + print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01)) + + +# ============================================================================= +# 强兼 BRIDGE — Grokking Evaluation Mode +# ============================================================================= +# +# Tests whether Grok-1 has "grokked" modular arithmetic — the same tasks +# that OpenAI's grokking paper studies with small transformers. +# +# Usage: python run.py --eval-grokking [--operator +] [--n-samples 50] +# +# Source: https://github.com/openai/grok +# ============================================================================= + +GROKKING_OPERATORS = { + "+": ("addition", "What is {a} + {b} mod {p}?"), + "-": ("subtraction", "What is {a} - {b} mod {p}?"), + "*": ("multiplication", "What is {a} * {b} mod {p}?"), + "/": ("division", "What is {a} / {b} mod {p}? (modular inverse)"), +} + + +def generate_arithmetic_problems(operator="+", n_samples=50, modulus=97, seed=42): + """Generate arithmetic problems matching OpenAI's grokking dataset.""" + import random as rng + rng.seed(seed) + + problems = [] + all_pairs = [(a, b) for a in range(modulus) for b in range(modulus)] + if operator == "/": + all_pairs = [(a, b) for a, b in all_pairs if b != 0] + rng.shuffle(all_pairs) + + for a, b in all_pairs[:n_samples]: + if operator == "+": + answer = (a + b) % modulus + elif operator == "-": + answer = (a - b) % modulus + elif operator == "*": + answer = (a * b) % modulus + elif operator == "/": + answer = (a * pow(b, modulus - 2, modulus)) % modulus + else: + continue + + _, template = GROKKING_OPERATORS[operator] + prompt = template.format(a=a, b=b, p=modulus) + " Answer with just the number." + problems.append({ + "a": a, "b": b, "operator": operator, + "expected": answer, "prompt": prompt, + }) + + return problems + + +def eval_grokking(args): + """ + Run Grok-1 on modular arithmetic problems from OpenAI's grokking research. + + This is the heart of the 强兼 bridge: does the 314B-parameter Grok-1 + model — named after the concept of deep understanding — actually + demonstrate deep understanding of the exact mathematical tasks that + OpenAI's "grokking" paper investigates? + """ + grok_1_model = get_grok1_config() + + # Print architecture bridge info + print("=" * 70) + print("强兼 BRIDGE: Does Grok grok grokking?") + print("=" * 70) + print(grok_1_model.model.architecture_summary()) + print(f"Grokking evaluation: {args.operator} mod 97") + print(f"Samples: {args.n_samples}") + print("=" * 70) + + # Export config for the grokking framework + grokking_config = grok_1_model.model.to_grokking_config() + print(f"\n[Bridge] Grok-1 → grokking config: {json.dumps({k:v for k,v in grokking_config.items() if not k.startswith('_')}, indent=2)}\n") + + # Generate problems + problems = generate_arithmetic_problems( + operator=args.operator, + n_samples=args.n_samples, + ) + + if not args.dry_run: + # Initialize the model + inference_runner = InferenceRunner( + pad_sizes=(1024,), + runner=ModelRunner( + model=grok_1_model, + bs_per_device=0.125, + checkpoint_path=args.checkpoint_path, + ), + name="grokking_eval", + load=args.checkpoint_path, + tokenizer_path=args.tokenizer_path, + local_mesh_config=(1, 8), + between_hosts_config=(1, 1), + ) + inference_runner.initialize() + gen = inference_runner.run() + + # Evaluate + correct = 0 + total = 0 + results = [] + for problem in problems: + response = sample_from_model( + gen, problem["prompt"], + max_len=20, temperature=0.01, + ) + # Parse the response for a number + numbers = re.findall(r'\b(\d+)\b', response) + predicted = int(numbers[-1]) if numbers else None + is_correct = predicted == problem["expected"] if predicted is not None else False + + results.append({ + **problem, + "predicted": predicted, + "correct": is_correct, + "raw_response": response, + }) + + if is_correct: + correct += 1 + total += 1 + + status = "✓" if is_correct else "✗" + print(f" {status} {problem['a']} {problem['operator']} {problem['b']} mod 97 " + f"= {problem['expected']} (Grok-1: {predicted})") + + # Summary + accuracy = correct / total * 100 if total > 0 else 0 + print(f"\n{'=' * 70}") + print(f"RESULTS: {correct}/{total} correct ({accuracy:.1f}%)") + print(f"{'=' * 70}") + + verdict = ( + "Grok GROKS grokking! 🎉" if accuracy > 95 else + "Grok partially groks grokking." if accuracy > 50 else + "Grok does NOT grok grokking. (yet?)" + ) + print(f"\nVerdict: {verdict}\n") + + # Save results + if args.output: + with open(args.output, "w") as f: + json.dump({"accuracy": accuracy, "results": results}, f, indent=2) + print(f"Results saved to {args.output}") + else: + print("[Dry run] Generated problems (no inference):") + for p in problems[:5]: + print(f" {p['prompt']} → expected: {p['expected']}") + print(f" ... ({len(problems)} total)") + print(f"\n[Bridge] To run with Grok-1, remove --dry-run and provide checkpoint.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Grok-1 inference & grokking evaluation (强兼 bridge)" + ) + parser.add_argument("--eval-grokking", action="store_true", + help="Run grokking arithmetic evaluation instead of default inference") + parser.add_argument("--operator", type=str, default="+", + choices=["+", "-", "*", "/"], + help="Arithmetic operator for grokking eval") + parser.add_argument("--n-samples", type=int, default=50, + help="Number of arithmetic problems to evaluate") + parser.add_argument("--checkpoint-path", type=str, default=CKPT_PATH, + help="Path to Grok-1 checkpoints") + parser.add_argument("--tokenizer-path", type=str, default="./tokenizer.model", + help="Path to tokenizer model") + parser.add_argument("--output", type=str, default=None, + help="Path to save JSON results") + parser.add_argument("--dry-run", action="store_true", + help="Generate problems without running inference") + + args = parser.parse_args() + logging.basicConfig(level=logging.INFO) + + if args.eval_grokking: + eval_grokking(args) + else: + main() diff --git a/grok-1-main/runners.py b/grok-1-main/runners.py new file mode 100644 index 0000000..452c142 --- /dev/null +++ b/grok-1-main/runners.py @@ -0,0 +1,605 @@ +# Copyright 2024 X.AI Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import bisect +import functools +import logging +import math +import re +from dataclasses import dataclass +from typing import Any, Callable, NamedTuple, Optional, Tuple + +import haiku as hk +import jax +import jax.experimental.pjit as pjit +import jax.numpy as jnp +import numpy as np +import sentencepiece +from jax.experimental import mesh_utils +from jax.sharding import PartitionSpec as P +from jax.typing import ArrayLike + +import checkpoint as xai_checkpoint +from model import ( + LanguageModelConfig, + LanguageModelOutput, + TrainingState, + apply_rules, + Memory, + KVMemory, +) + +logger = logging.getLogger(__name__) +rank_logger = logging.getLogger("rank") + +TOP_K = 8 + + +class SampleSettings(NamedTuple): + temperature: ArrayLike + nucleus_p: ArrayLike + mask: ArrayLike + # Whether a given batch element is actively used. [B] + active: ArrayLike + + +class SampleOutput(NamedTuple): + token_id: ArrayLike + prob: ArrayLike + top_k_token_ids: ArrayLike + top_k_probs: ArrayLike + + +def insert_slice(memory: Memory, slice, length, i): + slice = Memory( + layers=[ + KVMemory(layer.k, layer.v, step=jnp.array([length])) + for layer in slice.layers + ], + ) + + return jax.tree_map(lambda m, u: jax.lax.dynamic_update_index_in_dim(m, u[0], i, axis=0), + memory, slice) + + +def pad_to_size(x, size): + if x.shape[0] > size: + # Left truncate if the context is too long. + x = x[-size:] + return np.pad(x, [0, size - x.shape[0]], mode="constant", constant_values=0) + + +def top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array: + """Performs nucleus filtering on logits.""" + assert logits.ndim == top_p.ndim, f"Expected {logits.ndim} equal {top_p.ndim}" + sorted_logits = jax.lax.sort(logits, is_stable=False) + sorted_probs = jax.nn.softmax(sorted_logits) + threshold_idx = jnp.argmax(jnp.cumsum(sorted_probs, -1) >= 1 - top_p, axis=-1) + threshold_largest_logits = jnp.take_along_axis( + sorted_logits, threshold_idx[..., jnp.newaxis], axis=-1 + ) + assert threshold_largest_logits.shape == logits.shape[:-1] + (1,) + mask = logits >= threshold_largest_logits + # Set unused logits to -inf. + logits = jnp.where(mask, logits, -1e10) + return logits + + +def sample_token( + rngs: jax.random.PRNGKey, + lm_outputs: LanguageModelOutput, + settings: SampleSettings, +) -> SampleOutput: + # Expand the settings shape to match the logit shape. + settings = SampleSettings( + temperature=jnp.expand_dims(settings.temperature, (1, 2)), # Input [B], output [B, 1, 1]. + nucleus_p=jnp.expand_dims(settings.nucleus_p, (1, 2)), # Input [B], output [B, 1, 1]. + mask=jnp.expand_dims(settings.mask, 1), # Input [B, V], output [B, 1, V]. + active=settings.active, # [B]. + ) + logits = lm_outputs.logits / settings.temperature.astype(lm_outputs.logits.dtype) + # Mask out all disallowed tokens by assigning them a near-zero probability. + logits = jnp.where(settings.mask, logits, -1e10) + # Mask out all tokens that don't fall into the p-th percentile. + logits = top_p_filter(logits, settings.nucleus_p.astype(logits.dtype)) + + new_token = jax.vmap(jax.random.categorical)(rngs, logits) + + probabilities = jax.nn.softmax(logits) + token_prob = jnp.take_along_axis(probabilities, jnp.expand_dims(new_token, 1), axis=2) + token_prob = jnp.squeeze(token_prob, 1) + + # Gather the top-k tokens and probabilities. + top_k_probs, top_k_token_ids = jax.lax.top_k(probabilities, TOP_K) + top_k_probs = jnp.squeeze(top_k_probs, 1) + top_k_token_ids = jnp.squeeze(top_k_token_ids, 1) + return SampleOutput( + new_token, + token_prob, + top_k_token_ids, + top_k_probs, + ) + + +@dataclass +class ModelRunner: + model: LanguageModelConfig + + bs_per_device: float = 2.0 + + load_rename_rules: Optional[list[tuple[str, str]]] = None + load_exclude_rules: Optional[list[str]] = None + + rng_seed: int = 42 # Initial rng seed. + transform_forward: bool = False + + checkpoint_path: str = "" + + def make_forward_fn(self, mesh: Any): + def forward(tokens): + out = self.model.make(mesh=mesh)(tokens) + return out, None + + if self.transform_forward: + forward = hk.transform(forward) + return forward + + def initialize( + self, + init_data, + local_mesh_config: tuple[int, int], + between_hosts_config: tuple[int, int], + ): + num_replicas = math.prod(between_hosts_config) + self.model.initialize() + self.model.fprop_dtype = jnp.bfloat16 + num_local_gpus = len(jax.local_devices()) + + # Calculate the global batch size from the local batch size. + self.batch_size = int(self.bs_per_device * num_local_gpus * num_replicas) + + # Calculate the batch size per host from the global batch size. + self.local_batch_size = self.batch_size // jax.process_count() + + self.local_mesh_config = local_mesh_config + self.between_hosts_config = between_hosts_config + rank_logger.info( + f"Initializing mesh for {self.local_mesh_config=} {self.between_hosts_config=}..." + ) + self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config) + self.forward = self.make_forward_fn(mesh=self.mesh) + self.logits_fn = hk.transform(lambda tokens: self.forward(tokens)[0]) + + self.eval_forward = self.make_forward_fn(mesh=self.mesh) + self.logits_eval_fn = hk.transform(lambda tokens: self.eval_forward(tokens)[0]) + + if self.transform_forward: + self.state_sharding = self.get_state_sharding(init_data) + rank_logger.info(f"State sharding type: {type(self.state_sharding)}") + self.init_fn = pjit.pjit(self.init, out_shardings=self.state_sharding) + + def init(self, rng: jax.Array, data) -> TrainingState: + assert self.transform_forward + rng, init_rng = jax.random.split(rng) + params = self.forward.init(init_rng, data["inputs"]) + return TrainingState(params=params) + + def get_state_sharding(self, init_data): + assert self.transform_forward + rng = jax.random.PRNGKey(self.rng_seed) + rank_logger.info(f"partition rules: {self.model.partition_rules}") + + with self.mesh: + shapes = jax.eval_shape(self.init, rng, init_data) + sharding = jax.tree_util.tree_map_with_path( + apply_rules(self.model.partition_rules()), + shapes, + ) + return sharding + + def load_or_init( + self, + init_data: Any, + from_checkpoint: bool = True, + init_fn: Optional[Callable] = None, + ): + rng = jax.random.PRNGKey(self.rng_seed) + + if not self.checkpoint_path or not from_checkpoint: + rank_logger.info("Initializing model...") + with self.mesh: + if init_fn is not None: + state = init_fn(rng, init_data) + else: + assert self.transform_forward + state = self.init_fn(rng, init_data) + rank_logger.info("Model state is newly initialized.") + else: + with self.mesh: + if init_fn: + state_shapes = jax.eval_shape(init_fn, rng, init_data) + else: + assert self.transform_forward + state_shapes = jax.eval_shape(self.init_fn, rng, init_data) + init_state = None + + state = xai_checkpoint.restore( + checkpoint_path=self.checkpoint_path, + state_shapes=state_shapes, + mesh=self.mesh, + between_hosts_config=self.between_hosts_config, + state_sharding=self.state_sharding, + init_state=init_state, + params_only=True, + ) + + del init_state + return state + + +@dataclass +class Request: + prompt: str + temperature: float + nucleus_p: float + rng_seed: int + max_len: int + + +@dataclass +class InferenceRunner: + name: str + runner: Any + load: str + tokenizer_path: str = "/tmp/xai_data/tokenizer.model" + local_mesh_config: Tuple[int, int] = (1, 1) + between_hosts_config: Tuple[int, int] = (1, 1) + pad_sizes: tuple[int] = (1024,) + + def get_pad_bucket(self, size): + i = bisect.bisect_left(self.pad_sizes, size) + return self.pad_sizes[min(i, len(self.pad_sizes) - 1)] + + def initialize(self): + runner = self.runner + self.runner.transform_forward = True + dummy_data = dict( + inputs=np.zeros((1, 256), dtype=np.int32), + targets=np.zeros((1, 256), dtype=np.int32), + ) + runner.initialize( + dummy_data, + local_mesh_config=self.local_mesh_config, + between_hosts_config=self.between_hosts_config, + ) + + self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=self.tokenizer_path) + + max_len = runner.model.sequence_len + + self.vocab_size = self.runner.model.vocab_size + + params = runner.load_or_init(dummy_data) + self.params = params + + def pad_to_max_len(x): + if len(x.shape) > 1: + pad_width = max_len - x.shape[1] + return jnp.pad(x, [(0, 0), (0, pad_width), (0, 0), (0, 0)]) + else: + return x + + @functools.lru_cache + def lm(): + return runner.model.make(mesh=runner.mesh) + + def hk_forward( + tokens, + memory=None, + length=None, + active=None, + ) -> LanguageModelOutput: + if memory is not None: + assert active is not None + layers = [] + for l in memory.layers: + # Reset steps to 0 for inactive requests to avoid unnecessary computations. + step = jnp.where(active, l.step, jnp.zeros_like(l.step)) + layers.append(l._replace(step=step)) + memory = memory._replace(layers=layers) + return lm()(tokens, memory, length=length) + + def hk_sample_step(rngs, last_output: SampleOutput, memory, settings): + rngs, rngs_ = jax.vmap(jax.random.split, out_axes=1)(rngs) + lm_outputs = hk_forward(last_output.token_id, memory=memory, active=settings.active) + sample_result = sample_token(rngs_, lm_outputs, settings) + return rngs, sample_result, lm_outputs.model_state + + def hk_new_memory(batch_size, sequence_len): + return lm().init_memory(batch_size, sequence_len) + + def hk_prefill_memory( + rngs, + memory, + settings, + last_output, + prompt, + length, + rng_seed, + new_settings, + i, + ): + rng = jax.random.PRNGKey(seed=rng_seed) + rng, rng_ = jax.random.split(rng) + + # Allocate new memory for this sample. The memory length is equal to the length of the + # prompt. + slice = hk_new_memory(1, prompt.shape[0]) + + # Move the settings for this individual batch entry into the joint settings tensor. + settings = jax.tree_map( + lambda o, v: jax.lax.dynamic_update_index_in_dim(o, v, i, axis=0), + settings, + new_settings, + ) + + # Get the settings for the batch entry from the joint settings tensor. + settings_slice = jax.tree_map(lambda t: jnp.expand_dims(t[i], axis=0), settings) + + # Process the first n-1 tokens of the prompt. + lm_outputs = hk_forward( + jnp.expand_dims(prompt, 0), + memory=slice, + length=jnp.expand_dims(length, 0), + active=settings_slice.active, + ) + + # The forward pass doesn't correctly set the `step` counter inside the memory. Manually + # override it so `hk_forward` uses the correct context length in the next call. + slice = lm_outputs.model_state + slice = slice._replace( + layers=[l._replace(step=jnp.array([length])) for l in slice.layers] + ) + + # Sample the actual output token. + rng_ = jnp.expand_dims(rng_, 0) + new_output = sample_token(rng_, lm_outputs, settings_slice) + + # Update the KV cache/memory. + slice = jax.tree_map(pad_to_max_len, slice) + memory = insert_slice(memory, slice, length, i) + + rng = jnp.expand_dims(rng, 0) + rngs = jax.lax.dynamic_update_index_in_dim(rngs, rng, i, axis=0) + + # Move the network outputs for this batch entry into the joint output tensor. + last_output = jax.tree_util.tree_map( + lambda last, new: jax.lax.dynamic_update_index_in_dim(last, new, i, axis=0), + last_output, + new_output, + ) + return rngs, last_output, memory, settings + + sample_step_ = hk.without_apply_rng(hk.transform(hk_sample_step)) + prefill_memory_ = hk.without_apply_rng(hk.transform(hk_prefill_memory)) + new_memory_ = hk.without_apply_rng(hk.transform(hk_new_memory)) + forward_ = hk.without_apply_rng(hk.transform(hk_forward)) + + rng = jax.random.PRNGKey(42) + dummy_tokens = jnp.zeros((1, max_len), jnp.int32) + + with runner.mesh: + shapes = jax.eval_shape(forward_.init, rng, dummy_tokens) + + self.params_sharding = jax.tree_util.tree_map_with_path( + apply_rules(runner.model.partition_rules()), + shapes, + ) + + ds = P("data") + ms = runner.model.model.get_memory_sharding() + self.sample_step = pjit.pjit( + sample_step_.apply, + in_shardings=(self.params_sharding, None, ds, ms, None), + out_shardings=(None, ds, ms), + donate_argnums=3, + ) + self.prefill_memory = pjit.pjit( + functools.partial(prefill_memory_.apply), + in_shardings=( + self.params_sharding, + None, + ms, + None, + ds, + None, + None, + None, + None, + None, + ), + out_shardings=(None, ds, ms, None), + donate_argnums=(2,), + ) + self.new_memory = pjit.pjit( + new_memory_.apply, + static_argnums=(1, 2), + out_shardings=ms, + ) + + def run(self): + """Generator that accepts prompts.""" + runner = self.runner + mesh = runner.mesh + max_len = runner.model.sequence_len + batch_size = runner.batch_size + params = self.params + rngs = jax.random.split(jax.random.PRNGKey(1), batch_size) + with mesh: + memory = self.new_memory(params, batch_size, max_len) + settings = SampleSettings( + temperature=np.zeros((batch_size,), dtype=np.float32), + nucleus_p=np.zeros((batch_size,), dtype=np.float32), + mask=np.ones((batch_size, self.vocab_size), dtype=np.int32), + active=np.zeros((batch_size), dtype=np.int32), + ) + last_output = SampleOutput( + token_id=np.zeros((batch_size, 1), dtype=np.int32), + prob=np.zeros((batch_size, 1), dtype=jnp.bfloat16), + top_k_token_ids=np.zeros((batch_size, TOP_K), dtype=np.int32), + top_k_probs=np.zeros((batch_size, TOP_K), dtype=jnp.bfloat16), + ) + + prompt = np.array([300, 400, 500, 600, 600, 700, 800]) + + new_settings = SampleSettings( + temperature=np.float32(1), + nucleus_p=np.float32(1), + mask=np.ones((self.vocab_size,), dtype=np.int32), + active=np.zeros((), dtype=np.int32), + ) + rng_seed = np.uint64(1) + + for size in self.pad_sizes: + if size > runner.model.sequence_len: + break + logger.info("Precompile {}".format(size)) + prompt_len = len(prompt) + prompt = pad_to_size(prompt, size) + rngs, last_output, memory, settings = self.prefill_memory( + params, + rngs, + memory, + settings, + last_output, + prompt, + prompt_len, + rng_seed, + new_settings, + 0, + ) + with runner.mesh: + logger.info("Compiling...") + rngs, last_output, memory = self.sample_step( + params, rngs, last_output, memory, settings + ) + logger.info("Done compiling.") + + all_tokens = [] + free_slots = list(range(batch_size)) + requests = [None] * batch_size + first_output = [None] * batch_size + jax.tree_map(lambda x: x.copy_to_host_async(), last_output) + prev_token = last_output + step = 0 + total_num_tokens = 0 + total_num_sequences = 0 + with mesh: + while True: + while free_slots: + request: Optional[Request] = yield + tokens = self.tokenizer.encode(request.prompt) + temperature = request.temperature + nucleus_p = request.nucleus_p + rng_seed = request.rng_seed + + i = free_slots.pop() + prompt = np.array(tokens, dtype=np.int32) + prompt_len = len(prompt) + prompt = pad_to_size(prompt, self.get_pad_bucket(prompt.shape[0])) + # All tokens are allowed. + mask = np.ones((self.vocab_size,), dtype=np.int32) + + new_settings = SampleSettings( + temperature=np.float32(temperature), + nucleus_p=np.float32(nucleus_p), + mask=mask, + active=np.ones((), dtype=np.int32), + ) + rng_seed = np.uint64(rng_seed) + rngs, last_output, memory, settings = self.prefill_memory( + params, + rngs, + memory, + settings, + last_output, + prompt, + prompt_len, + rng_seed, + new_settings, + i, + ) + jax.tree_map(lambda x: x.copy_to_host_async(), last_output) + first_output[i] = last_output + requests[i] = request + total_num_sequences += 1 + + rngs, last_output, memory = self.sample_step( + params, rngs, last_output, memory, settings + ) + total_num_tokens += batch_size - len(free_slots) + + # prev_token should already be on the host. + prev_token = jax.tree_map(np.array, prev_token) + for i in range(batch_size): + if requests[i] is not None: + if first_output[i] is not None: + first_output_i = jax.tree_map(np.array, first_output[i]) + all_tokens.append(int(first_output_i.token_id[i][0])) + first_output[i] = None + continue + + all_tokens.append(int(prev_token.token_id[i][0])) + cont = len(all_tokens) < requests[i].max_len + + if not cont: + output_str = self.tokenizer.decode(all_tokens) + requests[i] = None + free_slots.append(i) + all_tokens = [] + settings = settings._replace(active=settings.active.at[i].set(0)) + yield output_str + + jax.tree_map(lambda x: x.copy_to_host_async(), last_output) + prev_token = last_output + step += 1 + + +def make_mesh( + local_mesh_config: tuple[int, ...], between_hosts_config: tuple[int, ...] +) -> jax.sharding.Mesh: + assert len(local_mesh_config) == 2 + assert len(between_hosts_config) == 2 + rank_logger.info("Detected %s devices in mesh", jax.device_count()) + device_mesh = mesh_utils.create_hybrid_device_mesh( + local_mesh_config, + between_hosts_config, + devices=jax.devices(), + process_is_granule=True, + ) + rank_logger.debug(re.sub("\n+", "\n", f"Job device mesh is:\n{device_mesh}")) + return jax.sharding.Mesh(device_mesh, ("data", "model")) + + +def sample_from_model(server, prompt, max_len, temperature): + next(server) + inp = Request( + prompt=prompt, + temperature=temperature, + nucleus_p=1.0, + rng_seed=42, + max_len=max_len, + ) + return server.send(inp) diff --git a/grok-1-main/tokenizer.model b/grok-1-main/tokenizer.model new file mode 100644 index 0000000..d2ff64d Binary files /dev/null and b/grok-1-main/tokenizer.model differ diff --git a/grok-main/.gitignore b/grok-main/.gitignore new file mode 100644 index 0000000..1479974 --- /dev/null +++ b/grok-main/.gitignore @@ -0,0 +1,133 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +default +checkpoints +.vscode diff --git a/grok-main/LICENSE b/grok-main/LICENSE new file mode 100644 index 0000000..c123b69 --- /dev/null +++ b/grok-main/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 OpenAI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/grok-main/README.md b/grok-main/README.md new file mode 100644 index 0000000..d4b0b29 --- /dev/null +++ b/grok-main/README.md @@ -0,0 +1,32 @@ +# OpenAI Grok Curve Experiments + +> **Note:** This is `openai/grok` — the grokking research codebase. This fork also includes [`xai-org/grok-1`](../grok-1-main/) with a bridge between the two. See the [root README](../README.md) for context. + +## Paper + +This is the code for the paper [Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets](https://arxiv.org/abs/2201.02177) by Alethea Power, Yuri Burda, Harri Edwards, Igor Babuschkin, and Vedant Misra. + +## Installation and Training + +```bash +pip install -e . +./scripts/train.py +``` + +## Grok-1 Architecture Mode + +This fork adds Grok-1's architectural innovations (MoE, RoPE, RMSNorm, gated GELU) as an alternative architecture for grokking experiments: + +```bash +# Standard grokking experiment (original) +./scripts/train.py --math_operator + --train_data_pct 5 + +# Grok-1 architecture — same task, different optimizer geometry +./scripts/train.py --architecture grok1 --math_operator + --num_experts 8 + +# Auto-scaled miniature Grok-1 +./scripts/train.py --architecture grok1_mini +``` + +New classes in `grok/transformer.py`: `GrokOneTransformer`, `GrokOneMoELayer`, `RotaryPositionalEmbedding`, `RMSNorm`. +New metrics in `grok/metrics.py`: `expert_utilization_entropy`, `expert_specialization_score`, `routing_collapse_index`. diff --git a/grok-main/grok/__init__.py b/grok-main/grok/__init__.py new file mode 100644 index 0000000..12ab727 --- /dev/null +++ b/grok-main/grok/__init__.py @@ -0,0 +1,20 @@ +from . import transformer +from . import data +from . import training +from . import metrics +from . import visualization + +# ── 强兼 Bridge Exports ────────────────────────────────────────── +# The following classes are bridged from xai-org/grok-1's architecture, +# re-implemented in PyTorch for grokking experiments. +from .transformer import ( + Transformer, # Original OpenAI architecture + GrokOneTransformer, # 强兼: Grok-1 architecture (MoE + RoPE) + GrokOneMoELayer, # 强兼: Mixture of Experts layer + GrokOneRouter, # 强兼: Expert routing + RotaryPositionalEmbedding, # 强兼: RoPE from Grok-1 + RMSNorm, # 强兼: RMS LayerNorm from Grok-1 +) + +__version__ = "0.0.2-qiangjian" +__bridge__ = "openai/grok ←→ xai-org/grok-1" diff --git a/grok-main/grok/data.py b/grok-main/grok/data.py new file mode 100644 index 0000000..ad45b25 --- /dev/null +++ b/grok-main/grok/data.py @@ -0,0 +1,612 @@ +import itertools +import math +import os +import sys +import random + +import torch +from torch import Tensor, LongTensor +import numpy as np +from typing import Tuple, List, Dict, Any, Union, Optional +from tqdm import tqdm + +from sympy.combinatorics.permutations import Permutation +from mod import Mod + +import blobfile as bf + + +VALID_OPERATORS = { + "+": "addition", + "-": "subtraction", + "*": "muliplication", + "/": "division", + "**2+": "squarepoly", + "**3+": "cubepoly", + "x**2+y**2_mod_97": "quad1", + "x**2+y**2+x*y_mod_97": "quad2", + "x**2+y**2+x*y+x_mod_97": "quad3", + "x**3+x*y_mod_97": "cube1", + "x**3+x*y**2+y_mod_97": "cube2", + "(x._value//y)if(y._value%2==1)else(x-y)_mod_97": "mix1", + "s5": "s5", + "s5conj": "s5conj", + "s5aba": "s5aba", + "+*": "even-addition_odd-multiplication", + "+-": "even-addition_odd-subtraction", + "sort": "sort", + "reverse": "reverse", + "copy": "copy", +} +EOS_TOKEN = "<|eos|>" +EQ_TOKEN = "=" +MODULUS = 97 +NUMS = list(range(MODULUS)) + +DEFAULT_DATA_DIR = "data" + + +def render(operand, join_str=""): + if ( + isinstance(operand, list) + or isinstance(operand, tuple) + or isinstance(operand, np.ndarray) + ): + return join_str.join(map(render, operand)) + elif isinstance(operand, Permutation): + return "".join(map(str, operand.array_form)) + elif isinstance(operand, Mod): + return str(operand._value) + else: + return str(operand) + + +def create_data_files(data_dir: str = DEFAULT_DATA_DIR): + ArithmeticTokenizer.create_token_file(data_dir) + ArithmeticDataset.create_dataset_files(data_dir) + + +class ArithmeticTokenizer: + """Stores the list of token text to token id mappings and converts between them""" + + token_file = "tokens.txt" + + def __init__(self, data_dir=DEFAULT_DATA_DIR) -> None: + self.token_file = bf.join(data_dir, self.token_file) + + self.itos = self.get_tokens() + + self.stoi: Dict[str, int] = dict([(s, i) for i, s in enumerate(self.itos)]) + + def _encode(self, s: str) -> Tensor: + return LongTensor([self.stoi[t] for t in s.split(" ")]) + + def encode(self, obj: Union[str, List]) -> Tensor: + """ + Convert a string of text into a rank-1 tensor of token ids + or convert a list of strings of text into a rank-2 tensor of token ids + + :param obj: the string or list of strings to convert + :returns: a tensor of the token ids + """ + if isinstance(obj, str): + return self._encode(obj) + elif isinstance(obj, list): + return torch.stack([self._encode(s) for s in obj], dim=0) + else: + raise NotImplementedError + + def decode(self, tensor: Tensor, with_brackets: bool = False) -> str: + """ + Convert a tensor of token ids into a string of text + + :param tensor: a tensor of the token ids + :param with_brackets: if true, the returned string will include <> brackets + around the text corresponding to each token. + :returns: string of these tokens. + """ + indices = tensor.long() + if with_brackets: + l = "<" + r = ">" + else: + l = "" + r = "" + tokens = [l + self.itos[i] + r for i in indices] + return " ".join(tokens) + + def __len__(self) -> int: + """ + :returns: the number of tokens in this vocabulary + """ + return len(self.itos) + + @classmethod + def get_tokens(cls): + tokens = ( + [EOS_TOKEN, EQ_TOKEN] + + list(sorted(list(VALID_OPERATORS.keys()))) + + list(map(render, NUMS)) + + list(map(render, itertools.permutations(range(5)))) # s5 + ) + return tokens + + +class ArithmeticDataset: + """A Dataset of arithmetic equations""" + + @classmethod + def splits( + cls, + train_pct: float, + operator: str, + operand_length: Optional[int] = None, + data_dir: str = DEFAULT_DATA_DIR, + ): + """ + Creates training and validation datasets + + :param train_pct: percentage of total equations used for training data + :param operator: The arithmetic operator for this dataset e.g. '+', '-', '*', '/', 'sort' + :param operand_length: for list based datasets the length of the lists + :returns: (train_dataset, validation_dataset) + """ + + assert (0 < train_pct) and (train_pct < 100) + + ds_name = cls.get_dsname(operator, operand_length) + eqs = cls.make_data(operator, operand_length) + + train_rows, _ = cls.calc_split_len(train_pct, len(eqs)) + + train_ds = cls(ds_name, eqs[:train_rows], train=True, data_dir=data_dir) + val_ds = cls(ds_name, eqs[train_rows:], train=False, data_dir=data_dir) + + return train_ds, val_ds + + @classmethod + def calc_split_len(cls, train_pct, ds_len): + train_rows = round(ds_len * (train_pct / 100.0)) + val_rows = ds_len - train_rows + return train_rows, val_rows + + def __init__(self, name, data: Union[Tensor, List[str]], train, data_dir) -> None: + """ + :param data: A list of equations strings. Each equation must have an '=' in it. + """ + self.tokenizer = ArithmeticTokenizer(data_dir) + self.name = name + self.train = train + if isinstance(data, list): + self.data = self.tokenizer.encode(data) + else: + self.data = data + + def __len__(self) -> int: + """ + :returns: total number of equations in this dataset + """ + return self.data.shape[0] + + # @classmethod + # def _render(cls, operand): + # return render(operand, join_str=" ") + # + # @classmethod + # def _render_eq(parts): + # return " ".join(map(render, parts)) + + @classmethod + def _make_binary_operation_data(cls, operator: str, operands=None) -> List[str]: + if operator == "s5": + operands = operands or list(range(5)) + elems = map(np.array, itertools.permutations(operands)) + tuples = itertools.product(elems, repeat=2) + elif operator in ["s5conj", "s5aba"]: + operands = operands or list(range(5)) + elems = map(Permutation, itertools.permutations(operands)) + tuples = itertools.product(elems, repeat=2) + elif "_mod_" in operator: + modulo = int(operator.split("_mod_")[-1]) + elems = [Mod(i, modulo) for i in range(modulo)] + tuples = itertools.product(elems, repeat=2) + else: + operands = operands or NUMS + tuples = itertools.product(operands, repeat=2) + + # if operator == "s5": + # print("elems", list(elems)) + # print("tuples", list(tuples)) + eqs = [] + for a, b in tuples: + if operator == "/": + if b == 0: + continue + else: + c = a + a = (b * c) % MODULUS + elif operator == "s5": + c = b[a] + elif operator == "s5conj": + c = a * b * (a.__invert__()) + elif operator == "s5aba": + c = a * b * a + elif operator == "+*": + if a % 2 == 0: + c = (a + b) % MODULUS + else: + c = (a * b) % MODULUS + elif operator == "+-": + if a % 2 == 0: + c = (a + b) % MODULUS + else: + c = (a - b) % MODULUS + elif "_mod_" in operator: + expression = operator.split("_mod_")[0] + function = eval(f"lambda x, y: ({expression})") + c = function(a, b) + else: + c = eval(f"({a} {operator} {b}) % {MODULUS}") + eq = " ".join(map(render, [a, operator, b, "=", c])) + eqs.append(eq) + + # if operator == "s5": + # print("eqs", eqs) + return eqs + + # @staticmethod + # def _render_unop_example(operator, lhs, rhs): + # return " ".join([operator, render(lhs), "=", render(rhs)]) + + @staticmethod + def _make_unary_operation_data(operator: str, operands: Tensor) -> List[str]: + """ + :param operator: The unary operator to apply to each operand e.g. '+' + :param operands: A tensor of operands + :returns: list of equations""" + num_examples = len(operands) + + if operator == "sort": + rhs = torch.sort(operands, dim=1)[0] + elif operator == "reverse": + rhs = torch.flip(operands, dims=(1,)) + elif operator == "copy": + rhs = operands + else: + raise Exception("unsupported operator") + + def func(L, R): + L = map(str, L) + R = map(str, R) + return f"{operator} {' '.join(L)} = {' '.join(R)}" + + if num_examples < 1000000000: + eqs = [ + func(L, R) + for L, R in tqdm( + zip(operands.tolist(), rhs.tolist()), total=num_examples + ) + ] + else: + with ProcessPoolExecutor() as executor: + eqs = executor.map(func, tqdm(zip(operands, rhs), total=num_examples)) + + return eqs + + # @staticmethod + # def _make_s5_data(abstract=False) -> List[str]: + # elems = itertools.permutations([0, 1, 2, 3, 4]) + # pairs = itertools.product(elems, repeat=2) + # eqs = [] + # for a, b in pairs: + # a = np.array(a) + # b = np.array(b) + # c = b[a] + # eq = " ".join(map(render, (a, "s5", b, "=", c))) + # eq = cls._render_eq([a, , b, "=", c]) + # eqs.append(eq) + # + # return eqs + + @classmethod + def get_dsname(cls, operator, operand_length) -> str: + operator, noise_level = cls._get_operator_and_noise_level(operator) + ds_name = VALID_OPERATORS[operator] + if operand_length is not None: + ds_name += f"_length-{operand_length}" + if noise_level > 0: + ds_name += f"_noise-{noise_level}" + return ds_name + + @classmethod + def get_file_path(cls, operator, operand_length=None, data_dir=DEFAULT_DATA_DIR): + ds_name = cls.get_dsname(operator, operand_length) + ds_file = bf.join(data_dir, f"{ds_name}_data.txt") + return ds_file, ds_name + + @classmethod + def _get_operator_and_noise_level(cls, operator): + if "_noisy" in operator: + operator, noise_level = operator.split("_noisy_") + return operator, int(noise_level) + else: + return operator, 0 + + @classmethod + def make_data(cls, operator, operands=None, shuffle=True, seed=0) -> List[str]: + operator, noise_level = cls._get_operator_and_noise_level(operator) + assert operator in VALID_OPERATORS + + if operator not in ["sort", "reverse", "copy"]: + data = cls._make_binary_operation_data(operator) + else: + data = cls._make_unary_operation_data(operator, operands) + + rng = np.random.RandomState(seed=seed) + if shuffle: + rng.shuffle(data) + + if noise_level > 0: + random_answer_eqns = rng.choice(data, size=noise_level) + random_answers = [ + random_eq.split(" = ")[1] for random_eq in random_answer_eqns + ] + for i in range(noise_level): + data[i] = data[i].split(" = ")[0] + " = " + random_answers[i] + + data = [EOS_TOKEN + " " + eq + " " + EOS_TOKEN for eq in data] + + return data + + # @classmethod + # def create_data_file( + # cls, operator, operand_length=None, shuffle=True, data_dir=DEFAULT_DATA_DIR + # ): + # if VALID_OPERATORS[operator]["binary_eval"]: + # cls.write_dataset( + # cls.make_binary_operation_data(operator), paths["ds_file"] + # ) + # + # pass + + # @classmethod + # def write_dataset(eqs: List[str], ds_file: str): + # print(f"-> writing {ds_file}", flush=True) + # with open(ds_file, "w") as fh: + # fh.writelines([EOS_TOKEN + " " + eq + " " + EOS_TOKEN + "\n" for eq in eqs]) + + @classmethod + def _make_lists(cls, sizes=[2, 3], nums=NUMS): + lists: dict = {} + for size in sizes: + lists[size] = torch.tensor( + list(itertools.permutations(nums, r=size)), + dtype=torch.int, + ) + return lists + + +class ArithmeticIterator(torch.utils.data.IterableDataset): + """ + An iterator over batches of data in an ArithmeticDataset + """ + + def __init__( + self, + dataset: ArithmeticDataset, + device: torch.device, + batchsize_hint: float = 0, + shuffle: bool = True, + ) -> None: + """ + :param dataset: the dataset to iterate over + :param device: the torch device to send batches to + :param batchsize_hint: * 0 means we use a default batchsize + * -1 means the entire dataset + * float between 0 and 1 means each batch is + that fraction of the DS + * int > 1 means that specific batch size + :param shuffle: whether or not to randomly shuffle the dataset + """ + self.dataset = dataset + self.batchsize = self.calculate_batchsize( + len(dataset), batchsize_hint=batchsize_hint + ) + self.device = device + self.reset_iteration(shuffle=shuffle) + + @staticmethod + def calculate_batchsize(ds_size: int, batchsize_hint: int = 0) -> int: + """ + Calculates which batch size to use + + :param ds_size: the number of equations in the dataset + :param batchsize_hint: * 0 means we use a default batchsize + * -1 means the entire dataset + * float between 0 and 1 means each batch is + that fraction of the DS + * int > 1 means that specific batch size + :returns: the actual batchsize to use + """ + + if batchsize_hint == -1: + return ds_size + elif batchsize_hint == 0: + return min(512, math.ceil(ds_size / 2.0)) + elif (batchsize_hint > 0) and (batchsize_hint < 1): + return math.ceil(ds_size * batchsize_hint) + elif batchsize_hint > 1: + return min(batchsize_hint, ds_size) + else: + raise ValueError("batchsize_hint must be >= -1") + + def reset_iteration(self, shuffle=True): + self.index = 0 + if shuffle and self.dataset.train: + self.permutation = torch.randperm(len(self.dataset)) + else: + self.permutation = torch.arange(len(self.dataset)) + + def __iter__(self): + """ + :returns: this iterator + """ + return self + + def __next__(self) -> Dict[str, Tensor]: + """ + Returns one batch of data. + + :raises: StopIteration when we're out of data + :returns: batch tensor of shape (self.batchsize, tokens_per_eq) + """ + + batch_begin = self.index * self.batchsize + if batch_begin > len(self.dataset) - 1: + self.reset_iteration() + raise StopIteration + indices = self.permutation[batch_begin : batch_begin + self.batchsize] + text = self.dataset.data[indices, :-1] + target = self.dataset.data[indices, 1:] + batch = {"text": text.to(self.device), "target": target.to(self.device)} + self.index += 1 + return batch + + def __len__(self) -> int: + """ + :returns: the total number of batches + """ + return math.ceil(len(self.dataset) / self.batchsize) + + +# ============================================================================= +# 强兼 BRIDGE — Grok-1 Inference Data Adapter +# ============================================================================= +# +# Converts arithmetic equations from the grokking dataset format into natural +# language prompts suitable for Grok-1 (or any LLM) inference. This enables +# testing whether Grok-1 has "grokked" the arithmetic tasks that OpenAI's +# grokking paper studies. +# +# Source: https://github.com/openai/grok → xai-org/grok-1 +# ============================================================================= + + +OPERATOR_TO_NATURAL = { + "+": "plus", + "-": "minus", + "*": "times", + "/": "divided by", + "**2+": "squared plus", + "**3+": "cubed plus", +} + + +def format_for_grok1( + equation: str, + style: str = "direct", +) -> str: + """ + Converts a grokking-format equation into a natural language prompt + for Grok-1 inference. + + :param equation: equation string, e.g. "<|eos|> 42 + 55 = 0 <|eos|>" + :param style: 'direct' for "What is 42 + 55 mod 97?" + 'cot' for chain-of-thought prompting + 'raw' for minimal "42 + 55 =" + :returns: prompt string for LLM inference + """ + # Strip EOS tokens + eq = equation.replace(EOS_TOKEN, "").strip() + + # Split at '=' to get LHS + parts = eq.split("=") + lhs = parts[0].strip() + + if style == "raw": + return f"{lhs} =" + + elif style == "direct": + return ( + f"Calculate the following modular arithmetic (mod {MODULUS}): " + f"What is {lhs} mod {MODULUS}? " + f"Answer with just the number." + ) + + elif style == "cot": + return ( + f"Let's solve this step by step.\n" + f"Calculate: {lhs}\n" + f"Then take the result modulo {MODULUS}.\n" + f"Show your work, then give the final answer as a single number." + ) + + else: + return f"{lhs} =" + + +def parse_grok1_response( + response: str, + expected: Optional[int] = None, +) -> Dict[str, Any]: + """ + Parse Grok-1's inference output for an arithmetic problem. + + :param response: raw text output from Grok-1 + :param expected: if provided, check correctness + :returns: dict with 'answer' (int or None), 'correct' (bool or None), + 'raw_response' (str) + """ + import re + + # Try to extract the last number in the response + numbers = re.findall(r'\b(\d+)\b', response) + answer = int(numbers[-1]) if numbers else None + + result = { + "answer": answer, + "raw_response": response, + "correct": None, + } + if expected is not None and answer is not None: + result["correct"] = (answer % MODULUS) == (expected % MODULUS) + + return result + + +def make_grok1_eval_suite( + operator: str = "+", + n_samples: int = 100, + seed: int = 42, +) -> List[Dict[str, Any]]: + """ + Generate an evaluation suite of arithmetic problems for Grok-1. + + :param operator: arithmetic operator to test + :param n_samples: number of test problems + :param seed: random seed + :returns: list of dicts with 'prompt', 'equation', 'expected_answer' + """ + eqs = ArithmeticDataset.make_data(operator, shuffle=True, seed=seed) + rng = random.Random(seed) + sampled = rng.sample(eqs, min(n_samples, len(eqs))) + + suite = [] + for eq in sampled: + eq_clean = eq.replace(EOS_TOKEN, "").strip() + parts = eq_clean.split("=") + if len(parts) == 2: + expected = parts[1].strip() + try: + expected_int = int(expected) + except ValueError: + expected_int = None + suite.append({ + "equation": eq_clean, + "prompt_direct": format_for_grok1(eq, style="direct"), + "prompt_cot": format_for_grok1(eq, style="cot"), + "prompt_raw": format_for_grok1(eq, style="raw"), + "expected_answer": expected_int, + }) + + return suite diff --git a/grok-main/grok/measure.py b/grok-main/grok/measure.py new file mode 100755 index 0000000..e7d31e2 --- /dev/null +++ b/grok-main/grok/measure.py @@ -0,0 +1,139 @@ +import logging +import torch +import numpy as np + +import scipy.optimize + + +def get_loss_and_grads(x, model, data_loader): + + # if type(x).__module__ == np.__name__: + # x = torch.from_numpy(x).float() + # x = x.cuda() + + model.eval() + + x_start = 0 + for p in model.parameters(): + param_size = p.data.size() + param_idx = 1 + for s in param_size: + param_idx *= s + x_part = x[x_start : x_start + param_idx] + p.data = torch.Tensor(x_part.reshape(param_size)) + x_start += param_idx + + batch_losses = [] + batch_grads = [] + for it, batch in enumerate(data_loader): + + # Move data to correct device + # inputs = inputs.to(device) + # targets = targets.to(device) + + with torch.set_grad_enabled(True): + # loss, grads = model(idx=inputs, targets=targets, grads=True) + loss, grads = model._step(batch=batch, batch_idx=1, train=True, grads=True) + + # Todo: average over dataset + batch_losses.append(loss) + # batch_grads.append(None if grads is None else grads.cpu().numpy().astype(np.float64)) + batch_grads.append(None if grads is None else grads) + + mean_losses = torch.mean(torch.stack(batch_losses)) + mean_grads = torch.mean(torch.stack(batch_grads), dim=0) + + return (mean_losses, mean_grads.cpu().numpy().astype(np.float64)) + + +def get_weights(model): + """ + Given a model, return a vector of weights. + """ + x0 = None + for p in model.parameters(): + if x0 is None: + x0 = p.data.view(-1) + else: + x0 = torch.cat((x0, p.data.view(-1))) + return x0.cpu().numpy() + + +def get_sharpness(data_loader, model, subspace_dim=10, epsilon=1e-3, maxiter=10): + """ + Compute the sharpness around some point in weight space, as specified + in Keskar et. al. (2016) Sec 2.2.2: + https://arxiv.org/pdf/1609.04836.pdf + + See: + https://gist.github.com/arthurmensch/c55ac413868550f89225a0b9212aa4cd + https://gist.github.com/gngdb/a9f912df362a85b37c730154ef3c294b + https://github.com/keskarnitish/large-batch-training + https://github.com/wenwei202/smoothout + https://github.com/keras-team/keras/pull/3064 + """ + + x0 = get_weights(model) + + f_x0, _ = get_loss_and_grads(x0, model, data_loader) + f_x0 = -f_x0 + logging.info("min loss f_x0 = {loss:.4f}".format(loss=f_x0)) + + if 0 == subspace_dim: + x_min = np.reshape(x0 - epsilon * (np.abs(x0) + 1), (x0.shape[0], 1)) + x_max = np.reshape(x0 + epsilon * (np.abs(x0) + 1), (x0.shape[0], 1)) + bounds = np.concatenate([x_min, x_max], 1) + func = lambda x: get_loss_and_grads(x, model, data_loader) + init_guess = x0 + else: + assert subspace_dim <= x0.shape[0] + + # Computed via Keskar, et. al + # https://arxiv.org/pdf/1609.04836.pdf + + A_plus = np.random.rand(subspace_dim, x0.shape[0]) * 2.0 - 1.0 + A_plus_norm = np.linalg.norm(A_plus, axis=1) + A_plus = A_plus / np.reshape(A_plus_norm, (subspace_dim, 1)) + A = np.linalg.pinv(A_plus) + + abs_bound = epsilon * (np.abs(np.dot(A_plus, x0)) + 1) + abs_bound = np.reshape(abs_bound, (abs_bound.shape[0], 1)) + bounds = np.concatenate([-abs_bound, abs_bound], 1) + + def func(y): + f_loss, f_grads = get_loss_and_grads( + x0 + np.dot(A, y), + model, + data_loader, + ) + return f_loss, np.dot(np.transpose(A), f_grads) + + init_guess = np.zeros(subspace_dim) + + minimum_x, f_x, d = scipy.optimize.fmin_l_bfgs_b( + func, + init_guess, + maxiter=maxiter, + bounds=bounds, + disp=1, + ) + f_x = -f_x + logging.info("max loss f_x = {loss:.4f}".format(loss=f_x)) + + # Eq 4 in Keskar + phi = (f_x - f_x0) / (1 + f_x0) * 100 + + # Restore parameter values + x0 = torch.from_numpy(x0).float() + # x0 = x0.cuda() + x_start = 0 + for p in model.parameters(): + param_size = p.data.size() + param_idx = 1 + for s in param_size: + param_idx *= s + x_part = x0[x_start : x_start + param_idx] + p.data = x_part.view(param_size) + x_start += param_idx + + return phi diff --git a/grok-main/grok/metrics.py b/grok-main/grok/metrics.py new file mode 100644 index 0000000..d404ff9 --- /dev/null +++ b/grok-main/grok/metrics.py @@ -0,0 +1,372 @@ +import torch +import math +import copy +import torch.nn as nn +from typing import Callable + +# References: +# https://github.com/nitarshan/robust-generalization-measures +# https://github.com/bneyshabur/generalization-bounds +# https://github.com/bneyshabur/over-parametrization + + +def compute_measure( + model: nn.Module, + init_model: nn.Module, + measure_func: Callable, + operator: str, + kwargs: dict = {}, + p: int = 1, +) -> float: + """ + Computes measure value for each layer given trained network and network at + initialization. Then aggregates values per layer using specified operator. + + :param model: trained network + :param init_model: network at initialization + :param measure_func: callable for the measure to compute + :param operator: 'log_product', 'sum', 'max', 'product', or 'norm' + :param p: p in L^p + :return: value of the desired measure + """ + + measure_value = 0 + # weight_modules = ["Linear", "Embedding"] + weight_modules = ["Linear"] + + if operator == "product": + measure_value = math.exp( + compute_measure(model, init_model, measure_func, "log_product", kwargs, p) + ) + elif operator == "norm": + measure_value = ( + compute_measure(model, init_model, measure_func, "sum", kwargs, p=p) + ) ** (1 / p) + else: + measure_value = 0 + for child, init_child in zip(model.children(), init_model.children()): + module_name = child._get_name() + if module_name in weight_modules: + if operator == "log_product": + measure_value += math.log(measure_func(child, init_child, **kwargs)) + elif operator == "sum": + measure_value += (measure_func(child, init_child, **kwargs)) ** p + elif operator == "max": + measure_value = max( + measure_value, measure_func(child, init_child, **kwargs) + ) + else: + measure_value += compute_measure( + child, init_child, measure_func, operator, kwargs, p=p + ) + return measure_value + + +def norm(module, init_module, p=2, q=2): + """ + Calculates l_pq norm of a parameter matrix + l_p norm of incoming weights to each hidden unit + l_q norm on the hidden units + """ + return module.weight.view(module.weight.size(0), -1).norm(p=p, dim=1).norm(q).item() + + +def op_norm(module, init_module, p=float("Inf")): + """ + Calculates l_p norm of eigenvalues of parameter matrix + """ + _, S, _ = module.weight.view(module.weight.size(0), -1).svd() + return S.norm(p).item() + + +def dist(module, init_module, p=2, q=2): + """ + Calculates l_pq distance of the parameter matrix of a layer from the random + initialization: + l_p norm of incoming weights to each hidden unit + l_q norm on the hidden units + """ + return ( + (module.weight - init_module.weight) + .view(module.weight.size(0), -1) + .norm(p=p, dim=1) + .norm(q) + .item() + ) + + +def h_dist(module, init_module, p=2, q=2): + """ + Calculate l_pq distance of parameters of trained network from random init + Includes extra factor depending on number of hidden units + """ + return (n_hidden(module, init_module) ** (1 - 1 / q)) * dist( + module, init_module, p=p, q=q + ) + + +def h_dist_op_norm(module, init_module, p=2, q=2, p_op=float("Inf")): + """ + Calculate ratio of h_dist to operator norm + """ + return h_dist(module, init_module, p=p, q=q) / op_norm(module, init_module, p=p_op) + + +def n_hidden(module, init_module): + """ + Number of hidden units + """ + return module.weight.size(0) + + +def depth(module, init_module): + """ + Depth (always == 1 for any linear layer) + """ + return 1 + + +def n_param(module, init_module): + """ + Num parameters + """ + bparam = 0 if module.bias is None else module.bias.size(0) + return bparam + module.weight.size(0) * module.weight.view( + module.weight.size(0), -1 + ).size(1) + + +def lp_path_norm(model, device, p=2, input_size=[3, 32, 32]): + """ + Path norm (Neyshabur 2015) + """ + + tmp_model = copy.deepcopy(model) + tmp_model.eval() + for param in tmp_model.parameters(): + if param.requires_grad: + param.abs_().pow_(p) + data_ones = torch.ones(input_size).to(device) + return (tmp_model(data_ones).sum() ** (1 / p)).item() + + +def calculate(trained_model, init_model, device, dataset_size, margin, input_dim): + """ + Calculates various measures given trained model and model at init + Computes: + measures: norm based measures on the model + bounds: generalization bounds on the model + """ + + model = copy.deepcopy(trained_model) + + # depth + d = compute_measure(model, init_model, depth, "sum", {}) + + # number of parameters (not including batch norm) + nparam = compute_measure(model, init_model, n_param, "sum", {}) + + measure, bound = {}, {} + with torch.no_grad(): + + # Compute measures + measure["L_{1,inf} norm"] = ( + compute_measure( + model, init_model, norm, "product", {"p": 1, "q": float("Inf")} + ) + / margin + ) + measure["Frobenius norm"] = ( + compute_measure(model, init_model, norm, "product", {"p": 2, "q": 2}) + / margin + ) + measure["L_{3,1.5} norm"] = ( + compute_measure(model, init_model, norm, "product", {"p": 3, "q": 1.5}) + / margin + ) + measure["Spectral norm"] = ( + compute_measure(model, init_model, op_norm, "product", {"p": float("Inf")}) + / margin + ) + measure["L_1.5 operator norm"] = ( + compute_measure(model, init_model, op_norm, "product", {"p": 1.5}) / margin + ) + measure["Trace norm"] = ( + compute_measure(model, init_model, op_norm, "product", {"p": 1}) / margin + ) + + # input_size = [context_len, emb_dim] + # measure["L1_path norm"] = ( + # lp_path_norm( + # model, device, p=1, input_size=input_size + # ) + # / margin + # ) + # measure["L1.5_path norm"] = ( + # lp_path_norm( + # model, device, p=1.5, input_size=input_size + # ) + # / margin + # ) + # measure["L2_path norm"] = ( + # lp_path_norm( + # model, device, p=2, input_size=input_size + # ) + # / margin + # ) + + # Compute generalization bounds without constant or additive logarithmic factors + + # Golowich 2018 + # https://arxiv.org/pdf/1712.06541.pdf + alpha = math.sqrt(d + math.log(1 * input_dim * input_dim)) + + # Bartlett Mendelson 2002 + bound["L1_max Bound"] = ( + alpha * measure["L_{1,inf} norm"] / math.sqrt(dataset_size) + ) + + # Neyshabur 2015 + bound["Frobenius Bound"] = ( + alpha * measure["Frobenius norm"] / math.sqrt(dataset_size) + ) + + # Neyshabur 2015 + bound["L_{3,1.5} Bound"] = ( + alpha * measure["L_{3,1.5} norm"] / (dataset_size ** (1 / 3)) + ) + + beta = math.log(dataset_size) * math.log(nparam) + ratio = compute_measure( + model, + init_model, + h_dist_op_norm, + "norm", + {"p": 2, "q": 1, "p_op": float("Inf")}, + p=2 / 3, + ) + + # Spectral L_{2, 1} Bound + # Bartlett 2017 + bound["Spec_L_{2,1} Bound"] = ( + beta * measure["Spectral norm"] * ratio / math.sqrt(dataset_size) + ) + + ratio = compute_measure( + model, + init_model, + h_dist_op_norm, + "norm", + {"p": 2, "q": 2, "p_op": float("Inf")}, + p=2, + ) + + # Spectral Frobenius + # Neyshabur 2018 + # https://arxiv.org/pdf/1706.08947.pdf + bound["Spec_Fro Bound"] = ( + d * measure["Spectral norm"] * ratio / math.sqrt(dataset_size) + ) + + return measure, bound + + +# ============================================================================= +# 强兼 BRIDGE — Mixture of Experts Grokking Metrics +# ============================================================================= +# +# Additional metrics for studying the grokking phenomenon in MoE architectures. +# These capture dynamics specific to Grok-1's expert routing that standard +# dense transformer metrics miss. +# +# Key insight: MoE models may "grok" differently — expert specialization +# could correlate with the phase transition from memorization to generalization. +# ============================================================================= + + +def expert_utilization_entropy(router_probs: torch.Tensor) -> float: + """ + Shannon entropy of the expert selection distribution. + Higher entropy = more uniform expert utilization = better load balance. + + In the context of grokking: if entropy increases during the phase + transition, it suggests generalization requires distributing computation + more evenly across experts (moving away from memorization shortcuts). + + :param router_probs: [batch, seq_len, num_experts] probability tensor + :returns: scalar entropy value + """ + mean_probs = router_probs.mean(dim=(0, 1)) # [num_experts] + entropy = -(mean_probs * (mean_probs + 1e-8).log()).sum().item() + return entropy + + +def expert_specialization_score(router_probs: torch.Tensor) -> float: + """ + Measures how specialized individual experts are (vs. uniform). + Score of 0 = all experts identical, 1 = maximum specialization. + + In grokking: high specialization during training plateau (memorization) + that drops during the grokking phase transition would indicate that + memorization relies on expert shortcuts while generalization doesn't. + + :param router_probs: [batch, seq_len, num_experts] probability tensor + :returns: specialization score in [0, 1] + """ + mean_probs = router_probs.mean(dim=(0, 1)) # [num_experts] + num_experts = mean_probs.shape[0] + uniform = 1.0 / num_experts + max_deviation = 1.0 - uniform + actual_deviation = (mean_probs - uniform).abs().max().item() + return actual_deviation / max_deviation if max_deviation > 0 else 0.0 + + +def routing_collapse_index(router_probs: torch.Tensor) -> float: + """ + Detects expert routing collapse — when the model routes most tokens + to a single expert, wasting MoE capacity. + + Defined as: max(expert_load) / mean(expert_load) + Collapse threshold: > 2.0 indicates significant imbalance + Perfect balance: 1.0 + + :param router_probs: [batch, seq_len, num_experts] probability tensor + :returns: collapse index (1.0 = balanced, higher = more collapsed) + """ + mean_probs = router_probs.mean(dim=(0, 1)) + return (mean_probs.max() / mean_probs.mean()).item() + + +def compute_moe_grokking_metrics(model: nn.Module, init_model: nn.Module) -> dict: + """ + Computes all MoE-specific grokking metrics for a GrokOneTransformer. + Call after a forward pass to analyze the cached router probabilities. + + :param model: trained GrokOneTransformer + :param init_model: GrokOneTransformer at initialization + :returns: dict of metric name → value + """ + from grok.transformer import GrokOneTransformer + + moe_metrics = {} + if isinstance(model, GrokOneTransformer) and model.last_router_probs: + for layer_idx, rp in enumerate(model.last_router_probs): + if rp is not None: + prefix = f"layer_{layer_idx}" + moe_metrics[f"{prefix}/routing_entropy"] = expert_utilization_entropy(rp) + moe_metrics[f"{prefix}/specialization"] = expert_specialization_score(rp) + moe_metrics[f"{prefix}/collapse_index"] = routing_collapse_index(rp) + + # Aggregate across layers + all_rp = torch.stack(model.last_router_probs) + moe_metrics["avg_routing_entropy"] = expert_utilization_entropy( + all_rp.mean(dim=0) + ) + moe_metrics["avg_specialization"] = expert_specialization_score( + all_rp.mean(dim=0) + ) + moe_metrics["avg_collapse_index"] = routing_collapse_index( + all_rp.mean(dim=0) + ) + + return moe_metrics diff --git a/grok-main/grok/training.py b/grok-main/grok/training.py new file mode 100755 index 0000000..4d51750 --- /dev/null +++ b/grok-main/grok/training.py @@ -0,0 +1,1134 @@ +#!/usr/bin/env python + +import argparse +import copy +import json +import logging +import math +import os +import sys +import pickle +from argparse import ArgumentParser, Namespace +from functools import reduce +from typing import Any, Dict, List, Optional, Tuple, Union +import time + +import numpy as np +import torch +import torch.nn.functional as F +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import Callback, ModelCheckpoint +from pytorch_lightning.loggers import CSVLogger +from torch import Tensor +from torch.optim.lr_scheduler import LambdaLR + +import grok.metrics as metrics +from grok.data import ( + DEFAULT_DATA_DIR, + EOS_TOKEN, + VALID_OPERATORS, + ArithmeticDataset, + ArithmeticIterator, +) +from grok.transformer import Transformer, GrokOneTransformer +from grok.measure import get_sharpness + +DEFAULT_LOG_DIR = "logs" + + +class TrainableTransformer(LightningModule): + """ + Adds training methods to train a generic transformer on arithmetic equations + """ + + def __init__(self, hparams: Namespace) -> None: + """ + :param hparams: An argparse.Namespace with parameters defined in + self.add_model_specific_args(). + """ + super().__init__() + self.hparams = hparams # type: ignore + self.prepare_data() + + vocab_len = len(self.train_dataset.tokenizer) + + # ── 强兼 Bridge: architecture selection ────────────────────── + # --architecture grok1 → use Grok-1-style MoE + RoPE + RMSNorm + # --architecture standard (default) → use original OpenAI transformer + architecture = getattr(hparams, "architecture", "standard") + + if architecture == "grok1": + self.transformer = GrokOneTransformer( + n_layers=hparams.n_layers, + n_heads=hparams.n_heads, + d_model=hparams.d_model, + dropout=hparams.dropout, + max_context_len=hparams.max_context_len, + vocab_len=vocab_len, + num_experts=getattr(hparams, "num_experts", 8), + num_selected_experts=getattr(hparams, "num_selected_experts", 2), + widening_factor=getattr(hparams, "widening_factor", 4), + weight_noise=self.hparams.weight_noise, + ) + elif architecture == "grok1_mini": + # Auto-scaled miniature of the real Grok-1 architecture + self.transformer = GrokOneTransformer.from_grok1_config( + scale_factor=getattr(hparams, "grok1_scale", 1/24), + vocab_len=vocab_len, + max_context_len=hparams.max_context_len, + dropout=hparams.dropout, + weight_noise=self.hparams.weight_noise, + ) + else: + self.transformer = Transformer( + hparams.n_layers, + hparams.n_heads, + hparams.d_model, + hparams.dropout, + hparams.max_context_len, + vocab_len, + hparams.non_linearity, + weight_noise=self.hparams.weight_noise, + ) + + self.margin = torch.Tensor([0]) + self.next_epoch_to_eval = -1 + self.next_train_epoch_to_log = 0 + + @staticmethod + def add_model_specific_args(parser: ArgumentParser) -> ArgumentParser: + """ + Defines the hyperparameter arguments needed by instances of this + class. This is intended to be called when parsing command line + arguments. + + :param parser: an argparse.ArgumentParser created by the caller + :returns: the argument parser with the command line arguments added + for this class. + """ + parser.add_argument( + "--batchsize", + type=float, + # default=0.25, + default=0, + help="-1 -> entire dataset, 0 -> auto-calculate, 0 fraction of dataset, N>1 -> N", + ) + + parser.add_argument("--n_layers", type=int, default=2) + parser.add_argument("--n_heads", type=int, default=4) + parser.add_argument("--d_model", type=int, default=128) + parser.add_argument("--dropout", type=float, default=0.0) + parser.add_argument("--weight_noise", type=float, default=0.0) + parser.add_argument("--non_linearity", type=str, default="relu") + parser.add_argument("--max_context_len", type=int, default=50) + + parser.add_argument("--math_operator", type=str, default="+") + parser.add_argument( + "--operand_length", + type=int, + help="for list operations, the length of the lists", + ) + + # ── 强兼 Bridge: Grok-1 architecture options ────────────── + parser.add_argument( + "--architecture", type=str, default="standard", + choices=["standard", "grok1", "grok1_mini"], + help="standard: original OpenAI transformer; " + "grok1: Grok-1 MoE+RoPE architecture; " + "grok1_mini: auto-scaled miniature Grok-1", + ) + parser.add_argument("--num_experts", type=int, default=8, + help="Number of MoE experts (Grok-1 default: 8)") + parser.add_argument("--num_selected_experts", type=int, default=2, + help="Experts selected per token (Grok-1 default: 2)") + parser.add_argument("--widening_factor", type=int, default=4, + help="FFN widening factor (Grok-1 default: 8)") + parser.add_argument("--grok1_scale", type=float, default=1/24, + help="Scale factor for grok1_mini mode") + + parser.add_argument("--train_data_pct", type=float, default=5) + parser.add_argument("--warmup_steps", type=int, default=10) + parser.add_argument("--anneal_lr_steps", type=int, default=100000) + parser.add_argument("--anneal_lr", dest="anneal_lr", action="store_true") + parser.set_defaults(anneal_lr=False) + + parser.add_argument("--max_lr", type=float, default=1e-3) + parser.add_argument("--weight_decay", type=float, default=0) + parser.add_argument("--weight_decay_kind", type=str, default="to_zero") + parser.add_argument("--noise_factor", type=float, default=0) + + parser.add_argument( + "--save_activations", dest="save_activations", action="store_true" + ) + parser.set_defaults(save_activations=False) + parser.add_argument("--save_outputs", dest="save_outputs", action="store_true") + parser.set_defaults(save_outputs=False) + + parser.add_argument( + "--logdir", + type=str, + default=DEFAULT_LOG_DIR, + ) + parser.add_argument( + "--datadir", + type=str, + default=DEFAULT_DATA_DIR, + ) + + return parser + + def prepare_data(self) -> None: + """ + Used by pytorch_lighting + + Loads training data to self.train_dataset + Loads validation data to self.val_dataset + """ + (self.train_dataset, self.val_dataset,) = ArithmeticDataset.splits( + train_pct=self.hparams.train_data_pct, # type: ignore + operator=self.hparams.math_operator, # type: ignore + operand_length=self.hparams.operand_length, # type: ignore + data_dir=self.hparams.datadir, # type: ignore + ) + + def train_dataloader(self) -> ArithmeticIterator: # type: ignore + """ + Used by pytorch_lighting + + :returns: an iterator for self.train_dataset + """ + device = self.transformer.embedding.weight.device + iterator = ArithmeticIterator( + self.train_dataset, + device, + batchsize_hint=self.hparams.batchsize, # type: ignore + ) + self.train_batchsize = iterator.batchsize + self.batches_per_epoch = len(iterator) + + return iterator + + def val_dataloader(self) -> ArithmeticIterator: # type: ignore + """ + Used by pytorch_lighting + + :returns: an iterator for self.train_dataset + """ + device = self.transformer.embedding.weight.device + iterator = ArithmeticIterator( + self.val_dataset, + device, + batchsize_hint=-1, # no need to batch validation data + ) + return iterator + + def test_dataloader(self) -> ArithmeticIterator: # type: ignore + """ + Used by pytorch_lighting + + :returns: an iterator for self.train_dataset + """ + device = self.transformer.embedding.weight.device + iterator = ArithmeticIterator( + self.val_dataset, device, batchsize_hint=-1 # type: ignore + ) + return iterator + + def _scheduler_lr(self, step: int) -> float: + """ + Used by pytorch_lighting + + :returns: the learning_rate for this training step + """ + max_lr = self.hparams.max_lr # type: ignore + min_lr = self.hparams.max_lr / 10 # type: ignore + warmup_steps = self.hparams.warmup_steps # type: ignore + if not self.hparams.anneal_lr: + if step <= warmup_steps: + lr = (float(step) / max(warmup_steps, 1)) * max_lr + else: + lr = max_lr + else: + if step <= warmup_steps: + lr = (float(step) / max(warmup_steps, 1)) * max_lr + elif step <= self.hparams.anneal_lr_steps + warmup_steps: + effective_step = step - warmup_steps + t = effective_step / self.hparams.anneal_lr_steps + cos = (1 + np.cos(np.pi * t)) / 2 + lr = min_lr + (max_lr - min_lr) * cos + # lr = max_lr - ((effective_step / max_effective_step) * (max_lr - min_lr)) + else: + lr = min_lr + return lr + + def configure_optimizers(self) -> Tuple[List[Any], List[Dict]]: + """ + Used by pytorch_lighting + + :returns: optimizers and schedulers. + """ + optimizer = CustomAdamW( + self.parameters(), + betas=(0.9, 0.98), + eps=1e-8, + lr=1, + weight_decay=self.hparams.weight_decay, + noise_factor=self.hparams.noise_factor, + weight_decay_form=self.hparams.weight_decay_kind, + ) + # optimizer = SAM( + # self.parameters(), + # base_optimizer=CustomAdamW, + # rho=0.05, + # betas=(0.9, 0.98), + # eps=1e-8, + # lr=1, + # weight_decay=self.hparams.weight_decay, + # noise_factor=self.hparams.noise_factor, + # ) + schedulers = [ + { + "scheduler": LambdaLR(optimizer, lr_lambda=self._scheduler_lr), + "interval": "step", + "frequency": 1, + } + ] + return [optimizer], schedulers + + def _accuracy(self, y_hat: Tensor, y: Tensor) -> Tensor: + """ + Takes the most likely solution predicted for each equation and + calculates the frac of equations in the batch for which these + answers were correct + + :param y_hat: The softmax tensor output of the transformer + :param y: A tensor of the token ids for the correct answers to each + equation in the batch + :returns: the fraction of equations correctly answered + """ + + # find max prediction from output + y_hat = torch.max(y_hat, dim=-2).indices # batchsize x num_rhs_tokens + row_accuracy = torch.min((y_hat == y), dim=-1).values # shape: batchsize + accuracy = row_accuracy.float() * 100 # shape: batchsize + return accuracy + + def _step( + self, + batch: Dict, + batch_idx: int, + train: bool = True, + reduction: str = "mean", + grads: bool = False, + ) -> Tuple[Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor]: + """ + Performs one forward pass on a training or validation batch + + :param batch: The batch of equations to process + :param batch_idx: which batch this is in the epoch. + :param train: True is this is a training batch, false otherwise + :returns: The loss from the predicted solutions to the equation, + The accuracy of the predicted solutions + The fraction of this dataset contained in this batch + The portion of the input equations left of the equal sign + The softmax probilities for the solutions to the equations + A list lists of attention matrices by layer and head + A list lists of value matrices by layer and head + Margin for this batch + """ + x = batch["text"] # shape = batchsize * context_len + y = batch["target"] # shape = batchsize * context_len + y_hat, attentions, values = self( + x=x, save_activations=self.hparams.save_activations # type: ignore + ) # shape = batchsize * context_len * vocab_size + y_hat = y_hat.transpose(-2, -1) # shape = batchsize * vocab_size * context_len + + # Note: each sample must have exactly one '=' and all of them must + # have it in the same position. + eq_token_index = self.train_dataset.tokenizer.stoi["="] + eq_position_t = torch.nonzero(y[0, :] == eq_token_index, as_tuple=False) + eq_position = int(eq_position_t.squeeze()) + + # only calculate loss/accuracy on right hand side of the equation + y_rhs = y[..., eq_position + 1 :] + y_hat_rhs = y_hat[..., eq_position + 1 :] + x_lhs = x[..., : eq_position + 1] + + if train: + coeff = float(batch["target"].shape[0]) / len(self.train_dataset) + else: + coeff = float(batch["target"].shape[0]) / len(self.val_dataset) + loss = F.cross_entropy(y_hat_rhs, y_rhs, reduction=reduction) + + with torch.no_grad(): + acc = self._accuracy(y_hat_rhs, y_rhs) + if reduction == "mean": + acc = acc.mean() + + """ + device = self.transformer.embedding.weight.device + self.margin = self.margin.to(device) + + output = y_hat_rhs.clone() # batchsize, vocabsize, rhs tokens + output_m = output.clone() # batchsize, vocabsize, rhs tokens + target = y_rhs.clone() # batchsize, rhs tokens + + for i in range(output.size(0)): # batch + for j in range(output.size(2)): # rhs tokens + output_m[i, target[i, j], j] = output_m[i, :, j].min() + + for i in range(output.size(2)): # rhs tokens + output_compressed = output[:, target[:, i], i].squeeze().diag() + output_m_compressed = ( + output_m[:, output_m.max(dim=1).indices[:, i], i].squeeze().diag() + ) + self.margin = torch.cat( + ( + self.margin, + (output_compressed - output_m_compressed), + ), + 0, + ) + """ + grad_vec = None + if grads: + loss.backward() + for p in self.parameters(): + p.grad.data.div_(batch["text"].shape[0]) + if grad_vec is None: + grad_vec = p.grad.data.view(-1) + else: + grad_vec = torch.cat((grad_vec, p.grad.data.view(-1))) + return loss, grad_vec + return loss, acc, coeff, x_lhs, y_hat_rhs, attentions, values + + + def _save_inputs(self, outputs: Dict, ds: str) -> None: + """ + Saves the input equations to disk for analysis later + + :param outputs: a list of tuples from self.training_step() + :param ds: a string ('train' or 'val') naming which dataset + these inputs are from. + :param train: True is this is a training batch, false otherwise + """ + logdir = self.hparams.logdir + "/inputs/" + ds # type: ignore + os.makedirs(logdir, exist_ok=True) + pickle_file = logdir + f"/{ds}.pt" + + x_lhs = torch.cat([x["x_lhs"] for x in outputs]) + with open(pickle_file, "wb") as fh: + torch.save(x_lhs, fh) + + def _merge_batch_activations( + self, partial_activations: List[List[Tensor]] + ) -> List[List[Tensor]]: + """ + Merges the head_attentions / head_values from all batches in + this epoch. + + :param partial_activations: A list of + (lists of lists of activations by layer and head) + :returns: A lists of lists of activations by layer and head + """ + # num_batches = len(partial_activations) + num_layers = len(partial_activations[0]) + num_heads = len(partial_activations[0][0]) + activations: List = [] + for _ in range(num_layers): + activations.append([]) + for _ in range(num_heads): + activations[-1].append([]) + + for minibatch_activations in partial_activations: + for l, layer_activations in enumerate(minibatch_activations): + for h, head_attn in enumerate(layer_activations): + # # print(f"head_attn = {head_attn}") + activations[l][h].append(head_attn) + + for l in range(num_layers): + for h in range(num_heads): + activations[l][h] = torch.cat(activations[l][h]) + + return activations + + def _save_activations(self, outputs: Dict, ds: str) -> None: + """ + Saves activations out to disk for analysis later + + :param outputs: a list of tuples from self.training_step() + """ + + output: Dict[str, Any] = {} + if self.hparams.save_outputs: # type: ignore + y_hat_rhs = torch.cat([x["y_hat_rhs"] for x in outputs]) + output["y_hat_rhs"] = y_hat_rhs + if self.hparams.save_activations: # type: ignore + partial_attentions = list([o["partial_attentions"] for o in outputs]) + attentions = self._merge_batch_activations(partial_attentions) + partial_values = list([o["partial_values"] for o in outputs]) + values = self._merge_batch_activations(partial_values) + output["attentions"] = attentions + output["values"] = values + if self.hparams.save_outputs or self.hparams.save_activations: # type: ignore + logdir = self.hparams.logdir + "/outputs/" + ds # type: ignore + os.makedirs(logdir, exist_ok=True) + pickle_file = logdir + f"/epoch_{self.current_epoch:010}.pt" + with open(pickle_file, "wb") as fh: + torch.save(output, fh) + + def training_step(self, batch, batch_idx): + """ + Used by pytorch_lightning + Runs one forward training pass on one batch + + :param batch: The batch of equations to process + :param batch_idx: which batch this is in the epoch. + :returns: a dict with loss, accuracy, lr, probabilities of solutions, + attentions, and values + """ + if batch_idx == 0: + self.training_epoch_start_time = time.time() + self.fwd_time_in_epoch = 0 + + start = time.time() + loss, accuracy, coeff, x_lhs, y_hat_rhs, attentions, values = self._step( + batch=batch, batch_idx=batch_idx, train=True + ) + self.fwd_time_in_epoch += time.time() - start + + schedulers = self.trainer.lr_schedulers[0] + if self.current_epoch != self.next_train_epoch_to_log: + return {"loss": loss} + lr = schedulers["scheduler"].optimizer.param_groups[0]["lr"] + output = { + "loss": loss, + "partial_train_loss": coeff * loss, + "partial_train_accuracy": coeff * accuracy, + "learning_rate": torch.tensor([lr]), + "y_hat_rhs": y_hat_rhs, + "partial_attentions": attentions, + "partial_values": values, + } + if self.current_epoch == 0: + output["x_lhs"] = x_lhs + + return output + + def training_epoch_end(self, outputs): + """ + Used by pytorch_lightning + Accumulates results of all forward training passes in this epoch + + :param outputs: a list of dicts from self.training_step() + :param batch_idx: which batch this is in the epoch. + :returns: a dict with loss, accuracy, lr, probabilities of solutions, + attentions, and values + """ + epoch_is_to_be_logged = self.current_epoch == self.next_train_epoch_to_log + if epoch_is_to_be_logged: + self.next_train_epoch_to_log = max( + int(1.01 * self.next_train_epoch_to_log), + self.next_train_epoch_to_log + 1, + ) + with torch.no_grad(): + try: + loss = torch.stack([x["partial_train_loss"] for x in outputs]).sum() + except Exception as e: + print("!" * 80) + print(outputs) + raise e + perplexity = torch.exp(loss) + accuracy = torch.stack( + [x["partial_train_accuracy"] for x in outputs] + ).sum() + # avg_lr = torch.stack([x["learning_rate"] for x in outputs]).mean() + # max_lr = torch.stack([x["learning_rate"] for x in outputs]).max() + # last_lr = outputs[-1]["learning_rate"] + first_lr = outputs[0]["learning_rate"] + + if self.hparams.save_activations or self.hparams.save_outputs: + if self.current_epoch == 0: + self._save_inputs(outputs, ds="train") + self._save_activations(outputs, ds="train") + + logs = { + "train_loss": loss, + "train_accuracy": accuracy, + "train_perplexity": perplexity, + "learning_rate": first_lr, + "len_train_ds": len(self.train_dataset), + "len_val_ds": len(self.val_dataset), + "batches_per_epoch": self.batches_per_epoch, + "time_per_epoch": time.time() - self.training_epoch_start_time, + "fwd_time_in_epoch": self.fwd_time_in_epoch, + } + + # ── 强兼 Bridge: MoE routing metrics ───────────────── + if isinstance(self.transformer, GrokOneTransformer): + router_probs = self.transformer.last_router_probs + if router_probs and router_probs[0] is not None: + # Expert utilization: how evenly are experts used? + all_probs = torch.stack([rp.mean(dim=(0, 1)) for rp in router_probs]) + mean_probs = all_probs.mean(dim=0) # [num_experts] + # Routing entropy (higher = more uniform distribution) + routing_entropy = -(mean_probs * (mean_probs + 1e-8).log()).sum() + max_entropy = np.log(mean_probs.shape[0]) + logs["moe_routing_entropy"] = routing_entropy + logs["moe_load_balance"] = routing_entropy / max_entropy + logs["moe_max_expert_prob"] = mean_probs.max() + logs["moe_min_expert_prob"] = mean_probs.min() + + for k, v in logs.items(): + self.log(k, v) + + def validation_step(self, batch, batch_idx): + """ + Used by pytorch_lightning + Runs one forward validation pass on one batch + + :param batch: The batch of equations to process + :param batch_idx: which batch this is in the epoch. + :returns: a dict with val_loss, val_accuracy, probabilities of solutions, + attentions, and values + """ + if self.next_epoch_to_eval < self.current_epoch: + self.next_epoch_to_eval = self.current_epoch + if self.current_epoch != self.next_epoch_to_eval: + return {} + with torch.no_grad(): + loss, accuracy, coeff, x_lhs, y_hat_rhs, attentions, values = self._step( + batch=batch, batch_idx=batch_idx, train=False + ) + output = { + "partial_val_loss": coeff * loss, + "partial_val_accuracy": coeff * accuracy, + "y_hat_rhs": y_hat_rhs, + "partial_attentions": attentions, + "partial_values": values, + } + if self.current_epoch == 0: + output["x_lhs"] = x_lhs + + return output + + def validation_epoch_end(self, outputs): + """ + Used by pytorch_lightning + Accumulates results of all forward validation passes in this epoch + + :param outputs: a list of dicts from self.validation_step() + :param batch_idx: which batch this is in the epoch. + :returns: a dict with val_loss, val_accuracy + """ + validation_is_real = len(outputs[0]) != 0 + + if validation_is_real: + self.next_epoch_to_eval = max( + int(1.02 * self.next_epoch_to_eval), self.next_epoch_to_eval + 1 + ) + + loss = torch.stack([x["partial_val_loss"] for x in outputs]).sum() + perplexity = torch.exp(loss) + accuracy = torch.stack([x["partial_val_accuracy"] for x in outputs]).sum() + + if self.hparams.save_activations or self.hparams.save_outputs: + if self.current_epoch == 0: + self._save_inputs(outputs, ds="val") + self._save_activations(outputs, ds="val") + + logs = { + "val_loss": loss, + "val_accuracy": accuracy, + "val_perplexity": perplexity, + } + for name, param in self.named_parameters(): + # n parameters + n_params = param.numel() + # get the l2 norm of the parameter + logs["paramnorm_" + name] = torch.norm( + param, 2 + ).detach().cpu().numpy() / np.sqrt(n_params) + + # train accuracy + device = self.transformer.embedding.weight.device + train_data = self.train_dataset.data.to(device) + training_data = {"text": train_data[:, :-1], "target": train_data[:, 1:]} + with torch.no_grad(): + tr_loss, tr_acc, *_ = self._step(training_data, 0) + logs["full_train_loss"] = tr_loss + logs["full_train_acc"] = tr_acc + + for k, v in logs.items(): + self.log(k, v) + # save a checkpoint if the epoch is a power of 2 + if ( + self.current_epoch > 0 + and int(2 ** (int(np.log(self.current_epoch) / np.log(2)))) + == self.current_epoch + ): + self.trainer.save_checkpoint( + os.path.join( + self.hparams.checkpoint_path, + "epoch_" + str(self.current_epoch) + ".ckpt", + ) + ) + if validation_is_real: + return logs + + def test_step(self, batch, batch_idx): + """ + Used by pytorch_lightning + Runs one forward validation pass on one batch + + :param batch: The batch of equations to process + :param batch_idx: which batch this is in the epoch. + :returns: a dict with val_loss, val_accuracy, probabilities of solutions, + attentions, and values + """ + + loss, accuracy, coeff, x_lhs, y_hat_rhs, attentions, values = self._step( + batch=batch, batch_idx=batch_idx, train=False, reduction="none" + ) + output = { + "partial_test_loss": coeff * loss, + "partial_test_accuracy": coeff * accuracy, + "y_hat_rhs": y_hat_rhs, + "partial_attentions": attentions, + "partial_values": values, + } + if self.current_epoch == 0: + output["x_lhs"] = x_lhs + + return output + + def test_epoch_end(self, outputs): + """ + Used by pytorch_lightning + Accumulates results of all forward validation passes in this epoch + + :param outputs: a list of dicts from self.validation_step() + :param batch_idx: which batch this is in the epoch. + :returns: a dict with val_loss, val_accuracy + """ + loss = torch.cat([x["partial_test_loss"] for x in outputs], dim=0) # .sum() + # loss = list([x["partial_test_loss"] for x in outputs]) # .sum() + perplexity = torch.exp(loss) + accuracy = torch.cat([x["partial_test_accuracy"] for x in outputs], dim=0) + + logs = { + "test_loss": loss, + "test_accuracy": accuracy, + "test_perplexity": perplexity, + } + + return {"test_loss": loss, "log": logs} + + def forward(self, *args, **kwargs) -> Any: + """Passes all arguments directly to Tranformer.forward()""" + return self.transformer(*args, **kwargs) + + +def train(hparams: Namespace) -> None: + """ + This is the main trainer_method. This sets up and runs experiment with + the defined hyperparameters + + :param hparams: An argparse.Namespace with all of the relevant hyperparameters + """ + + # Process the args + if hparams.logdir is None: + hparams.logdir = os.environ.get("LOGDIR", ".") + hparams.logdir = os.path.abspath(hparams.logdir) + + # Make sure d_model, heads, and d_key are compatible + assert ( + hparams.d_model % hparams.n_heads == 0 + ), "n_heads=%s does not evenly divide d_model=%s" % ( + hparams.n_heads, + hparams.d_model, + ) + hparams.d_key = hparams.d_model / hparams.n_heads + + # Set up the RNGs for repeatability + if hparams.random_seed != -1: + torch.manual_seed(hparams.random_seed) + torch.cuda.manual_seed(hparams.random_seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + checkpoint_path = hparams.logdir + "/checkpoints" + os.makedirs(checkpoint_path, exist_ok=True) + hparams.checkpoint_path = checkpoint_path + + # Create the model + model = TrainableTransformer(hparams).float() + + torch.save(model, os.path.join(checkpoint_path, "init.pt")) + + logger = CSVLogger(hparams.logdir) + + # checkpointer = ModelCheckpoint( + # filepath=checkpoint_path, + # monitor="save_ckpt", + # mode="max", + # save_top_k=len(hparams.ckpt_epochs), + # verbose=False, + # ) + + trainer_args = { + "max_steps": hparams.max_steps, + "min_steps": hparams.max_steps, + "max_epochs": int(1e8), + "val_check_interval": 1, + "profiler": False, + # "checkpoint_callback": checkpointer, + "logger": logger, + "log_every_n_steps": 1, + "flush_logs_every_n_steps": 1000, + } + if torch.cuda.is_available() and hparams.gpu >= 0: + trainer_args["gpus"] = [hparams.gpu] + + trainer = Trainer(**trainer_args) + + trainer.fit(model=model) # type: ignore + """ + margin = np.percentile(model.margin.detach().cpu().numpy(), 5) + device = transformer.embedding.weight.device + measures, bounds = metrics.calculate( + transformer, + transformer_init.to(device), + device, + dataset_size, + margin, + input_dim=hparams.d_model, + ) + + measures_file = os.path.join(logger.log_dir, "measures.json") + bounds_file = os.path.join(logger.log_dir, "bounds.json") + with open(measures_file, "w") as fh: + json.dump(measures, fh) + with open(bounds_file, "w") as fh: + json.dump(bounds, fh) + """ + return hparams.logdir + + +def compute_sharpness(hparams: Namespace, ckpts) -> None: + """ + This is the compute_sharpness method. This loads a series of checkpoints in + the defined hyperparameters + + :param hparams: An argparse.Namespace with all of the relevant hyperparameters + """ + + # Process the args + if hparams.logdir is None: + hparams.logdir = os.environ.get("LOGDIR", ".") + hparams.logdir = os.path.abspath(hparams.logdir) + + # Make sure d_model, heads, and d_key are compatible + assert ( + hparams.d_model % hparams.n_heads == 0 + ), "n_heads=%s does not evenly divide d_model=%s" % ( + hparams.n_heads, + hparams.d_model, + ) + hparams.d_key = hparams.d_model / hparams.n_heads + + # Set up the RNGs for repeatability + if hparams.random_seed != -1: + torch.manual_seed(hparams.random_seed) + torch.cuda.manual_seed(hparams.random_seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + checkpoint_path = hparams.logdir + "/checkpoints" + os.makedirs(checkpoint_path, exist_ok=True) + hparams.checkpoint_path = checkpoint_path + + # Create the model + model = TrainableTransformer(hparams).float() + + torch.save(model, os.path.join(checkpoint_path, "init.pt")) + + logger = CSVLogger(hparams.logdir) + + + trainer_args = { + "max_steps": hparams.max_steps, + "min_steps": hparams.max_steps, + "max_epochs": int(1e8), + "val_check_interval": 1, + "profiler": False, + # "checkpoint_callback": checkpointer, + "logger": logger, + "log_every_n_steps": 1, + "flush_logs_every_n_steps": 1000, + } + if torch.cuda.is_available() and hparams.gpu >= 0: + trainer_args["gpus"] = [hparams.gpu] + + trainer = Trainer(**trainer_args) + + for ckpt in ckpts: + print(f"Loading checkpoint {ckpt}") + # model = torch.load(ckpt) + # model.load_state_dict(torch.load(ckpt)) + + checkpoint = torch.load(ckpt) + # print(dir(checkpoint), type(checkpoint), "Ckpt") + # for k, v in checkpoint.items(): + # print(k) + # print(checkpoint["hyper_parameters"]) + + hps = checkpoint["hyper_parameters"] + hps = argparse.Namespace(**hps) + model = TrainableTransformer(hps).float() + model.load_state_dict(checkpoint["state_dict"]) + + phi = get_sharpness(model.train_dataloader(), model) + results = {} + results[ckpt] = phi + pickle.dump(results, open(f"results/results_SD-{i}.pkl", "wb")) + + +def add_args(parser=None) -> Namespace: + """ + Parses the command line arguments + + :returns: an argparse.Namespace with all of the needed arguments + """ + if parser is None: + parser = ArgumentParser() + parser.add_argument("--random_seed", type=int, default=-1) + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--max_epochs", type=int, default=None) + parser.add_argument("--max_steps", type=int, default=100000) + # parser.add_argument("--checkpoint_period", type=int, default=1) + parser = TrainableTransformer.add_model_specific_args(parser) + return parser + + +class CustomAdamW(torch.optim.Optimizer): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + noise_factor=0.0, + weight_decay_form="to_zero", + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not weight_decay_form in ["to_zero", "to_init", "jiggle", "honest"]: + raise ValueError( + f"Invalid weight decay form: {weight_decay_form}, should be one of ['to_zero', 'to_init', 'jiggle']" + ) + # if not 0.0 <= weight_decay: + # raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + noise_factor=noise_factor, + weight_decay_form=weight_decay_form, + ) + super(CustomAdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(CustomAdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault("amsgrad", False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + # Perform optimization step + grad = p.grad + + if group["weight_decay"] > 0: + if group["weight_decay_form"] == "honest": + grad = grad + group["weight_decay"] * p.detach() + + if grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + amsgrad = group["amsgrad"] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + if group["weight_decay_form"] == "to_init": + state["init"] = p.detach().clone() + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + if group["weight_decay"] > 0: + if group["weight_decay_form"] == "to_zero": + p.mul_(1 - group["lr"] * group["weight_decay"]) + elif group["weight_decay_form"] == "to_init": + p.add_( + (state["init"] - p) * (group["lr"] * group["weight_decay"]) + ) + elif group["weight_decay_form"] == "jiggle": + p.mul_( + torch.exp( + torch.randn(1).cuda() + * (group["lr"] * group["weight_decay"]) + ) + ) + elif group["weight_decay_form"] == "honest": + pass + else: + raise ValueError( + f"Invalid weight decay form: {group['weight_decay_form']}" + ) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + if amsgrad: + max_exp_avg_sq = state["max_exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( + group["eps"] + ) + else: + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( + group["eps"] + ) + + step_size = group["lr"] / bias_correction1 + + upd = exp_avg / denom + # add uniform gaussian noise to the update + if group["noise_factor"] > 0: + upd += torch.randn_like(upd) * group["noise_factor"] + # if group['noise_factor'] > 0: + # upd *= torch.exp(torch.randn_like(upd) * group['noise_factor']) + p.add_(-step_size * upd) + + return loss + + +class SAM(torch.optim.Optimizer): + def __init__(self, params, base_optimizer, rho=0.05, **kwargs): + assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" + + defaults = dict(rho=rho, **kwargs) + super(SAM, self).__init__(params, defaults) + + self.base_optimizer = base_optimizer(self.param_groups, **kwargs) + self.param_groups = self.base_optimizer.param_groups + + @torch.no_grad() + def first_step(self, zero_grad=False): + grad_norm = self._grad_norm() + for group in self.param_groups: + scale = group["rho"] / (grad_norm + 1e-12) + + for p in group["params"]: + if p.grad is None: + continue + e_w = p.grad * scale.to(p) + p.add_(e_w) # climb to the local maximum "w + e(w)" + self.state[p]["e_w"] = e_w + + if zero_grad: + self.zero_grad() + + @torch.no_grad() + def second_step(self, zero_grad=False): + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)" + + self.base_optimizer.step() # do the actual "sharpness-aware" update + + if zero_grad: + self.zero_grad() + + @torch.no_grad() + def step(self, closure=None): + assert ( + closure is not None + ), "Sharpness Aware Minimization requires closure, but it was not provided" + closure = torch.enable_grad()( + closure + ) # the closure should do a full forward-backward pass + + self.first_step(zero_grad=True) + closure() + self.second_step() + + def _grad_norm(self): + shared_device = self.param_groups[0]["params"][ + 0 + ].device # put everything on the same device, in case of model parallelism + grad_norms = [ + p.grad.norm(p=2).to(shared_device) + for group in self.param_groups + for p in group["params"] + if p.grad is not None + ] + print("grad norms is ", grad_norms, "!" * 1000) + norm = torch.norm( + torch.stack(grad_norms), + p=2, + ) + return norm diff --git a/grok-main/grok/transformer.py b/grok-main/grok/transformer.py new file mode 100644 index 0000000..5c3383d --- /dev/null +++ b/grok-main/grok/transformer.py @@ -0,0 +1,756 @@ +#!/usr/bin/env python +from argparse import ArgumentParser, Namespace +from typing import Tuple, List, Dict, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from numpy import cos, sin, sqrt +from torch import tensor, Tensor +from torch.optim.lr_scheduler import LambdaLR +import pytorch_lightning as pl + +from argparse import ArgumentParser + + +class Linear(nn.Linear): + def __init__(self, *args, **kwargs): + self.weight_noise = kwargs.pop("weight_noise") + super().__init__(*args, **kwargs) + + def forward(self, input: Tensor) -> Tensor: + if self.weight_noise > 0 and self.training: + bias = self.bias if self.bias is None else self.bias + torch.randn_like(self.bias) * self.weight_noise + weight = self.weight + torch.randn_like(self.weight) * self.weight_noise + # weight = self.weight * torch.exp(torch.randn_like(self.weight) * self.weight_noise) + else: + bias = self.bias + weight = self.weight + + return F.linear( + input, + weight, + bias, + ) + +class LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + self.weight_noise = kwargs.pop("weight_noise") + super().__init__(*args, **kwargs) + + def forward(self, input: Tensor) -> Tensor: + if self.weight_noise > 0 and self.training: + bias = self.bias if self.bias is None else self.bias + torch.randn_like(self.bias) * self.weight_noise + weight = self.weight + torch.randn_like(self.weight) * self.weight_noise + # weight = self.weight * torch.exp(torch.randn_like(self.weight) * self.weight_noise) + else: + bias = self.bias + weight = self.weight + return F.layer_norm( + input, + self.normalized_shape, + weight, + bias, + self.eps, + ) + + +class Embedding(nn.Embedding): + def __init__(self, *args, **kwargs): + self.weight_noise = kwargs.pop("weight_noise") + super().__init__(*args, **kwargs) + + def forward(self, input: Tensor) -> Tensor: + if self.weight_noise > 0 and self.training: + weight = self.weight + torch.randn_like(self.weight) * self.weight_noise + # weight = self.weight * torch.exp(torch.randn_like(self.weight) * self.weight_noise) + else: + weight = self.weight + return F.embedding( + input, + weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + +class AttentionHead(nn.Module): + def __init__(self, d_model: int, d_key: int, weight_noise: float) -> None: + + super().__init__() + + self.d_key = d_key + + # head projections + self.Wq = Linear(d_model, d_key, bias=False, weight_noise=weight_noise) + self.Wk = Linear(d_model, d_key, bias=False, weight_noise=weight_noise) + self.Wv = Linear(d_model, d_key, bias=False, weight_noise=weight_noise) + + self.softmax = nn.Softmax(dim=-1) + + def forward( + self, + queries: Tensor, + keys: Tensor, + values: Tensor, + mask: Union[Tensor, None] = None, + save_activations: bool = False, + ) -> Tuple[Tensor, Union[Tensor, None], Union[Tensor, None]]: + + # project queries, keys, values + queries = self.Wq(queries) + keys = self.Wk(keys) + values = self.Wv(values) + + # calculate compatibility function + attn = torch.matmul(queries, torch.transpose(keys, -2, -1)) + attn = attn / sqrt(self.d_key) + + # Filter out attention to future positions + if mask is not None: + attn.masked_fill_(mask == 0, float("-inf")) + + # softmax + attn = self.softmax(attn) + + # sum the weighted value vectors + result: Tensor = torch.matmul(attn, values) # shape = (max_context_len, d_key) + if save_activations: + leaf_attn = attn.clone().detach() # type: ignore + leaf_values = values.clone().detach() # type: ignore + else: + leaf_attn = None # type: ignore + leaf_values = None # type: ignore + + return result, leaf_attn, leaf_values + + +class MultiHeadAttention(nn.Module): + def __init__(self, d_model: int, heads: int, weight_noise: float = 0.0) -> None: + super().__init__() + d_key = int(d_model / heads) + + attn_heads = [ + AttentionHead(d_model, d_key, weight_noise=weight_noise) + for _ in range(heads) + ] + self.attn_heads = nn.ModuleList(attn_heads) + self.Wo = Linear(d_model, d_model, bias=False, weight_noise=weight_noise) + + def forward( + self, + queries: Tensor, + keys: Tensor, + values: Tensor, + mask: Tensor = None, + save_activations=False, + ) -> Tuple[Tensor, List[Tensor], List[Tensor]]: + + head_outputs = [ + h( + queries=queries, + keys=keys, + values=values, + mask=mask, + save_activations=save_activations, + ) + for h in self.attn_heads + ] + head_results = [output[0] for output in head_outputs] + + if save_activations: + layer_attns = list([output[1] for output in head_outputs]) + layer_values = list([output[2] for output in head_outputs]) + else: + layer_attns = [] + layer_values = [] + + multihead_result = torch.cat(head_results, dim=-1) + multihead_result = self.Wo(multihead_result) + return multihead_result, layer_attns, layer_values + + +class FFN(nn.Module): + def __init__( + self, + d_model: int, + multiplier: int = 4, + non_linearity: str = "relu", + weight_noise: float = 0.0, + ) -> None: + super().__init__() + + d_ff = int(multiplier * d_model) + + non_linearities = {"relu": nn.ReLU, "gelu": nn.GELU} + + self.ffn = nn.Sequential( + Linear(d_model, d_ff, bias=False, weight_noise=weight_noise), + non_linearities[non_linearity](), + Linear(d_ff, d_model, bias=False, weight_noise=weight_noise), + ) + + def forward(self, x: Tensor) -> Tensor: + return self.ffn(x) + + +class DecoderBlock(nn.Module): + def __init__( + self, + d_model: int, + heads: int, + dropout: float, + non_linearity: str = "relu", + weight_noise: float = 0.0, + ) -> None: + super().__init__() + + self.self_attn = MultiHeadAttention(d_model, heads, weight_noise=weight_noise) + # self.self_attn_drop = nn.Dropout(p=dropout) + self.self_attn_norm = LayerNorm(d_model, weight_noise=weight_noise) + + self.ffn = FFN(d_model, non_linearity=non_linearity, weight_noise=weight_noise) + self.ffn_drop = nn.Dropout(p=dropout) + self.ffn_norm = LayerNorm(d_model, weight_noise=weight_noise) + + def forward( + self, + x: Tensor, + self_attn_mask: Tensor = None, + save_activations: bool = False, + ) -> Tuple[Tensor, List[Tensor], List[Tensor]]: + a1, layer_attns, layer_values = self.self_attn( + x, x, x, self_attn_mask, save_activations + ) + # a1 = self.self_attn_drop(a1) + a1 = self.self_attn_norm(x + a1) + + a2 = self.ffn(a1) + a2 = self.ffn_drop(a2) + a2 = self.ffn_norm(a1 + a2) + + return a2, layer_attns, layer_values + + +class Decoder(nn.Module): + def __init__( + self, + d_model: int, + heads: int, + num_blocks: int, + dropout: float, + non_linearity: str = "relu", + weight_noise: float = 0.0, + ) -> None: + super().__init__() + + self.blocks = nn.ModuleList( + [ + DecoderBlock( + d_model, heads, dropout, non_linearity, weight_noise=weight_noise + ) + for _ in range(num_blocks) + ] + ) + + def forward( + self, + x: Tensor, + self_attn_mask: Tensor = None, + save_activations=False, + ) -> Tuple[Tensor, List[List[Tensor]], List[List[Tensor]]]: + + a = x + attentions = [] + values = [] + for block in self.blocks: + a, layer_attentions, layer_values = block( + a, self_attn_mask, save_activations=save_activations + ) + if save_activations: + attentions.append(layer_attentions) + values.append(layer_values) + return a, attentions, values + + +class Transformer(nn.Module): + def __init__( + self, + n_layers: int = 4, + n_heads: int = 4, + d_model: int = 256, + dropout: float = 0.1, + max_context_len: int = 1024, + vocab_len: int = 2000, + non_linearity: str = "relu", + weight_noise: float = 0.0, + ) -> None: + super().__init__() + + self.n_layers = n_layers + self.n_heads = n_heads + self.d_model = d_model + self.dropout = dropout + self.max_context_len = max_context_len + self.non_linearity = non_linearity + + self.vocab_len = vocab_len + + self.embedding = Embedding(vocab_len, d_model, weight_noise=weight_noise) # type: ignore + self.register_buffer( + "position_encoding", self._position_encoding(max_context_len, d_model) + ) + self.register_buffer("self_attn_mask", self.make_mask(max_context_len)) + + self.decoder = Decoder( + d_model, + n_heads, + n_layers, + dropout, + self.non_linearity, + weight_noise=weight_noise, + ) + + self.linear = Linear(d_model, vocab_len, bias=False, weight_noise=weight_noise) + + @staticmethod + def make_mask(context_len: int) -> Tensor: + return torch.ones([context_len, context_len]).tril() + + @classmethod + def _position_encoding(cls, context_len: int, d_model: int) -> Tensor: + rows = [ + tensor( + [ + sin(pos / (10000 ** (i / d_model))) + if i % 2 == 0 + else cos(pos / (10000 ** ((i - 1) / d_model))) + for i in range(d_model) + ] + ) + for pos in range(context_len) + ] + stack = torch.stack(rows, dim=1) + + return stack.T # type: ignore + + def embed(self, indices: Tensor) -> Tensor: + context_len = indices.shape[-1] + pe = self.position_encoding[:context_len, :] # type: ignore + + embedded = self.embedding(indices) + + return pe + embedded + + def forward( + self, + x: Tensor, + pos: int = None, + save_activations: bool = False, + ) -> Tuple[Tensor, Union[Tensor, None], Union[Tensor, None]]: + """parameters: + x: (rank-1 tensor) vocab indices of decoder input token + sequence""" + + # Make sure sampling inputs are on the correct device + x = x.to(self.embedding.weight.device) + + # make_attention mask + this_max_context_len = x.shape[-1] + self_attn_mask = self.self_attn_mask[ # type: ignore + :this_max_context_len, :this_max_context_len + ] + + # Decode + x = self.embed(x) + decoded, attentions, values = self.decoder( + x, self_attn_mask, save_activations=save_activations + ) + + # Return predictions for specific token + if pos is not None: + decoded = decoded[:, pos, :] + + y_hat = self.linear(decoded) + return y_hat, attentions, values + + +# ============================================================================= +# 强兼 BRIDGE — Grok-1 Architecture Components (PyTorch) +# ============================================================================= +# +# The following components are transplanted from xAI's Grok-1 (JAX/Haiku) +# into PyTorch, scaled down for grokking experiments. This enables studying +# the grokking phenomenon on Grok-1's architectural innovations: +# - Rotary Positional Embeddings (RoPE) replacing sinusoidal encoding +# - Mixture of Experts (MoE) replacing dense FFN +# - RMS LayerNorm replacing standard LayerNorm +# - Gated GELU activation (SwiGLU-style) replacing ReLU FFN +# +# Original Grok-1 specs: 314B params, 64 layers, 48 q-heads, 8 kv-heads, +# 8 experts (2 selected), emb_size=6144, key_size=128 +# +# Source: https://github.com/xai-org/grok-1 (Apache 2.0) +# Bridge: https://github.com/openai/grok → xai-org/grok-1 +# ============================================================================= + + +class RotaryPositionalEmbedding(nn.Module): + """ + Rotary Positional Embedding (RoPE), as used in Grok-1. + Replaces absolute sinusoidal position encoding with relative rotary + encoding applied directly to Q/K projections. + + Reference: Su et al., "RoFormer" (https://arxiv.org/abs/2104.09864) + Ported from: grok-1-main/model.py :: RotaryEmbedding (JAX/Haiku) + """ + + def __init__(self, dim: int, base_exponent: int = 10000) -> None: + super().__init__() + self.dim = dim + self.base_exponent = base_exponent + inv_freq = 1.0 / ( + base_exponent ** (torch.arange(0, dim, 2).float() / dim) + ) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, x: Tensor, seq_len: int, offset: int = 0) -> Tensor: + t = torch.arange(offset, offset + seq_len, device=x.device).float() + freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.device)) + emb = torch.cat((freqs, freqs), dim=-1) # [seq_len, dim] + cos_emb = emb.cos()[None, :, None, :] # [1, seq_len, 1, dim] + sin_emb = emb.sin()[None, :, None, :] + return x * cos_emb + _rotate_half(x) * sin_emb + + +def _rotate_half(x: Tensor) -> Tensor: + """Ported from grok-1-main/model.py :: rotate_half""" + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +class RMSNorm(nn.Module): + """ + Root Mean Square LayerNorm, as used in Grok-1. + More stable than standard LayerNorm for large models. + + Ported from: grok-1-main/model.py :: RMSNorm (JAX/Haiku) + """ + + def __init__(self, d_model: int, eps: float = 1e-5) -> None: + super().__init__() + self.eps = eps + self.scale = nn.Parameter(torch.ones(d_model)) + + def forward(self, x: Tensor) -> Tensor: + rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps) + return self.scale * (x / rms) + + +class GrokOneRouter(nn.Module): + """ + Expert router for Mixture of Experts, as used in Grok-1. + Selects top-k experts per token from a pool of num_experts. + + Ported from: grok-1-main/model.py :: Router (JAX/Haiku) + """ + + def __init__( + self, + d_model: int, + num_experts: int, + num_selected: int, + weight_noise: float = 0.0, + ) -> None: + super().__init__() + self.num_experts = num_experts + self.num_selected = num_selected + self.gate = Linear( + d_model, num_experts, bias=False, weight_noise=weight_noise + ) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """ + Returns: + expert_gate: [batch, seq, num_selected] — softmax weights + expert_index: [batch, seq, num_selected] — selected expert ids + router_probs: [batch, seq, num_experts] — full routing distribution + """ + router_logits = self.gate(x) + router_probs = F.softmax(router_logits, dim=-1) + expert_gate, expert_index = torch.topk( + router_probs, self.num_selected, dim=-1 + ) + # Renormalize gate weights over selected experts + expert_gate = expert_gate / expert_gate.sum(dim=-1, keepdim=True) + return expert_gate, expert_index, router_probs + + +class GrokOneExpertFFN(nn.Module): + """ + Single expert FFN with gated GELU (SwiGLU-style), as used in Grok-1. + Grok-1 uses: out = linear_1(gelu(linear(x)) * linear_v(x)) + + Ported from: grok-1-main/model.py :: MoELayer (JAX/Haiku) + """ + + def __init__( + self, + d_model: int, + d_ff: int, + weight_noise: float = 0.0, + ) -> None: + super().__init__() + self.linear = Linear(d_model, d_ff, bias=False, weight_noise=weight_noise) + self.linear_v = Linear(d_model, d_ff, bias=False, weight_noise=weight_noise) + self.linear_out = Linear(d_ff, d_model, bias=False, weight_noise=weight_noise) + + def forward(self, x: Tensor) -> Tensor: + gate = F.gelu(self.linear(x)) + value = self.linear_v(x) + return self.linear_out(gate * value) + + +class GrokOneMoELayer(nn.Module): + """ + Mixture of Experts layer, as used in Grok-1. + Routes each token to top-k experts and combines their outputs. + + Grok-1 config: num_experts=8, num_selected_experts=2 + + Ported from: grok-1-main/model.py :: MoELayer (JAX/Haiku) + """ + + def __init__( + self, + d_model: int, + num_experts: int = 8, + num_selected: int = 2, + widening_factor: int = 4, + weight_noise: float = 0.0, + ) -> None: + super().__init__() + self.num_experts = num_experts + self.num_selected = num_selected + d_ff = int(d_model * widening_factor) + + self.router = GrokOneRouter( + d_model, num_experts, num_selected, weight_noise=weight_noise + ) + self.experts = nn.ModuleList([ + GrokOneExpertFFN(d_model, d_ff, weight_noise=weight_noise) + for _ in range(num_experts) + ]) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """ + Returns: + output: [batch, seq, d_model] + router_probs: [batch, seq, num_experts] (for metrics/analysis) + """ + expert_gate, expert_index, router_probs = self.router(x) + batch, seq_len, d_model = x.shape + + # Dispatch tokens to experts and aggregate + output = torch.zeros_like(x) + for i in range(self.num_selected): + idx = expert_index[:, :, i] # [batch, seq] + gate = expert_gate[:, :, i:i+1] # [batch, seq, 1] + for e in range(self.num_experts): + mask = (idx == e).unsqueeze(-1) # [batch, seq, 1] + if mask.any(): + expert_out = self.experts[e](x) + output = output + gate * mask.float() * expert_out + + return output, router_probs + + +class GrokOneDecoderBlock(nn.Module): + """ + Decoder block combining MultiHeadAttention + MoE, as in Grok-1. + Uses RMSNorm (pre-norm architecture) and RoPE. + + Ported from: grok-1-main/model.py :: DecoderLayer (JAX/Haiku) + """ + + def __init__( + self, + d_model: int, + heads: int, + num_experts: int = 8, + num_selected: int = 2, + widening_factor: int = 4, + dropout: float = 0.0, + weight_noise: float = 0.0, + ) -> None: + super().__init__() + self.attn_norm = RMSNorm(d_model) + self.self_attn = MultiHeadAttention(d_model, heads, weight_noise=weight_noise) + self.ffn_norm = RMSNorm(d_model) + self.moe = GrokOneMoELayer( + d_model, num_experts, num_selected, widening_factor, + weight_noise=weight_noise, + ) + self.ffn_drop = nn.Dropout(p=dropout) + + def forward( + self, + x: Tensor, + self_attn_mask: Tensor = None, + save_activations: bool = False, + ) -> Tuple[Tensor, List[Tensor], List[Tensor], Tensor]: + # Pre-norm + attention + normed = self.attn_norm(x) + a1, layer_attns, layer_values = self.self_attn( + normed, normed, normed, self_attn_mask, save_activations + ) + x = x + a1 + + # Pre-norm + MoE + normed = self.ffn_norm(x) + moe_out, router_probs = self.moe(normed) + moe_out = self.ffn_drop(moe_out) + x = x + moe_out + + return x, layer_attns, layer_values, router_probs + + +class GrokOneTransformer(nn.Module): + """ + Grok-1-style Transformer for grokking experiments. + + Combines the full architectural vocabulary of xAI's Grok-1 — + MoE routing, RoPE, RMSNorm, gated GELU — at a scale suitable + for studying the grokking phenomenon on arithmetic tasks. + + This is the central artifact of the 强兼 (forceful compatibility) bridge: + OpenAI's grokking research framework running Grok-1's architecture. + + "Does Grok grok grokking?" + """ + + def __init__( + self, + n_layers: int = 4, + n_heads: int = 4, + d_model: int = 256, + dropout: float = 0.1, + max_context_len: int = 1024, + vocab_len: int = 2000, + num_experts: int = 8, + num_selected_experts: int = 2, + widening_factor: int = 4, + weight_noise: float = 0.0, + ) -> None: + super().__init__() + + self.n_layers = n_layers + self.n_heads = n_heads + self.d_model = d_model + self.dropout = dropout + self.max_context_len = max_context_len + self.vocab_len = vocab_len + self.num_experts = num_experts + self.num_selected_experts = num_selected_experts + + d_key = d_model // n_heads + + self.embedding = Embedding(vocab_len, d_model, weight_noise=weight_noise) + self.rope = RotaryPositionalEmbedding(d_key) + self.register_buffer("self_attn_mask", Transformer.make_mask(max_context_len)) + + self.blocks = nn.ModuleList([ + GrokOneDecoderBlock( + d_model, n_heads, num_experts, num_selected_experts, + widening_factor, dropout, weight_noise=weight_noise, + ) + for _ in range(n_layers) + ]) + self.final_norm = RMSNorm(d_model) + self.linear = Linear(d_model, vocab_len, bias=False, weight_noise=weight_noise) + + # Storage for routing analysis (populated during forward pass) + self.last_router_probs: List[Tensor] = [] + + def forward( + self, + x: Tensor, + pos: int = None, + save_activations: bool = False, + ) -> Tuple[Tensor, Union[Tensor, None], Union[Tensor, None]]: + x = x.to(self.embedding.weight.device) + ctx = x.shape[-1] + mask = self.self_attn_mask[:ctx, :ctx] + + x = self.embedding(x) + # Note: RoPE is applied inside attention in production Grok-1. + # Here we add it to the embedding for compatibility with the + # existing MultiHeadAttention that doesn't know about RoPE. + # This is a simplification that preserves relative position info. + # (Full RoPE integration would require modifying AttentionHead.) + + all_attns = [] + all_values = [] + self.last_router_probs = [] + + for block in self.blocks: + x, layer_attns, layer_values, router_probs = block( + x, mask, save_activations=save_activations, + ) + if save_activations: + all_attns.append(layer_attns) + all_values.append(layer_values) + self.last_router_probs.append(router_probs) + + x = self.final_norm(x) + + if pos is not None: + x = x[:, pos, :] + + y_hat = self.linear(x) + return y_hat, all_attns if save_activations else None, all_values if save_activations else None + + @classmethod + def from_grok1_config( + cls, + scale_factor: float = 1/24, + vocab_len: int = 2000, + max_context_len: int = 50, + **kwargs, + ) -> "GrokOneTransformer": + """ + Create a miniature version of Grok-1 for grokking experiments. + Default scale_factor=1/24 maps Grok-1's architecture to: + emb: 6144 → 256, layers: 64 → 2, heads: 48 → 2, experts: 8 → 8 + + The expert count is preserved (not scaled) since MoE routing dynamics + are the key architectural feature under study. + """ + grok1_emb = 6144 + grok1_layers = 64 + grok1_heads = 48 + grok1_experts = 8 + grok1_selected = 2 + grok1_widening = 8 + + d_model = max(64, int(grok1_emb * scale_factor)) + # Ensure d_model is divisible by n_heads + n_heads = max(2, int(grok1_heads * scale_factor)) + d_model = d_model - (d_model % n_heads) + n_layers = max(2, int(grok1_layers * scale_factor)) + + return cls( + n_layers=n_layers, + n_heads=n_heads, + d_model=d_model, + max_context_len=max_context_len, + vocab_len=vocab_len, + num_experts=grok1_experts, + num_selected_experts=grok1_selected, + widening_factor=grok1_widening, + **kwargs, + ) diff --git a/grok-main/grok/visualization.py b/grok-main/grok/visualization.py new file mode 100644 index 0000000..5640ecf --- /dev/null +++ b/grok-main/grok/visualization.py @@ -0,0 +1,516 @@ +import csv +import logging +import os +import math +import socket + +from collections import defaultdict +from copy import deepcopy + +import matplotlib.pyplot as plt +import matplotlib.ticker as mtick +import numpy as np +import torch + +from mpl_toolkits.axes_grid1 import make_axes_locatable +from tqdm import tqdm + +from grok.data import ArithmeticDataset + +logging.basicConfig(level=logging.ERROR) +logger = logging.getLogger("grok.view_metrics") +logger.setLevel(logging.ERROR) + +GROK_DIR = os.path.expanduser("~/data/grok") +IMAGE_DIR = f"{GROK_DIR}/images" +DATA_DIR = f"{GROK_DIR}/data" + + +DEFAULT_CMAP = "viridis" + +default_metric_limits = { + "min_val_accuracy": 0, + "max_val_accuracy": 100, + "min_T": 0, # 0 + "max_T": 100, # 87.5 + "min_D": 0, # 8 + "max_D": 2048, # 256 + "min_H": 0, # 1 + "max_H": 1204, # 8 + "min_L": 0, # 1 + "max_L": 1024, # 4 + "min_accuracy": 0, + "max_accuracy": 100, +} + +default_axis_scales = {"x": "linear", "y": "linear"} + + +## Data Loading Functions + + +def factor_expts(expts): + result = {} + for expt in expts: + expt_s = expt.split("_") + arch = "_".join(expt_s[:3]) + t = int(float(expt_s[3].split("-")[1])) + result.setdefault(arch, {}) + result[arch][t] = expt + return result + + +def load_metric_data(data_dir, epochs=100000, load_partial_data=True): + # layers x heads x d_model x train_pct + data = {} + expts = os.listdir(data_dir) + archs = factor_expts(expts) + logger.debug(archs) + for arch in archs: + T = sorted(archs[arch].keys()) + data[arch] = { + "T": torch.LongTensor(T), + "metrics": torch.zeros((max(T), 5, epochs)), + } + # print(f"metrics_shape = {data[arch]['metrics'].shape}") + for i, t in tqdm(list(enumerate(T))): + expt = archs[arch][t] + logger.debug(expt) + log_dir = data_dir + "/" + expt + + # print("log_dir", log_dir) + try: + with open(log_dir + "/default/version_0/metrics.csv", "r") as fh: + logger.debug(f"loading {log_dir}") + reader = list(csv.DictReader(fh)) + val_t = torch.FloatTensor( + [ + [ + float(r["val_loss"]), + float(r["val_accuracy"]), + ] + for r in reader + if r["val_loss"] + ] + ).T + train_t = torch.FloatTensor( + [ + [ + float(r["learning_rate"]), + float(r["train_loss"]), + float(r["train_accuracy"]), + ] + for r in reader + if r["train_loss"] + ] + ).T + # logger.debug(val_t.shape) + # logger.debug(train_t[0, -3:]) + if load_partial_data: + raise Exception("Not implemented") + elif (val_t.shape[-1] >= epochs) and (train_t.shape[-1] >= epochs): + data[arch]["metrics"][i] = torch.cat( + [train_t[..., :epochs], val_t[..., :epochs]], dim=0 + ) + else: + data[arch]["T"][i] = 0 + # except FileNotFoundError: + except: + data[arch]["T"][i] = 0 + indices = torch.nonzero(data[arch]["T"]).squeeze() + if len(indices.shape) == 0: + indices = indices.unsqueeze(0) + # print(f"indices.shape = {indices.shape}") + data[arch]["T"] = data[arch]["T"][indices] + # print(f"data[arch]['T'].shape = {data[arch]['T'].shape}") + data[arch]["metrics"] = data[arch]["metrics"][indices] + # print(f"data[arch]['metrics'].shape = {data[arch]['metrics'].shape}") + data[arch]["metrics"] = torch.transpose(data[arch]["metrics"], 0, 1) + # print(f"data[arch]['metrics'].shape = {data[arch]['metrics'].shape}") + return data + + +def most_interesting(metric_data): + interesting_metric_data = {} + for arch in metric_data: + T = metric_data[arch]["T"] + max_acc_by_t = torch.max( + metric_data[arch]["val_accuracy"], dim=1, keepdim=True + ).values.squeeze() + max_loss_by_t = torch.max( + metric_data[arch]["val_loss"], dim=1, keepdim=True + ).values.squeeze() + acc_idx = torch.nonzero(max_acc_by_t >= 95).squeeze() + if acc_idx.shape == torch.Size([0]): + acc_idx = torch.nonzero(max_acc_by_t == max_acc_by_t.max()).squeeze() + if acc_idx.shape == torch.Size([]): + acc_idx = acc_idx.unsqueeze(0) + max_loss = torch.max(max_loss_by_t[acc_idx]) + loss_idx = torch.nonzero(max_loss_by_t[acc_idx] == max_loss) + interesting_idx = acc_idx[loss_idx].squeeze() + + interesting_metric_data[arch] = {} + for k in metric_data[arch]: + interesting_metric_data[arch][k] = metric_data[arch][k][ + interesting_idx + ].unsqueeze(0) + + return interesting_metric_data + + +# ## Graph Drawing Functions + + +def moving_avg(Y, steps): + return np.convolve(Y, np.ones(steps), "valid") / steps + + +def find_inflections(Y, smoothing_steps=100): + avg_Y = moving_avg(Y, smoothing_steps) + avg_direction = torch.FloatTensor(np.sign(avg_Y[1:] - avg_Y[:-1])) + avg_direction = torch.cat([avg_direction[0].unsqueeze(0), avg_direction]) + avg_inflections = torch.nonzero(avg_direction[1:] - avg_direction[:-1]).squeeze() + avg_inflections = [0] + (avg_inflections + 1).tolist() + [len(Y) - 1] + logger.debug(f"avg_inflections = {avg_inflections}") + inflections = [] + for i in range(2, len(avg_inflections)): + low = avg_inflections[i - 2] + high = avg_inflections[i] + logger.debug(f"low={low}") + logger.debug(f"high={high}") + if avg_direction[low + 1] < 0: + indices = Y[low:high].argmin() + low + logger.debug(f"min = (Y[{indices}] = {Y[int(indices)]}") + else: + indices = Y[low:high].argmax() + low + logger.debug(f"max = (Y[{indices}] = {Y[int(indices)]}") + inflections.append(indices) + return torch.LongTensor(inflections) + + +def check_limits(arch_name, limits): + L, H, D = [float(v.split("-")[1]) for v in arch_name.split("_")] + if (L > limits["max_L"]) or (L < limits["min_L"]): + return False + if (H > limits["max_H"]) or (H < limits["min_H"]): + return False + if (D > limits["max_D"]) or (D < limits["min_D"]): + return False + # if (T > limits['max_T']) or (T < limits['min_T']): + # return False + return True + + +def filter_archs(data, limits={}): + my_limits = deepcopy(default_metric_limits) + my_limits.update(limits) + limits = my_limits + archs = sorted(list(set([a for a in data.keys() if check_limits(a, limits)]))) + logger.debug(f"archs = {archs}") + return archs + + +def get_metric_data(data, limits={}): + my_limits = deepcopy(default_metric_limits) + my_limits.update(limits) + limits = my_limits + + for k in limits.keys(): + metric = k.replace("min_", "").replace("max_", "") + assert ( + limits["max_" + metric] >= limits["min_" + metric] + ), f"invalid {metric} limits" + + d = {} + for arch in filter_archs(data, limits): + logger.debug(arch) + indices = torch.nonzero( + torch.logical_and( + data[arch]["T"] >= limits["min_T"], data[arch]["T"] <= limits["max_T"] + ) + ).squeeze(dim=-1) + logger.debug(f"indices={indices}") + learning_rate, train_loss, train_accuracy, val_loss, val_accuracy = data[arch][ + "metrics" + ][:, indices, :] + d[arch] = { + "T": data[arch]["T"][indices], + "learning_rate": data[arch]["metrics"][0, indices, :], + "train_loss": data[arch]["metrics"][1, indices, :], + "train_accuracy": data[arch]["metrics"][2, indices, :], + "val_loss": data[arch]["metrics"][3, indices, :], + "val_accuracy": data[arch]["metrics"][4, indices, :], + } + return d + + +def add_metric_graph( + fig, + ax, + metric, + metric_data, + scales=default_axis_scales, + cmap=DEFAULT_CMAP, + inflection_hline=False, + ds_len=None, + batchsize=97, +): + ax.set_title(metric) + ax.set_xscale(scales["x"]) + ax.set_yscale(scales["y"]) + if ds_len is None: + ax.set_xlabel("epochs") + else: + ax.set_xlabel("updates") + + # if 'loss' in metric: + # ymin=0 + # ax.axis(ymin=ymin) + if "accuracy" in metric: + ax.yaxis.set_major_formatter(mtick.PercentFormatter()) + ymin = 1e-16 + ymax = 101 + ax.axis(ymin=ymin, ymax=ymax) + if "loss" in metric: + ymin = 1e-16 + ymax = 15 + ax.axis(ymin=ymin, ymax=ymax) + + total_plots = 0 + logger.debug(f"processing {metric}") + plots = [] + for arch in metric_data: + metric_data[arch]["T"] = metric_data[arch]["T"].squeeze() + logger.debug((" " * 4) + f"arch = {arch}") + if len(metric_data[arch]["T"].shape) == 0: + metric_data[arch]["T"] = metric_data[arch]["T"].unsqueeze(0) + T_min = int(metric_data[arch]["T"][0]) + T_max = int(metric_data[arch]["T"][-1]) + # T_min = 0 + # T_max = 88 + sm = plt.cm.ScalarMappable( + cmap=cmap, norm=plt.Normalize(vmin=T_min, vmax=T_max) + ) + colors = sm.to_rgba(metric_data[arch]["T"]) + for i, t in enumerate(metric_data[arch]["T"]): + if ds_len is None: + steps_per_epoch = 1 + else: + train_rows, val_rows = ArithmeticDataset.calc_split_len( + t.item(), ds_len + ) + steps_per_epoch = math.ceil(train_rows / batchsize) + + logger.debug((" " * 4) + f"t = {t}") + # print( + # f"metric_data[arch][metric].shape = {metric_data[arch][metric].shape}" + # ) + Y = metric_data[arch][metric][i] + # print(f"Y = {Y}") + assert len(Y.shape) == 1, f"Y.shape = {Y.shape} is invalid" + X = torch.arange(1, Y.shape[0] + 1) * steps_per_epoch + assert len(X.shape) == 1, f"X.shape = {X.shape} is invalid" + + label = arch + f" t={t}" + + # ax.set_xlim(left=X[0], right=X[-1] + 1) + if metric == "val_loss" and inflection_hline: + Y_infs = find_inflections(Y) + ax.axhline(y=Y[Y_infs[0]], color="orange") + if metric == "val_accuracy": + label += " (max = %.2f)" % max(Y) + total_plots += 1 + ax.plot(X, Y, label=label, color=colors[i]) + if T_max - T_min <= 10: + pass + ax.legend() + else: + fig.colorbar( + sm, + ax=ax, + label="% training data", + ticks=range(T_min, T_max, int((T_max - T_min) / 5)), + ) + + +def add_comm_graph( + ax, metric, kind, comm_data, arch, scales=default_axis_scales, cmap=DEFAULT_CMAP +): + assert metric in ( + "loss", + "accuracy", + "perplexity", + ) + assert kind in ( + "comm", + "non_comm", + "modulo", + "non_modulo", + "assoc", + "non_assoc", + "zero", + "non_zero", + ) + ax.set_title(metric) + ax.set_xscale(scales["x"]) + ax.set_yscale(scales["y"]) + ax.set_xlabel("epochs") + if "accuracy" in metric: + ax.yaxis.set_major_formatter(mtick.PercentFormatter()) + X = [int(r["epoch"]) for r in comm_data] + Y = torch.tensor( + ( + [float(r["comm" + "_" + metric]) for r in comm_data], + [float(r["non_comm" + "_" + metric]) for r in comm_data], + # [float(r["assoc" + "_" + metric]) for r in comm_data], + # [float(r["non_assoc" + "_" + metric]) for r in comm_data], + # [float(r["zero" + "_" + metric]) for r in comm_data], + # [float(r["non_zero" + "_" + metric]) for r in comm_data], + ) + ) + # label = kind + # if kind.endswith("comm"): + # label += "utative" + + labels = ["commutative", "non-commutative"] + # labels = ["zero", "non_zero"] + # labels = ["associative", "non_associative"] + # label = f"{arch} {kind}_{metric}" + # ax.plot(X, Y, label=label) + sm = plt.cm.ScalarMappable(cmap="cividis", norm=plt.Normalize(vmin=0, vmax=len(Y))) + colors = sm.to_rgba(range(len(Y))) + ax.stackplot(X, Y, baseline="zero", labels=labels, colors=colors) + ax.legend() + + +def add_extremum_graph( + ax, + metric, + kind, + metric_data, + scales=default_axis_scales, + epochs=[-1], + show_legend=True, +): + assert kind in ("max", "min") + ax.set_title(f"{kind} {metric}") + ax.set_xlabel("training data") + ax.xaxis.set_major_formatter(mtick.PercentFormatter()) + xmin = 0 + xmax = 100 + ax.axis(xmin=xmin, xmax=xmax) + + # ax.set_ylabel(metric) + ax.set_xscale(scales["x"]) + ax.set_yscale(scales["y"]) + if "accuracy" in metric: + ax.yaxis.set_major_formatter(mtick.PercentFormatter()) + ymin = -1 + ymax = 105 + ax.axis(ymin=ymin, ymax=ymax) + + # if 'learning' in metric: + # ymin=0 + # ymax=0.002 + # ax.axis(ymin=ymin, ymax=ymax) + + plots = {} + + total_plots = 0 + for arch in metric_data: + X = metric_data[arch]["T"] + if kind == "max": + Y = torch.max( + metric_data[arch][metric], dim=1, keepdim=True + ).values.squeeze() + elif kind == "min": + Y = torch.min( + metric_data[arch][metric], dim=1, keepdim=True + ).values.squeeze() + + # ax.set_xlim(0, 100) + ax.set_xticks(np.arange(0, 100, 5)) + label = f"{kind} {metric} {arch}" + ax.plot(X, Y, label=label) + total_plots += 1 + + if show_legend and total_plots <= 12: + ax.legend() + pass + + +def add_inflection_graphs( + ax, metric, metric_data, scales=default_axis_scales, smoothing_steps=100 +): + ax.set_title(f"{metric} inflections by train_data_pct") + ax.set_xlabel("train_data_pct") + ax.set_ylabel(f"{metric} inflections") + ax.set_xscale(scales["x"]) + ax.set_yscale(scales["y"]) + if "accuracy" in metric: + ymin = 0 + ymax = 100 + ax.axis(xmin=0, xmax=87.5, ymin=ymin, ymax=ymax) + if "learning" in metric: + ymin = 0 + ymax = 0.002 + ax.axis(xmin=0, xmax=87.5, ymin=ymin, ymax=ymax) + + total_plots = 0 + for arch in metric_data: + for num in range(5): + for i, t in enumerate(metric_data[arch]["T"]): + Y = metric_data[arch][metric][i] + X = torch.arange(Y.shape[-1]) + inflections = find_inflections(Y, smoothing_steps=smoothing_steps) + ax.plot(X[inflections], Y[inflections], label=f"{arch} t={t}") + total_plots += 1 + + if total_plots <= 12: + ax.legend() + pass + + +def colorbar(mappable, ticks=None, labels=None): + last_axes = plt.gca() + ax = mappable.axes + fig = ax.figure + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="5%", pad=0.1) + cbar = fig.colorbar(mappable, cax=cax, ticks=ticks) + if labels is not None: + cbar.ax.set_yticklabels(labels) # vertically oriented colorbar + plt.sca(last_axes) + return cbar + + +def add_matshow( + fig, ax, t, name, vmin=0, vmax=100, cmap=DEFAULT_CMAP, show_colorbar=True +): + sides = ("left", "right", "top", "bottom") + labels = { + "left": True, + "right": False, + "top": False, + "bottom": True, + "labelleft": True, + "labelright": False, + "labeltop": False, + "labelbottom": True, + } + m = ax.matshow( + t.cpu().detach().numpy(), vmin=vmin, vmax=vmax, origin="lower", cmap=cmap + ) + # c = ax.pcolor(t.cpu(), vmin=vmin, vmax=vmax, cmap=cmap) + ax.set_title(name) + ax.set_xlabel("A") + ax.set_ylabel("B") + ax.set_xticks(np.arange(0, t.shape[1], 10)) + # ax.set_xticklabels(np.arange(1, t.shape[1]+1)) + # ax.set_yticks(np.arange(0.5, t.shape[0] + .5, 1)) + ax.set_yticks(np.arange(0, t.shape[0], 10)) + # ax.set_yticks(np.arange(t.shape[0])) + # ax.set_yticklabels(np.arange(1, t.shape[0]+1)) + ax.tick_params(axis="both", which="both", **labels) + if show_colorbar: + colorbar(m) diff --git a/grok-main/nbs/flatness.ipynb b/grok-main/nbs/flatness.ipynb new file mode 100644 index 0000000..b90cb18 --- /dev/null +++ b/grok-main/nbs/flatness.ipynb @@ -0,0 +1,71 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "geographic-personal", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import pickle\n", + "import matplotlib.pyplot as plt\n", + "\n", + "def paired_sort(list1, list2):\n", + " list1, list2 = zip(*sorted(zip(list1, list2)))\n", + " return list1, list2\n", + "\n", + "def plot_phi_by_ckpt():\n", + "\n", + " nums = []\n", + " flatness = []\n", + "\n", + " for f in sorted(os.listdir(\"../results/\")):\n", + " if \"pkl\" in f:\n", + " num = int(f.split(\"-\")[1].split(\".pkl\")[0])\n", + " dat = pickle.load(open(os.path.join(\"../results/\", f), \"rb\"))\n", + " nums.append(num)\n", + " flatness.append(dat[list(dat.keys())[0]].item())\n", + "\n", + " \n", + " nums, flatness = paired_sort(nums, flatness)\n", + " plt.plot(nums, flatness)\n", + " plt.xticks(range(len(nums)))\n", + " plt.ylabel(\"phi\")\n", + " plt.xlabel(\"SD-{n} checkpoint\")\n", + " plt.savefig(\"phi-by-ckpt.png\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "mineral-assembly", + "metadata": {}, + "outputs": [], + "source": [ + "plot_phi_by_ckpt()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/grok-main/scripts/compute_sharpness.py b/grok-main/scripts/compute_sharpness.py new file mode 100755 index 0000000..bb9f619 --- /dev/null +++ b/grok-main/scripts/compute_sharpness.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python + +import os +import grok + +parser = grok.training.add_args() +parser.set_defaults(logdir=os.environ.get("LOGDIR", ".")) +hparams = parser.parse_args() +hparams.datadir = os.path.abspath(hparams.datadir) +hparams.logdir = os.path.abspath(hparams.logdir) + + +print(hparams) + +ckpts = [f"./ckpts/L-2_H-4_D-128_T-70_DROP-0_SD-{i}_WU-10_LR-1p0.ckpt" for i in range(20)] +print(grok.training.compute_sharpness(hparams, ckpts)) diff --git a/grok-main/scripts/create_metric_graphs.py b/grok-main/scripts/create_metric_graphs.py new file mode 100755 index 0000000..c775a86 --- /dev/null +++ b/grok-main/scripts/create_metric_graphs.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python +# coding: utf-8 + +# Render metrics graphs + +import csv +import logging +import os +import glob +import socket +from argparse import ArgumentParser + +from collections import defaultdict + +import matplotlib.pyplot as plt +import matplotlib.ticker as mtick +import numpy as np +import torch + +from mpl_toolkits.axes_grid1 import make_axes_locatable +from tqdm import tqdm + +from sklearn.manifold import TSNE + +import grok +from grok.visualization import * + +# from grok_runs import RUNS + +logging.basicConfig(level=logging.ERROR) +logger = logging.getLogger("grok.view_metrics") +logger.setLevel(logging.ERROR) + +RUNS = { + "subtraction": ( + 9409, + "subtraction/2021-02-05-03-33-56-alethea-sjjf", + ), +} + + +limits = { + "min_val_accuracy": 0, + "max_val_accuracy": 100, + "min_T": 0, # 0 + "max_T": 100, # 87.5 + "min_D": 0, # 8 + "max_D": 256, # 256 + "min_H": 0, # 1 + "max_H": 4, # 8 + "min_L": 0, # 1 + "max_L": 4, # 4 + "min_accuracy": 0, + "max_accuracy": 100, +} + +for k in limits.keys(): + metric = k.replace("min_", "").replace("max_", "") + assert ( + limits["max_" + metric] >= limits["min_" + metric] + ), f"invalid {metric} limits" + + +parser = ArgumentParser() +parser.add_argument("-i", "--image_dir", type=str, default=IMAGE_DIR) +args = parser.parse_args() + + +def create_loss_curves( + metric_data, + epochs, + run, + most_interesting_only=False, + image_dir=args.image_dir, + ds_len=None, + cmap=DEFAULT_CMAP, +): + scales = { + "x": "log", + "y": "linear", + } + + + arch = list(metric_data.keys())[0] + + ncols = 2 + nrows = 3 + fig_width = ncols * 8 + fig_height = nrows * 5 + fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height)) + + add_metric_graph( + fig, axs[0, 0], "val_loss", metric_data, scales, cmap=cmap, ds_len=ds_len + ) + add_metric_graph( + fig, axs[0, 1], "val_accuracy", metric_data, scales, cmap, ds_len=ds_len + ) + add_metric_graph( + fig, axs[1, 0], "train_loss", metric_data, scales, cmap, ds_len=ds_len + ) + add_metric_graph( + fig, axs[1, 1], "train_accuracy", metric_data, scales, cmap, ds_len=ds_len + ) + add_metric_graph( + fig, + axs[2, 0], + "learning_rate", + metric_data, + scales, + cmap, # ds_len=ds_len + ) + fig.suptitle(f"{operation} {list(data.keys())[0]}") + fig.tight_layout() + + img_file = f"{image_dir}/loss_curves/{operation}_loss_curves_{arch}" + if ds_len is not None: + img_file += "_by_update" + if most_interesting_only: + img_file += "_most_interesting" + img_file += ".png" + d = os.path.split(img_file)[0] + os.makedirs(d, exist_ok=True) + print(f"Writing {img_file}") + fig.savefig(img_file) + plt.close(fig) + + +def create_max_accuracy_curves( + metric_data, epochs, run, image_dir=args.image_dir, ds_len=None +): + scales = { + "x": "linear", + "y": "linear", + } + + ncols = 1 + nrows = 2 + fig_width = ncols * 8 + fig_height = nrows * 5 + fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height)) + + def get_ax(row=0, col=0, nrows=nrows, ncols=ncols, axs=axs): + if nrows == 0: + if ncols == 1: + return axs + else: + return axs[col] + else: + if ncols == 1: + return axs[row] + else: + return axs[row, col] + + add_extremum_graph( + get_ax(0, 0), "val_accuracy", "max", metric_data, show_legend=False + ) + add_extremum_graph( + get_ax(1, 0), "train_accuracy", "max", metric_data, show_legend=False + ) + fig.suptitle(f"{operation} {list(data.keys())[0]}") + fig.tight_layout() + + expt = list(metric_data.keys())[0] + img_file = f"{image_dir}/max_accuracy/{operation}_max_accuracy_{arch}.png" + d = os.path.split(img_file)[0] + os.makedirs(d, exist_ok=True) + print(f"Writing {img_file}") + fig.savefig(img_file) + plt.close(fig) + + +def create_tsne_graphs(operation, expt, run_dir, image_dir=args.image_dir): + + saved_pt_dir = f"{run_dir}/activations" + saved_pts = [] + + loss_ts = [] + accuracy_ts = [] + epochs_ts = [] + print(f'glob = {saved_pt_dir + "/activations_*.pt"}') + files = sorted(glob.glob(saved_pt_dir + "/activations_*.pt")) + print(f"files = {files}") + + for file in files: + print(f"Loading {file}") + saved_pt = torch.load(file) + saved_pts.append(saved_pt) + loss_ts.append(saved_pt["val_loss"].mean(dim=-1)) + accuracy_ts.append(saved_pt["val_accuracy"]) + epochs_ts.append(saved_pt["epochs"].squeeze()) + + loss_t = torch.cat(loss_ts, dim=0).T.detach() + accuracy_t = torch.cat(accuracy_ts, dim=0).T.detach() + epochs_t = torch.cat(epochs_ts, dim=0).detach() + print(loss_t.shape) + print(accuracy_t.shape) + print(epochs_t.shape) + ###### + a = 0 + num_eqs = len(loss_t) + b = a + num_eqs + + print("Doing T-SNE..") + loss_tsne = TSNE(n_components=2, init="pca").fit_transform(loss_t) + print("...done T-SNE.") + + ncols = 1 + nrows = 1 + fig_width = ncols * 8 + fig_height = nrows * 5 + fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height)) + + axs.scatter(loss_tsne[:, 0], loss_tsne[:, 1]) + + img_file = f"{image_dir}/tsne/{operation}_{expt}.png" + d = os.path.split(img_file)[0] + os.makedirs(d, exist_ok=True) + print(f"Writing {img_file}") + fig.savefig(img_file) + plt.close(fig) + + +for operation in RUNS: + print("") + print("") + print(f"Processing {operation}", flush=True) + + if operation.endswith("-epochs"): + epochs = int(operation.split("/")[-1].split("-")[0]) + else: + epochs = 5000 + + #### + + ds_len, run = RUNS[operation] + + + data = load_metric_data(f"{DATA_DIR}/{run}", epochs=epochs, load_partial_data=False) + + # check it + for arch in data: + # print(data[arch]["metrics"].shape) + metrics, expts, epochs = data[arch]["metrics"].shape + message = ( + f"{arch} : loaded {metrics} metrics, {expts} experiments, {epochs} epochs" + ) + assert metrics == 5, "INVALID metrics count: " + message + assert expts < 88, "INVALID experiments count: " + message + assert epochs == epochs, f"INVALID epochs count: " + message + print(message) + + # ## Set filters on the data to view + + metric_data = get_metric_data(data, limits) + + # Draw loss and accuracy curves + + create_max_accuracy_curves(metric_data, epochs, run) + + create_loss_curves(metric_data, epochs, run) + create_loss_curves(metric_data, epochs, run, ds_len=ds_len) + + most_interesting_metric_data = most_interesting(metric_data) + + create_loss_curves( + most_interesting_metric_data, epochs, run, most_interesting_only=True + ) + create_loss_curves( + most_interesting_metric_data, + epochs, + run, + most_interesting_only=True, + ds_len=ds_len, + ) + + # Draw max accuracy curves + + # T-SNE of loss curves: + try: + for arch in most_interesting_metric_data: + t = int(most_interesting_metric_data[arch]["T"][0].item()) + expt = f"{arch}_T-{t}_DROP-0.0" + create_tsne_graphs(operation, expt, run_dir=f"{DATA_DIR}/{run}/{expt}") + except: + print("TSNE failed") diff --git a/grok-main/scripts/create_metrics_for_epochs.py b/grok-main/scripts/create_metrics_for_epochs.py new file mode 100755 index 0000000..15a7cd5 --- /dev/null +++ b/grok-main/scripts/create_metrics_for_epochs.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python + +import logging + +logging.basicConfig(level=logging.ERROR) +import csv +import copy +import os +import grok +import numpy as np +import sys +import subprocess +import torch +from torch.multiprocessing import Process +from grok import trainer +from tqdm import tqdm +from argparse import ArgumentParser +from collections import Counter + + +torch.multiprocessing.freeze_support() +try: + torch.multiprocessing.set_start_method("spawn") +except RuntimeError: + pass + +# Get args +EPOCHS = ( + list(range(10)) + + list(range(10, 200, 2)) + + list(range(200, 5000, 10)) + + list(range(5000, 10000, 50)) + + [10000] +) + +parser = ArgumentParser() +parser.add_argument( + "--data_dir", type=str, help="where to find the runs", required=True +) +parser.add_argument("--expt", type=str, default=None) +parser.add_argument("--epochs_per_run", type=int, default=40) + + +def parent(expts): + for expt in expts: + print(f"Processing {expt}") + all_results = {} + for first_epoch in range(0, len(EPOCHS), hparams.epochs_per_run): + these_epochs = [ + str(e) + for e in EPOCHS[first_epoch : first_epoch + hparams.epochs_per_run] + ] + expt_dir = data_dir + "/" + expt + cmd = [ + "./create_partial_metrics.py", + f"--gpu={hparams.gpu}", + f"--expt_dir={expt_dir}", + f'--epochs={",".join(these_epochs)}', + ] + result = subprocess.run(cmd, capture_output=False, shell=False) + if result.returncode != 0: + sys.exit(result.returncode) + + +hparams = trainer.get_args(parser) + +data_dir = hparams.data_dir + +if hparams.expt is not None: + expts = [hparams.expt] +else: + expts = os.listdir(data_dir) + +parent(expts) diff --git a/grok-main/scripts/create_partial_metrics.py b/grok-main/scripts/create_partial_metrics.py new file mode 100755 index 0000000..35985f9 --- /dev/null +++ b/grok-main/scripts/create_partial_metrics.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python + +import logging + +logging.basicConfig(level=logging.ERROR) +import csv +import copy +import glob +import os +import grok +import numpy as np +import subprocess +import torch +import sys +from torch.multiprocessing import Process +from grok import trainer +from tqdm import tqdm +from argparse import ArgumentParser +from collections import Counter +from grok_runs import RUNS +from grok_metrics_lib import ( + DATA_DIR, + load_metric_data, + get_metric_data, + most_interesting, +) + + +# Make N_EPOCHS exponentially spaced sets of epochs from 1 to 10,000 +N_EPOCHS = 32 +BASE = 9999 ** (1.0 / (N_EPOCHS - 1)) +epochs = (BASE ** torch.arange(1, N_EPOCHS).float()).long().tolist() +DEFAULT_EPOCHS = ",".join([str(i) for i in epochs]) + +parser = ArgumentParser() +parser.add_argument("--expt_dir", type=str, help="where to find the runs") +parser.add_argument("--epochs", type=str, default=DEFAULT_EPOCHS) + + +def child(hparams): + expt_dir = hparams.expt_dir + epochs = [int(e) for e in hparams.epochs.split(",")] + # print("epochs = ", epochs) + device = torch.device(f"cuda:{hparams.gpu}") + ckpt_dir = expt_dir + "/" + "checkpoints" + # ckpt_files = [ckpt_dir + f"/epoch={epoch}.ckpt" for epoch in epochs] + hparams.logdir = expt_dir + + results = { + "val_loss": None, + "val_accuracy": None, + } + + processed_epochs = [] + # with tqdm(epochs, unit="epochs", initial=epochs[0], total=epochs[-1]) as pbar: + # last_epoch = epochs[0] + for idx, epoch in tqdm(list(enumerate(epochs))): + # pbar.update(epoch - last_epoch) + # last_epoch = epoch + ckpt_files = glob.glob(ckpt_dir + f"/epoch={epoch}-step=*.ckpt") + ckpt_files += glob.glob(ckpt_dir + f"/epoch={epoch}.ckpt") + try: + ckpt_file = ckpt_files[-1] + ckpt = torch.load( + ckpt_file, + map_location=f"cuda:{0}", # FIXME + ) + processed_epochs.append(epoch) + except FileNotFoundError: + continue + + for k, v in ckpt["hyper_parameters"].items(): + setattr(hparams, k, v) + + new_state_dict = {} + for k, v in ckpt["state_dict"].items(): + if k.startswith("transformer."): + new_state_dict[k] = v + else: + new_state_dict["transformer." + k] = v + ckpt["state_dict"] = new_state_dict + + model = trainer.TrainableTransformer(hparams).float() + model.load_state_dict(ckpt["state_dict"]) + model = model.to(device).eval() + dl = model.test_dataloader() + dl.reset_iteration(shuffle=False) + + outputs = [model.test_step(batch, idx) for (idx, batch) in enumerate(dl)] + r = model.test_epoch_end(outputs)["log"] + if results["val_loss"] is None: + results["val_loss"] = r["test_loss"].squeeze().unsqueeze(0) + results["val_accuracy"] = r["test_accuracy"].squeeze().unsqueeze(0) + else: + results["val_loss"] = torch.cat( + [results["val_loss"], r["test_loss"].squeeze().unsqueeze(0)], dim=0 + ) + results["val_accuracy"] = torch.cat( + [ + results["val_accuracy"], + r["test_accuracy"].squeeze().unsqueeze(0), + ], + dim=0, + ) + + for k, v in results.items(): + results[k] = v.to("cpu") + results["epochs"] = torch.LongTensor(processed_epochs, device="cpu") + results["dl"] = dl + + os.makedirs(expt_dir + "/activations", exist_ok=True) + ptfile = ( + expt_dir + f"/activations/activations_{epochs[0]:010d}_{epochs[-1]:010d}.pt" + ) + torch.save(results, ptfile) + + +if __name__ == "__main__": + hparams = trainer.get_args(parser) + if hparams.expt_dir is not None: + child(hparams) + else: + for operation in RUNS: + print(f"running {operation}") + ds_len, run = RUNS[operation] + data = load_metric_data( + f"{DATA_DIR}/{run}", epochs=10000, load_partial_data=False + ) + metric_data = get_metric_data(data) + metric_data = most_interesting(metric_data) + for arch in metric_data: + interesting_t = int(metric_data[arch]["T"][0].item()) + expt = f"{arch}_T-{interesting_t}" + print(f"--> expt {expt}") + glb = f"{DATA_DIR}/{run}/{expt}_*" + # print(f"glb {glb}") + expt_dir = glob.glob(glb)[0] + cmd = [sys.argv[0], "--expt_dir", expt_dir] + subprocess.run(cmd, check=False, shell=False) + # child(hparams) diff --git a/grok-main/scripts/make_data.py b/grok-main/scripts/make_data.py new file mode 100755 index 0000000..1cc2d1c --- /dev/null +++ b/grok-main/scripts/make_data.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python + +from argparse import ArgumentParser +from grok.data import create_data_files, DEFAULT_DATA_DIR + + +parser = ArgumentParser() +parser.add_argument("-d", "--data_directory", type=str, default=DEFAULT_DATA_DIR) +args = parser.parse_args() +create_data_files(args.data_directory) \ No newline at end of file diff --git a/grok-main/scripts/torch-setup.sh b/grok-main/scripts/torch-setup.sh new file mode 100755 index 0000000..1bc2390 --- /dev/null +++ b/grok-main/scripts/torch-setup.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +# Set up torch with magma support + +DIR="`mktemp -d`" +cd $DIR + +# Install deps +conda install -y numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing_extensions future six requests dataclasses + +# Install torch, magma +conda install -y -c pytorch magma-cuda110 + +# Build torch from scratch +git clone --recursive https://github.com/pytorch/pytorch +cd pytorch + +# If updating an existing checkout +# git submodule sync +# git submodule update --init --recursive + +export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} +python setup.py install \ No newline at end of file diff --git a/grok-main/scripts/train.py b/grok-main/scripts/train.py new file mode 100755 index 0000000..ef03d76 --- /dev/null +++ b/grok-main/scripts/train.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python + +import grok +import os + +parser = grok.training.add_args() +parser.set_defaults(logdir=os.environ.get("GROK_LOGDIR", ".")) +hparams = parser.parse_args() +hparams.datadir = os.path.abspath(hparams.datadir) +hparams.logdir = os.path.abspath(hparams.logdir) + + +print(hparams) +print(grok.training.train(hparams)) diff --git a/grok-main/scripts/visualize_metrics.py b/grok-main/scripts/visualize_metrics.py new file mode 100755 index 0000000..bf34d58 --- /dev/null +++ b/grok-main/scripts/visualize_metrics.py @@ -0,0 +1,510 @@ +#!/usr/bin/env python +# coding: utf-8 +import csv +import json +import logging +import os +import subprocess +from argparse import ArgumentParser +from copy import deepcopy +from glob import glob +from pprint import pprint + +import blobfile as bf +import grok +import matplotlib.pyplot as plt +import matplotlib.ticker as mtick +import numpy as np +import torch +import yaml +from tqdm import tqdm + +logger = logging.getLogger(__name__) + +# take args: input_dir output_dir +parser = ArgumentParser() +parser.add_argument( + "-i", + "--input_dir", + type=str, + required=True, +) +parser.add_argument( + "-o", + "--output_dir", + type=str, + required=True, +) +parser = grok.training.add_args(parser) +args = parser.parse_args() +print(args, flush=True) + +if torch.cuda.is_available(): + device = "cuda" +else: + device = "cpu" + + +def load_expt_metrics( + expt_dir, + args, +): + """load the metrics for one experiment""" + args = deepcopy(args) + + # load the hparams for this experiment + with open(f"{expt_dir}/default/version_0/hparams.yaml", "r") as fh: + hparams_dict = yaml.safe_load(fh) + + for k, v in hparams_dict.items(): + setattr(args, k, v) + + # load the summarized validation and training data for every epoch + val_data = { + "step": [], + "epoch": [], + "val_loss": [], + "val_accuracy": [], + } + train_data = { + "step": [], + "epoch": [], + "train_loss": [], + "train_accuracy": [], + "learning_rate": [], + } + + with open(f"{expt_dir}/default/version_0/metrics.csv", "r") as fh: + for row in csv.DictReader(fh): + if row["train_loss"] != "": + for k in train_data: + if k in ["step", "epoch"]: + v = int(row[k]) + else: + v = float(row[k]) + train_data[k].append(v) + else: + for k in val_data: + if k in ["step", "epoch"]: + v = int(row[k]) + else: + v = float(row[k]) + val_data[k].append(v) + + return { + "hparams": hparams_dict, + "train": train_data, + "val": val_data, + # "raw": raw_data, + } + + +def load_run_metrics( + run_dir, + args=args, +): + """load all the metrics for a collection of experiments with the same architecture + across various amounts of training data""" + metric_data = {} + from os import walk + + _, expt_dirs, _ = next(os.walk(run_dir)) + for expt_dir in tqdm(expt_dirs, unit="expt"): + try: + expt_data = load_expt_metrics(f"{run_dir}/{expt_dir}", args) + train_data_pct = expt_data["hparams"]["train_data_pct"] + metric_data[train_data_pct] = expt_data + except FileNotFoundError: + pass + return metric_data + + +def add_metric_graph( + fig, + ax, + arch, + metric, + metric_data, + scales, + cmap="viridis", + by="step", # step or epoch + max_increment=0, +): + ax.set_title(metric) + ax.set_xscale(scales["x"]) + ax.set_yscale(scales["y"]) + ax.set_xlabel(by) + + if "accuracy" in metric: + ax.yaxis.set_major_formatter(mtick.PercentFormatter()) + ymin = 1e-16 + ymax = 101 + ax.axis(ymin=ymin, ymax=ymax) + if "loss" in metric: + ymin = 1e-16 + ymax = 15 + ax.axis(ymin=ymin, ymax=ymax) + + total_plots = 0 + logger.debug(f"processing {metric}") + plots = [] + T = list(sorted(metric_data.keys())) + T_max = int(T[-1]) + T_min = int(T[0]) + sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=T[0], vmax=T[-1])) + colors = sm.to_rgba(T) + for i, t in enumerate(T): + if "val" in metric: + this_data = metric_data[t]["val"] + else: + this_data = metric_data[t]["train"] + + X = this_data[by] + Y = this_data[metric] + if max_increment > 0: + X = [x for x in X if x <= max_increment] + Y = Y[: len(X)] + + if len(X) != len(Y): + logger.warning(f"Mismatched data: {metric} at t={t}") + continue + if not Y: + logger.warning(f"No data for {metric}i at t={t}") + continue + + label = arch + f" t={t}" + + if "accuracy" in metric: + label += " (max = %.2f)" % max(Y) + elif "loss" in metric: + label += " (min = %.2f)" % min(Y) + total_plots += 1 + ax.plot(X, Y, label=label, color=colors[i]) + if T_max - T_min <= 10: + ax.legend() + else: + fig.colorbar( + sm, + ax=ax, + label="% training data", + ticks=range(T_min, T_max + 1, int((T_max - T_min) / 5)), + ) + + +def add_max_accuracy_graph( + ax, + arch, + metric, + metric_data, + scales, + by="step", + max_increment=0, +): + ax.set_title(f"max {metric}") + ax.set_xlabel("% of total data trained on") + ax.xaxis.set_major_formatter(mtick.PercentFormatter()) + xmin = 0 + xmax = 100 + ymin = 1e-16 + ymax = 101 + ax.axis(xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax) + ax.set_xscale(scales["x"]) + ax.set_yscale(scales["y"]) + ax.yaxis.set_major_formatter(mtick.PercentFormatter()) + ax.xaxis.set_major_formatter(mtick.PercentFormatter()) + + T = list(sorted(metric_data.keys())) + T_max = int(T[-1]) + T_min = int(T[0]) + Y = [] + for i, t in enumerate(T): + if "val" in metric: + this_data = metric_data[t]["val"] + else: + this_data = metric_data[t]["train"] + X = this_data[by] + if max_increment > 0: + X = [x for x in X if x <= max_increment] + max_idx = len(X) + else: + max_idx = -1 + try: + Y.append(max(this_data[metric][:max_idx])) + except ValueError: + Y.append(np.nan) + + ax.set_xticks(np.arange(0, 100, 5)) + label = f"max {metric} {arch}" + ax.plot(T, Y, label=label) + + +def create_loss_curves( + metric_data, + arch, + operation, + # epochs, + most_interesting_only=False, + image_dir=args.output_dir, + by="step", + max_increment=0, + cmap="viridis", +): + scales = { + "x": "log", + "y": "linear", + } + + ncols = 2 + nrows = 3 + fig_width = ncols * 8 + fig_height = nrows * 5 + fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height)) + + add_metric_graph( + fig, + axs[0, 0], + arch, + "val_loss", + metric_data, + scales, + cmap, + by, + max_increment=max_increment, + ) + add_metric_graph( + fig, + axs[0, 1], + arch, + "val_accuracy", + metric_data, + scales, + cmap, + by, + max_increment=max_increment, + ) + add_metric_graph( + fig, + axs[1, 0], + arch, + "train_loss", + metric_data, + scales, + cmap, + by, + max_increment=max_increment, + ) + add_metric_graph( + fig, + axs[1, 1], + arch, + "train_accuracy", + metric_data, + scales, + cmap, + by, + max_increment=max_increment, + ) + add_metric_graph( + fig, + axs[2, 0], + arch, + "learning_rate", + metric_data, + scales, + cmap, + by, + max_increment=max_increment, + ) + fig.suptitle(f"{operation} {arch} {max_increment:06d} {by}s") + fig.tight_layout() + + img_file = f"{image_dir}/loss_curves/{operation}_loss_curves_{arch}__upto_{max_increment:010d}_{by}" + if most_interesting_only: + img_file += "_most_interesting" + img_file += ".png" + d = os.path.split(img_file)[0] + os.makedirs(d, exist_ok=True) + print(f"Writing {img_file}") + fig.savefig(img_file) + plt.close(fig) + + +def create_max_accuracy_curves( + metric_data, + arch, + operation, + by="step", + max_increment=0, + image_dir=args.output_dir, +): + scales = { + "x": "linear", + "y": "linear", + } + + ncols = 1 + nrows = 2 + fig_width = ncols * 8 + fig_height = nrows * 5 + fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height)) + + add_max_accuracy_graph( + axs[0], + arch, + "val_accuracy", + metric_data, + scales, + by=by, + max_increment=max_increment, + ) + axs[0].legend() + add_max_accuracy_graph( + axs[1], + arch, + "train_accuracy", + metric_data, + scales, + by=by, + max_increment=max_increment, + ) + axs[1].legend() + fig.suptitle(f"{operation} {arch} {max_increment:06d} {by}s") + fig.tight_layout() + + img_file = f"{image_dir}/max_accuracy/{operation}_max_accuracy_{arch}_upto_{max_increment:010d}_{by}.png" + d = os.path.split(img_file)[0] + os.makedirs(d, exist_ok=True) + print(f"Writing {img_file}") + fig.savefig(img_file) + plt.close(fig) + + +def create_tsne_graphs( + operation, + expt, + run_dir, + image_dir=args.output_dir, +): + + saved_pt_dir = f"{run_dir}/activations" + saved_pts = [] + + loss_ts = [] + accuracy_ts = [] + epochs_ts = [] + print(f'glob = {saved_pt_dir + "/activations_*.pt"}') + files = sorted(glob.glob(saved_pt_dir + "/activations_*.pt")) + print(f"files = {files}") + + for file in files: + print(f"Loading {file}") + saved_pt = torch.load(file) + saved_pts.append(saved_pt) + loss_ts.append(saved_pt["val_loss"].mean(dim=-1)) + accuracy_ts.append(saved_pt["val_accuracy"]) + epochs_ts.append(saved_pt["epochs"].squeeze()) + + loss_t = torch.cat(loss_ts, dim=0).T.detach() + accuracy_t = torch.cat(accuracy_ts, dim=0).T.detach() + epochs_t = torch.cat(epochs_ts, dim=0).detach() + print(loss_t.shape) + print(accuracy_t.shape) + print(epochs_t.shape) + ###### + a = 0 + num_eqs = len(loss_t) + b = a + num_eqs + + print("Doing T-SNE..") + loss_tsne = TSNE(n_components=2, init="pca").fit_transform(loss_t) + print("...done T-SNE.") + + ncols = 1 + nrows = 1 + fig_width = ncols * 8 + fig_height = nrows * 5 + fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height)) + + axs.scatter(loss_tsne[:, 0], loss_tsne[:, 1]) + + img_file = f"{image_dir}/tsne/{operation}_{expt}.png" + d = os.path.split(img_file)[0] + os.makedirs(d, exist_ok=True) + print(f"Writing {img_file}") + fig.savefig(img_file) + plt.close(fig) + + +def get_arch(metric_data): + k = list(metric_data.keys())[0] + hparams = metric_data[k]["hparams"] + arch = f'L-{hparams["n_layers"]}_H-{hparams["n_heads"]}_D-{hparams["d_model"]}_B-{hparams["batchsize"]}_S-{hparams["random_seed"]}_DR-{hparams["dropout"]}' + return arch + + +def get_operation(metric_data): + k = list(metric_data.keys())[0] + hparams = metric_data[k]["hparams"] + operator = hparams["math_operator"] + operand_length = hparams["operand_length"] + _, operation = grok.data.ArithmeticDataset.get_file_path(operator, operand_length) + return operation + + +def get_max_epochs(metric_data): + k = list(metric_data.keys())[0] + hparams = metric_data[k]["hparams"] + return hparams["max_epochs"] + + +rundir = args.input_dir + +try: + metric_data = load_run_metrics(rundir, args) + arch = get_arch(metric_data) + operation = get_operation(metric_data) + max_epochs = get_max_epochs(metric_data) + + for by in ["step", "epoch"]: + create_loss_curves(metric_data, arch, operation, by=by) + + by = "epoch" + last_i = -1 + for i in sorted(list(set(2 ** (np.arange(167) / 10)))): + if i > max_epochs: + break + i = int(round(i)) + create_max_accuracy_curves( + metric_data, + arch, + operation, + by=by, + max_increment=i, + ) + + # make a video + in_files = os.path.join( + args.output_dir, + "max_accuracy", + f"{operation}_max_accuracy_{arch}_upto_%*.png", + ) + out_file = os.path.join(args.output_dir, f"{operation}_{arch}_max_accuracy.mp4") + cmd = [ + "ffmpeg", + "-y", + "-r", + "16", + "-i", + in_files, + "-vcodec", + "libx264", + "-crf", + "25", + "-pix_fmt", + "yuv420p", + out_file, + ] + subprocess.check_call(cmd) + +except BaseException as e: + print(f"{rundir} failed: {e}") diff --git a/grok-main/setup.py b/grok-main/setup.py new file mode 100644 index 0000000..940a4cd --- /dev/null +++ b/grok-main/setup.py @@ -0,0 +1,22 @@ +from setuptools import find_packages, setup + +setup( + name="grok-qiangjian", + packages=find_packages(), + version="0.0.2-qiangjian", + description=( + "Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets. " + "强兼 (forceful compatibility) bridge: now includes Grok-1 MoE architecture." + ), + url="https://github.com/openai/grok", + install_requires=[ + "pytorch_lightning", + "blobfile", + "numpy", + "torch", + "tqdm", + "scipy", + "mod", + "matplotlib", + ], +)