Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions js/llm.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
buildClassificationTools,
LLMClassifierFromTemplate,
OpenAIClassifier,
templateUsesThreadVariables,
} from "../js/llm";
import {
openaiClassifierShouldEvaluateArithmeticExpressions,
Expand Down Expand Up @@ -64,6 +65,15 @@ afterAll(() => {
});

describe("LLM Tests", () => {
test("templateUsesThreadVariables recognizes thread_with_system", () => {
expect(templateUsesThreadVariables("{{thread_with_system}}")).toBe(true);
expect(
templateUsesThreadVariables(
"Full thread: {{thread_with_system.0.content}}",
),
).toBe(true);
});

test("openai classifier should evaluate titles", async () => {
let callCount = -1;
server.use(
Expand Down Expand Up @@ -342,6 +352,80 @@ Issue Description: {{page_content}}
expect(capturedRequestBody.reasoning_effort).toBeUndefined();
});

test("LLMClassifierFromTemplate keeps thread filtered while exposing thread_with_system", async () => {
let capturedRequestBody: unknown;
const systemMarker = "TRACE_SYSTEM_MESSAGE";

server.use(
http.post("https://api.openai.com/v1/responses", async ({ request }) => {
capturedRequestBody = await request.json();

return HttpResponse.json({
id: "resp-test",
object: "response",
created: 1234567890,
model: "gpt-5-mini",
output: [
{
type: "function_call",
call_id: "call_test",
name: "select_choice",
arguments: JSON.stringify({ choice: "1" }),
},
],
});
}),
);

const classifier = LLMClassifierFromTemplate({
name: "thread-template",
promptTemplate:
"Filtered thread:\n{{thread}}\n\nFull thread:\n{{thread_with_system}}",
choiceScores: { "1": 1, "2": 0 },
useCoT: false,
});

await classifier({
output: "",
expected: "",
trace: {
async getThread() {
return [
{ role: "system", content: systemMarker },
{ role: "user", content: "Hello" },
{ role: "assistant", content: "Hi there" },
];
},
},
});

if (
!capturedRequestBody ||
typeof capturedRequestBody !== "object" ||
!("input" in capturedRequestBody) ||
!Array.isArray(capturedRequestBody.input)
) {
throw new Error("Unexpected request body shape");
}

const firstInput = capturedRequestBody.input[0];
if (
!firstInput ||
typeof firstInput !== "object" ||
!("content" in firstInput) ||
typeof firstInput.content !== "string"
) {
throw new Error("Unexpected request input shape");
}

const [filteredThread, fullThread] =
firstInput.content.split("\n\nFull thread:\n");
expect(filteredThread).toContain("Hello");
expect(filteredThread).toContain("Hi there");
expect(filteredThread).not.toContain(systemMarker);
expect(fullThread).toContain(systemMarker);
});

test("useResponsesApi forces the Responses API for a non-gpt-5 model", async () => {
let responsesHit = false;
let chatCompletionsHit = false;
Expand Down
3 changes: 2 additions & 1 deletion js/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export interface TraceForScorer {
// Thread-related template variable names that require preprocessor invocation
export const THREAD_VARIABLE_NAMES = [
"thread",
"thread_with_system",
"thread_count",
"first_message",
"last_message",
Expand Down Expand Up @@ -335,7 +336,7 @@ export function LLMClassifierFromTemplate<RenderArgs>({
if (runtimeArgs.trace && templateUsesThreadVariables(promptTemplate)) {
const thread = await runtimeArgs.trace.getThread();
const scorerThread = filterSystemMessagesFromThread(thread);
const computed = computeThreadTemplateVars(scorerThread);
const computed = computeThreadTemplateVars(scorerThread, thread);
// Build threadVars from THREAD_VARIABLE_NAMES to keep in sync with the pattern
for (const name of THREAD_VARIABLE_NAMES) {
threadVars[name] = computed[name as keyof ThreadTemplateVars];
Expand Down
39 changes: 39 additions & 0 deletions js/render-messages.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { describe, expect, it } from "vitest";
import { renderMessages } from "./render-messages";
import { ChatCompletionMessageParam } from "openai/resources";
import { computeThreadTemplateVars } from "./thread-utils";

describe("renderMessages", () => {
it("should never HTML-escape values, regardless of mustache syntax", () => {
Expand Down Expand Up @@ -182,4 +183,42 @@ describe("renderMessages with thread variables", () => {
expect(rendered[0].content).toContain("Assistant:");
expect(rendered[0].content).toContain("Simple response");
});

it("computeThreadTemplateVars can expose thread_with_system separately", () => {
const fullThread = [
{ role: "system", content: "You are a helpful assistant." },
...sampleThread,
];

const renderedVars = computeThreadTemplateVars(sampleThread, fullThread);

expect(renderedVars.thread).toEqual(sampleThread);
expect(renderedVars.thread_with_system).toEqual(fullThread);
expect(renderedVars.thread_count).toBe(sampleThread.length);
expect(renderedVars.first_message).toEqual(sampleThread[0]);
});

it("{{thread_with_system}} renders full conversation and supports indexing", () => {
const fullThread = [
{ role: "system", content: "You are a helpful assistant." },
...sampleThread,
];
const messages: ChatCompletionMessageParam[] = [
{
role: "user",
content:
"Full thread: {{thread_with_system}}\n\nFirst full: {{thread_with_system.0}}",
},
];
const rendered = renderMessages(
messages,
computeThreadTemplateVars(sampleThread, fullThread),
);

expect(rendered[0].content).toContain("System:");
expect(rendered[0].content).toContain("You are a helpful assistant.");
expect(rendered[0].content).toContain(
"First full: system: You are a helpful assistant.",
);
});
});
3 changes: 3 additions & 0 deletions js/thread-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ export function formatMessageArrayAsText(messages: LLMMessage[]): string {
*/
export interface ThreadTemplateVars {
thread: unknown[];
thread_with_system: unknown[];
thread_count: number;
first_message: unknown | null;
last_message: unknown | null;
Expand All @@ -270,6 +271,7 @@ export interface ThreadTemplateVars {
*/
export function computeThreadTemplateVars(
thread: unknown[],
threadWithSystem: unknown[] = thread,
): ThreadTemplateVars {
let _user_messages: unknown[] | undefined;
let _assistant_messages: unknown[] | undefined;
Expand All @@ -279,6 +281,7 @@ export function computeThreadTemplateVars(

return {
thread,
thread_with_system: threadWithSystem,
thread_count: thread.length,

get first_message(): unknown | null {
Expand Down
4 changes: 2 additions & 2 deletions py/autoevals/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def _compute_thread_vars_sync(self, trace) -> dict[str, object]:
if not isinstance(thread, list):
thread = list(thread)

computed = compute_thread_template_vars(filter_system_messages_from_thread(thread))
computed = compute_thread_template_vars(filter_system_messages_from_thread(thread), thread)
return {name: computed[name] for name in self._thread_variable_names}

async def _compute_thread_vars_async(self, trace) -> dict[str, object]:
Expand All @@ -450,7 +450,7 @@ async def _compute_thread_vars_async(self, trace) -> dict[str, object]:
if not isinstance(thread, list):
thread = list(thread)

computed = compute_thread_template_vars(filter_system_messages_from_thread(thread))
computed = compute_thread_template_vars(filter_system_messages_from_thread(thread), thread)
return {name: computed[name] for name in self._thread_variable_names}

def _request_args(self, output, expected, **kwargs):
Expand Down
83 changes: 82 additions & 1 deletion py/autoevals/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from autoevals import init
from autoevals.llm import Battle, Factuality, LLMClassifier, OpenAILLMClassifier, build_classification_tools
from autoevals.oai import OpenAIV1Module, get_default_model
from autoevals.thread_utils import compute_thread_template_vars
from autoevals.thread_utils import compute_thread_template_vars, template_uses_thread_variables


class TestModel(BaseModel):
Expand Down Expand Up @@ -96,6 +96,87 @@ def test_render_messages_with_thread_variables():
assert rendered[6]["content"].startswith("Messages:\n- user: Hello, how are you?")


def test_thread_template_detection_and_split_thread_vars():
assert template_uses_thread_variables("{{thread_with_system}}")

full_thread = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I am doing well, thank you!"},
]
filtered_thread = full_thread[1:]

thread_vars = compute_thread_template_vars(filtered_thread, full_thread)

assert len(thread_vars["thread"]) == 2
assert len(thread_vars["thread_with_system"]) == 3
assert str(thread_vars["thread"][0]) == "user: Hello, how are you?"
assert str(thread_vars["thread_with_system"][0]) == "system: You are a helpful assistant."
assert thread_vars["thread_count"] == 2


class _FakeTrace:
def __init__(self, thread):
self._thread = thread

async def get_thread(self, options=None):
del options
return self._thread


def test_llm_classifier_request_args_keep_thread_filtered_and_thread_with_system_unfiltered():
system_marker = "PY_AUTOEVALS_SYSTEM_MARKER"
trace = _FakeTrace(
[
{"role": "system", "content": system_marker},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
]
)
classifier = LLMClassifier(
"test",
"Filtered thread:\n{{thread}}\n\nFull thread:\n{{thread_with_system}}",
{"Yes": 1, "No": 0},
use_cot=False,
)

request_args = classifier._request_args(output="", expected="", trace=trace)
rendered_prompt = request_args["messages"][0]["content"]

filtered_thread, full_thread = rendered_prompt.split("\n\nFull thread:\n", 1)
assert "Hello" in filtered_thread
assert "Hi there" in filtered_thread
assert system_marker not in filtered_thread
assert system_marker in full_thread


@pytest.mark.asyncio
async def test_llm_classifier_request_args_async_keep_thread_filtered_and_thread_with_system_unfiltered():
system_marker = "PY_AUTOEVALS_SYSTEM_MARKER_ASYNC"
trace = _FakeTrace(
[
{"role": "system", "content": system_marker},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
]
)
classifier = LLMClassifier(
"test",
"Filtered thread:\n{{thread}}\n\nFull thread:\n{{thread_with_system}}",
{"Yes": 1, "No": 0},
use_cot=False,
)

request_args = await classifier._request_args_async(output="", expected="", trace=trace)
rendered_prompt = request_args["messages"][0]["content"]

filtered_thread, full_thread = rendered_prompt.split("\n\nFull thread:\n", 1)
assert "Hello" in filtered_thread
assert "Hi there" in filtered_thread
assert system_marker not in filtered_thread
assert system_marker in full_thread


def test_openai():
e = OpenAILLMClassifier(
"title",
Expand Down
11 changes: 10 additions & 1 deletion py/autoevals/thread_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

THREAD_VARIABLE_NAMES = [
"thread",
"thread_with_system",
"thread_count",
"first_message",
"last_message",
Expand Down Expand Up @@ -245,8 +246,15 @@ def _to_renderable_message_array(messages: list[Any]) -> RenderableMessageArray:
return RenderableMessageArray(wrapped)


def compute_thread_template_vars(thread: list[Any]) -> dict[str, Any]:
def compute_thread_template_vars(thread: list[Any], thread_with_system: list[Any] | None = None) -> dict[str, Any]:
renderable_thread = _to_renderable_message_array(thread) if is_llm_message_array(thread) else thread
if thread_with_system is None:
thread_with_system = thread
renderable_thread_with_system = (
_to_renderable_message_array(thread_with_system)
if is_llm_message_array(thread_with_system)
else thread_with_system
)

first_message = renderable_thread[0] if len(renderable_thread) > 0 else None
last_message = renderable_thread[-1] if len(renderable_thread) > 0 else None
Expand All @@ -264,6 +272,7 @@ def compute_thread_template_vars(thread: list[Any]) -> dict[str, Any]:

return {
"thread": renderable_thread,
"thread_with_system": renderable_thread_with_system,
"thread_count": len(thread),
"first_message": first_message,
"last_message": last_message,
Expand Down
Loading