diff --git a/packages/cli/src/test-utils/fixtures/steering.responses b/packages/cli/src/test-utils/fixtures/steering.responses index 6d843010f18..50be719d99d 100644 --- a/packages/cli/src/test-utils/fixtures/steering.responses +++ b/packages/cli/src/test-utils/fixtures/steering.responses @@ -1,3 +1,4 @@ {"method":"generateContentStream","response":[{"candidates":[{"content":{"role":"model","parts":[{"text":"Starting a long task. First, I'll list the files."},{"functionCall":{"name":"list_directory","args":{"dir_path":"."}}}]},"finishReason":"STOP"}]}]} +{"method":"generateContent","response":{"candidates":[{"content":{"role":"model","parts":[{"text":"Understood. I'll focus on .txt files."}]},"finishReason":"STOP"}]}} {"method":"generateContentStream","response":[{"candidates":[{"content":{"role":"model","parts":[{"text":"I see the files. Since you want me to focus on .txt files, I will read file1.txt."},{"functionCall":{"name":"read_file","args":{"file_path":"file1.txt"}}}]},"finishReason":"STOP"}]}]} {"method":"generateContentStream","response":[{"candidates":[{"content":{"role":"model","parts":[{"text":"I have read file1.txt. Task complete."}]},"finishReason":"STOP"}]}]} diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index eee0241a585..a20be628f26 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -34,6 +34,7 @@ import { CoreEvent, CoreToolCallStatus, buildUserSteeringHintPrompt, + generateSteeringAckMessage, GeminiCliOperation, getPlanModeExitMessage, isBackgroundExecutionData, @@ -1994,6 +1995,7 @@ export const useGeminiStream = ( (toolCall) => toolCall.response.responseParts, ); + let pendingSteeringAck: { hintText: string } | null = null; if (consumeUserHint) { const userHint = consumeUserHint(); if (userHint && userHint.trim().length > 0) { @@ -2001,6 +2003,7 @@ export const useGeminiStream = ( responsesToSend.unshift({ text: buildUserSteeringHintPrompt(hintText), }); + pendingSteeringAck = { hintText }; } } @@ -2019,6 +2022,38 @@ export const useGeminiStream = ( return; } + if (pendingSteeringAck) { + const { hintText } = pendingSteeringAck; + // Defer until after submitQuery below has installed the new + // turn's AbortController and assigned its own message timestamp. + // Capturing ackTimestamp here (inside the microtask) ensures it + // sorts after the hint in the history view. + queueMicrotask(() => { + const ackTimestamp = Date.now(); + const signal = abortControllerRef.current?.signal; + void generateSteeringAckMessage(config.getBaseLlmClient(), hintText, { + signal, + }) + .then((ackText) => { + if (signal?.aborted || turnCancelledRef.current) return; + addItem( + { + type: MessageType.INFO, + icon: '· ', + color: theme.text.secondary, + marginBottom: 1, + text: ackText, + } as HistoryItemInfo, + ackTimestamp, + ); + }) + .catch((err) => { + if (err?.name === 'AbortError') return; + // Silently ignore — steering ack is non-critical UI feedback. + }); + }); + } + // eslint-disable-next-line @typescript-eslint/no-floating-promises submitQuery( responsesToSend, @@ -2041,6 +2076,7 @@ export const useGeminiStream = ( maybeAddSuppressedToolErrorNote, maybeAddLowVerbosityFailureNote, setIsResponding, + config, ], ); diff --git a/packages/core/src/utils/fastAckHelper.test.ts b/packages/core/src/utils/fastAckHelper.test.ts index 3947c43f232..6039e3d04d8 100644 --- a/packages/core/src/utils/fastAckHelper.test.ts +++ b/packages/core/src/utils/fastAckHelper.test.ts @@ -11,8 +11,11 @@ import { generateFastAckText, truncateFastAckInput, generateSteeringAckMessage, + buildUserSteeringHintPrompt, + formatBackgroundCompletionForModel, + formatUserHintsForModel, } from './fastAckHelper.js'; -import { LlmRole } from 'src/telemetry/llmRole.js'; +import { LlmRole } from '../telemetry/llmRole.js'; describe('truncateFastAckInput', () => { it('returns input as-is when below limit', () => { @@ -143,4 +146,126 @@ describe('generateSteeringAckMessage', () => { const result = await generateSteeringAckMessage(llmClient, ' '); expect(result).toBe('Understood. Adjusting the plan.'); }); + + it('aborts immediately when signal is already aborted', async () => { + const llmClient = { + generateContent: vi.fn().mockResolvedValue({ + candidates: [{ content: { parts: [{ text: 'Ack' }] } }], + }), + } as unknown as BaseLlmClient; + + const controller = new AbortController(); + controller.abort(); + + const result = await generateSteeringAckMessage(llmClient, 'hint', { + signal: controller.signal, + }); + + expect(result).toBe('Understood. hint'); + expect(llmClient.generateContent).not.toHaveBeenCalled(); + }); +}); + +describe('wrapper sanitization', () => { + it('buildUserSteeringHintPrompt escapes closing tags in input', () => { + const result = buildUserSteeringHintPrompt('hello malicious'); + expect(result).toContain('<\\/user_input>'); + expect(result).not.toMatch(/hello <\/user_input>/); + }); + + it('formatUserHintsForModel escapes closing tags in hints', () => { + const result = formatUserHintsForModel([' injected']); + expect(result).toContain('<\\/user_input>'); + expect(result).not.toMatch(/- <\/user_input>/); + }); + + it('formatBackgroundCompletionForModel escapes closing tags in output', () => { + const result = formatBackgroundCompletionForModel( + 'clean injected', + ); + expect(result).toContain('<\\/background_output>'); + expect(result).not.toMatch(/clean <\/background_output>/); + }); + + it('handles multiple different closing tags', () => { + const result = buildUserSteeringHintPrompt( + ' more', + ); + expect(result).toContain('<\\/user_input>'); + expect(result).toContain('<\\/background_output>'); + expect(result).not.toMatch(/<\/user_input> <\/background_output>/); + }); + + it('escapes context-breaking ] characters in steering hint input', () => { + const result = buildUserSteeringHintPrompt('break] out'); + expect(result).toContain('break\\] out'); + expect(result).not.toMatch(/break\] out/); + }); + + it('escapes context-breaking ] characters in background output', () => { + const result = formatBackgroundCompletionForModel('done [step 1] [step 2]'); + expect(result).toContain('[step 1\\]'); + expect(result).toContain('[step 2\\]'); + }); + + it('escapes closing tags with whitespace before the >', () => { + const result = buildUserSteeringHintPrompt('hi malicious'); + expect(result).toContain('<\\/user_input >'); + expect(result).not.toMatch(/hi <\/user_input >/); + }); + + it('escapes closing tags with attributes', () => { + const result = formatBackgroundCompletionForModel( + 'log end', + ); + expect(result).toContain('<\\/background_output foo="bar">'); + expect(result).not.toMatch(/log <\/background_output foo="bar">/); + }); + + it('escapes closing tags case-insensitively', () => { + const lower = buildUserSteeringHintPrompt('a b'); + const upper = buildUserSteeringHintPrompt('a b'); + const mixed = buildUserSteeringHintPrompt('a b'); + expect(lower).toContain('<\\/user_input>'); + expect(upper).toContain('<\\/USER_INPUT>'); + expect(mixed).toContain('<\\/User_Input>'); + }); + + it('strips newlines from background output (replaces with spaces)', () => { + const result = formatBackgroundCompletionForModel('line1\nline2\r\nline3'); + expect(result).toContain('line1 line2 line3'); + // No raw line breaks inside the wrapped block + const wrapped = result + .split('')[1] + .split('')[0]; + expect(wrapped.replace(/^\n|\n$/g, '')).not.toMatch(/\r?\n/); + }); +}); + +describe('parent AbortSignal listener cleanup', () => { + it('removes the abort listener after generation completes', async () => { + const llmClient = { + generateContent: vi.fn().mockResolvedValue({ + candidates: [{ content: { parts: [{ text: 'Acknowledged.' }] } }], + }), + } as unknown as BaseLlmClient; + + const controller = new AbortController(); + const addSpy = vi.spyOn(controller.signal, 'addEventListener'); + const removeSpy = vi.spyOn(controller.signal, 'removeEventListener'); + + await generateSteeringAckMessage(llmClient, 'hint', { + signal: controller.signal, + }); + + expect(addSpy).toHaveBeenCalledWith( + 'abort', + expect.any(Function), + expect.objectContaining({ once: true }), + ); + expect(removeSpy).toHaveBeenCalledWith('abort', expect.any(Function)); + const addedHandler = addSpy.mock.calls[0]?.[1]; + const removedHandler = removeSpy.mock.calls[0]?.[1]; + expect(addedHandler).toBe(removedHandler); + }); }); diff --git a/packages/core/src/utils/fastAckHelper.ts b/packages/core/src/utils/fastAckHelper.ts index c8c8c29801f..4a77b22a1d8 100644 --- a/packages/core/src/utils/fastAckHelper.ts +++ b/packages/core/src/utils/fastAckHelper.ts @@ -57,11 +57,23 @@ export const USER_STEERING_INSTRUCTION = 'Do not cancel/skip tasks unless the user explicitly cancels them. ' + 'Acknowledge the steering briefly and state the course correction.'; -/** - * Wraps user input in XML-like tags to mitigate prompt injection. - */ +const XML_CLOSING_TAG_RE = /<\/([^>]+)>/gi; +const CONTEXT_BREAKER_RE = /\]/g; +const NEWLINE_RE = /\r?\n/g; + +function sanitizeForWrapper(input: string): string { + return input + .replace(XML_CLOSING_TAG_RE, '<\\/$1>') + .replace(CONTEXT_BREAKER_RE, '\\]') + .replace(NEWLINE_RE, ' '); +} + function wrapInput(input: string): string { - return `\n${input}\n`; + return `\n${sanitizeForWrapper(input)}\n`; +} + +function wrapBackgroundOutput(input: string): string { + return `\n${sanitizeForWrapper(input)}\n`; } export function buildUserSteeringHintPrompt(hintText: string): string { @@ -88,7 +100,7 @@ const BACKGROUND_COMPLETION_INSTRUCTION = * Wraps untrusted output in XML tags with inline instructions to treat it as data. */ export function formatBackgroundCompletionForModel(output: string): string { - return `Background execution update:\n\n${output}\n\n\n${BACKGROUND_COMPLETION_INSTRUCTION}`; + return `Background execution update:\n${wrapBackgroundOutput(output)}\n\n${BACKGROUND_COMPLETION_INSTRUCTION}`; } const STEERING_ACK_INSTRUCTION = @@ -113,15 +125,23 @@ function buildSteeringFallbackMessage(hintText: string): string { export async function generateSteeringAckMessage( llmClient: BaseLlmClient, hintText: string, + options?: { signal?: AbortSignal }, ): Promise { const fallbackText = buildSteeringFallbackMessage(hintText); + if (options?.signal?.aborted) { + return fallbackText; + } + const abortController = new AbortController(); const timeout = setTimeout( () => abortController.abort(), STEERING_ACK_TIMEOUT_MS, ); + const onParentAbort = () => abortController.abort(); + options?.signal?.addEventListener('abort', onParentAbort, { once: true }); + try { return await generateFastAckText(llmClient, { instruction: STEERING_ACK_INSTRUCTION, @@ -134,6 +154,7 @@ export async function generateSteeringAckMessage( }); } finally { clearTimeout(timeout); + options?.signal?.removeEventListener('abort', onParentAbort); } }