diff --git a/tests/test_function_app_endpoints.py b/tests/test_function_app_endpoints.py index b72f519ff..06f20c693 100644 --- a/tests/test_function_app_endpoints.py +++ b/tests/test_function_app_endpoints.py @@ -46,6 +46,26 @@ def _mock_request( return req +def _capture_sse_http_response(monkeypatch, app_module, captured: dict) -> None: + """Patch HttpResponse so SSE bodies are captured as bytes or generators.""" + import inspect + + import azure.functions as _af + + _real_HttpResponse = _af.HttpResponse + + def _capturing_HttpResponse(body=None, **kwargs): + if body is not None and inspect.isgenerator(body): + consumed = b"".join(body) + captured["sse_body"] = consumed + return _real_HttpResponse(consumed, **kwargs) + if isinstance(body, (bytes, bytearray)): + captured["sse_body"] = bytes(body) + return _real_HttpResponse(body, **kwargs) + + monkeypatch.setattr(app_module.func, "HttpResponse", _capturing_HttpResponse) + + def _install_fake_quantum_trainer_module( monkeypatch: pytest.MonkeyPatch, capture: dict | None = None, @@ -526,22 +546,8 @@ def test_chat_stream_whitespace_only_input_text_block_message(self, app_module): def test_chat_stream_guardrail_blocks_prompt_injection(self, app_module, monkeypatch): """POST /api/chat/stream should emit safe fallback SSE when prompt is blocked.""" - import inspect - - import azure.functions as _af - captured: dict = {"sse_body": b""} - _real_HttpResponse = _af.HttpResponse - - def _capturing_HttpResponse(body=None, **kwargs): - if body is not None and inspect.isgenerator(body): - consumed = b"".join(body) - captured["sse_body"] = consumed - return _real_HttpResponse(consumed, **kwargs) - return _real_HttpResponse(body, **kwargs) - - monkeypatch.setattr(app_module.func, "HttpResponse", - _capturing_HttpResponse) + _capture_sse_http_response(monkeypatch, app_module, captured) req = _mock_request( "POST", @@ -562,10 +568,6 @@ def _capturing_HttpResponse(body=None, **kwargs): def test_chat_stream_memory_injection(self, app_module, monkeypatch): """POST /api/chat/stream should call memory helpers and include count in meta SSE event.""" - import inspect - - import azure.functions as _af - captured: dict = {"embedding": None, "session_id": None, "sse_body": b""} @@ -577,18 +579,7 @@ def _fake_similar(query_emb, top_k=5, session_id=None, min_similarity=0.0): captured["session_id"] = session_id return [{"content": "Previous answer about widgets", "similarity": 0.88}] - # Patch func.HttpResponse inside function_app so streaming body (generator) is consumed - _real_HttpResponse = _af.HttpResponse - - def _capturing_HttpResponse(body=None, **kwargs): - if body is not None and inspect.isgenerator(body): - consumed = b"".join(body) - captured["sse_body"] = consumed - return _real_HttpResponse(consumed, **kwargs) - return _real_HttpResponse(body, **kwargs) - - monkeypatch.setattr(app_module.func, "HttpResponse", - _capturing_HttpResponse) + _capture_sse_http_response(monkeypatch, app_module, captured) monkeypatch.setattr(app_module, "generate_embedding", _fake_embedding) monkeypatch.setattr( app_module, "fetch_similar_messages", _fake_similar) @@ -623,10 +614,6 @@ def _capturing_HttpResponse(body=None, **kwargs): def test_chat_stream_emits_done_sentinel(self, app_module, monkeypatch): """POST /api/chat/stream should terminate SSE with data: [DONE].""" - import inspect - - import azure.functions as _af - captured: dict = {"sse_body": b""} class _FakeProvider: @@ -647,20 +634,10 @@ def complete(self, messages, stream=False): monkeypatch.setattr( app_module, "fetch_similar_messages", - lambda query_emb, top_k=5, session_id=None: [], + lambda query_emb, top_k=5, session_id=None, min_similarity=0.0: [], ) - _real_HttpResponse = _af.HttpResponse - - def _capturing_HttpResponse(body=None, **kwargs): - if body is not None and inspect.isgenerator(body): - consumed = b"".join(body) - captured["sse_body"] = consumed - return _real_HttpResponse(consumed, **kwargs) - return _real_HttpResponse(body, **kwargs) - - monkeypatch.setattr(app_module.func, "HttpResponse", - _capturing_HttpResponse) + _capture_sse_http_response(monkeypatch, app_module, captured) req = _mock_request( "POST", @@ -829,10 +806,6 @@ def test_agi_reason_validation_error_when_missing_input(self, app_module): assert "validation error" in data["error"].lower() def test_agi_stream_emits_done_sentinel(self, app_module, monkeypatch): - import inspect - - import azure.functions as _af - captured: dict = {"sse_body": b""} class _FakeAgiProvider: @@ -853,17 +826,7 @@ def set_goal(self, _goal: str): ), ) - _real_HttpResponse = _af.HttpResponse - - def _capturing_HttpResponse(body=None, **kwargs): - if body is not None and inspect.isgenerator(body): - consumed = b"".join(body) - captured["sse_body"] = consumed - return _real_HttpResponse(consumed, **kwargs) - return _real_HttpResponse(body, **kwargs) - - monkeypatch.setattr(app_module.func, "HttpResponse", - _capturing_HttpResponse) + _capture_sse_http_response(monkeypatch, app_module, captured) req = _mock_request( "POST", diff --git a/tests/test_quantum_integration.py b/tests/test_quantum_integration.py index bb4d331b7..74949c5ab 100644 --- a/tests/test_quantum_integration.py +++ b/tests/test_quantum_integration.py @@ -33,8 +33,8 @@ def _stub_subprocess(returncode: int = 0, stdout: str = "ok", stderr: str = ""): @pytest.mark.parametrize( "jobs_payload,expected_preset", [ - ([{"name": "another", "preset": "heart", "status": "completed"}], "heart"), - ({"another": {"name": "another", "preset": "heart", "status": "completed"}}, "heart"), + ([{"name": "smoke", "preset": "heart", "status": "completed"}], "heart"), + ({"smoke": {"name": "smoke", "preset": "heart", "status": "completed"}}, "heart"), ], ) def test_run_autorun_job_reads_status_for_list_and_dict_shapes(tmp_path: Path, monkeypatch, jobs_payload, expected_preset): @@ -49,10 +49,10 @@ def test_run_autorun_job_reads_status_for_list_and_dict_shapes(tmp_path: Path, m monkeypatch.setattr("mount.quantum_integration.subprocess.run", lambda *args, **kwargs: _stub_subprocess()) - result = asyncio.run(integration.run_autorun_job("another", dry_run=True)) + result = asyncio.run(integration.run_autorun_job("smoke", dry_run=True)) assert result["success"] is True - assert result["job_name"] == "another" + assert result["job_name"] == "smoke" assert result["dry_run"] is True assert result["status"]["preset"] == expected_preset - assert result["status"]["name"] == "another" + assert result["status"]["name"] == "smoke"