diff --git a/alembic/versions/20260504_0013_add_regulation_chunk_search_tsvector.py b/alembic/versions/20260504_0013_add_regulation_chunk_search_tsvector.py index 4492208..221f0df 100644 --- a/alembic/versions/20260504_0013_add_regulation_chunk_search_tsvector.py +++ b/alembic/versions/20260504_0013_add_regulation_chunk_search_tsvector.py @@ -15,6 +15,8 @@ branch_labels = None depends_on = None +BACKFILL_BATCH_SIZE = 1000 + def upgrade() -> None: op.add_column( @@ -22,19 +24,53 @@ def upgrade() -> None: sa.Column("search_tsvector", postgresql.TSVECTOR(), nullable=True), ) - op.execute( - """ - UPDATE regulation_chunk AS rc - SET search_tsvector = to_tsvector( - 'simple', - COALESCE(rc.chunk_text, '') || ' ' || - COALESCE(rd.content, '') || ' ' || - COALESCE(rc.keywords::text, '') + bind = op.get_bind() + last_processed_id = 0 + while True: + result = bind.execute( + sa.text( + """ + WITH target_chunks AS ( + SELECT regulation_chunk_id + FROM regulation_chunk + WHERE regulation_chunk_id > :last_processed_id + ORDER BY regulation_chunk_id + LIMIT :batch_size + ), + updated_chunks AS ( + UPDATE regulation_chunk AS rc + SET search_tsvector = to_tsvector( + 'simple', + COALESCE(rc.chunk_text, '') || ' ' || + COALESCE(rd.content, '') || ' ' || + COALESCE( + ( + SELECT string_agg(keyword.value, ' ') + FROM jsonb_array_elements_text(rc.keywords) AS keyword(value) + ), + '' + ) + ) + FROM regulation_document AS rd + WHERE rd.regulation_document_id = rc.regulation_document_id + AND rc.regulation_chunk_id IN ( + SELECT regulation_chunk_id FROM target_chunks + ) + RETURNING rc.regulation_chunk_id + ) + SELECT max(regulation_chunk_id) AS last_processed_id + FROM updated_chunks + """ + ), + { + "batch_size": BACKFILL_BATCH_SIZE, + "last_processed_id": last_processed_id, + }, ) - FROM regulation_document AS rd - WHERE rd.regulation_document_id = rc.regulation_document_id - """ - ) + next_last_processed_id = result.scalar() + if next_last_processed_id is None: + break + last_processed_id = next_last_processed_id op.create_index( "idx_regulation_chunk_search_tsvector", diff --git a/app/repositories/regulation_chunk_repository.py b/app/repositories/regulation_chunk_repository.py index 76b3795..254ebdf 100644 --- a/app/repositories/regulation_chunk_repository.py +++ b/app/repositories/regulation_chunk_repository.py @@ -3,6 +3,7 @@ import hashlib from datetime import datetime from datetime import timezone +from typing import Optional from uuid import uuid4 from sqlalchemy import func @@ -18,6 +19,23 @@ from app.db.models.regulation_chunk import RegulationChunk +def _build_search_vector_text( + chunk_text: Optional[str], + document_content: Optional[str], + keywords: Optional[list[str]], +) -> str: + keyword_text = " ".join(keywords or []) + return " ".join( + part + for part in [ + chunk_text or "", + document_content or "", + keyword_text, + ] + if part + ) + + def create_regulation_chunks_for_document( db: Session, regulation_document: RegulationDocument, @@ -44,6 +62,14 @@ def create_regulation_chunks_for_document( chunk_index=index - 1, chunk_text=chunk_text, keywords=regulation_document.keywords, + search_tsvector=func.to_tsvector( + "simple", + _build_search_vector_text( + chunk_text=chunk_text, + document_content=regulation_document.content, + keywords=regulation_document.keywords, + ), + ), chunk_hash=hashlib.sha256(chunk_text.encode("utf-8")).hexdigest(), embedding_model=settings.openai_embedding_model, embedding=embedding, @@ -53,45 +79,11 @@ def create_regulation_chunks_for_document( created_chunks.append(regulation_chunk) db.flush() - refresh_search_vectors_for_chunks( - db, - [ - regulation_chunk.regulation_chunk_id - for regulation_chunk in created_chunks - if regulation_chunk.regulation_chunk_id is not None - ], - ) for regulation_chunk in created_chunks: db.refresh(regulation_chunk) return created_chunks -def refresh_search_vectors_for_chunks(db: Session, regulation_chunk_ids: list[int]) -> int: - """저장된 청크 검색 텍스트를 tsvector 컬럼에 반영합니다.""" - - if not regulation_chunk_ids: - return 0 - - result = db.execute( - text( - """ - UPDATE regulation_chunk AS rc - SET search_tsvector = to_tsvector( - 'simple', - COALESCE(rc.chunk_text, '') || ' ' || - COALESCE(rd.content, '') || ' ' || - COALESCE(rc.keywords::text, '') - ) - FROM regulation_document AS rd - WHERE rd.regulation_document_id = rc.regulation_document_id - AND rc.regulation_chunk_id = ANY(:regulation_chunk_ids) - """ - ), - {"regulation_chunk_ids": regulation_chunk_ids}, - ) - return result.rowcount or 0 - - def deactivate_chunks_for_document(db: Session, regulation_document_id: int) -> int: statement = ( update(RegulationChunk) diff --git a/app/services/chat_service.py b/app/services/chat_service.py index c1e52cf..660c9f2 100644 --- a/app/services/chat_service.py +++ b/app/services/chat_service.py @@ -285,7 +285,10 @@ def _answer_single_dormitory_chat( ) if expanded_query != question: - expanded_query_embedding = create_query_embedding(expanded_query) + if expanded_query == retrieval_query: + expanded_query_embedding = query_embedding + else: + expanded_query_embedding = create_query_embedding(expanded_query) rewritten_query = expanded_query # 1차: 확장 query로 사용자 dormitory + 공통 문서 검색 @@ -485,7 +488,10 @@ def _answer_unspecified_dormitory_chat( if expanded_query != question: - expanded_query_embedding = create_query_embedding(expanded_query) + if expanded_query == retrieval_query: + expanded_query_embedding = query_embedding + else: + expanded_query_embedding = create_query_embedding(expanded_query) rewritten_query = expanded_query expanded_chunks = search_hybrid_chunks_for_dormitories( diff --git a/tests/test_chat_service.py b/tests/test_chat_service.py index fcfbf6e..e14e742 100644 --- a/tests/test_chat_service.py +++ b/tests/test_chat_service.py @@ -265,6 +265,86 @@ def fake_generate_answer(_question, chunks): assert chat_log.retrieval_version == settings.chat_retrieval_version_grouped +def test_answer_chat_question_reuses_pre_expanded_embedding_for_grouped_fallback( + monkeypatch: pytest.MonkeyPatch, +) -> None: + db = FakeSession() + finalize_db = FakeSession() + chat_session = _build_chat_session() + chat_log = _build_chat_log() + embedding_queries: list[str] = [] + fallback_embeddings: list[list[float]] = [] + generate_call_count = 0 + + monkeypatch.setattr(chat_service, "get_chat_session", lambda *_args, **_kwargs: chat_session) + monkeypatch.setattr(chat_service, "create_chat_log", lambda *_args, **_kwargs: chat_log) + monkeypatch.setattr(chat_service, "get_chat_log_by_id", lambda *_args, **_kwargs: chat_log) + monkeypatch.setattr(chat_service, "touch_chat_session_activity", lambda *_args, **_kwargs: chat_session) + monkeypatch.setattr(chat_service, "get_session_factory", lambda: (lambda: finalize_db)) + monkeypatch.setattr(chat_service, "validate_question", lambda *_args, **_kwargs: (True, "통금")) + monkeypatch.setattr(chat_service, "expand_query_for_retrieval", lambda *_args, **_kwargs: "통금 시간 귀가 제한") + + def fake_create_query_embedding(query: str) -> list[float]: + embedding_queries.append(query) + return [0.4, 0.5, 0.6] + + def fake_generate_answer(_question, _chunks): + nonlocal generate_call_count + generate_call_count += 1 + if generate_call_count == 1: + return AnswerGenerationResult( + answer="관련 정보를 찾을 수 없습니다.", + source_url="", + cited_regulation_chunk_ids=[], + ) + return AnswerGenerationResult( + answer="통금은 생활관별 규정을 확인해야 합니다.", + source_url="https://example.com/rules", + cited_regulation_chunk_ids=[1001], + ) + + monkeypatch.setattr(chat_service, "create_query_embedding", fake_create_query_embedding) + + def fake_search_hybrid_chunks_for_dormitories(*_args, **kwargs): + if kwargs["query_text"] != "통금": + fallback_embeddings.append(kwargs["query_embedding"]) + return [ + { + "regulation_chunk_id": 1001, + "document_id": "curfew", + "document_version": "v1", + "chunk_id": "chunk-1", + "content": "통금 시간 안내", + "source": "생활관 규정집", + "source_url": "https://example.com/rules", + "similarity": 0.6, + "retrieval_group": "제1학생생활관", + } + ] + + monkeypatch.setattr( + chat_service, + "search_hybrid_chunks_for_dormitories", + fake_search_hybrid_chunks_for_dormitories, + ) + monkeypatch.setattr(chat_service, "generate_answer", fake_generate_answer) + monkeypatch.setattr(chat_service, "create_chat_retrieval_results", lambda *_args, **_kwargs: []) + monkeypatch.setattr(chat_service, "mark_chat_retrieval_results_used_in_answer", lambda *_args, **_kwargs: []) + + response = chat_service.answer_chat_question( + db, + ChatRequest( + session_id="session-123", + question="통금", + ), + ) + + assert response.answer_status == "SUCCESS" + assert embedding_queries == ["통금 시간 귀가 제한"] + assert fallback_embeddings == [[0.4, 0.5, 0.6]] + assert chat_log.rewritten_query == "통금 시간 귀가 제한" + + def test_answer_chat_question_returns_room_floor_without_retrieval(monkeypatch: pytest.MonkeyPatch) -> None: db = FakeSession() chat_session = _build_chat_session() diff --git a/tests/test_regulation_chunk_repository.py b/tests/test_regulation_chunk_repository.py index cb780c7..4f1dd41 100644 --- a/tests/test_regulation_chunk_repository.py +++ b/tests/test_regulation_chunk_repository.py @@ -42,6 +42,7 @@ def test_create_regulation_chunks_for_document_maps_document_fields_to_model() - regulation_document_id=7, document_id="dorm-rule-001", document_version="2026.04", + content="생활관 외박 신청은 통합 포털에서 처리한다.", keywords=["외박", "외출"], ) regulation_chunk_repository.get_settings = lambda: SimpleNamespace( @@ -61,10 +62,11 @@ def test_create_regulation_chunks_for_document_maps_document_fields_to_model() - assert regulation_chunk.regulation_document_id == 7 assert regulation_chunk.document_version == "2026.04" assert regulation_chunk.keywords == ["외박", "외출"] + assert "to_tsvector" in str(regulation_chunk.search_tsvector) assert regulation_chunk.embedding_model == "text-embedding-3-small" assert db.flush_called is True - assert db.executed_params == [{"regulation_chunk_ids": [1]}] - assert "search_tsvector = to_tsvector" in str(db.executed_statements[0]) + assert db.executed_params == [] + assert db.executed_statements == [] assert db.refresh_called_values == [regulation_chunk] @@ -74,6 +76,7 @@ def test_create_regulation_chunks_for_document_rejects_embedding_count_mismatch( regulation_document_id=7, document_id="dorm-rule-001", document_version="2026.04", + content="생활관 외박 신청은 통합 포털에서 처리한다.", keywords=["외박", "외출"], ) regulation_chunk_repository.get_settings = lambda: SimpleNamespace( @@ -188,24 +191,13 @@ def execute(self, _statement, params): assert executed_params[0]["dormitories"] == ["제1학생생활관", "제2학생생활관"] -def test_refresh_search_vectors_for_chunks_updates_tsvector_from_document_content() -> None: - executed_statements: list[object] = [] - executed_params: list[dict] = [] - - class RefreshSession: - def execute(self, statement, params): - executed_statements.append(statement) - executed_params.append(params) - return SimpleNamespace(rowcount=2) - - result = regulation_chunk_repository.refresh_search_vectors_for_chunks( - RefreshSession(), - [10, 11], +def test_build_search_vector_text_joins_keywords_without_json_syntax() -> None: + result = regulation_chunk_repository._build_search_vector_text( + chunk_text="외박 신청", + document_content="통합 포털에서 신청합니다.", + keywords=["외박", "외출"], ) - executed_sql = str(executed_statements[0]) - assert result == 2 - assert "UPDATE regulation_chunk AS rc" in executed_sql - assert "search_tsvector = to_tsvector" in executed_sql - assert "COALESCE(rd.content, '')" in executed_sql - assert executed_params == [{"regulation_chunk_ids": [10, 11]}] + assert result == "외박 신청 통합 포털에서 신청합니다. 외박 외출" + assert "[" not in result + assert "]" not in result