mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-18 07:01:30 +00:00
Add model switching and model list commands to pi-daemon
This commit is contained in:
parent
c46fb9bb16
commit
4e7ae2dad6
1 changed files with 108 additions and 10 deletions
|
|
@ -12,6 +12,7 @@ const PORT = Number(process.env.PORT ?? 4567);
|
||||||
const CWD = process.argv[2] || process.cwd();
|
const CWD = process.argv[2] || process.cwd();
|
||||||
|
|
||||||
let session: AgentSession | null = null;
|
let session: AgentSession | null = null;
|
||||||
|
let modelRegistry: ModelRegistry | null = null;
|
||||||
const clients = new Set<WebSocket>();
|
const clients = new Set<WebSocket>();
|
||||||
|
|
||||||
function broadcast(data: unknown) {
|
function broadcast(data: unknown) {
|
||||||
|
|
@ -23,9 +24,50 @@ function broadcast(data: unknown) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function getCurrentModel(): string | null {
|
||||||
|
if (!session?.model) return null;
|
||||||
|
return `${session.model.provider}/${session.model.id}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function getAvailableModels(refresh = false) {
|
||||||
|
if (!modelRegistry) return [];
|
||||||
|
if (refresh) modelRegistry.refresh();
|
||||||
|
|
||||||
|
return modelRegistry.getAvailable().map((model) => ({
|
||||||
|
provider: model.provider,
|
||||||
|
id: model.id,
|
||||||
|
name: model.name,
|
||||||
|
api: model.api,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
function sendConnectionPayload(ws: WebSocket) {
|
||||||
|
ws.send(
|
||||||
|
JSON.stringify({
|
||||||
|
type: "connected",
|
||||||
|
cwd: CWD,
|
||||||
|
model: getCurrentModel(),
|
||||||
|
modelName: session?.model?.name ?? null,
|
||||||
|
isStreaming: session?.isStreaming ?? false,
|
||||||
|
models: getAvailableModels(),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function sendModelsPayload(ws: WebSocket, refresh = false) {
|
||||||
|
ws.send(
|
||||||
|
JSON.stringify({
|
||||||
|
type: "models",
|
||||||
|
model: getCurrentModel(),
|
||||||
|
modelName: session?.model?.name ?? null,
|
||||||
|
models: getAvailableModels(refresh),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
async function initSession() {
|
async function initSession() {
|
||||||
const authStorage = AuthStorage.create();
|
const authStorage = AuthStorage.create();
|
||||||
const modelRegistry = new ModelRegistry(authStorage);
|
modelRegistry = new ModelRegistry(authStorage);
|
||||||
|
|
||||||
const result = await createAgentSession({
|
const result = await createAgentSession({
|
||||||
cwd: CWD,
|
cwd: CWD,
|
||||||
|
|
@ -61,7 +103,8 @@ const httpServer = createServer((req, res) => {
|
||||||
JSON.stringify({
|
JSON.stringify({
|
||||||
status: "ok",
|
status: "ok",
|
||||||
cwd: CWD,
|
cwd: CWD,
|
||||||
model: session?.model?.name ?? null,
|
model: getCurrentModel(),
|
||||||
|
modelName: session?.model?.name ?? null,
|
||||||
isStreaming: session?.isStreaming ?? false,
|
isStreaming: session?.isStreaming ?? false,
|
||||||
clients: clients.size,
|
clients: clients.size,
|
||||||
}),
|
}),
|
||||||
|
|
@ -79,14 +122,7 @@ wss.on("connection", (ws) => {
|
||||||
clients.add(ws);
|
clients.add(ws);
|
||||||
console.log(`Client connected (${clients.size} total)`);
|
console.log(`Client connected (${clients.size} total)`);
|
||||||
|
|
||||||
ws.send(
|
sendConnectionPayload(ws);
|
||||||
JSON.stringify({
|
|
||||||
type: "connected",
|
|
||||||
cwd: CWD,
|
|
||||||
model: session?.model?.name ?? null,
|
|
||||||
isStreaming: session?.isStreaming ?? false,
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
|
|
||||||
ws.on("message", async (raw) => {
|
ws.on("message", async (raw) => {
|
||||||
if (!session) {
|
if (!session) {
|
||||||
|
|
@ -138,6 +174,68 @@ wss.on("connection", (ws) => {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "get_models": {
|
||||||
|
sendModelsPayload(ws, true);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case "set_model": {
|
||||||
|
if (session.isStreaming) {
|
||||||
|
ws.send(
|
||||||
|
JSON.stringify({
|
||||||
|
type: "error",
|
||||||
|
error: "Cannot change model while a response is streaming.",
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const modelKey = (msg.model as string | undefined)?.trim();
|
||||||
|
if (!modelKey) {
|
||||||
|
ws.send(JSON.stringify({ type: "error", error: "Missing model key." }));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const split = modelKey.split("/");
|
||||||
|
if (split.length < 2) {
|
||||||
|
ws.send(
|
||||||
|
JSON.stringify({
|
||||||
|
type: "error",
|
||||||
|
error: `Invalid model key "${modelKey}". Expected provider/model.`,
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const provider = split[0];
|
||||||
|
const modelId = split.slice(1).join("/");
|
||||||
|
const availableModels = getAvailableModels(true);
|
||||||
|
const nextModel = modelRegistry
|
||||||
|
?.getAvailable()
|
||||||
|
.find((model) => model.provider === provider && model.id === modelId);
|
||||||
|
|
||||||
|
if (!nextModel) {
|
||||||
|
ws.send(
|
||||||
|
JSON.stringify({
|
||||||
|
type: "error",
|
||||||
|
error: `Model "${modelKey}" is not available with current auth.`,
|
||||||
|
models: availableModels,
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
sendModelsPayload(ws);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
await session.setModel(nextModel);
|
||||||
|
|
||||||
|
broadcast({
|
||||||
|
type: "model_changed",
|
||||||
|
model: `${nextModel.provider}/${nextModel.id}`,
|
||||||
|
modelName: nextModel.name,
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
ws.send(JSON.stringify({ type: "error", error: `Unknown command: ${msg.type}` }));
|
ws.send(JSON.stringify({ type: "error", error: `Unknown command: ${msg.type}` }));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue