Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,20 @@ jobs:
- name: Train PyTorch kws_mfcc (produces reference predictions + weights)
run: uv run examples/kws_mfcc/train_pytorch.py

- name: Cache kws_raw processed data (6-class)
id: kws-raw-cache
uses: actions/cache@v4
with:
path: examples/kws_raw/data/6class
key: kws-raw-6class-${{ hashFiles('examples/kws_raw/prepare_data.py', 'examples/_shared/speechcommands_data.py') }}

- name: Prepare kws_raw data (6-class; only on cache miss)
if: steps.kws-raw-cache.outputs.cache-hit != 'true'
run: uv run examples/kws_raw/prepare_data.py

- name: Train PyTorch kws_raw (produces reference predictions + weights)
run: uv run examples/kws_raw/train_pytorch.py

- name: Configure
run: cmake --preset examples

Expand Down Expand Up @@ -268,6 +282,16 @@ jobs:
--c examples/kws_mfcc/outputs/6class/c_predictions.npy \
--dtype int32

- name: Run kws_raw in BIT_PARITY mode
run: BIT_PARITY=1 build/examples/examples/kws_raw/train_c_kws_raw

- name: Diff kws_raw predictions (int32, exact match required)
run: |
uv run examples/_shared/compare_predictions.py \
--pytorch examples/kws_raw/outputs/6class/pytorch_predictions.npy \
--c examples/kws_raw/outputs/6class/c_predictions.npy \
--dtype int32

python-test:
runs-on: ubuntu-latest

Expand Down
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ add_subdirectory(ecg_anomaly_ae)
add_subdirectory(mnist_mlp)
add_subdirectory(mnist_cnn)
add_subdirectory(kws_mfcc)
add_subdirectory(kws_raw)
3 changes: 2 additions & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ checking and visualizations.
| `mnist_cnn/` | MNIST 1D-CNN digit classification | ✅ |
| `har_classifier/` | UCI HAR 6-class activity classification | Stage 1 |
| `ecg_anomaly_ae/` | ECG5000 reconstruction-based anomaly detection | Stage 2 ✅ |
| `kws_classifier/` | SpeechCommands 6-class keyword spotting | Stage 3 (planned) |
| `kws_mfcc/` | SpeechCommands keyword spotting (MFCC features) | Stage 3 ✅ |
| `kws_raw/` | SpeechCommands keyword spotting (raw waveform + in-model downsample) | Stage 3 ✅ |
| `kws_denoising_ae/` | SpeechCommands additive-noise denoising | Stage 4 (planned) |

## Running an example
Expand Down
68 changes: 68 additions & 0 deletions examples/kws_raw/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
add_executable(train_c_kws_raw train_c.c)

target_link_libraries(train_c_kws_raw PRIVATE
DataLoaderApi
DataLoader
NPYLoaderApi
NPYLoader

Layer

Conv1dApi
Conv1d

LinearApi
Linear

ReluApi
Relu

FlattenApi
Flatten

Pool1dApi
MaxPool1d
AvgPool1d

AdaptivePool1dApi
AdaptiveAvgPool1d

LayerNormApi
LayerNorm

QuantizationApi
Quantization

TensorApi
Tensor
Rounding

TrainingLoopApi
CalculateGradsSequential
TrainingBatchDefault
TrainingEpochDefault
Optimizer

LossFunction
CrossEntropy

SoftmaxApi
Softmax

Sgd
SgdApi

InferenceApi

StateDictApi
LayerWeightsApi
LayerQuant
LayerCommon
Distributions

Common
StorageApi
RNG

examples_shared
)
61 changes: 61 additions & 0 deletions examples/kws_raw/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# KWS Raw Waveform — PyTorch + C Parity Demo

Trains a 1D-CNN keyword-spotter on **raw 16 kHz SpeechCommands waveforms** in both
PyTorch (reference) and the ODT C framework. Companion to `kws_mfcc/`: same data
and harness, but instead of pre-computing MFCC features, the model consumes the
native `[1, 16000]` waveform and **downsamples in-framework** — its first layer is
`AvgPool1d(K=16, S=16)`, a decimation-by-16 box filter that turns 16 kHz into
1 kHz. Change `K` to change the effective rate (8 → 2 kHz, …) with no re-prep; the
`AdaptiveAvgPool1d(1)` head is length-agnostic so the rest of the model is
unchanged (only the three MaxPool nominal `inputLength`s in `train_c.c` need to
track the new lengths).

One binary, two modes — **bit-parity** (`BIT_PARITY=1`, the exact CI gate) and a
**train-from-scratch** informational demo. See `kws_mfcc/README.md` for the mode
explanation and the `KWS_CLASSES` knob; commands are identical with `kws_raw`
substituted.

## Why LayerNorm + a longer schedule

Raw waveforms are far harder to train than MFCC features: at the `kws_mfcc`
settings (lr=0.001, 15 epochs) the raw model never escapes its random-init
fixed point (flat loss, every clip predicted as one class), which would make the
bit-parity gate degenerate. Two changes fix it without leaving the framework's
bit-parity-covered layers:

- a rate-agnostic **`LayerNorm(64)`** on the pooled features before the classifier
(the C framework has bit-parity LayerNorm; BatchNorm is not covered), and
- **lr=0.005, 20 epochs** (the model breaks through around epoch 15).

The reference then reaches ~0.59 test accuracy with predictions spread across all
six classes, so the gate genuinely exercises the `AvgPool1d[1,16000]` + Conv +
LayerNorm arithmetic (C reproduces PyTorch's predictions int32-exactly).

## Run it (6-class)

```bash
uv run python examples/kws_raw/prepare_data.py
uv run python examples/kws_raw/train_pytorch.py
cmake --preset examples
cmake --build --preset examples --target train_c_kws_raw

BIT_PARITY=1 ./build/examples/examples/kws_raw/train_c_kws_raw
uv run python examples/_shared/compare_predictions.py \
--pytorch examples/kws_raw/outputs/6class/pytorch_predictions.npy \
--c examples/kws_raw/outputs/6class/c_predictions.npy --dtype int32
```

## Model

- Input: `[1, 16000]` → `reshapeItemsAddBatchDim` → `[1, 1, 16000]`
- `AvgPool1d(16) → Conv1d(1→16,K3,SAME) → ReLU → MaxPool(4) → Conv1d(16→32,K3,SAME)
→ ReLU → MaxPool(4) → Conv1d(32→64,K3,SAME) → ReLU → MaxPool(4) →
AdaptiveAvgPool1d(1) → Flatten → LayerNorm(64) → Linear(64→C) → Softmax → CE`
- Lengths: 16000 → 1000 → 250 → 62 → 15 → 1; ~10 K params
- State-dict layers: `conv1`, `conv2`, `conv3`, `ln`, `fc`
- Hyperparameters: SGD lr=0.005, momentum=0.9, batch=32, 20 epochs

The train-from-scratch demo is the slowest in the suite (raw `[1,16000]` is the
heaviest input even after the AvgPool downsample) — run it offline. Bit-parity
mode requires exact equality; the train-from-scratch tolerances are informational
and match `kws_mfcc/`.
88 changes: 88 additions & 0 deletions examples/kws_raw/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Compare PyTorch and C runs of the kws_raw classifier.

Reads logs/<n>class/{pytorch,c}.json and outputs/<n>class/{pytorch,c}_predictions.npy.
Writes plots into plots/<n>class/. Prints a final-state parity report within tolerances.
INFORMATIONAL only — the bit-parity check (compare_predictions.py) is the gate.
"""
from __future__ import annotations

import os
import sys
from pathlib import Path

import numpy as np

REPO_ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(REPO_ROOT))

from examples._shared.log_schema import load_log # noqa: E402
from examples._shared.parity import ParityCheck, run_parity_checks # noqa: E402
from examples._shared.plotting import ( # noqa: E402
plot_accuracy_curves,
plot_confusion_matrix,
plot_loss_curves,
)

HERE = Path(__file__).resolve().parent
NUM_CLASSES = int(os.environ.get("KWS_CLASSES", "6"))
assert NUM_CLASSES in (6, 35), NUM_CLASSES
TAG = f"{NUM_CLASSES}class"
LOGS = HERE / "logs" / TAG
OUTPUTS = HERE / "outputs" / TAG
PLOTS = HERE / "plots" / TAG
DATA = HERE / "data" / TAG

CLASS_NAMES = (
["yes", "no", "up", "down", "silence", "unknown"]
if NUM_CLASSES == 6
else [str(i) for i in range(NUM_CLASSES)]
)

CHECKS = [
ParityCheck("test_acc", abs_tol=0.025), # ±2.5 pp
ParityCheck("test_loss", abs_tol=0.15), # ±0.15 nats (informational)
]


def confusion_matrix(preds: np.ndarray, labels: np.ndarray, num_classes: int) -> np.ndarray:
cm = np.zeros((num_classes, num_classes), dtype=np.int64)
for p, a in zip(preds, labels):
cm[int(p), int(a)] += 1
return cm


def main() -> int:
PLOTS.mkdir(parents=True, exist_ok=True)
pt = load_log(LOGS / "pytorch.json")
c = load_log(LOGS / "c.json")

plot_loss_curves(PLOTS / "loss_curves.png", pt, c)
plot_accuracy_curves(PLOTS / "accuracy_curves.png", pt, c)

test_y = np.load(DATA / "test_y.npy")
pt_pred = np.load(OUTPUTS / "pytorch_predictions.npy")
c_pred = np.load(OUTPUTS / "c_predictions.npy")
cm_pt = confusion_matrix(pt_pred, test_y, len(CLASS_NAMES))
cm_c = confusion_matrix(c_pred, test_y, len(CLASS_NAMES))
plot_confusion_matrix(PLOTS / "confusion_matrix_pt.png", cm_pt, CLASS_NAMES, "PyTorch KWS Raw")
plot_confusion_matrix(PLOTS / "confusion_matrix_c.png", cm_c, CLASS_NAMES, "C KWS Raw")

pt_finals = pt["final"]
c_finals = c["final"]
overall_pass, results = run_parity_checks(
CHECKS,
{"test_acc": pt_finals["test_acc"], "test_loss": pt_finals["test_loss"]},
{"test_acc": c_finals["test_acc"], "test_loss": c_finals["test_loss"]},
)

print("\nParity report (PyTorch vs C) — INFORMATIONAL:")
print(f"{'metric':<14} {'pt':>10} {'c':>10} {'diff':>10} {'tol':>8} {'type':>5} {'pass':>6}")
for r in results:
print(f"{r.metric:<14} {r.pt_value:>10.5f} {r.c_value:>10.5f} {r.diff:>10.5f} "
f"{r.tolerance:>8.4f} {r.tolerance_type:>5} {str(r.passed):>6}")
print(f"\nOverall: {'PASS' if overall_pass else 'FAIL'} (informational; not a CI gate)")
return 0 if overall_pass else 1


if __name__ == "__main__":
sys.exit(main())
42 changes: 42 additions & 0 deletions examples/kws_raw/prepare_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Prepare raw SpeechCommands waveforms for the kws_raw example.

Writes the native 16 kHz waveform directly — no resampling, no feature
extraction. Downsampling (16 kHz → 1 kHz via AvgPool1d) is the model's first
layer, so PyTorch and C read identical raw .npy.

Output (under examples/kws_raw/data/<n>class/, n = KWS_CLASSES in {6,35}, default 6):
{train,val,test}_x.npy [N,1,16000] f32
{train,val,test}_y.npy [N] i32 (0..n-1)
"""
from __future__ import annotations

import os
import sys
from pathlib import Path

import numpy as np

REPO_ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(REPO_ROOT))
from examples._shared.speechcommands_data import load_speechcommands # noqa: E402

HERE = Path(__file__).resolve().parent
RAW_ROOT = REPO_ROOT / "examples" / "_shared" / "data" / "speech_commands"


def main() -> None:
num_classes = int(os.environ.get("KWS_CLASSES", "6"))
assert num_classes in (6, 35), num_classes
data_dir = HERE / "data" / f"{num_classes}class"
data_dir.mkdir(parents=True, exist_ok=True)

splits = load_speechcommands(RAW_ROOT, num_classes)
for split in ("train", "val", "test"):
x, y = splits[split]
np.save(data_dir / f"{split}_x.npy", x.astype(np.float32))
np.save(data_dir / f"{split}_y.npy", y.astype(np.int32))
print(f"{split}: x={x.shape} y={y.shape} classes={num_classes}", flush=True)


if __name__ == "__main__":
main()
Loading
Loading