diff --git a/lib/rag/chunking.test.ts b/lib/rag/chunking.test.ts new file mode 100644 index 0000000..2efb602 --- /dev/null +++ b/lib/rag/chunking.test.ts @@ -0,0 +1,259 @@ +import { describe, expect, it, vi } from "vitest"; +import { chunkNoteDocument } from "@/lib/rag/chunking"; +import type { NoteDocument, NoteListItem } from "@/lib/notes/types"; + +const countWords = (value: string) => value.match(/[A-Za-z0-9_]+/g)?.length ?? 0; + +function paragraph(id: string, text: string): NoteDocument["blocks"][number] { + return { + id, + type: "paragraph", + data: { + text, + }, + }; +} + +function header( + id: string, + level: 1 | 2 | 3 | 4 | 5 | 6, + text: string, +): NoteDocument["blocks"][number] { + return { + id, + type: "header", + data: { + level, + text, + }, + }; +} + +function code(id: string, source: string): NoteDocument["blocks"][number] { + return { + id, + type: "code", + data: { + code: source, + }, + }; +} + +function nestedListItem(depth: number): NoteListItem { + if (depth === 0) { + return { + content: "leaf detail", + items: [], + }; + } + + return { + content: `depth ${depth}`, + items: [nestedListItem(depth - 1)], + }; +} + +describe("chunkNoteDocument", () => { + it("creates semantic chunks from H1-H3 sections and carries heading context", () => { + const document: NoteDocument = { + time: 1, + blocks: [ + header("h1", 1, "Biology"), + paragraph("p1", "Cells regulate transport."), + header("h2", 2, "Photosynthesis"), + paragraph("p2", "Chloroplasts convert light energy."), + header("h3", 3, "Light reactions"), + paragraph("p3", "Photosystems split water."), + header("h4", 4, "Tiny detail"), + paragraph("p4", "H4 stays inside the active H3 chunk."), + ], + }; + + const chunks = chunkNoteDocument(document, { + maxTokens: 80, + countTokens: countWords, + noteId: "note-1", + }); + + expect(chunks).toHaveLength(3); + expect(chunks.map((chunk) => chunk.headingPath.map((heading) => heading.text))).toEqual([ + ["Biology"], + ["Biology", "Photosynthesis"], + ["Biology", "Photosynthesis", "Light reactions"], + ]); + expect(chunks.map((chunk) => chunk.sourceBlockIds)).toEqual([ + ["p1"], + ["p2"], + ["p3", "h4", "p4"], + ]); + expect(chunks[0]).toMatchObject({ + id: "note-1:chunk:0", + content: "# Biology\n\nCells regulate transport.", + tokenCount: 4, + isOversized: false, + }); + expect(chunks[2].content).toBe( + [ + "# Biology", + "## Photosynthesis", + "### Light reactions", + "", + "Photosystems split water.", + "", + "#### Tiny detail", + "", + "H4 stays inside the active H3 chunk.", + ].join("\n"), + ); + }); + + it("splits chunks within a section when adding another block would exceed the token limit", () => { + const document: NoteDocument = { + time: 1, + blocks: [ + header("h1", 1, "RAG"), + header("h2", 2, "Retrieval"), + paragraph("p1", "semantic search ranks context"), + paragraph("p2", "query embeddings guide generation"), + ], + }; + + const chunks = chunkNoteDocument(document, { + maxTokens: 6, + countTokens: countWords, + }); + + expect(chunks).toHaveLength(2); + expect(chunks.map((chunk) => chunk.content)).toEqual([ + "# RAG\n## Retrieval\n\nsemantic search ranks context", + "# RAG\n## Retrieval\n\nquery embeddings guide generation", + ]); + expect(chunks.map((chunk) => chunk.tokenCount)).toEqual([6, 6]); + expect(chunks.every((chunk) => chunk.headingPath.map((heading) => heading.text).join(" > ") === "RAG > Retrieval")).toBe( + true, + ); + }); + + it("serializes deeply nested lists in document order without dropping nested content", () => { + const document: NoteDocument = { + time: 1, + blocks: [ + header("h1", 1, "Algorithms"), + { + id: "list-1", + type: "list", + data: { + style: "unordered", + items: [nestedListItem(8)], + }, + }, + ], + }; + + const chunks = chunkNoteDocument(document, { + maxTokens: 80, + countTokens: countWords, + }); + + expect(chunks).toHaveLength(1); + expect(chunks[0].sourceBlockIds).toEqual(["list-1"]); + expect(chunks[0].content).toContain("- depth 8"); + expect(chunks[0].content).toContain(" - leaf detail"); + expect(chunks[0].tokenCount).toBe(19); + }); + + it("skips empty headings, empty blocks, and blank list items without creating empty chunks", () => { + const document: NoteDocument = { + time: 1, + blocks: [ + header("empty-heading", 1, " "), + paragraph("empty-paragraph", "
"), + { + id: "empty-list", + type: "list", + data: { + style: "unordered", + items: [ + { + content: " ", + items: [], + }, + ], + }, + }, + ], + }; + + expect( + chunkNoteDocument(document, { + maxTokens: 20, + countTokens: countWords, + }), + ).toEqual([]); + }); + + it("keeps an oversized code block intact and isolates following content into a new chunk", () => { + const document: NoteDocument = { + time: 1, + blocks: [ + header("h1", 1, "Runtime"), + code("code-1", "const alpha = 1;\nconst beta = 2;\nconst gamma = alpha + beta;"), + paragraph("p1", "The code initializes values."), + ], + }; + + const chunks = chunkNoteDocument(document, { + maxTokens: 5, + countTokens: countWords, + }); + + expect(chunks).toHaveLength(2); + expect(chunks[0]).toMatchObject({ + sourceBlockIds: ["code-1"], + isOversized: true, + splitReason: "oversized-block", + }); + expect(chunks[0].content).toBe( + [ + "# Runtime", + "", + "```", + "const alpha = 1;", + "const beta = 2;", + "const gamma = alpha + beta;", + "```", + ].join("\n"), + ); + expect(chunks[1]).toMatchObject({ + sourceBlockIds: ["p1"], + content: "# Runtime\n\nThe code initializes values.", + isOversized: false, + }); + }); + + it("counts each non-empty heading or block once so token accounting stays linear", () => { + const countTokens = vi.fn(countWords); + const document: NoteDocument = { + time: 1, + blocks: [ + header("h1", 1, "Linear accounting"), + paragraph("p1", "first block"), + paragraph("p2", "second block"), + paragraph("p3", "third block"), + ], + }; + + chunkNoteDocument(document, { + maxTokens: 6, + countTokens, + }); + + expect(countTokens).toHaveBeenCalledTimes(4); + expect(countTokens.mock.calls.map(([value]) => value)).toEqual([ + "Linear accounting", + "first block", + "second block", + "third block", + ]); + }); +}); diff --git a/lib/rag/chunking.ts b/lib/rag/chunking.ts new file mode 100644 index 0000000..a8340cc --- /dev/null +++ b/lib/rag/chunking.ts @@ -0,0 +1,302 @@ +import TurndownService from "turndown"; +import { gfm } from "turndown-plugin-gfm"; +import type { NoteBlock, NoteDocument, NoteListBlockData, NoteListItem } from "@/lib/notes/types"; + +const DEFAULT_MAX_TOKENS = 700; + +const turndown = new TurndownService({ + bulletListMarker: "-", + codeBlockStyle: "fenced", + emDelimiter: "*", + headingStyle: "atx", + linkStyle: "inlined", + strongDelimiter: "**", +}); + +turndown.use(gfm); + +export type ChunkHeading = { + level: 1 | 2 | 3; + text: string; + blockId?: string; +}; + +export type ChunkSplitReason = "heading-boundary" | "token-limit" | "oversized-block" | "end-of-document"; + +export type NoteChunk = { + id: string; + chunkIndex: number; + headingPath: ChunkHeading[]; + sourceBlockIds: string[]; + content: string; + tokenCount: number; + isOversized: boolean; + splitReason: ChunkSplitReason; +}; + +export type ChunkNoteDocumentOptions = { + maxTokens?: number; + noteId?: string; + countTokens?: (value: string) => number; +}; + +type SerializedBlock = { + blockId: string; + markdown: string; + tokenText: string; +}; + +type PendingChunk = { + bodyBlocks: string[]; + sourceBlockIds: string[]; + bodyTokenCount: number; + isOversized: boolean; +}; + +const emptyPendingChunk = (): PendingChunk => ({ + bodyBlocks: [], + sourceBlockIds: [], + bodyTokenCount: 0, + isOversized: false, +}); + +function defaultCountTokens(value: string) { + return value.trim().split(/\s+/).filter(Boolean).length; +} + +function htmlToMarkdown(value: string) { + const normalized = value.replace(/\n{3,}/g, "\n\n").trim(); + + if (!/<\/?[a-z][\s\S]*>/i.test(normalized) && !/&[a-z#0-9]+;/i.test(normalized)) { + return normalized; + } + + return turndown + .turndown(normalized) + .replace(/\n{3,}/g, "\n\n") + .trim(); +} + +function htmlToPlainText(value: string) { + return value + .replace(//gi, "\n") + .replace(/<\/(p|div|li|blockquote|h[1-6])>/gi, "\n") + .replace(/<[^>]+>/g, "") + .replace(/ /g, " ") + .replace(/&/g, "&") + .replace(/</g, "<") + .replace(/>/g, ">") + .replace(/"/g, '"') + .replace(/'/g, "'") + .replace(/\n{3,}/g, "\n\n") + .trim(); +} + +function createCodeFence(code: string) { + const matches = code.match(/`+/g) ?? []; + const longestFence = matches.reduce((longest, match) => Math.max(longest, match.length), 0); + + return "`".repeat(Math.max(3, longestFence + 1)); +} + +function prefixLines(value: string, prefix: string) { + return value + .split("\n") + .map((line) => `${prefix}${line}`) + .join("\n"); +} + +function serializeListItems(items: NoteListItem[], style: NoteListBlockData["style"], start = 1) { + const lines: string[] = []; + const stack = items + .map((item, index) => ({ item, depth: 0, index, start })) + .reverse(); + + while (stack.length > 0) { + const { item, depth, index, start: itemStart } = stack.pop()!; + const content = htmlToMarkdown(item.content); + const hasContent = content.length > 0; + + if (hasContent) { + const indent = " ".repeat(depth); + const marker = + style === "checklist" + ? `- [${item.meta?.checked ? "x" : " "}]` + : style === "ordered" + ? `${itemStart + index}.` + : "-"; + lines.push(`${indent}${marker} ${content}`); + } + + const childDepth = hasContent ? depth + 1 : depth; + for (let childIndex = item.items.length - 1; childIndex >= 0; childIndex -= 1) { + stack.push({ + item: item.items[childIndex], + depth: childDepth, + index: childIndex, + start: 1, + }); + } + } + + return lines.join("\n"); +} + +function serializeBlock(block: NoteBlock, fallbackId: string): SerializedBlock | null { + switch (block.type) { + case "paragraph": { + const markdown = htmlToMarkdown(block.data.text); + return markdown ? { blockId: block.id ?? fallbackId, markdown, tokenText: markdown } : null; + } + case "header": { + const content = htmlToMarkdown(block.data.text); + if (!content) { + return null; + } + + const level = Math.min(Math.max(block.data.level, 1), 6); + return { + blockId: block.id ?? fallbackId, + markdown: `${"#".repeat(level)} ${content}`, + tokenText: content, + }; + } + case "list": { + const start = typeof block.data.meta?.start === "number" ? block.data.meta.start : 1; + const markdown = serializeListItems(block.data.items, block.data.style, start); + return markdown ? { blockId: block.id ?? fallbackId, markdown, tokenText: markdown } : null; + } + case "quote": { + const quoteText = htmlToMarkdown(block.data.text); + const caption = htmlToMarkdown(block.data.caption); + const markdown = [quoteText, caption] + .filter(Boolean) + .map((section) => prefixLines(section, "> ")) + .join("\n>\n"); + return markdown ? { blockId: block.id ?? fallbackId, markdown, tokenText: markdown } : null; + } + case "code": { + if (!block.data.code.trim()) { + return null; + } + + const fence = createCodeFence(block.data.code); + return { + blockId: block.id ?? fallbackId, + markdown: `${fence}\n${block.data.code}\n${fence}`, + tokenText: block.data.code, + }; + } + case "image": { + const altText = htmlToPlainText(block.data.caption) || "Image"; + const caption = htmlToMarkdown(block.data.caption); + const imageLine = `![${altText}](${block.data.file.url})`; + const markdown = caption ? `${imageLine}\n\n${caption}` : imageLine; + return { blockId: block.id ?? fallbackId, markdown, tokenText: caption || altText }; + } + case "math": { + if (!block.data.latex.trim()) { + return null; + } + + return { + blockId: block.id ?? fallbackId, + markdown: `$$\n${block.data.latex}\n$$`, + tokenText: block.data.latex, + }; + } + default: + return null; + } +} + +function renderHeadingPath(headingPath: ChunkHeading[]) { + return headingPath.map((heading) => `${"#".repeat(heading.level)} ${heading.text}`).join("\n"); +} + +function renderChunkContent(headingPath: ChunkHeading[], bodyBlocks: string[]) { + return [renderHeadingPath(headingPath), bodyBlocks.join("\n\n")].filter(Boolean).join("\n\n"); +} + +function getHeadingTokenCount(headingPath: ChunkHeading[], countTokens: (value: string) => number) { + return headingPath.reduce((total, heading) => total + countTokens(heading.text), 0); +} + +function isChunkBoundaryLevel(level: number): level is 1 | 2 | 3 { + return level === 1 || level === 2 || level === 3; +} + +export function chunkNoteDocument(document: NoteDocument, options: ChunkNoteDocumentOptions = {}): NoteChunk[] { + const maxTokens = Math.max(1, options.maxTokens ?? DEFAULT_MAX_TOKENS); + const countTokens = options.countTokens ?? defaultCountTokens; + const noteId = options.noteId ?? "note"; + const chunks: NoteChunk[] = []; + let headingPath: ChunkHeading[] = []; + let headingTokenCount = 0; + let pending = emptyPendingChunk(); + + const flush = (splitReason: ChunkSplitReason) => { + if (pending.bodyBlocks.length === 0) { + return; + } + + const chunkIndex = chunks.length; + chunks.push({ + id: `${noteId}:chunk:${chunkIndex}`, + chunkIndex, + headingPath: headingPath.map((heading) => ({ ...heading })), + sourceBlockIds: [...pending.sourceBlockIds], + content: renderChunkContent(headingPath, pending.bodyBlocks), + tokenCount: headingTokenCount + pending.bodyTokenCount, + isOversized: pending.isOversized, + splitReason, + }); + pending = emptyPendingChunk(); + }; + + for (const [blockIndex, block] of document.blocks.entries()) { + const fallbackId = `block:${blockIndex}`; + + if (block.type === "header" && isChunkBoundaryLevel(block.data.level)) { + const headingText = htmlToMarkdown(block.data.text); + if (!headingText) { + continue; + } + + flush("heading-boundary"); + headingPath = headingPath.filter((heading) => heading.level < block.data.level); + headingPath.push({ + level: block.data.level, + text: headingText, + blockId: block.id, + }); + headingTokenCount = getHeadingTokenCount(headingPath, countTokens); + continue; + } + + const serialized = serializeBlock(block, fallbackId); + if (!serialized) { + continue; + } + + const blockTokenCount = countTokens(serialized.tokenText); + const projectedTokenCount = headingTokenCount + pending.bodyTokenCount + blockTokenCount; + if (pending.bodyBlocks.length > 0 && projectedTokenCount > maxTokens) { + flush("token-limit"); + } + + const tokenCount = headingTokenCount + blockTokenCount; + pending.bodyBlocks.push(serialized.markdown); + pending.sourceBlockIds.push(serialized.blockId); + pending.bodyTokenCount += blockTokenCount; + + if (tokenCount > maxTokens) { + pending.isOversized = true; + flush("oversized-block"); + } + } + + flush("end-of-document"); + + return chunks; +}