mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-17 08:00:59 +00:00
Add model selector TUI and update session management
This commit is contained in:
parent
bf5f4b17c0
commit
95d040195c
4 changed files with 355 additions and 23 deletions
|
|
@ -142,8 +142,8 @@ async function runInteractiveMode(agent: Agent, sessionManager: SessionManager,
|
||||||
try {
|
try {
|
||||||
await agent.prompt(userInput);
|
await agent.prompt(userInput);
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
// Error handling - errors should be in agent state
|
// Display error in the TUI by adding an error message to the chat
|
||||||
console.error("Error:", error.message);
|
renderer.showError(error.message || "Unknown error occurred");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -182,26 +182,55 @@ export async function main(args: string[]) {
|
||||||
const provider = (parsed.provider || "anthropic") as any;
|
const provider = (parsed.provider || "anthropic") as any;
|
||||||
const modelId = parsed.model || "claude-sonnet-4-5";
|
const modelId = parsed.model || "claude-sonnet-4-5";
|
||||||
|
|
||||||
// Get API key
|
// Helper function to get API key for a provider
|
||||||
let apiKey = parsed.apiKey;
|
const getApiKeyForProvider = (providerName: string): string | undefined => {
|
||||||
if (!apiKey) {
|
// Check if API key was provided via command line
|
||||||
const envVarMap: Record<string, string> = {
|
if (parsed.apiKey) {
|
||||||
google: "GEMINI_API_KEY",
|
return parsed.apiKey;
|
||||||
openai: "OPENAI_API_KEY",
|
|
||||||
anthropic: "ANTHROPIC_OAUTH_TOKEN",
|
|
||||||
xai: "XAI_API_KEY",
|
|
||||||
groq: "GROQ_API_KEY",
|
|
||||||
cerebras: "CEREBRAS_API_KEY",
|
|
||||||
zai: "ZAI_API_KEY",
|
|
||||||
};
|
|
||||||
const envVar = envVarMap[provider] || `${provider.toUpperCase()}_API_KEY`;
|
|
||||||
apiKey = process.env[envVar];
|
|
||||||
|
|
||||||
if (!apiKey) {
|
|
||||||
console.error(chalk.red(`Error: No API key found for provider "${provider}"`));
|
|
||||||
console.error(chalk.dim(`Set ${envVar} environment variable or use --api-key flag`));
|
|
||||||
process.exit(1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const envVarMap: Record<string, string[]> = {
|
||||||
|
google: ["GEMINI_API_KEY"],
|
||||||
|
openai: ["OPENAI_API_KEY"],
|
||||||
|
anthropic: ["ANTHROPIC_OAUTH_TOKEN", "ANTHROPIC_API_KEY"],
|
||||||
|
xai: ["XAI_API_KEY"],
|
||||||
|
groq: ["GROQ_API_KEY"],
|
||||||
|
cerebras: ["CEREBRAS_API_KEY"],
|
||||||
|
zai: ["ZAI_API_KEY"],
|
||||||
|
moonshotai: ["MOONSHOT_API_KEY"],
|
||||||
|
};
|
||||||
|
|
||||||
|
const envVars = envVarMap[providerName] || [`${providerName.toUpperCase()}_API_KEY`];
|
||||||
|
|
||||||
|
// Check each environment variable in priority order
|
||||||
|
for (const envVar of envVars) {
|
||||||
|
const key = process.env[envVar];
|
||||||
|
if (key) {
|
||||||
|
return key;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return undefined;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get initial API key
|
||||||
|
const initialApiKey = getApiKeyForProvider(provider);
|
||||||
|
if (!initialApiKey) {
|
||||||
|
const envVarMap: Record<string, string[]> = {
|
||||||
|
google: ["GEMINI_API_KEY"],
|
||||||
|
openai: ["OPENAI_API_KEY"],
|
||||||
|
anthropic: ["ANTHROPIC_OAUTH_TOKEN", "ANTHROPIC_API_KEY"],
|
||||||
|
xai: ["XAI_API_KEY"],
|
||||||
|
groq: ["GROQ_API_KEY"],
|
||||||
|
cerebras: ["CEREBRAS_API_KEY"],
|
||||||
|
zai: ["ZAI_API_KEY"],
|
||||||
|
moonshotai: ["MOONSHOT_API_KEY"],
|
||||||
|
};
|
||||||
|
const envVars = envVarMap[provider] || [`${provider.toUpperCase()}_API_KEY`];
|
||||||
|
const envVarList = envVars.join(" or ");
|
||||||
|
console.error(chalk.red(`Error: No API key found for provider "${provider}"`));
|
||||||
|
console.error(chalk.dim(`Set ${envVarList} environment variable or use --api-key flag`));
|
||||||
|
process.exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create agent
|
// Create agent
|
||||||
|
|
@ -216,7 +245,17 @@ export async function main(args: string[]) {
|
||||||
tools: codingTools,
|
tools: codingTools,
|
||||||
},
|
},
|
||||||
transport: new ProviderTransport({
|
transport: new ProviderTransport({
|
||||||
getApiKey: async () => apiKey!,
|
// Dynamic API key lookup based on current model's provider
|
||||||
|
getApiKey: async () => {
|
||||||
|
const currentProvider = agent.state.model.provider;
|
||||||
|
const key = getApiKeyForProvider(currentProvider);
|
||||||
|
if (!key) {
|
||||||
|
throw new Error(
|
||||||
|
`No API key found for provider "${currentProvider}". Please set the appropriate environment variable.`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return key;
|
||||||
|
},
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -228,6 +267,22 @@ export async function main(args: string[]) {
|
||||||
agent.replaceMessages(messages);
|
agent.replaceMessages(messages);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load and restore model
|
||||||
|
const savedModel = sessionManager.loadModel();
|
||||||
|
if (savedModel) {
|
||||||
|
// Parse provider/modelId from saved model string (format: "provider/modelId")
|
||||||
|
const [savedProvider, savedModelId] = savedModel.split("/");
|
||||||
|
if (savedProvider && savedModelId) {
|
||||||
|
try {
|
||||||
|
const restoredModel = getModel(savedProvider as any, savedModelId);
|
||||||
|
agent.setModel(restoredModel);
|
||||||
|
console.log(chalk.dim(`Restored model: ${savedModel}`));
|
||||||
|
} catch (error: any) {
|
||||||
|
console.error(chalk.yellow(`Warning: Could not restore model ${savedModel}: ${error.message}`));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Load and restore thinking level
|
// Load and restore thinking level
|
||||||
const thinkingLevel = sessionManager.loadThinkingLevel() as ThinkingLevel;
|
const thinkingLevel = sessionManager.loadThinkingLevel() as ThinkingLevel;
|
||||||
if (thinkingLevel) {
|
if (thinkingLevel) {
|
||||||
|
|
|
||||||
|
|
@ -212,6 +212,29 @@ export class SessionManager {
|
||||||
return lastThinkingLevel;
|
return lastThinkingLevel;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
loadModel(): string | null {
|
||||||
|
if (!existsSync(this.sessionFile)) return null;
|
||||||
|
|
||||||
|
const lines = readFileSync(this.sessionFile, "utf8").trim().split("\n");
|
||||||
|
|
||||||
|
// Find the most recent model (from session header or change event)
|
||||||
|
let lastModel: string | null = null;
|
||||||
|
for (const line of lines) {
|
||||||
|
try {
|
||||||
|
const entry = JSON.parse(line);
|
||||||
|
if (entry.type === "session" && entry.model) {
|
||||||
|
lastModel = entry.model;
|
||||||
|
} else if (entry.type === "model_change" && entry.model) {
|
||||||
|
lastModel = entry.model;
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// Skip malformed lines
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return lastModel;
|
||||||
|
}
|
||||||
|
|
||||||
getSessionId(): string {
|
getSessionId(): string {
|
||||||
return this.sessionId;
|
return this.sessionId;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
188
packages/coding-agent/src/tui/model-selector.ts
Normal file
188
packages/coding-agent/src/tui/model-selector.ts
Normal file
|
|
@ -0,0 +1,188 @@
|
||||||
|
import { getModels, getProviders, type Model } from "@mariozechner/pi-ai";
|
||||||
|
import { Container, Input, Spacer, Text } from "@mariozechner/pi-tui";
|
||||||
|
import chalk from "chalk";
|
||||||
|
|
||||||
|
interface ModelItem {
|
||||||
|
provider: string;
|
||||||
|
id: string;
|
||||||
|
model: Model<any>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Component that renders a model selector with search
|
||||||
|
*/
|
||||||
|
export class ModelSelectorComponent extends Container {
|
||||||
|
private searchInput: Input;
|
||||||
|
private listContainer: Container;
|
||||||
|
private allModels: ModelItem[] = [];
|
||||||
|
private filteredModels: ModelItem[] = [];
|
||||||
|
private selectedIndex: number = 0;
|
||||||
|
private currentModel: Model<any>;
|
||||||
|
private onSelectCallback: (model: Model<any>) => void;
|
||||||
|
private onCancelCallback: () => void;
|
||||||
|
|
||||||
|
constructor(currentModel: Model<any>, onSelect: (model: Model<any>) => void, onCancel: () => void) {
|
||||||
|
super();
|
||||||
|
|
||||||
|
this.currentModel = currentModel;
|
||||||
|
this.onSelectCallback = onSelect;
|
||||||
|
this.onCancelCallback = onCancel;
|
||||||
|
|
||||||
|
// Load all models
|
||||||
|
this.loadModels();
|
||||||
|
|
||||||
|
// Add top border
|
||||||
|
this.addChild(new Text(chalk.blue("─".repeat(80)), 0, 0));
|
||||||
|
this.addChild(new Spacer(1));
|
||||||
|
|
||||||
|
// Create search input
|
||||||
|
this.searchInput = new Input();
|
||||||
|
this.searchInput.onSubmit = () => {
|
||||||
|
// Enter on search input selects the first filtered item
|
||||||
|
if (this.filteredModels[this.selectedIndex]) {
|
||||||
|
this.handleSelect(this.filteredModels[this.selectedIndex].model);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
this.addChild(this.searchInput);
|
||||||
|
|
||||||
|
this.addChild(new Spacer(1));
|
||||||
|
|
||||||
|
// Create list container
|
||||||
|
this.listContainer = new Container();
|
||||||
|
this.addChild(this.listContainer);
|
||||||
|
|
||||||
|
this.addChild(new Spacer(1));
|
||||||
|
|
||||||
|
// Add bottom border
|
||||||
|
this.addChild(new Text(chalk.blue("─".repeat(80)), 0, 0));
|
||||||
|
|
||||||
|
// Initial render
|
||||||
|
this.updateList();
|
||||||
|
}
|
||||||
|
|
||||||
|
private loadModels(): void {
|
||||||
|
const models: ModelItem[] = [];
|
||||||
|
const providers = getProviders();
|
||||||
|
|
||||||
|
for (const provider of providers) {
|
||||||
|
const providerModels = getModels(provider as any);
|
||||||
|
for (const model of providerModels) {
|
||||||
|
models.push({ provider, id: model.id, model });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort: current model first, then by provider
|
||||||
|
models.sort((a, b) => {
|
||||||
|
const aIsCurrent = this.currentModel?.id === a.model.id;
|
||||||
|
const bIsCurrent = this.currentModel?.id === b.model.id;
|
||||||
|
if (aIsCurrent && !bIsCurrent) return -1;
|
||||||
|
if (!aIsCurrent && bIsCurrent) return 1;
|
||||||
|
return a.provider.localeCompare(b.provider);
|
||||||
|
});
|
||||||
|
|
||||||
|
this.allModels = models;
|
||||||
|
this.filteredModels = models;
|
||||||
|
}
|
||||||
|
|
||||||
|
private filterModels(query: string): void {
|
||||||
|
if (!query.trim()) {
|
||||||
|
this.filteredModels = this.allModels;
|
||||||
|
} else {
|
||||||
|
const searchTokens = query
|
||||||
|
.toLowerCase()
|
||||||
|
.split(/\s+/)
|
||||||
|
.filter((t) => t);
|
||||||
|
this.filteredModels = this.allModels.filter(({ provider, id, model }) => {
|
||||||
|
const searchText = `${provider} ${id} ${model.name}`.toLowerCase();
|
||||||
|
return searchTokens.every((token) => searchText.includes(token));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
this.selectedIndex = Math.min(this.selectedIndex, Math.max(0, this.filteredModels.length - 1));
|
||||||
|
this.updateList();
|
||||||
|
}
|
||||||
|
|
||||||
|
private updateList(): void {
|
||||||
|
this.listContainer.clear();
|
||||||
|
|
||||||
|
const maxVisible = 10;
|
||||||
|
const startIndex = Math.max(
|
||||||
|
0,
|
||||||
|
Math.min(this.selectedIndex - Math.floor(maxVisible / 2), this.filteredModels.length - maxVisible),
|
||||||
|
);
|
||||||
|
const endIndex = Math.min(startIndex + maxVisible, this.filteredModels.length);
|
||||||
|
|
||||||
|
// Show visible slice of filtered models
|
||||||
|
for (let i = startIndex; i < endIndex; i++) {
|
||||||
|
const item = this.filteredModels[i];
|
||||||
|
if (!item) continue;
|
||||||
|
|
||||||
|
const isSelected = i === this.selectedIndex;
|
||||||
|
const isCurrent = this.currentModel?.id === item.model.id;
|
||||||
|
|
||||||
|
let line = "";
|
||||||
|
if (isSelected) {
|
||||||
|
const prefix = chalk.blue("→ ");
|
||||||
|
const modelText = `${item.id}`;
|
||||||
|
const providerBadge = chalk.gray(`[${item.provider}]`);
|
||||||
|
const checkmark = isCurrent ? chalk.green(" ✓") : "";
|
||||||
|
line = prefix + chalk.blue(modelText) + " " + providerBadge + checkmark;
|
||||||
|
} else {
|
||||||
|
const modelText = ` ${item.id}`;
|
||||||
|
const providerBadge = chalk.gray(`[${item.provider}]`);
|
||||||
|
const checkmark = isCurrent ? chalk.green(" ✓") : "";
|
||||||
|
line = modelText + " " + providerBadge + checkmark;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.listContainer.addChild(new Text(line, 0, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add scroll indicator if needed
|
||||||
|
if (startIndex > 0 || endIndex < this.filteredModels.length) {
|
||||||
|
const scrollInfo = chalk.gray(` (${this.selectedIndex + 1}/${this.filteredModels.length})`);
|
||||||
|
this.listContainer.addChild(new Text(scrollInfo, 0, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Show "no results" if empty
|
||||||
|
if (this.filteredModels.length === 0) {
|
||||||
|
this.listContainer.addChild(new Text(chalk.gray(" No matching models"), 0, 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
handleInput(keyData: string): void {
|
||||||
|
// Up arrow
|
||||||
|
if (keyData === "\x1b[A") {
|
||||||
|
this.selectedIndex = Math.max(0, this.selectedIndex - 1);
|
||||||
|
this.updateList();
|
||||||
|
}
|
||||||
|
// Down arrow
|
||||||
|
else if (keyData === "\x1b[B") {
|
||||||
|
this.selectedIndex = Math.min(this.filteredModels.length - 1, this.selectedIndex + 1);
|
||||||
|
this.updateList();
|
||||||
|
}
|
||||||
|
// Enter
|
||||||
|
else if (keyData === "\r") {
|
||||||
|
const selectedModel = this.filteredModels[this.selectedIndex];
|
||||||
|
if (selectedModel) {
|
||||||
|
this.handleSelect(selectedModel.model);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Escape
|
||||||
|
else if (keyData === "\x1b") {
|
||||||
|
this.onCancelCallback();
|
||||||
|
}
|
||||||
|
// Pass everything else to search input
|
||||||
|
else {
|
||||||
|
this.searchInput.handleInput(keyData);
|
||||||
|
this.filterModels(this.searchInput.getValue());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private handleSelect(model: Model<any>): void {
|
||||||
|
this.onSelectCallback(model);
|
||||||
|
}
|
||||||
|
|
||||||
|
getSearchInput(): Input {
|
||||||
|
return this.searchInput;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -15,6 +15,7 @@ import type { SessionManager } from "../session-manager.js";
|
||||||
import { AssistantMessageComponent } from "./assistant-message.js";
|
import { AssistantMessageComponent } from "./assistant-message.js";
|
||||||
import { CustomEditor } from "./custom-editor.js";
|
import { CustomEditor } from "./custom-editor.js";
|
||||||
import { FooterComponent } from "./footer.js";
|
import { FooterComponent } from "./footer.js";
|
||||||
|
import { ModelSelectorComponent } from "./model-selector.js";
|
||||||
import { ThinkingSelectorComponent } from "./thinking-selector.js";
|
import { ThinkingSelectorComponent } from "./thinking-selector.js";
|
||||||
import { ToolExecutionComponent } from "./tool-execution.js";
|
import { ToolExecutionComponent } from "./tool-execution.js";
|
||||||
import { UserMessageComponent } from "./user-message.js";
|
import { UserMessageComponent } from "./user-message.js";
|
||||||
|
|
@ -47,6 +48,9 @@ export class TuiRenderer {
|
||||||
// Thinking level selector
|
// Thinking level selector
|
||||||
private thinkingSelector: ThinkingSelectorComponent | null = null;
|
private thinkingSelector: ThinkingSelectorComponent | null = null;
|
||||||
|
|
||||||
|
// Model selector
|
||||||
|
private modelSelector: ModelSelectorComponent | null = null;
|
||||||
|
|
||||||
// Track if this is the first user message (to skip spacer)
|
// Track if this is the first user message (to skip spacer)
|
||||||
private isFirstUserMessage = true;
|
private isFirstUserMessage = true;
|
||||||
|
|
||||||
|
|
@ -68,8 +72,13 @@ export class TuiRenderer {
|
||||||
description: "Select reasoning level (opens selector UI)",
|
description: "Select reasoning level (opens selector UI)",
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const modelCommand: SlashCommand = {
|
||||||
|
name: "model",
|
||||||
|
description: "Select model (opens selector UI)",
|
||||||
|
};
|
||||||
|
|
||||||
// Setup autocomplete for file paths and slash commands
|
// Setup autocomplete for file paths and slash commands
|
||||||
const autocompleteProvider = new CombinedAutocompleteProvider([thinkingCommand], process.cwd());
|
const autocompleteProvider = new CombinedAutocompleteProvider([thinkingCommand, modelCommand], process.cwd());
|
||||||
this.editor.setAutocompleteProvider(autocompleteProvider);
|
this.editor.setAutocompleteProvider(autocompleteProvider);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -134,6 +143,14 @@ export class TuiRenderer {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for /model command
|
||||||
|
if (text === "/model") {
|
||||||
|
// Show model selector
|
||||||
|
this.showModelSelector();
|
||||||
|
this.editor.setText("");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (this.onInputCallback) {
|
if (this.onInputCallback) {
|
||||||
this.onInputCallback(text);
|
this.onInputCallback(text);
|
||||||
}
|
}
|
||||||
|
|
@ -404,6 +421,13 @@ export class TuiRenderer {
|
||||||
this.ui.requestRender();
|
this.ui.requestRender();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
showError(errorMessage: string): void {
|
||||||
|
// Show error message in the chat
|
||||||
|
this.chatContainer.addChild(new Spacer(1));
|
||||||
|
this.chatContainer.addChild(new Text(chalk.red(`Error: ${errorMessage}`), 1, 0));
|
||||||
|
this.ui.requestRender();
|
||||||
|
}
|
||||||
|
|
||||||
private showThinkingSelector(): void {
|
private showThinkingSelector(): void {
|
||||||
// Create thinking selector with current level
|
// Create thinking selector with current level
|
||||||
this.thinkingSelector = new ThinkingSelectorComponent(
|
this.thinkingSelector = new ThinkingSelectorComponent(
|
||||||
|
|
@ -446,6 +470,48 @@ export class TuiRenderer {
|
||||||
this.ui.setFocus(this.editor);
|
this.ui.setFocus(this.editor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private showModelSelector(): void {
|
||||||
|
// Create model selector with current model
|
||||||
|
this.modelSelector = new ModelSelectorComponent(
|
||||||
|
this.agent.state.model,
|
||||||
|
(model) => {
|
||||||
|
// Apply the selected model
|
||||||
|
this.agent.setModel(model);
|
||||||
|
|
||||||
|
// Save model change to session
|
||||||
|
this.sessionManager.saveModelChange(`${model.provider}/${model.id}`);
|
||||||
|
|
||||||
|
// Show confirmation message with proper spacing
|
||||||
|
this.chatContainer.addChild(new Spacer(1));
|
||||||
|
const confirmText = new Text(chalk.dim(`Model: ${model.id}`), 1, 0);
|
||||||
|
this.chatContainer.addChild(confirmText);
|
||||||
|
|
||||||
|
// Hide selector and show editor again
|
||||||
|
this.hideModelSelector();
|
||||||
|
this.ui.requestRender();
|
||||||
|
},
|
||||||
|
() => {
|
||||||
|
// Just hide the selector
|
||||||
|
this.hideModelSelector();
|
||||||
|
this.ui.requestRender();
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
// Replace editor with selector
|
||||||
|
this.editorContainer.clear();
|
||||||
|
this.editorContainer.addChild(this.modelSelector);
|
||||||
|
this.ui.setFocus(this.modelSelector);
|
||||||
|
this.ui.requestRender();
|
||||||
|
}
|
||||||
|
|
||||||
|
private hideModelSelector(): void {
|
||||||
|
// Replace selector with editor in the container
|
||||||
|
this.editorContainer.clear();
|
||||||
|
this.editorContainer.addChild(this.editor);
|
||||||
|
this.modelSelector = null;
|
||||||
|
this.ui.setFocus(this.editor);
|
||||||
|
}
|
||||||
|
|
||||||
stop(): void {
|
stop(): void {
|
||||||
if (this.loadingAnimation) {
|
if (this.loadingAnimation) {
|
||||||
this.loadingAnimation.stop();
|
this.loadingAnimation.stop();
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue