Skip to content

nanzhijin/CNN-Analize

Repository files navigation

CnnMusic — Multi-Modal Content Recall System

Python PyTorch Spark FAISS License

CNN (Audio) + NLP (Text) → Multi-Modal Embedding → FAISS Content Recall

内容召回通道 · 播客/音乐推荐系统的内容理解基础设施

GNN 项目 形成「协同召回 + 内容召回」双通道混合召回体系


为什么需要这个项目

GNN 项目已经解决了协同召回…

我的 GNN 项目(CAAI-BDSC2023 社交图谱链接预测)基于用户行为图谱做协同召回:

  • LightGBM MRR@5 = 0.56(好友场景) / GNN MRR@5 = 0.33(冷启动场景)

…但诊断出一个关键缺口

  1. 没有 per-(item, user) 内容匹配特征 — 模型不知道「这篇文章/播客讲的是什么」,只知道「谁分享给了谁」
  2. ItemEncoder 对所有候选输出相同的嵌入 — 商品信息对排序贡献为零
  3. 冷启动 item 完全不可见 — 没有互动历史的新内容无法被协同通道召回

CnnMusic 填补这个缺口

┌─────────────────────────────────────────────────────────────────┐
│              Multi-Channel Recall System                        │
│                                                                 │
│  ┌─────────────────────────┐    ┌─────────────────────────┐    │
│  │  Collaborative Recall   │    │  Content Recall          │    │
│  │  (GNN Project)          │    │  (This Project)          │    │
│  │                         │    │                          │    │
│  │  用户行为图谱            │    │  音频 + 文本              │    │
│  │  → LightGBM / GNN       │    │  → CNN + MiniLM          │    │
│  │  → 社交信号召回          │    │  → 内容信号召回           │    │
│  │                         │    │                          │    │
│  │  MRR@5: 0.56            │    │  覆盖冷启动 & 长尾内容    │    │
│  └─────────────────────────┘    └─────────────────────────┘    │
│                    \                      /                      │
│                     └──────────┬──────────┘                     │
│                                ▼                                │
│                     Multi-Channel Fusion                        │
│                     → Ranking Stage                             │
└─────────────────────────────────────────────────────────────────┘

系统架构

                     Input: Audio File (.wav/.mp3)
                                │
              ┌─────────────────┴─────────────────┐
              ▼                                   ▼
    ┌─────────────────┐                 ┌─────────────────┐
    │  Audio Branch    │                 │  Text Branch     │
    │  (CNN · PyTorch) │                 │  (s-bert · NLP)  │
    │                  │                 │                  │
    │  梅尔频谱图       │                 │  歌词/标题/简介   │
    │  → 4层 CNN       │                 │  → MiniLM-L6     │
    │  → 128d 嵌入     │                 │  → 384d 嵌入     │
    │  (~1.1M params)  │                 │  (zero-training) │
    └────────┬─────────┘                 └────────┬─────────┘
             │                                    │
             └────────────┬───────────────────────┘
                          ▼
              ┌──────────────────────┐
              │   Joint Embedding    │
              │   concat[128 | 384]  │
              │   = 512d, L2-norm    │
              └──────────┬───────────┘
                          ▼
              ┌──────────────────────┐
              │   FAISS Index x3      │
              │   Cosine Similarity   │
              │                      │
              │  Audio (128d)        │
              │  Text  (384d)        │
              │  Joint (512d)        │
              └──────────┬───────────┘
                          ▼
                 Top-K Retrieval
                          
    ┌──────────┬──────────┬──────────┐
    │  Audio   │  Text    │  Joint   │  ← 三路对比召回
    │  Recall  │  Recall  │  Recall  │
    └──────────┴──────────┴──────────┘

关键设计决策:

决策 理由
CNN 不纯做分类,做嵌入提取 分类是辅助任务,嵌入是核心产出
NLP 用 MiniLM-L6 不训练 性价比最高,384d 足够区分文本主题
双模态 concat 而非加权融合 FAISS cosine 对每维独立,scale 差异不敏感
离线提取嵌入 + 在线 FAISS 检索 嵌入提取是批量作业,在线只做 <1ms 的向量检索
三路索引并存 单模态兜底:无文本用 Audio,无音频用 Text,都有用 Joint

场景覆盖

场景 CNN (音频) NLP (文本) 双模态 典型应用
流行歌 + 有歌词 ✅ 音色/编曲 ✅ 歌词语义 ⭐ 最强 音乐推荐
纯音乐/器乐 ✅ 乐器/节奏 ⚠️ 仅曲名 ✅ CNN 兜底 古典/爵士推荐
播客访谈 ✅ 语调/语速 ✅ ASR 语义 ⭐ 最强 播客推荐
播客无文本元数据 ✅ 音频特征 ❌ 无文本 ✅ CNN 兜底 冷启动播客
冷启动新内容 (无音频) ❌ 无音频 ✅ 文本描述 ✅ NLP 兜底 新上架内容

模型能力测量

"我没有直接设计最终架构,而是先测了三个问题:CNN 精度够吗?文本嵌入能区分主题吗?双模态比单模态好多少?数字驱动设计。"

测量 1: CNN 分类精度对比

模型 参数量 Test Accuracy 模型大小 Epochs 决策
MelSpectrogramCNN (标准) 656K 66.0% ~2.6 MB 40 (早停) --
LightweightMelCNN (轻量) 157K 68.0% ~0.6 MB 30 (早停) ✅ 选中

决策:Lightweight 参数少 4×,精度反而高 2pp → 选轻量版作为默认模型

测量 2: 文本嵌入区分度

测试 结果 判定
同类内 cosine (同流派描述相同 → 1.0) 1.0 ✅ 粗粒度召回完美
异类间 cosine 均值 0.377 ✅ 良好区分
最高跨类相似度 (metal ↔ rock) 0.568 ✅ 合理(二者确实相似)
跨类相似度 > 0.60 对数 0/90 ✅ 无混淆对

结论:文本嵌入能有效区分流派。但同流派内所有 item 共享相同描述 → 文本召回在流派级别完美(Recall=1.0),但在流派内部无法区分。这是预期行为:文本做粗召回,音频做细召回。

测量 3: 三路召回对比

模态 Recall@5 Recall@10 MRR@5 Weighted Recall@10 备注
Audio (128d) 0.9585 0.9728 0.8904 0.9729 CNN 嵌入
Text (384d) 1.0000 1.0000 1.0000 1.0000 粗粒度完美,无细粒度
Joint (512d) 0.9614 0.9757 0.8999 0.9757 双模态融合
Random Baseline ~0.41 ~0.65 ~0.20 ~0.65 随机检索期望值

关键发现

  • Audio 已经很强大(Recall@5 = 95.85%),比随机基线高 2.3×
  • Text=1.0 是因为同流派 item 共享完全相同的文本描述 → 流派级检索完美但流派内无法区分
  • Joint 比 Audio 略微提升 (+0.29pp Recall@10),但差值很小

为什么 Recall 95.85% 但 Classification Accuracy 只有 68%?

这是两个完全不同的任务,不矛盾:

分类 (Accuracy = 68%):        召回 (Recall@5 = 95.85%):

  频谱 → CNN → 128d →          频谱 → CNN → 128d → FAISS检索
          ↓                            ↓
  线性分类头 (128→10)           余弦相似度 → Top-5 近邻
          ↓                            ↓
  预测: "这是爵士"              结果: 近邻里有没有同流派的?
          ↓                            ↓
  边界模糊 → 分错              空间靠近 → 命中

逐流派对比揭示本质:

流派 分类 Precision 分类 Recall 召回 Recall@5 诊断
classical 0.93 0.87 ~1.0 嵌入好 + 边界好 ✅
metal 0.86 0.80 ~1.0 嵌入好 + 边界好 ✅
rock 0.27 0.20 ~0.95 嵌入还行,边界崩溃 ⚠️
country 0.67 0.53 ~0.95 分类头混淆 country/rock/pop
disco 0.88 0.47 ~0.95 分类头不敢判 disco

Rock 是最典型的案例:嵌入空间里 rock 样本聚集在一起(召回能找得到),但线性分类头没法在 128 维空间里精准画出 rock vs country vs pop 的决策边界(分类大量错判)。

核心洞察:

"CNN backbone 产出的嵌入已经很好地把同流派音频聚在一起了(Recall@5=95.85%),但最后一层的线性分类器面对边界模糊的流派(rock/country/pop 的确听起来很像)无能为力。

这在真实推荐系统里不是问题——召回阶段只需要「相似的东西在一起」,不需要「这个东西叫什么名字」。分类头只是辅助训练 backbone 的手段。嵌入质量才是核心交付物,分类精度只是辅助信号。"

面试追问预案:

追问 回答
"那你怎么知道嵌入真的好?" 用 Recall@K 评估——留一法下 95.85% 的近邻命中同流派。分类头在 10 类上混淆不代表嵌入不好。
"分类头不准能提升吗?" 可以——换更强的分类器(MLP 替代单层 Linear)、加 metric learning 损失、或放弃分类直接做对比学习。但这些都不改变召回的核心价值
"召回 95% 是不是过拟合?" 不是——GTZAN 只有 1000 条,留一法评估的是泛化能力。同流派音频的频谱确实相似,这是物理事实不是数据泄露。
"生产环境 Recall 也会这么高吗?" 不会——GTZAN 的 10 类边界比较清晰。真实播客的音频相似度会更模糊(两个不同话题的访谈节目频谱可以很像)。但评估框架(Recall@K + 留一法)是通用的,换任何数据集都能测出真实水平。

A/B 实验结果

指标 A (Audio-only, 128d) B (Joint, 512d) Delta
Weighted Recall@10 (北极星) 0.9729 0.9757 +0.0029
MRR@5 0.8904 0.8999 +0.0095
  • p-value: < 0.01 (Bootstrap Permutation Test, N=10,000) — 统计显著
  • 效应量: 极小 (+0.29pp)
  • Index Size Guardrail: 4.0× > 2.0 threshold — 未通过 (512d vs 128d = 4×存储)

业务决策:虽然双模态统计显著,但效应量极小 (+0.29pp) 且存储成本翻 4 倍。建议首期上线纯音频召回 (128d),待获得逐 item 的细粒度文本(如 ASR 转录、歌词)后,再评估文本分支的价值。

"A/B 实验不仅验证了双模态的统计显著性,更重要的是给出了'现在不该上'的商业决策——这就是面试官想看到的工程判断力。"


北极星指标 & A/B 实验

北极星: Weighted Recall@10

Weighted Recall@10 = Σ(w_g × Recall@10_g) / Σ w_g
其中 w_g = 1 / frequency_g(逆流派频率加权)

为什么选这个?

考量 答案
Recall 而非 MRR? 召回只管「有没有」,排序是排序阶段的事
@10 而非 @5? 给排序阶段留 buffer,10 候选够筛
为什么加权? 防止靠 easy genres 刷分,长尾品类才是真本领
业务翻译 播客 App 用户偏好多样,加权度量均等服务质量

A/B 实验设计

  • Control (A): 纯音频 CNN 召回 (128d)
  • Treatment (B): 双模态 CNN+NLP 召回 (512d)
  • Primary Metric: Weighted Recall@10
  • 统计检验: Bootstrap Permutation Test (10,000 次), α=0.05, Bonferroni 校正
  • Guardrails: 嵌入提取延迟、索引大小、查询延迟

详见 recall/ab_test.py


快速开始

1. 环境

pip install -r requirements.txt

2. 下载数据集

# GTZAN — 10 流派 × 100 条 = 1000 条 30s 音频片段
# 已验证下载源: HuggingFace mirror
python -c "
import requests, tarfile, os
url = 'https://huggingface.co/datasets/marsyas/gtzan/resolve/main/data/genres.tar.gz'
r = requests.get(url, stream=True)
with open('data/genres.tar.gz', 'wb') as f:
    for chunk in r.iter_content(8192): f.write(chunk)
with tarfile.open('data/genres.tar.gz', 'r:gz') as tar:
    tar.extractall('data/')
os.remove('data/genres.tar.gz')
print('Done!')
"

3. 模型训练

# 标准 CNN
python train.py --source online --model standard --epochs 50

# 轻量 CNN (性价比对比)
python train.py --model lightweight --epochs 50

4. 嵌入提取 & 建索引

# 提取三路嵌入
python -c "
from recall.audio_embedder import AudioEmbedder
from recall.text_embedder import TextEmbedder
from recall.faiss_index import build_all_indexes
from utils.dataset import OnlineMelDataset
import pandas as pd, numpy as np, os

# 加载数据
ds = OnlineMelDataset(split='train')
val_ds = OnlineMelDataset(split='val')
test_ds = OnlineMelDataset(split='test')
all_ds = OnlineMelDataset.__new__(OnlineMelDataset)

print('Extracting audio embeddings...')
audio = AudioEmbedder('models/checkpoints/best_model_standard.pt')
audio_embs, labels = audio.extract_dataset(ds)

# 构建 metadata
paths = [ds.records[ds._indices[i]][0] for i in range(len(ds))]
genres = [os.path.basename(os.path.dirname(p)) for p in paths]
meta = pd.DataFrame({'file_path': paths, 'genre': genres})

print('Extracting text embeddings...')
text = TextEmbedder()
text_embs = text.encode_from_audio_list(paths, genres)

print('Building joint embeddings...')
joint_embs = np.concatenate([audio_embs, text_embs], axis=1)
norms = np.linalg.norm(joint_embs, axis=1, keepdims=True) + 1e-8
joint_embs = joint_embs / norms
joint_embs = joint_embs.astype(np.float32)

print('Building FAISS indexes...')
build_all_indexes(audio_embs, text_embs, joint_embs, meta)

# Save for A/B test
np.save('data/features/audio_embeddings.npy', audio_embs)
np.save('data/features/text_embeddings.npy', text_embs)
np.save('data/features/joint_embeddings.npy', joint_embs)
np.save('data/features/recall_labels.npy', labels)
print('All done!')
"

5. 召回评估

# 三路对比
python -c "
from recall.recall_evaluate import compare_modalities
import numpy as np
audio = np.load('data/features/audio_embeddings.npy')
text = np.load('data/features/text_embeddings.npy')
joint = np.load('data/features/joint_embeddings.npy')
labels = np.load('data/features/recall_labels.npy')
compare_modalities(audio, text, joint, labels)
"

6. A/B 实验

python recall/ab_test.py

7. 端到端 Demo

# 双模态查询
python recall/recall_demo.py data/genres/blues/blues.00000.wav --k 5 --modality all

# 纯文本查询(不需要音频)
python recall/recall_demo.py --text "heavy distorted guitar with fast drums and aggressive vocals" --k 5

# 仅音频查询
python recall/recall_demo.py data/genres/jazz/jazz.00000.wav --k 5 --modality audio

项目结构

CnnMusic/
├── config.py                         # 全局配置 (含嵌入维度/K值/NLP模型)
├── train.py                          # 训练脚本 (MLflow 追踪)
├── evaluate.py                       # 评估脚本 (分类 + 召回)
├── inference.py                      # 推理入口 (分类 / 嵌入提取)
├── requirements.txt
├── README.md
│
├── models/
│   ├── cnn_model.py                  # CNN (标准 656K / 轻量 157K)
│   └── checkpoints/                  # 模型 checkpoint
│
├── recall/                           # ★ 召回模块 (新增)
│   ├── __init__.py                   # 模块架构文档
│   ├── audio_embedder.py             # CNN → 128d 音频嵌入
│   ├── text_embedder.py              # MiniLM → 384d 文本嵌入
│   ├── joint_embedder.py             # 双模态融合 → 512d
│   ├── faiss_index.py                # FAISS 索引 + 多模态检索器
│   ├── recall_evaluate.py            # Recall@K / MRR@K 评估
│   ├── recall_demo.py                # 端到端检索 Demo
│   └── ab_test.py                    # A/B 实验 (Bootstrap)
│
├── utils/
│   ├── dataset.py                    # PyTorch Dataset (Parquet / Online)
│   └── audio_utils.py                # 音频处理 + 数据增强
│
├── spark/
│   ├── feature_extraction.py         # Spark 分布式特征提取
│   └── batch_embedding.py            # (future) 批量嵌入提取
│
└── data/
    ├── genres/                        # GTZAN 数据集 (需下载)
    ├── features/                      # Parquet / FAISS 索引 / 嵌入
    └── genre_descriptions.json        # 流派文本描述 (NLP 分支输入)

技术栈

环节 技术 说明
音频处理 librosa + soundfile 梅尔频谱图提取
CNN 模型 PyTorch 2.6 分类 + 嵌入提取,backbone 可替换
NLP 嵌入 sentence-transformers (MiniLM-L6) 零训练,384d 文本语义编码
向量检索 FAISS (cosine similarity) 毫秒级检索
实验管理 MLflow 参数/指标/模型版本追踪
分布式处理 PySpark 规模化特征提取管线
数据存储 Parquet (snappy) 列式存储,高压缩比

与 GNN 项目的关系

维度 GNN 项目 本项目 (CnnMusic)
召回类型 协同召回 (Collaborative) 内容召回 (Content-Based)
输入信号 用户行为图谱 (inviter → voter) 音频 + 文本 (内容本身)
核心模型 LightGBM + GraphSAGE/DySAT CNN + MiniLM
评估指标 MRR@5 Weighted Recall@10
冷启动覆盖 弱 (依赖交互历史) 强 (有内容就有嵌入)
长尾覆盖 弱 (交互稀疏) 强 (内容相似度)
面试叙事 "我知道谁跟谁关系好" "我知道什么内容像什么"

两个项目合在一起,就是一个完整的推荐系统召回层。

面试叙事:

"GNN 项目验证了基于社交图谱的协同召回能力——'这个人跟你的朋友们互动频繁,可以推荐给你'。CnnMusic 验证了基于内容理解的召回能力——'这个播客听起来跟你刚听完的那个很像'。在真实的播客推荐系统中,两个通道融合:协同通道覆盖高互动用户,内容通道覆盖冷启动和长尾内容。这就是工业级多通道召回系统的完整思路。"


规模化路径

当前用 GTZAN (1000 条) 验证管线架构。生产环境:

组件 当前 (Demo) 生产 (Podcast Platform)
数据量 1,000 条 百万条播客
特征提取 单机 librosa Spark 分布式 → Parquet
NLP 输入 流派描述 ASR 转录 + 标题 + 简介
嵌入提取 CPU 离线 GPU 批量 / Spark mapInPandas
FAISS IndexFlatIP (暴力) IVF + PQ 量化索引
在线检索 < 1ms < 10ms (分布式 FAISS)

核心原则:架构不变,只换数据源和部署规模。

License

MIT

About

大学时期校内实习,针对CNN进行着重学习,涉及到Resnet、Attention、线性回归、图像增强等技术。在后来进行了新的探索,涉及到YOLO11的动态识别,但是失败了。重新在两年后考虑和GNN项目联合,开发一个能够辅助GNN项目,输出博客产品类型的CNN音频识别项目,业务场景明显,而且项目健壮厚重,有非常高的实践价值。

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages