Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,26 @@ public class AgentExecutionOptions {

public static final ConfigOption<Boolean> RAG_ASYNC =
new ConfigOption<>("rag.async", Boolean.class, true);

/** Set to a positive value in milliseconds to enable short-term memory TTL; 0 disables it. */
public static final ConfigOption<Long> SHORT_TERM_MEMORY_STATE_TTL_MS =
new ConfigOption<>("short-term-memory.state-ttl.ms", Long.class, 0L);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: 0L doubles as the "TTL disabled" sentinel, but that contract only lives in OperatorStateManager.maybeEnableShortTermMemoryTTL. A one-line javadoc here — e.g. "Set to a positive value in milliseconds to enable TTL; 0 (the default) disables it" — would spare future readers a trip into the runtime. Same applies to the two enum options below: worth noting they're only consulted when TTL_MS > 0.


/** Update policy for short-term memory TTL, consulted only when TTL is enabled. */
public static final ConfigOption<ShortTermMemoryTtlUpdate>
SHORT_TERM_MEMORY_STATE_TTL_UPDATE_TYPE =
new ConfigOption<>(
"short-term-memory.state-ttl.update-type",
ShortTermMemoryTtlUpdate.class,
ShortTermMemoryTtlUpdate.ON_READ_AND_WRITE);

/**
* Visibility policy for expired short-term memory state, consulted only when TTL is enabled.
*/
public static final ConfigOption<ShortTermMemoryTtlVisibility>
SHORT_TERM_MEMORY_STATE_TTL_VISIBILITY =
new ConfigOption<>(
"short-term-memory.state-ttl.visibility",
ShortTermMemoryTtlVisibility.class,
ShortTermMemoryTtlVisibility.NEVER_RETURN_EXPIRED);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.agents.api.agents;

/** Defines when short-term memory state TTL is refreshed. */
public enum ShortTermMemoryTtlUpdate {
ON_CREATE_AND_WRITE,
ON_READ_AND_WRITE
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.agents.api.agents;

/** Defines whether expired short-term memory state can be returned before cleanup. */
public enum ShortTermMemoryTtlVisibility {
NEVER_RETURN_EXPIRED,
RETURN_EXPIRED_IF_NOT_CLEANED_UP
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
* public static ResourceDesc openAIResponses() {
* return ResourceDescriptor.Builder.newBuilder(OpenAIResponsesModelConnection.class.getName())
* .addInitialArgument("api_key", System.getenv("OPENAI_API_KEY"))
* .addInitialArgument("api_base_url", System.getenv("OPENAI_API_URL"))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is unrelated to TTL. It should be placed in a separate commit. Usually, in such cases, we add [hotfix] commits for fixing existing issues before commits of the actual PR changes.

* .addInitialArgument("timeout", 120)
* .addInitialArgument("max_retries", 3)
* .build();
Expand Down
37 changes: 37 additions & 0 deletions python/flink_agents/api/core_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,20 @@ class ErrorHandlingStrategy(Enum):
IGNORE = "ignore"


class ShortTermMemoryTtlUpdate(Enum):
"""Update policy for short-term memory TTL."""

ON_CREATE_AND_WRITE = "ON_CREATE_AND_WRITE"
ON_READ_AND_WRITE = "ON_READ_AND_WRITE"


class ShortTermMemoryTtlVisibility(Enum):
"""Visibility policy for expired short-term memory state."""

NEVER_RETURN_EXPIRED = "NEVER_RETURN_EXPIRED"
RETURN_EXPIRED_IF_NOT_CLEANED_UP = "RETURN_EXPIRED_IF_NOT_CLEANED_UP"


class LoggerType(Enum):
"""Built-in event logger types.

Expand Down Expand Up @@ -179,3 +193,26 @@ class AgentExecutionOptions:
config_type=bool,
default=True,
)

# Set to a positive value in milliseconds to enable short-term memory TTL;
# 0 disables it.
SHORT_TERM_MEMORY_STATE_TTL_MS = ConfigOption(
key="short-term-memory.state-ttl.ms",
config_type=int,
default=0,
)

# Update policy for short-term memory TTL, consulted only when TTL is enabled.
SHORT_TERM_MEMORY_STATE_TTL_UPDATE_TYPE = ConfigOption(
key="short-term-memory.state-ttl.update-type",
config_type=ShortTermMemoryTtlUpdate,
default=ShortTermMemoryTtlUpdate.ON_READ_AND_WRITE,
)

# Visibility policy for expired short-term memory state, consulted only when TTL
# is enabled.
SHORT_TERM_MEMORY_STATE_TTL_VISIBILITY = ConfigOption(
key="short-term-memory.state-ttl.visibility",
config_type=ShortTermMemoryTtlVisibility,
default=ShortTermMemoryTtlVisibility.NEVER_RETURN_EXPIRED,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#################################################################################
import os
import sysconfig
import time
from pathlib import Path
from typing import Any

from pydantic import BaseModel
from pyflink.common import Configuration
from pyflink.datastream import KeySelector, StreamExecutionEnvironment

from flink_agents.api.agents.agent import Agent
from flink_agents.api.core_options import (
AgentExecutionOptions,
ShortTermMemoryTtlUpdate,
ShortTermMemoryTtlVisibility,
)
from flink_agents.api.decorators import action
from flink_agents.api.events.event import Event, InputEvent, OutputEvent
from flink_agents.api.execution_environment import AgentsExecutionEnvironment
from flink_agents.api.runner_context import RunnerContext

current_dir = Path(__file__).parent
os.environ["PYTHONPATH"] = (
f"{current_dir.parent.parent.parent}:{sysconfig.get_paths()['purelib']}"
)


class TtlTestInput(BaseModel):
event_key: str
sleep_ms: int


class TtlTestKeySelector(KeySelector):
def get_key(self, value: TtlTestInput) -> str:
return "test_key"


class ShortTermMemoryTtlTestAgent(Agent):
@action(InputEvent.EVENT_TYPE)
@staticmethod
def input(event: Event, ctx: RunnerContext) -> None:
input_data = TtlTestInput.model_validate(InputEvent.from_event(event).input)

short_term_memory = ctx.short_term_memory
existing_value = short_term_memory.get(input_data.event_key)
current_count = 0
if isinstance(existing_value, int):
current_count = existing_value
elif isinstance(existing_value, float):
current_count = int(existing_value)

short_term_memory.set(input_data.event_key, current_count + 1)
time.sleep(input_data.sleep_ms / 1000)
ctx.send_event(
OutputEvent(
output=(
f"{input_data.event_key}|"
f"{'NEW' if existing_value is None else 'EXISTING'}"
)
)
)


def run_scenario(
ttl_ms: int,
sleep_ms: int,
*,
configure_ttl_ms: bool,
configure_ttl_options: bool,
) -> list[Any]:
config = Configuration()
config.set_string("python.pythonpath", os.environ["PYTHONPATH"])
env = StreamExecutionEnvironment.get_execution_environment(config)
env.set_parallelism(1)

input_stream = env.from_collection(
[
TtlTestInput(event_key="event1", sleep_ms=sleep_ms),
TtlTestInput(event_key="event2", sleep_ms=sleep_ms),
TtlTestInput(event_key="event1", sleep_ms=sleep_ms),
]
)

agents_env = AgentsExecutionEnvironment.get_execution_environment(env=env)
agents_config = agents_env.get_config()
if configure_ttl_ms:
agents_config.set(AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_MS, ttl_ms)
if configure_ttl_options:
agents_config.set(
AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_UPDATE_TYPE,
ShortTermMemoryTtlUpdate.ON_CREATE_AND_WRITE,
)
agents_config.set(
AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_VISIBILITY,
ShortTermMemoryTtlVisibility.NEVER_RETURN_EXPIRED,
)

output_datastream = (
agents_env.from_datastream(
input=input_stream, key_selector=TtlTestKeySelector()
)
.apply(ShortTermMemoryTtlTestAgent())
.to_datastream()
)

return list(output_datastream.execute_and_collect())


def test_value_still_visible_before_ttl_expiry() -> None:
results = run_scenario(
1000,
0,
configure_ttl_ms=True,
configure_ttl_options=True,
)

assert results == ["event1|NEW", "event2|NEW", "event1|EXISTING"]


def test_ttl_configuration_disabled_with_zero_ttl() -> None:
results = run_scenario(
0,
2000,
configure_ttl_ms=True,
configure_ttl_options=True,
)

assert results == ["event1|NEW", "event2|NEW", "event1|EXISTING"]


def test_ttl_configuration_disabled_by_default() -> None:
results = run_scenario(
0,
2000,
configure_ttl_ms=False,
configure_ttl_options=True,
)

assert results == ["event1|NEW", "event2|NEW", "event1|EXISTING"]


def test_value_expires_after_ttl() -> None:
results = run_scenario(
1000,
2000,
configure_ttl_ms=True,
configure_ttl_options=True,
)

assert results == ["event1|NEW", "event2|NEW", "event1|NEW"]


def test_ttl_configuration_applied_with_default_update_type_and_visibility() -> None:
results = run_scenario(
1000,
2000,
configure_ttl_ms=True,
configure_ttl_options=False,
)

assert results == ["event1|NEW", "event2|NEW", "event1|NEW"]
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@

from pyflink.util.java_utils import add_jars_to_context_class_loader

from flink_agents.api.core_options import AgentConfigOptions
from flink_agents.api.core_options import (
AgentConfigOptions,
AgentExecutionOptions,
ShortTermMemoryTtlUpdate,
ShortTermMemoryTtlVisibility,
)

# This script is used to verify that Java-defined configuration options
# (e.g., AgentConfigOptions) are correctly exposed and accessible in the
Expand All @@ -39,3 +44,36 @@
assert AgentConfigOptions.BASE_LOG_DIR.get_key() == "baseLogDir"
assert AgentConfigOptions.BASE_LOG_DIR.get_type() is str
assert AgentConfigOptions.BASE_LOG_DIR.get_default_value() is None

assert (
AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_MS.get_key()
== "short-term-memory.state-ttl.ms"
)
assert AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_MS.get_type() is int
assert AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_MS.get_default_value() == 0

assert (
AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_UPDATE_TYPE.get_key()
== "short-term-memory.state-ttl.update-type"
)
assert (
AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_UPDATE_TYPE.get_type()
is ShortTermMemoryTtlUpdate
)
assert (
AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_UPDATE_TYPE.get_default_value()
is ShortTermMemoryTtlUpdate.ON_READ_AND_WRITE
)

assert (
AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_VISIBILITY.get_key()
== "short-term-memory.state-ttl.visibility"
)
assert (
AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_VISIBILITY.get_type()
is ShortTermMemoryTtlVisibility
)
assert (
AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_VISIBILITY.get_default_value()
is ShortTermMemoryTtlVisibility.NEVER_RETURN_EXPIRED
)
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ public void setup(
public void open() throws Exception {
super.open();

stateManager.initializeKeyedStates(getRuntimeContext());
stateManager.initializeKeyedStates(getRuntimeContext(), agentPlan);
stateManager.initializeOperatorStates(getOperatorStateBackend());

// ResourceCache constructs its own long-lived ResourceContextImpl internally; on
Expand Down
Loading
Loading