mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-16 03:01:56 +00:00
401 lines
13 KiB
TypeScript
401 lines
13 KiB
TypeScript
/**
|
|
* Model resolution, scoping, and initial selection
|
|
*/
|
|
|
|
import type { ThinkingLevel } from "@mariozechner/pi-agent-core";
|
|
import { type Api, type KnownProvider, type Model, modelsAreEqual } from "@mariozechner/pi-ai";
|
|
import chalk from "chalk";
|
|
import { minimatch } from "minimatch";
|
|
import { isValidThinkingLevel } from "../cli/args.js";
|
|
import type { ModelRegistry } from "./model-registry.js";
|
|
|
|
/** Default model IDs for each known provider */
|
|
export const defaultModelPerProvider: Record<KnownProvider, string> = {
|
|
"amazon-bedrock": "global.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
|
anthropic: "claude-sonnet-4-5",
|
|
openai: "gpt-5.1-codex",
|
|
"openai-codex": "gpt-5.2-codex",
|
|
google: "gemini-2.5-pro",
|
|
"google-gemini-cli": "gemini-2.5-pro",
|
|
"google-antigravity": "gemini-3-pro-high",
|
|
"google-vertex": "gemini-3-pro-preview",
|
|
"github-copilot": "gpt-4o",
|
|
openrouter: "openai/gpt-5.1-codex",
|
|
"vercel-ai-gateway": "anthropic/claude-opus-4.5",
|
|
xai: "grok-4-fast-non-reasoning",
|
|
groq: "openai/gpt-oss-120b",
|
|
cerebras: "zai-glm-4.6",
|
|
zai: "glm-4.6",
|
|
mistral: "devstral-medium-latest",
|
|
minimax: "MiniMax-M2.1",
|
|
"minimax-cn": "MiniMax-M2.1",
|
|
opencode: "claude-opus-4-5",
|
|
};
|
|
|
|
export interface ScopedModel {
|
|
model: Model<Api>;
|
|
/** Thinking level if explicitly specified in pattern (e.g., "model:high"), undefined otherwise */
|
|
thinkingLevel?: ThinkingLevel;
|
|
}
|
|
|
|
/**
|
|
* Helper to check if a model ID looks like an alias (no date suffix)
|
|
* Dates are typically in format: -20241022 or -20250929
|
|
*/
|
|
function isAlias(id: string): boolean {
|
|
// Check if ID ends with -latest
|
|
if (id.endsWith("-latest")) return true;
|
|
|
|
// Check if ID ends with a date pattern (-YYYYMMDD)
|
|
const datePattern = /-\d{8}$/;
|
|
return !datePattern.test(id);
|
|
}
|
|
|
|
/**
|
|
* Try to match a pattern to a model from the available models list.
|
|
* Returns the matched model or undefined if no match found.
|
|
*/
|
|
function tryMatchModel(modelPattern: string, availableModels: Model<Api>[]): Model<Api> | undefined {
|
|
// Check for provider/modelId format (provider is everything before the first /)
|
|
const slashIndex = modelPattern.indexOf("/");
|
|
if (slashIndex !== -1) {
|
|
const provider = modelPattern.substring(0, slashIndex);
|
|
const modelId = modelPattern.substring(slashIndex + 1);
|
|
const providerMatch = availableModels.find(
|
|
(m) => m.provider.toLowerCase() === provider.toLowerCase() && m.id.toLowerCase() === modelId.toLowerCase(),
|
|
);
|
|
if (providerMatch) {
|
|
return providerMatch;
|
|
}
|
|
// No exact provider/model match - fall through to other matching
|
|
}
|
|
|
|
// Check for exact ID match (case-insensitive)
|
|
const exactMatch = availableModels.find((m) => m.id.toLowerCase() === modelPattern.toLowerCase());
|
|
if (exactMatch) {
|
|
return exactMatch;
|
|
}
|
|
|
|
// No exact match - fall back to partial matching
|
|
const matches = availableModels.filter(
|
|
(m) =>
|
|
m.id.toLowerCase().includes(modelPattern.toLowerCase()) ||
|
|
m.name?.toLowerCase().includes(modelPattern.toLowerCase()),
|
|
);
|
|
|
|
if (matches.length === 0) {
|
|
return undefined;
|
|
}
|
|
|
|
// Separate into aliases and dated versions
|
|
const aliases = matches.filter((m) => isAlias(m.id));
|
|
const datedVersions = matches.filter((m) => !isAlias(m.id));
|
|
|
|
if (aliases.length > 0) {
|
|
// Prefer alias - if multiple aliases, pick the one that sorts highest
|
|
aliases.sort((a, b) => b.id.localeCompare(a.id));
|
|
return aliases[0];
|
|
} else {
|
|
// No alias found, pick latest dated version
|
|
datedVersions.sort((a, b) => b.id.localeCompare(a.id));
|
|
return datedVersions[0];
|
|
}
|
|
}
|
|
|
|
export interface ParsedModelResult {
|
|
model: Model<Api> | undefined;
|
|
/** Thinking level if explicitly specified in pattern, undefined otherwise */
|
|
thinkingLevel?: ThinkingLevel;
|
|
warning: string | undefined;
|
|
}
|
|
|
|
/**
|
|
* Parse a pattern to extract model and thinking level.
|
|
* Handles models with colons in their IDs (e.g., OpenRouter's :exacto suffix).
|
|
*
|
|
* Algorithm:
|
|
* 1. Try to match full pattern as a model
|
|
* 2. If found, return it with "off" thinking level
|
|
* 3. If not found and has colons, split on last colon:
|
|
* - If suffix is valid thinking level, use it and recurse on prefix
|
|
* - If suffix is invalid, warn and recurse on prefix with "off"
|
|
*
|
|
* @internal Exported for testing
|
|
*/
|
|
export function parseModelPattern(pattern: string, availableModels: Model<Api>[]): ParsedModelResult {
|
|
// Try exact match first
|
|
const exactMatch = tryMatchModel(pattern, availableModels);
|
|
if (exactMatch) {
|
|
return { model: exactMatch, thinkingLevel: undefined, warning: undefined };
|
|
}
|
|
|
|
// No match - try splitting on last colon if present
|
|
const lastColonIndex = pattern.lastIndexOf(":");
|
|
if (lastColonIndex === -1) {
|
|
// No colons, pattern simply doesn't match any model
|
|
return { model: undefined, thinkingLevel: undefined, warning: undefined };
|
|
}
|
|
|
|
const prefix = pattern.substring(0, lastColonIndex);
|
|
const suffix = pattern.substring(lastColonIndex + 1);
|
|
|
|
if (isValidThinkingLevel(suffix)) {
|
|
// Valid thinking level - recurse on prefix and use this level
|
|
const result = parseModelPattern(prefix, availableModels);
|
|
if (result.model) {
|
|
// Only use this thinking level if no warning from inner recursion
|
|
return {
|
|
model: result.model,
|
|
thinkingLevel: result.warning ? undefined : suffix,
|
|
warning: result.warning,
|
|
};
|
|
}
|
|
return result;
|
|
} else {
|
|
// Invalid suffix - recurse on prefix and warn
|
|
const result = parseModelPattern(prefix, availableModels);
|
|
if (result.model) {
|
|
return {
|
|
model: result.model,
|
|
thinkingLevel: undefined,
|
|
warning: `Invalid thinking level "${suffix}" in pattern "${pattern}". Using default instead.`,
|
|
};
|
|
}
|
|
return result;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Resolve model patterns to actual Model objects with optional thinking levels
|
|
* Format: "pattern:level" where :level is optional
|
|
* For each pattern, finds all matching models and picks the best version:
|
|
* 1. Prefer alias (e.g., claude-sonnet-4-5) over dated versions (claude-sonnet-4-5-20250929)
|
|
* 2. If no alias, pick the latest dated version
|
|
*
|
|
* Supports models with colons in their IDs (e.g., OpenRouter's model:exacto).
|
|
* The algorithm tries to match the full pattern first, then progressively
|
|
* strips colon-suffixes to find a match.
|
|
*/
|
|
export async function resolveModelScope(patterns: string[], modelRegistry: ModelRegistry): Promise<ScopedModel[]> {
|
|
const availableModels = await modelRegistry.getAvailable();
|
|
const scopedModels: ScopedModel[] = [];
|
|
|
|
for (const pattern of patterns) {
|
|
// Check if pattern contains glob characters
|
|
if (pattern.includes("*") || pattern.includes("?") || pattern.includes("[")) {
|
|
// Extract optional thinking level suffix (e.g., "provider/*:high")
|
|
const colonIdx = pattern.lastIndexOf(":");
|
|
let globPattern = pattern;
|
|
let thinkingLevel: ThinkingLevel | undefined;
|
|
|
|
if (colonIdx !== -1) {
|
|
const suffix = pattern.substring(colonIdx + 1);
|
|
if (isValidThinkingLevel(suffix)) {
|
|
thinkingLevel = suffix;
|
|
globPattern = pattern.substring(0, colonIdx);
|
|
}
|
|
}
|
|
|
|
// Match against "provider/modelId" format OR just model ID
|
|
// This allows "*sonnet*" to match without requiring "anthropic/*sonnet*"
|
|
const matchingModels = availableModels.filter((m) => {
|
|
const fullId = `${m.provider}/${m.id}`;
|
|
return minimatch(fullId, globPattern, { nocase: true }) || minimatch(m.id, globPattern, { nocase: true });
|
|
});
|
|
|
|
if (matchingModels.length === 0) {
|
|
console.warn(chalk.yellow(`Warning: No models match pattern "${pattern}"`));
|
|
continue;
|
|
}
|
|
|
|
for (const model of matchingModels) {
|
|
if (!scopedModels.find((sm) => modelsAreEqual(sm.model, model))) {
|
|
scopedModels.push({ model, thinkingLevel });
|
|
}
|
|
}
|
|
continue;
|
|
}
|
|
|
|
const { model, thinkingLevel, warning } = parseModelPattern(pattern, availableModels);
|
|
|
|
if (warning) {
|
|
console.warn(chalk.yellow(`Warning: ${warning}`));
|
|
}
|
|
|
|
if (!model) {
|
|
console.warn(chalk.yellow(`Warning: No models match pattern "${pattern}"`));
|
|
continue;
|
|
}
|
|
|
|
// Avoid duplicates
|
|
if (!scopedModels.find((sm) => modelsAreEqual(sm.model, model))) {
|
|
scopedModels.push({ model, thinkingLevel });
|
|
}
|
|
}
|
|
|
|
return scopedModels;
|
|
}
|
|
|
|
export interface InitialModelResult {
|
|
model: Model<Api> | undefined;
|
|
thinkingLevel: ThinkingLevel;
|
|
fallbackMessage: string | undefined;
|
|
}
|
|
|
|
/**
|
|
* Find the initial model to use based on priority:
|
|
* 1. CLI args (provider + model)
|
|
* 2. First model from scoped models (if not continuing/resuming)
|
|
* 3. Restored from session (if continuing/resuming)
|
|
* 4. Saved default from settings
|
|
* 5. First available model with valid API key
|
|
*/
|
|
export async function findInitialModel(options: {
|
|
cliProvider?: string;
|
|
cliModel?: string;
|
|
scopedModels: ScopedModel[];
|
|
isContinuing: boolean;
|
|
defaultProvider?: string;
|
|
defaultModelId?: string;
|
|
defaultThinkingLevel?: ThinkingLevel;
|
|
modelRegistry: ModelRegistry;
|
|
}): Promise<InitialModelResult> {
|
|
const {
|
|
cliProvider,
|
|
cliModel,
|
|
scopedModels,
|
|
isContinuing,
|
|
defaultProvider,
|
|
defaultModelId,
|
|
defaultThinkingLevel,
|
|
modelRegistry,
|
|
} = options;
|
|
|
|
let model: Model<Api> | undefined;
|
|
let thinkingLevel: ThinkingLevel = "off";
|
|
|
|
// 1. CLI args take priority
|
|
if (cliProvider && cliModel) {
|
|
const found = modelRegistry.find(cliProvider, cliModel);
|
|
if (!found) {
|
|
console.error(chalk.red(`Model ${cliProvider}/${cliModel} not found`));
|
|
process.exit(1);
|
|
}
|
|
return { model: found, thinkingLevel: "off", fallbackMessage: undefined };
|
|
}
|
|
|
|
// 2. Use first model from scoped models (skip if continuing/resuming)
|
|
if (scopedModels.length > 0 && !isContinuing) {
|
|
return {
|
|
model: scopedModels[0].model,
|
|
thinkingLevel: scopedModels[0].thinkingLevel ?? defaultThinkingLevel ?? "off",
|
|
fallbackMessage: undefined,
|
|
};
|
|
}
|
|
|
|
// 3. Try saved default from settings
|
|
if (defaultProvider && defaultModelId) {
|
|
const found = modelRegistry.find(defaultProvider, defaultModelId);
|
|
if (found) {
|
|
model = found;
|
|
if (defaultThinkingLevel) {
|
|
thinkingLevel = defaultThinkingLevel;
|
|
}
|
|
return { model, thinkingLevel, fallbackMessage: undefined };
|
|
}
|
|
}
|
|
|
|
// 4. Try first available model with valid API key
|
|
const availableModels = await modelRegistry.getAvailable();
|
|
|
|
if (availableModels.length > 0) {
|
|
// Try to find a default model from known providers
|
|
for (const provider of Object.keys(defaultModelPerProvider) as KnownProvider[]) {
|
|
const defaultId = defaultModelPerProvider[provider];
|
|
const match = availableModels.find((m) => m.provider === provider && m.id === defaultId);
|
|
if (match) {
|
|
return { model: match, thinkingLevel: "off", fallbackMessage: undefined };
|
|
}
|
|
}
|
|
|
|
// If no default found, use first available
|
|
return { model: availableModels[0], thinkingLevel: "off", fallbackMessage: undefined };
|
|
}
|
|
|
|
// 5. No model found
|
|
return { model: undefined, thinkingLevel: "off", fallbackMessage: undefined };
|
|
}
|
|
|
|
/**
|
|
* Restore model from session, with fallback to available models
|
|
*/
|
|
export async function restoreModelFromSession(
|
|
savedProvider: string,
|
|
savedModelId: string,
|
|
currentModel: Model<Api> | undefined,
|
|
shouldPrintMessages: boolean,
|
|
modelRegistry: ModelRegistry,
|
|
): Promise<{ model: Model<Api> | undefined; fallbackMessage: string | undefined }> {
|
|
const restoredModel = modelRegistry.find(savedProvider, savedModelId);
|
|
|
|
// Check if restored model exists and has a valid API key
|
|
const hasApiKey = restoredModel ? !!(await modelRegistry.getApiKey(restoredModel)) : false;
|
|
|
|
if (restoredModel && hasApiKey) {
|
|
if (shouldPrintMessages) {
|
|
console.log(chalk.dim(`Restored model: ${savedProvider}/${savedModelId}`));
|
|
}
|
|
return { model: restoredModel, fallbackMessage: undefined };
|
|
}
|
|
|
|
// Model not found or no API key - fall back
|
|
const reason = !restoredModel ? "model no longer exists" : "no API key available";
|
|
|
|
if (shouldPrintMessages) {
|
|
console.error(chalk.yellow(`Warning: Could not restore model ${savedProvider}/${savedModelId} (${reason}).`));
|
|
}
|
|
|
|
// If we already have a model, use it as fallback
|
|
if (currentModel) {
|
|
if (shouldPrintMessages) {
|
|
console.log(chalk.dim(`Falling back to: ${currentModel.provider}/${currentModel.id}`));
|
|
}
|
|
return {
|
|
model: currentModel,
|
|
fallbackMessage: `Could not restore model ${savedProvider}/${savedModelId} (${reason}). Using ${currentModel.provider}/${currentModel.id}.`,
|
|
};
|
|
}
|
|
|
|
// Try to find any available model
|
|
const availableModels = await modelRegistry.getAvailable();
|
|
|
|
if (availableModels.length > 0) {
|
|
// Try to find a default model from known providers
|
|
let fallbackModel: Model<Api> | undefined;
|
|
for (const provider of Object.keys(defaultModelPerProvider) as KnownProvider[]) {
|
|
const defaultId = defaultModelPerProvider[provider];
|
|
const match = availableModels.find((m) => m.provider === provider && m.id === defaultId);
|
|
if (match) {
|
|
fallbackModel = match;
|
|
break;
|
|
}
|
|
}
|
|
|
|
// If no default found, use first available
|
|
if (!fallbackModel) {
|
|
fallbackModel = availableModels[0];
|
|
}
|
|
|
|
if (shouldPrintMessages) {
|
|
console.log(chalk.dim(`Falling back to: ${fallbackModel.provider}/${fallbackModel.id}`));
|
|
}
|
|
|
|
return {
|
|
model: fallbackModel,
|
|
fallbackMessage: `Could not restore model ${savedProvider}/${savedModelId} (${reason}). Using ${fallbackModel.provider}/${fallbackModel.id}.`,
|
|
};
|
|
}
|
|
|
|
// No models available
|
|
return { model: undefined, fallbackMessage: undefined };
|
|
}
|