Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/cell_annotator/_response_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,13 @@ class BaseOutput(BaseModel):

@classmethod
def default_failure(cls: type["BaseOutput"], failure_reason: str = "Manual fallback due to model failure."):
"""Return a default output in case of failure, with a custom failure reason."""
return cls(reason_for_failure=failure_reason)
"""Return a default output in case of failure, with a custom failure reason.

Uses ``model_construct`` so it cannot mask the upstream error with a
Pydantic ``ValidationError`` if a subclass declares a required field
without a default. Defaulted fields are still populated.
"""
return cls.model_construct(reason_for_failure=failure_reason)


class CellTypeColor(BaseOutput):
Expand Down
92 changes: 55 additions & 37 deletions src/cell_annotator/model/_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,28 @@ def list_available_models(self) -> list[str]:
def _list_models_impl(self) -> list[str]:
"""Provider-specific implementation for listing models."""

def _fail(
self,
response_format: type[BaseOutput],
failure_reason: str,
*,
exc: BaseException | None = None,
) -> BaseOutput:
"""
Build a failure response with consistent logging across providers.

Call this at every provider error boundary instead of
``response_format.default_failure(...)`` directly. When ``exc`` is
given, the full traceback is recorded via ``exc_info=True`` so CI
logs contain enough context to diagnose upstream API errors.
"""
provider_name = self.__class__.__name__
if exc is not None:
logger.error("[%s] %s", provider_name, failure_reason, exc_info=True)
else:
logger.warning("[%s] %s", provider_name, failure_reason)
return response_format.default_failure(failure_reason=failure_reason)


class OpenAIProvider(LLMProvider):
"""OpenAI provider implementation."""
Expand Down Expand Up @@ -184,17 +206,15 @@ def query(
if response.parsed:
return response.parsed
elif response.refusal:
failure_reason = "Model refused to respond: %s"
logger.warning(failure_reason, response.refusal)
return response_format.default_failure(failure_reason=failure_reason % response.refusal)
return self._fail(response_format, f"Model refused to respond: {response.refusal}")
else:
failure_reason = "Unknown model failure."
logger.warning(failure_reason)
return response_format.default_failure(failure_reason=failure_reason)
except openai.LengthFinishReasonError:
failure_reason = "Maximum number of tokens exceeded. Try increasing `max_completion_tokens`."
logger.warning(failure_reason)
return response_format.default_failure(failure_reason=failure_reason)
return self._fail(response_format, "Unknown model failure.")
except openai.LengthFinishReasonError as e:
return self._fail(
response_format,
"Maximum number of tokens exceeded. Try increasing `max_completion_tokens`.",
exc=e,
)
except openai.OpenAIError as e:
logger.warning(
"Structured parse failed for model '%s'. Falling back to JSON-mode query. Error: %s", model, str(e)
Expand Down Expand Up @@ -297,10 +317,9 @@ def _query_with_json_fallback(
raw_content = completion.choices[0].message.content
text = self._coerce_text_content(raw_content)
if not text:
return response_format.default_failure(
failure_reason=(
f"Model returned empty content during JSON fallback. Original parse error: {fallback_error}"
)
return self._fail(
response_format,
f"Model returned empty content during JSON fallback. Original parse error: {fallback_error}",
)

# Strict JSON parsing first.
Expand Down Expand Up @@ -341,19 +360,17 @@ def _query_with_json_fallback(
except Exception: # noqa: BLE001
pass

return response_format.default_failure(
failure_reason=(
"Could not parse structured JSON response from model output. "
f"Original parse error: {fallback_error}"
)
return self._fail(
response_format,
f"Could not parse structured JSON response from model output. Original parse error: {fallback_error}",
)
except Exception as fallback_exception: # noqa: BLE001
return response_format.default_failure(
failure_reason=(
"Fallback JSON query failed. "
f"Original parse error: {fallback_error}. "
f"Fallback error: {str(fallback_exception)}"
)
return self._fail(
response_format,
"Fallback JSON query failed. "
f"Original parse error: {fallback_error}. "
f"Fallback error: {str(fallback_exception)}",
exc=fallback_exception,
)

def _coerce_text_content(self, content) -> str:
Expand Down Expand Up @@ -564,13 +581,10 @@ def query(
if hasattr(response, "parsed") and response.parsed:
return response.parsed
else:
# Fallback if parsing fails
failure_reason = "Gemini failed to parse structured response"
return response_format.default_failure(failure_reason=failure_reason)
return self._fail(response_format, "Gemini failed to parse structured response")

except (ValueError, TypeError, KeyError) as e:
failure_reason = f"Gemini API error: {str(e)}"
return response_format.default_failure(failure_reason=failure_reason)
return self._fail(response_format, f"Gemini API error: {str(e)}", exc=e)


class AnthropicProvider(LLMProvider):
Expand Down Expand Up @@ -681,16 +695,20 @@ def query(
if isinstance(input_data, dict):
return response_format(**input_data)

# Fallback if no tool use found
failure_reason = "No structured response found in tool use output"
return response_format.default_failure(failure_reason=failure_reason)
# No tool_use block — Claude either refused or returned plain text.
# Capture stop_reason and content-block types so CI logs can tell the difference.
stop_reason = getattr(response, "stop_reason", "<unknown>")
block_types = [getattr(b, "type", "<unknown>") for b in (response.content or [])]
return self._fail(
response_format,
f"No structured response found in tool use output "
f"(stop_reason={stop_reason!r}, content_block_types={block_types})",
)

except anthropic.AnthropicError as e:
failure_reason = f"Anthropic API error: {str(e)}"
return response_format.default_failure(failure_reason=failure_reason)
return self._fail(response_format, f"Anthropic API error: {str(e)}", exc=e)
except (ValueError, TypeError, KeyError) as e:
failure_reason = f"Response parsing error: {str(e)}"
return response_format.default_failure(failure_reason=failure_reason)
return self._fail(response_format, f"Response parsing error: {str(e)}", exc=e)


# Provider registry - initialize lazily to avoid import errors
Expand Down
Loading