Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.regex.Matcher;
Expand Down Expand Up @@ -151,13 +152,22 @@ public <ResponseType> ResponseType startQuery(
? modelOverride
: this.config.getModel();

Map<String, Object> responseFormat = Map.of(
"type", "json_schema",
"json_schema", Map.of(
"name", "structured_output",
"schema", schemaObject,
"strict", true
)
);

log.info("Starting LLM query. Model: {}", modelToUse);

OllamaRequest request = new OllamaRequest(
modelToUse,
filledPrompt,
false,
schemaObject
modelToUse,
List.of(new OllamaRequest.Message("user", filledPrompt)),
0.0,
responseFormat
);

final OllamaResponse response = queryLLM(request);
Expand Down Expand Up @@ -215,12 +225,13 @@ private OllamaResponse queryLLM(final OllamaRequest request) throws IOException,

final HttpResponse<String> response = client.send(reqBuilder.build(), HttpResponse.BodyHandlers.ofString());

log.info("RAW OLLAMA HTTP RESPONSE BODY: {}", response.body());
log.debug("RAW OLLAMA HTTP RESPONSE BODY: {}", response.body());

final OllamaResponse result = jsonMapper.readValue(response.body(), OllamaResponse.class);

if (result.getError() != null) {
throw new RuntimeException("Ollama returned error: " + result.getError());
if (response.statusCode() >= 400 || result.getErrorMessage() != null) {
throw new RuntimeException("LLM returned error: " +
(result.getErrorMessage() != null ? result.getErrorMessage() : response.body()));
}

return result;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
package de.unistuttgart.iste.meitrex.common.ollama;

import com.fasterxml.jackson.annotation.JsonProperty;

import java.util.List;
import java.util.Map;

public record OllamaRequest(
@JsonProperty("model") String model,
@JsonProperty("prompt") String prompt,
@JsonProperty("stream") boolean stream,
@JsonProperty("format") Map<String, Object> format
) {}
@JsonProperty("messages") List<Message> messages,
@JsonProperty("temperature") double temperature,
@JsonProperty("response_format") Map<String, Object> responseFormat
) {
public record Message(
@JsonProperty("role") String role,
@JsonProperty("content") String content
) {}
}

Original file line number Diff line number Diff line change
@@ -1,139 +1,34 @@
package de.unistuttgart.iste.meitrex.common.ollama;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Getter;
import lombok.Setter;
import java.util.List;

@Setter
@Getter
public class OllamaResponse {
@JsonIgnoreProperties(ignoreUnknown = true)
public record OllamaResponse(
@JsonProperty("choices") List<Choice> choices,
@JsonProperty("error") OpenAiError error
) {
@JsonIgnoreProperties(ignoreUnknown = true)
public record Choice(@JsonProperty("message") Message message) {}

@JsonProperty("total_duration")
private long totalDuration;
@JsonProperty("load_duration")
private long loadDuration;
@JsonProperty("prompt_eval_count")
private long promptEvalCount;
@JsonProperty("prompt_eval_duration")
private long promptEvalDuration;
@JsonProperty("eval_count")
private long evalCount;
@JsonProperty("eval_duration")
private long evalDuration;
@JsonProperty("model")
private String model;
@JsonProperty("created_at")
private String createdAt;
@JsonProperty("response")
private String response;
@JsonProperty("done")
private boolean done;
@JsonProperty("done_reason")
private String doneReason;
@JsonProperty("context")
private long[] context;
@JsonProperty("error")
private String error;
@JsonIgnoreProperties(ignoreUnknown = true)
public record Message(@JsonProperty("content") String content) {}

/**
* @return The total duration of the request in milliseconds.
*/
@JsonIgnore
public long getTotalDuration() {
return totalDuration;
}

/**
* @return The duration of the model loading in milliseconds.
*/
@JsonIgnore
public long getLoadDuration() {
return loadDuration;
}

/**
* @return The number of prompt evaluations.
*/
@JsonIgnore
public long getPromptEvalCount() {
return promptEvalCount;
}

/**
* @return The duration of the prompt evaluation in milliseconds.
*/
@JsonIgnore
public long getPromptEvalDuration() {
return promptEvalDuration;
}

/**
* @return The number of evaluations.
*/
@JsonIgnore
public long getEvalCount() {
return evalCount;
}

/**
* @return The duration of the evaluation in milliseconds.
*/
@JsonIgnore
public long getEvalDuration() {
return evalDuration;
}

/**
* @return The model used for the request.
*/
@JsonIgnore
public String getModel() {
return model;
}
@JsonIgnoreProperties(ignoreUnknown = true)
public record OpenAiError(@JsonProperty("message") String message) {}

/**
* @return The creation time of the request in ISO 8601 format.
*/
@JsonIgnore
public String getCreatedAt() {
return createdAt;
}

/**
* @return The response from the model.
*/
@JsonIgnore
public String getResponse() {
return response;
}

/**
* @return Whether the request is done or not.
*/
@JsonIgnore
public boolean isDone() {
return done;
}

/**
* @return The reason why the request is done.
*/
@JsonIgnore
public String getDoneReason() {
return doneReason;
}

/**
* @return The context of the request.
*/
@JsonIgnore
public long[] getContext() {
return context;
if (choices != null && !choices.isEmpty() && choices.get(0).message() != null) {
return choices.get(0).message().content();
}
return null;
}

@JsonIgnore
public String getError() {
return error;
public String getErrorMessage() {
return error != null ? error.message() : null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,14 @@ void testStartQuerySuccess() throws Exception {
String ollamaJsonResponse = """
{
"model": "mixtral:8x22b",
"response": "{\\"result\\": 2}",
"done": true,
"done_reason": "stop"
"choices": [
{
"message": {
"role": "assistant",
"content": "{\\"result\\": 2}"
}
}
]
}
""";

Expand Down Expand Up @@ -159,7 +164,14 @@ void testStartQueryHandlesOllamaError() throws Exception {

when(jsonSchemaService.getJsonSchema(any())).thenReturn("{\"properties\":{}}");

String errorJson = "{\"error\": \"Authentication Error\"}";
String errorJson = """
{
"error": {
"message": "Authentication Error",
"type": "invalid_request_error"
}
}
""";

@SuppressWarnings("unchecked")
HttpResponse<String> mockHttpResponse = mock(HttpResponse.class);
Expand Down Expand Up @@ -194,9 +206,14 @@ void testStartQueryHandlesInvalidContentJson() throws Exception {
String brokenContentJsonResponse = """
{
"model": "mixtral:8x22b",
"response": "This is not valid JSON text",
"done": true,
"done_reason": "stop"
"choices": [
{
"message": {
"role": "assistant",
"content": "This is not valid JSON text"
}
}
]
}
""";

Expand Down
Loading