Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 31 additions & 25 deletions agent/middleware/apiBasedTools.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { ToolMessage } from "@langchain/core/messages";
import { isGraphInterrupt } from "@langchain/langgraph";
import { createMiddleware } from "langchain";
import { logger, type AdminUser, type IAdminForth } from "adminforth";
import {
Expand All @@ -13,6 +14,7 @@ import { ALWAYS_AVAILABLE_API_TOOL_NAMES } from "../tools/index.js";
import { createApiTool } from "../tools/apiTool.js";
import type { AgentEventEmitter } from "../../agentEvents.js";
import type { SequenceDebugCollector } from "./sequenceDebug.js";
import { isAbortError } from "../../errors.js";

function getEnabledApiToolNames(messages: unknown[]) {
const enabledToolNames = new Set<string>();
Expand Down Expand Up @@ -82,8 +84,14 @@ export function createApiBasedToolsMiddleware(
async wrapToolCall(request, handler) {
const startedAt = Date.now();
const toolInput = JSON.stringify(request.toolCall.args ?? {});
const { adminUser, emit, sequenceDebugSink, userTimeZone } = request.runtime.context as {
if (!request.toolCall.id) {
throw new Error(`Tool call "${request.toolCall.name}" has no id.`);
}

const toolCallId = request.toolCall.id;
const { adminUser, abortSignal, emit, sequenceDebugSink, userTimeZone } = request.runtime.context as {
adminUser: AdminUser;
abortSignal?: AbortSignal;
emit?: AgentEventEmitter;
sequenceDebugSink: SequenceDebugCollector;
userTimeZone: string;
Expand Down Expand Up @@ -113,7 +121,7 @@ export function createApiBasedToolsMiddleware(
}
const toolCallTracker = createToolCallTracker({
emit: emitToolCall,
toolCallId: request.toolCall.id,
toolCallId,
toolName: request.toolCall.name,
toolInfo,
input: toolArgs,
Expand All @@ -125,39 +133,37 @@ export function createApiBasedToolsMiddleware(
);

try {
let result;

if (request.tool) {
result = await handler(request);
} else {
const enabledApiToolNames = getEnabledApiToolNames(request.state.messages);

if (enabledApiToolNames.has(request.toolCall.name)) {
result = await handler({
const result = getEnabledApiToolNames(request.state.messages).has(request.toolCall.name)
? await handler({
...request,
tool: dynamicTools[request.toolCall.name],
});
} else {
result = new ToolMessage({
content: `Tool "${request.toolCall.name}" is not loaded. Call fetch_tool_schema first.`,
tool_call_id: request.toolCall.id ?? "",
name: request.toolCall.name,
status: "error",
});
}
}
})
: await handler(request);

toolCallTracker.finishSuccess(result);
return result;
} catch (error) {
const errorDetails =
error instanceof Error ? error.stack ?? error.message : String(error);
if (
isGraphInterrupt(error)
|| abortSignal?.aborted
|| isAbortError(error)
) {
throw error;
}

const message = error instanceof Error ? error.message : String(error);

logger.error(
`Tool "${request.toolCall.name}" failed after ${Date.now() - startedAt}ms with input: ${toolInput}\n${errorDetails}`,
`Error calling tool "${request.toolCall.name}": ${error instanceof Error ? error.stack ?? error.message : String(error)}`,
);
toolCallTracker.finishError(error);
throw error;
toolCallTracker.finishError(`Error: ${message}`);
return new ToolMessage({
name: request.toolCall.name,
tool_call_id: toolCallId,
status: "error",
content: `Error: ${message}`,
})
} finally {
logger.info(
`Tool "${request.toolCall.name}" finished in ${Date.now() - startedAt}ms`,
Expand Down
27 changes: 24 additions & 3 deletions agent/runtime/AgentRuntime.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
import type { IAdminForth } from "adminforth";
import { createAgent, summarizationMiddleware } from "langchain";
import { createAgent, summarizationMiddleware, humanInTheLoopMiddleware } from "langchain";
import type { BaseCheckpointSaver } from "@langchain/langgraph";
import { createApiBasedToolsMiddleware } from "../middleware/apiBasedTools.js";
import { createSequenceDebugMiddleware } from "../middleware/sequenceDebug.js";
import { createAgentLlmMetricsLogger } from "../simpleAgent.js";
import type { AgentToolProvider } from "../tools/AgentToolProvider.js";
import type { AgentRuntimeRunInput } from "../turn/turnTypes.js";
import { contextSchema, toLangchainAgentContext } from "./AgentContext.js";
import type { ApiBasedTool } from "../../apiBasedTools.js";

function createHumanInTheLoopInterrupts(
apiBasedTools: Record<string, ApiBasedTool>,
): Record<string, { allowedDecisions: ("approve" | "reject" | "edit")[] }> {
return Object.fromEntries(
Object.entries(apiBasedTools)
.filter(([, apiBasedTool]) => apiBasedTool.agent?.isDangerous === true)
.map(([toolName]) => [
toolName,
{
allowedDecisions: ["approve", "reject"],
},
]),
);
}

export type AgentRuntimeOptions = {
name: string;
Expand All @@ -29,8 +45,13 @@ export class AgentRuntime {
const sequenceDebugMiddleware = createSequenceDebugMiddleware(
input.observability.sequenceDebugSink,
);
const hitlMiddleware = humanInTheLoopMiddleware({
interruptOn: createHumanInTheLoopInterrupts(apiBasedTools),
descriptionPrefix: "Tool execution pending approval",
});
const middleware = [
apiBasedToolsMiddleware,
hitlMiddleware,
...(input.models.modelMiddleware ?? []),
sequenceDebugMiddleware,
summarizationMiddleware({
Expand All @@ -49,8 +70,8 @@ export class AgentRuntime {
middleware,
});

return agent.stream({ messages: input.messages } as any, {
streamMode: "messages",
return agent.stream(input.input as any, {
streamMode: ["messages", "updates"],
recursionLimit: 100,
callbacks: [createAgentLlmMetricsLogger()],
signal: input.context.abortSignal,
Expand Down
9 changes: 3 additions & 6 deletions agent/systemPrompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,8 @@ export const DEFAULT_AGENT_SYSTEM_PROMPT = [
"Do not add extra explanations or suggestions unless the user asks.",
"Adapt to the user's tone and style of speaking, mirroring their vibe and wording.",
"if the user speaks casually, you should respond casually too",
"Never mutate data without user confirmation for a clearly described mutation plan.",
"One confirmation may cover one mutation or one explicitly described batch/sequence of related mutations.",
"If the confirmed plan has multiple steps, you may execute the whole confirmed plan without asking again between those steps.",
"If the plan changes, expands, or you want to do anything beyond the confirmed plan, ask for confirmation again.",
"Do not reuse an old confirmation for a new mutation plan.",
"Before calling a dangerous tool, briefly describe the exact action, target, and important changes in chat.",
"Do not ask the user for textual confirmation; dangerous tools are approved by the runtime approval UI.",
].join(" ");

export function appendCustomSystemPrompt(
Expand Down Expand Up @@ -124,7 +121,7 @@ export async function buildAgentSystemPrompt(
"If the user wants to fetch records, load fetch_data first. If the user wants analytics or charts, load analyze_data first.",
"Only call fetch_tool_schema for tool names that are explicitly mentioned in a fetched skill and are not already available as base tools.",
"If a fetched skill lists a non-base tool you need, call fetch_tool_schema for it immediately instead of telling the user the tool is unavailable.",
"For example: for record creation load mutate_data, read its tool list, call fetch_tool_schema for create_record, and then use create_record after confirmation.",
"For example: for record creation load mutate_data, read its tool list, call fetch_tool_schema for create_record, describe the planned record, and then use create_record.",
"When fetch_tool_schema succeeds, that tool becomes available on the next step.",
"All admin links must be root-relative and start with '/'.",
"Build record links as '/resource/{resourceId}/show/{primary key}'. Never use bare 'resource/{resourceId}/show/{primary key}' without the leading slash.",
Expand Down
18 changes: 18 additions & 0 deletions agent/turn/TurnLifecycleService.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import type { AgentSessionStore } from "../../sessionStore.js";
import type { PluginOptions } from "../../types.js";
import type { BaseAgentTurnInput } from "./turnTypes.js";
import { TurnPersistenceService } from "./TurnPersistenceService.js";

export class TurnLifecycleService {
constructor(
private readonly sessionStore: AgentSessionStore,
private readonly persistence: TurnPersistenceService,
private readonly options: PluginOptions,
) {}

async start(input: BaseAgentTurnInput) {
Expand All @@ -19,6 +21,22 @@ export class TurnLifecycleService {
};
}

async resume(input: BaseAgentTurnInput) {
const latestTurn = await this.sessionStore.getLatestTurn(input.sessionId);

if (!latestTurn) {
throw new Error(`No agent turn found for session "${input.sessionId}".`);
}

return {
turnId: latestTurn[this.options.turnResource.idField],
previousUserMessages: await this.sessionStore.getPreviousUserMessages(input.sessionId),
initialResponse: latestTurn[this.options.turnResource.responseField] === "not_finished"
? ""
: String(latestTurn[this.options.turnResource.responseField]),
};
}

async finish(input: {
turnId: string;
responseText: string;
Expand Down
14 changes: 11 additions & 3 deletions agent/turn/TurnStreamConsumer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,27 @@ import { VegaLiteStreamBuffer } from "./VegaLiteStreamBuffer.js";

export class TurnStreamConsumer {
async consume(input: {
stream: AsyncIterable<[any, any]>;
stream: AsyncIterable<["messages", [any, any]] | ["updates", Record<string, any>]>;
abortSignal?: AbortSignal;
emit?: AgentEventEmitter;
onInterrupt?: (interrupt: unknown) => void | Promise<void>;
}) {
let fullResponse = "";
const textBuffer = new VegaLiteStreamBuffer();

for await (const rawChunk of input.stream) {
for await (const [mode, chunk] of input.stream) {
if (input.abortSignal?.aborted) {
throw new DOMException("This operation was aborted", "AbortError");
}

const [token, metadata] = rawChunk;
if (mode === "updates") {
if ("__interrupt__" in chunk) {
await input.onInterrupt?.(chunk.__interrupt__);
}
continue;
}

const [token, metadata] = chunk;
const nodeName =
typeof metadata?.langgraph_node === "string"
? metadata.langgraph_node
Expand Down
10 changes: 9 additions & 1 deletion agent/turn/turnTypes.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import type { AdminUser, AudioAdapter } from "adminforth";
import type { Messages } from "@langchain/langgraph";
import type { Command } from "@langchain/langgraph";
import type { AgentChatModel, AgentMiddleware } from "../simpleAgent.js";
import type { SequenceDebugCollector } from "../middleware/sequenceDebug.js";
import type { PreviousUserMessage } from "../languageDetect.js";
Expand All @@ -20,6 +21,7 @@ export type BaseAgentTurnInput = {

export type TextAgentTurnInput = BaseAgentTurnInput & {
emit: AgentEventEmitter;
approvalDecision?: "approve" | "reject";
failureLogMessage?: string;
abortLogMessage?: string;
};
Expand Down Expand Up @@ -60,6 +62,11 @@ export type PreparedAgentTurn = {
modeName?: string | null;
context: AgentTurnContext;
observability: AgentTurnObservability;
resume?: {
decision: "approve" | "reject";
interrupts?: { id: string; count: number }[];
};
initialResponse?: string;
};

export type AgentTurnModels = {
Expand All @@ -70,13 +77,14 @@ export type AgentTurnModels = {

export type AgentRuntimeRunInput = {
models: AgentTurnModels;
messages: Messages;
input: { messages: Messages } | Command;
context: AgentTurnContext;
observability: AgentTurnObservability;
};

export type RunAndPersistAgentResponseInput = BaseAgentTurnInput & {
emit?: AgentEventEmitter;
approvalDecision?: "approve" | "reject";
failureLogMessage: string;
abortLogMessage: string;
};
Expand Down
5 changes: 5 additions & 0 deletions agentEvents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ export type AgentEvent =
phase: "start" | "end";
label: string;
}
| {
type: "interrupt";
sessionId: string;
interrupt: unknown;
}
| {
type: "open-page";
targetPath: string;
Expand Down
Loading