From 26aad34c48f46373391d99c4c439b86858fd359c Mon Sep 17 00:00:00 2001
From: RoomWithOutRoof <166608075+Jah-yee@users.noreply.github.com>
Date: Tue, 3 Mar 2026 08:45:54 +0800
Subject: [PATCH 1/2] =?UTF-8?q?Initial=20commit:=20grok=20=E5=BC=BA?=
=?UTF-8?q?=E5=85=BC=20-=20bridge=20openai/grok=20and=20xai-org/grok-1?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Made-with: Cursor
---
README.md | 246 +++
does_grok_grok.py | 287 ++++
grok-1-main/.gitignore | 2 +
grok-1-main/CODE_OF_CONDUCT.md | 1 +
grok-1-main/LICENSE.txt | 202 +++
grok-1-main/README.md | 72 +
grok-1-main/checkpoint.py | 221 +++
grok-1-main/checkpoints/README.md | 3 +
grok-1-main/model.py | 1462 +++++++++++++++++
grok-1-main/pyproject.toml | 14 +
grok-1-main/requirements.txt | 4 +
grok-1-main/run.py | 261 +++
grok-1-main/runners.py | 605 +++++++
grok-1-main/tokenizer.model | Bin 0 -> 2229219 bytes
grok-main/.gitignore | 133 ++
grok-main/LICENSE | 21 +
grok-main/README.md | 32 +
grok-main/grok/__init__.py | 20 +
grok-main/grok/data.py | 612 +++++++
grok-main/grok/measure.py | 139 ++
grok-main/grok/metrics.py | 372 +++++
grok-main/grok/training.py | 1134 +++++++++++++
grok-main/grok/transformer.py | 756 +++++++++
grok-main/grok/visualization.py | 516 ++++++
grok-main/nbs/flatness.ipynb | 71 +
grok-main/scripts/compute_sharpness.py | 16 +
grok-main/scripts/create_metric_graphs.py | 285 ++++
.../scripts/create_metrics_for_epochs.py | 74 +
grok-main/scripts/create_partial_metrics.py | 140 ++
grok-main/scripts/make_data.py | 10 +
grok-main/scripts/torch-setup.sh | 23 +
grok-main/scripts/train.py | 14 +
grok-main/scripts/visualize_metrics.py | 510 ++++++
grok-main/setup.py | 22 +
34 files changed, 8280 insertions(+)
create mode 100644 README.md
create mode 100644 does_grok_grok.py
create mode 100644 grok-1-main/.gitignore
create mode 100644 grok-1-main/CODE_OF_CONDUCT.md
create mode 100644 grok-1-main/LICENSE.txt
create mode 100644 grok-1-main/README.md
create mode 100644 grok-1-main/checkpoint.py
create mode 100644 grok-1-main/checkpoints/README.md
create mode 100644 grok-1-main/model.py
create mode 100644 grok-1-main/pyproject.toml
create mode 100644 grok-1-main/requirements.txt
create mode 100644 grok-1-main/run.py
create mode 100644 grok-1-main/runners.py
create mode 100644 grok-1-main/tokenizer.model
create mode 100644 grok-main/.gitignore
create mode 100644 grok-main/LICENSE
create mode 100644 grok-main/README.md
create mode 100644 grok-main/grok/__init__.py
create mode 100644 grok-main/grok/data.py
create mode 100755 grok-main/grok/measure.py
create mode 100644 grok-main/grok/metrics.py
create mode 100755 grok-main/grok/training.py
create mode 100644 grok-main/grok/transformer.py
create mode 100644 grok-main/grok/visualization.py
create mode 100644 grok-main/nbs/flatness.ipynb
create mode 100755 grok-main/scripts/compute_sharpness.py
create mode 100755 grok-main/scripts/create_metric_graphs.py
create mode 100755 grok-main/scripts/create_metrics_for_epochs.py
create mode 100755 grok-main/scripts/create_partial_metrics.py
create mode 100755 grok-main/scripts/make_data.py
create mode 100755 grok-main/scripts/torch-setup.sh
create mode 100755 grok-main/scripts/train.py
create mode 100755 grok-main/scripts/visualize_metrics.py
create mode 100644 grok-main/setup.py
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..04d8832
--- /dev/null
+++ b/README.md
@@ -0,0 +1,246 @@
+
+
+```
+╔══════════════════════════════════════════════════════════════════════════╗
+║ ║
+║ $ git merge openai/grok xai-org/grok-1 ║
+║ ║
+║ CONFLICT (rename): both repos are named "grok" ║
+║ CONFLICT (content): PyTorch ≠ JAX ║
+║ CONFLICT (scale): 100,000 params ≠ 314,000,000,000 params ║
+║ CONFLICT (purpose): studying understanding ≠ claiming to understand ║
+║ CONFLICT (history): OpenAI ≠ xAI ║
+║ ║
+║ Automatic merge failed. Proceeding anyway. ║
+║ ║
+╚══════════════════════════════════════════════════════════════════════════╝
+```
+
+# Does Grok grok grokking?
+
+[`openai/grok`](https://github.com/openai/grok) · [`xai-org/grok-1`](https://github.com/xai-org/grok-1) · bridged
+
+
+
+---
+
+**grok** — three words that happen to be the same word:
+
+| | Meaning | Source |
+|---|---|---|
+| **Grok** *(noun)* | xAI's 314B-parameter Mixture-of-Experts LLM, open-sourced March 17, 2024 | *xai-org/grok-1* |
+| **grok** *(verb)* | to understand something so thoroughly that observer and observed merge | Heinlein, *Stranger in a Strange Land*, 1961 |
+| **grokking** *(ML noun)* | a phase transition where neural networks suddenly generalize long after memorizing training data | Power et al., 2022; *openai/grok* |
+
+This repository asks one question with three interpretations: **does the model named Grok deeply understand the phenomenon called grokking?**
+
+And, more concretely: does Grok-1's architecture — when miniaturized and trained on arithmetic — exhibit different grokking dynamics than a standard transformer?
+
+---
+
+## Background
+
+In February 2023, Elon Musk publicly accused OpenAI of betraying its founding mission as an open-source nonprofit. In November 2023, xAI launched **Grok** as a closed product. On February 29, 2024, Musk filed a lawsuit against OpenAI. Eleven days later, he announced Grok would be open-sourced — and released it on March 17.
+
+Somewhere in that timeline, two GitHub repositories existed under the name `grok`:
+
+```
+github.com/openai/grok ← ~500 lines, PyTorch, studying when models learn
+github.com/xai-org/grok-1 ← ~1400 lines, JAX, a model claiming it has learned
+```
+
+No one had thought to connect them. This fork does.
+
+---
+
+## Architecture
+
+```
+ openai/grok (original) this fork adds xai-org/grok-1 (original)
+ ───────────────────────── ──────────────────────────
+
+ ┌───────────────────────┐ ┌──────────────────────────┐
+ │ Dense Transformer │ │ Mixture-of-Experts │
+ │ │ │ │
+ │ Input Embedding │ │ Input Embedding │
+ │ + sinusoidal PE │ ─── replicated as RoPE ────────────► │ + Rotary PE (RoPE) │
+ │ │ │ │
+ │ Multi-Head Attention │ ─── replicated as GQA ─────────────► │ GQA (48q / 8kv heads) │
+ │ │ │ │
+ │ FFN │ │ MoE FFN │
+ │ ReLU(xW₁)W₂ │ ─── replicated as MoE + gating ────► │ 8 experts, top-2 │
+ │ │ │ GELU(xW₁) ⊙ (xWᵥ) W₂ │
+ │ LayerNorm │ ─── replicated as RMSNorm ──────────► │ RMSNorm │
+ └───────────────────────┘ └──────────────────────────┘
+
+ ~100K parameters 314,000,000K parameters
+ PyTorch + Lightning JAX + Haiku
+ task: 42 + 55 mod 97 task: everything
+```
+
+The resulting class — `GrokOneTransformer` in `grok/transformer.py` — is a PyTorch implementation of Grok-1's decoder architecture that plugs directly into OpenAI's training and evaluation framework. The same training loop, the same arithmetic datasets, a fundamentally different optimizer landscape.
+
+---
+
+## Bridges
+
+Three concrete connections were made between the two codebases:
+
+**Bridge A — Architecture port (JAX → PyTorch)**
+
+`GrokOneTransformer` brings Grok-1's full architectural stack into the grokking framework. Train it against the standard transformer on identical tasks to study whether MoE changes the memorization-to-generalization phase transition.
+
+```bash
+# Standard transformer (original)
+./grok-main/scripts/train.py --math_operator + --train_data_pct 5
+
+# Grok-1 architecture, same task
+./grok-main/scripts/train.py --architecture grok1 --math_operator + --num_experts 8
+
+# Auto-scaled miniature Grok-1
+./grok-main/scripts/train.py --architecture grok1_mini
+```
+
+New metrics track routing-specific grokking signals (routing entropy, expert specialization, collapse index) logged per epoch alongside the standard weight norm and generalization bounds.
+
+**Bridge B — Arithmetic evaluation (OpenAI tasks → Grok-1 inference)**
+
+`run.py` in `grok-1-main/` gains an `--eval-grokking` flag that generates modular arithmetic problems in the style of the OpenAI paper and scores Grok-1's responses.
+
+```bash
+# Dry run — inspect problem generation without a checkpoint
+python grok-1-main/run.py --eval-grokking --operator + --n-samples 50 --dry-run
+
+# Full evaluation (requires the 314B checkpoint, ~300GB)
+python grok-1-main/run.py --eval-grokking --operator + --n-samples 100
+```
+
+**Bridge C — Config export**
+
+`TransformerConfig` in `grok-1-main/model.py` can now export itself as a scaled-down PyTorch-compatible config:
+
+```python
+from grok_1_main.model import TransformerConfig
+
+full_config = TransformerConfig(
+ emb_size=6144, num_layers=64, num_q_heads=48,
+ num_kv_heads=8, num_experts=8, num_selected_experts=2,
+ widening_factor=8,
+)
+
+# Scale down by 1/24 for a trainable experiment
+mini = full_config.to_grokking_config(scale_factor=1/24)
+# → {'d_model': 256, 'n_layers': 2, 'n_heads': 2, 'num_experts': 8, ...}
+```
+
+---
+
+## The Research Question
+
+Beyond the provocation, there is a real empirical question here.
+
+The 2022 grokking paper studied dense transformers on algorithmic tasks and found a universal pattern: models memorize first, then — often thousands of steps later — abruptly generalize. The phase transition is sharp, nearly discontinuous, and poorly understood.
+
+Mixture-of-Experts models have a fundamentally different optimization geometry. The router introduces a discrete, non-differentiable dispatch decision at each layer. This creates a non-smooth loss landscape, load-balancing pressures, and the possibility of "routing collapse" — where one expert handles all inputs.
+
+Three testable hypotheses:
+
+1. **Routing entropy as a leading indicator** — Does the distribution over experts become more uniform *before* the grokking transition appears in the loss curve? If so, routing entropy might predict generalization before it happens.
+
+2. **Expert collapse → memorization trap** — If all tokens route to one expert during memorization, that expert overfits. Grokking might require the router to "spread out" first.
+
+3. **MoE capacity delays grokking onset** — More parameters means more room to memorize without generalizing. Does a larger expert count push the phase transition further out in training time?
+
+None of these have been tested. This fork provides the infrastructure to test them.
+
+---
+
+## Quick Start
+
+```bash
+# Verify both architectures instantiate and run forward passes
+python does_grok_grok.py --demo
+
+# Run a comparative grokking experiment (logs to ./logs/)
+python does_grok_grok.py --experiment --operator + --max-steps 50000
+
+# Test Grok-1's arithmetic ability (no checkpoint needed for dry run)
+python does_grok_grok.py --eval-grok1 --dry-run --n-samples 20
+```
+
+Expected output from `--demo`:
+
+```
+[1] Standard transformer (dense, sinusoidal PE)
+ Params: 329,860 · 2L / 4H / 128D
+
+[2] Grok-1 architecture (MoE + RoPE + RMSNorm + gated GELU)
+ Params: 4,610,148 · 2L / 2H / 256D / 8 experts (top-2)
+
+[3] MoE routing (layer 0):
+ Expert 0 ████████░░░░░░░░ 0.125
+ Expert 1 ████████░░░░░░░░ 0.125
+ Expert 2 ████████░░░░░░░░ 0.125
+ ...
+
+ Routing entropy: 2.079 / 2.079 (perfectly uniform)
+
+[4] Config export (Grok-1 → PyTorch scale):
+ {'d_model': 256, 'n_layers': 2, 'n_heads': 2, 'num_experts': 8, ...}
+
+Both architectures operational. Bridge verified.
+```
+
+---
+
+## Structure
+
+```
+.
+├── does_grok_grok.py unified entry point
+├── README.md you are here
+│
+├── grok-main/ fork of openai/grok
+│ ├── grok/
+│ │ ├── transformer.py + GrokOneTransformer, MoE, RoPE, RMSNorm
+│ │ ├── training.py + architecture flag, MoE metrics logging
+│ │ ├── metrics.py + routing entropy, specialization, collapse
+│ │ ├── data.py + format_for_grok1(), eval suite generator
+│ │ └── __init__.py + bridge exports
+│ ├── setup.py
+│ └── README.md → original research documentation
+│
+└── grok-1-main/ fork of xai-org/grok-1
+ ├── model.py + to_grokking_config(), architecture_summary()
+ ├── run.py + --eval-grokking mode
+ └── README.md → original model documentation
+```
+
+All original files work exactly as before. Bridge code is appended, never replacing. Every addition is marked with a `# Bridge:` comment. Running `git diff` against the upstream repos shows only additions.
+
+---
+
+## Why
+
+Because both projects are named `grok`. Because naming things is hard and irony compounds. Because Musk accused OpenAI of abandoning open source, named his AI after a word meaning deep understanding, then open-sourced it eleven days after filing a lawsuit — while OpenAI had a repository studying how models *learn* to understand, sitting quietly with the same name.
+
+The act of forcing incompatible things to work together has a name in Chinese: 强兼 *(qiáng jiān)*. It seemed appropriate.
+
+The answer to "Does Grok grok grokking?" is, genuinely, not yet known.
+
+---
+
+## License
+
+- `grok-main/` — [MIT License](grok-main/LICENSE)
+- `grok-1-main/` — [Apache 2.0](grok-1-main/LICENSE.txt)
+- Bridge code (`GrokOneTransformer`, `does_grok_grok.py`, and all `# Bridge:` additions) — public domain
+
+---
+
+
+
+"The word is much wider in meaning than any English word conceived to date — it means to understand so thoroughly that the observer becomes a part of the observed."
+— Heinlein, via Jubal Harshaw
+
+
diff --git a/does_grok_grok.py b/does_grok_grok.py
new file mode 100644
index 0000000..446b464
--- /dev/null
+++ b/does_grok_grok.py
@@ -0,0 +1,287 @@
+#!/usr/bin/env python
+"""
+does_grok_grok.py — The unified entry point of 强兼 (forceful compatibility)
+
+ ╔═══════════════════════════════╗
+ ║ Does Grok grok grokking? ║
+ ╚═══════════════════════════════╝
+
+This script bridges two open-source projects that share a name but nothing else:
+
+ • openai/grok — Studies "grokking": the phenomenon where neural networks
+ suddenly generalize after prolonged memorization. Trains
+ small transformers on modular arithmetic. (PyTorch)
+
+ • xai-org/grok-1 — The 314B-parameter Mixture-of-Experts language model
+ from xAI, open-sourced under Apache 2.0. (JAX/Haiku)
+
+This script asks a simple question with a beautiful double meaning:
+Does Grok (the model) grok (deeply understand) grokking (the phenomenon)?
+
+We answer it in two ways:
+
+ MODE 1: "Grok-1 Architecture Grokking" (--experiment)
+ Train a miniature version of Grok-1's architecture (MoE + RoPE + RMSNorm
+ + gated GELU) on OpenAI's grokking benchmark tasks, and compare the
+ grokking dynamics against the original dense transformer. Do MoE models
+ grok differently? Does the expert routing change during the phase
+ transition from memorization to generalization?
+
+ MODE 2: "Does Grok-1 Know Arithmetic?" (--eval-grok1)
+ Generate modular arithmetic problems from OpenAI's grokking dataset
+ and evaluate Grok-1 on them. (Requires Grok-1 checkpoint.)
+
+Usage:
+ # Compare grokking: standard transformer vs. Grok-1 architecture
+ python does_grok_grok.py --experiment --operator + --max-steps 50000
+
+ # Quick demo (no training, just shows the architecture bridge)
+ python does_grok_grok.py --demo
+
+ # Evaluate Grok-1 on arithmetic (requires checkpoint)
+ python does_grok_grok.py --eval-grok1 --checkpoint ./grok-1-main/checkpoints/
+
+Copyright: This bridge is a work of 行为艺术 (behavioral art).
+License: Both source projects' licenses apply to their respective code.
+ The bridge code itself is unlicensed — do whatever you want with it.
+"""
+
+import argparse
+import sys
+import os
+import json
+import time
+
+# Add both project directories to path
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "grok-main"))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "grok-1-main"))
+
+
+BANNER = r"""
+ ╭──────────────────────────────────────────────────────────────╮
+ │ │
+ │ ██████╗ ██████╗ ██████╗ ██╗ ██╗ │
+ │ ██╔════╝ ██╔══██╗██╔═══██╗██║ ██╔╝ │
+ │ ██║ ███╗██████╔╝██║ ██║█████╔╝ │
+ │ ██║ ██║██╔══██╗██║ ██║██╔═██╗ │
+ │ ╚██████╔╝██║ ██║╚██████╔╝██║ ██╗ │
+ │ ╚═════╝ ╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝ │
+ │ │
+ │ 强 兼 │
+ │ — forceful compatibility — │
+ │ │
+ │ openai/grok ←──bridge──→ xai-org/grok-1 │
+ │ (grokking) (Grok-1 314B) │
+ │ PyTorch JAX/Haiku │
+ │ │
+ │ "Does Grok grok grokking?" │
+ │ │
+ ╰──────────────────────────────────────────────────────────────╯
+"""
+
+
+def demo():
+ """
+ Demonstrate the 强兼 bridge without training or inference.
+ Shows that Grok-1's architecture can be instantiated in PyTorch
+ at grokking-experiment scale.
+ """
+ print(BANNER)
+ print("=" * 66)
+ print(" MODE: Demo — Architecture Bridge Verification")
+ print("=" * 66)
+
+ # Import from OpenAI's grokking framework
+ from grok.transformer import Transformer, GrokOneTransformer
+ from grok.data import ArithmeticDataset, MODULUS
+
+ print("\n[1] Original OpenAI transformer (dense, sinusoidal PE):")
+ standard = Transformer(
+ n_layers=2, n_heads=4, d_model=128,
+ dropout=0.0, max_context_len=50, vocab_len=100,
+ )
+ n_params_standard = sum(p.numel() for p in standard.parameters())
+ print(f" Params: {n_params_standard:,}")
+ print(f" Architecture: {standard.n_layers}L / {standard.n_heads}H / {standard.d_model}D")
+
+ print("\n[2] Grok-1 architecture (MoE + RoPE + RMSNorm + gated GELU):")
+ grok1_mini = GrokOneTransformer.from_grok1_config(
+ scale_factor=1/24, vocab_len=100, max_context_len=50,
+ )
+ n_params_grok1 = sum(p.numel() for p in grok1_mini.parameters())
+ print(f" Params: {n_params_grok1:,}")
+ print(f" Architecture: {grok1_mini.n_layers}L / {grok1_mini.n_heads}H / "
+ f"{grok1_mini.d_model}D / {grok1_mini.num_experts}E (top-{grok1_mini.num_selected_experts})")
+
+ print(f"\n[3] Parameter ratio: Grok-1-mini is {n_params_grok1/n_params_standard:.1f}x "
+ f"the standard transformer")
+ print(f" (Real Grok-1 is ~314B params — "
+ f"{314_000_000_000/n_params_grok1:.0f}x this miniature)")
+
+ # Quick forward pass test
+ import torch
+ print("\n[4] Forward pass test:")
+ dummy_input = torch.randint(0, 100, (2, 10)) # batch=2, seq=10
+
+ with torch.no_grad():
+ out_std, _, _ = standard(dummy_input)
+ out_grok, _, _ = grok1_mini(dummy_input)
+
+ print(f" Standard: input {tuple(dummy_input.shape)} → output {tuple(out_std.shape)}")
+ print(f" Grok-1: input {tuple(dummy_input.shape)} → output {tuple(out_grok.shape)}")
+
+ # Show routing info
+ if grok1_mini.last_router_probs:
+ rp = grok1_mini.last_router_probs[0]
+ print(f"\n[5] MoE Routing (layer 0):")
+ print(f" Router probs shape: {tuple(rp.shape)}")
+ mean_probs = rp.mean(dim=(0, 1))
+ for i, p in enumerate(mean_probs):
+ bar = "█" * int(p.item() * 40)
+ print(f" Expert {i}: {p.item():.3f} {bar}")
+
+ # Cross-framework config export
+ print("\n[6] Cross-framework config export (JAX → PyTorch):")
+ try:
+ from model import TransformerConfig as Grok1TransformerConfig
+ grok1_config = Grok1TransformerConfig(
+ emb_size=48 * 128, widening_factor=8, key_size=128,
+ num_q_heads=48, num_kv_heads=8, num_layers=64,
+ num_experts=8, num_selected_experts=2,
+ )
+ exported = grok1_config.to_grokking_config()
+ print(f" Grok-1 ({grok1_config.emb_size}D, {grok1_config.num_layers}L) →")
+ print(f" Grokking ({exported['d_model']}D, {exported['n_layers']}L, "
+ f"{exported['num_experts']}E)")
+ except ImportError:
+ print(" (Grok-1 model.py not in path — run from project root)")
+
+ print("\n" + "=" * 66)
+ print(" Bridge verified. Both architectures are alive and compatible.")
+ print(" Run with --experiment to compare grokking dynamics.")
+ print("=" * 66 + "\n")
+
+
+def run_experiment(args):
+ """
+ Train both architectures on the same grokking task and compare.
+ """
+ print(BANNER)
+ print("=" * 66)
+ print(" MODE: Comparative Grokking Experiment")
+ print(f" Task: {args.operator} mod 97 | Steps: {args.max_steps}")
+ print("=" * 66)
+
+ from grok.training import train, add_args
+
+ # Prepare base arguments
+ parser = add_args()
+ base_args = [
+ "--math_operator", args.operator,
+ "--max_steps", str(args.max_steps),
+ "--train_data_pct", str(args.train_pct),
+ "--d_model", str(args.d_model),
+ "--n_layers", str(args.n_layers),
+ "--n_heads", str(args.n_heads),
+ ]
+
+ # Run 1: Standard transformer
+ print("\n" + "─" * 66)
+ print(" [1/2] Training STANDARD transformer (OpenAI's original)")
+ print("─" * 66)
+ std_args = parser.parse_args(base_args + [
+ "--architecture", "standard",
+ "--logdir", os.path.join(args.logdir, "standard"),
+ ])
+ t0 = time.time()
+ train(std_args)
+ t_std = time.time() - t0
+
+ # Run 2: Grok-1 architecture
+ print("\n" + "─" * 66)
+ print(" [2/2] Training GROK-1 architecture (MoE + RoPE + RMSNorm)")
+ print("─" * 66)
+ grok1_args = parser.parse_args(base_args + [
+ "--architecture", "grok1",
+ "--num_experts", str(args.num_experts),
+ "--num_selected_experts", str(args.num_selected_experts),
+ "--logdir", os.path.join(args.logdir, "grok1"),
+ ])
+ t0 = time.time()
+ train(grok1_args)
+ t_grok1 = time.time() - t0
+
+ print("\n" + "=" * 66)
+ print(" EXPERIMENT COMPLETE")
+ print(f" Standard: {t_std:.1f}s | Grok-1: {t_grok1:.1f}s")
+ print(f" Logs: {args.logdir}/")
+ print("=" * 66)
+ print("\n Compare training curves to see if MoE affects grokking dynamics.")
+ print(" Key questions:")
+ print(" • Does the Grok-1 architecture grok faster or slower?")
+ print(" • Does expert routing entropy change at the grokking transition?")
+ print(" • Do MoE models find different generalization shortcuts?\n")
+
+
+def eval_grok1(args):
+ """Evaluate Grok-1 on grokking arithmetic (delegates to grok-1-main/run.py)."""
+ print(BANNER)
+ os.chdir(os.path.join(os.path.dirname(__file__), "grok-1-main"))
+ from run import eval_grokking
+ eval_args = argparse.Namespace(
+ operator=args.operator,
+ n_samples=args.n_samples,
+ checkpoint_path=args.checkpoint,
+ tokenizer_path=args.tokenizer,
+ output=args.output,
+ dry_run=args.dry_run,
+ )
+ eval_grokking(eval_args)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="强兼 — Does Grok grok grokking?",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+Examples:
+ python does_grok_grok.py --demo
+ python does_grok_grok.py --experiment --operator + --max-steps 50000
+ python does_grok_grok.py --eval-grok1 --dry-run
+ """,
+ )
+
+ mode = parser.add_mutually_exclusive_group(required=True)
+ mode.add_argument("--demo", action="store_true",
+ help="Demo: verify the architecture bridge works")
+ mode.add_argument("--experiment", action="store_true",
+ help="Run comparative grokking experiment")
+ mode.add_argument("--eval-grok1", action="store_true",
+ help="Evaluate Grok-1 on arithmetic tasks")
+
+ # Experiment args
+ parser.add_argument("--operator", type=str, default="+")
+ parser.add_argument("--max-steps", type=int, default=50000)
+ parser.add_argument("--train-pct", type=float, default=5)
+ parser.add_argument("--d-model", type=int, default=128)
+ parser.add_argument("--n-layers", type=int, default=2)
+ parser.add_argument("--n-heads", type=int, default=4)
+ parser.add_argument("--num-experts", type=int, default=8)
+ parser.add_argument("--num-selected-experts", type=int, default=2)
+ parser.add_argument("--logdir", type=str, default="logs/qiangjian")
+
+ # Eval args
+ parser.add_argument("--checkpoint", type=str, default="./grok-1-main/checkpoints/")
+ parser.add_argument("--tokenizer", type=str, default="./grok-1-main/tokenizer.model")
+ parser.add_argument("--n-samples", type=int, default=50)
+ parser.add_argument("--output", type=str, default=None)
+ parser.add_argument("--dry-run", action="store_true")
+
+ args = parser.parse_args()
+
+ if args.demo:
+ demo()
+ elif args.experiment:
+ run_experiment(args)
+ elif args.eval_grok1:
+ eval_grok1(args)
diff --git a/grok-1-main/.gitignore b/grok-1-main/.gitignore
new file mode 100644
index 0000000..24d0d7e
--- /dev/null
+++ b/grok-1-main/.gitignore
@@ -0,0 +1,2 @@
+checkpoints/*
+!checkpoints/README.md
diff --git a/grok-1-main/CODE_OF_CONDUCT.md b/grok-1-main/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000..d715425
--- /dev/null
+++ b/grok-1-main/CODE_OF_CONDUCT.md
@@ -0,0 +1 @@
+Be excellent to each other.
diff --git a/grok-1-main/LICENSE.txt b/grok-1-main/LICENSE.txt
new file mode 100644
index 0000000..d645695
--- /dev/null
+++ b/grok-1-main/LICENSE.txt
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/grok-1-main/README.md b/grok-1-main/README.md
new file mode 100644
index 0000000..f68f21a
--- /dev/null
+++ b/grok-1-main/README.md
@@ -0,0 +1,72 @@
+# Grok-1
+
+> **Note:** This is `xai-org/grok-1` — the 314B MoE language model. This fork bridges it with [`openai/grok`](../grok-main/) (the grokking research framework). See the [root README](../README.md) for context.
+
+This repository contains JAX example code for loading and running the Grok-1 open-weights model.
+
+Make sure to download the checkpoint and place the `ckpt-0` directory in `checkpoints` - see [Downloading the weights](#downloading-the-weights)
+
+Then, run
+
+```shell
+pip install -r requirements.txt
+python run.py
+```
+
+to test the code.
+
+The script loads the checkpoint and samples from the model on a test input.
+
+Due to the large size of the model (314B parameters), a machine with enough GPU memory is required to test the model with the example code.
+The implementation of the MoE layer in this repository is not efficient. The implementation was chosen to avoid the need for custom kernels to validate the correctness of the model.
+
+# Model Specifications
+
+Grok-1 is currently designed with the following specifications:
+
+- **Parameters:** 314B
+- **Architecture:** Mixture of 8 Experts (MoE)
+- **Experts Utilization:** 2 experts used per token
+- **Layers:** 64
+- **Attention Heads:** 48 for queries, 8 for keys/values
+- **Embedding Size:** 6,144
+- **Tokenization:** SentencePiece tokenizer with 131,072 tokens
+- **Additional Features:**
+ - Rotary embeddings (RoPE)
+ - Supports activation sharding and 8-bit quantization
+- **Maximum Sequence Length (context):** 8,192 tokens
+
+# Downloading the weights
+
+You can download the weights using a torrent client and this magnet link:
+
+```
+magnet:?xt=urn:btih:5f96d43576e3d386c9ba65b883210a393b68210e&tr=https%3A%2F%2Facademictorrents.com%2Fannounce.php&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337%2Fannounce
+```
+
+or directly using [HuggingFace 🤗 Hub](https://huggingface.co/xai-org/grok-1):
+```
+git clone https://github.com/xai-org/grok-1.git && cd grok-1
+pip install huggingface_hub[hf_transfer]
+huggingface-cli download xai-org/grok-1 --repo-type model --include ckpt-0/* --local-dir checkpoints --local-dir-use-symlinks False
+```
+
+# Grokking Evaluation Mode
+
+This fork adds a `--eval-grokking` mode to `run.py` that evaluates Grok-1 on modular arithmetic problems from OpenAI's grokking research:
+
+```bash
+# Dry run (no checkpoint needed)
+python run.py --eval-grokking --operator + --n-samples 50 --dry-run
+
+# Full evaluation (requires checkpoint)
+python run.py --eval-grokking --operator + --n-samples 100
+```
+
+Also adds `to_grokking_config()` to `TransformerConfig` for exporting Grok-1's architecture at experiment scale.
+
+# License
+
+The code and associated Grok-1 weights in this release are licensed under the
+Apache 2.0 license. The license only applies to the source files in this
+repository and the model weights of Grok-1.
diff --git a/grok-1-main/checkpoint.py b/grok-1-main/checkpoint.py
new file mode 100644
index 0000000..1c6e878
--- /dev/null
+++ b/grok-1-main/checkpoint.py
@@ -0,0 +1,221 @@
+# Copyright 2024 X.AI Corp.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import contextlib
+import logging
+import math
+import os
+import pickle
+import re
+import shutil
+import sys
+import tempfile
+from concurrent.futures import ThreadPoolExecutor, wait
+from typing import Any, Optional
+
+import jax
+import numpy as np
+from jax.experimental import multihost_utils
+
+from model import QuantizedWeight8bit
+
+logger = logging.getLogger(__name__)
+rank_logger = logging.getLogger("rank")
+
+# Needed for loading the checkpoint with pickle.
+sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit
+
+
+@contextlib.contextmanager
+def copy_to_shm(file: str):
+ if file.startswith("/dev/shm/"):
+ # Nothing to do, the file is already in shared memory.
+ yield file
+ return
+
+ tmp_dir = "/dev/shm/"
+ fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
+ try:
+ shutil.copyfile(file, tmp_path)
+ yield tmp_path
+ finally:
+ os.remove(tmp_path)
+ os.close(fd)
+
+
+@contextlib.contextmanager
+def copy_from_shm(file: str):
+ tmp_dir = "/dev/shm/"
+ fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
+ try:
+ yield tmp_path
+ shutil.copyfile(tmp_path, file)
+ finally:
+ os.remove(tmp_path)
+ os.close(fd)
+
+
+def fast_unpickle(path: str) -> Any:
+ with copy_to_shm(path) as tmp_path:
+ with open(tmp_path, "rb") as f:
+ return pickle.load(f)
+
+
+def fast_pickle(obj: Any, path: str) -> None:
+ with copy_from_shm(path) as tmp_path:
+ with open(tmp_path, "wb") as f:
+ pickle.dump(obj, f)
+
+
+def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
+ """Loads a set of arrays."""
+ pool = ThreadPoolExecutor(max_workers=32)
+ fs = list()
+ num_tensors = 0
+ num_replicas = 1
+ data_model_shards = math.prod(mesh_config)
+ if tensor_indices is None:
+ iterator = enumerate(shaped_arrays)
+ else:
+ iterator = zip(tensor_indices, shaped_arrays)
+ for i, t in iterator:
+ if (i % num_replicas) == ((jax.process_index() // data_model_shards) % num_replicas):
+ idx = (
+ jax.process_index() // (num_replicas * data_model_shards) * data_model_shards
+ + jax.process_index() % data_model_shards
+ )
+ fs.append(
+ pool.submit(fast_unpickle, os.path.join(directory, f"tensor{i:05d}_{idx:03d}"))
+ )
+ num_tensors += 1
+ else:
+ fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype))
+ wait(fs)
+ return [f.result() for f in fs]
+
+
+def path_tuple_to_string(path: tuple) -> str:
+ pieces = []
+ for elem in path:
+ if isinstance(elem, jax.tree_util.DictKey):
+ pieces.append(elem.key)
+ elif isinstance(elem, jax.tree_util.GetAttrKey):
+ pieces.append(elem.name)
+ else:
+ assert isinstance(elem, (jax.tree_util.FlattenedIndexKey, jax.tree_util.SequenceKey))
+ return "/".join(pieces)
+
+
+def get_load_path_str(
+ init_path_str: str,
+ load_rename_rules: Optional[list[tuple[str, str]]] = None,
+ load_exclude_rules: Optional[list[str]] = None,
+) -> Optional[str]:
+ # Exclusion
+ if load_exclude_rules is not None:
+ for search_pattern in load_exclude_rules:
+ if re.search(search_pattern, init_path_str):
+ return None
+
+ # Renaming
+ load_path_str = init_path_str
+ if load_rename_rules is not None:
+ for search_pattern, replacement_pattern in load_rename_rules:
+ if re.search(search_pattern, load_path_str):
+ load_path_str = re.sub(search_pattern, replacement_pattern, load_path_str)
+ break
+
+ return load_path_str
+
+
+def replace_with_load_state(
+ init_state: Any,
+ load_state: Any,
+ load_rename_rules: Optional[list[tuple[str, str]]] = None,
+ load_exclude_rules: Optional[list[str]] = None,
+ mesh_config: tuple = (1, 1),
+) -> Any:
+ flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
+ flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
+ load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}
+
+ replaced = []
+ num_replicas = 1
+ data_model_shards = math.prod(mesh_config)
+ for i, (init_path, tensor) in enumerate(flatten_init):
+ init_path_str = path_tuple_to_string(init_path)
+ load_path_str = get_load_path_str(init_path_str, load_rename_rules, load_exclude_rules)
+ if load_path_str is None:
+ rank_logger.info(f"Excluded from restore: {init_path_str}.")
+ replaced.append(tensor)
+ elif load_path_str in load_map:
+ if load_path_str == init_path_str:
+ rank_logger.info(f"Restored from ckpt: {init_path_str}.")
+ else:
+ rank_logger.info(f"Restored from ckpt: {init_path_str} <-- {load_path_str}.")
+ replaced.append(load_map[load_path_str])
+ else:
+ rank_logger.info(f"Not found in ckpt: {init_path_str}.")
+ if (i % num_replicas) == ((jax.process_index() // data_model_shards) % num_replicas):
+ replaced.append(tensor)
+ else:
+ replaced.append(np.zeros_like(tensor))
+
+ return jax.tree_util.tree_unflatten(structure_init, replaced)
+
+
+def restore(
+ checkpoint_path: str,
+ state_shapes: Any,
+ mesh,
+ between_hosts_config,
+ params_only,
+ state_sharding,
+ init_state: Optional[Any] = None,
+) -> Any:
+ ckpt_path = os.path.join(checkpoint_path, "ckpt-0")
+
+ rank_logger.info("Loading checkpoint at {}".format(ckpt_path))
+ ckpt_shapes = state_shapes
+ ckpt_shapes_with_path, structure = jax.tree_util.tree_flatten_with_path(ckpt_shapes)
+
+ ckpt_shapes_flat = [elem[1] for elem in ckpt_shapes_with_path]
+ loaded_tensors = load_tensors(ckpt_shapes_flat, ckpt_path, between_hosts_config)
+
+ state = jax.tree_util.tree_unflatten(structure, loaded_tensors)
+
+ # Sanity check to give a better error message.
+ ckpt_keys = set(state.params.keys())
+ code_keys = set(state_sharding.params.keys())
+
+ if ckpt_keys != code_keys and init_state is None:
+ missing_in_ckpt = code_keys - ckpt_keys
+ missing_locally = ckpt_keys - code_keys
+ raise ValueError(
+ "Parameters in the code are not matching checkpoint parameters.\n"
+ "Params missing in checkpoint: {}\nParams missing in code: {}".format(
+ missing_in_ckpt, missing_locally
+ )
+ )
+ state_sharding = jax.tree_util.tree_map(
+ lambda x: jax.sharding.PartitionSpec() if x is None else x,
+ state_sharding,
+ is_leaf=lambda x: x is None,
+ )
+ state = multihost_utils.host_local_array_to_global_array(state, mesh, state_sharding)
+ if params_only:
+ state = state.params
+ return state
diff --git a/grok-1-main/checkpoints/README.md b/grok-1-main/checkpoints/README.md
new file mode 100644
index 0000000..fc34b62
--- /dev/null
+++ b/grok-1-main/checkpoints/README.md
@@ -0,0 +1,3 @@
+# Checkpoint directory
+
+Place Grok-1 checkpoints here so they can be loaded by the example script.
diff --git a/grok-1-main/model.py b/grok-1-main/model.py
new file mode 100644
index 0000000..4869808
--- /dev/null
+++ b/grok-1-main/model.py
@@ -0,0 +1,1462 @@
+# Copyright 2024 X.AI Corp.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import functools
+import logging
+import re
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
+
+import haiku as hk
+import jax
+import jax.experimental.maps
+import jax.numpy as jnp
+from jax import config, tree_util
+from jax.experimental.shard_map import shard_map
+from jax.lax import with_sharding_constraint as pjit_sharding_constraint
+from jax.sharding import PartitionSpec
+from jax.sharding import PartitionSpec as P
+
+config.update("jax_spmd_mode", "allow_all")
+
+logger = logging.getLogger(__name__)
+rank_logger = logging.getLogger("rank")
+
+
+@dataclass
+class QuantizedWeight8bit:
+ weight: jnp.array
+ scales: jnp.array
+
+ @property
+ def shape(self):
+ return self.weight.shape
+
+
+tree_util.register_pytree_node(
+ QuantizedWeight8bit,
+ lambda qw: ([qw.weight, qw.scales], ()),
+ lambda _, children: QuantizedWeight8bit(children[0], children[1]),
+)
+
+
+class TrainingState(NamedTuple):
+ """Container for the training state."""
+
+ params: hk.Params
+
+
+def _match(qs, ks):
+ """Return True if regexes in qs match any window of strings in tuple ks."""
+ # compile regexes and force complete match
+ qts = tuple(map(lambda x: re.compile(x + "$"), qs))
+ for i in range(len(ks) - len(qs) + 1):
+ matches = [x.match(y) for x, y in zip(qts, ks[i:])]
+ if matches and all(matches):
+ return True
+ return False
+
+
+def with_sharding_constraint(x, constraint):
+ if jax.experimental.maps.thread_resources.env.physical_mesh.empty:
+ return x
+ else:
+ return pjit_sharding_constraint(x, constraint)
+
+
+def cast_bfloat16(x):
+ if x.dtype.kind == "f":
+ return x.astype(jnp.bfloat16)
+ else:
+ return x
+
+
+def ffn_size(emb_size, widening_factor):
+ _ffn_size = int(widening_factor * emb_size) * 2 // 3
+ _ffn_size = _ffn_size + (8 - _ffn_size) % 8 # ensure it's a multiple of 8
+ logger.debug(f"emd_size: {emb_size} adjusted ffn_size: {_ffn_size}")
+ return _ffn_size
+
+
+def apply_rules(rules):
+ def _apply_rules(path, value):
+ del value # Unused.
+
+ path_list = [str(i.key).split("/") for i in path if isinstance(i, jax.tree_util.DictKey)]
+ flattened_path = jax.tree_util.tree_flatten(path_list)[0]
+
+ for rule, replacement in rules:
+ if _match(rule, flattened_path):
+ if isinstance(replacement, PartitionSpec):
+ if "layer_stack" in flattened_path:
+ replacement = PartitionSpec(None, *replacement)
+ rank_logger.debug(f"Apply {replacement} to {flattened_path} with rule {rule}")
+ return replacement
+ rank_logger.info(f"{flattened_path} no matching found!")
+ return None
+
+ return _apply_rules
+
+
+TRANSFORMER_PARTITION_RULES = [
+ # attention
+ (("multi_head_attention", "(query|key|value)", "w"), P("data", "model")),
+ (("multi_head_attention", "(query|key|value)", "b"), P(None)),
+ (("multi_head_attention", "linear", "w"), P("model", "data")),
+ (("multi_head_attention", "linear", "b"), P(None)),
+ # mlp
+ ((r"decoder_layer_[0-9]+", "linear", "w"), P("data", "model")),
+ ((r"decoder_layer_[0-9]+", "linear", "b"), P(None)),
+ ((r"decoder_layer_[0-9]+", "linear_v", "w"), P("data", "model")),
+ ((r"decoder_layer_[0-9]+", "linear_v", "b"), P(None)),
+ (
+ (r"decoder_layer_[0-9]+", "linear_1", "w"),
+ P(
+ "model",
+ "data",
+ ),
+ ),
+ ((r"decoder_layer_[0-9]+", "linear_1", "b"), P(None)),
+ # layer norms
+ ((r"decoder_layer_[0-9]+", "layer_norm", "offset"), P(None)),
+ ((r"decoder_layer_[0-9]+", "layer_norm", "scale"), P(None)),
+ ((r"decoder_layer_[0-9]+", "layer_norm_1", "offset"), P(None)),
+ ((r"decoder_layer_[0-9]+", "layer_norm_1", "scale"), P(None)),
+ # rms norms
+ ((r"decoder_layer_[0-9]+", "rms_norm", "scale"), P(None)),
+ ((r"decoder_layer_[0-9]+", "rms_norm_1", "scale"), P(None)),
+ ((r"decoder_layer_[0-9]+", "rms_norm_2", "scale"), P(None)),
+ ((r"decoder_layer_[0-9]+", "rms_norm_3", "scale"), P(None)),
+ # router
+ (("router", "w"), P("data")),
+ # moe mlp
+ (("moe", "linear", "w"), P(None, "data", "model")),
+ (("moe", "linear", "b"), P(None)),
+ (("moe", "linear_v", "w"), P(None, "data", "model")),
+ (("moe", "linear_v", "b"), P(None)),
+ (("moe", "linear_1", "w"), P(None, "model", "data")),
+ (("moe", "linear_1", "b"), P(None)),
+ # layer norms
+ (("moe", "layer_norm", "offset"), P(None)),
+ (("moe", "layer_norm", "scale"), P(None)),
+ (("moe", "layer_norm_1", "offset"), P(None)),
+ (("moe", "layer_norm_1", "scale"), P(None)),
+ # rms norms
+ (("moe", "rms_norm", "scale"), P(None)),
+ (("moe", "rms_norm_1", "scale"), P(None)),
+ (("moe", "rms_norm_2", "scale"), P(None)),
+ (("moe", "rms_norm_3", "scale"), P(None)),
+]
+
+LM_PARTITION_RULES = [
+ # Embedding layer.
+ (
+ ("language_model", "positional_embeddings"),
+ P(None, ("data", "model")),
+ ),
+ (
+ ("language_model", "in_out_embed", "embeddings"),
+ P(None, ("data", "model")),
+ ),
+ # Final RMSNorm.
+ (("language_model", "rms_norm"), P(None)),
+]
+TOP_K = 8
+
+
+class KVMemory(NamedTuple):
+ k: Optional[jax.Array]
+ v: Optional[jax.Array]
+ step: Optional[jax.Array]
+
+
+def init_layer_memories(
+ batch_size: int,
+ sequence_len: int,
+ num_kv_heads: int,
+ key_size: int,
+ num_layers: int,
+ step: Optional[jax.Array] = None,
+ dtype=jnp.bfloat16,
+):
+ return [
+ KVMemory(
+ k=jnp.zeros((batch_size, sequence_len, num_kv_heads, key_size), dtype=dtype),
+ v=jnp.zeros((batch_size, sequence_len, num_kv_heads, key_size), dtype=dtype),
+ step=step,
+ )
+ for _ in range(num_layers)
+ ]
+
+
+class Memory(NamedTuple):
+ # Self-attention key/value cache.
+ layers: List[KVMemory]
+
+
+class Router(hk.Module):
+ def __init__(
+ self,
+ num_selected_experts: int,
+ data_axis: Union[str, Tuple[str, ...]] = "data",
+ model_axis: Union[str, Tuple[str, ...]] = "model",
+ shard_activations: bool = False,
+ mesh: Any = None,
+ name: str = "router",
+ ):
+ super().__init__(name)
+ self.shard_activations = shard_activations
+ self.data_axis = data_axis
+ self.model_axis = model_axis
+ self.mesh = mesh
+ self.num_selected_experts = num_selected_experts
+
+ def compute_routing_prob(
+ self, inputs: jax.Array, padding_mask: Optional[jax.Array], num_experts: int
+ ):
+ return self._compute_routing_prob(inputs, padding_mask, num_experts)
+
+ @hk.transparent
+ def _compute_routing_prob(
+ self,
+ inputs: jax.Array,
+ padding_mask: Optional[jax.Array],
+ num_experts: int,
+ ):
+ # Using fp32 for the routing prob computation.
+ inputs = jax.lax.convert_element_type(inputs, jnp.float32)
+
+ # [batch_size, seq_len, num_experts]
+ routing_logits = self._router_weights(inputs, num_experts, sharding=P("data"))
+ assert routing_logits.dtype == jnp.float32
+ routing_probs = jax.nn.softmax(routing_logits)
+
+ if padding_mask is not None:
+ routing_probs *= padding_mask
+
+ return routing_probs, routing_logits, 0
+
+ @hk.transparent
+ def _router_weights(
+ self,
+ x: jax.Array,
+ num_experts: int,
+ sharding: Optional[P] = None,
+ ):
+ fprop_dtype = x.dtype
+ if not x.shape:
+ raise ValueError("Input must not be scalar.")
+
+ input_size = self.input_size = x.shape[-1]
+ w = hk.get_parameter(
+ "w", [input_size, num_experts], jnp.float32, init=hk.initializers.Constant(0)
+ )
+ if sharding:
+ w = with_sharding_constraint(w, sharding)
+
+ out = jnp.dot(x, w.astype(fprop_dtype))
+ return out
+
+
+class MoELayer(hk.Module):
+ def __init__(
+ self,
+ num_experts: int,
+ layer_fn: Callable,
+ router: Router,
+ mesh: Any = None,
+ shard_activations: bool = False,
+ data_axis: Union[str, Tuple[str, ...]] = "data",
+ model_axis: Union[str, Tuple[str, ...]] = "model",
+ name: Optional[str] = "moe",
+ ):
+ super().__init__(name)
+ self.num_experts = num_experts
+ self.layer_fn = layer_fn
+ self.router = router
+ self.mesh = mesh
+ self.shard_activations = shard_activations
+ self.data_axis = data_axis
+ self.model_axis = model_axis
+
+ @hk.transparent
+ def _inference_call(self, inputs: jax.Array, padding_mask: Optional[jax.Array] = None):
+ routing_probs, _, _ = self.router.compute_routing_prob(
+ inputs, padding_mask, self.num_experts
+ )
+ expert_gate, expert_index = jax.lax.top_k(routing_probs, k=self.router.num_selected_experts)
+ tmp = jnp.reshape(inputs, (inputs.shape[0] * inputs.shape[1], inputs.shape[2]))
+ broad_inputs = jnp.tile(tmp[:, jnp.newaxis, :], (1, self.router.num_selected_experts, 1))
+ broad_inputs = jnp.reshape(
+ broad_inputs, (broad_inputs.shape[0] * broad_inputs.shape[1], broad_inputs.shape[2])
+ )
+ init_fn, _ = hk.transform(self.layer_fn)
+ vmapped_init_fn = jax.vmap(init_fn, in_axes=0, out_axes=0)
+ lifted_init_fn = hk.experimental.transparent_lift(vmapped_init_fn)
+ # Fetch the vmapped params of the DenseBlock.
+ params = lifted_init_fn(
+ jax.random.split(jax.random.PRNGKey(1), self.num_experts),
+ jnp.zeros((self.num_experts, 1, 1, inputs.shape[-1])),
+ )
+
+ # Index and prob are in the shape [m, 2] indicating which token assigned to which experts.
+ # b: num_expert
+ # m: token or sequence dim
+ # k: input embed dim
+ # n: output embed dim
+ # e: the number of experts chosen for each token
+ @functools.partial(
+ shard_map,
+ mesh=self.mesh,
+ in_specs=(
+ P(self.data_axis, None),
+ P(None, None, self.model_axis),
+ P(None, None, self.model_axis),
+ P(None),
+ P(None),
+ ),
+ out_specs=P(self.data_axis, self.model_axis),
+ check_rep=False,
+ )
+ def moe_slow_matmul1(input, weight, scales, index, prob):
+ weight = weight * scales
+ one_hot_indices = jax.nn.one_hot(index.reshape(-1), 8, axis=0)
+ all_expert_output = jnp.einsum("mk,bkn->bmn", input, weight)
+ output = jnp.einsum("bm,bmn->mn", one_hot_indices, all_expert_output)
+ return output
+
+ @functools.partial(
+ shard_map,
+ mesh=self.mesh,
+ in_specs=(
+ P(self.data_axis, self.model_axis),
+ P(None, self.model_axis, None),
+ P(None, self.model_axis, None),
+ P(None),
+ P(None),
+ ),
+ out_specs=P(self.data_axis, None),
+ check_rep=False,
+ )
+ def moe_slow_matmul2(input, weight, scales, index, prob):
+ weight = weight * scales
+ one_hot_indices = jax.nn.one_hot(index.reshape(-1), 8, axis=0)
+ all_expert_output = jnp.einsum("mk,bkn->bmn", input, weight)
+ output = jnp.einsum("bm,bmn->mn", one_hot_indices, all_expert_output)
+ return jax.lax.psum(output, axis_name="model")
+
+ if hasattr(params["linear"]["w"], "scales"):
+ x = moe_slow_matmul1(
+ broad_inputs,
+ params["linear_v"]["w"].weight,
+ params["linear_v"]["w"].scales,
+ expert_index,
+ expert_gate,
+ )
+ y = moe_slow_matmul1(
+ broad_inputs,
+ params["linear"]["w"].weight,
+ params["linear"]["w"].scales,
+ expert_index,
+ expert_gate,
+ )
+ y = jax.nn.gelu(y)
+ out = moe_slow_matmul2(
+ x * y,
+ params["linear_1"]["w"].weight,
+ params["linear_1"]["w"].scales,
+ expert_index,
+ expert_gate,
+ )
+ out = jnp.reshape(
+ out,
+ [
+ inputs.shape[0],
+ inputs.shape[1],
+ self.router.num_selected_experts,
+ out.shape[-1],
+ ],
+ )
+ out = expert_gate[:, :, :, None].astype(jnp.bfloat16) * out
+ out = jnp.sum(out, axis=2)
+ out = out.astype(jnp.bfloat16)
+ else:
+ # This is only here so that we can construct a valid init_fn with this code.
+ return inputs
+ return out
+
+ def __call__(self, inputs: jax.Array, padding_mask: jax.Array):
+ return self._inference_call(inputs)
+
+
+class MHAOutput(NamedTuple):
+ """Outputs of the multi-head attention operation."""
+
+ embeddings: jax.Array
+ memory: Any
+
+
+class DecoderOutput(NamedTuple):
+ embeddings: jax.Array
+ memory: Any
+
+
+class TransformerOutput(NamedTuple):
+ embeddings: jax.Array
+ memory: Any
+
+
+@dataclass
+class TransformerConfig:
+ emb_size: int
+ key_size: int
+ num_q_heads: int
+ num_kv_heads: int
+ num_layers: int
+ vocab_size: int = 128 * 1024
+ widening_factor: float = 4.0
+
+ attn_output_multiplier: float = 1.0
+
+ name: Optional[str] = None
+
+ num_experts: int = -1
+ capacity_factor: float = 1.0
+ num_selected_experts: int = 1
+
+ init_scale: float = 1.0
+ shard_activations: bool = False
+
+ # Used for activation sharding.
+ data_axis: Union[str, Tuple[str, ...]] = "data"
+ model_axis: Union[str, Tuple[str, ...]] = "model"
+
+ def __post_init__(self):
+ if isinstance(self.data_axis, list):
+ self.data_axis = tuple(self.data_axis)
+ if isinstance(self.model_axis, list):
+ self.model_axis = tuple(self.model_axis)
+
+ def partition_rules(self):
+ return TRANSFORMER_PARTITION_RULES
+
+ def make(self, mesh=None) -> "Transformer":
+ data_axis = tuple(self.data_axis) if isinstance(self.data_axis, list) else self.data_axis
+ model_axis = (
+ tuple(self.model_axis) if isinstance(self.model_axis, list) else self.model_axis
+ )
+
+ return Transformer(
+ num_q_heads=self.num_q_heads,
+ num_kv_heads=self.num_kv_heads,
+ widening_factor=self.widening_factor,
+ key_size=self.key_size,
+ init_scale=self.init_scale,
+ mesh=mesh,
+ attn_output_multiplier=self.attn_output_multiplier,
+ shard_activations=self.shard_activations,
+ num_layers=self.num_layers,
+ num_experts=self.num_experts,
+ num_selected_experts=self.num_selected_experts,
+ data_axis=data_axis,
+ model_axis=model_axis,
+ )
+
+ def get_memory_sharding(self):
+ return Memory(
+ layers=[
+ KVMemory(
+ k=P(self.data_axis, self.model_axis),
+ v=P(self.data_axis, self.model_axis),
+ step=P(self.data_axis),
+ )
+ for _ in range(self.num_layers)
+ ],
+ )
+
+ # ── 强兼 Bridge ─────────────────────────────────────────────
+ # Exports Grok-1's architecture parameters in a format compatible
+ # with the OpenAI grokking framework (PyTorch).
+ # Source: https://github.com/openai/grok
+ # ─────────────────────────────────────────────────────────────
+
+ def to_grokking_config(self, scale_factor: float = 1/24) -> dict:
+ """
+ Export a scaled-down version of this Grok-1 config for grokking
+ experiments in the OpenAI grokking framework.
+
+ The architecture is preserved (MoE, RoPE, gated GELU) but all
+ dimensions are scaled down by `scale_factor` to enable training
+ on consumer hardware.
+
+ Returns a dict compatible with GrokOneTransformer.__init__() from
+ grok-main/grok/transformer.py.
+
+ >>> config = TransformerConfig(emb_size=6144, key_size=128, ...)
+ >>> grokking_args = config.to_grokking_config(scale_factor=1/24)
+ >>> # grokking_args can be passed to GrokOneTransformer(**grokking_args)
+ """
+ d_model = max(64, int(self.emb_size * scale_factor))
+ n_heads = max(2, int(self.num_q_heads * scale_factor))
+ d_model = d_model - (d_model % n_heads) # ensure divisibility
+ n_layers = max(2, int(self.num_layers * scale_factor))
+
+ return {
+ "n_layers": n_layers,
+ "n_heads": n_heads,
+ "d_model": d_model,
+ "num_experts": self.num_experts,
+ "num_selected_experts": self.num_selected_experts,
+ "widening_factor": int(self.widening_factor),
+ # Metadata for provenance tracking
+ "_source": "grok-1",
+ "_original_emb_size": self.emb_size,
+ "_original_num_layers": self.num_layers,
+ "_original_num_q_heads": self.num_q_heads,
+ "_scale_factor": scale_factor,
+ }
+
+ def architecture_summary(self) -> str:
+ """Human-readable summary of this Grok-1 configuration."""
+ total_params_approx = (
+ self.emb_size * self.vocab_size # embedding
+ + self.num_layers * (
+ 4 * self.emb_size * self.key_size * self.num_q_heads # attention
+ + self.num_experts * 3 * self.emb_size *
+ int(self.widening_factor * self.emb_size * 2 / 3) # MoE FFN
+ )
+ )
+ return (
+ f"Grok-1 TransformerConfig\n"
+ f" Layers: {self.num_layers}\n"
+ f" Embedding: {self.emb_size}\n"
+ f" Q-Heads: {self.num_q_heads}\n"
+ f" KV-Heads: {self.num_kv_heads}\n"
+ f" Key size: {self.key_size}\n"
+ f" Experts: {self.num_experts} (top-{self.num_selected_experts})\n"
+ f" Widening: {self.widening_factor}x\n"
+ f" ~Params: {total_params_approx/1e9:.1f}B\n"
+ )
+
+
+def hk_rms_norm(
+ x: jax.Array,
+ fixed_scale=False,
+ sharding=P(None),
+) -> jax.Array:
+ """Applies a unique LayerNorm to x with default settings."""
+ ln = RMSNorm(axis=-1, create_scale=not fixed_scale, sharding=sharding)
+ return ln(x)
+
+
+def make_attention_mask(
+ query_input: jax.Array,
+ key_input: jax.Array,
+ pairwise_fn: Callable[..., Any] = jnp.multiply,
+ dtype: Any = jnp.bfloat16,
+):
+ """Mask-making helper for attention weights.
+
+ In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the
+ attention weights will be `[batch..., heads, len_q, len_kv]` and this
+ function will produce `[batch..., 1, len_q, len_kv]`.
+
+ Args:
+ query_input: a batched, flat input of query_length size
+ key_input: a batched, flat input of key_length size
+ pairwise_fn: broadcasting elementwise comparison function
+ dtype: mask return dtype
+
+ Returns:
+ A `[batch..., 1, len_q, len_kv]` shaped mask for 1d attention.
+ """
+ mask = pairwise_fn(jnp.expand_dims(query_input, axis=-1), jnp.expand_dims(key_input, axis=-2))
+ mask = jnp.expand_dims(mask, axis=-3)
+ return mask.astype(dtype)
+
+
+class Linear(hk.Linear):
+ def __init__(
+ self,
+ output_size: int,
+ with_bias: bool = True,
+ sharding: Optional[P] = None,
+ mesh: Any = None,
+ name: Optional[str] = None,
+ shard_axis: int = 0,
+ ):
+ super().__init__(
+ output_size=output_size,
+ with_bias=with_bias,
+ name=name,
+ )
+ self.sharding = sharding
+ self.mesh = mesh
+ self.shard_axis = shard_axis
+
+ def __call__(
+ self,
+ inputs: jax.Array,
+ ) -> jax.Array:
+ """Computes a linear transform of the input."""
+
+ fprop_dtype = inputs.dtype
+ if not inputs.shape:
+ raise ValueError("Input must not be scalar.")
+
+ input_size = self.input_size = inputs.shape[-1]
+ output_size = self.output_size
+
+ w = hk.get_parameter(
+ "w", [input_size, output_size], jnp.float32, init=hk.initializers.Constant(0)
+ )
+
+ if hasattr(w, "scales"):
+ shape = inputs.shape
+ inputs = jnp.reshape(inputs, (-1, shape[-1]))
+
+ @functools.partial(
+ shard_map,
+ mesh=self.mesh,
+ in_specs=(self.sharding, self.sharding),
+ out_specs=self.sharding,
+ check_rep=False,
+ )
+ def mul(w, s):
+ return w.astype(s.dtype) * s
+
+ w = mul(w.weight, w.scales)
+ out = jnp.dot(inputs, w.astype(fprop_dtype))
+ if self.with_bias:
+ b = hk.get_parameter(
+ "b", [self.output_size], jnp.float32, init=hk.initializers.Constant(0)
+ )
+ b = jnp.broadcast_to(b, out.shape)
+ out = out + b.astype(fprop_dtype)
+
+ return out
+
+
+class RMSNorm(hk.RMSNorm):
+
+ def __init__(
+ self,
+ axis: Union[int, Sequence[int], slice],
+ eps: float = 1e-5,
+ name: Optional[str] = None,
+ create_scale: bool = True,
+ sharding: Optional[P] = None,
+ ):
+ super().__init__(axis, eps, create_scale=create_scale, name=name)
+ self.sharding = sharding
+
+ def __call__(self, inputs: jax.Array):
+ fprop_dtype = inputs.dtype
+ param_shape = (inputs.shape[-1],)
+ if self.create_scale:
+ scale = hk.get_parameter(
+ "scale",
+ param_shape,
+ dtype=jnp.float32,
+ init=hk.initializers.Constant(0),
+ )
+ if self.sharding:
+ scale = with_sharding_constraint(scale, self.sharding)
+ scale = jnp.broadcast_to(scale.astype(jnp.float32), inputs.shape)
+ else:
+ scale = 1.0
+ inputs = inputs.astype(jnp.float32)
+ scale = scale.astype(jnp.float32)
+ mean_squared = jnp.mean(jnp.square(inputs), axis=[-1], keepdims=True)
+ mean_squared = jnp.broadcast_to(mean_squared, inputs.shape)
+
+ normed_inputs = inputs * jax.lax.rsqrt(mean_squared + self.eps)
+
+ outputs = scale * normed_inputs
+
+ return outputs.astype(fprop_dtype)
+
+
+def rotate_half(
+ x: jax.Array,
+) -> jax.Array:
+ """Obtain the rotated counterpart of each feature"""
+ x1, x2 = jnp.split(x, 2, axis=-1)
+ return jnp.concatenate((-x2, x1), axis=-1)
+
+
+class RotaryEmbedding(hk.Module):
+ """Applies rotary embeddings (RoPE) to the input sequence tensor,
+ as described in https://arxiv.org/abs/2104.09864.
+
+ Attributes:
+ dim (int): Dimensionality of the feature vectors
+ base_exponent (int): Base exponent to compute embeddings from
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ name: Optional[str] = None,
+ base_exponent: int = 10000,
+ ):
+ super().__init__(name)
+ self.dim = dim
+ self.base_exponent = base_exponent
+ assert self.dim % 2 == 0
+
+ def __call__(
+ self,
+ x: jax.Array,
+ seq_dim: int,
+ offset: jax.Array,
+ const_position: Optional[int] = None,
+ t: Optional[jax.Array] = None,
+ ) -> jax.Array:
+ fprop_dtype = x.dtype
+ # Compute the per-dimension frequencies
+ exponents = jnp.arange(0, self.dim, 2, dtype=jnp.float32)
+ inv_freq = jnp.asarray(
+ 1.0 / (self.base_exponent ** (exponents / self.dim)), dtype=jnp.float32
+ )
+
+ if jnp.shape(offset) == ():
+ # Offset can be a scalar or one offset per batch element.
+ offset = jnp.expand_dims(offset, 0)
+
+ # Compute the per element phase (to pass into sin and cos)
+ if const_position:
+ t = const_position * jnp.ones(
+ (
+ 1,
+ x.shape[seq_dim],
+ ),
+ dtype=jnp.float32,
+ )
+ elif t is None:
+ t = jnp.arange(x.shape[seq_dim], dtype=jnp.float32) + jnp.expand_dims(offset, -1)
+ phase = jnp.einsum("bi,j->bij", t, inv_freq)
+ phase = jnp.tile(phase, reps=(1, 2))[:, :, None, :]
+
+ x = x * jnp.cos(phase) + rotate_half(x) * jnp.sin(phase)
+ x = x.astype(fprop_dtype)
+
+ return x
+
+
+class MultiHeadAttention(hk.Module):
+ def __init__(
+ self,
+ num_q_heads: int,
+ num_kv_heads: int,
+ key_size: int,
+ *,
+ with_bias: bool = True,
+ value_size: Optional[int] = None,
+ model_size: Optional[int] = None,
+ attn_output_multiplier: 1.0,
+ data_axis: Union[str, Tuple[str, ...]] = "data",
+ model_axis: Union[str, Tuple[str, ...]] = "model",
+ name: Optional[str] = None,
+ ):
+ super().__init__(name=name)
+ self.num_q_heads = num_q_heads
+ self.num_kv_heads = num_kv_heads
+ self.key_size = key_size
+ self.value_size = value_size or key_size
+ self.model_size = model_size or key_size * num_q_heads
+ self.data_axis = data_axis
+ self.model_axis = model_axis
+ self.attn_output_multiplier = attn_output_multiplier
+ self.with_bias = with_bias
+
+ def __call__(
+ self,
+ query: jax.Array,
+ key: Optional[jax.Array],
+ value: Optional[jax.Array],
+ mask: Optional[jax.Array] = None,
+ kv_memory: Optional[KVMemory] = None,
+ mesh: Any = None,
+ ) -> MHAOutput:
+ # In shape hints below, we suppress the leading dims [...] for brevity.
+ # Hence e.g. [A, B] should be read in every case as [..., A, B].
+ sequence_length = query.shape[1]
+ projection = self._linear_projection
+ use_memory = False
+ if kv_memory is not None:
+ if kv_memory.k is None:
+ assert kv_memory.v is None
+ assert key is not None
+ assert value is not None
+ else:
+ assert kv_memory.v is not None
+ use_memory = True
+ else:
+ assert key is not None
+ assert value is not None
+
+ # Check that the keys and values have consistent batch size and sequence length.
+ if not use_memory:
+ assert key.shape[:2] == value.shape[:2], f"key/value shape: {key.shape}/{value.shape}"
+
+ if mask is not None:
+ assert mask.ndim == 4
+ assert mask.shape[0] in {
+ 1,
+ query.shape[0],
+ }, f"mask/query shape: {mask.shape}/{query.shape}"
+ if not use_memory:
+ assert key.shape[0] in {
+ 1,
+ query.shape[0],
+ }, f"key/query shape: {key.shape}/{query.shape}"
+ assert mask.shape[1] == 1
+ assert mask.shape[2] in {
+ 1,
+ query.shape[1],
+ }, f"mask/query shape: {mask.shape}/{query.shape}"
+ if not use_memory:
+ assert mask.shape[3] in {
+ 1,
+ key.shape[1],
+ }, f"mask/query shape: {mask.shape}/{key.shape}"
+
+ # Compute key/query/values (overload K/Q/V to denote the respective sizes).
+ assert self.num_q_heads % self.num_kv_heads == 0
+ query_heads = projection(
+ query,
+ self.key_size,
+ self.num_q_heads,
+ name="query",
+ sharding=P("data", "model"),
+ mesh=mesh,
+ ) # [B, T', H, Q=K]
+
+ new_memory = None
+ key_heads = projection(
+ key,
+ self.key_size,
+ self.num_kv_heads,
+ name="key",
+ sharding=P("data", "model"),
+ mesh=mesh,
+ ) # [B, T, H, K]
+ value_heads = projection(
+ value,
+ self.value_size,
+ self.num_kv_heads,
+ name="value",
+ sharding=P("data", "model"),
+ mesh=mesh,
+ ) # [B, T, H, V]
+
+ rotate = RotaryEmbedding(dim=self.key_size, base_exponent=int(1e4))
+ key_heads = rotate(key_heads, seq_dim=1, offset=(kv_memory.step if kv_memory else 0))
+ query_heads = rotate(query_heads, seq_dim=1, offset=(kv_memory.step if kv_memory else 0))
+
+ @functools.partial(jax.vmap)
+ def update_into(mem, start, update):
+ return jax.lax.dynamic_update_slice_in_dim(mem, update, start, axis=0)
+
+ if kv_memory:
+ if mesh is not None:
+
+ @functools.partial(
+ shard_map,
+ mesh=mesh,
+ in_specs=(
+ P("data", None, "model"),
+ P("data"),
+ P("data", None, "model"),
+ ),
+ out_specs=P("data", None, "model"),
+ check_rep=False,
+ )
+ def update_into_shmap(mems, starts, updates):
+ return update_into(mems, starts, updates)
+
+ key_heads = update_into_shmap(kv_memory.k, kv_memory.step, key_heads)
+ value_heads = update_into_shmap(kv_memory.v, kv_memory.step, value_heads)
+ else:
+ key_heads = update_into(kv_memory.k, kv_memory.step, key_heads)
+ value_heads = update_into(kv_memory.v, kv_memory.step, value_heads)
+
+ new_step = kv_memory.step + sequence_length
+ memory_mask = jnp.arange(kv_memory.k.shape[1]) < new_step[:, None]
+ memory_mask = memory_mask[:, None, None, :] # [B, H, T, T]
+ if mask is not None:
+ mask = memory_mask * mask
+ else:
+ mask = memory_mask
+
+ new_memory = KVMemory(
+ k=key_heads,
+ v=value_heads,
+ step=new_step,
+ )
+ # Add separate dimension for grouped query heads.
+ query_heads = with_sharding_constraint(query_heads, P(self.data_axis, None, "model", None))
+ key_heads = with_sharding_constraint(key_heads, P(self.data_axis, None, "model", None))
+ value_heads = with_sharding_constraint(value_heads, P(self.data_axis, None, "model", None))
+ b, t, h, d = query_heads.shape
+ _, _, kv_h, _ = key_heads.shape
+ assert h % kv_h == 0, f"query_heads {h} must be a multiple of kv_heads {kv_h}"
+
+ query_heads = jnp.reshape(query_heads, (b, t, kv_h, h // kv_h, d))
+ query_heads = with_sharding_constraint(
+ query_heads, P(self.data_axis, None, "model", None, None)
+ )
+
+ # Compute attention weights.
+ # Attention softmax is always carried out in fp32.
+ attn_logits = jnp.einsum("...thHd,...Thd->...hHtT", query_heads, key_heads).astype(
+ jnp.float32
+ )
+ attn_logits *= self.attn_output_multiplier
+ max_attn_val = jnp.array(30.0, dtype=attn_logits.dtype)
+ attn_logits = max_attn_val * jnp.tanh(attn_logits / max_attn_val)
+
+ mask = mask[:, :, None, :, :]
+
+ if mask is not None:
+ if mask.ndim != attn_logits.ndim:
+ raise ValueError(
+ f"Mask dimensionality {mask.ndim} must match logits dimensionality "
+ f"{attn_logits.ndim} for {mask.shape}/{attn_logits.shape}."
+ )
+ attn_logits = jnp.where(mask, attn_logits, -1e30)
+ attn_weights = jax.nn.softmax(attn_logits).astype(query.dtype) # [H, T', T]
+
+ # Weight the values by the attention and flatten the head vectors.
+ attn = jnp.einsum("...hHtT,...Thd->...thHd", attn_weights, value_heads)
+ attn = with_sharding_constraint(attn, P(self.data_axis, None, "model", None, None))
+ leading_dims = attn.shape[:2]
+ attn = jnp.reshape(attn, (*leading_dims, -1)) # [T', H*V]
+ attn = with_sharding_constraint(attn, P(self.data_axis, None, "model"))
+ # Apply another projection to get the final embeddings.
+ final_projection = Linear(
+ self.model_size,
+ with_bias=False,
+ sharding=P("model", "data"),
+ mesh=mesh,
+ )
+ return MHAOutput(final_projection(attn), new_memory)
+
+ @hk.transparent
+ def _linear_projection(
+ self,
+ x: jax.Array,
+ head_size: int,
+ num_heads: int,
+ sharding: Optional[P] = None,
+ name: Optional[str] = None,
+ mesh: Any = None,
+ ) -> jax.Array:
+ y = Linear(
+ num_heads * head_size,
+ with_bias=False,
+ name=name,
+ sharding=sharding,
+ mesh=mesh,
+ )(x)
+ *leading_dims, _ = x.shape
+ return y.reshape((*leading_dims, num_heads, head_size))
+
+
+@dataclass
+class MHABlock(hk.Module):
+ """A MHA Block"""
+
+ num_q_heads: int
+ num_kv_heads: int
+ key_size: int
+ attn_output_multiplier: float = 1.0
+ mesh: Any = None
+ data_axis: Union[str, Tuple[str, ...]] = "data"
+ model_axis: Union[str, Tuple[str, ...]] = "model"
+
+ @hk.transparent
+ def __call__(
+ self,
+ inputs: jax.Array, # [B, T, D]
+ mask: jax.Array, # [B, 1, T, T] or [B, 1, 1, T] or B[1, 1, 1, 1]
+ layer_memory: Optional[KVMemory],
+ ) -> MHAOutput:
+ _, _, model_size = inputs.shape
+ assert mask.ndim == 4, f"shape: {mask.shape}"
+ assert mask.shape[2] in {1, inputs.shape[1]}, str(mask.shape)
+ assert mask.shape[3] in {1, inputs.shape[1]}, str(mask.shape)
+ side_input = inputs
+
+ def attn_block(query, key, value, mask, memory) -> MHAOutput:
+ return MultiHeadAttention(
+ num_q_heads=self.num_q_heads,
+ num_kv_heads=self.num_kv_heads,
+ key_size=self.key_size,
+ model_size=model_size,
+ data_axis=self.data_axis,
+ model_axis=self.model_axis,
+ attn_output_multiplier=self.attn_output_multiplier,
+ )(
+ query,
+ key,
+ value,
+ mask,
+ memory,
+ mesh=self.mesh,
+ )
+
+ attn_output = attn_block(inputs, side_input, side_input, mask, layer_memory)
+ h_attn = attn_output.embeddings
+
+ return attn_output._replace(embeddings=h_attn)
+
+
+@dataclass
+class DenseBlock(hk.Module):
+ num_q_heads: int
+ num_kv_heads: int
+ key_size: int
+ widening_factor: float = 4.0
+ sharding_constraint: bool = False
+ mesh: Any = None
+
+ @hk.transparent
+ def __call__(
+ self,
+ inputs: jax.Array, # [B, T, D]
+ ) -> jax.Array: # [B, T, D]
+ _, _, model_size = inputs.shape
+ h_v = Linear(
+ ffn_size(
+ model_size,
+ self.widening_factor,
+ ),
+ with_bias=False,
+ mesh=self.mesh,
+ sharding=P("data", "model"),
+ name="linear_v",
+ )(inputs)
+ h_w1 = jax.nn.gelu(
+ Linear(
+ ffn_size(
+ model_size,
+ self.widening_factor,
+ ),
+ with_bias=False,
+ mesh=self.mesh,
+ sharding=P("data", "model"),
+ )(inputs)
+ )
+ h_dense = Linear(
+ model_size,
+ with_bias=False,
+ sharding=P("model", "data"),
+ mesh=self.mesh,
+ shard_axis=1,
+ )(h_w1 * h_v)
+
+ return h_dense
+
+
+@dataclass
+class DecoderLayer(hk.Module):
+ """A transformer stack."""
+
+ num_q_heads: int
+ num_kv_heads: int
+ key_size: int
+ num_layers: int
+ # MoE.
+ num_experts: int
+ layer_index: Optional[int] = None
+ num_selected_experts: int = 1
+ widening_factor: float = 4.0
+ name: Optional[str] = None
+ data_axis: Union[str, Tuple[str, ...]] = "data"
+ model_axis: Union[str, Tuple[str, ...]] = "model"
+ shard_activations: bool = False
+ attn_output_multiplier: float = 1.0
+ mesh: Any = None
+
+ def __call__(
+ self,
+ inputs: jax.Array, # [B, T, D]
+ mask: jax.Array, # [B, 1, T, T] or [B, 1, 1, T]
+ padding_mask: Optional[jax.Array],
+ layer_memory: Optional[KVMemory],
+ ) -> DecoderOutput:
+ """Transforms input embedding sequences to output embedding sequences."""
+
+ def layer_norm(x):
+ return hk_rms_norm(x)
+
+ if self.shard_activations:
+ sharding = P(self.data_axis, None, self.model_axis)
+ else:
+ sharding = P(self.data_axis, None)
+ h = with_sharding_constraint(inputs, sharding)
+
+ attn_output = MHABlock(
+ num_q_heads=self.num_q_heads,
+ num_kv_heads=self.num_kv_heads,
+ key_size=self.key_size,
+ attn_output_multiplier=self.attn_output_multiplier,
+ mesh=self.mesh,
+ data_axis=self.data_axis,
+ model_axis=self.model_axis,
+ )(layer_norm(h), mask, layer_memory)
+ h_attn = attn_output.embeddings
+
+ h_attn = layer_norm(h_attn)
+ h += h_attn
+ h = with_sharding_constraint(h, sharding)
+
+ def base_dense_block(h):
+ h = DenseBlock(
+ num_q_heads=self.num_q_heads,
+ num_kv_heads=self.num_kv_heads,
+ key_size=self.key_size,
+ widening_factor=self.widening_factor,
+ sharding_constraint=False,
+ mesh=self.mesh,
+ )(h)
+ return h
+
+ if self.num_experts > 1:
+ rank_logger.debug("Using MoE!")
+ router = Router(
+ num_selected_experts=self.num_selected_experts,
+ shard_activations=self.shard_activations,
+ data_axis=self.data_axis,
+ model_axis=self.model_axis,
+ mesh=self.mesh,
+ )
+ h_dense = MoELayer(
+ num_experts=self.num_experts,
+ mesh=self.mesh,
+ layer_fn=base_dense_block,
+ router=router,
+ shard_activations=self.shard_activations,
+ data_axis=self.data_axis,
+ model_axis=self.model_axis,
+ )(layer_norm(h), padding_mask)
+ else:
+ h_dense = base_dense_block(layer_norm(h))
+
+ h_dense = layer_norm(h_dense)
+ h += h_dense
+ h = with_sharding_constraint(h, sharding)
+
+ return DecoderOutput(
+ embeddings=h,
+ memory=attn_output.memory,
+ )
+
+
+class LanguageModelOutput(NamedTuple):
+ logits: jax.Array
+ model_state: Any
+
+
+class InOutEmbed(hk.Embed):
+ """Module for embedding tokens in a low-dimensional space."""
+
+ def __init__(
+ self,
+ vocab_size: Optional[int] = None,
+ embed_dim: Optional[int] = None,
+ sharding: Optional[P] = None,
+ name: Optional[str] = None,
+ ):
+ super().__init__(
+ vocab_size=vocab_size,
+ embed_dim=embed_dim,
+ name=name,
+ )
+ self.sharding = sharding
+
+ @property
+ def embeddings(self):
+ embed_mat = hk.get_parameter(
+ "embeddings",
+ [self.vocab_size, self.embed_dim],
+ dtype=jnp.float32,
+ init=hk.initializers.Constant(0),
+ )
+ if self.sharding:
+ embed_mat = with_sharding_constraint(embed_mat, self.sharding)
+ return embed_mat
+
+ def decode(
+ self,
+ inputs: jax.Array,
+ ) -> jax.Array:
+ return jnp.dot(inputs, self.embeddings.T.astype(inputs.dtype))
+
+
+@dataclass
+class LanguageModelConfig:
+ """An autoregressive transformer-based language model."""
+
+ model: Optional[TransformerConfig]
+ vocab_size: int
+ pad_token: int
+ eos_token: int
+ sequence_len: int
+ model_size: int = 0
+ embedding_init_scale: float = 1.0
+ embedding_multiplier_scale: float = 1.0
+ output_multiplier_scale: float = 1.0
+ name: Optional[str] = None
+ fprop_dtype: Any = jnp.bfloat16
+ model_type: Optional[str] = None
+ init_scale_override: Optional[float] = None
+ shard_embeddings: bool = True
+
+ _initialized = False
+
+ def initialize(self):
+ # We cannot specify [] as a default value (it is mutable), hence None.
+ model_config = self.model
+ assert self.init_scale_override is None, (
+ "Overriding model initialize scale is supported only for predefined models."
+ )
+ if self.model_size == 0:
+ self.model_size = model_config.emb_size
+ assert self.model is not None, "Model could not be initialized."
+ self._initialized = True
+ return self
+
+ def make(self, *args, **kwargs):
+ if not self._initialized:
+ logger.warning(
+ f"LanguageModel {self.name} is not initialized. Initializing for one replica."
+ )
+ self.initialize()
+
+ return LanguageModel(
+ model=self.model.make(*args, **kwargs),
+ config=self,
+ fprop_dtype=self.fprop_dtype,
+ mesh=kwargs.get("mesh", None),
+ )
+
+ def partition_rules(self):
+ return LM_PARTITION_RULES + self.model.partition_rules()
+
+
+def layer_norm(x, model):
+ return hk_rms_norm(x)
+
+
+@dataclass
+class LanguageModel(hk.Module):
+ """An autoregressive transformer-based language model."""
+
+ model: "Transformer"
+ config: LanguageModelConfig
+ fprop_dtype: Any = jnp.bfloat16
+ name: Optional[str] = None
+ mesh: Any = None
+
+ def __call__(
+ self,
+ tokens: jax.Array,
+ memory: Optional[Memory] = None,
+ *,
+ batch: Dict[str, jax.Array] = {},
+ last_hid_only: bool = False,
+ length: Optional[jax.Array] = None,
+ ) -> LanguageModelOutput:
+ """Forward pass, producing a sequence of logits."""
+ del batch # Unused.
+
+ config = self.config
+
+ input_mask = jnp.greater(tokens, config.pad_token)
+
+ # Embed the input tokens and positions.
+ in_out_embed = InOutEmbed(
+ self.config.vocab_size,
+ embed_dim=self.config.model_size,
+ sharding=P(None, ("data", "model")),
+ )
+ input_embeddings = in_out_embed(tokens).astype(config.fprop_dtype)
+ input_embeddings = with_sharding_constraint(
+ input_embeddings, P("data", None, self.model.model_axis)
+ )
+ input_embeddings *= config.embedding_multiplier_scale
+
+ model_output = self.model(
+ input_embeddings,
+ input_mask,
+ memory=memory,
+ ) # [B, T, D]
+ embeddings, model_state = model_output.embeddings, model_output.memory
+ if self.model.shard_activations:
+ embeddings = with_sharding_constraint(
+ embeddings, P("data", None, self.model.model_axis)
+ )
+ else:
+ embeddings = with_sharding_constraint(embeddings, P("data", None))
+ rank_logger.debug(f"Final embedding shape: {embeddings.shape}")
+ embeddings = layer_norm(embeddings, self.model)
+ assert embeddings.dtype == self.fprop_dtype
+
+ if last_hid_only:
+ last_step = jnp.maximum(jnp.sum(input_mask.astype(jnp.int32), axis=1) - 1, 0)
+ last_hid = jax.vmap(lambda x, i: x[i], in_axes=0, out_axes=0)(embeddings, last_step)
+ return last_hid
+
+ if length is not None:
+ last_step = jnp.maximum(length.astype(jnp.int32) - 1, 0)
+ embeddings = jax.vmap(lambda x, i: x[i], in_axes=0, out_axes=0)(embeddings, last_step)
+ embeddings = jnp.expand_dims(embeddings, axis=1)
+
+ # Decode the embeddings (here, we use tied weights).
+ rank_logger.info(embeddings.shape)
+ out = in_out_embed.decode(embeddings)
+ rank_logger.info(out.shape)
+ out *= config.output_multiplier_scale
+
+ if self.model.shard_activations:
+ out = with_sharding_constraint(out, P("data", None, self.model.model_axis))
+ else:
+ out = with_sharding_constraint(out, P("data", None))
+
+ return LanguageModelOutput(
+ logits=out,
+ model_state=model_state,
+ )
+
+ def init_memory(self, batch_size: int, seq_len: int, dtype=jnp.bfloat16):
+ return self.model.init_memory(batch_size=batch_size, sequence_len=seq_len, dtype=dtype)
+
+ def prefill_memory(self, prompts, memory):
+ # Pad to the left and right align?
+ # Basically assume prompt is already padded
+ model_output = self(prompts, memory=memory)
+ return model_output.logits, model_output.model_state
+
+
+@dataclass
+class Transformer(hk.Module):
+ """A transformer stack."""
+
+ num_q_heads: int
+ num_kv_heads: int
+ key_size: int
+ widening_factor: float
+ init_scale: float
+ mesh: Any
+ attn_output_multiplier: float
+ shard_activations: bool
+ num_layers: int
+ # MoE
+ num_experts: int
+ num_selected_experts: int
+ name: Optional[str] = None
+
+ # Used for activation sharding
+ data_axis: Union[str, Tuple[str, ...]] = "data"
+ model_axis: Union[str, Tuple[str, ...]] = "model"
+
+ def init_memory(self, batch_size: int, sequence_len: int, dtype=jnp.bfloat16):
+ return Memory(
+ layers=init_layer_memories(
+ batch_size,
+ sequence_len,
+ self.num_kv_heads,
+ self.key_size,
+ self.num_layers,
+ step=jnp.zeros(batch_size, dtype=jnp.int32),
+ dtype=dtype,
+ ),
+ )
+
+ def __call__(
+ self,
+ embeddings: jax.Array, # [B, T, D]
+ mask: jax.Array, # [B, T]
+ memory: Optional[Memory],
+ ) -> TransformerOutput:
+ """Transforms input embedding sequences to output embedding sequences."""
+
+ fprop_dtype = embeddings.dtype
+ _, seq_len, model_size = embeddings.shape
+ padding_mask = mask.copy()
+ mask = mask[:, None, None, :] # [B, H=1, T'=1, T]
+
+ # Compute causal mask for autoregressive sequence modelling.
+ causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len))).astype(
+ fprop_dtype
+ ) # [B=1, H=1, T, T]
+ mask = mask * causal_mask # [B, H=1, T, T]
+
+ h = embeddings
+ kv_memories = []
+
+ def block(
+ h,
+ mask,
+ padding_mask,
+ memory,
+ layer_index: Optional[int] = None,
+ widening_factor: Optional[int] = None,
+ name: Optional[str] = None,
+ ) -> DecoderOutput:
+ return DecoderLayer(
+ num_q_heads=self.num_q_heads,
+ num_kv_heads=self.num_kv_heads,
+ key_size=self.key_size,
+ widening_factor=widening_factor or self.widening_factor,
+ num_layers=self.num_layers,
+ mesh=self.mesh,
+ data_axis=self.data_axis,
+ model_axis=self.model_axis,
+ attn_output_multiplier=self.attn_output_multiplier,
+ shard_activations=self.shard_activations,
+ # MoE.
+ num_experts=self.num_experts,
+ num_selected_experts=self.num_selected_experts,
+ name=name,
+ layer_index=layer_index,
+ )(
+ h,
+ mask,
+ padding_mask,
+ memory,
+ )
+
+ for i in range(self.num_layers):
+ decoder_output = block(
+ h,
+ mask,
+ padding_mask,
+ memory.layers[i] if memory else None,
+ layer_index=i,
+ name=f"decoder_layer_{i}",
+ )
+ h, new_kv_memory = (
+ decoder_output.embeddings,
+ decoder_output.memory,
+ )
+ kv_memories.append(new_kv_memory)
+
+ return TransformerOutput(
+ embeddings=h,
+ memory=Memory(layers=kv_memories),
+ )
diff --git a/grok-1-main/pyproject.toml b/grok-1-main/pyproject.toml
new file mode 100644
index 0000000..aa55016
--- /dev/null
+++ b/grok-1-main/pyproject.toml
@@ -0,0 +1,14 @@
+[tool.ruff]
+indent-width = 4
+line-length = 100
+
+[tool.ruff.lint]
+ignore = [
+ "E722",
+ "E731",
+ "E741",
+ "F405",
+ "E402",
+ "F403",
+]
+select = ["ISC001"]
diff --git a/grok-1-main/requirements.txt b/grok-1-main/requirements.txt
new file mode 100644
index 0000000..a612687
--- /dev/null
+++ b/grok-1-main/requirements.txt
@@ -0,0 +1,4 @@
+dm_haiku==0.0.12
+jax[cuda12-pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+numpy==1.26.4
+sentencepiece==0.2.0
diff --git a/grok-1-main/run.py b/grok-1-main/run.py
new file mode 100644
index 0000000..bc75c3e
--- /dev/null
+++ b/grok-1-main/run.py
@@ -0,0 +1,261 @@
+# Copyright 2024 X.AI Corp.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import json
+import logging
+import re
+import sys
+
+from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
+from runners import InferenceRunner, ModelRunner, sample_from_model
+
+
+CKPT_PATH = "./checkpoints/"
+
+
+def get_grok1_config():
+ """Returns the standard Grok-1 model configuration."""
+ return LanguageModelConfig(
+ vocab_size=128 * 1024,
+ pad_token=0,
+ eos_token=2,
+ sequence_len=8192,
+ embedding_init_scale=1.0,
+ output_multiplier_scale=0.5773502691896257,
+ embedding_multiplier_scale=78.38367176906169,
+ model=TransformerConfig(
+ emb_size=48 * 128,
+ widening_factor=8,
+ key_size=128,
+ num_q_heads=48,
+ num_kv_heads=8,
+ num_layers=64,
+ attn_output_multiplier=0.08838834764831845,
+ shard_activations=True,
+ # MoE.
+ num_experts=8,
+ num_selected_experts=2,
+ # Activation sharding.
+ data_axis="data",
+ model_axis="model",
+ ),
+ )
+
+
+def main():
+ grok_1_model = get_grok1_config()
+ inference_runner = InferenceRunner(
+ pad_sizes=(1024,),
+ runner=ModelRunner(
+ model=grok_1_model,
+ bs_per_device=0.125,
+ checkpoint_path=CKPT_PATH,
+ ),
+ name="local",
+ load=CKPT_PATH,
+ tokenizer_path="./tokenizer.model",
+ local_mesh_config=(1, 8),
+ between_hosts_config=(1, 1),
+ )
+ inference_runner.initialize()
+ gen = inference_runner.run()
+
+ inp = "The answer to life the universe and everything is of course"
+ print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01))
+
+
+# =============================================================================
+# 强兼 BRIDGE — Grokking Evaluation Mode
+# =============================================================================
+#
+# Tests whether Grok-1 has "grokked" modular arithmetic — the same tasks
+# that OpenAI's grokking paper studies with small transformers.
+#
+# Usage: python run.py --eval-grokking [--operator +] [--n-samples 50]
+#
+# Source: https://github.com/openai/grok
+# =============================================================================
+
+GROKKING_OPERATORS = {
+ "+": ("addition", "What is {a} + {b} mod {p}?"),
+ "-": ("subtraction", "What is {a} - {b} mod {p}?"),
+ "*": ("multiplication", "What is {a} * {b} mod {p}?"),
+ "/": ("division", "What is {a} / {b} mod {p}? (modular inverse)"),
+}
+
+
+def generate_arithmetic_problems(operator="+", n_samples=50, modulus=97, seed=42):
+ """Generate arithmetic problems matching OpenAI's grokking dataset."""
+ import random as rng
+ rng.seed(seed)
+
+ problems = []
+ all_pairs = [(a, b) for a in range(modulus) for b in range(modulus)]
+ if operator == "/":
+ all_pairs = [(a, b) for a, b in all_pairs if b != 0]
+ rng.shuffle(all_pairs)
+
+ for a, b in all_pairs[:n_samples]:
+ if operator == "+":
+ answer = (a + b) % modulus
+ elif operator == "-":
+ answer = (a - b) % modulus
+ elif operator == "*":
+ answer = (a * b) % modulus
+ elif operator == "/":
+ answer = (a * pow(b, modulus - 2, modulus)) % modulus
+ else:
+ continue
+
+ _, template = GROKKING_OPERATORS[operator]
+ prompt = template.format(a=a, b=b, p=modulus) + " Answer with just the number."
+ problems.append({
+ "a": a, "b": b, "operator": operator,
+ "expected": answer, "prompt": prompt,
+ })
+
+ return problems
+
+
+def eval_grokking(args):
+ """
+ Run Grok-1 on modular arithmetic problems from OpenAI's grokking research.
+
+ This is the heart of the 强兼 bridge: does the 314B-parameter Grok-1
+ model — named after the concept of deep understanding — actually
+ demonstrate deep understanding of the exact mathematical tasks that
+ OpenAI's "grokking" paper investigates?
+ """
+ grok_1_model = get_grok1_config()
+
+ # Print architecture bridge info
+ print("=" * 70)
+ print("强兼 BRIDGE: Does Grok grok grokking?")
+ print("=" * 70)
+ print(grok_1_model.model.architecture_summary())
+ print(f"Grokking evaluation: {args.operator} mod 97")
+ print(f"Samples: {args.n_samples}")
+ print("=" * 70)
+
+ # Export config for the grokking framework
+ grokking_config = grok_1_model.model.to_grokking_config()
+ print(f"\n[Bridge] Grok-1 → grokking config: {json.dumps({k:v for k,v in grokking_config.items() if not k.startswith('_')}, indent=2)}\n")
+
+ # Generate problems
+ problems = generate_arithmetic_problems(
+ operator=args.operator,
+ n_samples=args.n_samples,
+ )
+
+ if not args.dry_run:
+ # Initialize the model
+ inference_runner = InferenceRunner(
+ pad_sizes=(1024,),
+ runner=ModelRunner(
+ model=grok_1_model,
+ bs_per_device=0.125,
+ checkpoint_path=args.checkpoint_path,
+ ),
+ name="grokking_eval",
+ load=args.checkpoint_path,
+ tokenizer_path=args.tokenizer_path,
+ local_mesh_config=(1, 8),
+ between_hosts_config=(1, 1),
+ )
+ inference_runner.initialize()
+ gen = inference_runner.run()
+
+ # Evaluate
+ correct = 0
+ total = 0
+ results = []
+ for problem in problems:
+ response = sample_from_model(
+ gen, problem["prompt"],
+ max_len=20, temperature=0.01,
+ )
+ # Parse the response for a number
+ numbers = re.findall(r'\b(\d+)\b', response)
+ predicted = int(numbers[-1]) if numbers else None
+ is_correct = predicted == problem["expected"] if predicted is not None else False
+
+ results.append({
+ **problem,
+ "predicted": predicted,
+ "correct": is_correct,
+ "raw_response": response,
+ })
+
+ if is_correct:
+ correct += 1
+ total += 1
+
+ status = "✓" if is_correct else "✗"
+ print(f" {status} {problem['a']} {problem['operator']} {problem['b']} mod 97 "
+ f"= {problem['expected']} (Grok-1: {predicted})")
+
+ # Summary
+ accuracy = correct / total * 100 if total > 0 else 0
+ print(f"\n{'=' * 70}")
+ print(f"RESULTS: {correct}/{total} correct ({accuracy:.1f}%)")
+ print(f"{'=' * 70}")
+
+ verdict = (
+ "Grok GROKS grokking! 🎉" if accuracy > 95 else
+ "Grok partially groks grokking." if accuracy > 50 else
+ "Grok does NOT grok grokking. (yet?)"
+ )
+ print(f"\nVerdict: {verdict}\n")
+
+ # Save results
+ if args.output:
+ with open(args.output, "w") as f:
+ json.dump({"accuracy": accuracy, "results": results}, f, indent=2)
+ print(f"Results saved to {args.output}")
+ else:
+ print("[Dry run] Generated problems (no inference):")
+ for p in problems[:5]:
+ print(f" {p['prompt']} → expected: {p['expected']}")
+ print(f" ... ({len(problems)} total)")
+ print(f"\n[Bridge] To run with Grok-1, remove --dry-run and provide checkpoint.")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Grok-1 inference & grokking evaluation (强兼 bridge)"
+ )
+ parser.add_argument("--eval-grokking", action="store_true",
+ help="Run grokking arithmetic evaluation instead of default inference")
+ parser.add_argument("--operator", type=str, default="+",
+ choices=["+", "-", "*", "/"],
+ help="Arithmetic operator for grokking eval")
+ parser.add_argument("--n-samples", type=int, default=50,
+ help="Number of arithmetic problems to evaluate")
+ parser.add_argument("--checkpoint-path", type=str, default=CKPT_PATH,
+ help="Path to Grok-1 checkpoints")
+ parser.add_argument("--tokenizer-path", type=str, default="./tokenizer.model",
+ help="Path to tokenizer model")
+ parser.add_argument("--output", type=str, default=None,
+ help="Path to save JSON results")
+ parser.add_argument("--dry-run", action="store_true",
+ help="Generate problems without running inference")
+
+ args = parser.parse_args()
+ logging.basicConfig(level=logging.INFO)
+
+ if args.eval_grokking:
+ eval_grokking(args)
+ else:
+ main()
diff --git a/grok-1-main/runners.py b/grok-1-main/runners.py
new file mode 100644
index 0000000..452c142
--- /dev/null
+++ b/grok-1-main/runners.py
@@ -0,0 +1,605 @@
+# Copyright 2024 X.AI Corp.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import bisect
+import functools
+import logging
+import math
+import re
+from dataclasses import dataclass
+from typing import Any, Callable, NamedTuple, Optional, Tuple
+
+import haiku as hk
+import jax
+import jax.experimental.pjit as pjit
+import jax.numpy as jnp
+import numpy as np
+import sentencepiece
+from jax.experimental import mesh_utils
+from jax.sharding import PartitionSpec as P
+from jax.typing import ArrayLike
+
+import checkpoint as xai_checkpoint
+from model import (
+ LanguageModelConfig,
+ LanguageModelOutput,
+ TrainingState,
+ apply_rules,
+ Memory,
+ KVMemory,
+)
+
+logger = logging.getLogger(__name__)
+rank_logger = logging.getLogger("rank")
+
+TOP_K = 8
+
+
+class SampleSettings(NamedTuple):
+ temperature: ArrayLike
+ nucleus_p: ArrayLike
+ mask: ArrayLike
+ # Whether a given batch element is actively used. [B]
+ active: ArrayLike
+
+
+class SampleOutput(NamedTuple):
+ token_id: ArrayLike
+ prob: ArrayLike
+ top_k_token_ids: ArrayLike
+ top_k_probs: ArrayLike
+
+
+def insert_slice(memory: Memory, slice, length, i):
+ slice = Memory(
+ layers=[
+ KVMemory(layer.k, layer.v, step=jnp.array([length]))
+ for layer in slice.layers
+ ],
+ )
+
+ return jax.tree_map(lambda m, u: jax.lax.dynamic_update_index_in_dim(m, u[0], i, axis=0),
+ memory, slice)
+
+
+def pad_to_size(x, size):
+ if x.shape[0] > size:
+ # Left truncate if the context is too long.
+ x = x[-size:]
+ return np.pad(x, [0, size - x.shape[0]], mode="constant", constant_values=0)
+
+
+def top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array:
+ """Performs nucleus filtering on logits."""
+ assert logits.ndim == top_p.ndim, f"Expected {logits.ndim} equal {top_p.ndim}"
+ sorted_logits = jax.lax.sort(logits, is_stable=False)
+ sorted_probs = jax.nn.softmax(sorted_logits)
+ threshold_idx = jnp.argmax(jnp.cumsum(sorted_probs, -1) >= 1 - top_p, axis=-1)
+ threshold_largest_logits = jnp.take_along_axis(
+ sorted_logits, threshold_idx[..., jnp.newaxis], axis=-1
+ )
+ assert threshold_largest_logits.shape == logits.shape[:-1] + (1,)
+ mask = logits >= threshold_largest_logits
+ # Set unused logits to -inf.
+ logits = jnp.where(mask, logits, -1e10)
+ return logits
+
+
+def sample_token(
+ rngs: jax.random.PRNGKey,
+ lm_outputs: LanguageModelOutput,
+ settings: SampleSettings,
+) -> SampleOutput:
+ # Expand the settings shape to match the logit shape.
+ settings = SampleSettings(
+ temperature=jnp.expand_dims(settings.temperature, (1, 2)), # Input [B], output [B, 1, 1].
+ nucleus_p=jnp.expand_dims(settings.nucleus_p, (1, 2)), # Input [B], output [B, 1, 1].
+ mask=jnp.expand_dims(settings.mask, 1), # Input [B, V], output [B, 1, V].
+ active=settings.active, # [B].
+ )
+ logits = lm_outputs.logits / settings.temperature.astype(lm_outputs.logits.dtype)
+ # Mask out all disallowed tokens by assigning them a near-zero probability.
+ logits = jnp.where(settings.mask, logits, -1e10)
+ # Mask out all tokens that don't fall into the p-th percentile.
+ logits = top_p_filter(logits, settings.nucleus_p.astype(logits.dtype))
+
+ new_token = jax.vmap(jax.random.categorical)(rngs, logits)
+
+ probabilities = jax.nn.softmax(logits)
+ token_prob = jnp.take_along_axis(probabilities, jnp.expand_dims(new_token, 1), axis=2)
+ token_prob = jnp.squeeze(token_prob, 1)
+
+ # Gather the top-k tokens and probabilities.
+ top_k_probs, top_k_token_ids = jax.lax.top_k(probabilities, TOP_K)
+ top_k_probs = jnp.squeeze(top_k_probs, 1)
+ top_k_token_ids = jnp.squeeze(top_k_token_ids, 1)
+ return SampleOutput(
+ new_token,
+ token_prob,
+ top_k_token_ids,
+ top_k_probs,
+ )
+
+
+@dataclass
+class ModelRunner:
+ model: LanguageModelConfig
+
+ bs_per_device: float = 2.0
+
+ load_rename_rules: Optional[list[tuple[str, str]]] = None
+ load_exclude_rules: Optional[list[str]] = None
+
+ rng_seed: int = 42 # Initial rng seed.
+ transform_forward: bool = False
+
+ checkpoint_path: str = ""
+
+ def make_forward_fn(self, mesh: Any):
+ def forward(tokens):
+ out = self.model.make(mesh=mesh)(tokens)
+ return out, None
+
+ if self.transform_forward:
+ forward = hk.transform(forward)
+ return forward
+
+ def initialize(
+ self,
+ init_data,
+ local_mesh_config: tuple[int, int],
+ between_hosts_config: tuple[int, int],
+ ):
+ num_replicas = math.prod(between_hosts_config)
+ self.model.initialize()
+ self.model.fprop_dtype = jnp.bfloat16
+ num_local_gpus = len(jax.local_devices())
+
+ # Calculate the global batch size from the local batch size.
+ self.batch_size = int(self.bs_per_device * num_local_gpus * num_replicas)
+
+ # Calculate the batch size per host from the global batch size.
+ self.local_batch_size = self.batch_size // jax.process_count()
+
+ self.local_mesh_config = local_mesh_config
+ self.between_hosts_config = between_hosts_config
+ rank_logger.info(
+ f"Initializing mesh for {self.local_mesh_config=} {self.between_hosts_config=}..."
+ )
+ self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config)
+ self.forward = self.make_forward_fn(mesh=self.mesh)
+ self.logits_fn = hk.transform(lambda tokens: self.forward(tokens)[0])
+
+ self.eval_forward = self.make_forward_fn(mesh=self.mesh)
+ self.logits_eval_fn = hk.transform(lambda tokens: self.eval_forward(tokens)[0])
+
+ if self.transform_forward:
+ self.state_sharding = self.get_state_sharding(init_data)
+ rank_logger.info(f"State sharding type: {type(self.state_sharding)}")
+ self.init_fn = pjit.pjit(self.init, out_shardings=self.state_sharding)
+
+ def init(self, rng: jax.Array, data) -> TrainingState:
+ assert self.transform_forward
+ rng, init_rng = jax.random.split(rng)
+ params = self.forward.init(init_rng, data["inputs"])
+ return TrainingState(params=params)
+
+ def get_state_sharding(self, init_data):
+ assert self.transform_forward
+ rng = jax.random.PRNGKey(self.rng_seed)
+ rank_logger.info(f"partition rules: {self.model.partition_rules}")
+
+ with self.mesh:
+ shapes = jax.eval_shape(self.init, rng, init_data)
+ sharding = jax.tree_util.tree_map_with_path(
+ apply_rules(self.model.partition_rules()),
+ shapes,
+ )
+ return sharding
+
+ def load_or_init(
+ self,
+ init_data: Any,
+ from_checkpoint: bool = True,
+ init_fn: Optional[Callable] = None,
+ ):
+ rng = jax.random.PRNGKey(self.rng_seed)
+
+ if not self.checkpoint_path or not from_checkpoint:
+ rank_logger.info("Initializing model...")
+ with self.mesh:
+ if init_fn is not None:
+ state = init_fn(rng, init_data)
+ else:
+ assert self.transform_forward
+ state = self.init_fn(rng, init_data)
+ rank_logger.info("Model state is newly initialized.")
+ else:
+ with self.mesh:
+ if init_fn:
+ state_shapes = jax.eval_shape(init_fn, rng, init_data)
+ else:
+ assert self.transform_forward
+ state_shapes = jax.eval_shape(self.init_fn, rng, init_data)
+ init_state = None
+
+ state = xai_checkpoint.restore(
+ checkpoint_path=self.checkpoint_path,
+ state_shapes=state_shapes,
+ mesh=self.mesh,
+ between_hosts_config=self.between_hosts_config,
+ state_sharding=self.state_sharding,
+ init_state=init_state,
+ params_only=True,
+ )
+
+ del init_state
+ return state
+
+
+@dataclass
+class Request:
+ prompt: str
+ temperature: float
+ nucleus_p: float
+ rng_seed: int
+ max_len: int
+
+
+@dataclass
+class InferenceRunner:
+ name: str
+ runner: Any
+ load: str
+ tokenizer_path: str = "/tmp/xai_data/tokenizer.model"
+ local_mesh_config: Tuple[int, int] = (1, 1)
+ between_hosts_config: Tuple[int, int] = (1, 1)
+ pad_sizes: tuple[int] = (1024,)
+
+ def get_pad_bucket(self, size):
+ i = bisect.bisect_left(self.pad_sizes, size)
+ return self.pad_sizes[min(i, len(self.pad_sizes) - 1)]
+
+ def initialize(self):
+ runner = self.runner
+ self.runner.transform_forward = True
+ dummy_data = dict(
+ inputs=np.zeros((1, 256), dtype=np.int32),
+ targets=np.zeros((1, 256), dtype=np.int32),
+ )
+ runner.initialize(
+ dummy_data,
+ local_mesh_config=self.local_mesh_config,
+ between_hosts_config=self.between_hosts_config,
+ )
+
+ self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=self.tokenizer_path)
+
+ max_len = runner.model.sequence_len
+
+ self.vocab_size = self.runner.model.vocab_size
+
+ params = runner.load_or_init(dummy_data)
+ self.params = params
+
+ def pad_to_max_len(x):
+ if len(x.shape) > 1:
+ pad_width = max_len - x.shape[1]
+ return jnp.pad(x, [(0, 0), (0, pad_width), (0, 0), (0, 0)])
+ else:
+ return x
+
+ @functools.lru_cache
+ def lm():
+ return runner.model.make(mesh=runner.mesh)
+
+ def hk_forward(
+ tokens,
+ memory=None,
+ length=None,
+ active=None,
+ ) -> LanguageModelOutput:
+ if memory is not None:
+ assert active is not None
+ layers = []
+ for l in memory.layers:
+ # Reset steps to 0 for inactive requests to avoid unnecessary computations.
+ step = jnp.where(active, l.step, jnp.zeros_like(l.step))
+ layers.append(l._replace(step=step))
+ memory = memory._replace(layers=layers)
+ return lm()(tokens, memory, length=length)
+
+ def hk_sample_step(rngs, last_output: SampleOutput, memory, settings):
+ rngs, rngs_ = jax.vmap(jax.random.split, out_axes=1)(rngs)
+ lm_outputs = hk_forward(last_output.token_id, memory=memory, active=settings.active)
+ sample_result = sample_token(rngs_, lm_outputs, settings)
+ return rngs, sample_result, lm_outputs.model_state
+
+ def hk_new_memory(batch_size, sequence_len):
+ return lm().init_memory(batch_size, sequence_len)
+
+ def hk_prefill_memory(
+ rngs,
+ memory,
+ settings,
+ last_output,
+ prompt,
+ length,
+ rng_seed,
+ new_settings,
+ i,
+ ):
+ rng = jax.random.PRNGKey(seed=rng_seed)
+ rng, rng_ = jax.random.split(rng)
+
+ # Allocate new memory for this sample. The memory length is equal to the length of the
+ # prompt.
+ slice = hk_new_memory(1, prompt.shape[0])
+
+ # Move the settings for this individual batch entry into the joint settings tensor.
+ settings = jax.tree_map(
+ lambda o, v: jax.lax.dynamic_update_index_in_dim(o, v, i, axis=0),
+ settings,
+ new_settings,
+ )
+
+ # Get the settings for the batch entry from the joint settings tensor.
+ settings_slice = jax.tree_map(lambda t: jnp.expand_dims(t[i], axis=0), settings)
+
+ # Process the first n-1 tokens of the prompt.
+ lm_outputs = hk_forward(
+ jnp.expand_dims(prompt, 0),
+ memory=slice,
+ length=jnp.expand_dims(length, 0),
+ active=settings_slice.active,
+ )
+
+ # The forward pass doesn't correctly set the `step` counter inside the memory. Manually
+ # override it so `hk_forward` uses the correct context length in the next call.
+ slice = lm_outputs.model_state
+ slice = slice._replace(
+ layers=[l._replace(step=jnp.array([length])) for l in slice.layers]
+ )
+
+ # Sample the actual output token.
+ rng_ = jnp.expand_dims(rng_, 0)
+ new_output = sample_token(rng_, lm_outputs, settings_slice)
+
+ # Update the KV cache/memory.
+ slice = jax.tree_map(pad_to_max_len, slice)
+ memory = insert_slice(memory, slice, length, i)
+
+ rng = jnp.expand_dims(rng, 0)
+ rngs = jax.lax.dynamic_update_index_in_dim(rngs, rng, i, axis=0)
+
+ # Move the network outputs for this batch entry into the joint output tensor.
+ last_output = jax.tree_util.tree_map(
+ lambda last, new: jax.lax.dynamic_update_index_in_dim(last, new, i, axis=0),
+ last_output,
+ new_output,
+ )
+ return rngs, last_output, memory, settings
+
+ sample_step_ = hk.without_apply_rng(hk.transform(hk_sample_step))
+ prefill_memory_ = hk.without_apply_rng(hk.transform(hk_prefill_memory))
+ new_memory_ = hk.without_apply_rng(hk.transform(hk_new_memory))
+ forward_ = hk.without_apply_rng(hk.transform(hk_forward))
+
+ rng = jax.random.PRNGKey(42)
+ dummy_tokens = jnp.zeros((1, max_len), jnp.int32)
+
+ with runner.mesh:
+ shapes = jax.eval_shape(forward_.init, rng, dummy_tokens)
+
+ self.params_sharding = jax.tree_util.tree_map_with_path(
+ apply_rules(runner.model.partition_rules()),
+ shapes,
+ )
+
+ ds = P("data")
+ ms = runner.model.model.get_memory_sharding()
+ self.sample_step = pjit.pjit(
+ sample_step_.apply,
+ in_shardings=(self.params_sharding, None, ds, ms, None),
+ out_shardings=(None, ds, ms),
+ donate_argnums=3,
+ )
+ self.prefill_memory = pjit.pjit(
+ functools.partial(prefill_memory_.apply),
+ in_shardings=(
+ self.params_sharding,
+ None,
+ ms,
+ None,
+ ds,
+ None,
+ None,
+ None,
+ None,
+ None,
+ ),
+ out_shardings=(None, ds, ms, None),
+ donate_argnums=(2,),
+ )
+ self.new_memory = pjit.pjit(
+ new_memory_.apply,
+ static_argnums=(1, 2),
+ out_shardings=ms,
+ )
+
+ def run(self):
+ """Generator that accepts prompts."""
+ runner = self.runner
+ mesh = runner.mesh
+ max_len = runner.model.sequence_len
+ batch_size = runner.batch_size
+ params = self.params
+ rngs = jax.random.split(jax.random.PRNGKey(1), batch_size)
+ with mesh:
+ memory = self.new_memory(params, batch_size, max_len)
+ settings = SampleSettings(
+ temperature=np.zeros((batch_size,), dtype=np.float32),
+ nucleus_p=np.zeros((batch_size,), dtype=np.float32),
+ mask=np.ones((batch_size, self.vocab_size), dtype=np.int32),
+ active=np.zeros((batch_size), dtype=np.int32),
+ )
+ last_output = SampleOutput(
+ token_id=np.zeros((batch_size, 1), dtype=np.int32),
+ prob=np.zeros((batch_size, 1), dtype=jnp.bfloat16),
+ top_k_token_ids=np.zeros((batch_size, TOP_K), dtype=np.int32),
+ top_k_probs=np.zeros((batch_size, TOP_K), dtype=jnp.bfloat16),
+ )
+
+ prompt = np.array([300, 400, 500, 600, 600, 700, 800])
+
+ new_settings = SampleSettings(
+ temperature=np.float32(1),
+ nucleus_p=np.float32(1),
+ mask=np.ones((self.vocab_size,), dtype=np.int32),
+ active=np.zeros((), dtype=np.int32),
+ )
+ rng_seed = np.uint64(1)
+
+ for size in self.pad_sizes:
+ if size > runner.model.sequence_len:
+ break
+ logger.info("Precompile {}".format(size))
+ prompt_len = len(prompt)
+ prompt = pad_to_size(prompt, size)
+ rngs, last_output, memory, settings = self.prefill_memory(
+ params,
+ rngs,
+ memory,
+ settings,
+ last_output,
+ prompt,
+ prompt_len,
+ rng_seed,
+ new_settings,
+ 0,
+ )
+ with runner.mesh:
+ logger.info("Compiling...")
+ rngs, last_output, memory = self.sample_step(
+ params, rngs, last_output, memory, settings
+ )
+ logger.info("Done compiling.")
+
+ all_tokens = []
+ free_slots = list(range(batch_size))
+ requests = [None] * batch_size
+ first_output = [None] * batch_size
+ jax.tree_map(lambda x: x.copy_to_host_async(), last_output)
+ prev_token = last_output
+ step = 0
+ total_num_tokens = 0
+ total_num_sequences = 0
+ with mesh:
+ while True:
+ while free_slots:
+ request: Optional[Request] = yield
+ tokens = self.tokenizer.encode(request.prompt)
+ temperature = request.temperature
+ nucleus_p = request.nucleus_p
+ rng_seed = request.rng_seed
+
+ i = free_slots.pop()
+ prompt = np.array(tokens, dtype=np.int32)
+ prompt_len = len(prompt)
+ prompt = pad_to_size(prompt, self.get_pad_bucket(prompt.shape[0]))
+ # All tokens are allowed.
+ mask = np.ones((self.vocab_size,), dtype=np.int32)
+
+ new_settings = SampleSettings(
+ temperature=np.float32(temperature),
+ nucleus_p=np.float32(nucleus_p),
+ mask=mask,
+ active=np.ones((), dtype=np.int32),
+ )
+ rng_seed = np.uint64(rng_seed)
+ rngs, last_output, memory, settings = self.prefill_memory(
+ params,
+ rngs,
+ memory,
+ settings,
+ last_output,
+ prompt,
+ prompt_len,
+ rng_seed,
+ new_settings,
+ i,
+ )
+ jax.tree_map(lambda x: x.copy_to_host_async(), last_output)
+ first_output[i] = last_output
+ requests[i] = request
+ total_num_sequences += 1
+
+ rngs, last_output, memory = self.sample_step(
+ params, rngs, last_output, memory, settings
+ )
+ total_num_tokens += batch_size - len(free_slots)
+
+ # prev_token should already be on the host.
+ prev_token = jax.tree_map(np.array, prev_token)
+ for i in range(batch_size):
+ if requests[i] is not None:
+ if first_output[i] is not None:
+ first_output_i = jax.tree_map(np.array, first_output[i])
+ all_tokens.append(int(first_output_i.token_id[i][0]))
+ first_output[i] = None
+ continue
+
+ all_tokens.append(int(prev_token.token_id[i][0]))
+ cont = len(all_tokens) < requests[i].max_len
+
+ if not cont:
+ output_str = self.tokenizer.decode(all_tokens)
+ requests[i] = None
+ free_slots.append(i)
+ all_tokens = []
+ settings = settings._replace(active=settings.active.at[i].set(0))
+ yield output_str
+
+ jax.tree_map(lambda x: x.copy_to_host_async(), last_output)
+ prev_token = last_output
+ step += 1
+
+
+def make_mesh(
+ local_mesh_config: tuple[int, ...], between_hosts_config: tuple[int, ...]
+) -> jax.sharding.Mesh:
+ assert len(local_mesh_config) == 2
+ assert len(between_hosts_config) == 2
+ rank_logger.info("Detected %s devices in mesh", jax.device_count())
+ device_mesh = mesh_utils.create_hybrid_device_mesh(
+ local_mesh_config,
+ between_hosts_config,
+ devices=jax.devices(),
+ process_is_granule=True,
+ )
+ rank_logger.debug(re.sub("\n+", "\n", f"Job device mesh is:\n{device_mesh}"))
+ return jax.sharding.Mesh(device_mesh, ("data", "model"))
+
+
+def sample_from_model(server, prompt, max_len, temperature):
+ next(server)
+ inp = Request(
+ prompt=prompt,
+ temperature=temperature,
+ nucleus_p=1.0,
+ rng_seed=42,
+ max_len=max_len,
+ )
+ return server.send(inp)
diff --git a/grok-1-main/tokenizer.model b/grok-1-main/tokenizer.model
new file mode 100644
index 0000000000000000000000000000000000000000..d2ff64d9c3329cfe2296d64126ef4e23452ee07f
GIT binary patch
literal 2229219
zcma&P3!G)yRo8zYAtaSo5fTXRo5_%wNhXu(*K|*qB&WK%d!{GT)0gR<3?T`*x2kS+
z-I=PoH+8GJXF^5~h=_<-0wN;d6%i2;OI}9gaWaXBh*&{HM8xZeh=_PaL_|ctzx6wN
zcjtV<|4)7T)1UQQ`?c2EkF)pMkGt@)g-?1{clq3VURCh->gO#8pS^tLWAOPFe(aIA
zH{qYSu(0q63r_p|g%>Y?6L`tO^9W4nr3=poWk@N48NYnt1%WeTap9BA$ed4EcwxiI
ztXC|&D6ul{Q!@u5m5z{0x6sFKp^x7}pSXoSc?*5&7W&>>=+n2*XKtbIyM=yHqA|{k
zGQc=5x`{9jC%v6KPLVyihd0L7EM2fe~YRg
z!@ot>kKx~8oqL(e`8bx2XFu{69qBR+(XFpjku$J>k(ZN|wq
z<5Zh*Z<}$t%{bF$+}CC-w;5;KjFmRyT$^#e&A8BF9BbR-Slb@Q+V(isw#TuyJ&v{Q
zajb2RV{LmJYun>k+aAZ-_BhtI$Fa6Oj?4EINr9$@wPpV
zx9xGfZI9z^dmL}u<9OR1$J_Qe-nPf_wmpuw?Qy(qkK=87oM_wQMB5%G+V(inw#SLK
zJx;XkaiVRH6K#8(XxrmN+a4#{_BheD$BDK*PPFZDqHT{8ZF`(-+v8;09w*!OIN7$x
z$+kUCw(W7UZI6>}dz@_B<7C?&C)@To*|x{YwmnX^?QybgkCSbCoNC+SRNEe>+V(ir
zw#TWqJx;alajI>PQ*C>kYTM&f+a9Od_BhqH$EmhGPPOfEs%?)`ZF}6?w#U6~d)(W$
z$GvTP+}pOty={Bk+qTEOZF}6?w#U6~d)(W$$GvTP+}pOty={Bk+qTEOZF`(<+v9ZG
z9;e&(INi3#>9##ix9xGdZI9D!dz^0D<8<2|r`z^8-L}W+wmnX_?QyzokJD{?oN3$R
zOxqr3+V(ipw#S*aJo+aBlI_BhwJ$GNsW&b94v
zu5FKVZF`(=+v9xO9_QQkIN!F%`L;dIx9xGhZIAP9dz^3E<9yp5=iBx;-?qp3wmr_b
z?Qy2-~H|_D#g%>>cjbFw;xxOkqxzC5*15`iWfWAU%Jg%QO+<`2E+#D!Mkh>zM
zL5>um4sx&yIja|(|6ar|KGbA6A#CwO%CfJbIv-m2)Ps;JZGA@A!GcoompCDPR@w!30Kt
zI)O3ZC!Om$5AQ;_m1o^3J+n?FCH
zk)(X#vyD~zJBs>}XJyDO)xu(sLz&~1LH4EXtDkMmz6tp3XG6uPsZPWhNtD}D344lk
zr?ZqMcLzCpUPKxx
z)29-D8SsH5(yEH6B356P`|=bpl&!zQWhme)gPbV(S2U64a*@(n@mOR{fFEnp+M#5X1Q^hnMBo0HrtfNx1aA>S4?cs=^&^jWriKAAe$
z+l)^5?Ji06`5j5HsP^xSfRO^e%Oa>De|N;`?+ruv?{TY`C?Y#7jgwj1+4!!uO?I
zqhtO+LNPM@kPBDFA9M;e!yir}-WW#CKjK*G13#7uISZ(h{#PTjXH!_0u#cp2qAUDF
z$hj8=_{kJf&c|EJ<_}0R!3MXM^mkW_~Wn=!`#~vUwc-F9iQd
zW&ev&@F;5Umu##pYrm6<9Y;s}-H5d%L;PNleT{~{pNhHw
z_=A+99~l%M)@ZaZi9X@O8frGf*q<2eO*P)BC>0$$)a((#i&
z+>C^qQsIRu)^)&(BDqJwg-G!}welB#xEUYUVrY9wgp8!=OH*F!viHj(`7q4!@|5OA
zw3Ee16OHauQkv)`uZVzwdh(}6^3lE;?JHBH=#Q_8kliT#(^5zb1fQPJ;3-nQIw|uo
z;4>0XexI4ZBY@9xLAx^EXImK!4WHvu1~T5~274^6KF=}4?WA1J16~vS7`$%_^0lvp
z``Q%c0l@8%V?W}2zUlbgTFH}*!EZ9t_T?JX>@zNGYF~QkTTfp
z$H;wmvi434v~P5>2)ZX(XriioQ{+Cp77b@9<+5SZ9!s*{M$F?WU0;3VMC7<1ReLgp
zOw@xc7Q0
zMeV%Rric#zwn#8k^DR??_;xQl(%O@RoDF(+pnkVv{;ueArngnk=Ogu6GXP(RkWqv@
z5Jkk0@}Svahl`h3=v-#1@Ff9=+P9iBe}0c`jOPUnu@s*-R7}KGmKd8h*)FQ<>Mb-IK8lR>tx}N
zLX+@|l6WfXok_(T!>ETZPC`W)?~3BCNxpXn*;ktPB>6BJ@Jk|OUyf~PWh&pmJtk`ba+R6y)GR8RZ!NuceRpj2SwUX9-JbwLlq)TT$#p1boQEr^g|e1ZRRWF(_ikh1$uXz?wE^JB3{!nu>wlwYwFA25k)^K
z8$-hOT==i26z+ULDWd7ZcSfxJUA25@Yq77a{zgut&Qk$h$L}EQyHZTe9xIS!i|_v5
zOwpJsRRDL5tzk^*zdIt1^kJ?VI%@Q9g*@xZ3(`jD$3**k65P=QvH~LbQ!?`3cA-y8
zRcP3|(&u|4K`gQq0<=m4)!N^QG%?4mpqs_qFuo~=k;m&_Pa;TE9qmV(`n#!`M$O0t
z18qk)WTWqkL>uZgkhaeq@V`H0K8g871-2MZD%B5I^r1!%XeYa$fFDfBC`w}l*ki2%ve(`~U)mJDR-xbjq8vme2M|Cxo{!;>x6dnZ2`H)C;f25{Zhg^
z15GL_Folg`4fb!Mu)FSt4e7SL8K(VZ$C87lf<}%F`*i+oDon1q3WT2F-kC*;p_YIZ
zv|GX(Q1#yhcp@gO6~M);4%Et@i$qomn!Lv975_fi{XKnwp~-94F5_QGr9Y}F04*-O
zh$cH8*f7O^Na3p*fk3D_31`x;I-6KUR{-0Hu^lyv|Co~P>8k-E*&uuqzZOEhMXdi4Nwn6lAf03w!^UqW1)khdbyXR!_M-><*A#tS)0PUTI**`v(EIU8fOb22
zoLVdXTS~O6x~YI98pitLw^Dhd9W8DvFkx#kC;IP>x2GlxZ87^X?fdNnS7YV@q=pBY
z!~KuYb8Snj+{%DSujaMCV{=A*R{+Y{j)4D}GTV*JfXL4JHx^&NYayb$Lu13CXR7}z
zedadYk*dsCYd(^{=cHPXRL~?H?Z@Kee@DO|#x`gLc;5K?P7zZK)1z2{{+}da)OZCX
z4Qi=L|3R|gjqv_dU|VW47GnQ5!T!1|0z~iK7=Me8h#pDJ650)SD;69T*ttAcef_J7
zItOpZDyaB`CMg1x(5~xM&6_gXpb6KYH!R#vtNs7|hbEsG8DV%CsRV^|s(t9DkxJMa
zNOe!t&z~2h=0K&%+cf4tl6I)>QA0cJc3&Uw=Q~^Z#7k%gP2&3r#H=(Ha;tM#EFg+c
zN>Se_zeowR&dd4?fFuvbs7>B^7bU$QnRI)kfwlyD5;gF8s)9kXx8T64b6n<3>qCr}z*UA@E11VuDS!1%VRbQDg-Ga$g1$N>MtxW3-cVkQE
z)hYG0k)}5#m{R-tjGB}$s)n>;6HO34BanRXvLQ%%#Y_*{Wes;Ix(2#~F~4a({GXZ1
z^LSE%NzKxR0--5|^@VzB?T~A+Bk)-fLa&v~rIA{%zEU8UxjRtP2BN@FqswQfvKUfi
z8F14WcA9T39P4#O%?#v5IN6IieDOIE+ZHja@a^y72uE2Sm|{LN~N$
zS3*<9{#LA44MwajNoR1QAO8K?=ZbCh4uMxmmnvM
znNJ0jWaWCha@nkRdNA9vb`hJ$CceU|TnU1Deo8FIE*y|&He7;I(wMq-`HGaVIEydmka
zqP@lvNSa-HD**`F*NUfh0GkhFP~2_P?}q6CwD&9e`a-&n_`+-Os9~U9c%QI#NgCe~2hE~}Fxz}od?Q6%t#eQ})}2ccG>jGvjh6bhR*Rzjmxi=zv4TM!TQR|L%>Gum
z5Z;=^*koj;Ffb8gzFL`%31SV6>_dAt-)0ai>=GdH!f6L>ldgwOP0*J7ab=|GdlUtr
z>^63;AW42rYvUTaNgm6r<Ku7hr4%cv%xf||NJVUG=&Ar7oyN`7yoN)elbAuD3zhn8aF&o(Bc|aM
z&?VxFQ_BLAYOqTI#rYIr11agM_RLBknrc-^Isg*BS({k3_9si
zmBSf(CVR2bUc*3!yVPkf2B}p{32AZIFAAq&W?ZA=NMIPwcQ-b09tx3qa$1(qP2$mR
zxZNGt*N$b~(8{!Ud&=Vxc|*z!*J2(9w6$bxQi9VIuK651>^6HXzFYvBQGfTzaPS!I
zd(K~jQ1Ui)-b;~&nwFXXrO_q^wgI`owM_|7gE6!Bh8{_xPPDyV0xe6w@gY@k$O^-*
zDk;cWP5QCnb2-J|XhxzE4w+wzf}x%Eob&`a?P>^FbVK;QdOswI_G2#++KH}c3IVb}
zlV;Xmd|?V-mcSrVq9}So)0S?;Hbf0@Jx-!eKblglGEpxwo6uBb_Dk5}L{UQObpMGePxxTegZSL8CmA>FDlcJBnB2GOGe*?E&MDL&XS)_N=t0Q4<+&;VvG%k^tDuR;
z7Ky(kj`y_N??%X??cfs9vTS-^JL4^{V!+sg!Zoa43uDVtu^=Kj94Ft3Bcv6$=?Q$&
z)OTPw{TGscGvu=;2^xFcqiq{V;;apHg0Kdl`=B{YIgD6x-ItIqYBjuJ71%YeQ#$(6
z)j7tFzXwmGI*auWh*EHMAyx&zQo|Tzp&JwUEA(W9O=4Vtwmx`Y;)CG8wM|T1L%S+A
zG`JL7saebC73XV8Sd_WvSpmRJ=wS&9LjU)^mxS99(9~23lAJpsC%{E)>5ypcFpN(F
zD|V9H3{Wft~*f4Zq5vdxID*9Xj0Gi%G#KrDf}bj7xA+PJGxi43|@#W3wTnWQ)jS2J_3;E_vT-l^J=clYB9wi_w+
zG7Vw^6l+Hf92x`0=7`Cd158q@2y-%;lyI;olX*Y|Akj6=tsu#36o+s3k~toa?fDWM
z^%kC=9JuK&0rpKD9R+D>t=JFaGkhya38I+TCo;LOi7TXyycTtE-2!MgpoDfC+_2*U
z?5NjuGNT5PU!2{&5u)*ZS26%PgvX|Wc7NDzre`%AXUg#oD!F(nlJ162x&%0I6syt-
z>=HM0ilxrMOcON>Kb@kFy$1uPWB^8m##;Sdc@Js_4C0{1`%~{>FGCwcWoFz%tDwJb8eof5vyFN<$5wVecArU^T{5M*%th
zHMQ>^ki0nPBAWy9z`3ni=M;{M=4*^%HCWA!L2U+UC06%iKrg;5l7}sWX0ZptFKAu+=#)H}sdMLigkw>HtxPzD5}(
zfS7DRYf$Vthz(pdOaP+XsS~nOXhgaA2qS8fOmrQYxYN6DjEN*w*WStsX9x
zIB!{I4jbmNJng|D=T#M$lk-i&)KhTN1UQVSSvCyUOKjTDfSw8UxdagFqv9(fx2L5h
zB;xPKrjXfMFZ3WS{;GGF!8X+fH9vL4r!|g3y0BesJI!*u8^kOAWYUbU2fJew2Zww(
z3lx4-BL?X;V)sCZGKd|KoKlC1Wb~;+huVj!O^52PO8gUainwx3oG)SjB~1Al{DUtGgh69
z*W%=Bnele)LG-}NJcc-E>pB^DAu@&I3guBYO3#VC(4OUBPe%WAioLEL)d9MhT%>AB
zkTd6_7EZ7p9EuJfN|i%*2cF4HbLzADI)H4H$#%G;XE5lIw3aW%%wK(rLD%Ce-hrje
zbs4KP9c{7)Z86pxPj3ZB_G6C%15ZKLIKDJBP~!j1G6zrOGwT6WXiNnw1BNpi!Zau3=!W_)7!z+}y_w@Rlj*MwS?FXG25qTV
zbt$t1x@Db}X^mLIM~BT}U5)xiOqyZ`0fc&+@`cTyo#n+a>Yt0CI4$3ScF;Kddq9^a
z&s#4CvX@Rp-h^5+Oo6rvGa${tz?8VJQ&ux*w_%2snBp8n#++x@WTXRfQl=^xxXc*0
z71iDYyP4qo_rW-fe7m(-N;Z(KH>+2#u3q4H>Iv|VB
z;tqX*6mo+(Ko8QE+8e~er*gvKPE6;gAXK5Cm!hQsTe)O817Pl$H5XrR;k*RsfNb5Z
zIOtk}n*qhWst3c;#4*S!$G;ZmN~aDUYC}xpFc1;YLth;Ob{J@z#WzH;eQ(ZoKrVK}
z^HH!Xkvh7riFgkVB75rv+Cpu~Z-RkgSg7SVT?3?4yOL))^4Ocb0P@BQ6JbK$vqR
z|01MwS(TA`z|=<9Ri$9dG-=!_)0|*iXD8|$f%6Pekux|{JPXv=zx&3tKJoFAILy|8
z=|abGIt7U6`@3rDJwW~dHitkni|>KGaZ^}WCeGi~PL2%`cm}WnzVK80MRzi-RyxoW
zixWJoP&$a2IHf|n*}Na60hs5?Rg4UYG6PSyYLKn5-ncYpaFAo9xnJ>3Nvj*0vvh!}
zyjU9Lppg!s^nk8Q7E2n#svI6BoC0kFc@w-VonufBs0ox@%`w-%hS%I
zbFT+>shiqgt}-0ReK7^Pl_eJ=3}?11n`X2w|Nb;(HPPv{4$v+XDwPAqb_?Y=>zau4
zfYBy2ELQ+E7EKkV(1QhY~}axFYTos2otDHXw#&Lj!Y1304_
zahSnj?7dQ_Qv+I$dLzKPQwo7I2Qn2=hy3d)7gkZ|2Sny)K7z)@cczRUyF!5=YtdK5l|b^1HCPWCS?;-;;+#f@
zIxyihr@4Px?SsDgxx!a1rkxCfJ3ElxZe9Nw*5{4Cf3x4<>2ouv>@g7$W
z;*f3_71W3-_e>2WBY8^vEt=)slvH=^LBDe0rIZ&3U(z80
zxE7MOu0_=phMSsJ!>aGvnJ~*mXK1S+zf$oxQ^A*Ho(@n|r6ZmtH2H=rHICHv92bvC
zbk@poVBX3$(H`JM=w#EHrdz|G(Z9YZWLT#fn!BE~RA
zbtl6_UqnFwS9BaMZ&tXqqu|m&mKngN8tgQeF#dLu+?%W&Fe0-4qotPsPYg7A*tD=b
z`1h~pFI)u^kzy@7bq2%exHZ6bVRHN^bkQe${VI{>S7Ike93lBW4R
zAjbLoRC5lT6}+$!{nQCI)J8S^mlb4zl34+
z*-b}ZV3YMAt~ND^Px6UUj@;s-+6YcV`QBmMHDlnCi&rB!|p
zgcXpLQMz)9;RXRyXe8;cHYbs5CuYU_~EQj+xjt2^d
z^SOzME~r^zXoOUIox-x9t8omb&Y1lqxp=^c_H^vF{k>3P!Wn9m1=tv@a-vz54j#tN
zT+e_VD%7X|8dUT-okBN_Zf&nQE?dKM?W}8f0ySTf#>-WFFbd|7iV6nUdOOXB9fnn~
zWIe~G0aGUx01L9m)mO`#0-bmGaP)n3Y}&$WC^`pgmp0ake?N`1J`fdF?ZBj>pShii
zhK&9aI}bgO4a|u_`{OHE)I2>EX{G=S!Mv7&b6jnq#P0hHR#V+^hvSEnrhOki@4&E~
zmKp3ykPBn&bA6~ubuXtFaXU$4Gz?1T?50WwIQs{}cbL=bnw>|E9AlynYVi-EyfFul
zWOxw5Z|S34LN`r#6!(yNj>;@a2`eLvI%e83&5_qfgNbYo1N&)PY6eNBqd39zBUz#S
z#>dxzMQ}WnP(q_UetZq}JtK-ZM}Mn8ZUfr`En=ol*Vl!c8XC2^lpS?6GbG`Y8Viaa
zZM1^!K)ZF|&5XuY!a>r@3?V&8H{88Z%B#ve?V9@P#}uBeyrWrj4YZ}!eIRiL1C@LT
z`^O^t4K(OLS_5xXln&a)nmtGqpKw}*YN$Z2_O;~AoWdeYwizf7+B$GKHMVwVPOhOB
z%M?GJ#9Pa0f^O8G7F_hsEA=4=71vCnX2P`QfRPJrbL+HNeeMm-J>18V=(1t~#=a
zvxpsNTLPz=zQQF8%I-fQxqAScZClk_Ww@eG!W4)oc>K{2+Lm$;n}MmL_b`?hA4-{w
zwP)J_lB%!U)+IElm`6A3RWOC`*u4Qj1BNj*?x};A0JCtaosQwC4B1Ra
zuK0&GQp{@_T3^W$+VV^=iTYqKg<^*a+M?WuTYytA_8Kt7uOY2DUPX;~v&>j--UaxH
zlr@^ht;!A@wc&6omVjtR{`NY^7j7$nnxd{JOijmHpmroC@%r2Yni+GP$I|Dn_(w@z
z4)5_n?tooevFTF+DRSTEGd-+n@*03ZBb~yULEB>7=ZjtU
z;wMvD*2Fr5)`3Cv__mkO>PmR~H$@m^1D_!tr{w&f-MN+FJ}?yk~J)w
zcffaE8&orAuo2q+lT?_lv!i7Ph6~%#fL&&YyOy#u+)`VrfX+J(Kd6@BG((PmWDEgd
zQ9(ZBI|CqmY2aLV><+gyOgQ**DzxKekaDUC`zA88vD(=}%
zp+9|LF)lIO@*hSZ0E5v<_4S#7r~#jTEB;xk;8+58fQWj>UFZ%;v|*go>j5nHjhL!e
zU<<;2bT~Ywa1dl5zf+w<@vqwd2S=I3esgh@64U1Shuz^QcuONs2|$Ktp26JaHdX-a+SZ+}o
zn3A7Rcd4P>BsQ9L5ez$7uIYo;IOd9D=bIk~xXx!i4*FMM
zM0wAhC<|$elOWFa)NoLsFSpms5%rE%@$+HV@z|>`kfpw%k9!HUB)Z5LkJR;G**&Js
zoBf?|fmY5$H>5R6f?T(8<|%r94G)v~KI9C5D9o`b|6hj8HxQ)*NgU2as@DLJgjsf^
z?*ZL5zEF2kGdhoyQ>SCrPr4dln{et;AHf{;tV;WZ)Y`Ya721J;$@#@bfgpPXd|FXYCM@k^->#6?fibl|8D#5M`E&Fb^-
zy^JsVE)ZC)@kvdrrU2B7V{E5IE?1Y$FHWws0qhaNqWIc+pk<_;V}_qC@~
zLVE;zY=_SVo1yL=S|VBY6|2(9PgW?S%WacJMGJ8_?t(v$N(7$iE9I
z_VhJ|bTQgZP*rCJ6KNgM>A|1^sn5@+SOedWoErE%VhwPujW_l6m}Sg9B?T4F-7E+x
zC>e%rx2O0e(CR%v4fb+uHp;^Z3ziG+y-n=igM$<_GmKFgj8;|gQ$SOc-gl}U=8i`l$TJRP>F3PL_AT?%i5Oq`{|zOJVFUVV|6Wb9t&2hYjT|okBOCf;d-2
zhBYk9rVH6HW|^S}>#$$(Ye|H)XssPEYH}b#Ghu)=VWC0c4&xdG(~Jt>(lpVFPw3Rq
zuW;H{BVY}5-QKXH9*pjd!=?C7p%tg6(A#u>Tf3OhA%J>*58y0TVm4WU=?Z;@>1hs{
z;7O)5Fjg9>tSh%97GdZk5gnoOTj^D
zwf-Eu63=lsW$bUva@6L{r~e$n%;}}e#J%(^%usJ|b
z!K2)vms9}Ceb*h7>vSGXP49<}|26U($tm1{b|uKOs4sm9hupZRLQCu!;H$$n6#zZd
z6Js_$1>44}O^eNzm~$E$
zmnt|2wjZz2O>^wEI6{&kkB?*F4C8ZD{I^D7eIub=8Ef&_VhN_${WS44R1=I@H;~4m
z18&GP1))-03#p-zhFv8BW;xYzTMZSLF*R{=3VL2dp{@
z7!@208?QG^p_7;zwQDe??yKWKdJy5gWjBiAf26#aMsp;VqTyIA993#JB@AlHuBl#*
z-`}?Hu!7@ucS+`)8pd-G=vCj0zhyu3432xC+>x{l27=;uqD(tF#OT1V5-ipcOt-6!pVZ9(BUR9rcQE*0sr$26zd2)|AZcA+w`q?~1ro*<0@<#)rvYyL_AassAe
z2m!~1>ZU+^J$kS#7e05sCY5ve=$C!;%Mk%!#i^tC51D4H8pGdpJGw(UTMj_S6d*G=x~Om36o#9^c6e)mmYb9IKF~4CNd`~FmVEJh
zsdV<<;}}*44kbRSZv~|4lXLm{T7YPWc&IUA1#sT63X++oun?Pr5i~n!BuTr@Gehkr
zg}wiGSZ|mPQFOqTn=MF+1vV{o#E9cQ2TwN9YyB0p%l6rA4KlE78xF8(&QQZZSYF@&
z%mA*fjj&Vk`>DLueKqzBab!_E&<#KnQeBYh!EyTSxKUVT%!h>@It9y49HdmyAVl5L
zGp94tk7^(+{!bWIuRwMntuPP#P-2FcVlNDsd#ILDB(C5!9br$0j;1-uLop2pQfVtR
zf*GVu@wgwnDgGej>1%JW1C3c2^VF^ZmSSC-TRk8}=wM46%BjF;#lZZ8#!`oZ*1Y`&
zb|vUiW%QgGEb`-KZQRN#{%@4U2_M>V2M~=z139!azJ&35Aj`{-EyXv>QTwXuDYVso
z$nHNU)v*qYGBZM^^Ay1&@{Kw}9U)i*Vv;-QIUPtBF=3At`!B(0EX*3-cachXm}VG<
z6}mYEPP14T!9Zr+kf-1=2yHl8N6nZX>G{V4Xp1}E)G^&DSb4H7@T%evpjOh@FGt;A
zYwuOX5tuCFz1|tLTOi)GnA*>CY<62vuK0w=OPaGPSp+r3hKH|{Lm%K2O$V
zn3Rs72^dm#K`u_eHxXSO4#My~Kv!rRj|W;~LkWgUW7m-62N{p!An_FJ?4mCc;}8yJ
zcsNF!8h{xF>;qG2M<5Hw+HEa#oxyT}SNt5r9EJ-#zaMAwJ~5E`dncTWS8*nz-~n1k6~ID`Rmk>D>!P%dPCs=n%v}s
zq^<$vMwOfc9fB+uCTHefJ1TQj0*?T0e!8vLxPfwlajb{tU<4
z!2E|FMpZ~u#owu~+o5A|Pl>)=11$EcCgVpy%Z9T>bFdi`la}moc2w
z%4L-wx)GlPSRyuMI+R|$O
zBIZtG$Gph2u7b2hXrkfIILI7kPEpfSI4+2N6(l(XBNhuP+D8qbsDXa0;3y~MwM8FX
zb5uWIGY43b_i;QR{2s*%vQiv^!Arjg!b|To;3&hf
zu7u5iNd21EP`{9<)7*eHxQtc1
zkk)dQ2Tf=jJs7rnIMv1uXDn#;>nj*!#Bhee4geH-Tf9%G^OYYVo3!h6B@bG|-S?A|Cq~&NFqy{3qJ(y}uISU=1=0d8T%f`so56{NrnX$haGbYyhII_2g4eVY
z9$;f|IB*Vcy`N7OB;~~8zlYFG1J>w7${G&p;I&gHHt;!;nSrk_%>lL&K3AU$EMA=S
z*@#Du7QvLm6oSfzmXT*szTai-DuDi
zm~4kJt{woaGYw8(hp9txKWc!^Nq6EsgF53&F*CBDl0vwb!KhNa
zB!JpHwFm7egN0y!VTpdd}elu>QSOFZEMI09N?xd7yDo56AG(UQ$x
z%DEFUHgxs1C|;UMRKH#XAv$+4D82(gxY1VgCg+j^8D)Qrngv)dDof|0(HoY{u&nd@
zJ)p&*33&j6gW}pmsvH+rK#UCsIj*L@I!!i(ibbqUgHE${e#
zR0+qGc8Q^{n^Q9<;La^p?K!nBu8kpWqu45}fOI3C1ETJs-4-UJ@H6D7>v~w|FbDHS
zi7FpVsjMzkjYl9d+SW9225kvhqQ)0zo;ft-__P!+kJPwqwJ=`rg69DZL|7Bnuk
zEoImbcMZ@j6HkLq#dTq*YFUT%AgLyH|7ZkbfCv4w1hBsrwnfC%#B>ks;n;}F%(bT=
zmpr{yikxQ=y{W1o}ZWC
z6gDo&q+?*`^Af;S^UipEzsz_gW_4Y#b6>HuvIm3e7|LfhhDI~I4Y6OJqcS(vjN<@?
zRlTgq!PK$iBah-B@)>B8sy5B}E2yXCK#95xP2A#BB8VHj(h=a&
z!tdB=!q`-yWg9ey4*^zrE#@@cjJe~Yxd~Vr>Nm*7IjngY5A53IIg>cqec;%AeIcgM
zuIUF5{1D{I+Mr_mpneUnseS#*#SzfTt`6en%nXKQ+Zb-qHRc@b6f^B9)(xuPnlG
z14kbJG(-m$g)qyA)9gz?49HG7LX{3fGZDyWt@PVwW3w
z>*N61cF;*gU7nx9p_^PK-Jv6*0aXiag*IdGJIa{Hn<#PyMphn0i#9aR!5guZEj~45
z9}TtgTm-0MbjGCP0ERlc;u3(e*BOOMNE>)rPg!>zvfM>ENK%jV%Qs_aOyXY>teiyO
z8uhUQL(T+;L{mql#nHwNo2bp5t{SNJ&-2yT@(3jHVuKADwex#-b1)LV{!KL6S2_`2
zUVV!}He_;`ci@l>$81py+GC796?#uJwn56ANDr}VknF;-+T9qIffRRs1yvzk0ff*-
zD=;$SmeSCBkkc*e7Z0WmY~KDj1i1k>r|yuqh9}ESG#xksc*>d%*uk8^v?vq3t~Lj{
z{N?cf6|YK7o@a{Wl~{z~`0;~49iY@#WiCOxKOKD^RSQBXonz=i+G@Ot7x#*KaGVPd
zw4ybTQqj0sCxSg&i??~2XR{9AI{zzb!A4hnD)ta$6bVG|z#t5nFs(6eCZzY8<5Q7}qm94x&4rr?V`9nz4^V>110vE{~v5b6tQXdL-7a
z>drlQNc5oEG^CZ0=_*k3-DAgNpnbUmSQJc`QB4CZf6NvRf#h@BJ=8EH`80O|j&i(p
z*>wga58or411v6weI!ls=_%6(*M%Jfg*OaW-d6-P3a?7m)~8ACfWdSN^+Rl%|Hn0q58-v_YJht-rj`xI<-AJwVm
zLm(VlydH4~)?nsyfHeqpj(%9nydwuO^j_A+KFFHr8m@+`IUFq2
z52w&)M?z*g)Y>AzHohRIN(X2i#kcgbRW!F%b
zIWz$FfEJm#2pQ|wSjO;N{oL);c7{F`6-?{FPr>Os=6uped1{Eb8=6ud0xf;qLe7Dk
zn%f;2QjbwLvmAGsTA630g<0{rkv{WTkh%DnV24OI@$8gWHcdk*uQ75h5zltA%$Ni>wE@WqT-;4vBKfU{BL)m_RFbY%-e
zfL#||nnM>vhl&Fky_}SZjuz!(7*=Ha>E>Z)I8FPAH=^V-Jb-7bGhAQ|AXeXsr;rbG
z*rOV?YhZ4rcm{}mlw(YE;(P|Q<)4fr?{hG@dF#CBxcaZhDtOU|<&@U79gy{Y#N*}?
z42w0~;Vs+}NXBvQK@GSIwkTZoW2n^{(@|)&>v0Hf48}HbY`lVY_IhTTnmK@h8TECh
z`viq}}@oxXgr
ziy$ZDXkpA@QgFBomvS%;IE`UtW=!_tDTyu&4WN1>wFgb1`e}->*>QKGf<{umY;j<4
zH6FT}0$lsN>f4+HKZKVw_eOIKhiu0*#v@3K)I0<=Gin~Ao;%v*D|GAIQXX`#@V_p?
zaqZ}`aP*-LEE^Qh875x>Bj$eCr3Bb6&5In}9Q#Uqm3tXK6!$F0PRc>sL9oUN%fz_>x7JB0L#GdIMC>7@XhudzbSYG%cJE6{k+*Qn!$9Pp3{lYIndJTDmB
zuTA00>Jy7VhmX1H%6Qbd$zaH&XD(&N4d2TsVOX;D=y+YQOC0PmbMFB$*5;m`H5dbO
zmHFsTPCe=Cmje$R-OiLLKrP)1>O+uK#_8`16m@3tE|)LP9XSnM=`5{l2J%8LUW;pH
zgmbtUSv&s4?MeEzxNE%#Cg^E|?Lbnx9!bJbOAfgG9-20!I#j%v*>x<+`{;?Zy`XUS~b0wNL(1z!yKVveslsPvB&0{k;2`3U6zJb;s
zABpB(Js4E95p2>&(|2e|ylH6xz_lO3A(Y7SvJe_ViCWYf&)DO`BKS4xr+o{hvIaK_6n75?FQ;-GnG@rMD+Om^dr-8Cw~UrjDaywPyusTjW27Hy|Z*y@2;pq
z>If|s0aVG5$w&v%^@oe3&xC#nh85u2RrJl$NqO0W+`BnGZw1Q-)pO`aD=nJF0Go@C
z1+lA)_t`0<0f14%Yf)%~&?4dx(n4rkgLE|vS4CJc=J`kPQhSX}E~mf!Zd8V}$+);d
zABnPy*QfYb;wN+$!B*@_^CEO7$9O32{w?JgOrGN?TM5VJjjOfVN$%c^)Vx^U%h~9%
zE=B;OB=a5&08M9MuC2-U0mu~}J1tY7b7!jNuhpT`j@6%QXbY^@y0rj5$_X}ys#z^e
zGfS)PC}j@IhU&)HUA!R`!?ug}W*1>02t!Dy*)e29!V6Cm$ncUXt;{ffXR-_QhvGGy
z{f;5mS%Semt8Y%rwQ$L0Gt1X^FNYsZTPH5?bU
zavAN7n&%#7AexVMVdjvwBX=w8mM`v3mfqIkx<#N(!m%18?10_in#UQI96v^A!pocS
zId4X#gPW?@@&&zK&VgV)=>Qj!ca_h84OwhWt@}A*2V7>Ky6kOvca%aDxpt>KR~RZ+KhN5yLG!
zg|^_!+G9A(aH1pBHPDs86jav;jxs|pHz3FiM$>hhZSy%Cm$n^(7xyIVaeK>8>@C9b
zpb#h2nzdPu$0JZaOR$t#Dn9D1gk^cxwN~uj0yW|RUH0TV
zn++q7^B!tn>@Y_@@27FVG;96n-jUN$8CP&P%yPt?*zKA-baiAe^P7@PlQHzXOpRnL7L
zn2uwxQ)nA!FHWZ(f}6_qg^xN%SdVvVkHBt*{LU};k2K=U;8_mMTr`Tzb2hq+#SF4o
zO6B+onME-2+@;A+2inPb*gYzA367<}%38gaFx*}AQ$}47dfm+_s$P!Cs#XVi#tz%Z
zh!v#sjf*=p%!AC)Ub;3Rr*K@oSL6dZ1iDx2C{0|fs9~j#Rs1M`zA8sBkzhm)Ge|Uv
zUp1IRtE%EBPl{v7Qrg}3OJa-gY;G1~2-E>vQ94SjO>+dV8_SHZjl%hm@#C>a(gRyZ
zc_Ng<7!IP{hG8p6M2m;24xqiHNca506*6_~F*_57kO=Shlxlzly~YE8C=6}{^}};B
z7*boOcjnM)b@sFt$0KA@W8NaPji>9i&AxO8rt_wK>-&}QlYTP_>^gZ=3Uz@lTAMA%
z(u3h9sCUWzw)8mj{90Ui8sQ;g%$5&K-+9+P)H$?mriq0H!$UYo#7eFC{RQyN5*XAF
z;|QLmid#7|pmZL`w0fQchkDtkIFU+V+dB417vWgvwV}2=J7BexxH`N9Ag|_}MG11@
z4D-4;CmX=C4dh*=0roP3U+2+Y?-+)~n`mph0=hhPKw5r||A23GPhnu3OY(jmLQ9w!
zyla3nI;VRrM~=6z^K&z38}7pHFs^0H;WW)+Gk&6^IBBzRhXW}VL6+h{78hvkP=-j|
zc3lEF@hW>&$Wi9#Y;MzmyI}W5z7lj0rk*`m7U99>7o5g$Fdj>YxS9-#Ml8J!04|=?
ze4wWpa{(KD4jtK#D%1csRo$>_R#HbWEdl#W-ld#55i1nh_Z)x~D&`r=sJnGPziU
zbTSgLs>T!@hl*UvpemW{=mg5hF?K|jBIj>?9#o|j?<
zc>qQ#KQRJrgH6~Eh$FUetf(AZsm13kWN4}#>{SVkP*
z1Iu>cAU5V+0=stDbEXFXU1pr#L5eOMH;Q;-#|^Xx6J2=Vl^A1y)Z7~A_oge5GvpCL
zv<6xPKDvJRd%<8R>MS#@SOV8LV1*E3(2_$8kb4rbBzEO~ZsW)Z(==RhXK
znzoEVjksRZft^`2OB6bQ(^Rn+9#dzq<$3lY$O6VIHg%3Y;D;Hp;SmgKVi3whyZvcu
z+&F#bP9a}n^MiT!M~Y2-s~4fky+84>i4M@}GwZ``uAZ_4FLhjg-67sJgoXa`wifz>
ztl|UyK?5)i{vGkd_G17A#H){(yE0T8)b9ly02^z>0S7v}6XTL5_8}k@z2TpQsbMvZ
zjp2$KK7wKU#P8h9fYyNPP0a)C^PJ|v=BL}=k_O-OM8D>ui*T(6-&cg|fSa5Jmw=X)
zDQWbU5{}DyBu`7bPLvmWIgi-+8UyJ;@7E8DnY=<9Ie^4q9}f)L}Q0OMh+Du8tT@wdN+vt
z;>h9@+70Tern-l~W~|Xq3&t9}hG)ZbL@j>P^C)vP<%uKtaLAJRHu^%l7I~^3(Rij=
zBC|t
A4J9k2_Gw-4guy5s~rq(p)e;PC@wggjfDf@Ka()DEFt8=9ySxrQOjd4hcevMENZ&9BO2j!y2%bv}n<
zIk?Aw_RA?pjdX-BX2umJD`?n(;RLIC0(A-Kn#8q21(2l7N%Zstj4ljJ$}{vN=mBg}
zHmYcEGmBuS6hd3IC-4`<_zRMKbfrG6%w?j=sHx;i+DuJ$e{jx#KKDw~1
z;G;3I=z&p?Uw~IX9vdM}{0>`=%a*n#GGOQasLpduf$m4XtrR&N<}^HBK=ZFN#uE%a
z0#hC(=vS_0AotZ~mW|MJcs@zVfTn|(jE1sZaW0yLeph%A(uS4SMcH=Xq#@Q%fyV$3
zx2fK}_|3NxhJ8<36SAGq7zMT!Tay)fKrA}aqcP*0g`ObC#1#zN)Pp%Ik^@-iL=}Z2
z4`d6o(SqEEU^3@2VjPF69n4#d*qq=a2lIr0KW;hoBXMJU4z`8P;bO8=s5l>%xa~G9
z0qMfnalv>UFrw`9;2$-#1d67Cg-U4JJ@!JwF2DxF{SadH9F-R+&>7m9^4n>0UxF=z
z-stub{R4P5*%P|mGzGd^^-(@FLgjGNgOGx87Hy%O^`?ImJ>yAK&>ZY+{bt#PfWB{p
zwuUsGWgRo>0GnpO7fQ`w363lLioEn_4TG`dEq9g{I(V?xB%GVjCKpM4W@1!Da3K6b~f$
zNE~ekBIG{%z32yUqt;K1rF$GpaNPhIEt?neVA-l#QK2f>X1K_7f)3Nm5jZnK7Y@Ly1Rl
zsPPHq&LFMAh7L+f&ACG_s>&~>q}ujXNsDmQaP&L39cZ$-cnL)zJ;2gv(<%BArp45y
zMrO!e7_OYhHRknzvn{>78+m6rEz@32-{&w~dvOQP8)e1YqqJ@P7R@5Gg~K1p
ztRFkTk(E6qT>@#qiX-$T!0Pjppp9=fr)Oal)7%~$>P3&ik0GrmpCyS?WyZwkS`Tg}
z-^ZV&ZVrdyNub%*A;{WpYdfn3Vkf*kXb$cAieWbT(kw^696K>UMA0qX;^AZneNLKr
z$RfjG&P5@mFSmkRBHTmUDK&IJSDs!>2h5ZZ$A&9fBzUWv;#h-9gzL7Ul>RvaVc3@f`Q<7<~@H
z=-dUNsT7Z-3g;BR5^4ZkgsGZnw$njwp4@lUTzLrw8Ss0D_n@&se0+mNf8
z)E#BW<92B7nDILF9NL2$^{4Sz(~hE>6woyvKFF8=mRr;nd0#*@pL_V(+$Bisq1WxH
za@*w0kl6Y-RnXYr@p!%T0AMSg=jmPlNR{KpFT@-|drrj+
znO@;j`88YwyU37#lq2Z56xIl_qMpG;kN#F1u$qG?@T1ky52E!Kh9r1aP#3zX@!gSX
z70o*EERanFYYCRM(Qi0Qdl+QMbQ918NHzTubPw8QIj3E|vEwn=W7`VgIv^QCfcF$&
zS_1!|R@6B>w0rb2wGHX~m?=`>K({Ym)b{iZj_ZXLFT%`njLDODu;I+%Sf7mnzi+P*
zxp*`RSYZjf2&r!Fb88)trO{#uXBZu92_EWg$uV2zaDC+0;>Vf0&Vi3$xFT|lX5W+g
z9>YtDYX%d~#Z^r*w)(K!0jTLFZuVex6RwwGga(AQYp07xcc^`k?ZfECyM;
zmgM+mSbWWZHYY!#%hNk*ItPwFjMty10C%^CXds6mjKC1+`BDwGINA)3nmck7JznX3
zRj{?;l)pava~RR^UC~$F*40cGp_?v=AsL1{nZbD|jmt}!(QgjL!oY&(QL_tatz#(c
zIb`_(bcVEr=z8&GtKeXgt1%=VIHn#!pF-LkTihhHj)(B9WAnQ^G;&x7I^aby$2l^C
z&s=|ZS0Bn53=88kLqJ+6FUU~Sj|F`t7H6if>jxP+&`$EGI_44(5t-l+Rsw99Yu@PY
zW(Lp1=o8h0K?x(xLdMW;s$AIBXa;s2uF#qeKpL4IMqNnCz}bZT4goGdev_;=?C-Ju
zBh?YmHr!*Ktx{$={zIy~Igq?qeE^DgL>XPGY!T8X*>I=KyT*U=8_{nb0>O2
z-nQ^r*zphhM?DkA^+8YdWsIJ4=wgR4q_tSp6s7`VmdEsx-~lwHh7T0o9g&@Sc6iVp
zi1TZGOnTkC>O
zW;CmaB^Z`Fj9UT`hb;!`uM42Y-iCLk2iYjk=R2O17(4Xr3fe+yAk<001L$TT9__E{
z$lUZMhdnT5d*2Dx3}aP(1+IV{W6-y%8fqWI#((e)|$2p8edQiy==&^Gz<&Fj-kFkPA2
zm}{=Mz{5z7JwiXo!T53IZpVf_cvgZZ6*Zp$Q^rVN`YMO-BKK8Cc^P0F!ViPUGV|CySMM@vdZ{ow(UY#*46Qq0y@EK6Kz9
z>(hFoa0!~+m_SSNGACz2>ci_@N51o}yHOe%rSm6hY)tFn6$8~i*v)}orqnQa00S+!
zXig`c=GfY6M`JKzT+k_=x{0`}+4m8^*<9U>hwEpKxvDedb7%yn%(A4i-)&>)m+4Ry
zgnZep(g8R9UWfjowJkXg>6!FVIc3HSvKXWbcHum48QmU6)9TeS5PY1I=wW@-2BXMu
zf-B`9C*RRsDYVYX*Dz?TQ!?k_W*Jw*QDw)-as+V8VxCHSo`Ee8{u1qB%^iWu6`Qd$
zT)ZcW=gburfTk6pYU~Ae45hmSUjo`5jKhs*y@Y4=-o*p+}%;DFTNziipR)fNe)&k-d@$$pab&Ch_@$LX?vl*)Wq>bsRVf_(U083?V4_8
z$(fEH7WHAF)Rps8IwT_chm9)pFLu!z($b=u(oA94WP11&-NBf1H%@#(lZk<|wD+
zPLp;=XAb3C&o^+v{l+6wG|3}ul1vj2#XMT8OY2M8r4(mf{N^JOpG-GStJYWt|B3_ScRHV%b3}<4SQu9USF2-
zZLup`hP=LDSCR#@Cb^Ju#};M7oDzlsIy9u%xeYuD<72SB%D@xG1^
z^N(^WhtFc=5lj);Vd9Hv{?Q|fIT(I;7=l_(*cy;`nP{k2IXbJU6lV>#BEvmrI#t0#
z#;K*;*SNL}lMkIy`LYd|$jK(DPc}^!SRUn3C
zj>ODc9pDNAZ2XffsjLaY{Z>4*edJaqZK;67>-Q-#K{LVtD<)|WV4lUTIJwdQh6rd8
zT7y_01hNbCVu#Swf&3j9%e=1qfjgUK36$SmC3LIUm&0u}iwQiwFoO
zLq+%%kW)UUBnY?$A=OS5Jcg`|97$6FgxkuIG>h^foenh8-S>gO9!Qp=JLH2Ld+SB<
zoX-dv@{V&Z>V3vw^Kt2ij;0VOA#-=y={dj_6E6uYftFEv>P=-_`GfcWeDFJn0pNCu
zDI%{1GWsi>6dJ~>9;-`bFdTkck9JiElNr-65~>++S_>t%)d0c9YG)^4X9O1#*UAEp2s1nDHwbGWsG$0pXkS**`YMp
z#^r|?C@T~fk6p9wMR=5kM0NBtWbgh@xlfC@EW4b=4saQgd1XdHi7Q|gKT2E!*Z{%<
zhYv+TOIDJ`GdYDr)<6+O##|s$8=pVm;;b|{fZ%f*tVl4t|3O-98RZYzl!KLV{xDsO
zpF*(eAIm|`+`sd>ZQABrLa=aj5lW`C%9O3-rf&_!}Jg@NiOw=rZEZ
zB<{XIz@X!~F{}b)J`N402o``x)~5g(_7ov
z1VOUkM5cPNG01HC#fzyUK2ftr(7E%54-Gvq)fN_6G42%v*9f+Gi^g#c(W242nR0Ah
z>K1%5abuvZ!E_m#bM0NnLLC^szZ}T-dkx@pIf8@gA1Grc#HL!GY-v99^S1^mEWTj{e%mBTwo
zf^gBYP)*fvRv=k(TRO8<%Q4u}@&;NXlk}W`PfTVKXOO*!(19dF(S}ZsFljQNjWv$}
z(8A)1d^}q^f`W}j5Yb=XKR!3>@oy$~N*3l=LAQ9)m3GV>!Q?I^Kg_^i9y^lD$w)Fy%UkNx=$`
z)#mr~HJzhO9$=9smvyywv7>`%D7{@QPeZE
zYlQSvgVrgg&sdaR7TL-ZXS&b}&5RzMct5%V#iHZh7QG!+xdzFi22;8Q7-78$S+oFV
z?txnNj(;OZRdIl2{%t)43ZjU*k3AwN+gItvj-(YubwUKflsEZJmGCVke_T6n*aAWeXh>v|Kq{rpa&x0_FGx_R%RMplEY7vvlf>{0*yi
zTK{Q5klnKEjAS|>)wQIHo`2LfC^j`PigVKx8I;Bm*em{c_fSI@q~^SVSrKF@{g8(hCu)DT$A%mC*u4-3=*aABzQB$%qouyO+O%)w-I4RTfw
zRBcMXQ52BAH;&XmAa7)homK#_+uEn31k-AM+}fk>R(s!Ybi15r0aAgv!yL_D2a+|Z
zha|*KFH__=O_%8gP%LC_izv|%1b6QYR9{@YwUbS$D
zmQXA^z5K8OSzITKL=Xj-6ge6h?MvyhklvBzuGAR~xXMsbQ?HAfDukT?c~X=bI90K9
z-#JHUg&VOJ*i6dBQ;ZQhkX+fdJH(toB7ug?c_tZ+44ekdlrCi<3@|!)^b}Jd=mK&o
zBu+I=GX-xEo&ZB3E!2->mjE;D?qUTlgm&(A@O6&zFylWX&{2krRcNMjhzKa#0UjB|{Gi`?rAehfj!~NnYf9w}m
z{IPEQp!odp2fEjR+A~4V>X6qGg1V5G^fN2MaP$LNzR9d;VD_)xNs``xWXn+aXiEcz
zio@T~+b}i4F5?xu(&(e%A3stf)XG2Ndr!=D{NrVe+z4CPnhwN>Yv7+54Wd`Vsb;V=
zAnA>vSPI)QRGETp@q5~roC7J!6I!9AgTKmfW#vda$zRAC=uBDs;v_t!DP4!kSgi+?
z0X9FaO{eGWprw4nW!Lt{uK$yoHM3&uyU{hn-4Z8d<=CB$uc`o(^7g}
z2pJWGJEXEHzQg!R6|MqBwT~mbs2gC0dePoDz%az@OQV(+$f<9sYIY3O&T8+kmw)Be
zZj9QWvWKR9=#hV9+!T%S7-;RcPA+`
z4A%tVkj2)Rx4tRTZ#15)Q>NV^}#aGB97Dh(liu{VmG`p*1*NEAYPk_DOLk2
zIWXZt;e>5r^3fa(YzLB)Wv4zSxIHKok5g`_age`YHmfx&RPW!t`aK5ZK0`NSrVx@P
zq!aLSC}xv`xX5$~Mk29#O%AL4!Bz66x`qM+>?CnjwDdNA9kZV@*zyb~HkAxxnXZ;$
z%uX7>)S|0>Kmh)|7&3K$W{rUr2GPr(x74E!fSBd?evLvC3=2l?W5Os{b7&d55fw@k
zkvS;opmg7EG#ptX
z=UR^bLG<4ZFnrNRJwXvIkaeviPcb*|`0L9!)!73$am`aTycqarUe(3=Mh0AABik{+
za+D2FpecmZ?&?_$mWL=hv?NTnxF|H(n!@)i1&~Kaes;`Bzp;RCUJ?;AZbJ#V3fQaZ?aaM}sWsEj?{4GgvzAH@-r-2|AY#597Mv
z8iF}bRSE`6zgZ+9|6_5Bu$)O87enf#T7l%GjT`$B)2>@smdnkN{z8bXz-W(+s0GnP>R9FWu0ZHt$xFR|LSf;#wty(pL;F*CY!g1?5mhB;%r^ECJ
z#uSQEeXd8r=RgYd5)PIGohA<+i-8pc*UY#C0a9UjMvlV&R;0ozRS2
z$*%?eOHnrjSE5gUZn!ar>Z0wca|f9T>O76m3T#{Vb9Az>wZEW+Q*TiEr;0CpV4(~y
zq)OxpFfEOSv#AhL4WgAU|2ff`HS%XYjT>4yMmiFUJaP=4w#Nv7GNQ0~!Y;!;<}vaw
z7iwk3DXf+2DPfoGadfwHFbeIvixo5%5l`)BgNunBbTr*Uzo0#0@03H{7OKq?VczUQJ#vBNmIbiLcP2aDUx=?KQN4}oa
zgFupWFJCS72|Gz|?G3;(-sc*p4)c#R{;Gsvr>(6yjVEs$-M*#92D&)()zSoDyR?&T
zuPjfYx`1-g6ic5of5jOo%zF+nuk2218Giu*b$$*vOMsQR$T;Y72__;ZCl4
zk2VAwL2MZlk6H!IIkHAiqFpdr{WzFErEDuVD%x3jHY7jPLSK`?LaUklQaUIMAi4)b|j*
z8KD)TvwWtt^y5h9ngeB+1Hht@^8^`T8Hz>2XL$9{k=0Fqf(=uiO8!KLpp;V;3OrxP
zwQtQ)&$z~l9RQ^K%V>TYAeVbfZDZ44oxG#~Ey6B9>R^>@8=^Tm>1$C^2GoIOVe0lK
zsdSwRm#MG>!q%X6-e?T`+l4Ly4hTDkI49y^#}J}(c!MeV`1;5nh4YIW!seKL8MF^{
z=2yf;dK3SCcA6Xs`O|RTJ0omfK8rSf4n`EZ)I-r0j>%{UjT27hpswl`h?>2bwgOl^
zSJz(VyMsbH3?j4~zs|ouwzmVGmSzRv?+DBYi6
z-iNxYfWk9@o$L!b`VIth>?=Z;R2LHE=OfjS9>}uWk!kjGd>RwGXB$9*$8aSwBVAAPJD&MnY|e#28Ku`|7bWX82(Ko^j|*jF}dAX@wlRDPdhYM}rV
z_^VMw?rn(Q4UpMmUQJc^k7*xc(?}QH(PsseTDt2?JUWnT6_UAl8V&_(ISTEB7X<5&
z)ZMbzgasIIn+UfI%^X?HN^FF-AX(tNu@Glt+L=O>=nXnhoF~_WP(e52$8k8RmoZ+E
zXxRH;tMr$CiAtZ*79H6TgqWVR(0u+(ribpqnU&Zf9%l-xXel0^zJiJtMYUU>e@^Hbiweo-1%$(5y|djjKRUjlys9I8EW|(
zRLP$lyg{o~pv%O0b#>S^2EG#zUDr%N5Q
z>kxKfd@nE7m%0!wSr%>S2YMNA$I!8#aWm{>0JcooQBLOwhS1>LqKGXM2-f96iWBe*4{2AV4(BIRDOVK$d2~Mw6-zJpR@7rm6}XpaN?#PiY@dtGFgI1
z-X~#i0WPHdkFCJAuCVnx$Mp*aYk<|D8v^8g55eq&pJ&zK(mzwAM`NH40POVkb-}O<
zbn$T^5|w^4f6$Dfb_MKAUiesOHK*`Qy9hO4YO3USSrf!MBrNK;IB3=&=p!QUoGr(g
zzR3?xHf|PU4LKcF^QZvu-iCx=bU6@{0AM-cbfcR;e(2$7&oS}Eixm0*7m*u#Gyp)G
z0n5P8aJNObGlXU-vsXk;BQOd%@LMZ4`B%+n(SJ^X*@s3~urtZD^oA_7DKwWUJS$=g
z5GVlU;nz99d49%~4*CJG*=(w#US_zh<5su+^ii~;D*y>|hb5hex`SY~v9}WoF6;dH
zmfl3UH)MQtuKHB^XCrS7x8vyQ1{e+X`Ke40HqO#%xlKd*)?-Z-DnM)3*LmT)3d!YQ
zH3D1%BRjdZra!6Y??08Rj|R{pdG4L^CIoZN3KpYH6z?ze-e)KsNb)-W7%PGuu+^s~
z*qdox$AL+ASjQ1Ka1AmN$_7AWE
zkQv%j2h;?rDr@}FB1{AE%m2_HmEV!#U4;%|7g;4ll3fUvB#GeU0TY0Nh}gR$_kNCl
zGbFX3Sy$9JKN|Vgwno6
z(>KdsA2SP#3ovs?mcyrdFlPa@##uF|6GmX0>VO|&l
zG)nku-@^_FS_`doc)13}^G)V(43P?V+4UTgTVC-PX9J4M-oncoHUYM1PR^Oav@+~z
zDAdm1$AAYL>Oeu?^2v!g0L=TT9)^iijXh}2;8lz~`d}o=Q*W_x4sv!c^mfoNM}8dj
zVU+O|9^SBojF*2FXG_4ypM!@q0EC^C{VnFG4`ZA@1Dm%a`4$fE*Yg~QTS6qi0K?m^
zy%8@#$+)@B@RHwep;*N+rcjf&g5*{%`MOZWzstlBlNuT9BPVwh=Oi(YoLpSck%g81
z`6AEa+Rz3R=gIhtrACr3L!#7NTEaB)&u`pNHN7x{Z8jMzo~z84SIx4}saJam-KT67>umXbSZCj1p3g|rWCogxfcMxFc
zLd)T6!q#3c^Syt@tsm6UL5
z0*s!qR8%8qljogjx+M)w>(H_}F%`w;8jz6u!?Sq(wh3|$JjNH}0tnWk9&*(OZ77xv
z<7kB10jt}OO}=jap0gFNvDO|0i|a@R+s_}L^t9$V$Ui?&3p320_k4$WOcz~*X8?_`fjLu>}+KHn{#U2nc3Y20(%ZmSheC1Bdif7q2-@^`P&$P0WSekq6WYv;v_Mr
zku(NxLbE8C_~SkqzFL_SCqC7awIMmb-SF|z0aJ3VsZlEtOcz3`{P9jOhLBT9&Z;p6
zNdE}V9Xy;IfZ+L1XWfQ`t)Ht`Ff{@?wZmh3D87M0q27xJgvJ1KqdP_P1sN~@K{%oW
zlTUg*eFmb6^r)+If-Y>{|6Js6w_a%8vIJYFz-%bDJYje=z@5jv&*4a0ek-4l{*#
z=NyXVe9FN@zv^@k4W6{SLU~L7LZA~nE#Cma;ocUFJwYc%KNg2?H=$TGpUTO51?VE+
zSIEQJDkKE?Sm(-Xgv~z7c$(nU^XJFmqy*^nSzKYS76X6+V>Gs3V=c0kznAkC)qV)@
zreMc0S(;GW+6B4ZC5Me_D|!&k>=$}jsh{DV+&K+^DAB&%hX4|O_!xB)v_4;4Yq9JG
z>@uDf4pPTCN;7sECg8$C+__Alxa5qabsBUA!3_E;XAI-#Ih_|5s?ZC6#p%!32}b`A
zobXOs*S&=TH^};0S3sn`6(qI!x~YFhm_l&}6=|;lE{S}gQQ!S7!j;R~YiXeLFBbCN
z=j#nnVamGXPzJjccm>e0icJV4uh;8QevYkUT6FU{Hg0!S<{AVmQ{u0KEXSj>=j?OA
zP$QF=F!+i@6H?k=^UK)2=K<9=tbGA?PvE7UKh|bbOlvxjV2{I46s-%e_Frls-vgpy
zotWRZuC4mLNveFwa(fKSPk;RL2Q~uV`Ph)(!GQa{*
z!<{3dWxVxb1V9@eJAyg|pt;k1jy>fioJ(wj=|yJI3`$wf<%9yade$|)=>}kKyw)0fx?_k=O5HX*Od!|@&UM&-3WVFUUhIv{a_lSSucSH8ge`V-AS@tI
z9?mdMsrUwZ8O_~0${|x#D&QJQs=qHH$~^>YiYX)oE{zJz
zf`?QOb`d^QiscNq1vi0C4u5%a&R2n8RWiM@pMWYPO7SceAx`!_ocY#8js
z4qX%AgyfSvR#94zTs~C+W#5Kk<}S6!-T_jUe0#YIM&4=q+XG>7Jxxmb`NMc@dH{50
zJ_$~_hdD-5t0g0#Pvy8b>3Ns`CVwbBcq=fI*hZyIhs_Bja^^B5HbB@C<3@FK0}v3;
z8TYd=0f6mU=baWg?x$XYTtcw29>*<>Td<1MMR6;_5_33=zsp~0Ljt|5{pF61!S~hh
zgWY^`uTv+!5K8G^E~Icu&l{u-NJ+o;UQ>pk9Cf&0ldyV7h6NP|w=vqPgVi6>+7nFb
z(UiCjH20sywUh?fiKi1yO$b(ldZTbW-hxDaTvbNfgejpelb~w=lIr=@_O7GRjZnJ@
zJDba>s(rAz*w?sV03;O@BQ5(oY9uNqoLVtFBB?XFfno;54jsrCLP|>Q>I6tqtPk_s
zG=CmfNs-SCii@+0&&Sxtnr9MbW?qCkXFBE>QR+R`GfM%i`$e5QdxgD_0fy}(KnKoY-1cJ-;V+_7g)di=gTJD@UPFGmrwEX-+G#&yC|vZ3$%oHV?X_8l$`=BzUAl
zM!dKA_cv^L#qrG*B=g9V)#{G#AXrG8RaJ9Hx@%~KAk`<|XG}WSbLn3RnLnj^ZV*hS
zeaHnjETNn!dWA=|e-nz6JLcLjS_f0};!H`EF!IpurW(Uqj_P&PIvAB4`u=AFWaib1
zFy3fFptN89SmxWxzwSJVL!E61&KHN&6tx2`s-e9dcXJHnO^hDc+Bwx`RUb(1xXPBa
z2S&W$A~f`;dk8E9KLTTidzi)zL5q}?$Kv^RXlA{oO%XKj?@<&Z8)Q5MQxVAWpx$P2CQ5j#vdj_{qTyVYeLB-CPVxgitFbeqM`&62B{%tR|
zz3#zgbT2*0@vj!i>%g2eyaB~^_~|bCHc~4@u{J>-Qqyj5q}mJi6{sFm;3b3$PjN=6
z3K3z;LDdmmY|SvbKN6|~k%x{PQsV(GA}0mnnqkvF`?;1@{+Y=@u-JxT6&&h>c&wjw
zpi$~Gby8h_u6=8?4zNs_R!oG
zQVP?*$>}f?^BzAG3n=#XQd>KLYN?a7Ca=@{;c2We(-=UpZ9Y_AGzU_eea(y)4uAiL
ztjsKPv}c---vX25G-*bAD@VRk&vNI;y4@!8U*~UmFo0_BnB-5D|JRD#chmxG05Jg$
z6zB?i83Zp}Bf$U&qc?nZ0xAF-qLvhB}(GG34^M**F6HN+V4~C9XzGnV}j;^O>hJ>0XrSW9`^q|g@gnz=sl3cEPrsw
zf)(1$SdQ1xa8BU!6P6RzsxF-X)lLJ!+x!7fBJt}AoXo5^odI8WIo8A20$hV#;H&G{
zOt^<&HPB;4g)aT;p-=8te-pI>D6MYlmI=cD8}{EA{BIi3nmuN=Q~|m855pH_6^u=!
zqZ|H_Ly_*(>Nzrt$_(5aU=rT41GXl}8u~=m)dIqRp7?C%@2o&%d(l
zM1AUId=js~_Z=JdWm$v#YYPGu8%#D@#~3;I4Vuaug7Cs2jS%rTfArCXW_J@XrKBrH
zq0i1h#mmFba9^R#lzV7a3MYTzu=FfFzCeCt15zEkEK~;H^pjddT
z;b{fn_MA_-M<0ifs}QMJtU*w-37aJ?Vd#@O6lZ@P(rSQF^CQfcwZ|sNYJRM#Obe(w
zDi3#U!mc|=hlLw>=|D28$JyIbH%H_8C(Sx_qL)e8-E$ujrC{PuGK2g(g|IOWA;AaR
z644-z3{l)C8V}qU!iNn2=tAPF{NfooPN29xG3U}9uPFoya->4d2$MIf2E?4_n2hmt
zs&D}Td2j+fyuBHb@z=hF`u!z4tE$-mTT7O`nyNd30V_YVP6wbJIs-emzkV3GIkLC%(CMVc5T&sd8(sgXX*8o-&
z>%QKyszWkUTR3!4y=XwO>g*0&-DnegLN0F9c(_?S@BB*!~BtF)>L~ah~OeURdAEj;5soJjlr^Y
zodBLVD!K~CTT`G*_^YZ|F>s$jLuzl?PC*895XHHsTV42Do==NTnAK5SA;SNA!#oW4XQFa1s->gRDRd;{zR(=B6^4~YbQ
z3S%?JKYgKhYAaw@{J0K=;}wKag-CgJ+tho4mVQreQDM0b#WlODt;Yt?MQUAez#Zjh
zf=;+0l2#_@_z5YrA*kLk>O%hj+Y0AukYyL_+OUrh+%C=>Mzwh-e*Yq)ck*=5aVgI-5X#&)?{0qUOE5#+2L!XzUiMy>|1)iv;w54
ztR6D$ssdzlenh7RvXHfWqNR>Hl)RwG!+CM&ssYi`yI{aiOVWgZa5x5IVYG5O;VzP*
zx1o>^byzOTJ4SG{hwQsKwkDG5V|)HR-s(X0AZMq+JD(3ASij-4S`FwBQrLZ5l>@jK
zk53MtYZi8sDci-#EGSsdwQe9x2Y}#SQw6GIf3Zq>GbCBDDvS;WnMtB
zjqPa|vIL^-Z+wB`9sKw)iqDZg*O>Ya2piw}*0)eP;Uv~K;rIF8pMCx7sQlk7e&d*$
zrfM4y3dQ)9R#GcNv|u=H@4NJykSvRrc+=A+Pz6dZZF;c~Lx743b*3arUjw@i@ZkJ~
z@8Z;MA8DfNE*%~$2J6n@>Y$3qhbWn5O;xQ>jOqg
zm*=i_J*ckO5BQ`Ha&jzK$~*@U%tM_1c2hCTvEn&r+Jj6u#$7$1-9T{R!nIzsgJWn(
zdurJx`5TTsXiPeVU_mn8h~s#(oQfTe#^?~5Te#C0xh}wzcV`!SAc#HrV%;r3q9iRu
zxdNr}KA9BBGJO41{4)wUk$bo*(8jbQ7b<$fum*|5e)MC=L*I1x
z>vAaGB5WP`ws@FV8zORfp+NybltYhJT{6QRMZWE37;370uYS)|L|ZlOFza9o;jeXc+?ec
zTEo%;B7%9K@Y3O3dE~q$?2bU|1-cUwo0Kca7A;FUe$ny{lEry?)>n(RPRXeCxzF*L
zUsG;!DgE1p@^o&M_5f`2J!oQVWgznX-g~H-phc?Zh}B$Ipvd@+Pt=cALGb>`?o$*@
z7%|W{Z~>zZFqd)(;I~2`k*r>SZ4z{boTEm|K+-&y`*f-~h;*2qVHh2N?Q9>nU+CCh
zvQQ9d_JCPNoZm|u+I`1T^AP*4Kr_P>%^PJy$I$R!mjy;ZIG_j8s>V(J{H>HS22#lI
zK{Fw2fgE$IL}$RKIr6rS$j*Rfm^U8yc@B1SaV1~lF~3_tqlB=>^7j%z)_W}fQ7>;n
zYo~WKx#thWHfkQ`{yIP@1kXYcdC`2(!WzgWk?Wrp*En%LO)kZ<^1yp
zOc!*KU=vEw8pE_Z>Iy^`gq9G_s^Fp^I^C|MYL3H-4;2$h)xkEzW4q04K#;ZY`U9;(
zGa=X~a?9w*S7+NGOhUJ?HB#>QNBS#SLDvX@{iR=~R)dl#uH7Q`q5Q*RCUO|$h|gZ&
zYZuB83WB(Jq1pEc0CS|e{8N+;a`}8OMpM@@G|S)xd!5wuj9F1aL(`nX8{|F%kvPK;
za+(|H%sf>bU;~uLrYPXjpK^jit?Ml)RS3aVhTBr;U51Z!3w{lBi?4fuuUHjT*L)8(
z=_h=S#F3l78CCA#6Y4lYtLDgFT+0v~pH@NZfar76@#XZpLQtBcJ0<%C-G-AE>%^@&`=pP(F~vdB2>*
zdVo}Rw`L*CzQ15cFHVFEfG!$dJ(H8lAvBu^I|$*`HpXTn$QBE?WyS3c1X`(U#6BkI
zLep|$Kr#SGNagS?`>d!R1ZT&Wa^^cTK$!e82RK=dRz@Qafn+~NmP>*bfKO4QeJJVb
z%Fp=|#)wuqD&1(j6@Lc>VaAvn(X9tL5wU>vNn7A!{x8=#b_4ndN3
z8A>W0Run4TbfiPh52nth0>x6pi4+4X7E4rfs*q(I_;&!+KvwWmRg!v!XBsIqGWT*jMcrUZ#_QHJc<4K)qWiHIvFSB~
zM)o{QCvO3Be}F!Ybogrlbb1GI1#}5^ma(d=CgwJO!Mq!d11a^>@HTb_pwthkj%$LF
zxI3tOkki2xiMQHHf2(N2FN*CKv<}f)-{-ctvMWQeydND<-$CT|gi2102bfvBQc>D!
z#uzv~e+_J#-oA)qa&-u@Hr)bh5O%uDKSs$Q*HUiM%C&$8^p213d^AnioUw!$FLZXG
zz{@Gs6MX|n?!vo&J^$2wV41WX9UpPKgNCp!dn0l&im<)S`HzskJ8){Hdi{7
z;87eTONa&|{CKLpzJu!O#ujO)G5lGrA-j|u6he6S;M8s^9nDYAXE4cscy
zzA%b51n1Ah^jy0WolK$UjCUluP^^tdl0gqh?z@H0;Jz{4s`uXqV5#4}x`!Z45Vi5?
z2;l5lwN{n5fq)t5ip7{9Vvn=elN?WOUofI61ecgqAn*)K9<+-xu{p?vV2duAwgnWW
zyGzAS!KY}Rpi_Am4|}h`wksBbIJyQycaU7xBYFH-`{R8K?B2h~skQPf{kw(C_p~h>
zd~857AG}l>bAmDyC-z=hMkjyx{OmIB)grn8s;PzK>Dya18<0Xkn2lv%0X+!FS@>eXAE6MCG!g6Qmduo45v
zkWB*&0BEUUhhE5oLW0+qRL#=AR|pKB2JGpR#|9*Sz@_lGz+Z;q3=XvvvkAloSdcim
zj7yvq2(V6@@(CGMjTj|qCRoegayJtT0NO-+aH9dX!f=Pi)c|_nCZyCCu|hz53r{U*
zE;k#B3=*TEYD4rOH!Ww_MGR)a4-}9q6!Zdwo#819L)3k+Gt{>AnQqVa9i1ULT>)Vi
z{z}a<2@Ihmk;|a~Z7lqZAfle;0S!R@J!W}!H8F-ro3Z>uor#?QQk_92lyC~c^4yaP
z*ja`LS`V26onbuh&7cFe7ZA<%=>=_}COOMY;_xwE!ekzh68A6d%y|U`QE)&S!^ppK
zB05e0t)ZZ}@Z~`J05YGiylztZy`mc0+k@9)H=vRv*ZAoh2qT9rINT&ik$IF6p;Z7b
zgBCT?T)qm?>El*SZDP$|e`goFAV{KY_R%yDCMB-6l57)TS#fU)L-22m)hKQoVBPYp
z9l3Tg-ix_aH)D33NUH}XE$So}SI|2U=AV3u+6kK97>Cm+BjBN)+#PzZdUyY%-?V`s>S9`Dls1llHqf^4F>3*wpX2cvqn+bD})BAw4O&Y&1pdbw{
z8cFBrQ}$)0?i`sxO*Anp|E^nQ{C)>Q;8A13iq)WOcm`|^@wBggrG
z3qcq{j>G1$E~~(P3!0^M%KEmxYUi(>oU701fYEX5F)aZ>^U2CD{)dR&gGBK?k?8|$
zp4_RA$LAoJtMC*OXLN?plA**DE~Gw!h;kWz&>Rr9z4l_gC&N9RO_*fJrZ$!UG)wXY
z5C%1a1j7%p^?Cl)Xa6Yh0+M;;?z^@emJrM%3Q8lgTWFU4(|EmcmGP^%b8rWCbNy{L
z;MO1*E}n|K2Ra9OE|s7(35H&=AhAJ^gt_79?T|9qaq%yx-S$n#{Z*u0$-l8<5j(pO
zEEpaw!~}q5awp!Ks)NmxR@E9H6muH>*_r^d`N`P-q4pg2Cw~W(gOahubglzN8$bPh
z(&^^kC?9QW4{UKAA9C0vyu9>t94=<^%>bMY%>{{kz
zwslqL7N|x>!-*APv+x>^0JH_hp52l9+HtSrI1gc%e|8aWNK3z86!?@k>dAK_6Lx(k
ztPH_vGelzivkA6TsaaYLA^l2@^GV#btb)~y#Tcswz{roOpml=ettM5K9#k}*FMVR2ng?Tut3}!5QXY|6a5oR0?|58G!heN&Wwq;q&|gU{tna`
z&N4g~oC75y>D|HB7S)|#YHN}6c}xZ&Ww!+^yA)^PR?t)1sOK?A=kDFdREF3$Sm#JA
z!Nv^W9t!#XCDzVL|3M*X=Cru}KG$?_1EPzTUN<1&GBn%6g9~M}nd2Z=R?$u=Z3P;|
ze;V6~(H=lBf67_=0W~PDLrob*{v9vpvHRcf?~IU@NfYEucj$Joi;Q{7hI-ou+fEoA
z7KgAp5K|Rmbf~?dE)?oaI799cMk810JZ2wgjeYEA4F?dgVSHmDh++Qq$q_G4U=0wQ
zHVaAND&YnSX@{D|gsttZ7OQWQ9Qo}0v$)YXg_8OL+{H&XI?JRt#pBa+NXY0pb3Gc-
zMNZ|}+2u1W94zxEcKb!0x`mW#YxFQHCvm2=VvHgup$}BBH59ns-9KP_Kp3uhX(Sf#
zN`I%2NK0q0v4%{*SG+zhLvUt1g#im8vz)pFgB2
zVgY0V0j_n~DsELy!7y0dL!9|%25v77wNf?*r*g|Trkc)0j)GfR@DhyVxPL-eD`odR39dOW>t
z6kA!0070qAWIP_ruw6*-zz_gk6Lzs*%V7z{?x)XniIg4yjGes)9EE`RL=Y}QrU@{VRJCj
zJY!~smH|$mNsN~Ymi}BLXDut-0-YT*Z1p57|4w>TTJ#LG?2b=#`eY4u0eOit-Lboe
zMwVD3wp03#Lb2b+Y&QrxZBo|aTN#1{ze`S=`R8Z6qX93KOlaC&YZZb`>LQ+%1(MH~
zKh~kqI>3zK77&AMfHA0t2k4t1*&JY)Fx2u7FJZn-(97HO&)ODnz3VtWMSg%ec0pEg
zcogEZUZzkLVxO`PC0mKT*9{=KuqS?2dkDb{eu`sD8vGpL0~n*YHxMWew=T$I3`k>u
zesO4K;s_^kcY6xJY)&;6m=QLM`>{Su4bPvyTi0FqbNRHPe&-0U7CC89^gVSHi=OV+}%L5%B7N%c#M0TQ{l8BE!hO@5iZis7fq&6v4yXHlb&Lh
zzxXIF!Op>!0?T``i?@JcMz1er3HgI&obtNOA8-OFt{|tbFO1;Bv?9n%yUwjjp>2rH42R8KwS^rhq`b?I
zU4Om*gdU?u7`sUgY9C~_IbX|cjCwSHr_h>0tY=k6f5SXa?R(HlTPK+n|>5Hy=Gfg=tzy`^c9&3<^&+5kt3r=tvI
zfaONdbM{ih(M|u(%_Kgq0FW{!so`=JWM;P5I3mBAe?GIb<$6MwaUY-oa`6u0Zf6s$
ze9yIh+j3ORJISRDa+zLJFLk7>lRx&zy$g21&Sk@K{jLX%1i1#F5rDyAJPJGjAorjB
z==Xm1qrZyYfhaN>wGm+#0pr$D?HdSY^YmJCvoW~Pf_nRyhD{(^Ow74_OkxU&M1Mi&
zZfAsxME#UKgrs;iAPd9dq1B~7JyiC$gk5yZ|4i~QG=yr^OotlU>KjkS5r|*pq-;`CB|Y#I9Bwic9>lk41GL
zSO;796{3F1j=K=u{%LK6Npl?h=s|VjZ{9Ef=^O4OUnv8iS$WKnU@QTwVriN(%3p4)
zls7=j`||kAzqcMx>j^s3?Rb1^0d~qP#jv|X
z&CfC1V&8%-VT}{QugD4#oW9lI^DZaA-OyMyUqgX49y_5>_W+BJc1s6SOMkB*Y_oAl
zE#H6u^ZT^#WrAupS>My>FUO1t1bWYeD?e3Pg<$42pOvO-P?UGfspe5D+cvv)K*@RMTKuSvEESwZwhT=gtQMitl)Uyd0
zeRx1Yvy<$h?v>_
zS649t*$5v*?YPO<{+!2-O>rK{1dP}@(4pxmNNobWB?e%Cj|yZp2T*9X9k2kx=IE9D
zBxwnB2`P;sV`0YA^;@Vmh;*fVmA}-5tXMR=g9P($q)D3p9KoASBz>O~(e#ogzVwGh
z_S|sO{R}Juk_EDNh#{9VJdYJ0AiTdg`3+2faMFejgr-N7t57X7Zl)mMnsGdS9nC-3
z1>q3AG|rF*&4Z9J*BudS*Wmm@GwELb?4_2T
z`aqikFRH|1*8mD^vYPIB`49|SIw*92fs;NG6b|
z-k;IEO$ovtH+D#G25>W#wp%IcJb!-OW4%J+SwM0{u(*{1G!pPONA)d7D4dNw7Ft
zwSwg0(J8-<`}$6X<3(EiS%WM@j%8v7K={(iEGF1HNSDBZ-UK(M->pXuOe6-@EZbufq^xyC~?2%w&0zBto;g$CG6GgpnH3r#5SWp^Vj
zfCaJqyXYO{5-_bo@%cwR8rcOoCC)0J>%wji3S1mZ^#m>VIHqzDn?TUaG^!-IA(*`I
zl&@*T2$ZCG(vMRnH&9^goOMYWet?CqfKvLs9K~zoHoIKSDDNJew1x|d~?k<11v-<=s1B=YkgzO$t;*||W
zQY`(aQBqbuqppLJY&4OHQv+paw)1p~Sys6T(dn>ir6ExTf;scaG87L5&NwPfA!`5&
zm~$b>p`O3r!*>jG^9F>XmtjYYL0a_NO~_;(Q~c~gs|C?q#vQ9@>e|rA?P0vr*a6s7
zzWgg#1IWB;-%@P>1j~iDZ#;A9L$aDKwR<#h0*|N=Fh9%*Xdj)@k22wn*THWfxOi;T
z$|S((?5%btCWPTkGe68BM}&3i?hI^JDiF?)}lDE%g7#h2C`UT;ul*D5Ur?!7@Ry1Re}_P6IC9&5x3@F|MFi()dXDzYKIz_hNF{?{CQi=kdRR?jJySa
zZ;tCx%^k+!A$2@|hu=uXB9OZNjs_H)>H)07u6AUjf
z*n#f_*yfyeHArNc6WP@zUi1tBPFn0`Qj`_QtTJqOO}T@VDj@!1B#$)|mzQ_hS;V>r
zBV=Z7D5x|m>Kn&SuC+N7#4~
zXsFv6<2VPMcEIrUwx=?7{VjYQ3HCC4pnkcZKZRdxX#h57xR_C6ID}y4k1!yuD2||D
znP2Ft!A*|NB@=axV+a(4`Jp*cIK(6qY-~k3a>AnelS3T!hd{u)wTJSGs8*
zjL}jrOt1mA6mYWW^|dA>v?|BIm;k^AV^7=u4THlHI{+6aj)`Lej$Awx(F3}wg_jLh
z`>+Ei;5C5(z@1%M1(PU;5Z!q0G%*5@`O_9s44G~qr5emU3L43dp;^-3!S4>;0NBEN
zbY;JL(@cpOChX6kAf^X8$vr1brqnQLNO9iTfDgA-&f;I?ay${)zXy#~-;|@$KuCvg4!qZ^@$GGnfQf)tVCm9t>fC`~G;Cu{*t~S#VwU&7loEq3zAXJo
z5I75+4jFDhFt;DPxWGRPs?U^DX3c6CH=&q6oKS%)U^0wWxJuYLs1IQVS%V5!~8iPd44s>#Pe#mGDQvenA
zE+@Kj1kuYhDx4l$qH*a%bNX5jiEaUk<+Z2QX9y(OXsu`qKrUOnBY_!Y!k%xUjPvI^
z8VgK-E|OgPhsVAtBvMSPEwh{ePcQ5E=-fY3O`{zEQ7?X_N+Uqfj9uyK-Yw82)Kyb&
zL#&`7O&^86`?>g-+#2YTo}TGy9;N^#iCbC~sr08s`CjR^2aP}`=msN%G6ZLMp&3B4Fh3VmX8G5
z1c1|($4O~p?w^qphPVJY*+&}qEHiu+_kn@V<&AHSfW;DeciIo-
zyDb&$9%|7vJXa4;`ZKc^)p`Tq91a+5VF=(tZMw%B%D-1_>j7hD7QFohn}SG`$e}jhqOyI9lF>LOyIUq>aoL1f&|xRhuxJGN$uF
zP6uq0<=z~pc(oANb^K=<_w@*yBTur?MIVyHcNNcb4?qP|`GUw0qEp3@CS9f)$`Lfn
z2)Fo@{w8BxzrYqUUjBic9!$VwqnZFuQ;_+vI||tZBqZ|-?_mrCu@z597XWj#6Zd>y
z$?5D8BB{|-wVu*GfUF7~zoVEd2v$04bsSSg_E7KgN86mGSm%!(`~2KJ1W(iT+?)Ib
z#ky?iKQ9XUX>2%efXzBX8^6O_hJ?K;8QJnC5cZdU_T^v0DgYE~hc%0;>3BlME^7eu
z%urB6-FhZm;hc{$HXvB)jM-Bg-h}4N^|)d*K{wSUTqOOP1!-Ac51P#vpNH{Ea36{@^CNLY8bG2n+oU%nNP7EPQy39W27(!L
z)tmy}WE!2q&u#h#0n
zHcJ1X_=2HzTp@=*W*;!Zp(iYZ$n0ryq_UY~>%1-rRva7aeee621<1+A-ljcln8z2dx;**Fm;vPMj<{lr=D{utts*S}
z&)F4|_e)2K;oO=G6i1B1~14?Q}yrA2JK#8y&
zjYnIcqM~TT*Y>ZE^vpem0J21ABj`(I+THvGyCjUgdSJ_#D-8BO(ud?ywUU>L0Tgrb
zDQj1#X$Zk``0#>^M)~(wZ(?)m1_ElNf#uh+0VnhU6M#9Un*~peV{c9xnSn7^+!$gE
zge@cPc!obY83jaV&ErTI1Q2<|!|^l&prm2F2uAuVD3(FGG%KrsX46*7%Y)l-Y-J6V
zoG)~->)zp0P0UOGMN#;e|6vrF@I)>vu!>9|HT6WB5J)O~URDSvv*f95hPN}L
z(dEk?L3reFK5YSE@;mT^cmPWJ$703f5Q-InQy?mK+@%~rG$U6Ug5Kmf^1B#rSWrDg
zGNE@s^Mq|_oHsSI(@aud8!u+fGKt3eG&;&TB=`ycMxG!opxB7^n5AO`V6*Z{8+*4v
zO4+A^to$w2Wrur9`ti=+svC*^e+@QUY(X(VM?LqM#D=vlDqsXZDU4uWm)9`@2+s1;
zm{$U=W$!er<=un^8~YrhuMk8IJ?-CBO+I3bQX}Ycd=gW*x__$v<0ze=m2kzGPSx!u
z6k9FIRW}?<-mRR3x?5TRFgd;VO~zRr!mhU0_9W1SV5XUBU;;pD>GDstJlXe$Te6w~
zVP_RS`jcm=p<_^8sPTj?wUhJDF5arRH<`3APY;sH7!r2Hr3|fUOpMTlcC@FUq7rJO
zl0&h?nL$Nnhh#A)DDM8Aj@&K)&d;Ywq;yCwog9Hcxb?Tc%vG@!K?_#<_b44~Y3lrg
zOnRMT&`rYXeU8C(yl9z9KP^=1%h3|d1~jKbEx1&TDno(ckNI(vpr!lxgqExVG=FS&
zXxvzZ;Gvt`U2FWA=KD2%NKM3x!jDfKS}HIb265e~0m7dR##nLV*KzVxyTFIrc#~
z?7smZa%wI63-B;DCf)t0FcUERaZMfF6L!japI%e285DS;Qc9@v{Ov>LbZFn-Udma<
zlCZ_fU`+G+TL>0QQ(e;vm;@KSErxLi2_1dyTd19&%Fb<1a=p)SS>E;fWa+;wbiq0V
zO9C?E4QS5ttsDrUbOeq{7-lES*eX<0Y09A*3csdje4ldN~qf
zGYxe6IZ`^41uFw+Xzr)Ki}Z(t&G}Q^iW&K*cqs3NuqDB^U97B*p};gtR%)3h{*+T~
zk@XblofkS`Zav}4MXs0B2IPb^1=28X4kZl)vz(^j{Z|+0@bsA{bdP@kZkgZ)y0kNf^#pR_hU~0-}txYt-8+z?DMRnWXcYf5xyG
zjMV+59#iJ02FLh19V1gipB|)E)=+;XgmDg
z17TC?3>+!XhK^mZDSd5NSj|e7fG1aNB$E!hsq>yf
zBK74hS*9G^#U^s}7|5cd
zJ1As$s8DBz*#gQB-s0=k|NcNjwp+(4F_(mseHXzxGf64)|W{&Wnc=D_T$u_~@
z0U@=)@cGR@AN&pt1Hcl~7NCUE&Cz)Sfu-AC#s@Y7uw@tzUJrm2O7~IexJiEK-+jyv
zTL8k|Go6T`Rd8&zI^5TxY-7icc@%}20BwqCZWUer)M?-sOZ8~xPe@~1XVT_C7Y~y}
z#|uXe=Q`-$sO5+2Zb6jRtCTT8TA4y(m~{A_xr1WO$