Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,62 @@
branch_labels = None
depends_on = None

BACKFILL_BATCH_SIZE = 1000


def upgrade() -> None:
op.add_column(
"regulation_chunk",
sa.Column("search_tsvector", postgresql.TSVECTOR(), nullable=True),
)

op.execute(
"""
UPDATE regulation_chunk AS rc
SET search_tsvector = to_tsvector(
'simple',
COALESCE(rc.chunk_text, '') || ' ' ||
COALESCE(rd.content, '') || ' ' ||
COALESCE(rc.keywords::text, '')
bind = op.get_bind()
last_processed_id = 0
while True:
result = bind.execute(
sa.text(
"""
WITH target_chunks AS (
SELECT regulation_chunk_id
FROM regulation_chunk
WHERE regulation_chunk_id > :last_processed_id
ORDER BY regulation_chunk_id
LIMIT :batch_size
),
updated_chunks AS (
UPDATE regulation_chunk AS rc
SET search_tsvector = to_tsvector(
'simple',
COALESCE(rc.chunk_text, '') || ' ' ||
COALESCE(rd.content, '') || ' ' ||
COALESCE(
(
SELECT string_agg(keyword.value, ' ')
FROM jsonb_array_elements_text(rc.keywords) AS keyword(value)
),
''
)
)
FROM regulation_document AS rd
WHERE rd.regulation_document_id = rc.regulation_document_id
AND rc.regulation_chunk_id IN (
SELECT regulation_chunk_id FROM target_chunks
)
RETURNING rc.regulation_chunk_id
)
SELECT max(regulation_chunk_id) AS last_processed_id
FROM updated_chunks
"""
),
{
"batch_size": BACKFILL_BATCH_SIZE,
"last_processed_id": last_processed_id,
},
)
FROM regulation_document AS rd
WHERE rd.regulation_document_id = rc.regulation_document_id
"""
)
next_last_processed_id = result.scalar()
if next_last_processed_id is None:
break
last_processed_id = next_last_processed_id

op.create_index(
"idx_regulation_chunk_search_tsvector",
Expand Down
60 changes: 26 additions & 34 deletions app/repositories/regulation_chunk_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import hashlib
from datetime import datetime
from datetime import timezone
from typing import Optional
from uuid import uuid4

from sqlalchemy import func
Expand All @@ -18,6 +19,23 @@
from app.db.models.regulation_chunk import RegulationChunk


def _build_search_vector_text(
chunk_text: Optional[str],
document_content: Optional[str],
keywords: Optional[list[str]],
) -> str:
keyword_text = " ".join(keywords or [])
return " ".join(
part
for part in [
chunk_text or "",
document_content or "",
keyword_text,
]
if part
)


def create_regulation_chunks_for_document(
db: Session,
regulation_document: RegulationDocument,
Expand All @@ -44,6 +62,14 @@ def create_regulation_chunks_for_document(
chunk_index=index - 1,
chunk_text=chunk_text,
keywords=regulation_document.keywords,
search_tsvector=func.to_tsvector(
"simple",
_build_search_vector_text(
chunk_text=chunk_text,
document_content=regulation_document.content,
keywords=regulation_document.keywords,
),
),
chunk_hash=hashlib.sha256(chunk_text.encode("utf-8")).hexdigest(),
embedding_model=settings.openai_embedding_model,
embedding=embedding,
Expand All @@ -53,45 +79,11 @@ def create_regulation_chunks_for_document(
created_chunks.append(regulation_chunk)

db.flush()
refresh_search_vectors_for_chunks(
db,
[
regulation_chunk.regulation_chunk_id
for regulation_chunk in created_chunks
if regulation_chunk.regulation_chunk_id is not None
],
)
for regulation_chunk in created_chunks:
db.refresh(regulation_chunk)
return created_chunks


def refresh_search_vectors_for_chunks(db: Session, regulation_chunk_ids: list[int]) -> int:
"""저장된 청크 검색 텍스트를 tsvector 컬럼에 반영합니다."""

if not regulation_chunk_ids:
return 0

result = db.execute(
text(
"""
UPDATE regulation_chunk AS rc
SET search_tsvector = to_tsvector(
'simple',
COALESCE(rc.chunk_text, '') || ' ' ||
COALESCE(rd.content, '') || ' ' ||
COALESCE(rc.keywords::text, '')
)
FROM regulation_document AS rd
WHERE rd.regulation_document_id = rc.regulation_document_id
AND rc.regulation_chunk_id = ANY(:regulation_chunk_ids)
"""
),
{"regulation_chunk_ids": regulation_chunk_ids},
)
return result.rowcount or 0


def deactivate_chunks_for_document(db: Session, regulation_document_id: int) -> int:
statement = (
update(RegulationChunk)
Expand Down
10 changes: 8 additions & 2 deletions app/services/chat_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,10 @@ def _answer_single_dormitory_chat(
)

if expanded_query != question:
expanded_query_embedding = create_query_embedding(expanded_query)
if expanded_query == retrieval_query:
expanded_query_embedding = query_embedding
else:
expanded_query_embedding = create_query_embedding(expanded_query)
rewritten_query = expanded_query

# 1차: 확장 query로 사용자 dormitory + 공통 문서 검색
Expand Down Expand Up @@ -485,7 +488,10 @@ def _answer_unspecified_dormitory_chat(


if expanded_query != question:
expanded_query_embedding = create_query_embedding(expanded_query)
if expanded_query == retrieval_query:
expanded_query_embedding = query_embedding
else:
expanded_query_embedding = create_query_embedding(expanded_query)
rewritten_query = expanded_query

expanded_chunks = search_hybrid_chunks_for_dormitories(
Expand Down
80 changes: 80 additions & 0 deletions tests/test_chat_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,86 @@ def fake_generate_answer(_question, chunks):
assert chat_log.retrieval_version == settings.chat_retrieval_version_grouped


def test_answer_chat_question_reuses_pre_expanded_embedding_for_grouped_fallback(
monkeypatch: pytest.MonkeyPatch,
) -> None:
db = FakeSession()
finalize_db = FakeSession()
chat_session = _build_chat_session()
chat_log = _build_chat_log()
embedding_queries: list[str] = []
fallback_embeddings: list[list[float]] = []
generate_call_count = 0

monkeypatch.setattr(chat_service, "get_chat_session", lambda *_args, **_kwargs: chat_session)
monkeypatch.setattr(chat_service, "create_chat_log", lambda *_args, **_kwargs: chat_log)
monkeypatch.setattr(chat_service, "get_chat_log_by_id", lambda *_args, **_kwargs: chat_log)
monkeypatch.setattr(chat_service, "touch_chat_session_activity", lambda *_args, **_kwargs: chat_session)
monkeypatch.setattr(chat_service, "get_session_factory", lambda: (lambda: finalize_db))
monkeypatch.setattr(chat_service, "validate_question", lambda *_args, **_kwargs: (True, "통금"))
monkeypatch.setattr(chat_service, "expand_query_for_retrieval", lambda *_args, **_kwargs: "통금 시간 귀가 제한")

def fake_create_query_embedding(query: str) -> list[float]:
embedding_queries.append(query)
return [0.4, 0.5, 0.6]

def fake_generate_answer(_question, _chunks):
nonlocal generate_call_count
generate_call_count += 1
if generate_call_count == 1:
return AnswerGenerationResult(
answer="관련 정보를 찾을 수 없습니다.",
source_url="",
cited_regulation_chunk_ids=[],
)
return AnswerGenerationResult(
answer="통금은 생활관별 규정을 확인해야 합니다.",
source_url="https://example.com/rules",
cited_regulation_chunk_ids=[1001],
)

monkeypatch.setattr(chat_service, "create_query_embedding", fake_create_query_embedding)

def fake_search_hybrid_chunks_for_dormitories(*_args, **kwargs):
if kwargs["query_text"] != "통금":
fallback_embeddings.append(kwargs["query_embedding"])
return [
{
"regulation_chunk_id": 1001,
"document_id": "curfew",
"document_version": "v1",
"chunk_id": "chunk-1",
"content": "통금 시간 안내",
"source": "생활관 규정집",
"source_url": "https://example.com/rules",
"similarity": 0.6,
"retrieval_group": "제1학생생활관",
}
]

monkeypatch.setattr(
chat_service,
"search_hybrid_chunks_for_dormitories",
fake_search_hybrid_chunks_for_dormitories,
)
monkeypatch.setattr(chat_service, "generate_answer", fake_generate_answer)
monkeypatch.setattr(chat_service, "create_chat_retrieval_results", lambda *_args, **_kwargs: [])
monkeypatch.setattr(chat_service, "mark_chat_retrieval_results_used_in_answer", lambda *_args, **_kwargs: [])

response = chat_service.answer_chat_question(
db,
ChatRequest(
session_id="session-123",
question="통금",
),
)

assert response.answer_status == "SUCCESS"
assert embedding_queries == ["통금 시간 귀가 제한"]
assert fallback_embeddings == [[0.4, 0.5, 0.6]]
assert chat_log.rewritten_query == "통금 시간 귀가 제한"


def test_answer_chat_question_returns_room_floor_without_retrieval(monkeypatch: pytest.MonkeyPatch) -> None:
db = FakeSession()
chat_session = _build_chat_session()
Expand Down
34 changes: 13 additions & 21 deletions tests/test_regulation_chunk_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_create_regulation_chunks_for_document_maps_document_fields_to_model() -
regulation_document_id=7,
document_id="dorm-rule-001",
document_version="2026.04",
content="생활관 외박 신청은 통합 포털에서 처리한다.",
keywords=["외박", "외출"],
)
regulation_chunk_repository.get_settings = lambda: SimpleNamespace(
Expand All @@ -61,10 +62,11 @@ def test_create_regulation_chunks_for_document_maps_document_fields_to_model() -
assert regulation_chunk.regulation_document_id == 7
assert regulation_chunk.document_version == "2026.04"
assert regulation_chunk.keywords == ["외박", "외출"]
assert "to_tsvector" in str(regulation_chunk.search_tsvector)
assert regulation_chunk.embedding_model == "text-embedding-3-small"
assert db.flush_called is True
assert db.executed_params == [{"regulation_chunk_ids": [1]}]
assert "search_tsvector = to_tsvector" in str(db.executed_statements[0])
assert db.executed_params == []
assert db.executed_statements == []
assert db.refresh_called_values == [regulation_chunk]


Expand All @@ -74,6 +76,7 @@ def test_create_regulation_chunks_for_document_rejects_embedding_count_mismatch(
regulation_document_id=7,
document_id="dorm-rule-001",
document_version="2026.04",
content="생활관 외박 신청은 통합 포털에서 처리한다.",
keywords=["외박", "외출"],
)
regulation_chunk_repository.get_settings = lambda: SimpleNamespace(
Expand Down Expand Up @@ -188,24 +191,13 @@ def execute(self, _statement, params):
assert executed_params[0]["dormitories"] == ["제1학생생활관", "제2학생생활관"]


def test_refresh_search_vectors_for_chunks_updates_tsvector_from_document_content() -> None:
executed_statements: list[object] = []
executed_params: list[dict] = []

class RefreshSession:
def execute(self, statement, params):
executed_statements.append(statement)
executed_params.append(params)
return SimpleNamespace(rowcount=2)

result = regulation_chunk_repository.refresh_search_vectors_for_chunks(
RefreshSession(),
[10, 11],
def test_build_search_vector_text_joins_keywords_without_json_syntax() -> None:
result = regulation_chunk_repository._build_search_vector_text(
chunk_text="외박 신청",
document_content="통합 포털에서 신청합니다.",
keywords=["외박", "외출"],
)

executed_sql = str(executed_statements[0])
assert result == 2
assert "UPDATE regulation_chunk AS rc" in executed_sql
assert "search_tsvector = to_tsvector" in executed_sql
assert "COALESCE(rd.content, '')" in executed_sql
assert executed_params == [{"regulation_chunk_ids": [10, 11]}]
assert result == "외박 신청 통합 포털에서 신청합니다. 외박 외출"
assert "[" not in result
assert "]" not in result
Loading