mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-15 06:04:40 +00:00
fix(coding-agent): honor --model selection, thinking, and --api-key
This commit is contained in:
parent
7ccf809a5d
commit
56342258e1
4 changed files with 286 additions and 21 deletions
|
|
@ -189,7 +189,7 @@ ${chalk.bold("Commands:")}
|
|||
|
||||
${chalk.bold("Options:")}
|
||||
--provider <name> Provider name (default: google)
|
||||
--model <id> Model ID (default: gemini-2.5-flash)
|
||||
--model <pattern> Model pattern or ID (supports "provider/id" and optional ":<thinking>")
|
||||
--api-key <key> API key (defaults to env vars)
|
||||
--system-prompt <text> System prompt (default: coding assistant prompt)
|
||||
--append-system-prompt <text> Append text or file contents to the system prompt
|
||||
|
|
|
|||
|
|
@ -126,7 +126,11 @@ export interface ParsedModelResult {
|
|||
*
|
||||
* @internal Exported for testing
|
||||
*/
|
||||
export function parseModelPattern(pattern: string, availableModels: Model<Api>[]): ParsedModelResult {
|
||||
export function parseModelPattern(
|
||||
pattern: string,
|
||||
availableModels: Model<Api>[],
|
||||
options?: { allowInvalidThinkingLevelFallback?: boolean },
|
||||
): ParsedModelResult {
|
||||
// Try exact match first
|
||||
const exactMatch = tryMatchModel(pattern, availableModels);
|
||||
if (exactMatch) {
|
||||
|
|
@ -145,7 +149,7 @@ export function parseModelPattern(pattern: string, availableModels: Model<Api>[]
|
|||
|
||||
if (isValidThinkingLevel(suffix)) {
|
||||
// Valid thinking level - recurse on prefix and use this level
|
||||
const result = parseModelPattern(prefix, availableModels);
|
||||
const result = parseModelPattern(prefix, availableModels, options);
|
||||
if (result.model) {
|
||||
// Only use this thinking level if no warning from inner recursion
|
||||
return {
|
||||
|
|
@ -156,8 +160,16 @@ export function parseModelPattern(pattern: string, availableModels: Model<Api>[]
|
|||
}
|
||||
return result;
|
||||
} else {
|
||||
// Invalid suffix - recurse on prefix and warn
|
||||
const result = parseModelPattern(prefix, availableModels);
|
||||
// Invalid suffix
|
||||
const allowFallback = options?.allowInvalidThinkingLevelFallback ?? true;
|
||||
if (!allowFallback) {
|
||||
// In strict mode (CLI --model parsing), treat it as part of the model id and fail.
|
||||
// This avoids accidentally resolving to a different model.
|
||||
return { model: undefined, thinkingLevel: undefined, warning: undefined };
|
||||
}
|
||||
|
||||
// Scope mode: recurse on prefix and warn
|
||||
const result = parseModelPattern(prefix, availableModels, options);
|
||||
if (result.model) {
|
||||
return {
|
||||
model: result.model,
|
||||
|
|
@ -240,6 +252,116 @@ export async function resolveModelScope(patterns: string[], modelRegistry: Model
|
|||
return scopedModels;
|
||||
}
|
||||
|
||||
export interface ResolveCliModelResult {
|
||||
model: Model<Api> | undefined;
|
||||
thinkingLevel?: ThinkingLevel;
|
||||
warning: string | undefined;
|
||||
/**
|
||||
* Error message suitable for CLI display.
|
||||
* When set, model will be undefined.
|
||||
*/
|
||||
error: string | undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve a single model from CLI flags.
|
||||
*
|
||||
* Supports:
|
||||
* - --provider <provider> --model <pattern>
|
||||
* - --model <provider>/<pattern>
|
||||
* - Fuzzy matching (same rules as model scoping: exact id, then partial id/name)
|
||||
*
|
||||
* Note: This does not apply the thinking level by itself, but it may *parse* and
|
||||
* return a thinking level from "<pattern>:<thinking>" so the caller can apply it.
|
||||
*/
|
||||
export async function resolveCliModel(options: {
|
||||
cliProvider?: string;
|
||||
cliModel?: string;
|
||||
modelRegistry: ModelRegistry;
|
||||
}): Promise<ResolveCliModelResult> {
|
||||
const { cliProvider, cliModel, modelRegistry } = options;
|
||||
|
||||
if (!cliModel) {
|
||||
return { model: undefined, warning: undefined, error: undefined };
|
||||
}
|
||||
|
||||
// Important: use *all* models here, not just models with pre-configured auth.
|
||||
// This allows "--api-key" to be used for first-time setup.
|
||||
const availableModels = modelRegistry.getAll();
|
||||
if (availableModels.length === 0) {
|
||||
return {
|
||||
model: undefined,
|
||||
warning: undefined,
|
||||
error: "No models available. Check your installation or add models to models.json.",
|
||||
};
|
||||
}
|
||||
|
||||
// Build canonical provider lookup (case-insensitive)
|
||||
const providerMap = new Map<string, string>();
|
||||
for (const m of availableModels) {
|
||||
providerMap.set(m.provider.toLowerCase(), m.provider);
|
||||
}
|
||||
|
||||
let provider = cliProvider ? providerMap.get(cliProvider.toLowerCase()) : undefined;
|
||||
if (cliProvider && !provider) {
|
||||
return {
|
||||
model: undefined,
|
||||
warning: undefined,
|
||||
error: `Unknown provider "${cliProvider}". Use --list-models to see available providers/models.`,
|
||||
};
|
||||
}
|
||||
|
||||
// If no explicit --provider, first try exact matches without any provider inference.
|
||||
// This avoids misinterpreting model IDs that themselves contain slashes (e.g. OpenRouter-style IDs).
|
||||
if (!provider) {
|
||||
const lower = cliModel.toLowerCase();
|
||||
const exact = availableModels.find(
|
||||
(m) => m.id.toLowerCase() === lower || `${m.provider}/${m.id}`.toLowerCase() === lower,
|
||||
);
|
||||
if (exact) {
|
||||
return { model: exact, warning: undefined, thinkingLevel: undefined, error: undefined };
|
||||
}
|
||||
}
|
||||
|
||||
let pattern = cliModel;
|
||||
|
||||
// If no explicit --provider, allow --model provider/<pattern>
|
||||
if (!provider) {
|
||||
const slashIndex = cliModel.indexOf("/");
|
||||
if (slashIndex !== -1) {
|
||||
const maybeProvider = cliModel.substring(0, slashIndex);
|
||||
const canonical = providerMap.get(maybeProvider.toLowerCase());
|
||||
if (canonical) {
|
||||
provider = canonical;
|
||||
pattern = cliModel.substring(slashIndex + 1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If both were provided, tolerate --model <provider>/<pattern> by stripping the provider prefix
|
||||
const prefix = `${provider}/`;
|
||||
if (cliModel.toLowerCase().startsWith(prefix.toLowerCase())) {
|
||||
pattern = cliModel.substring(prefix.length);
|
||||
}
|
||||
}
|
||||
|
||||
const candidates = provider ? availableModels.filter((m) => m.provider === provider) : availableModels;
|
||||
const { model, thinkingLevel, warning } = parseModelPattern(pattern, candidates, {
|
||||
allowInvalidThinkingLevelFallback: false,
|
||||
});
|
||||
|
||||
if (!model) {
|
||||
const display = provider ? `${provider}/${pattern}` : cliModel;
|
||||
return {
|
||||
model: undefined,
|
||||
thinkingLevel: undefined,
|
||||
warning,
|
||||
error: `Model "${display}" not found. Use --list-models to see available models.`,
|
||||
};
|
||||
}
|
||||
|
||||
return { model, thinkingLevel, warning, error: undefined };
|
||||
}
|
||||
|
||||
export interface InitialModelResult {
|
||||
model: Model<Api> | undefined;
|
||||
thinkingLevel: ThinkingLevel;
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ import { exportFromFile } from "./core/export-html/index.js";
|
|||
import type { LoadExtensionsResult } from "./core/extensions/index.js";
|
||||
import { KeybindingsManager } from "./core/keybindings.js";
|
||||
import { ModelRegistry } from "./core/model-registry.js";
|
||||
import { resolveModelScope, type ScopedModel } from "./core/model-resolver.js";
|
||||
import { resolveCliModel, resolveModelScope, type ScopedModel } from "./core/model-resolver.js";
|
||||
import { DefaultPackageManager } from "./core/package-manager.js";
|
||||
import { DefaultResourceLoader } from "./core/resource-loader.js";
|
||||
import { type CreateAgentSessionOptions, createAgentSession } from "./core/sdk.js";
|
||||
|
|
@ -403,28 +403,48 @@ async function createSessionManager(parsed: Args, cwd: string): Promise<SessionM
|
|||
return undefined;
|
||||
}
|
||||
|
||||
function buildSessionOptions(
|
||||
async function buildSessionOptions(
|
||||
parsed: Args,
|
||||
scopedModels: ScopedModel[],
|
||||
sessionManager: SessionManager | undefined,
|
||||
modelRegistry: ModelRegistry,
|
||||
settingsManager: SettingsManager,
|
||||
): CreateAgentSessionOptions {
|
||||
): Promise<{ options: CreateAgentSessionOptions; cliThinkingFromModel: boolean }> {
|
||||
const options: CreateAgentSessionOptions = {};
|
||||
let cliThinkingFromModel = false;
|
||||
|
||||
if (sessionManager) {
|
||||
options.sessionManager = sessionManager;
|
||||
}
|
||||
|
||||
// Model from CLI
|
||||
if (parsed.provider && parsed.model) {
|
||||
const model = modelRegistry.find(parsed.provider, parsed.model);
|
||||
if (!model) {
|
||||
console.error(chalk.red(`Model ${parsed.provider}/${parsed.model} not found`));
|
||||
// - supports --provider <name> --model <pattern>
|
||||
// - supports --model <provider>/<pattern>
|
||||
if (parsed.model) {
|
||||
const resolved = await resolveCliModel({
|
||||
cliProvider: parsed.provider,
|
||||
cliModel: parsed.model,
|
||||
modelRegistry,
|
||||
});
|
||||
if (resolved.warning) {
|
||||
console.warn(chalk.yellow(`Warning: ${resolved.warning}`));
|
||||
}
|
||||
if (resolved.error) {
|
||||
console.error(chalk.red(resolved.error));
|
||||
process.exit(1);
|
||||
}
|
||||
options.model = model;
|
||||
} else if (scopedModels.length > 0 && !parsed.continue && !parsed.resume) {
|
||||
if (resolved.model) {
|
||||
options.model = resolved.model;
|
||||
// Allow "--model <pattern>:<thinking>" as a shorthand.
|
||||
// Explicit --thinking still takes precedence (applied later).
|
||||
if (!parsed.thinking && resolved.thinkingLevel) {
|
||||
options.thinkingLevel = resolved.thinkingLevel;
|
||||
cliThinkingFromModel = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!options.model && scopedModels.length > 0 && !parsed.continue && !parsed.resume) {
|
||||
// Check if saved default is in scoped models - use it if so, otherwise first scoped model
|
||||
const savedProvider = settingsManager.getDefaultProvider();
|
||||
const savedModelId = settingsManager.getDefaultModel();
|
||||
|
|
@ -476,7 +496,7 @@ function buildSessionOptions(
|
|||
options.tools = parsed.tools.map((name) => allTools[name]);
|
||||
}
|
||||
|
||||
return options;
|
||||
return { options, cliThinkingFromModel };
|
||||
}
|
||||
|
||||
async function handleConfigCommand(args: string[]): Promise<boolean> {
|
||||
|
|
@ -650,7 +670,13 @@ export async function main(args: string[]) {
|
|||
sessionManager = SessionManager.open(selectedPath);
|
||||
}
|
||||
|
||||
const sessionOptions = buildSessionOptions(parsed, scopedModels, sessionManager, modelRegistry, settingsManager);
|
||||
const { options: sessionOptions, cliThinkingFromModel } = await buildSessionOptions(
|
||||
parsed,
|
||||
scopedModels,
|
||||
sessionManager,
|
||||
modelRegistry,
|
||||
settingsManager,
|
||||
);
|
||||
sessionOptions.authStorage = authStorage;
|
||||
sessionOptions.modelRegistry = modelRegistry;
|
||||
sessionOptions.resourceLoader = resourceLoader;
|
||||
|
|
@ -658,7 +684,9 @@ export async function main(args: string[]) {
|
|||
// Handle CLI --api-key as runtime override (not persisted)
|
||||
if (parsed.apiKey) {
|
||||
if (!sessionOptions.model) {
|
||||
console.error(chalk.red("--api-key requires a model to be specified via --provider/--model or -m/--models"));
|
||||
console.error(
|
||||
chalk.red("--api-key requires a model to be specified via --model, --provider/--model, or --models"),
|
||||
);
|
||||
process.exit(1);
|
||||
}
|
||||
authStorage.setRuntimeApiKey(sessionOptions.model.provider, parsed.apiKey);
|
||||
|
|
@ -674,9 +702,11 @@ export async function main(args: string[]) {
|
|||
process.exit(1);
|
||||
}
|
||||
|
||||
// Clamp thinking level to model capabilities (for CLI override case)
|
||||
if (session.model && parsed.thinking) {
|
||||
let effectiveThinking = parsed.thinking;
|
||||
// Clamp thinking level to model capabilities for CLI-provided thinking levels.
|
||||
// This covers both --thinking <level> and --model <pattern>:<thinking>.
|
||||
const cliThinkingOverride = parsed.thinking !== undefined || cliThinkingFromModel;
|
||||
if (session.model && cliThinkingOverride) {
|
||||
let effectiveThinking = session.thinkingLevel;
|
||||
if (!session.model.reasoning) {
|
||||
effectiveThinking = "off";
|
||||
} else if (effectiveThinking === "xhigh" && !supportsXhigh(session.model)) {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
import type { Model } from "@mariozechner/pi-ai";
|
||||
import { describe, expect, test } from "vitest";
|
||||
import { defaultModelPerProvider, findInitialModel, parseModelPattern } from "../src/core/model-resolver.js";
|
||||
import {
|
||||
defaultModelPerProvider,
|
||||
findInitialModel,
|
||||
parseModelPattern,
|
||||
resolveCliModel,
|
||||
} from "../src/core/model-resolver.js";
|
||||
|
||||
// Mock models for testing
|
||||
const mockModels: Model<"anthropic-messages">[] = [
|
||||
|
|
@ -201,6 +206,114 @@ describe("parseModelPattern", () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe("resolveCliModel", () => {
|
||||
test("resolves --model provider/id without --provider", async () => {
|
||||
const registry = {
|
||||
getAll: () => allModels,
|
||||
} as unknown as Parameters<typeof resolveCliModel>[0]["modelRegistry"];
|
||||
|
||||
const result = await resolveCliModel({
|
||||
cliModel: "openai/gpt-4o",
|
||||
modelRegistry: registry,
|
||||
});
|
||||
|
||||
expect(result.error).toBeUndefined();
|
||||
expect(result.model?.provider).toBe("openai");
|
||||
expect(result.model?.id).toBe("gpt-4o");
|
||||
});
|
||||
|
||||
test("resolves fuzzy patterns within an explicit provider", async () => {
|
||||
const registry = {
|
||||
getAll: () => allModels,
|
||||
} as unknown as Parameters<typeof resolveCliModel>[0]["modelRegistry"];
|
||||
|
||||
const result = await resolveCliModel({
|
||||
cliProvider: "openai",
|
||||
cliModel: "4o",
|
||||
modelRegistry: registry,
|
||||
});
|
||||
|
||||
expect(result.error).toBeUndefined();
|
||||
expect(result.model?.provider).toBe("openai");
|
||||
expect(result.model?.id).toBe("gpt-4o");
|
||||
});
|
||||
|
||||
test("supports --model <pattern>:<thinking> (without explicit --thinking)", async () => {
|
||||
const registry = {
|
||||
getAll: () => allModels,
|
||||
} as unknown as Parameters<typeof resolveCliModel>[0]["modelRegistry"];
|
||||
|
||||
const result = await resolveCliModel({
|
||||
cliModel: "sonnet:high",
|
||||
modelRegistry: registry,
|
||||
});
|
||||
|
||||
expect(result.error).toBeUndefined();
|
||||
expect(result.model?.id).toBe("claude-sonnet-4-5");
|
||||
expect(result.thinkingLevel).toBe("high");
|
||||
});
|
||||
|
||||
test("prefers exact model id match over provider inference (OpenRouter-style ids)", async () => {
|
||||
const registry = {
|
||||
getAll: () => allModels,
|
||||
} as unknown as Parameters<typeof resolveCliModel>[0]["modelRegistry"];
|
||||
|
||||
const result = await resolveCliModel({
|
||||
cliModel: "openai/gpt-4o:extended",
|
||||
modelRegistry: registry,
|
||||
});
|
||||
|
||||
expect(result.error).toBeUndefined();
|
||||
expect(result.model?.provider).toBe("openrouter");
|
||||
expect(result.model?.id).toBe("openai/gpt-4o:extended");
|
||||
});
|
||||
|
||||
test("does not strip invalid :suffix as thinking level in --model (fail fast)", async () => {
|
||||
const registry = {
|
||||
getAll: () => allModels,
|
||||
} as unknown as Parameters<typeof resolveCliModel>[0]["modelRegistry"];
|
||||
|
||||
const result = await resolveCliModel({
|
||||
cliProvider: "openai",
|
||||
cliModel: "gpt-4o:extended",
|
||||
modelRegistry: registry,
|
||||
});
|
||||
|
||||
expect(result.model).toBeUndefined();
|
||||
expect(result.error).toContain("not found");
|
||||
});
|
||||
|
||||
test("returns a clear error when there are no models", async () => {
|
||||
const registry = {
|
||||
getAll: () => [],
|
||||
} as unknown as Parameters<typeof resolveCliModel>[0]["modelRegistry"];
|
||||
|
||||
const result = await resolveCliModel({
|
||||
cliProvider: "openai",
|
||||
cliModel: "gpt-4o",
|
||||
modelRegistry: registry,
|
||||
});
|
||||
|
||||
expect(result.model).toBeUndefined();
|
||||
expect(result.error).toContain("No models available");
|
||||
});
|
||||
|
||||
test("resolves provider-prefixed fuzzy patterns (openrouter/qwen -> openrouter model)", async () => {
|
||||
const registry = {
|
||||
getAll: () => allModels,
|
||||
} as unknown as Parameters<typeof resolveCliModel>[0]["modelRegistry"];
|
||||
|
||||
const result = await resolveCliModel({
|
||||
cliModel: "openrouter/qwen",
|
||||
modelRegistry: registry,
|
||||
});
|
||||
|
||||
expect(result.error).toBeUndefined();
|
||||
expect(result.model?.provider).toBe("openrouter");
|
||||
expect(result.model?.id).toBe("qwen/qwen3-coder:exacto");
|
||||
});
|
||||
});
|
||||
|
||||
describe("default model selection", () => {
|
||||
test("ai-gateway default is opus 4.6", () => {
|
||||
expect(defaultModelPerProvider["vercel-ai-gateway"]).toBe("anthropic/claude-opus-4-6");
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue