From 4e7ae2dad654a020e3831bcbd039e7d9867c139a Mon Sep 17 00:00:00 2001 From: Advait Paliwal Date: Thu, 5 Mar 2026 17:29:41 -0800 Subject: [PATCH] Add model switching and model list commands to pi-daemon --- packages/pi-daemon/src/main.ts | 118 ++++++++++++++++++++++++++++++--- 1 file changed, 108 insertions(+), 10 deletions(-) diff --git a/packages/pi-daemon/src/main.ts b/packages/pi-daemon/src/main.ts index a6c16d91..a71c6af2 100644 --- a/packages/pi-daemon/src/main.ts +++ b/packages/pi-daemon/src/main.ts @@ -12,6 +12,7 @@ const PORT = Number(process.env.PORT ?? 4567); const CWD = process.argv[2] || process.cwd(); let session: AgentSession | null = null; +let modelRegistry: ModelRegistry | null = null; const clients = new Set(); function broadcast(data: unknown) { @@ -23,9 +24,50 @@ function broadcast(data: unknown) { } } +function getCurrentModel(): string | null { + if (!session?.model) return null; + return `${session.model.provider}/${session.model.id}`; +} + +function getAvailableModels(refresh = false) { + if (!modelRegistry) return []; + if (refresh) modelRegistry.refresh(); + + return modelRegistry.getAvailable().map((model) => ({ + provider: model.provider, + id: model.id, + name: model.name, + api: model.api, + })); +} + +function sendConnectionPayload(ws: WebSocket) { + ws.send( + JSON.stringify({ + type: "connected", + cwd: CWD, + model: getCurrentModel(), + modelName: session?.model?.name ?? null, + isStreaming: session?.isStreaming ?? false, + models: getAvailableModels(), + }), + ); +} + +function sendModelsPayload(ws: WebSocket, refresh = false) { + ws.send( + JSON.stringify({ + type: "models", + model: getCurrentModel(), + modelName: session?.model?.name ?? null, + models: getAvailableModels(refresh), + }), + ); +} + async function initSession() { const authStorage = AuthStorage.create(); - const modelRegistry = new ModelRegistry(authStorage); + modelRegistry = new ModelRegistry(authStorage); const result = await createAgentSession({ cwd: CWD, @@ -61,7 +103,8 @@ const httpServer = createServer((req, res) => { JSON.stringify({ status: "ok", cwd: CWD, - model: session?.model?.name ?? null, + model: getCurrentModel(), + modelName: session?.model?.name ?? null, isStreaming: session?.isStreaming ?? false, clients: clients.size, }), @@ -79,14 +122,7 @@ wss.on("connection", (ws) => { clients.add(ws); console.log(`Client connected (${clients.size} total)`); - ws.send( - JSON.stringify({ - type: "connected", - cwd: CWD, - model: session?.model?.name ?? null, - isStreaming: session?.isStreaming ?? false, - }), - ); + sendConnectionPayload(ws); ws.on("message", async (raw) => { if (!session) { @@ -138,6 +174,68 @@ wss.on("connection", (ws) => { break; } + case "get_models": { + sendModelsPayload(ws, true); + break; + } + + case "set_model": { + if (session.isStreaming) { + ws.send( + JSON.stringify({ + type: "error", + error: "Cannot change model while a response is streaming.", + }), + ); + break; + } + + const modelKey = (msg.model as string | undefined)?.trim(); + if (!modelKey) { + ws.send(JSON.stringify({ type: "error", error: "Missing model key." })); + break; + } + + const split = modelKey.split("/"); + if (split.length < 2) { + ws.send( + JSON.stringify({ + type: "error", + error: `Invalid model key "${modelKey}". Expected provider/model.`, + }), + ); + break; + } + + const provider = split[0]; + const modelId = split.slice(1).join("/"); + const availableModels = getAvailableModels(true); + const nextModel = modelRegistry + ?.getAvailable() + .find((model) => model.provider === provider && model.id === modelId); + + if (!nextModel) { + ws.send( + JSON.stringify({ + type: "error", + error: `Model "${modelKey}" is not available with current auth.`, + models: availableModels, + }), + ); + sendModelsPayload(ws); + break; + } + + await session.setModel(nextModel); + + broadcast({ + type: "model_changed", + model: `${nextModel.provider}/${nextModel.id}`, + modelName: nextModel.name, + }); + break; + } + default: ws.send(JSON.stringify({ type: "error", error: `Unknown command: ${msg.type}` })); }