Add model switching and model list commands to pi-daemon

This commit is contained in:
Advait Paliwal 2026-03-05 17:29:41 -08:00
parent c46fb9bb16
commit 4e7ae2dad6

View file

@ -12,6 +12,7 @@ const PORT = Number(process.env.PORT ?? 4567);
const CWD = process.argv[2] || process.cwd();
let session: AgentSession | null = null;
let modelRegistry: ModelRegistry | null = null;
const clients = new Set<WebSocket>();
function broadcast(data: unknown) {
@ -23,9 +24,50 @@ function broadcast(data: unknown) {
}
}
function getCurrentModel(): string | null {
if (!session?.model) return null;
return `${session.model.provider}/${session.model.id}`;
}
function getAvailableModels(refresh = false) {
if (!modelRegistry) return [];
if (refresh) modelRegistry.refresh();
return modelRegistry.getAvailable().map((model) => ({
provider: model.provider,
id: model.id,
name: model.name,
api: model.api,
}));
}
function sendConnectionPayload(ws: WebSocket) {
ws.send(
JSON.stringify({
type: "connected",
cwd: CWD,
model: getCurrentModel(),
modelName: session?.model?.name ?? null,
isStreaming: session?.isStreaming ?? false,
models: getAvailableModels(),
}),
);
}
function sendModelsPayload(ws: WebSocket, refresh = false) {
ws.send(
JSON.stringify({
type: "models",
model: getCurrentModel(),
modelName: session?.model?.name ?? null,
models: getAvailableModels(refresh),
}),
);
}
async function initSession() {
const authStorage = AuthStorage.create();
const modelRegistry = new ModelRegistry(authStorage);
modelRegistry = new ModelRegistry(authStorage);
const result = await createAgentSession({
cwd: CWD,
@ -61,7 +103,8 @@ const httpServer = createServer((req, res) => {
JSON.stringify({
status: "ok",
cwd: CWD,
model: session?.model?.name ?? null,
model: getCurrentModel(),
modelName: session?.model?.name ?? null,
isStreaming: session?.isStreaming ?? false,
clients: clients.size,
}),
@ -79,14 +122,7 @@ wss.on("connection", (ws) => {
clients.add(ws);
console.log(`Client connected (${clients.size} total)`);
ws.send(
JSON.stringify({
type: "connected",
cwd: CWD,
model: session?.model?.name ?? null,
isStreaming: session?.isStreaming ?? false,
}),
);
sendConnectionPayload(ws);
ws.on("message", async (raw) => {
if (!session) {
@ -138,6 +174,68 @@ wss.on("connection", (ws) => {
break;
}
case "get_models": {
sendModelsPayload(ws, true);
break;
}
case "set_model": {
if (session.isStreaming) {
ws.send(
JSON.stringify({
type: "error",
error: "Cannot change model while a response is streaming.",
}),
);
break;
}
const modelKey = (msg.model as string | undefined)?.trim();
if (!modelKey) {
ws.send(JSON.stringify({ type: "error", error: "Missing model key." }));
break;
}
const split = modelKey.split("/");
if (split.length < 2) {
ws.send(
JSON.stringify({
type: "error",
error: `Invalid model key "${modelKey}". Expected provider/model.`,
}),
);
break;
}
const provider = split[0];
const modelId = split.slice(1).join("/");
const availableModels = getAvailableModels(true);
const nextModel = modelRegistry
?.getAvailable()
.find((model) => model.provider === provider && model.id === modelId);
if (!nextModel) {
ws.send(
JSON.stringify({
type: "error",
error: `Model "${modelKey}" is not available with current auth.`,
models: availableModels,
}),
);
sendModelsPayload(ws);
break;
}
await session.setModel(nextModel);
broadcast({
type: "model_changed",
model: `${nextModel.provider}/${nextModel.id}`,
modelName: nextModel.name,
});
break;
}
default:
ws.send(JSON.stringify({ type: "error", error: `Unknown command: ${msg.type}` }));
}