A Julia prototype of Arora, Li, Liang, Ma & Risteski, "A Latent Variable Model Approach to PMI-based Word Embeddings" (TACL 2016) — implemented as a journal-club presentation aid.
- Builds a vocabulary from a tokenised corpus (defaults sized for text8, ~17 M tokens).
-
Builds the sliding-window co-occurrence matrix as a symmetric
SparseMatrixCSC{Float32, Int32}viaOhMyThreads-parallel dict-merge. - Trains word vectors by stochastic gradient descent on the SN objective from §3 of the paper, with hand-coded AdaGrad fused into the inner loop.
- Empirically verifies the paper's predictions via five diagnostics (D.1–D.6 below); each emits a JLD2 record and a Makie figure.
-
Optionally reformulates the discourse walk as a continuous-time
Stratonovich SDE on
$S^{d-1}$ and integrates it with a projected Euler–Maruyama scheme (tangent-space increment + renormalisation).
cd LatentRandomWalk
# put text8 and (optionally) questions-words.txt in data/raw/
julia --project --threads=auto scripts/run_all.jlEach phase script is also runnable on its own and is skipped by run_all.jl
if its checkpoint already exists. FORCE=1 julia --project scripts/run_all.jl
re-runs everything.
Override training hyperparameters with environment variables:
DIM=100 EPOCHS=20 julia --project --threads=auto scripts/03_train.jlVerifying the paper rather than the implementation is the whole point. The five diagnostics are:
| ID | Test | Paper reference | Predicted outcome |
|---|---|---|---|
| D.1 | Partition function |
Lemma 2.1 / Fig. 1a |
|
| D.2 |
|
Theorem 2.2 | Pearson |
| D.3 | Singular values of |
Theorem 4.1 | |
| D.4 | Corollary 2.3 (paper's headline eq. 1.1) | slope |
|
| D.5 | Google analogy testbed | §5.2 | ~35–50 % on text8 |
| D.6 |
|
§5.3 / RELATIONS=LINES |
All five run in well under a minute against trained text8 vectors. They
constitute the central slides of the journal-club talk.
LatentRandomWalk/
├── Project.toml / Manifest.toml
├── src/
│ ├── LatentRandomWalk.jl # module top, includes & exports
│ ├── corpus.jl # Vocabulary, tokenisation
│ ├── cooccurrence.jl # sparse pair-count matrix
│ ├── model.jl # Embeddings, SN training, AdaGrad
│ ├── analogies.jl # Google/MSR analogy evaluation
│ ├── verify.jl # D.1, D.2, D.3, D.4, D.6 diagnostics
│ └── sde.jl # LatentRandomWalk.SDE submodule
├── scripts/
│ ├── 00_download_corpus.jl
│ ├── 01_build_vocab.jl
│ ├── 02_build_cooccurrence.jl
│ ├── 03_train.jl
│ ├── 04_verify.jl
│ ├── 05_analogies.jl
│ ├── 06_sde_demo.jl
│ └── run_all.jl
├── test/ # gradient check vs Zygote, etc.
├── notebooks/walkthrough.jl # Pluto journal-club companion
├── data/ # gitignored: raw/, processed/, results/
└── figures/ # generated PDFs
The SN loss is
Per pair, with residual
We iterate over the upper triangle of the symmetric co-occurrence matrix
(row < col) and apply per-parameter AdaGrad inline. The inner loop is
@inbounds @simd, type-stable (@code_warntype-clean), allocation-free
after the upfront randperm. AdaGrad is not delegated to
Optimisers.jl: the optimizer step is fused with the sparse gradient
access pattern, which a generic library cannot exploit — see
the discussion in implementation-plan.md.
SN units vs model units. Theorem 2.2 of the paper is written in vectors
The SN objective drops the
- Float32 throughout. Saves half the memory, halves cache pressure,
precision is irrelevant for embeddings. (Note:
svdvals(V)in D.3 promotes to Float64 internally inside LAPACK — the trained vectors themselves remain Float32 everywhere else.) - Symmetric upper-triangle storage during training (each unique pair visited once).
- BLAS-vectorised verification. D.1 is a series of GEMVs; D.5 is one GEMM per batch of analogy queries.
OhMyThreadsfor the embarrassingly-parallel parts of the pipeline (Phase B co-occurrence build; D.6 SVDs).
The plan's stretch goal: take the discrete random walk's continuous-time
limit, which is Brownian motion on
with a projected Euler step — a tangent-space increment followed by
renormalisation back to the sphere. This is the standard geometric
integrator for SDEs on StochasticDiffEq.EM(); the projected scheme is
both simpler and exact-on-the-sphere by construction.
partition_function_along_path then shows that
- The paper is sometimes ambiguous about whether
$X_{w,w'}$ is the raw count or a distance-weighted one. The SN derivation works with raw counts, and that's what we use. - The bias scalar
$C$ absorbs$-2 \log Z$ ; we don't try to recover$Z$ separately. - The Stratonovich SDE is integrated with a projected Euler step (tangent
increment then renormalise), not a generic Euler–Maruyama solver. The
ambient-space Itô form carries a stiff
$(d-1)/2$ drift that would force a tiny timestep at$d = 300$ ; the projected scheme avoids it and keeps$\Vert c_t \Vert = 1$ to floating-point precision at every step. - The Google testbed contains some questions whose answers aren't in the text8 vocabulary; we skip those and report coverage.
julia --project --threads=auto scripts/run_all.jl
ls figures/All randomness is seeded from a single seed (SEED= env var; default 0).