Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,13 @@ public void open() throws Exception {
public abstract Map<String, Object> getParameters();

public ChatMessage chat(List<ChatMessage> messages) {
return this.chat(messages, Collections.emptyMap());
return this.chat(messages, Collections.emptyMap(), Collections.emptyMap());
}

public ChatMessage chat(List<ChatMessage> messages, Map<String, Object> parameters) {
public ChatMessage chat(
List<ChatMessage> messages,
Map<String, Object> promptArgs,
Map<String, Object> modelParams) {
Preconditions.checkNotNull(
connection,
"Connection is not initialized. Ensure open() is called before chat().");
Expand All @@ -124,15 +127,17 @@ public ChatMessage chat(List<ChatMessage> messages, Map<String, Object> paramete
prompt instanceof Prompt,
"Prompt is not initialized. Ensure open() is called before chat().");
Prompt prompt = (Prompt) this.prompt;
Map<String, String> arguments = new HashMap<>();
for (ChatMessage message : messages) {
for (Map.Entry<String, Object> entry : message.getExtraArgs().entrySet()) {
arguments.put(entry.getKey(), entry.getValue().toString());
Map<String, String> stringified = new HashMap<>();
if (promptArgs != null) {
for (Map.Entry<String, Object> entry : promptArgs.entrySet()) {
stringified.put(
entry.getKey(),
entry.getValue() != null ? entry.getValue().toString() : "");
}
}

// append meaningful messages
List<ChatMessage> promptMessages = prompt.formatMessages(MessageRole.USER, arguments);
List<ChatMessage> promptMessages = prompt.formatMessages(MessageRole.USER, stringified);
for (ChatMessage message : messages) {
if ((message.getContent() != null && !message.getContent().isEmpty())
|| message.getRole() == MessageRole.ASSISTANT) {
Expand All @@ -150,7 +155,9 @@ public ChatMessage chat(List<ChatMessage> messages, Map<String, Object> paramete
}

Map<String, Object> params = this.getParameters();
params.putAll(parameters);
if (modelParams != null) {
params.putAll(modelParams);
}
return connection.chat(messages, tools, params);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,23 @@ public void open() {
}

@Override
public ChatMessage chat(List<ChatMessage> messages, Map<String, Object> parameters) {
public ChatMessage chat(
List<ChatMessage> messages,
Map<String, Object> promptArgs,
Map<String, Object> modelParams) {
checkState(
chatModelSetup != null,
"ChatModelSetup is not initialized. Cannot perform chat operation.");

Map<String, Object> kwargs = new HashMap<>(parameters);
Map<String, Object> kwargs = new HashMap<>(modelParams);

List<Object> pythonMessages = new ArrayList<>();
for (ChatMessage message : messages) {
pythonMessages.add(adapter.toPythonChatMessage(message));
}

kwargs.put("messages", pythonMessages);
kwargs.put("prompt_args", promptArgs != null ? promptArgs : Collections.emptyMap());

Object pythonMessageResponse = adapter.callMethod(chatModelSetup, "chat", kwargs);
return adapter.fromPythonChatMessage(pythonMessageResponse);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import javax.annotation.Nullable;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -39,17 +40,26 @@ public class ChatRequestEvent extends Event {
private static final ObjectMapper MAPPER = new ObjectMapper();

public ChatRequestEvent(
String model, List<ChatMessage> messages, @Nullable Object outputSchema) {
String model,
List<ChatMessage> messages,
@Nullable Map<String, Object> promptArgs,
@Nullable Object outputSchema) {
super(EVENT_TYPE);
setAttr("model", model);
setAttr("messages", new ArrayList<>(messages));
setAttr("prompt_args", promptArgs != null ? promptArgs : Collections.emptyMap());
if (outputSchema != null) {
setAttr("output_schema", outputSchema);
}
}

public ChatRequestEvent(
String model, List<ChatMessage> messages, @Nullable Object outputSchema) {
this(model, messages, null, outputSchema);
}

public ChatRequestEvent(String model, List<ChatMessage> messages) {
this(model, messages, null);
this(model, messages, null, null);
}

public ChatRequestEvent(UUID id, Map<String, Object> attributes) {
Expand Down Expand Up @@ -100,4 +110,11 @@ public List<ChatMessage> getMessages() {
public Object getOutputSchema() {
return getAttr("output_schema");
}

@JsonIgnore
@SuppressWarnings("unchecked")
public Map<String, Object> getPromptArgs() {
Map<String, Object> args = (Map<String, Object>) getAttr("prompt_args");
return args != null ? args : Collections.emptyMap();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
import org.apache.flink.agents.api.resource.ResourceContext;
import org.apache.flink.agents.api.resource.ResourceDescriptor;
import org.apache.flink.agents.api.resource.ResourceType;
import org.apache.flink.agents.api.tools.Tool;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -60,7 +62,10 @@ public Map<String, Object> getParameters() {
}

@Override
public ChatMessage chat(List<ChatMessage> messages, Map<String, Object> parameters) {
public ChatMessage chat(
List<ChatMessage> messages,
Map<String, Object> promptArgs,
Map<String, Object> modelParams) {
// Simple test implementation that echoes the last user message

String lastUserContent = "";
Expand Down Expand Up @@ -229,6 +234,92 @@ void testChatResponseFormat() {
assertTrue(response.getContent().length() > 0);
}

/** Connection that captures the messages passed to it for assertions. */
private static class RecordingConnection extends BaseChatModelConnection {
List<ChatMessage> capturedMessages;

RecordingConnection() {
super(
new ResourceDescriptor(
RecordingConnection.class.getName(), Collections.emptyMap()),
null);
}

@Override
public ChatMessage chat(
List<ChatMessage> messages, List<Tool> tools, Map<String, Object> arguments) {
this.capturedMessages = new ArrayList<>(messages);
return new ChatMessage(MessageRole.ASSISTANT, "ok");
}
}

/** Subclass that exposes setters so we can inject the connection and prompt directly. */
private static class RecordingChatModelSetup extends BaseChatModelSetup {
RecordingChatModelSetup(BaseChatModelConnection connection, Prompt prompt) {
super(
new ResourceDescriptor(
RecordingChatModelSetup.class.getName(), Collections.emptyMap()),
null);
this.connection = connection;
this.prompt = prompt;
}

@Override
public Map<String, Object> getParameters() {
return new HashMap<>();
}
}

@Test
@DisplayName("chat() fills prompt template from promptArgs parameter")
void testChatFillsTemplateFromPromptArgsParameter() {
RecordingConnection connection = new RecordingConnection();
Prompt prompt = Prompt.fromText("Task: {key}");
RecordingChatModelSetup setup = new RecordingChatModelSetup(connection, prompt);

setup.chat(Collections.emptyList(), Map.of("key", "value"), Map.of());

assertNotNull(connection.capturedMessages);
assertEquals(1, connection.capturedMessages.size());
assertEquals("Task: value", connection.capturedMessages.get(0).getContent());
}

@Test
@DisplayName("chat() does not read template vars from ChatMessage.extraArgs")
void testChatDoesNotReadTemplateVarsFromExtraArgs() {
RecordingConnection connection = new RecordingConnection();
Prompt prompt = Prompt.fromText("Task: {key}");
RecordingChatModelSetup setup = new RecordingChatModelSetup(connection, prompt);

ChatMessage userMessage =
new ChatMessage(MessageRole.USER, "hello", Map.of("key", "value"));
setup.chat(List.of(userMessage), Map.of(), Map.of());

assertNotNull(connection.capturedMessages);
assertEquals(2, connection.capturedMessages.size());
assertEquals("Task: {key}", connection.capturedMessages.get(0).getContent());
assertEquals("hello", connection.capturedMessages.get(1).getContent());
}

@Test
@DisplayName("chat() re-fills prompt template on subsequent invocations when args supplied")
void testChatRefillsTemplateOnSubsequentInvocations() {
RecordingConnection connection = new RecordingConnection();
Prompt prompt = Prompt.fromText("Task: {key}");
RecordingChatModelSetup setup = new RecordingChatModelSetup(connection, prompt);

setup.chat(Collections.emptyList(), Map.of("key", "v1"), Map.of());
assertNotNull(connection.capturedMessages);
assertEquals(1, connection.capturedMessages.size());
assertEquals("Task: v1", connection.capturedMessages.get(0).getContent());

ChatMessage toolResponse = new ChatMessage(MessageRole.TOOL, "tool result");
setup.chat(List.of(toolResponse), Map.of("key", "v1"), Map.of());
assertEquals(2, connection.capturedMessages.size());
assertEquals("Task: v1", connection.capturedMessages.get(0).getContent());
assertEquals("tool result", connection.capturedMessages.get(1).getContent());
}

@Test
@DisplayName("Test chat with long input")
void testChatWithLongInput() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,11 @@ void testChat() {
ChatMessage inputMessage = mock(ChatMessage.class);
ChatMessage outputMessage = mock(ChatMessage.class);
List<ChatMessage> messages = Collections.singletonList(inputMessage);
Map<String, Object> parameters = new HashMap<>();
parameters.put("temperature", 0.7);
parameters.put("max_tokens", 100);
Map<String, Object> promptArgs = new HashMap<>();
promptArgs.put("input", "value");
Map<String, Object> modelParams = new HashMap<>();
modelParams.put("temperature", 0.7);
modelParams.put("max_tokens", 100);

Object pythonInputMessage = new Object();
Object pythonOutputMessage = new Object();
Expand All @@ -105,7 +107,7 @@ void testChat() {
.thenReturn(pythonOutputMessage);
when(mockAdapter.fromPythonChatMessage(pythonOutputMessage)).thenReturn(outputMessage);

ChatMessage result = pythonChatModelSetup.chat(messages, parameters);
ChatMessage result = pythonChatModelSetup.chat(messages, promptArgs, modelParams);

assertThat(result).isEqualTo(outputMessage);

Expand All @@ -117,8 +119,10 @@ void testChat() {
argThat(
kwargs -> {
assertThat(kwargs).containsKey("messages");
assertThat(kwargs).containsKey("prompt_args");
assertThat(kwargs).containsKey("temperature");
assertThat(kwargs).containsKey("max_tokens");
assertThat(kwargs.get("prompt_args")).isEqualTo(promptArgs);
assertThat(kwargs.get("temperature")).isEqualTo(0.7);
assertThat(kwargs.get("max_tokens")).isEqualTo(100);
List<?> pythonMessages = (List<?>) kwargs.get("messages");
Expand All @@ -136,9 +140,10 @@ void testChatWithNullChatModelSetupThrowsException() {

ChatMessage inputMessage = mock(ChatMessage.class);
List<ChatMessage> messages = Collections.singletonList(inputMessage);
Map<String, Object> parameters = new HashMap<>();
Map<String, Object> promptArgs = new HashMap<>();
Map<String, Object> modelParams = new HashMap<>();

assertThatThrownBy(() -> setupWithNullModel.chat(messages, parameters))
assertThatThrownBy(() -> setupWithNullModel.chat(messages, promptArgs, modelParams))
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining("ChatModelSetup is not initialized")
.hasMessageContaining("Cannot perform chat operation");
Expand Down
18 changes: 13 additions & 5 deletions docs/content/docs/development/prompts.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,14 @@ class ReviewAnalysisAgent(Agent):
"id": {input.id},
"review": {input.review}
"""
msg = ChatMessage(role=MessageRole.USER, extra_args={"input": content})
ctx.send_event(ChatRequestEvent(model="review_analysis_model", messages=[msg]))
msg = ChatMessage(role=MessageRole.USER)
ctx.send_event(
ChatRequestEvent(
model="review_analysis_model",
messages=[msg],
prompt_args={"input": content},
)
)
```
{{< /tab >}}

Expand Down Expand Up @@ -316,9 +322,11 @@ public class ReviewAnalysisAgent extends Agent {
String.format(
"{\n" + "\"id\": %s,\n" + "\"review\": \"%s\"\n" + "}",
inputObj.getId(), inputObj.getReview());
ChatMessage msg = new ChatMessage(MessageRole.USER, "", Map.of("input", content));
ChatMessage msg = new ChatMessage(MessageRole.USER, "");

ctx.sendEvent(new ChatRequestEvent("reviewAnalysisModel", List.of(msg)));
ctx.sendEvent(
new ChatRequestEvent(
"reviewAnalysisModel", List.of(msg), Map.of("input", content), null));
}
}

Expand All @@ -327,4 +335,4 @@ public class ReviewAnalysisAgent extends Agent {

{{< /tabs >}}

Prompts use `{variable_name}` syntax for template variables. Variables are filled from `ChatMessage.extra_args`. The prompt is automatically applied when the chat model is invoked.
Prompts use `{variable_name}` syntax for template variables. Variables are filled from the `prompt_args` argument of `ChatRequestEvent` (Python) / the `promptArgs` constructor argument (Java). The prompt is automatically applied when the chat model is invoked.
16 changes: 12 additions & 4 deletions docs/content/docs/development/workflow_agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,14 @@ class ReviewAnalysisAgent(Agent):
"id": {input.id},
"review": {input.review}
"""
msg = ChatMessage(role=MessageRole.USER, extra_args={"input": content})
ctx.send_event(ChatRequestEvent(model="review_analysis_model", messages=[msg]))
msg = ChatMessage(role=MessageRole.USER)
ctx.send_event(
ChatRequestEvent(
model="review_analysis_model",
messages=[msg],
prompt_args={"input": content},
)
)

@action(ChatResponseEvent.EVENT_TYPE)
@staticmethod
Expand Down Expand Up @@ -184,9 +190,11 @@ public class ReviewAnalysisAgent extends Agent {
String.format(
"{\n" + "\"id\": %s,\n" + "\"review\": \"%s\"\n" + "}",
inputObj.getId(), inputObj.getReview());
ChatMessage msg = new ChatMessage(MessageRole.USER, "", Map.of("input", content));
ChatMessage msg = new ChatMessage(MessageRole.USER, "");

ctx.sendEvent(new ChatRequestEvent("reviewAnalysisModel", List.of(msg)));
ctx.sendEvent(
new ChatRequestEvent(
"reviewAnalysisModel", List.of(msg), Map.of("input", content), null));
}

@Action(listenEventTypes = {ChatResponseEvent.EVENT_TYPE})
Expand Down
16 changes: 12 additions & 4 deletions docs/content/docs/get-started/quickstart/workflow_agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,14 @@ class ReviewAnalysisAgent(Agent):
"id": {input.id},
"review": {input.review}
"""
msg = ChatMessage(role=MessageRole.USER, extra_args={"input": content})
ctx.send_event(ChatRequestEvent(model="review_analysis_model", messages=[msg]))
msg = ChatMessage(role=MessageRole.USER)
ctx.send_event(
ChatRequestEvent(
model="review_analysis_model",
messages=[msg],
prompt_args={"input": content},
)
)

@action(ChatResponseEvent.EVENT_TYPE)
@staticmethod
Expand Down Expand Up @@ -227,9 +233,11 @@ public class ReviewAnalysisAgent extends Agent {
String.format(
"{\n" + "\"id\": %s,\n" + "\"review\": \"%s\"\n" + "}",
inputObj.getId(), inputObj.getReview());
ChatMessage msg = new ChatMessage(MessageRole.USER, "", Map.of("input", content));
ChatMessage msg = new ChatMessage(MessageRole.USER, "");

ctx.sendEvent(new ChatRequestEvent("reviewAnalysisModel", List.of(msg)));
ctx.sendEvent(
new ChatRequestEvent(
"reviewAnalysisModel", List.of(msg), Map.of("input", content), null));
}

@Action(listenEventTypes = {ChatResponseEvent.EVENT_TYPE})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"extra_args": {}
}
],
"prompt_args": {},
"output_schema": null
}
}
Loading
Loading