diff --git a/packages/coding-agent/src/main.ts b/packages/coding-agent/src/main.ts index 83931196..0e61fd50 100644 --- a/packages/coding-agent/src/main.ts +++ b/packages/coding-agent/src/main.ts @@ -142,8 +142,8 @@ async function runInteractiveMode(agent: Agent, sessionManager: SessionManager, try { await agent.prompt(userInput); } catch (error: any) { - // Error handling - errors should be in agent state - console.error("Error:", error.message); + // Display error in the TUI by adding an error message to the chat + renderer.showError(error.message || "Unknown error occurred"); } } } @@ -182,26 +182,55 @@ export async function main(args: string[]) { const provider = (parsed.provider || "anthropic") as any; const modelId = parsed.model || "claude-sonnet-4-5"; - // Get API key - let apiKey = parsed.apiKey; - if (!apiKey) { - const envVarMap: Record = { - google: "GEMINI_API_KEY", - openai: "OPENAI_API_KEY", - anthropic: "ANTHROPIC_OAUTH_TOKEN", - xai: "XAI_API_KEY", - groq: "GROQ_API_KEY", - cerebras: "CEREBRAS_API_KEY", - zai: "ZAI_API_KEY", - }; - const envVar = envVarMap[provider] || `${provider.toUpperCase()}_API_KEY`; - apiKey = process.env[envVar]; - - if (!apiKey) { - console.error(chalk.red(`Error: No API key found for provider "${provider}"`)); - console.error(chalk.dim(`Set ${envVar} environment variable or use --api-key flag`)); - process.exit(1); + // Helper function to get API key for a provider + const getApiKeyForProvider = (providerName: string): string | undefined => { + // Check if API key was provided via command line + if (parsed.apiKey) { + return parsed.apiKey; } + + const envVarMap: Record = { + google: ["GEMINI_API_KEY"], + openai: ["OPENAI_API_KEY"], + anthropic: ["ANTHROPIC_OAUTH_TOKEN", "ANTHROPIC_API_KEY"], + xai: ["XAI_API_KEY"], + groq: ["GROQ_API_KEY"], + cerebras: ["CEREBRAS_API_KEY"], + zai: ["ZAI_API_KEY"], + moonshotai: ["MOONSHOT_API_KEY"], + }; + + const envVars = envVarMap[providerName] || [`${providerName.toUpperCase()}_API_KEY`]; + + // Check each environment variable in priority order + for (const envVar of envVars) { + const key = process.env[envVar]; + if (key) { + return key; + } + } + + return undefined; + }; + + // Get initial API key + const initialApiKey = getApiKeyForProvider(provider); + if (!initialApiKey) { + const envVarMap: Record = { + google: ["GEMINI_API_KEY"], + openai: ["OPENAI_API_KEY"], + anthropic: ["ANTHROPIC_OAUTH_TOKEN", "ANTHROPIC_API_KEY"], + xai: ["XAI_API_KEY"], + groq: ["GROQ_API_KEY"], + cerebras: ["CEREBRAS_API_KEY"], + zai: ["ZAI_API_KEY"], + moonshotai: ["MOONSHOT_API_KEY"], + }; + const envVars = envVarMap[provider] || [`${provider.toUpperCase()}_API_KEY`]; + const envVarList = envVars.join(" or "); + console.error(chalk.red(`Error: No API key found for provider "${provider}"`)); + console.error(chalk.dim(`Set ${envVarList} environment variable or use --api-key flag`)); + process.exit(1); } // Create agent @@ -216,7 +245,17 @@ export async function main(args: string[]) { tools: codingTools, }, transport: new ProviderTransport({ - getApiKey: async () => apiKey!, + // Dynamic API key lookup based on current model's provider + getApiKey: async () => { + const currentProvider = agent.state.model.provider; + const key = getApiKeyForProvider(currentProvider); + if (!key) { + throw new Error( + `No API key found for provider "${currentProvider}". Please set the appropriate environment variable.`, + ); + } + return key; + }, }), }); @@ -228,6 +267,22 @@ export async function main(args: string[]) { agent.replaceMessages(messages); } + // Load and restore model + const savedModel = sessionManager.loadModel(); + if (savedModel) { + // Parse provider/modelId from saved model string (format: "provider/modelId") + const [savedProvider, savedModelId] = savedModel.split("/"); + if (savedProvider && savedModelId) { + try { + const restoredModel = getModel(savedProvider as any, savedModelId); + agent.setModel(restoredModel); + console.log(chalk.dim(`Restored model: ${savedModel}`)); + } catch (error: any) { + console.error(chalk.yellow(`Warning: Could not restore model ${savedModel}: ${error.message}`)); + } + } + } + // Load and restore thinking level const thinkingLevel = sessionManager.loadThinkingLevel() as ThinkingLevel; if (thinkingLevel) { diff --git a/packages/coding-agent/src/session-manager.ts b/packages/coding-agent/src/session-manager.ts index e0219c86..e984b0ec 100644 --- a/packages/coding-agent/src/session-manager.ts +++ b/packages/coding-agent/src/session-manager.ts @@ -212,6 +212,29 @@ export class SessionManager { return lastThinkingLevel; } + loadModel(): string | null { + if (!existsSync(this.sessionFile)) return null; + + const lines = readFileSync(this.sessionFile, "utf8").trim().split("\n"); + + // Find the most recent model (from session header or change event) + let lastModel: string | null = null; + for (const line of lines) { + try { + const entry = JSON.parse(line); + if (entry.type === "session" && entry.model) { + lastModel = entry.model; + } else if (entry.type === "model_change" && entry.model) { + lastModel = entry.model; + } + } catch { + // Skip malformed lines + } + } + + return lastModel; + } + getSessionId(): string { return this.sessionId; } diff --git a/packages/coding-agent/src/tui/model-selector.ts b/packages/coding-agent/src/tui/model-selector.ts new file mode 100644 index 00000000..8e1c5daa --- /dev/null +++ b/packages/coding-agent/src/tui/model-selector.ts @@ -0,0 +1,188 @@ +import { getModels, getProviders, type Model } from "@mariozechner/pi-ai"; +import { Container, Input, Spacer, Text } from "@mariozechner/pi-tui"; +import chalk from "chalk"; + +interface ModelItem { + provider: string; + id: string; + model: Model; +} + +/** + * Component that renders a model selector with search + */ +export class ModelSelectorComponent extends Container { + private searchInput: Input; + private listContainer: Container; + private allModels: ModelItem[] = []; + private filteredModels: ModelItem[] = []; + private selectedIndex: number = 0; + private currentModel: Model; + private onSelectCallback: (model: Model) => void; + private onCancelCallback: () => void; + + constructor(currentModel: Model, onSelect: (model: Model) => void, onCancel: () => void) { + super(); + + this.currentModel = currentModel; + this.onSelectCallback = onSelect; + this.onCancelCallback = onCancel; + + // Load all models + this.loadModels(); + + // Add top border + this.addChild(new Text(chalk.blue("─".repeat(80)), 0, 0)); + this.addChild(new Spacer(1)); + + // Create search input + this.searchInput = new Input(); + this.searchInput.onSubmit = () => { + // Enter on search input selects the first filtered item + if (this.filteredModels[this.selectedIndex]) { + this.handleSelect(this.filteredModels[this.selectedIndex].model); + } + }; + this.addChild(this.searchInput); + + this.addChild(new Spacer(1)); + + // Create list container + this.listContainer = new Container(); + this.addChild(this.listContainer); + + this.addChild(new Spacer(1)); + + // Add bottom border + this.addChild(new Text(chalk.blue("─".repeat(80)), 0, 0)); + + // Initial render + this.updateList(); + } + + private loadModels(): void { + const models: ModelItem[] = []; + const providers = getProviders(); + + for (const provider of providers) { + const providerModels = getModels(provider as any); + for (const model of providerModels) { + models.push({ provider, id: model.id, model }); + } + } + + // Sort: current model first, then by provider + models.sort((a, b) => { + const aIsCurrent = this.currentModel?.id === a.model.id; + const bIsCurrent = this.currentModel?.id === b.model.id; + if (aIsCurrent && !bIsCurrent) return -1; + if (!aIsCurrent && bIsCurrent) return 1; + return a.provider.localeCompare(b.provider); + }); + + this.allModels = models; + this.filteredModels = models; + } + + private filterModels(query: string): void { + if (!query.trim()) { + this.filteredModels = this.allModels; + } else { + const searchTokens = query + .toLowerCase() + .split(/\s+/) + .filter((t) => t); + this.filteredModels = this.allModels.filter(({ provider, id, model }) => { + const searchText = `${provider} ${id} ${model.name}`.toLowerCase(); + return searchTokens.every((token) => searchText.includes(token)); + }); + } + + this.selectedIndex = Math.min(this.selectedIndex, Math.max(0, this.filteredModels.length - 1)); + this.updateList(); + } + + private updateList(): void { + this.listContainer.clear(); + + const maxVisible = 10; + const startIndex = Math.max( + 0, + Math.min(this.selectedIndex - Math.floor(maxVisible / 2), this.filteredModels.length - maxVisible), + ); + const endIndex = Math.min(startIndex + maxVisible, this.filteredModels.length); + + // Show visible slice of filtered models + for (let i = startIndex; i < endIndex; i++) { + const item = this.filteredModels[i]; + if (!item) continue; + + const isSelected = i === this.selectedIndex; + const isCurrent = this.currentModel?.id === item.model.id; + + let line = ""; + if (isSelected) { + const prefix = chalk.blue("→ "); + const modelText = `${item.id}`; + const providerBadge = chalk.gray(`[${item.provider}]`); + const checkmark = isCurrent ? chalk.green(" ✓") : ""; + line = prefix + chalk.blue(modelText) + " " + providerBadge + checkmark; + } else { + const modelText = ` ${item.id}`; + const providerBadge = chalk.gray(`[${item.provider}]`); + const checkmark = isCurrent ? chalk.green(" ✓") : ""; + line = modelText + " " + providerBadge + checkmark; + } + + this.listContainer.addChild(new Text(line, 0, 0)); + } + + // Add scroll indicator if needed + if (startIndex > 0 || endIndex < this.filteredModels.length) { + const scrollInfo = chalk.gray(` (${this.selectedIndex + 1}/${this.filteredModels.length})`); + this.listContainer.addChild(new Text(scrollInfo, 0, 0)); + } + + // Show "no results" if empty + if (this.filteredModels.length === 0) { + this.listContainer.addChild(new Text(chalk.gray(" No matching models"), 0, 0)); + } + } + + handleInput(keyData: string): void { + // Up arrow + if (keyData === "\x1b[A") { + this.selectedIndex = Math.max(0, this.selectedIndex - 1); + this.updateList(); + } + // Down arrow + else if (keyData === "\x1b[B") { + this.selectedIndex = Math.min(this.filteredModels.length - 1, this.selectedIndex + 1); + this.updateList(); + } + // Enter + else if (keyData === "\r") { + const selectedModel = this.filteredModels[this.selectedIndex]; + if (selectedModel) { + this.handleSelect(selectedModel.model); + } + } + // Escape + else if (keyData === "\x1b") { + this.onCancelCallback(); + } + // Pass everything else to search input + else { + this.searchInput.handleInput(keyData); + this.filterModels(this.searchInput.getValue()); + } + } + + private handleSelect(model: Model): void { + this.onSelectCallback(model); + } + + getSearchInput(): Input { + return this.searchInput; + } +} diff --git a/packages/coding-agent/src/tui/tui-renderer.ts b/packages/coding-agent/src/tui/tui-renderer.ts index 6e0bac2e..b6f32574 100644 --- a/packages/coding-agent/src/tui/tui-renderer.ts +++ b/packages/coding-agent/src/tui/tui-renderer.ts @@ -15,6 +15,7 @@ import type { SessionManager } from "../session-manager.js"; import { AssistantMessageComponent } from "./assistant-message.js"; import { CustomEditor } from "./custom-editor.js"; import { FooterComponent } from "./footer.js"; +import { ModelSelectorComponent } from "./model-selector.js"; import { ThinkingSelectorComponent } from "./thinking-selector.js"; import { ToolExecutionComponent } from "./tool-execution.js"; import { UserMessageComponent } from "./user-message.js"; @@ -47,6 +48,9 @@ export class TuiRenderer { // Thinking level selector private thinkingSelector: ThinkingSelectorComponent | null = null; + // Model selector + private modelSelector: ModelSelectorComponent | null = null; + // Track if this is the first user message (to skip spacer) private isFirstUserMessage = true; @@ -68,8 +72,13 @@ export class TuiRenderer { description: "Select reasoning level (opens selector UI)", }; + const modelCommand: SlashCommand = { + name: "model", + description: "Select model (opens selector UI)", + }; + // Setup autocomplete for file paths and slash commands - const autocompleteProvider = new CombinedAutocompleteProvider([thinkingCommand], process.cwd()); + const autocompleteProvider = new CombinedAutocompleteProvider([thinkingCommand, modelCommand], process.cwd()); this.editor.setAutocompleteProvider(autocompleteProvider); } @@ -134,6 +143,14 @@ export class TuiRenderer { return; } + // Check for /model command + if (text === "/model") { + // Show model selector + this.showModelSelector(); + this.editor.setText(""); + return; + } + if (this.onInputCallback) { this.onInputCallback(text); } @@ -404,6 +421,13 @@ export class TuiRenderer { this.ui.requestRender(); } + showError(errorMessage: string): void { + // Show error message in the chat + this.chatContainer.addChild(new Spacer(1)); + this.chatContainer.addChild(new Text(chalk.red(`Error: ${errorMessage}`), 1, 0)); + this.ui.requestRender(); + } + private showThinkingSelector(): void { // Create thinking selector with current level this.thinkingSelector = new ThinkingSelectorComponent( @@ -446,6 +470,48 @@ export class TuiRenderer { this.ui.setFocus(this.editor); } + private showModelSelector(): void { + // Create model selector with current model + this.modelSelector = new ModelSelectorComponent( + this.agent.state.model, + (model) => { + // Apply the selected model + this.agent.setModel(model); + + // Save model change to session + this.sessionManager.saveModelChange(`${model.provider}/${model.id}`); + + // Show confirmation message with proper spacing + this.chatContainer.addChild(new Spacer(1)); + const confirmText = new Text(chalk.dim(`Model: ${model.id}`), 1, 0); + this.chatContainer.addChild(confirmText); + + // Hide selector and show editor again + this.hideModelSelector(); + this.ui.requestRender(); + }, + () => { + // Just hide the selector + this.hideModelSelector(); + this.ui.requestRender(); + }, + ); + + // Replace editor with selector + this.editorContainer.clear(); + this.editorContainer.addChild(this.modelSelector); + this.ui.setFocus(this.modelSelector); + this.ui.requestRender(); + } + + private hideModelSelector(): void { + // Replace selector with editor in the container + this.editorContainer.clear(); + this.editorContainer.addChild(this.editor); + this.modelSelector = null; + this.ui.setFocus(this.editor); + } + stop(): void { if (this.loadingAnimation) { this.loadingAnimation.stop();