diff --git a/apps/mobile/src/features/auth/hooks/useProjectsQuery.ts b/apps/mobile/src/features/auth/hooks/useProjectsQuery.ts index 999d9a26fc..43a1aac114 100644 --- a/apps/mobile/src/features/auth/hooks/useProjectsQuery.ts +++ b/apps/mobile/src/features/auth/hooks/useProjectsQuery.ts @@ -1,4 +1,5 @@ import { useQuery } from "@tanstack/react-query"; +import { authedFetch, getBaseUrl } from "@/lib/api"; import { useAuthStore } from "../stores/authStore"; export interface ProjectSummary { @@ -14,21 +15,19 @@ export interface ProjectSummary { * rather than dropping the project from the list. */ export function useProjectsQuery() { - const { cloudRegion, oauthAccessToken, scopedTeams, getCloudUrlFromRegion } = - useAuthStore(); + const { cloudRegion, oauthAccessToken, scopedTeams } = useAuthStore(); return useQuery({ queryKey: ["projects", cloudRegion, scopedTeams], queryFn: async (): Promise => { - if (!cloudRegion) throw new Error("No cloud region"); - const baseUrl = getCloudUrlFromRegion(cloudRegion); + const baseUrl = getBaseUrl(); return Promise.all( scopedTeams.map(async (id): Promise => { try { - const response = await fetch(`${baseUrl}/api/projects/${id}/`, { - headers: { Authorization: `Bearer ${oauthAccessToken}` }, - }); + const response = await authedFetch( + `${baseUrl}/api/projects/${id}/`, + ); if (!response.ok) return { id, name: `Project ${id}` }; const data: { name?: string } = await response.json(); return { id, name: data.name || `Project ${id}` }; diff --git a/apps/mobile/src/features/auth/hooks/useUserQuery.ts b/apps/mobile/src/features/auth/hooks/useUserQuery.ts index 5640928548..167d8ab45e 100644 --- a/apps/mobile/src/features/auth/hooks/useUserQuery.ts +++ b/apps/mobile/src/features/auth/hooks/useUserQuery.ts @@ -1,4 +1,5 @@ import { useQuery } from "@tanstack/react-query"; +import { authedFetch, getBaseUrl } from "@/lib/api"; import { useAuthStore } from "../stores/authStore"; export interface UserData { @@ -19,19 +20,12 @@ export interface UserData { } export function useUserQuery() { - const { cloudRegion, oauthAccessToken, getCloudUrlFromRegion } = - useAuthStore(); + const { cloudRegion, oauthAccessToken } = useAuthStore(); return useQuery({ queryKey: ["user", "me"], queryFn: async (): Promise => { - if (!cloudRegion) throw new Error("No cloud region"); - const baseUrl = getCloudUrlFromRegion(cloudRegion); - const response = await fetch(`${baseUrl}/api/users/@me/`, { - headers: { - Authorization: `Bearer ${oauthAccessToken}`, - }, - }); + const response = await authedFetch(`${getBaseUrl()}/api/users/@me/`); if (!response.ok) { throw new Error(`Failed to fetch user: ${response.statusText}`); diff --git a/apps/mobile/src/features/inbox/api.ts b/apps/mobile/src/features/inbox/api.ts index 7838f6dc96..880253e15f 100644 --- a/apps/mobile/src/features/inbox/api.ts +++ b/apps/mobile/src/features/inbox/api.ts @@ -1,6 +1,5 @@ -import { fetch } from "expo/fetch"; import { HttpError } from "@/features/tasks/api"; -import { getBaseUrl, getHeaders, getProjectId } from "@/lib/api"; +import { authedFetch, getBaseUrl, getProjectId } from "@/lib/api"; import { logger } from "@/lib/logger"; import type { DismissalReasonOptionValue } from "./constants"; @@ -24,7 +23,6 @@ export async function getSignalReports( ): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); const url = new URL(`${baseUrl}/api/projects/${projectId}/signals/reports/`); @@ -47,7 +45,7 @@ export async function getSignalReports( url.searchParams.set("suggested_reviewers", params.suggested_reviewers); } - const response = await fetch(url.toString(), { headers }); + const response = await authedFetch(url.toString()); if (!response.ok) { throw new HttpError( @@ -69,11 +67,9 @@ export async function getSignalReport( ): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/projects/${projectId}/signals/reports/${reportId}/`, - { headers }, ); if (response.status === 404 || response.status === 403) { @@ -94,11 +90,9 @@ export async function getSignalReport( export async function getSignalProcessingState(): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/projects/${projectId}/signals/processing_state/`, - { headers }, ); if (!response.ok) { @@ -117,7 +111,6 @@ export async function getAvailableSuggestedReviewers( ): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); const url = new URL( `${baseUrl}/api/projects/${projectId}/signals/reports/available_reviewers/`, @@ -127,7 +120,7 @@ export async function getAvailableSuggestedReviewers( url.searchParams.set("query", query.trim()); } - const response = await fetch(url.toString(), { headers }); + const response = await authedFetch(url.toString()); if (!response.ok) { throw new HttpError( @@ -160,11 +153,9 @@ export async function getSignalReportTasks( ): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/projects/${projectId}/signals/reports/${reportId}/tasks/`, - { headers }, ); if (!response.ok) { @@ -184,11 +175,9 @@ export async function getSignalReportArtefacts( ): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/projects/${projectId}/signals/reports/${reportId}/artefacts/`, - { headers }, ); if (!response.ok) { @@ -211,11 +200,9 @@ export async function getSignalReportSignals( ): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/projects/${projectId}/signals/reports/${reportId}/signals/`, - { headers }, ); if (!response.ok) { @@ -268,13 +255,11 @@ export async function dismissSignalReport( ): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/projects/${projectId}/signals/reports/${reportId}/state/`, { method: "POST", - headers: { ...headers, "Content-Type": "application/json" }, body: JSON.stringify({ state: "suppressed", dismissal_reason: input.reason, diff --git a/apps/mobile/src/features/mcp/api.ts b/apps/mobile/src/features/mcp/api.ts index ee783827b5..f7fa0a828c 100644 --- a/apps/mobile/src/features/mcp/api.ts +++ b/apps/mobile/src/features/mcp/api.ts @@ -1,5 +1,4 @@ -import { fetch } from "expo/fetch"; -import { getBaseUrl, getHeaders, getProjectId } from "@/lib/api"; +import { authedFetch, getBaseUrl, getProjectId } from "@/lib/api"; import type { InstallCustomMcpServerOptions, InstallMcpTemplateOptions, @@ -37,9 +36,8 @@ export async function getMcpRecommendedServers(): Promise< > { const base = getBaseUrl(); const projectId = getProjectId(); - const response = await fetch( + const response = await authedFetch( `${base}/api/environments/${projectId}/mcp_servers/`, - { headers: getHeaders() }, ); const data = await readJsonOrThrow< McpRecommendedServer[] | { results?: McpRecommendedServer[] } @@ -51,7 +49,7 @@ export async function getMcpRecommendedServers(): Promise< export async function getMcpServerInstallations(): Promise< McpServerInstallation[] > { - const response = await fetch(`${mcpBaseUrl()}/`, { headers: getHeaders() }); + const response = await authedFetch(`${mcpBaseUrl()}/`); const data = await readJsonOrThrow< McpServerInstallation[] | { results?: McpServerInstallation[] } >(response, "Failed to fetch MCP server installations"); @@ -62,9 +60,8 @@ export async function getMcpServerInstallations(): Promise< export async function installCustomMcpServer( options: InstallCustomMcpServerOptions, ): Promise { - const response = await fetch(`${mcpBaseUrl()}/install_custom/`, { + const response = await authedFetch(`${mcpBaseUrl()}/install_custom/`, { method: "POST", - headers: getHeaders(), body: JSON.stringify(options), }); return readJsonOrThrow( @@ -77,9 +74,8 @@ export async function installCustomMcpServer( export async function installMcpTemplate( options: InstallMcpTemplateOptions, ): Promise { - const response = await fetch(`${mcpBaseUrl()}/install_template/`, { + const response = await authedFetch(`${mcpBaseUrl()}/install_template/`, { method: "POST", - headers: getHeaders(), body: JSON.stringify(options), }); return readJsonOrThrow( @@ -93,9 +89,8 @@ export async function updateMcpServerInstallation( installationId: string, updates: UpdateMcpServerInstallationOptions, ): Promise { - const response = await fetch(`${mcpBaseUrl()}/${installationId}/`, { + const response = await authedFetch(`${mcpBaseUrl()}/${installationId}/`, { method: "PATCH", - headers: getHeaders(), body: JSON.stringify(updates), }); return readJsonOrThrow( @@ -108,9 +103,8 @@ export async function updateMcpServerInstallation( export async function uninstallMcpServer( installationId: string, ): Promise { - const response = await fetch(`${mcpBaseUrl()}/${installationId}/`, { + const response = await authedFetch(`${mcpBaseUrl()}/${installationId}/`, { method: "DELETE", - headers: getHeaders(), }); if (!response.ok && response.status !== 204) { throw new Error(`Failed to uninstall MCP server: ${response.statusText}`); @@ -131,9 +125,8 @@ export async function authorizeMcpInstallation(options: { if (options.posthog_code_callback_url) { params.set("posthog_code_callback_url", options.posthog_code_callback_url); } - const response = await fetch( + const response = await authedFetch( `${mcpBaseUrl()}/authorize/?${params.toString()}`, - { headers: getHeaders() }, ); return readJsonOrThrow( response, @@ -149,9 +142,8 @@ export async function getMcpInstallationTools( const params = new URLSearchParams(); if (options.includeRemoved) params.set("include_removed", "1"); const query = params.toString(); - const response = await fetch( + const response = await authedFetch( `${mcpBaseUrl()}/${installationId}/tools/${query ? `?${query}` : ""}`, - { headers: getHeaders() }, ); const data = await readJsonOrThrow< McpInstallationTool[] | { results?: McpInstallationTool[] } @@ -165,11 +157,10 @@ export async function updateMcpToolApproval( toolName: string, approval_state: McpApprovalState, ): Promise { - const response = await fetch( + const response = await authedFetch( `${mcpBaseUrl()}/${installationId}/tools/${encodeURIComponent(toolName)}/`, { method: "PATCH", - headers: getHeaders(), body: JSON.stringify({ approval_state }), }, ); @@ -183,9 +174,9 @@ export async function updateMcpToolApproval( export async function refreshMcpInstallationTools( installationId: string, ): Promise { - const response = await fetch( + const response = await authedFetch( `${mcpBaseUrl()}/${installationId}/tools/refresh/`, - { method: "POST", headers: getHeaders() }, + { method: "POST" }, ); const data = await readJsonOrThrow< McpInstallationTool[] | { results?: McpInstallationTool[] } diff --git a/apps/mobile/src/features/tasks/api.automations.test.ts b/apps/mobile/src/features/tasks/api.automations.test.ts index ea9a06834d..c4390a5d14 100644 --- a/apps/mobile/src/features/tasks/api.automations.test.ts +++ b/apps/mobile/src/features/tasks/api.automations.test.ts @@ -10,11 +10,16 @@ vi.mock("expo/fetch", () => ({ vi.mock("@/lib/api", () => ({ getBaseUrl: () => "https://app.posthog.test", - getHeaders: () => ({ - Authorization: "Bearer token", - "Content-Type": "application/json", - }), getProjectId: () => 42, + authedFetch: (url: string, init?: RequestInit) => + mockFetch(url, { + ...init, + headers: { + Authorization: "Bearer token", + "Content-Type": "application/json", + ...((init?.headers as Record | undefined) ?? {}), + }, + }), })); import { diff --git a/apps/mobile/src/features/tasks/api.ts b/apps/mobile/src/features/tasks/api.ts index d09d2afc50..aa6e2519b7 100644 --- a/apps/mobile/src/features/tasks/api.ts +++ b/apps/mobile/src/features/tasks/api.ts @@ -1,9 +1,9 @@ import { fetch } from "expo/fetch"; import { + authedFetch, createTimeoutSignal, getAccessToken, getBaseUrl, - getHeaders, getProjectId, } from "@/lib/api"; import { logger } from "@/lib/logger"; @@ -127,7 +127,6 @@ export async function getTasks(filters?: { }): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); const params = new URLSearchParams({ limit: "500" }); if (filters?.repository) { @@ -140,9 +139,8 @@ export async function getTasks(filters?: { params.set("origin_product", filters.originProduct); } - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/projects/${projectId}/tasks/?${params}`, - { headers }, ); if (!response.ok) { @@ -160,11 +158,9 @@ export async function getTasks(filters?: { export async function getTask(taskId: string): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/projects/${projectId}/tasks/${taskId}/`, - { headers }, ); if (!response.ok) { @@ -181,11 +177,9 @@ export async function getTask(taskId: string): Promise { export async function getTaskAutomations(): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/projects/${projectId}/task_automations/?limit=500`, - { headers }, ); if (!response.ok) { @@ -207,11 +201,9 @@ export async function getTaskAutomation( ): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/projects/${projectId}/task_automations/${automationId}/`, - { headers }, ); if (!response.ok) { @@ -230,13 +222,11 @@ export async function createTaskAutomation( ): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/projects/${projectId}/task_automations/`, { method: "POST", - headers, body: JSON.stringify(options), }, ); @@ -254,13 +244,11 @@ export async function updateTaskAutomation( ): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/projects/${projectId}/task_automations/${automationId}/`, { method: "PATCH", - headers, body: JSON.stringify(updates), }, ); @@ -277,14 +265,10 @@ export async function deleteTaskAutomation( ): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/projects/${projectId}/task_automations/${automationId}/`, - { - method: "DELETE", - headers, - }, + { method: "DELETE" }, ); if (!response.ok) { @@ -301,14 +285,10 @@ export async function runTaskAutomation( ): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/projects/${projectId}/task_automations/${automationId}/run/`, - { - method: "POST", - headers, - }, + { method: "POST" }, ); if (!response.ok) { @@ -325,16 +305,17 @@ export async function runTaskAutomation( export async function createTask(options: CreateTaskOptions): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); - const response = await fetch(`${baseUrl}/api/projects/${projectId}/tasks/`, { - method: "POST", - headers, - body: JSON.stringify({ - origin_product: "user_created", - ...options, - }), - }); + const response = await authedFetch( + `${baseUrl}/api/projects/${projectId}/tasks/`, + { + method: "POST", + body: JSON.stringify({ + origin_product: "user_created", + ...options, + }), + }, + ); if (!response.ok) { const errorText = await response.text(); @@ -355,13 +336,11 @@ export async function updateTask( ): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/projects/${projectId}/tasks/${taskId}/`, { method: "PATCH", - headers, body: JSON.stringify(updates), }, ); @@ -380,14 +359,10 @@ export async function updateTask( export async function deleteTask(taskId: string): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/projects/${projectId}/tasks/${taskId}/`, - { - method: "DELETE", - headers, - }, + { method: "DELETE" }, ); if (!response.ok) { @@ -424,7 +399,6 @@ export async function runTaskInCloud( ): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); // Only serialize a body when we have options to send. Sending an empty // or minimal body on the initial run historically changed backend @@ -470,11 +444,10 @@ export async function runTaskInCloud( body = JSON.stringify(payload); } - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/projects/${projectId}/tasks/${taskId}/run/`, { method: "POST", - headers, body, }, ); @@ -496,11 +469,9 @@ export async function getTaskRun( ): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/projects/${projectId}/tasks/${taskId}/runs/${runId}/`, - { headers }, ); if (!response.ok) { @@ -523,13 +494,11 @@ export async function appendTaskRunLog( async () => { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/projects/${projectId}/tasks/${taskId}/runs/${runId}/append_log/`, { method: "POST", - headers, body: JSON.stringify({ entries }), }, ); @@ -593,7 +562,6 @@ export async function sendCloudCommand( ): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); const body = { jsonrpc: "2.0", @@ -602,11 +570,10 @@ export async function sendCloudCommand( id: `posthog-mobile-${Date.now()}`, }; - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/projects/${projectId}/tasks/${taskId}/runs/${runId}/command/`, { method: "POST", - headers, body: JSON.stringify(body), }, ); @@ -661,16 +628,15 @@ export async function fetchSessionLogs( async () => { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); const params = new URLSearchParams({ limit: String(options.limit ?? 5000), offset: String(options.offset ?? 0), }); - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/projects/${projectId}/tasks/${taskId}/runs/${runId}/session_logs/?${params}`, - { headers, signal: createTimeoutSignal(10_000) }, + { signal: createTimeoutSignal(10_000) }, ); if (!response.ok) { @@ -731,11 +697,9 @@ export async function streamCloudTask( export async function getIntegrations(): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/environments/${projectId}/integrations/`, - { headers }, ); if (!response.ok) { @@ -762,7 +726,6 @@ export async function getGithubRepositories( ): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); const allRepos: string[] = []; let offset = 0; @@ -772,9 +735,8 @@ export async function getGithubRepositories( limit: String(GITHUB_REPOS_PAGE_SIZE), offset: String(offset), }); - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/environments/${projectId}/integrations/${integrationId}/github_repos/?${params}`, - { headers }, ); if (!response.ok) { @@ -822,13 +784,11 @@ export interface GithubUserConnectResult { export async function startGithubUserIntegrationConnect(): Promise { const baseUrl = getBaseUrl(); const projectId = getProjectId(); - const headers = getHeaders(); - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/users/@me/integrations/github/start/`, { method: "POST", - headers, body: JSON.stringify({ team_id: projectId, connect_from: "posthog_mobile", @@ -854,11 +814,8 @@ export async function getUserGithubIntegrations(): Promise< UserGithubIntegration[] > { const baseUrl = getBaseUrl(); - const headers = getHeaders(); - const response = await fetch(`${baseUrl}/api/users/@me/integrations/`, { - headers, - }); + const response = await authedFetch(`${baseUrl}/api/users/@me/integrations/`); if (!response.ok) { throw new HttpError( @@ -878,7 +835,6 @@ export async function getUserGithubRepositories( installationId: string, ): Promise { const baseUrl = getBaseUrl(); - const headers = getHeaders(); const allRepos: string[] = []; let offset = 0; @@ -888,9 +844,8 @@ export async function getUserGithubRepositories( limit: String(GITHUB_REPOS_PAGE_SIZE), offset: String(offset), }); - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/users/@me/integrations/github/${installationId}/repos/?${params}`, - { headers }, ); if (!response.ok) { diff --git a/apps/mobile/src/features/tasks/skills/api.test.ts b/apps/mobile/src/features/tasks/skills/api.test.ts index 6c1338ae5a..9c6a13f2a2 100644 --- a/apps/mobile/src/features/tasks/skills/api.test.ts +++ b/apps/mobile/src/features/tasks/skills/api.test.ts @@ -10,11 +10,16 @@ vi.mock("expo/fetch", () => ({ vi.mock("@/lib/api", () => ({ getBaseUrl: () => "https://app.posthog.test", - getHeaders: () => ({ - Authorization: "Bearer token", - "Content-Type": "application/json", - }), getProjectId: () => 42, + authedFetch: (url: string, init?: RequestInit) => + mockFetch(url, { + ...init, + headers: { + Authorization: "Bearer token", + "Content-Type": "application/json", + ...((init?.headers as Record | undefined) ?? {}), + }, + }), })); import { getSkillStoreSkill, getSkillStoreSkills } from "./api"; diff --git a/apps/mobile/src/features/tasks/skills/api.ts b/apps/mobile/src/features/tasks/skills/api.ts index 0fc2e1b02e..46584a4241 100644 --- a/apps/mobile/src/features/tasks/skills/api.ts +++ b/apps/mobile/src/features/tasks/skills/api.ts @@ -1,5 +1,4 @@ -import { fetch } from "expo/fetch"; -import { getBaseUrl, getHeaders, getProjectId } from "@/lib/api"; +import { authedFetch, getBaseUrl, getProjectId } from "@/lib/api"; import type { SkillStoreListEntry, SkillStoreSkill } from "./types"; function skillStoreBaseUrl(): string { @@ -23,9 +22,7 @@ async function readJsonOrThrow( } export async function getSkillStoreSkills(): Promise { - const response = await fetch(`${skillStoreBaseUrl()}/`, { - headers: getHeaders(), - }); + const response = await authedFetch(`${skillStoreBaseUrl()}/`); const data = await readJsonOrThrow< SkillStoreListEntry[] | { results?: SkillStoreListEntry[] } @@ -37,11 +34,8 @@ export async function getSkillStoreSkills(): Promise { export async function getSkillStoreSkill( skillName: string, ): Promise { - const response = await fetch( + const response = await authedFetch( `${skillStoreBaseUrl()}/name/${encodeURIComponent(skillName)}/`, - { - headers: getHeaders(), - }, ); return readJsonOrThrow(response, "Failed to fetch skill"); diff --git a/apps/mobile/src/lib/api.test.ts b/apps/mobile/src/lib/api.test.ts new file mode 100644 index 0000000000..5cd7823ff5 --- /dev/null +++ b/apps/mobile/src/lib/api.test.ts @@ -0,0 +1,216 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; + +const { mockFetch, mockRefreshAccessToken, mockGetState } = vi.hoisted(() => ({ + mockFetch: vi.fn(), + mockRefreshAccessToken: vi.fn(), + mockGetState: vi.fn(), +})); + +vi.mock("expo/fetch", () => ({ + fetch: mockFetch, +})); + +vi.mock("expo-constants", () => ({ + default: { expoConfig: { version: "0.0.0-test" } }, +})); + +vi.mock("@/features/auth", () => ({ + useAuthStore: { + getState: mockGetState, + }, +})); + +import { authedFetch } from "./api"; + +const url = "https://app.posthog.test/api/projects/1/tasks/"; +const ok = (data: unknown = {}) => + new Response(JSON.stringify(data), { + status: 200, + headers: { "Content-Type": "application/json" }, + }); +const err = (status: number, body: unknown = { error: status }) => + new Response(JSON.stringify(body), { + status, + statusText: `Error ${status}`, + headers: { "Content-Type": "application/json" }, + }); + +function setupTokens(initial = "old-token", refreshed = "new-token") { + let current = initial; + mockRefreshAccessToken.mockImplementation(async () => { + current = refreshed; + }); + mockGetState.mockImplementation(() => ({ + oauthAccessToken: current, + refreshAccessToken: mockRefreshAccessToken, + })); +} + +describe("authedFetch", () => { + beforeEach(() => { + mockFetch.mockReset(); + mockRefreshAccessToken.mockReset(); + mockGetState.mockReset(); + }); + + it("attaches the bearer token from the auth store", async () => { + setupTokens("my-token"); + mockFetch.mockResolvedValueOnce(ok()); + + await authedFetch(url); + + expect(mockFetch).toHaveBeenCalledTimes(1); + expect(mockRefreshAccessToken).not.toHaveBeenCalled(); + expect(mockFetch.mock.calls[0][1].headers.Authorization).toBe( + "Bearer my-token", + ); + }); + + it.each([ + { + name: "401", + failure: () => err(401), + }, + { + name: "403 with authentication_failed body", + failure: () => + err(403, { + type: "authentication_error", + code: "authentication_failed", + detail: "Invalid access token.", + }), + }, + ])( + "retries once with a freshly fetched token on $name", + async ({ failure }) => { + setupTokens("old-token", "new-token"); + mockFetch.mockResolvedValueOnce(failure()).mockResolvedValueOnce(ok()); + + const response = await authedFetch(url); + + expect(response.ok).toBe(true); + expect(mockRefreshAccessToken).toHaveBeenCalledTimes(1); + expect(mockFetch).toHaveBeenCalledTimes(2); + expect(mockFetch.mock.calls[0][1].headers.Authorization).toBe( + "Bearer old-token", + ); + expect(mockFetch.mock.calls[1][1].headers.Authorization).toBe( + "Bearer new-token", + ); + }, + ); + + it.each([ + { + name: "403 without authentication_failed body", + response: () => err(403, { detail: "Permission denied." }), + expectedStatus: 403, + }, + { + name: "400 bad request", + response: () => err(400, { detail: "Bad request." }), + expectedStatus: 400, + }, + ])( + "does not retry on $name", + async ({ response: makeResponse, expectedStatus }) => { + setupTokens("token"); + mockFetch.mockResolvedValueOnce(makeResponse()); + + const response = await authedFetch(url); + + expect(response.status).toBe(expectedStatus); + expect(mockRefreshAccessToken).not.toHaveBeenCalled(); + expect(mockFetch).toHaveBeenCalledTimes(1); + }, + ); + + it("returns the failed response when the retry still 401s", async () => { + setupTokens("token-1", "token-2"); + mockFetch.mockResolvedValueOnce(err(401)).mockResolvedValueOnce(err(401)); + + const response = await authedFetch(url); + + expect(response.status).toBe(401); + expect(mockRefreshAccessToken).toHaveBeenCalledTimes(1); + expect(mockFetch).toHaveBeenCalledTimes(2); + }); + + it("falls through with the original 401 when token refresh itself fails", async () => { + mockGetState.mockReturnValue({ + oauthAccessToken: "token", + refreshAccessToken: mockRefreshAccessToken, + }); + mockRefreshAccessToken.mockRejectedValueOnce(new Error("refresh failed")); + mockFetch.mockResolvedValueOnce(err(401)); + + const response = await authedFetch(url); + + expect(response.status).toBe(401); + expect(mockRefreshAccessToken).toHaveBeenCalledTimes(1); + expect(mockFetch).toHaveBeenCalledTimes(1); + }); + + it("propagates network errors from the underlying fetch", async () => { + setupTokens("token"); + mockFetch.mockRejectedValueOnce(new Error("Network failure")); + + await expect(authedFetch(url)).rejects.toThrow("Network failure"); + }); + + it("merges caller-provided headers with the auth headers", async () => { + setupTokens("my-token"); + mockFetch.mockResolvedValueOnce(ok()); + + await authedFetch(url, { + method: "POST", + headers: { "X-Custom": "value" }, + body: "{}", + }); + + const init = mockFetch.mock.calls[0][1]; + expect(init.method).toBe("POST"); + expect(init.body).toBe("{}"); + expect(init.headers.Authorization).toBe("Bearer my-token"); + expect(init.headers["X-Custom"]).toBe("value"); + }); + + it("dedups concurrent refreshes so only one fires on a 401 stampede", async () => { + let current = "old-token"; + let resolveRefresh: () => void = () => {}; + const refreshPromise = new Promise((resolve) => { + resolveRefresh = () => { + current = "new-token"; + resolve(); + }; + }); + mockRefreshAccessToken.mockImplementation(() => refreshPromise); + mockGetState.mockImplementation(() => ({ + oauthAccessToken: current, + refreshAccessToken: mockRefreshAccessToken, + })); + + mockFetch + .mockResolvedValueOnce(err(401)) + .mockResolvedValueOnce(err(401)) + .mockResolvedValueOnce(err(401)) + .mockResolvedValueOnce(ok({ n: 1 })) + .mockResolvedValueOnce(ok({ n: 2 })) + .mockResolvedValueOnce(ok({ n: 3 })); + + const pending = Promise.all([ + authedFetch(url), + authedFetch(url), + authedFetch(url), + ]); + + // Drain microtasks until all three callers have parked on the shared + // refresh, then release it and let the retries complete. + for (let i = 0; i < 20; i++) await Promise.resolve(); + resolveRefresh(); + const responses = await pending; + + expect(responses.every((r) => r.ok)).toBe(true); + expect(mockRefreshAccessToken).toHaveBeenCalledTimes(1); + }); +}); diff --git a/apps/mobile/src/lib/api.ts b/apps/mobile/src/lib/api.ts index 0aaa6d5dc7..58603788d3 100644 --- a/apps/mobile/src/lib/api.ts +++ b/apps/mobile/src/lib/api.ts @@ -3,6 +3,10 @@ import Constants from "expo-constants"; import { useAuthStore } from "@/features/auth"; import { logger } from "@/lib/logger"; +// Derive the init shape directly from expo/fetch so we don't import from +// expo's internal build output (which can move between versions). +type FetchInit = NonNullable[1]>; + const log = logger.scope("api"); const USER_AGENT = `posthog/mobile.hog.dev; version: ${Constants.expoConfig?.version ?? "unknown"}`; @@ -57,19 +61,112 @@ export function createTimeoutSignal(ms: number): AbortSignal { return controller.signal; } +// Concurrent 401s would otherwise stampede the refresh endpoint and have the +// in-flight responses invalidate each other's new tokens. Share a single +// pending refresh across all callers and reset it once it settles. +let pendingRefresh: Promise | null = null; + +async function refreshAccessTokenOnce(): Promise { + if (pendingRefresh) return pendingRefresh; + const promise = useAuthStore + .getState() + .refreshAccessToken() + .finally(() => { + if (pendingRefresh === promise) { + pendingRefresh = null; + } + }); + pendingRefresh = promise; + return promise; +} + +async function isAuthFailureResponse(response: Response): Promise { + if (response.status === 401) return true; + if (response.status !== 403) return false; + try { + const body = await response.clone().json(); + return ( + body?.code === "authentication_failed" || + body?.type === "authentication_error" + ); + } catch { + return false; + } +} + +function mergeHeaders( + base: Record, + override: HeadersInit | undefined, +): Record { + if (!override) return base; + const merged: Record = { ...base }; + if (override instanceof Headers) { + override.forEach((value, key) => { + merged[key] = value; + }); + return merged; + } + if (Array.isArray(override)) { + for (const [key, value] of override) { + merged[key] = value; + } + return merged; + } + for (const [key, value] of Object.entries(override)) { + merged[key] = value; + } + return merged; +} + +/** + * `fetch` against the PostHog API with automatic token refresh on auth + * failure. On a 401 — or a 403 whose JSON body looks like an authentication + * failure (`code: "authentication_failed"` / `type: "authentication_error"`) — + * triggers a single shared token refresh and retries the request once. If the + * refresh itself fails, the original response is returned so callers fall + * through to their existing error-handling and sign-out flows. + * + * Mirrors the desktop fetcher's retry semantics + * (apps/code/src/renderer/api/fetcher.ts). + */ +export async function authedFetch( + url: string, + init: FetchInit = {}, +): Promise { + const headers = mergeHeaders(getHeaders(), init.headers); + let response: Response = await fetch(url, { ...init, headers }); + + if (response.ok || !(await isAuthFailureResponse(response))) { + return response; + } + + try { + await refreshAccessTokenOnce(); + } catch (err) { + log.warn("Token refresh on auth failure failed", { + url, + status: response.status, + error: err instanceof Error ? err.message : String(err), + }); + return response; + } + + const retryHeaders = mergeHeaders(getHeaders(), init.headers); + response = await fetch(url, { ...init, headers: retryHeaders }); + return response; +} + export async function registerPushToken(args: { token: string; platform: string; }): Promise { const baseUrl = getBaseUrl(); - const headers = getHeaders(); // Push tokens are per-user, not per-project — endpoint lives under // /api/users/@me/ alongside the other user-scoped APIs. const url = `${baseUrl}/api/users/@me/push_tokens/`; - const response = await fetch(url, { + const response = await authedFetch(url, { method: "POST", - headers, body: JSON.stringify(args), }); @@ -89,15 +186,13 @@ export async function registerPushToken(args: { export async function deletePushToken(args: { token: string }): Promise { const baseUrl = getBaseUrl(); - const headers = getHeaders(); // Unregister is a POST sub-action (not DELETE) because some clients and // proxies strip request bodies on DELETE. - const response = await fetch( + const response = await authedFetch( `${baseUrl}/api/users/@me/push_tokens/unregister/`, { method: "POST", - headers, body: JSON.stringify(args), }, );