diff --git a/app/src/main/java/com/ryoncook/glassesai/ConfirmationParser.kt b/app/src/main/java/com/ryoncook/glassesai/ConfirmationParser.kt new file mode 100644 index 0000000..f790c44 --- /dev/null +++ b/app/src/main/java/com/ryoncook/glassesai/ConfirmationParser.kt @@ -0,0 +1,29 @@ +package com.ryoncook.glassesai + +enum class ConfirmationResult { AFFIRM, DENY } + +object ConfirmationParser { + + private val AFFIRMATIVES = setOf( + "yes", "yeah", "yep", "yup", "sure", "okay", "ok", + "do it", "go ahead", "go for it", "confirm", "absolutely", "definitely" + ) + + private val NEGATIVES = setOf( + "no", "nope", "nah", "cancel", "stop", "never mind", "nevermind", + "don't", "dont", "abort", "forget it", "negative" + ) + + fun parse(transcript: String): ConfirmationResult { + val lower = transcript.trim().lowercase().replace(Regex("[.,!?]+"), " ").trim() + if (AFFIRMATIVES.any { lower == it || lower.startsWith("$it ") }) return ConfirmationResult.AFFIRM + if (NEGATIVES.any { lower == it || lower.startsWith("$it ") || lower.contains(it) }) return ConfirmationResult.DENY + return ConfirmationResult.DENY + } + + fun questionFor(action: Action): String = when (action) { + is Action.Call -> "Call ${action.contact}?" + is Action.Sms -> "Text ${action.contact}?" + else -> "" + } +} diff --git a/app/src/main/java/com/ryoncook/glassesai/ConversationContext.kt b/app/src/main/java/com/ryoncook/glassesai/ConversationContext.kt new file mode 100644 index 0000000..338f1c0 --- /dev/null +++ b/app/src/main/java/com/ryoncook/glassesai/ConversationContext.kt @@ -0,0 +1,17 @@ +package com.ryoncook.glassesai + +class ConversationContext(val maxTurns: Int = 3) { + + private val _turns = ArrayDeque>() + + val turns: List> get() = _turns.toList() + + val isEmpty: Boolean get() = _turns.isEmpty() + + fun add(userText: String, assistantText: String) { + if (_turns.size >= maxTurns) _turns.removeFirst() + _turns.addLast(userText to assistantText) + } + + fun clear() = _turns.clear() +} diff --git a/app/src/main/java/com/ryoncook/glassesai/GlassesAIService.kt b/app/src/main/java/com/ryoncook/glassesai/GlassesAIService.kt index 619b5cf..eb16248 100644 --- a/app/src/main/java/com/ryoncook/glassesai/GlassesAIService.kt +++ b/app/src/main/java/com/ryoncook/glassesai/GlassesAIService.kt @@ -73,9 +73,12 @@ class GlassesAIService : Service() { private enum class State { IDLE, MODEL_LOADING, WAKE_LISTENING, CALIBRATING, SCO_CONNECTING, STREAMING, RESPONDING, - TRANSCRIBING, INFERRING, SYNTHESIZING + TRANSCRIBING, INFERRING, SYNTHESIZING, + AWAITING_CONFIRMATION, FOLLOW_UP_LISTENING } + private enum class ReListenMode { NONE, CONFIRMATION, FOLLOW_UP } + @Volatile private var serviceState = State.IDLE @Volatile private var isDestroyed = false @Volatile private var appInForeground = false @@ -122,6 +125,11 @@ class GlassesAIService : Service() { @Volatile private var lastPrompt = "" @Volatile private var ttsVolumeGain = 1.0f + private var pendingAction: Action? = null + private var reListenMode = ReListenMode.NONE + private val conversationContext = ConversationContext(maxTurns = 3) + private val reListenTimeoutRunnable = Runnable { onReListenTimeout() } + private var voskModel: Model? = null private var wakeRecognizer: Recognizer? = null @Volatile private var serverUrl = Config.DEFAULT_SERVER_URL @@ -141,7 +149,8 @@ class GlassesAIService : Service() { AudioManager.SCO_AUDIO_STATE_DISCONNECTED -> { when (serviceState) { State.STREAMING, State.RESPONDING, - State.TRANSCRIBING, State.INFERRING, State.SYNTHESIZING -> { + State.TRANSCRIBING, State.INFERRING, State.SYNTHESIZING, + State.AWAITING_CONFIRMATION, State.FOLLOW_UP_LISTENING -> { // Unexpected drop during conversation — reconnect Log.w(TAG, "SCO dropped mid-conversation, reconnecting") audioManager.startBluetoothSco() @@ -493,6 +502,7 @@ class GlassesAIService : Service() { Log.d(TAG, "Wake word detected — starting SCO") serviceState = State.SCO_CONNECTING updateNotif("Activating...") + conversationContext.clear() pauseMedia() stopOpenWakeWord() stopPhoneMicRecording() @@ -692,8 +702,24 @@ class GlassesAIService : Service() { } private fun onUtteranceComplete(transcript: String) { - if (serviceState != State.TRANSCRIBING) return + mainHandler.removeCallbacks(reListenTimeoutRunnable) Log.d(TAG, "Transcript: $transcript") + + when (serviceState) { + State.AWAITING_CONFIRMATION -> { + glassesRecording = false + onConfirmationResponse(transcript) + return + } + State.FOLLOW_UP_LISTENING -> { + glassesRecording = false + onFollowUpUtterance(transcript) + return + } + State.TRANSCRIBING -> { /* normal path below */ } + else -> return + } + if (TranscriptValidator.isHallucination(transcript)) { Log.w(TAG, "Whisper hallucination detected — discarding transcript") glassesRecording = false @@ -705,12 +731,63 @@ class GlassesAIService : Service() { serviceState = State.INFERRING updateNotif("Thinking...") - inferenceManager?.infer(transcript) { response -> + inferenceManager?.infer(transcript, conversationContext.turns) { response -> + Handler(Looper.getMainLooper()).post { speakResponse(response) } + } ?: afterOnDeviceTurnComplete() + } + + private fun onConfirmationResponse(transcript: String) { + val action = pendingAction ?: run { afterOnDeviceTurnComplete(); return } + pendingAction = null + when (ConfirmationParser.parse(transcript)) { + ConfirmationResult.AFFIRM -> { + Log.d(TAG, "Confirmation: affirmed — executing $action") + val confirm = executeAction(action) ?: action.confirm + speakResponse(confirm, allowFollowUp = false) + } + ConfirmationResult.DENY -> { + Log.d(TAG, "Confirmation: denied") + conversationContext.clear() + speakResponse("Cancelled.", allowFollowUp = false) + } + } + } + + private fun onFollowUpUtterance(transcript: String) { + if (TranscriptValidator.isHallucination(transcript)) { + Log.w(TAG, "Follow-up hallucination — discarding") + afterOnDeviceTurnComplete() + return + } + lastPrompt = transcript + serviceState = State.INFERRING + updateNotif("Thinking...") + inferenceManager?.infer(transcript, conversationContext.turns) { response -> Handler(Looper.getMainLooper()).post { speakResponse(response) } } ?: afterOnDeviceTurnComplete() } - private fun speakResponse(response: String) { + private fun onReListenTimeout() { + when (serviceState) { + State.AWAITING_CONFIRMATION -> { + Log.d(TAG, "Confirmation timeout — cancelling") + pendingAction = null + reListenMode = ReListenMode.NONE + conversationContext.clear() + glassesRecording = false + speakResponse("Cancelled.", allowFollowUp = false) + } + State.FOLLOW_UP_LISTENING -> { + Log.d(TAG, "Follow-up timeout — returning to wake listening") + conversationContext.clear() + glassesRecording = false + afterOnDeviceTurnComplete() + } + else -> {} + } + } + + private fun speakResponse(response: String, allowFollowUp: Boolean = true) { serviceState = State.SYNTHESIZING val cleaned = ResponseParser.cleanResponse(response) @@ -720,12 +797,23 @@ class GlassesAIService : Service() { if (json != null && json.has("action")) { val action = ActionParser.parse(jsonStr!!) ?.let { ActionOverride.apply(it, lastPrompt) } - val override = if (action != null) executeAction(action) else null when { - override != null -> override // feature blocked or error - else -> action?.confirm?.ifBlank { "" } ?: "" + action is Action.Call || action is Action.Sms -> { + // Hold execution — ask for verbal confirmation first + pendingAction = action + reListenMode = ReListenMode.CONFIRMATION + ConfirmationParser.questionFor(action) + } + else -> { + val override = if (action != null) executeAction(action) else null + when { + override != null -> override + else -> action?.confirm?.ifBlank { "" } ?: "" + } + } } } else { + if (allowFollowUp) reListenMode = ReListenMode.FOLLOW_UP cleaned } } catch (e: Exception) { @@ -765,7 +853,10 @@ class GlassesAIService : Service() { } updateNotif("Responding...") + val responseToStore = textToSpeak + val capturedLastPrompt = lastPrompt Thread { + audioTrack?.play() // resume if paused from a previous re-listen window try { val pcm = tm.synthesize(textToSpeak) if (pcm != null) { @@ -773,13 +864,15 @@ class GlassesAIService : Service() { val durationMs = pcm.size.toLong() * 1000 / 2 / tm.sampleRate() Thread.sleep(durationMs + 200) } - val beep = generateDescendingBeepPcm(tm.sampleRate()) - audioTrack?.write(beep, 0, beep.size) - Thread.sleep(beep.size.toLong() * 1000 / 2 / tm.sampleRate() + 50) } catch (e: Exception) { Log.e(TAG, "TTS synthesis error: ${e.message}") } - Handler(Looper.getMainLooper()).post { afterOnDeviceTurnComplete() } + Handler(Looper.getMainLooper()).post { + if (reListenMode == ReListenMode.FOLLOW_UP && capturedLastPrompt.isNotBlank()) { + conversationContext.add(capturedLastPrompt, responseToStore) + } + afterOnDeviceTurnComplete() + } }.start() } @@ -787,12 +880,28 @@ class GlassesAIService : Service() { glassesRecording = false glassesMicThread?.join(500) glassesMicThread = null + + val mode = reListenMode + reListenMode = ReListenMode.NONE + + if (mode == ReListenMode.CONFIRMATION || mode == ReListenMode.FOLLOW_UP) { + // Reuse glassesMicRecord without stop/restart — stopping and recreating + // the SCO mic between turns causes audible hardware artifacts on the glasses. + // Pause the audioTrack so an empty buffer in PLAY state doesn't produce + // underrun noise on the SCO line during the listening window. + audioTrack?.pause() + startReListenRecording(mode) + return + } + glassesMicRecord?.stop() glassesMicRecord?.release() glassesMicRecord = null val drainTrack = audioTrack audioTrack = null + val endBeep = generateDescendingBeepPcm(ttsManager?.sampleRate() ?: Config.TTS_SAMPLE_RATE) + drainTrack?.write(endBeep, 0, endBeep.size) drainTrack?.stop() serviceState = State.IDLE @@ -809,6 +918,76 @@ class GlassesAIService : Service() { }, 500) } + private fun startReListenRecording(mode: ReListenMode) { + serviceState = if (mode == ReListenMode.CONFIRMATION) State.AWAITING_CONFIRMATION + else State.FOLLOW_UP_LISTENING + updateNotif(if (mode == ReListenMode.CONFIRMATION) "Waiting for confirmation..." else "Listening...") + + val timeoutMs = if (mode == ReListenMode.CONFIRMATION) 8000L else 5000L + mainHandler.postDelayed(reListenTimeoutRunnable, timeoutMs) + + // glassesMicRecord is reused from the previous turn — no stop/restart. + // If it was released (e.g. after a timeout teardown), create a fresh one. + if (glassesMicRecord == null) { + val recBuf = AudioRecord.getMinBufferSize( + Config.INPUT_SAMPLE_RATE, + AudioFormat.CHANNEL_IN_MONO, + AudioFormat.ENCODING_PCM_16BIT + ).coerceAtLeast(6400) + @Suppress("MissingPermission") + glassesMicRecord = AudioRecord( + MediaRecorder.AudioSource.VOICE_COMMUNICATION, + Config.INPUT_SAMPLE_RATE, + AudioFormat.CHANNEL_IN_MONO, + AudioFormat.ENCODING_PCM_16BIT, + recBuf + ) + glassesMicRecord?.startRecording() + } + glassesRecording = true + + val silenceEndChunks = 9 + val maxChunks = if (mode == ReListenMode.CONFIRMATION) 80 else 50 + + glassesMicThread = Thread({ + // Discard audio buffered during TTS playback so TTS echo doesn't + // register as speech and trigger a false DENY before the user speaks. + val discardBuf = ByteArray(3200) + repeat(4) { if (glassesRecording) glassesMicRecord?.read(discardBuf, 0, discardBuf.size) } + + val chunk = ByteArray(3200) + val buffer = java.io.ByteArrayOutputStream() + var silentChunks = 0 + var speechStarted = false + var totalChunks = 0 + + while (glassesRecording && + (serviceState == State.AWAITING_CONFIRMATION || serviceState == State.FOLLOW_UP_LISTENING)) { + val read = glassesMicRecord?.read(chunk, 0, chunk.size) ?: break + if (read <= 0) continue + buffer.write(chunk, 0, read) + totalChunks++ + if (SilenceDetector.isSilence(chunk.copyOf(read))) { + if (speechStarted) silentChunks++ + } else { + speechStarted = true + silentChunks = 0 + } + if ((speechStarted && silentChunks >= silenceEndChunks) || totalChunks >= maxChunks) { + glassesRecording = false + } + } + + val audio = buffer.toByteArray() + if (audio.isNotEmpty() && + (serviceState == State.AWAITING_CONFIRMATION || serviceState == State.FOLLOW_UP_LISTENING)) { + val text = whisperTranscriber?.transcribe(audio) ?: "" + Log.d(TAG, "Re-listen transcript: \"$text\"") + if (text.isNotBlank()) mainHandler.post { onUtteranceComplete(text) } + } + }, "glasses-mic-relisten").also { it.start() } + } + private fun startPlaybackMonitor() { val track = audioTrack ?: return lastPlaybackHeadPosition = track.playbackHeadPosition diff --git a/app/src/main/java/com/ryoncook/glassesai/InferenceManager.kt b/app/src/main/java/com/ryoncook/glassesai/InferenceManager.kt index 9714fe8..2f6a829 100644 --- a/app/src/main/java/com/ryoncook/glassesai/InferenceManager.kt +++ b/app/src/main/java/com/ryoncook/glassesai/InferenceManager.kt @@ -32,11 +32,11 @@ class InferenceManager(private val context: Context) { inference = null } - fun infer(userText: String, onDone: (String) -> Unit) { + fun infer(userText: String, history: List> = emptyList(), onDone: (String) -> Unit) { val llm = inference ?: run { onDone(""); return } Thread { try { - val response = llm.generateResponse(buildPrompt(userText)) + val response = llm.generateResponse(buildPrompt(userText, history)) onDone(response.trim()) } catch (e: Exception) { Log.e(TAG, "Inference failed: ${e.message}") @@ -48,8 +48,7 @@ class InferenceManager(private val context: Context) { companion object { private const val TAG = "InferenceManager" - fun buildPrompt(userText: String) = - "user\n" + + private const val SYSTEM_PROMPT = "You are Prism, a voice assistant in smart glasses. Responses are spoken aloud — plain text only, no markdown.\n\n" + "When the user wants to make a call, send a text, set a timer, set an alarm, or change volume, output ONLY the matching JSON with no other text:\n" + "{\"action\":\"call\",\"contact\":\"Name\",\"confirm\":\"Calling Name\"}\n" + @@ -57,8 +56,33 @@ class InferenceManager(private val context: Context) { "{\"action\":\"timer\",\"amount\":3,\"unit\":\"minutes\",\"confirm\":\"Setting a 3-minute timer\"}\n" + "{\"action\":\"alarm\",\"hour\":7,\"minute\":0,\"label\":\"\",\"confirm\":\"Alarm set for 7 AM\"}\n" + "{\"action\":\"volume\",\"level\":30,\"confirm\":\"Volume set to 30 percent\"}\n" + - "Use the EXACT number the user states. For everything else, answer in one sentence, 15 words or fewer.\n\n" + - "$userText\n" + - "model\n" + "Use the EXACT number the user states. For everything else, answer in one sentence, 15 words or fewer." + + fun buildPrompt(userText: String, history: List> = emptyList()): String { + val sb = StringBuilder() + if (history.isEmpty()) { + sb.append("user\n") + sb.append("$SYSTEM_PROMPT\n\n") + sb.append("$userText\n") + } else { + val (firstUser, firstAssistant) = history[0] + sb.append("user\n") + sb.append("$SYSTEM_PROMPT\n\n") + sb.append("$firstUser\n") + sb.append("model\n") + sb.append("$firstAssistant\n") + for (i in 1 until history.size) { + val (u, a) = history[i] + sb.append("user\n") + sb.append("$u\n") + sb.append("model\n") + sb.append("$a\n") + } + sb.append("user\n") + sb.append("$userText\n") + } + sb.append("model\n") + return sb.toString() + } } } diff --git a/app/src/test/java/com/ryoncook/glassesai/ConfirmationParserTest.kt b/app/src/test/java/com/ryoncook/glassesai/ConfirmationParserTest.kt new file mode 100644 index 0000000..0a9d3fd --- /dev/null +++ b/app/src/test/java/com/ryoncook/glassesai/ConfirmationParserTest.kt @@ -0,0 +1,68 @@ +package com.ryoncook.glassesai + +import org.junit.Assert.* +import org.junit.Test + +class ConfirmationParserTest { + + // ── Affirmatives ───────────────────────────────────────────────────────── + + @Test fun `yes is affirm`() { assertEquals(ConfirmationResult.AFFIRM, ConfirmationParser.parse("yes")) } + @Test fun `yeah is affirm`() { assertEquals(ConfirmationResult.AFFIRM, ConfirmationParser.parse("yeah")) } + @Test fun `yep is affirm`() { assertEquals(ConfirmationResult.AFFIRM, ConfirmationParser.parse("yep")) } + @Test fun `sure is affirm`() { assertEquals(ConfirmationResult.AFFIRM, ConfirmationParser.parse("sure")) } + @Test fun `okay is affirm`() { assertEquals(ConfirmationResult.AFFIRM, ConfirmationParser.parse("okay")) } + @Test fun `ok is affirm`() { assertEquals(ConfirmationResult.AFFIRM, ConfirmationParser.parse("ok")) } + @Test fun `do it is affirm`() { assertEquals(ConfirmationResult.AFFIRM, ConfirmationParser.parse("do it")) } + @Test fun `go ahead is affirm`() { assertEquals(ConfirmationResult.AFFIRM, ConfirmationParser.parse("go ahead")) } + @Test fun `confirm is affirm`() { assertEquals(ConfirmationResult.AFFIRM, ConfirmationParser.parse("confirm")) } + + @Test fun `yes please is affirm`() { assertEquals(ConfirmationResult.AFFIRM, ConfirmationParser.parse("yes please")) } + @Test fun `yeah do it is affirm`() { assertEquals(ConfirmationResult.AFFIRM, ConfirmationParser.parse("yeah do it")) } + + // ── Negatives ──────────────────────────────────────────────────────────── + + @Test fun `no is deny`() { assertEquals(ConfirmationResult.DENY, ConfirmationParser.parse("no")) } + @Test fun `nope is deny`() { assertEquals(ConfirmationResult.DENY, ConfirmationParser.parse("nope")) } + @Test fun `cancel is deny`() { assertEquals(ConfirmationResult.DENY, ConfirmationParser.parse("cancel")) } + @Test fun `stop is deny`() { assertEquals(ConfirmationResult.DENY, ConfirmationParser.parse("stop")) } + @Test fun `never mind is deny`() { assertEquals(ConfirmationResult.DENY, ConfirmationParser.parse("never mind")) } + @Test fun `dont is deny`() { assertEquals(ConfirmationResult.DENY, ConfirmationParser.parse("don't")) } + @Test fun `no thanks is deny`() { assertEquals(ConfirmationResult.DENY, ConfirmationParser.parse("no thanks")) } + @Test fun `actually no is deny`() { assertEquals(ConfirmationResult.DENY, ConfirmationParser.parse("actually no")) } + + // ── Unrecognised → safe default is deny ────────────────────────────────── + + @Test fun `unrelated question is deny`() { + assertEquals(ConfirmationResult.DENY, ConfirmationParser.parse("what time is it")) + } + + @Test fun `empty string is deny`() { + assertEquals(ConfirmationResult.DENY, ConfirmationParser.parse("")) + } + + // ── Case insensitive ───────────────────────────────────────────────────── + + @Test fun `YES uppercase is affirm`() { assertEquals(ConfirmationResult.AFFIRM, ConfirmationParser.parse("YES")) } + @Test fun `NO uppercase is deny`() { assertEquals(ConfirmationResult.DENY, ConfirmationParser.parse("NO")) } + + // ── Whisper punctuation (returns "Yes." / "No." etc) ───────────────────── + + @Test fun `yes period is affirm`() { assertEquals(ConfirmationResult.AFFIRM, ConfirmationParser.parse("yes.")) } + @Test fun `Yeah period is affirm`() { assertEquals(ConfirmationResult.AFFIRM, ConfirmationParser.parse("Yeah.")) } + @Test fun `no period is deny`() { assertEquals(ConfirmationResult.DENY, ConfirmationParser.parse("no.")) } + @Test fun `No exclamation is deny`() { assertEquals(ConfirmationResult.DENY, ConfirmationParser.parse("No!")) } + @Test fun `yes comma more is affirm`() { assertEquals(ConfirmationResult.AFFIRM, ConfirmationParser.parse("yes, go ahead")) } + + // ── Confirmation question builder ──────────────────────────────────────── + + @Test fun `call action produces call question`() { + val action = Action.Call(contact = "Mom", confirm = "Calling Mom") + assertEquals("Call Mom?", ConfirmationParser.questionFor(action)) + } + + @Test fun `sms action produces text question`() { + val action = Action.Sms(contact = "John", message = "hi", confirm = "Texting John") + assertEquals("Text John?", ConfirmationParser.questionFor(action)) + } +} diff --git a/app/src/test/java/com/ryoncook/glassesai/ConversationContextTest.kt b/app/src/test/java/com/ryoncook/glassesai/ConversationContextTest.kt new file mode 100644 index 0000000..e8179da --- /dev/null +++ b/app/src/test/java/com/ryoncook/glassesai/ConversationContextTest.kt @@ -0,0 +1,91 @@ +package com.ryoncook.glassesai + +import org.junit.Assert.* +import org.junit.Test + +class ConversationContextTest { + + private fun ctx(max: Int = 3) = ConversationContext(maxTurns = max) + + // ── Empty state ─────────────────────────────────────────────────────────── + + @Test + fun `new context is empty`() { + assertTrue(ctx().isEmpty) + } + + @Test + fun `new context has no turns`() { + assertEquals(0, ctx().turns.size) + } + + // ── Adding turns ────────────────────────────────────────────────────────── + + @Test + fun `add stores user and assistant text`() { + val c = ctx() + c.add("what time is it", "It's 3 PM.") + assertEquals(1, c.turns.size) + assertEquals("what time is it", c.turns[0].first) + assertEquals("It's 3 PM.", c.turns[0].second) + } + + @Test + fun `add multiple turns preserves order oldest first`() { + val c = ctx() + c.add("first question", "first answer") + c.add("second question", "second answer") + assertEquals("first question", c.turns[0].first) + assertEquals("second question", c.turns[1].first) + } + + @Test + fun `isEmpty false after add`() { + val c = ctx() + c.add("hi", "hello") + assertFalse(c.isEmpty) + } + + // ── Capacity cap ───────────────────────────────────────────────────────── + + @Test + fun `adding beyond maxTurns drops oldest turn`() { + val c = ctx(max = 2) + c.add("turn 1 user", "turn 1 assistant") + c.add("turn 2 user", "turn 2 assistant") + c.add("turn 3 user", "turn 3 assistant") + assertEquals(2, c.turns.size) + assertEquals("turn 2 user", c.turns[0].first) + assertEquals("turn 3 user", c.turns[1].first) + } + + @Test + fun `maxTurns of 3 keeps only last 3`() { + val c = ctx(max = 3) + repeat(5) { i -> c.add("user $i", "assistant $i") } + assertEquals(3, c.turns.size) + assertEquals("user 2", c.turns[0].first) + assertEquals("user 4", c.turns[2].first) + } + + // ── Clear ───────────────────────────────────────────────────────────────── + + @Test + fun `clear empties all turns`() { + val c = ctx() + c.add("question", "answer") + c.clear() + assertTrue(c.isEmpty) + assertEquals(0, c.turns.size) + } + + @Test + fun `can add after clear`() { + val c = ctx() + c.add("first", "answer") + c.clear() + c.add("second", "answer2") + assertEquals(1, c.turns.size) + assertEquals("second", c.turns[0].first) + } +} diff --git a/app/src/test/java/com/ryoncook/glassesai/InferenceManagerPromptTest.kt b/app/src/test/java/com/ryoncook/glassesai/InferenceManagerPromptTest.kt index cdcdd69..2b1a071 100644 --- a/app/src/test/java/com/ryoncook/glassesai/InferenceManagerPromptTest.kt +++ b/app/src/test/java/com/ryoncook/glassesai/InferenceManagerPromptTest.kt @@ -121,4 +121,58 @@ class InferenceManagerPromptTest { val p = prompt(text) assertTrue(p.endsWith("$text\nmodel\n")) } + + // ── Multi-turn history ──────────────────────────────────────────────────── + + private fun promptWithHistory(text: String, history: List>) = + InferenceManager.buildPrompt(text, history) + + @Test + fun `no history produces same prompt as single-arg overload`() { + assertEquals(prompt("hello"), promptWithHistory("hello", emptyList())) + } + + @Test + fun `history turn appears before current user turn`() { + val p = promptWithHistory("follow up", listOf("first question" to "first answer")) + val historyPos = p.indexOf("first question") + val currentPos = p.indexOf("follow up") + assertTrue(historyPos < currentPos) + } + + @Test + fun `history assistant turn appears in model turn`() { + val p = promptWithHistory("follow up", listOf("first question" to "first answer")) + assertTrue(p.contains("first answer")) + } + + @Test + fun `system prompt only appears once even with history`() { + val p = promptWithHistory("q", listOf("q1" to "a1", "q2" to "a2")) + val count = p.split("You are Prism").size - 1 + assertEquals(1, count) + } + + @Test + fun `prompt still ends with model turn open when history provided`() { + val p = promptWithHistory("hello", listOf("prev" to "answer")) + assertTrue(p.endsWith("model\n")) + } + + @Test + fun `current user text is in final user turn with history`() { + val text = "unique_follow_xyz" + val p = promptWithHistory(text, listOf("earlier" to "response")) + assertTrue(p.endsWith("$text\nmodel\n")) + } + + @Test + fun `two history turns both appear in prompt in order`() { + val p = promptWithHistory("third", listOf("first" to "ans1", "second" to "ans2")) + val pos1 = p.indexOf("first") + val pos2 = p.indexOf("second") + val pos3 = p.indexOf("third") + assertTrue(pos1 < pos2) + assertTrue(pos2 < pos3) + } }