mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-21 13:00:33 +00:00
refactor(oauth): add provider registry
This commit is contained in:
parent
89636cfe6e
commit
3256d3c083
19 changed files with 655 additions and 291 deletions
|
|
@ -2,13 +2,8 @@
|
||||||
|
|
||||||
import { existsSync, readFileSync, writeFileSync } from "fs";
|
import { existsSync, readFileSync, writeFileSync } from "fs";
|
||||||
import { createInterface } from "readline";
|
import { createInterface } from "readline";
|
||||||
import { loginAnthropic } from "./utils/oauth/anthropic.js";
|
import { getOAuthProvider, getOAuthProviders } from "./utils/oauth/index.js";
|
||||||
import { loginGitHubCopilot } from "./utils/oauth/github-copilot.js";
|
import type { OAuthCredentials, OAuthProviderId } from "./utils/oauth/types.js";
|
||||||
import { loginAntigravity } from "./utils/oauth/google-antigravity.js";
|
|
||||||
import { loginGeminiCli } from "./utils/oauth/google-gemini-cli.js";
|
|
||||||
import { getOAuthProviders } from "./utils/oauth/index.js";
|
|
||||||
import { loginOpenAICodex } from "./utils/oauth/openai-codex.js";
|
|
||||||
import type { OAuthCredentials, OAuthProvider } from "./utils/oauth/types.js";
|
|
||||||
|
|
||||||
const AUTH_FILE = "auth.json";
|
const AUTH_FILE = "auth.json";
|
||||||
const PROVIDERS = getOAuthProviders();
|
const PROVIDERS = getOAuthProviders();
|
||||||
|
|
@ -30,78 +25,31 @@ function saveAuth(auth: Record<string, { type: "oauth" } & OAuthCredentials>): v
|
||||||
writeFileSync(AUTH_FILE, JSON.stringify(auth, null, 2), "utf-8");
|
writeFileSync(AUTH_FILE, JSON.stringify(auth, null, 2), "utf-8");
|
||||||
}
|
}
|
||||||
|
|
||||||
async function login(provider: OAuthProvider): Promise<void> {
|
async function login(providerId: OAuthProviderId): Promise<void> {
|
||||||
const rl = createInterface({ input: process.stdin, output: process.stdout });
|
const provider = getOAuthProvider(providerId);
|
||||||
|
if (!provider) {
|
||||||
|
console.error(`Unknown provider: ${providerId}`);
|
||||||
|
process.exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
const rl = createInterface({ input: process.stdin, output: process.stdout });
|
||||||
const promptFn = (msg: string) => prompt(rl, `${msg} `);
|
const promptFn = (msg: string) => prompt(rl, `${msg} `);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
let credentials: OAuthCredentials;
|
const credentials = await provider.login({
|
||||||
|
onAuth: (info) => {
|
||||||
switch (provider) {
|
console.log(`\nOpen this URL in your browser:\n${info.url}`);
|
||||||
case "anthropic":
|
if (info.instructions) console.log(info.instructions);
|
||||||
credentials = await loginAnthropic(
|
console.log();
|
||||||
(url) => {
|
},
|
||||||
console.log(`\nOpen this URL in your browser:\n${url}\n`);
|
onPrompt: async (p) => {
|
||||||
},
|
return await promptFn(`${p.message}${p.placeholder ? ` (${p.placeholder})` : ""}:`);
|
||||||
async () => {
|
},
|
||||||
return await promptFn("Paste the authorization code:");
|
onProgress: (msg) => console.log(msg),
|
||||||
},
|
});
|
||||||
);
|
|
||||||
break;
|
|
||||||
|
|
||||||
case "github-copilot":
|
|
||||||
credentials = await loginGitHubCopilot({
|
|
||||||
onAuth: (url, instructions) => {
|
|
||||||
console.log(`\nOpen this URL in your browser:\n${url}`);
|
|
||||||
if (instructions) console.log(instructions);
|
|
||||||
console.log();
|
|
||||||
},
|
|
||||||
onPrompt: async (p) => {
|
|
||||||
return await promptFn(`${p.message}${p.placeholder ? ` (${p.placeholder})` : ""}:`);
|
|
||||||
},
|
|
||||||
onProgress: (msg) => console.log(msg),
|
|
||||||
});
|
|
||||||
break;
|
|
||||||
|
|
||||||
case "google-gemini-cli":
|
|
||||||
credentials = await loginGeminiCli(
|
|
||||||
(info) => {
|
|
||||||
console.log(`\nOpen this URL in your browser:\n${info.url}`);
|
|
||||||
if (info.instructions) console.log(info.instructions);
|
|
||||||
console.log();
|
|
||||||
},
|
|
||||||
(msg) => console.log(msg),
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
|
|
||||||
case "google-antigravity":
|
|
||||||
credentials = await loginAntigravity(
|
|
||||||
(info) => {
|
|
||||||
console.log(`\nOpen this URL in your browser:\n${info.url}`);
|
|
||||||
if (info.instructions) console.log(info.instructions);
|
|
||||||
console.log();
|
|
||||||
},
|
|
||||||
(msg) => console.log(msg),
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
case "openai-codex":
|
|
||||||
credentials = await loginOpenAICodex({
|
|
||||||
onAuth: (info) => {
|
|
||||||
console.log(`\nOpen this URL in your browser:\n${info.url}`);
|
|
||||||
if (info.instructions) console.log(info.instructions);
|
|
||||||
console.log();
|
|
||||||
},
|
|
||||||
onPrompt: async (p) => {
|
|
||||||
return await promptFn(`${p.message}${p.placeholder ? ` (${p.placeholder})` : ""}:`);
|
|
||||||
},
|
|
||||||
onProgress: (msg) => console.log(msg),
|
|
||||||
});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auth = loadAuth();
|
const auth = loadAuth();
|
||||||
auth[provider] = { type: "oauth", ...credentials };
|
auth[providerId] = { type: "oauth", ...credentials };
|
||||||
saveAuth(auth);
|
saveAuth(auth);
|
||||||
|
|
||||||
console.log(`\nCredentials saved to ${AUTH_FILE}`);
|
console.log(`\nCredentials saved to ${AUTH_FILE}`);
|
||||||
|
|
@ -115,6 +63,7 @@ async function main(): Promise<void> {
|
||||||
const command = args[0];
|
const command = args[0];
|
||||||
|
|
||||||
if (!command || command === "help" || command === "--help" || command === "-h") {
|
if (!command || command === "help" || command === "--help" || command === "-h") {
|
||||||
|
const providerList = PROVIDERS.map((p) => ` ${p.id.padEnd(20)} ${p.name}`).join("\n");
|
||||||
console.log(`Usage: npx @mariozechner/pi-ai <command> [provider]
|
console.log(`Usage: npx @mariozechner/pi-ai <command> [provider]
|
||||||
|
|
||||||
Commands:
|
Commands:
|
||||||
|
|
@ -122,11 +71,7 @@ Commands:
|
||||||
list List available providers
|
list List available providers
|
||||||
|
|
||||||
Providers:
|
Providers:
|
||||||
anthropic Anthropic (Claude Pro/Max)
|
${providerList}
|
||||||
github-copilot GitHub Copilot
|
|
||||||
google-gemini-cli Google Gemini CLI
|
|
||||||
google-antigravity Antigravity (Gemini 3, Claude, GPT-OSS)
|
|
||||||
openai-codex OpenAI Codex (ChatGPT Plus/Pro)
|
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
npx @mariozechner/pi-ai login # interactive provider selection
|
npx @mariozechner/pi-ai login # interactive provider selection
|
||||||
|
|
@ -145,7 +90,7 @@ Examples:
|
||||||
}
|
}
|
||||||
|
|
||||||
if (command === "login") {
|
if (command === "login") {
|
||||||
let provider = args[1] as OAuthProvider | undefined;
|
let provider = args[1] as OAuthProviderId | undefined;
|
||||||
|
|
||||||
if (!provider) {
|
if (!provider) {
|
||||||
const rl = createInterface({ input: process.stdin, output: process.stdout });
|
const rl = createInterface({ input: process.stdin, output: process.stdout });
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { generatePKCE } from "./pkce.js";
|
import { generatePKCE } from "./pkce.js";
|
||||||
import type { OAuthCredentials } from "./types.js";
|
import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.js";
|
||||||
|
|
||||||
const decode = (s: string) => atob(s);
|
const decode = (s: string) => atob(s);
|
||||||
const CLIENT_ID = decode("OWQxYzI1MGEtZTYxYi00NGQ5LTg4ZWQtNTk0NGQxOTYyZjVl");
|
const CLIENT_ID = decode("OWQxYzI1MGEtZTYxYi00NGQ5LTg4ZWQtNTk0NGQxOTYyZjVl");
|
||||||
|
|
@ -116,3 +116,23 @@ export async function refreshAnthropicToken(refreshToken: string): Promise<OAuth
|
||||||
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
|
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const anthropicOAuthProvider: OAuthProviderInterface = {
|
||||||
|
id: "anthropic",
|
||||||
|
name: "Anthropic (Claude Pro/Max)",
|
||||||
|
|
||||||
|
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
|
||||||
|
return loginAnthropic(
|
||||||
|
(url) => callbacks.onAuth({ url }),
|
||||||
|
() => callbacks.onPrompt({ message: "Paste the authorization code:" }),
|
||||||
|
);
|
||||||
|
},
|
||||||
|
|
||||||
|
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||||
|
return refreshAnthropicToken(credentials.refresh);
|
||||||
|
},
|
||||||
|
|
||||||
|
getApiKey(credentials: OAuthCredentials): string {
|
||||||
|
return credentials.access;
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,12 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { getModels } from "../../models.js";
|
import { getModels } from "../../models.js";
|
||||||
import type { OAuthCredentials } from "./types.js";
|
import type { Api, Model } from "../../types.js";
|
||||||
|
import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.js";
|
||||||
|
|
||||||
|
type CopilotCredentials = OAuthCredentials & {
|
||||||
|
enterpriseUrl?: string;
|
||||||
|
};
|
||||||
|
|
||||||
const decode = (s: string) => atob(s);
|
const decode = (s: string) => atob(s);
|
||||||
const CLIENT_ID = decode("SXYxLmI1MDdhMDhjODdlY2ZlOTg=");
|
const CLIENT_ID = decode("SXYxLmI1MDdhMDhjODdlY2ZlOTg=");
|
||||||
|
|
@ -344,3 +349,33 @@ export async function loginGitHubCopilot(options: {
|
||||||
await enableAllGitHubCopilotModels(credentials.access, enterpriseDomain ?? undefined);
|
await enableAllGitHubCopilotModels(credentials.access, enterpriseDomain ?? undefined);
|
||||||
return credentials;
|
return credentials;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const githubCopilotOAuthProvider: OAuthProviderInterface = {
|
||||||
|
id: "github-copilot",
|
||||||
|
name: "GitHub Copilot",
|
||||||
|
|
||||||
|
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
|
||||||
|
return loginGitHubCopilot({
|
||||||
|
onAuth: (url, instructions) => callbacks.onAuth({ url, instructions }),
|
||||||
|
onPrompt: callbacks.onPrompt,
|
||||||
|
onProgress: callbacks.onProgress,
|
||||||
|
signal: callbacks.signal,
|
||||||
|
});
|
||||||
|
},
|
||||||
|
|
||||||
|
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||||
|
const creds = credentials as CopilotCredentials;
|
||||||
|
return refreshGitHubCopilotToken(creds.refresh, creds.enterpriseUrl);
|
||||||
|
},
|
||||||
|
|
||||||
|
getApiKey(credentials: OAuthCredentials): string {
|
||||||
|
return credentials.access;
|
||||||
|
},
|
||||||
|
|
||||||
|
modifyModels(models: Model<Api>[], credentials: OAuthCredentials): Model<Api>[] {
|
||||||
|
const creds = credentials as CopilotCredentials;
|
||||||
|
const domain = creds.enterpriseUrl ? (normalizeDomain(creds.enterpriseUrl) ?? undefined) : undefined;
|
||||||
|
const baseUrl = getGitHubCopilotBaseUrl(creds.access, domain);
|
||||||
|
return models.map((m) => (m.provider === "github-copilot" ? { ...m, baseUrl } : m));
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,11 @@
|
||||||
|
|
||||||
import type { Server } from "http";
|
import type { Server } from "http";
|
||||||
import { generatePKCE } from "./pkce.js";
|
import { generatePKCE } from "./pkce.js";
|
||||||
import type { OAuthCredentials } from "./types.js";
|
import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.js";
|
||||||
|
|
||||||
|
type AntigravityCredentials = OAuthCredentials & {
|
||||||
|
projectId: string;
|
||||||
|
};
|
||||||
|
|
||||||
// Antigravity OAuth credentials (different from Gemini CLI)
|
// Antigravity OAuth credentials (different from Gemini CLI)
|
||||||
const decode = (s: string) => atob(s);
|
const decode = (s: string) => atob(s);
|
||||||
|
|
@ -411,3 +415,26 @@ export async function loginAntigravity(
|
||||||
server.server.close();
|
server.server.close();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const antigravityOAuthProvider: OAuthProviderInterface = {
|
||||||
|
id: "google-antigravity",
|
||||||
|
name: "Antigravity (Gemini 3, Claude, GPT-OSS)",
|
||||||
|
usesCallbackServer: true,
|
||||||
|
|
||||||
|
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
|
||||||
|
return loginAntigravity(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput);
|
||||||
|
},
|
||||||
|
|
||||||
|
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||||
|
const creds = credentials as AntigravityCredentials;
|
||||||
|
if (!creds.projectId) {
|
||||||
|
throw new Error("Antigravity credentials missing projectId");
|
||||||
|
}
|
||||||
|
return refreshAntigravityToken(creds.refresh, creds.projectId);
|
||||||
|
},
|
||||||
|
|
||||||
|
getApiKey(credentials: OAuthCredentials): string {
|
||||||
|
const creds = credentials as AntigravityCredentials;
|
||||||
|
return JSON.stringify({ token: creds.access, projectId: creds.projectId });
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,11 @@
|
||||||
|
|
||||||
import type { Server } from "http";
|
import type { Server } from "http";
|
||||||
import { generatePKCE } from "./pkce.js";
|
import { generatePKCE } from "./pkce.js";
|
||||||
import type { OAuthCredentials } from "./types.js";
|
import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.js";
|
||||||
|
|
||||||
|
type GeminiCredentials = OAuthCredentials & {
|
||||||
|
projectId: string;
|
||||||
|
};
|
||||||
|
|
||||||
const decode = (s: string) => atob(s);
|
const decode = (s: string) => atob(s);
|
||||||
const CLIENT_ID = decode(
|
const CLIENT_ID = decode(
|
||||||
|
|
@ -553,3 +557,26 @@ export async function loginGeminiCli(
|
||||||
server.server.close();
|
server.server.close();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const geminiCliOAuthProvider: OAuthProviderInterface = {
|
||||||
|
id: "google-gemini-cli",
|
||||||
|
name: "Google Cloud Code Assist (Gemini CLI)",
|
||||||
|
usesCallbackServer: true,
|
||||||
|
|
||||||
|
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
|
||||||
|
return loginGeminiCli(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput);
|
||||||
|
},
|
||||||
|
|
||||||
|
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||||
|
const creds = credentials as GeminiCredentials;
|
||||||
|
if (!creds.projectId) {
|
||||||
|
throw new Error("Google Cloud credentials missing projectId");
|
||||||
|
}
|
||||||
|
return refreshGoogleCloudToken(creds.refresh, creds.projectId);
|
||||||
|
},
|
||||||
|
|
||||||
|
getApiKey(credentials: OAuthCredentials): string {
|
||||||
|
const creds = credentials as GeminiCredentials;
|
||||||
|
return JSON.stringify({ token: creds.access, projectId: creds.projectId });
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
|
||||||
|
|
@ -10,100 +10,111 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// Anthropic
|
// Anthropic
|
||||||
export { loginAnthropic, refreshAnthropicToken } from "./anthropic.js";
|
export { anthropicOAuthProvider, loginAnthropic, refreshAnthropicToken } from "./anthropic.js";
|
||||||
// GitHub Copilot
|
// GitHub Copilot
|
||||||
export {
|
export {
|
||||||
getGitHubCopilotBaseUrl,
|
getGitHubCopilotBaseUrl,
|
||||||
|
githubCopilotOAuthProvider,
|
||||||
loginGitHubCopilot,
|
loginGitHubCopilot,
|
||||||
normalizeDomain,
|
normalizeDomain,
|
||||||
refreshGitHubCopilotToken,
|
refreshGitHubCopilotToken,
|
||||||
} from "./github-copilot.js";
|
} from "./github-copilot.js";
|
||||||
// Google Antigravity
|
// Google Antigravity
|
||||||
export {
|
export { antigravityOAuthProvider, loginAntigravity, refreshAntigravityToken } from "./google-antigravity.js";
|
||||||
loginAntigravity,
|
|
||||||
refreshAntigravityToken,
|
|
||||||
} from "./google-antigravity.js";
|
|
||||||
// Google Gemini CLI
|
// Google Gemini CLI
|
||||||
export {
|
export { geminiCliOAuthProvider, loginGeminiCli, refreshGoogleCloudToken } from "./google-gemini-cli.js";
|
||||||
loginGeminiCli,
|
|
||||||
refreshGoogleCloudToken,
|
|
||||||
} from "./google-gemini-cli.js";
|
|
||||||
// OpenAI Codex (ChatGPT OAuth)
|
// OpenAI Codex (ChatGPT OAuth)
|
||||||
export {
|
export { loginOpenAICodex, openaiCodexOAuthProvider, refreshOpenAICodexToken } from "./openai-codex.js";
|
||||||
loginOpenAICodex,
|
|
||||||
refreshOpenAICodexToken,
|
|
||||||
} from "./openai-codex.js";
|
|
||||||
|
|
||||||
export * from "./types.js";
|
export * from "./types.js";
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// High-level API
|
// Provider Registry
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
||||||
import { refreshAnthropicToken } from "./anthropic.js";
|
import { anthropicOAuthProvider } from "./anthropic.js";
|
||||||
import { refreshGitHubCopilotToken } from "./github-copilot.js";
|
import { githubCopilotOAuthProvider } from "./github-copilot.js";
|
||||||
import { refreshAntigravityToken } from "./google-antigravity.js";
|
import { antigravityOAuthProvider } from "./google-antigravity.js";
|
||||||
import { refreshGoogleCloudToken } from "./google-gemini-cli.js";
|
import { geminiCliOAuthProvider } from "./google-gemini-cli.js";
|
||||||
import { refreshOpenAICodexToken } from "./openai-codex.js";
|
import { openaiCodexOAuthProvider } from "./openai-codex.js";
|
||||||
import type { OAuthCredentials, OAuthProvider, OAuthProviderInfo } from "./types.js";
|
import type { OAuthCredentials, OAuthProviderId, OAuthProviderInfo, OAuthProviderInterface } from "./types.js";
|
||||||
|
|
||||||
|
const oauthProviderRegistry = new Map<string, OAuthProviderInterface>([
|
||||||
|
[anthropicOAuthProvider.id, anthropicOAuthProvider],
|
||||||
|
[githubCopilotOAuthProvider.id, githubCopilotOAuthProvider],
|
||||||
|
[geminiCliOAuthProvider.id, geminiCliOAuthProvider],
|
||||||
|
[antigravityOAuthProvider.id, antigravityOAuthProvider],
|
||||||
|
[openaiCodexOAuthProvider.id, openaiCodexOAuthProvider],
|
||||||
|
]);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get an OAuth provider by ID
|
||||||
|
*/
|
||||||
|
export function getOAuthProvider(id: OAuthProviderId): OAuthProviderInterface | undefined {
|
||||||
|
return oauthProviderRegistry.get(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Register a custom OAuth provider
|
||||||
|
*/
|
||||||
|
export function registerOAuthProvider(provider: OAuthProviderInterface): void {
|
||||||
|
oauthProviderRegistry.set(provider.id, provider);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get all registered OAuth providers
|
||||||
|
*/
|
||||||
|
export function getOAuthProviders(): OAuthProviderInterface[] {
|
||||||
|
return Array.from(oauthProviderRegistry.values());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated Use getOAuthProviders() which returns OAuthProviderInterface[]
|
||||||
|
*/
|
||||||
|
export function getOAuthProviderInfoList(): OAuthProviderInfo[] {
|
||||||
|
return getOAuthProviders().map((p) => ({
|
||||||
|
id: p.id,
|
||||||
|
name: p.name,
|
||||||
|
available: true,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// High-level API (uses provider registry)
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Refresh token for any OAuth provider.
|
* Refresh token for any OAuth provider.
|
||||||
* Saves the new credentials and returns the new access token.
|
* @deprecated Use getOAuthProvider(id).refreshToken() instead
|
||||||
*/
|
*/
|
||||||
export async function refreshOAuthToken(
|
export async function refreshOAuthToken(
|
||||||
provider: OAuthProvider,
|
providerId: OAuthProviderId,
|
||||||
credentials: OAuthCredentials,
|
credentials: OAuthCredentials,
|
||||||
): Promise<OAuthCredentials> {
|
): Promise<OAuthCredentials> {
|
||||||
if (!credentials) {
|
const provider = getOAuthProvider(providerId);
|
||||||
throw new Error(`No OAuth credentials found for ${provider}`);
|
if (!provider) {
|
||||||
|
throw new Error(`Unknown OAuth provider: ${providerId}`);
|
||||||
}
|
}
|
||||||
|
return provider.refreshToken(credentials);
|
||||||
let newCredentials: OAuthCredentials;
|
|
||||||
|
|
||||||
switch (provider) {
|
|
||||||
case "anthropic":
|
|
||||||
newCredentials = await refreshAnthropicToken(credentials.refresh);
|
|
||||||
break;
|
|
||||||
case "github-copilot":
|
|
||||||
newCredentials = await refreshGitHubCopilotToken(credentials.refresh, credentials.enterpriseUrl);
|
|
||||||
break;
|
|
||||||
case "google-gemini-cli":
|
|
||||||
if (!credentials.projectId) {
|
|
||||||
throw new Error("Google Cloud credentials missing projectId");
|
|
||||||
}
|
|
||||||
newCredentials = await refreshGoogleCloudToken(credentials.refresh, credentials.projectId);
|
|
||||||
break;
|
|
||||||
case "google-antigravity":
|
|
||||||
if (!credentials.projectId) {
|
|
||||||
throw new Error("Antigravity credentials missing projectId");
|
|
||||||
}
|
|
||||||
newCredentials = await refreshAntigravityToken(credentials.refresh, credentials.projectId);
|
|
||||||
break;
|
|
||||||
case "openai-codex":
|
|
||||||
newCredentials = await refreshOpenAICodexToken(credentials.refresh);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw new Error(`Unknown OAuth provider: ${provider}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
return newCredentials;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get API key for a provider from OAuth credentials.
|
* Get API key for a provider from OAuth credentials.
|
||||||
* Automatically refreshes expired tokens.
|
* Automatically refreshes expired tokens.
|
||||||
*
|
*
|
||||||
* For google-gemini-cli and antigravity, returns JSON-encoded { token, projectId }
|
* @returns API key string and updated credentials, or null if no credentials
|
||||||
*
|
|
||||||
* @returns API key string, or null if no credentials
|
|
||||||
* @throws Error if refresh fails
|
* @throws Error if refresh fails
|
||||||
*/
|
*/
|
||||||
export async function getOAuthApiKey(
|
export async function getOAuthApiKey(
|
||||||
provider: OAuthProvider,
|
providerId: OAuthProviderId,
|
||||||
credentials: Record<string, OAuthCredentials>,
|
credentials: Record<string, OAuthCredentials>,
|
||||||
): Promise<{ newCredentials: OAuthCredentials; apiKey: string } | null> {
|
): Promise<{ newCredentials: OAuthCredentials; apiKey: string } | null> {
|
||||||
let creds = credentials[provider];
|
const provider = getOAuthProvider(providerId);
|
||||||
|
if (!provider) {
|
||||||
|
throw new Error(`Unknown OAuth provider: ${providerId}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
let creds = credentials[providerId];
|
||||||
if (!creds) {
|
if (!creds) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
@ -111,47 +122,12 @@ export async function getOAuthApiKey(
|
||||||
// Refresh if expired
|
// Refresh if expired
|
||||||
if (Date.now() >= creds.expires) {
|
if (Date.now() >= creds.expires) {
|
||||||
try {
|
try {
|
||||||
creds = await refreshOAuthToken(provider, creds);
|
creds = await provider.refreshToken(creds);
|
||||||
} catch (_error) {
|
} catch (_error) {
|
||||||
throw new Error(`Failed to refresh OAuth token for ${provider}`);
|
throw new Error(`Failed to refresh OAuth token for ${providerId}`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// For providers that need projectId, return JSON
|
const apiKey = provider.getApiKey(creds);
|
||||||
const needsProjectId = provider === "google-gemini-cli" || provider === "google-antigravity";
|
|
||||||
const apiKey = needsProjectId ? JSON.stringify({ token: creds.access, projectId: creds.projectId }) : creds.access;
|
|
||||||
return { newCredentials: creds, apiKey };
|
return { newCredentials: creds, apiKey };
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Get list of OAuth providers
|
|
||||||
*/
|
|
||||||
export function getOAuthProviders(): OAuthProviderInfo[] {
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
id: "anthropic",
|
|
||||||
name: "Anthropic (Claude Pro/Max)",
|
|
||||||
available: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: "openai-codex",
|
|
||||||
name: "ChatGPT Plus/Pro (Codex Subscription)",
|
|
||||||
available: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: "github-copilot",
|
|
||||||
name: "GitHub Copilot",
|
|
||||||
available: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: "google-gemini-cli",
|
|
||||||
name: "Google Cloud Code Assist (Gemini CLI)",
|
|
||||||
available: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: "google-antigravity",
|
|
||||||
name: "Antigravity (Gemini 3, Claude, GPT-OSS)",
|
|
||||||
available: true,
|
|
||||||
},
|
|
||||||
];
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ if (typeof process !== "undefined" && (process.versions?.node || process.version
|
||||||
}
|
}
|
||||||
|
|
||||||
import { generatePKCE } from "./pkce.js";
|
import { generatePKCE } from "./pkce.js";
|
||||||
import type { OAuthCredentials, OAuthPrompt } from "./types.js";
|
import type { OAuthCredentials, OAuthLoginCallbacks, OAuthPrompt, OAuthProviderInterface } from "./types.js";
|
||||||
|
|
||||||
const CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann";
|
const CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann";
|
||||||
const AUTHORIZE_URL = "https://auth.openai.com/oauth/authorize";
|
const AUTHORIZE_URL = "https://auth.openai.com/oauth/authorize";
|
||||||
|
|
@ -430,3 +430,26 @@ export async function refreshOpenAICodexToken(refreshToken: string): Promise<OAu
|
||||||
accountId,
|
accountId,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const openaiCodexOAuthProvider: OAuthProviderInterface = {
|
||||||
|
id: "openai-codex",
|
||||||
|
name: "ChatGPT Plus/Pro (Codex Subscription)",
|
||||||
|
usesCallbackServer: true,
|
||||||
|
|
||||||
|
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
|
||||||
|
return loginOpenAICodex({
|
||||||
|
onAuth: callbacks.onAuth,
|
||||||
|
onPrompt: callbacks.onPrompt,
|
||||||
|
onProgress: callbacks.onProgress,
|
||||||
|
onManualCodeInput: callbacks.onManualCodeInput,
|
||||||
|
});
|
||||||
|
},
|
||||||
|
|
||||||
|
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||||
|
return refreshOpenAICodexToken(credentials.refresh);
|
||||||
|
},
|
||||||
|
|
||||||
|
getApiKey(credentials: OAuthCredentials): string {
|
||||||
|
return credentials.access;
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,16 @@
|
||||||
|
import type { Api, Model } from "../../types.js";
|
||||||
|
|
||||||
export type OAuthCredentials = {
|
export type OAuthCredentials = {
|
||||||
refresh: string;
|
refresh: string;
|
||||||
access: string;
|
access: string;
|
||||||
expires: number;
|
expires: number;
|
||||||
enterpriseUrl?: string;
|
[key: string]: unknown;
|
||||||
projectId?: string;
|
|
||||||
email?: string;
|
|
||||||
accountId?: string;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export type OAuthProvider =
|
export type OAuthProviderId = string;
|
||||||
| "anthropic"
|
|
||||||
| "github-copilot"
|
/** @deprecated Use OAuthProviderId instead */
|
||||||
| "google-gemini-cli"
|
export type OAuthProvider = OAuthProviderId;
|
||||||
| "google-antigravity"
|
|
||||||
| "openai-codex";
|
|
||||||
|
|
||||||
export type OAuthPrompt = {
|
export type OAuthPrompt = {
|
||||||
message: string;
|
message: string;
|
||||||
|
|
@ -26,8 +23,37 @@ export type OAuthAuthInfo = {
|
||||||
instructions?: string;
|
instructions?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export interface OAuthLoginCallbacks {
|
||||||
|
onAuth: (info: OAuthAuthInfo) => void;
|
||||||
|
onPrompt: (prompt: OAuthPrompt) => Promise<string>;
|
||||||
|
onProgress?: (message: string) => void;
|
||||||
|
onManualCodeInput?: () => Promise<string>;
|
||||||
|
signal?: AbortSignal;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface OAuthProviderInterface {
|
||||||
|
readonly id: OAuthProviderId;
|
||||||
|
readonly name: string;
|
||||||
|
|
||||||
|
/** Run the login flow, return credentials to persist */
|
||||||
|
login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials>;
|
||||||
|
|
||||||
|
/** Whether login uses a local callback server and supports manual code input. */
|
||||||
|
usesCallbackServer?: boolean;
|
||||||
|
|
||||||
|
/** Refresh expired credentials, return updated credentials to persist */
|
||||||
|
refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials>;
|
||||||
|
|
||||||
|
/** Convert credentials to API key string for the provider */
|
||||||
|
getApiKey(credentials: OAuthCredentials): string;
|
||||||
|
|
||||||
|
/** Optional: modify models for this provider (e.g., update baseUrl) */
|
||||||
|
modifyModels?(models: Model<Api>[], credentials: OAuthCredentials): Model<Api>[];
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @deprecated Use OAuthProviderInterface instead */
|
||||||
export interface OAuthProviderInfo {
|
export interface OAuthProviderInfo {
|
||||||
id: OAuthProvider;
|
id: OAuthProviderId;
|
||||||
name: string;
|
name: string;
|
||||||
available: boolean;
|
available: boolean;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1146,6 +1146,67 @@ pi.events.on("my:event", (data) => { ... });
|
||||||
pi.events.emit("my:event", { ... });
|
pi.events.emit("my:event", { ... });
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### pi.registerProvider(name, config)
|
||||||
|
|
||||||
|
Register or override a model provider dynamically. Useful for proxies, custom endpoints, or team-wide model configurations.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Register a new provider with custom models
|
||||||
|
pi.registerProvider("my-proxy", {
|
||||||
|
baseUrl: "https://proxy.example.com",
|
||||||
|
apiKey: "PROXY_API_KEY", // env var name or literal
|
||||||
|
api: "anthropic-messages",
|
||||||
|
models: [
|
||||||
|
{
|
||||||
|
id: "claude-sonnet-4-20250514",
|
||||||
|
name: "Claude 4 Sonnet (proxy)",
|
||||||
|
reasoning: false,
|
||||||
|
input: ["text", "image"],
|
||||||
|
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||||
|
contextWindow: 200000,
|
||||||
|
maxTokens: 16384
|
||||||
|
}
|
||||||
|
]
|
||||||
|
});
|
||||||
|
|
||||||
|
// Override baseUrl for an existing provider (keeps all models)
|
||||||
|
pi.registerProvider("anthropic", {
|
||||||
|
baseUrl: "https://proxy.example.com"
|
||||||
|
});
|
||||||
|
|
||||||
|
// Register provider with OAuth support for /login
|
||||||
|
pi.registerProvider("corporate-ai", {
|
||||||
|
baseUrl: "https://ai.corp.com",
|
||||||
|
api: "openai-responses",
|
||||||
|
models: [...],
|
||||||
|
oauth: {
|
||||||
|
name: "Corporate AI (SSO)",
|
||||||
|
async login(callbacks) {
|
||||||
|
// Custom OAuth flow
|
||||||
|
callbacks.onAuth({ url: "https://sso.corp.com/..." });
|
||||||
|
const code = await callbacks.onPrompt({ message: "Enter code:" });
|
||||||
|
return { refresh: code, access: code, expires: Date.now() + 3600000 };
|
||||||
|
},
|
||||||
|
async refreshToken(credentials) {
|
||||||
|
// Refresh logic
|
||||||
|
return credentials;
|
||||||
|
},
|
||||||
|
getApiKey(credentials) {
|
||||||
|
return credentials.access;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
```
|
||||||
|
|
||||||
|
**Config options:**
|
||||||
|
- `baseUrl` - API endpoint URL. Required when defining models.
|
||||||
|
- `apiKey` - API key or environment variable name. Required when defining models (unless `oauth` provided).
|
||||||
|
- `api` - API type: `"anthropic-messages"`, `"openai-completions"`, `"openai-responses"`, etc.
|
||||||
|
- `headers` - Custom headers to include in requests.
|
||||||
|
- `authHeader` - If true, adds `Authorization: Bearer` header automatically.
|
||||||
|
- `models` - Array of model definitions. If provided, replaces all existing models for this provider.
|
||||||
|
- `oauth` - OAuth provider config for `/login` support. When provided, the provider appears in the login menu.
|
||||||
|
|
||||||
## State Management
|
## State Management
|
||||||
|
|
||||||
Extensions with state should store it in tool result `details` for proper branching support:
|
Extensions with state should store it in tool result `details` for proper branching support:
|
||||||
|
|
|
||||||
|
|
@ -9,13 +9,11 @@
|
||||||
import {
|
import {
|
||||||
getEnvApiKey,
|
getEnvApiKey,
|
||||||
getOAuthApiKey,
|
getOAuthApiKey,
|
||||||
loginAnthropic,
|
getOAuthProvider,
|
||||||
loginAntigravity,
|
getOAuthProviders,
|
||||||
loginGeminiCli,
|
|
||||||
loginGitHubCopilot,
|
|
||||||
loginOpenAICodex,
|
|
||||||
type OAuthCredentials,
|
type OAuthCredentials,
|
||||||
type OAuthProvider,
|
type OAuthLoginCallbacks,
|
||||||
|
type OAuthProviderId,
|
||||||
} from "@mariozechner/pi-ai";
|
} from "@mariozechner/pi-ai";
|
||||||
import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
|
import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
|
||||||
import { dirname, join } from "path";
|
import { dirname, join } from "path";
|
||||||
|
|
@ -156,54 +154,14 @@ export class AuthStorage {
|
||||||
/**
|
/**
|
||||||
* Login to an OAuth provider.
|
* Login to an OAuth provider.
|
||||||
*/
|
*/
|
||||||
async login(
|
async login(providerId: OAuthProviderId, callbacks: OAuthLoginCallbacks): Promise<void> {
|
||||||
provider: OAuthProvider,
|
const provider = getOAuthProvider(providerId);
|
||||||
callbacks: {
|
if (!provider) {
|
||||||
onAuth: (info: { url: string; instructions?: string }) => void;
|
throw new Error(`Unknown OAuth provider: ${providerId}`);
|
||||||
onPrompt: (prompt: { message: string; placeholder?: string }) => Promise<string>;
|
|
||||||
onProgress?: (message: string) => void;
|
|
||||||
/** For providers with local callback servers (e.g., openai-codex), races with browser callback */
|
|
||||||
onManualCodeInput?: () => Promise<string>;
|
|
||||||
/** For cancellation support (e.g., github-copilot polling) */
|
|
||||||
signal?: AbortSignal;
|
|
||||||
},
|
|
||||||
): Promise<void> {
|
|
||||||
let credentials: OAuthCredentials;
|
|
||||||
|
|
||||||
switch (provider) {
|
|
||||||
case "anthropic":
|
|
||||||
credentials = await loginAnthropic(
|
|
||||||
(url) => callbacks.onAuth({ url }),
|
|
||||||
() => callbacks.onPrompt({ message: "Paste the authorization code:" }),
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
case "github-copilot":
|
|
||||||
credentials = await loginGitHubCopilot({
|
|
||||||
onAuth: (url, instructions) => callbacks.onAuth({ url, instructions }),
|
|
||||||
onPrompt: callbacks.onPrompt,
|
|
||||||
onProgress: callbacks.onProgress,
|
|
||||||
signal: callbacks.signal,
|
|
||||||
});
|
|
||||||
break;
|
|
||||||
case "google-gemini-cli":
|
|
||||||
credentials = await loginGeminiCli(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput);
|
|
||||||
break;
|
|
||||||
case "google-antigravity":
|
|
||||||
credentials = await loginAntigravity(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput);
|
|
||||||
break;
|
|
||||||
case "openai-codex":
|
|
||||||
credentials = await loginOpenAICodex({
|
|
||||||
onAuth: callbacks.onAuth,
|
|
||||||
onPrompt: callbacks.onPrompt,
|
|
||||||
onProgress: callbacks.onProgress,
|
|
||||||
onManualCodeInput: callbacks.onManualCodeInput,
|
|
||||||
});
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw new Error(`Unknown OAuth provider: ${provider}`);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
this.set(provider, { type: "oauth", ...credentials });
|
const credentials = await provider.login(callbacks);
|
||||||
|
this.set(providerId, { type: "oauth", ...credentials });
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -219,8 +177,13 @@ export class AuthStorage {
|
||||||
* This ensures only one instance refreshes while others wait and use the result.
|
* This ensures only one instance refreshes while others wait and use the result.
|
||||||
*/
|
*/
|
||||||
private async refreshOAuthTokenWithLock(
|
private async refreshOAuthTokenWithLock(
|
||||||
provider: OAuthProvider,
|
providerId: OAuthProviderId,
|
||||||
): Promise<{ apiKey: string; newCredentials: OAuthCredentials } | null> {
|
): Promise<{ apiKey: string; newCredentials: OAuthCredentials } | null> {
|
||||||
|
const provider = getOAuthProvider(providerId);
|
||||||
|
if (!provider) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure auth file exists for locking
|
// Ensure auth file exists for locking
|
||||||
if (!existsSync(this.authPath)) {
|
if (!existsSync(this.authPath)) {
|
||||||
const dir = dirname(this.authPath);
|
const dir = dirname(this.authPath);
|
||||||
|
|
@ -250,7 +213,7 @@ export class AuthStorage {
|
||||||
// Re-read file after acquiring lock - another instance may have refreshed
|
// Re-read file after acquiring lock - another instance may have refreshed
|
||||||
this.reload();
|
this.reload();
|
||||||
|
|
||||||
const cred = this.data[provider];
|
const cred = this.data[providerId];
|
||||||
if (cred?.type !== "oauth") {
|
if (cred?.type !== "oauth") {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
@ -259,10 +222,7 @@ export class AuthStorage {
|
||||||
// (another instance may have already refreshed it)
|
// (another instance may have already refreshed it)
|
||||||
if (Date.now() < cred.expires) {
|
if (Date.now() < cred.expires) {
|
||||||
// Token is now valid - another instance refreshed it
|
// Token is now valid - another instance refreshed it
|
||||||
const needsProjectId = provider === "google-gemini-cli" || provider === "google-antigravity";
|
const apiKey = provider.getApiKey(cred);
|
||||||
const apiKey = needsProjectId
|
|
||||||
? JSON.stringify({ token: cred.access, projectId: cred.projectId })
|
|
||||||
: cred.access;
|
|
||||||
return { apiKey, newCredentials: cred };
|
return { apiKey, newCredentials: cred };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -274,9 +234,9 @@ export class AuthStorage {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = await getOAuthApiKey(provider, oauthCreds);
|
const result = await getOAuthApiKey(providerId, oauthCreds);
|
||||||
if (result) {
|
if (result) {
|
||||||
this.data[provider] = { type: "oauth", ...result.newCredentials };
|
this.data[providerId] = { type: "oauth", ...result.newCredentials };
|
||||||
this.save();
|
this.save();
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
@ -303,41 +263,44 @@ export class AuthStorage {
|
||||||
* 4. Environment variable
|
* 4. Environment variable
|
||||||
* 5. Fallback resolver (models.json custom providers)
|
* 5. Fallback resolver (models.json custom providers)
|
||||||
*/
|
*/
|
||||||
async getApiKey(provider: string): Promise<string | undefined> {
|
async getApiKey(providerId: string): Promise<string | undefined> {
|
||||||
// Runtime override takes highest priority
|
// Runtime override takes highest priority
|
||||||
const runtimeKey = this.runtimeOverrides.get(provider);
|
const runtimeKey = this.runtimeOverrides.get(providerId);
|
||||||
if (runtimeKey) {
|
if (runtimeKey) {
|
||||||
return runtimeKey;
|
return runtimeKey;
|
||||||
}
|
}
|
||||||
|
|
||||||
const cred = this.data[provider];
|
const cred = this.data[providerId];
|
||||||
|
|
||||||
if (cred?.type === "api_key") {
|
if (cred?.type === "api_key") {
|
||||||
return cred.key;
|
return cred.key;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cred?.type === "oauth") {
|
if (cred?.type === "oauth") {
|
||||||
|
const provider = getOAuthProvider(providerId);
|
||||||
|
if (!provider) {
|
||||||
|
// Unknown OAuth provider, can't get API key
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
|
||||||
// Check if token needs refresh
|
// Check if token needs refresh
|
||||||
const needsRefresh = Date.now() >= cred.expires;
|
const needsRefresh = Date.now() >= cred.expires;
|
||||||
|
|
||||||
if (needsRefresh) {
|
if (needsRefresh) {
|
||||||
// Use locked refresh to prevent race conditions
|
// Use locked refresh to prevent race conditions
|
||||||
try {
|
try {
|
||||||
const result = await this.refreshOAuthTokenWithLock(provider as OAuthProvider);
|
const result = await this.refreshOAuthTokenWithLock(providerId);
|
||||||
if (result) {
|
if (result) {
|
||||||
return result.apiKey;
|
return result.apiKey;
|
||||||
}
|
}
|
||||||
} catch {
|
} catch {
|
||||||
// Refresh failed - re-read file to check if another instance succeeded
|
// Refresh failed - re-read file to check if another instance succeeded
|
||||||
this.reload();
|
this.reload();
|
||||||
const updatedCred = this.data[provider];
|
const updatedCred = this.data[providerId];
|
||||||
|
|
||||||
if (updatedCred?.type === "oauth" && Date.now() < updatedCred.expires) {
|
if (updatedCred?.type === "oauth" && Date.now() < updatedCred.expires) {
|
||||||
// Another instance refreshed successfully, use those credentials
|
// Another instance refreshed successfully, use those credentials
|
||||||
const needsProjectId = provider === "google-gemini-cli" || provider === "google-antigravity";
|
return provider.getApiKey(updatedCred);
|
||||||
return needsProjectId
|
|
||||||
? JSON.stringify({ token: updatedCred.access, projectId: updatedCred.projectId })
|
|
||||||
: updatedCred.access;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh truly failed - return undefined so model discovery skips this provider
|
// Refresh truly failed - return undefined so model discovery skips this provider
|
||||||
|
|
@ -346,16 +309,22 @@ export class AuthStorage {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Token not expired, use current access token
|
// Token not expired, use current access token
|
||||||
const needsProjectId = provider === "google-gemini-cli" || provider === "google-antigravity";
|
return provider.getApiKey(cred);
|
||||||
return needsProjectId ? JSON.stringify({ token: cred.access, projectId: cred.projectId }) : cred.access;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fall back to environment variable
|
// Fall back to environment variable
|
||||||
const envKey = getEnvApiKey(provider);
|
const envKey = getEnvApiKey(providerId);
|
||||||
if (envKey) return envKey;
|
if (envKey) return envKey;
|
||||||
|
|
||||||
// Fall back to custom resolver (e.g., models.json custom providers)
|
// Fall back to custom resolver (e.g., models.json custom providers)
|
||||||
return this.fallbackResolver?.(provider) ?? undefined;
|
return this.fallbackResolver?.(providerId) ?? undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get all registered OAuth providers
|
||||||
|
*/
|
||||||
|
getOAuthProviders() {
|
||||||
|
return getOAuthProviders();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -76,6 +76,9 @@ export type {
|
||||||
MessageRenderOptions,
|
MessageRenderOptions,
|
||||||
ModelSelectEvent,
|
ModelSelectEvent,
|
||||||
ModelSelectSource,
|
ModelSelectSource,
|
||||||
|
// Provider Registration
|
||||||
|
ProviderConfig,
|
||||||
|
ProviderModelConfig,
|
||||||
ReadToolResultEvent,
|
ReadToolResultEvent,
|
||||||
// Commands
|
// Commands
|
||||||
RegisteredCommand,
|
RegisteredCommand,
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@ import type {
|
||||||
ExtensionRuntime,
|
ExtensionRuntime,
|
||||||
LoadExtensionsResult,
|
LoadExtensionsResult,
|
||||||
MessageRenderer,
|
MessageRenderer,
|
||||||
|
ProviderConfig,
|
||||||
RegisteredCommand,
|
RegisteredCommand,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
} from "./types.js";
|
} from "./types.js";
|
||||||
|
|
@ -122,6 +123,7 @@ export function createExtensionRuntime(): ExtensionRuntime {
|
||||||
getThinkingLevel: notInitialized,
|
getThinkingLevel: notInitialized,
|
||||||
setThinkingLevel: notInitialized,
|
setThinkingLevel: notInitialized,
|
||||||
flagValues: new Map(),
|
flagValues: new Map(),
|
||||||
|
pendingProviderRegistrations: [],
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -238,6 +240,10 @@ function createExtensionAPI(
|
||||||
runtime.setThinkingLevel(level);
|
runtime.setThinkingLevel(level);
|
||||||
},
|
},
|
||||||
|
|
||||||
|
registerProvider(name: string, config: ProviderConfig) {
|
||||||
|
runtime.pendingProviderRegistrations.push({ name, config });
|
||||||
|
},
|
||||||
|
|
||||||
events: eventBus,
|
events: eventBus,
|
||||||
} as ExtensionAPI;
|
} as ExtensionAPI;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -203,6 +203,12 @@ export class ExtensionRunner {
|
||||||
this.shutdownHandler = contextActions.shutdown;
|
this.shutdownHandler = contextActions.shutdown;
|
||||||
this.getContextUsageFn = contextActions.getContextUsage;
|
this.getContextUsageFn = contextActions.getContextUsage;
|
||||||
this.compactFn = contextActions.compact;
|
this.compactFn = contextActions.compact;
|
||||||
|
|
||||||
|
// Process provider registrations queued during extension loading
|
||||||
|
for (const { name, config } of this.runtime.pendingProviderRegistrations) {
|
||||||
|
this.modelRegistry.registerProvider(name, config);
|
||||||
|
}
|
||||||
|
this.runtime.pendingProviderRegistrations = [];
|
||||||
}
|
}
|
||||||
|
|
||||||
bindCommandContext(actions?: ExtensionCommandContextActions): void {
|
bindCommandContext(actions?: ExtensionCommandContextActions): void {
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,15 @@ import type {
|
||||||
AgentToolUpdateCallback,
|
AgentToolUpdateCallback,
|
||||||
ThinkingLevel,
|
ThinkingLevel,
|
||||||
} from "@mariozechner/pi-agent-core";
|
} from "@mariozechner/pi-agent-core";
|
||||||
import type { ImageContent, Model, TextContent, ToolResultMessage } from "@mariozechner/pi-ai";
|
import type {
|
||||||
|
Api,
|
||||||
|
ImageContent,
|
||||||
|
Model,
|
||||||
|
OAuthCredentials,
|
||||||
|
OAuthLoginCallbacks,
|
||||||
|
TextContent,
|
||||||
|
ToolResultMessage,
|
||||||
|
} from "@mariozechner/pi-ai";
|
||||||
import type {
|
import type {
|
||||||
AutocompleteItem,
|
AutocompleteItem,
|
||||||
Component,
|
Component,
|
||||||
|
|
@ -854,10 +862,119 @@ export interface ExtensionAPI {
|
||||||
/** Set thinking level (clamped to model capabilities). */
|
/** Set thinking level (clamped to model capabilities). */
|
||||||
setThinkingLevel(level: ThinkingLevel): void;
|
setThinkingLevel(level: ThinkingLevel): void;
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// Provider Registration
|
||||||
|
// =========================================================================
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Register or override a model provider.
|
||||||
|
*
|
||||||
|
* If `models` is provided: replaces all existing models for this provider.
|
||||||
|
* If only `baseUrl` is provided: overrides the URL for existing models.
|
||||||
|
* If `oauth` is provided: registers OAuth provider for /login support.
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* // Register a new provider with custom models
|
||||||
|
* pi.registerProvider("my-proxy", {
|
||||||
|
* baseUrl: "https://proxy.example.com",
|
||||||
|
* apiKey: "PROXY_API_KEY",
|
||||||
|
* api: "anthropic-messages",
|
||||||
|
* models: [
|
||||||
|
* {
|
||||||
|
* id: "claude-sonnet-4-20250514",
|
||||||
|
* name: "Claude 4 Sonnet (proxy)",
|
||||||
|
* reasoning: false,
|
||||||
|
* input: ["text", "image"],
|
||||||
|
* cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||||
|
* contextWindow: 200000,
|
||||||
|
* maxTokens: 16384
|
||||||
|
* }
|
||||||
|
* ]
|
||||||
|
* });
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* // Override baseUrl for an existing provider
|
||||||
|
* pi.registerProvider("anthropic", {
|
||||||
|
* baseUrl: "https://proxy.example.com"
|
||||||
|
* });
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* // Register provider with OAuth support
|
||||||
|
* pi.registerProvider("corporate-ai", {
|
||||||
|
* baseUrl: "https://ai.corp.com",
|
||||||
|
* api: "openai-responses",
|
||||||
|
* models: [...],
|
||||||
|
* oauth: {
|
||||||
|
* name: "Corporate AI (SSO)",
|
||||||
|
* async login(callbacks) { ... },
|
||||||
|
* async refreshToken(credentials) { ... },
|
||||||
|
* getApiKey(credentials) { return credentials.access; }
|
||||||
|
* }
|
||||||
|
* });
|
||||||
|
*/
|
||||||
|
registerProvider(name: string, config: ProviderConfig): void;
|
||||||
|
|
||||||
/** Shared event bus for extension communication. */
|
/** Shared event bus for extension communication. */
|
||||||
events: EventBus;
|
events: EventBus;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Provider Registration Types
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
/** Configuration for registering a provider via pi.registerProvider(). */
|
||||||
|
export interface ProviderConfig {
|
||||||
|
/** Base URL for the API endpoint. Required when defining models. */
|
||||||
|
baseUrl?: string;
|
||||||
|
/** API key or environment variable name. Required when defining models (unless oauth provided). */
|
||||||
|
apiKey?: string;
|
||||||
|
/** API type. Required at provider or model level when defining models. */
|
||||||
|
api?: Api;
|
||||||
|
/** Custom headers to include in requests. */
|
||||||
|
headers?: Record<string, string>;
|
||||||
|
/** If true, adds Authorization: Bearer header with the resolved API key. */
|
||||||
|
authHeader?: boolean;
|
||||||
|
/** Models to register. If provided, replaces all existing models for this provider. */
|
||||||
|
models?: ProviderModelConfig[];
|
||||||
|
/** OAuth provider for /login support. The `id` is set automatically from the provider name. */
|
||||||
|
oauth?: {
|
||||||
|
/** Display name for the provider in login UI. */
|
||||||
|
name: string;
|
||||||
|
/** Run the login flow, return credentials to persist. */
|
||||||
|
login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials>;
|
||||||
|
/** Refresh expired credentials, return updated credentials to persist. */
|
||||||
|
refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials>;
|
||||||
|
/** Convert credentials to API key string for the provider. */
|
||||||
|
getApiKey(credentials: OAuthCredentials): string;
|
||||||
|
/** Optional: modify models for this provider (e.g., update baseUrl based on credentials). */
|
||||||
|
modifyModels?(models: Model<Api>[], credentials: OAuthCredentials): Model<Api>[];
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Configuration for a model within a provider. */
|
||||||
|
export interface ProviderModelConfig {
|
||||||
|
/** Model ID (e.g., "claude-sonnet-4-20250514"). */
|
||||||
|
id: string;
|
||||||
|
/** Display name (e.g., "Claude 4 Sonnet"). */
|
||||||
|
name: string;
|
||||||
|
/** API type override for this model. */
|
||||||
|
api?: Api;
|
||||||
|
/** Whether the model supports extended thinking. */
|
||||||
|
reasoning: boolean;
|
||||||
|
/** Supported input types. */
|
||||||
|
input: ("text" | "image")[];
|
||||||
|
/** Cost per token (for tracking, can be 0). */
|
||||||
|
cost: { input: number; output: number; cacheRead: number; cacheWrite: number };
|
||||||
|
/** Maximum context window size in tokens. */
|
||||||
|
contextWindow: number;
|
||||||
|
/** Maximum output tokens. */
|
||||||
|
maxTokens: number;
|
||||||
|
/** Custom headers for this model. */
|
||||||
|
headers?: Record<string, string>;
|
||||||
|
/** OpenAI compatibility settings. */
|
||||||
|
compat?: Model<Api>["compat"];
|
||||||
|
}
|
||||||
|
|
||||||
/** Extension factory function type. Supports both sync and async initialization. */
|
/** Extension factory function type. Supports both sync and async initialization. */
|
||||||
export type ExtensionFactory = (pi: ExtensionAPI) => void | Promise<void>;
|
export type ExtensionFactory = (pi: ExtensionAPI) => void | Promise<void>;
|
||||||
|
|
||||||
|
|
@ -926,6 +1043,8 @@ export type SetLabelHandler = (entryId: string, label: string | undefined) => vo
|
||||||
*/
|
*/
|
||||||
export interface ExtensionRuntimeState {
|
export interface ExtensionRuntimeState {
|
||||||
flagValues: Map<string, boolean | string>;
|
flagValues: Map<string, boolean | string>;
|
||||||
|
/** Provider registrations queued during extension loading, processed when runner binds */
|
||||||
|
pendingProviderRegistrations: Array<{ name: string; config: ProviderConfig }>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,12 @@
|
||||||
|
|
||||||
import {
|
import {
|
||||||
type Api,
|
type Api,
|
||||||
getGitHubCopilotBaseUrl,
|
|
||||||
getModels,
|
getModels,
|
||||||
getProviders,
|
getProviders,
|
||||||
type KnownProvider,
|
type KnownProvider,
|
||||||
type Model,
|
type Model,
|
||||||
normalizeDomain,
|
type OAuthProviderInterface,
|
||||||
|
registerOAuthProvider,
|
||||||
} from "@mariozechner/pi-ai";
|
} from "@mariozechner/pi-ai";
|
||||||
import { type Static, Type } from "@sinclair/typebox";
|
import { type Static, Type } from "@sinclair/typebox";
|
||||||
import AjvModule from "ajv";
|
import AjvModule from "ajv";
|
||||||
|
|
@ -26,7 +26,13 @@ const OpenAICompletionsCompatSchema = Type.Object({
|
||||||
supportsStore: Type.Optional(Type.Boolean()),
|
supportsStore: Type.Optional(Type.Boolean()),
|
||||||
supportsDeveloperRole: Type.Optional(Type.Boolean()),
|
supportsDeveloperRole: Type.Optional(Type.Boolean()),
|
||||||
supportsReasoningEffort: Type.Optional(Type.Boolean()),
|
supportsReasoningEffort: Type.Optional(Type.Boolean()),
|
||||||
|
supportsUsageInStreaming: Type.Optional(Type.Boolean()),
|
||||||
maxTokensField: Type.Optional(Type.Union([Type.Literal("max_completion_tokens"), Type.Literal("max_tokens")])),
|
maxTokensField: Type.Optional(Type.Union([Type.Literal("max_completion_tokens"), Type.Literal("max_tokens")])),
|
||||||
|
requiresToolResultName: Type.Optional(Type.Boolean()),
|
||||||
|
requiresAssistantAfterToolResult: Type.Optional(Type.Boolean()),
|
||||||
|
requiresThinkingAsText: Type.Optional(Type.Boolean()),
|
||||||
|
requiresMistralToolIds: Type.Optional(Type.Boolean()),
|
||||||
|
thinkingFormat: Type.Optional(Type.Union([Type.Literal("openai"), Type.Literal("zai")])),
|
||||||
});
|
});
|
||||||
|
|
||||||
const OpenAIResponsesCompatSchema = Type.Object({
|
const OpenAIResponsesCompatSchema = Type.Object({
|
||||||
|
|
@ -174,6 +180,7 @@ export function clearApiKeyCache(): void {
|
||||||
export class ModelRegistry {
|
export class ModelRegistry {
|
||||||
private models: Model<Api>[] = [];
|
private models: Model<Api>[] = [];
|
||||||
private customProviderApiKeys: Map<string, string> = new Map();
|
private customProviderApiKeys: Map<string, string> = new Map();
|
||||||
|
private registeredProviders: Map<string, ProviderConfigInput> = new Map();
|
||||||
private loadError: string | undefined = undefined;
|
private loadError: string | undefined = undefined;
|
||||||
|
|
||||||
constructor(
|
constructor(
|
||||||
|
|
@ -200,6 +207,10 @@ export class ModelRegistry {
|
||||||
this.customProviderApiKeys.clear();
|
this.customProviderApiKeys.clear();
|
||||||
this.loadError = undefined;
|
this.loadError = undefined;
|
||||||
this.loadModels();
|
this.loadModels();
|
||||||
|
|
||||||
|
for (const [providerName, config] of this.registeredProviders.entries()) {
|
||||||
|
this.applyProviderConfig(providerName, config);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -224,19 +235,17 @@ export class ModelRegistry {
|
||||||
}
|
}
|
||||||
|
|
||||||
const builtInModels = this.loadBuiltInModels(replacedProviders, overrides);
|
const builtInModels = this.loadBuiltInModels(replacedProviders, overrides);
|
||||||
const combined = [...builtInModels, ...customModels];
|
let combined = [...builtInModels, ...customModels];
|
||||||
|
|
||||||
// Update github-copilot base URL based on OAuth credentials
|
// Let OAuth providers modify their models (e.g., update baseUrl)
|
||||||
const copilotCred = this.authStorage.get("github-copilot");
|
for (const oauthProvider of this.authStorage.getOAuthProviders()) {
|
||||||
if (copilotCred?.type === "oauth") {
|
const cred = this.authStorage.get(oauthProvider.id);
|
||||||
const domain = copilotCred.enterpriseUrl
|
if (cred?.type === "oauth" && oauthProvider.modifyModels) {
|
||||||
? (normalizeDomain(copilotCred.enterpriseUrl) ?? undefined)
|
combined = oauthProvider.modifyModels(combined, cred);
|
||||||
: undefined;
|
}
|
||||||
const baseUrl = getGitHubCopilotBaseUrl(copilotCred.access, domain);
|
|
||||||
this.models = combined.map((m) => (m.provider === "github-copilot" ? { ...m, baseUrl } : m));
|
|
||||||
} else {
|
|
||||||
this.models = combined;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
this.models = combined;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Load built-in models, skipping replaced providers and applying overrides */
|
/** Load built-in models, skipping replaced providers and applying overrides */
|
||||||
|
|
@ -449,4 +458,118 @@ export class ModelRegistry {
|
||||||
const cred = this.authStorage.get(model.provider);
|
const cred = this.authStorage.get(model.provider);
|
||||||
return cred?.type === "oauth";
|
return cred?.type === "oauth";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Register a provider dynamically (from extensions).
|
||||||
|
*
|
||||||
|
* If provider has models: replaces all existing models for this provider.
|
||||||
|
* If provider has only baseUrl/headers: overrides existing models' URLs.
|
||||||
|
* If provider has oauth: registers OAuth provider for /login support.
|
||||||
|
*/
|
||||||
|
registerProvider(providerName: string, config: ProviderConfigInput): void {
|
||||||
|
this.registeredProviders.set(providerName, config);
|
||||||
|
this.applyProviderConfig(providerName, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
private applyProviderConfig(providerName: string, config: ProviderConfigInput): void {
|
||||||
|
// Register OAuth provider if provided
|
||||||
|
if (config.oauth) {
|
||||||
|
// Ensure the OAuth provider ID matches the provider name
|
||||||
|
const oauthProvider: OAuthProviderInterface = {
|
||||||
|
...config.oauth,
|
||||||
|
id: providerName,
|
||||||
|
};
|
||||||
|
registerOAuthProvider(oauthProvider);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store API key for auth resolution
|
||||||
|
if (config.apiKey) {
|
||||||
|
this.customProviderApiKeys.set(providerName, config.apiKey);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (config.models && config.models.length > 0) {
|
||||||
|
// Full replacement: remove existing models for this provider
|
||||||
|
this.models = this.models.filter((m) => m.provider !== providerName);
|
||||||
|
|
||||||
|
// Validate required fields
|
||||||
|
if (!config.baseUrl) {
|
||||||
|
throw new Error(`Provider ${providerName}: "baseUrl" is required when defining models.`);
|
||||||
|
}
|
||||||
|
if (!config.apiKey && !config.oauth) {
|
||||||
|
throw new Error(`Provider ${providerName}: "apiKey" or "oauth" is required when defining models.`);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse and add new models
|
||||||
|
for (const modelDef of config.models) {
|
||||||
|
const api = modelDef.api || config.api;
|
||||||
|
if (!api) {
|
||||||
|
throw new Error(`Provider ${providerName}, model ${modelDef.id}: no "api" specified.`);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge headers
|
||||||
|
const providerHeaders = resolveHeaders(config.headers);
|
||||||
|
const modelHeaders = resolveHeaders(modelDef.headers);
|
||||||
|
let headers = providerHeaders || modelHeaders ? { ...providerHeaders, ...modelHeaders } : undefined;
|
||||||
|
|
||||||
|
// If authHeader is true, add Authorization header
|
||||||
|
if (config.authHeader && config.apiKey) {
|
||||||
|
const resolvedKey = resolveConfigValue(config.apiKey);
|
||||||
|
if (resolvedKey) {
|
||||||
|
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
this.models.push({
|
||||||
|
id: modelDef.id,
|
||||||
|
name: modelDef.name,
|
||||||
|
api: api as Api,
|
||||||
|
provider: providerName,
|
||||||
|
baseUrl: config.baseUrl,
|
||||||
|
reasoning: modelDef.reasoning,
|
||||||
|
input: modelDef.input as ("text" | "image")[],
|
||||||
|
cost: modelDef.cost,
|
||||||
|
contextWindow: modelDef.contextWindow,
|
||||||
|
maxTokens: modelDef.maxTokens,
|
||||||
|
headers,
|
||||||
|
compat: modelDef.compat,
|
||||||
|
} as Model<Api>);
|
||||||
|
}
|
||||||
|
} else if (config.baseUrl) {
|
||||||
|
// Override-only: update baseUrl/headers for existing models
|
||||||
|
const resolvedHeaders = resolveHeaders(config.headers);
|
||||||
|
this.models = this.models.map((m) => {
|
||||||
|
if (m.provider !== providerName) return m;
|
||||||
|
return {
|
||||||
|
...m,
|
||||||
|
baseUrl: config.baseUrl ?? m.baseUrl,
|
||||||
|
headers: resolvedHeaders ? { ...m.headers, ...resolvedHeaders } : m.headers,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Input type for registerProvider API.
|
||||||
|
*/
|
||||||
|
export interface ProviderConfigInput {
|
||||||
|
baseUrl?: string;
|
||||||
|
apiKey?: string;
|
||||||
|
api?: Api;
|
||||||
|
headers?: Record<string, string>;
|
||||||
|
authHeader?: boolean;
|
||||||
|
/** OAuth provider for /login support */
|
||||||
|
oauth?: Omit<OAuthProviderInterface, "id">;
|
||||||
|
models?: Array<{
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
api?: Api;
|
||||||
|
reasoning: boolean;
|
||||||
|
input: ("text" | "image")[];
|
||||||
|
cost: { input: number; output: number; cacheRead: number; cacheWrite: number };
|
||||||
|
contextWindow: number;
|
||||||
|
maxTokens: number;
|
||||||
|
headers?: Record<string, string>;
|
||||||
|
compat?: Model<Api>["compat"];
|
||||||
|
}>;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -76,6 +76,8 @@ export type {
|
||||||
LoadExtensionsResult,
|
LoadExtensionsResult,
|
||||||
MessageRenderer,
|
MessageRenderer,
|
||||||
MessageRenderOptions,
|
MessageRenderOptions,
|
||||||
|
ProviderConfig,
|
||||||
|
ProviderModelConfig,
|
||||||
RegisteredCommand,
|
RegisteredCommand,
|
||||||
RegisteredTool,
|
RegisteredTool,
|
||||||
SessionBeforeCompactEvent,
|
SessionBeforeCompactEvent,
|
||||||
|
|
|
||||||
|
|
@ -432,10 +432,6 @@ export async function main(args: string[]) {
|
||||||
// Run migrations (pass cwd for project-local migrations)
|
// Run migrations (pass cwd for project-local migrations)
|
||||||
const { migratedAuthProviders: migratedProviders, deprecationWarnings } = runMigrations(process.cwd());
|
const { migratedAuthProviders: migratedProviders, deprecationWarnings } = runMigrations(process.cwd());
|
||||||
|
|
||||||
// Create AuthStorage and ModelRegistry upfront
|
|
||||||
const authStorage = new AuthStorage();
|
|
||||||
const modelRegistry = new ModelRegistry(authStorage);
|
|
||||||
|
|
||||||
// First pass: parse args to get --extension paths
|
// First pass: parse args to get --extension paths
|
||||||
const firstPass = parseArgs(args);
|
const firstPass = parseArgs(args);
|
||||||
|
|
||||||
|
|
@ -443,6 +439,8 @@ export async function main(args: string[]) {
|
||||||
const cwd = process.cwd();
|
const cwd = process.cwd();
|
||||||
const agentDir = getAgentDir();
|
const agentDir = getAgentDir();
|
||||||
const settingsManager = SettingsManager.create(cwd, agentDir);
|
const settingsManager = SettingsManager.create(cwd, agentDir);
|
||||||
|
const authStorage = new AuthStorage();
|
||||||
|
const modelRegistry = new ModelRegistry(authStorage, getModelsPath());
|
||||||
|
|
||||||
const resourceLoader = new DefaultResourceLoader({
|
const resourceLoader = new DefaultResourceLoader({
|
||||||
cwd,
|
cwd,
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
import { getOAuthProviders, type OAuthProviderInfo } from "@mariozechner/pi-ai";
|
import { getOAuthProviders, type OAuthProviderInterface } from "@mariozechner/pi-ai";
|
||||||
import { Container, getEditorKeybindings, Spacer, TruncatedText } from "@mariozechner/pi-tui";
|
import { Container, getEditorKeybindings, Spacer, TruncatedText } from "@mariozechner/pi-tui";
|
||||||
import type { AuthStorage } from "../../../core/auth-storage.js";
|
import type { AuthStorage } from "../../../core/auth-storage.js";
|
||||||
import { theme } from "../theme/theme.js";
|
import { theme } from "../theme/theme.js";
|
||||||
|
|
@ -9,7 +9,7 @@ import { DynamicBorder } from "./dynamic-border.js";
|
||||||
*/
|
*/
|
||||||
export class OAuthSelectorComponent extends Container {
|
export class OAuthSelectorComponent extends Container {
|
||||||
private listContainer: Container;
|
private listContainer: Container;
|
||||||
private allProviders: OAuthProviderInfo[] = [];
|
private allProviders: OAuthProviderInterface[] = [];
|
||||||
private selectedIndex: number = 0;
|
private selectedIndex: number = 0;
|
||||||
private mode: "login" | "logout";
|
private mode: "login" | "logout";
|
||||||
private authStorage: AuthStorage;
|
private authStorage: AuthStorage;
|
||||||
|
|
@ -66,7 +66,6 @@ export class OAuthSelectorComponent extends Container {
|
||||||
if (!provider) continue;
|
if (!provider) continue;
|
||||||
|
|
||||||
const isSelected = i === this.selectedIndex;
|
const isSelected = i === this.selectedIndex;
|
||||||
const isAvailable = provider.available;
|
|
||||||
|
|
||||||
// Check if user is logged in for this provider
|
// Check if user is logged in for this provider
|
||||||
const credentials = this.authStorage.get(provider.id);
|
const credentials = this.authStorage.get(provider.id);
|
||||||
|
|
@ -76,10 +75,10 @@ export class OAuthSelectorComponent extends Container {
|
||||||
let line = "";
|
let line = "";
|
||||||
if (isSelected) {
|
if (isSelected) {
|
||||||
const prefix = theme.fg("accent", "→ ");
|
const prefix = theme.fg("accent", "→ ");
|
||||||
const text = isAvailable ? theme.fg("accent", provider.name) : theme.fg("dim", provider.name);
|
const text = theme.fg("accent", provider.name);
|
||||||
line = prefix + text + statusIndicator;
|
line = prefix + text + statusIndicator;
|
||||||
} else {
|
} else {
|
||||||
const text = isAvailable ? ` ${provider.name}` : theme.fg("dim", ` ${provider.name}`);
|
const text = ` ${provider.name}`;
|
||||||
line = text + statusIndicator;
|
line = text + statusIndicator;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -109,7 +108,7 @@ export class OAuthSelectorComponent extends Container {
|
||||||
// Enter
|
// Enter
|
||||||
else if (kb.matches(keyData, "selectConfirm")) {
|
else if (kb.matches(keyData, "selectConfirm")) {
|
||||||
const selectedProvider = this.allProviders[this.selectedIndex];
|
const selectedProvider = this.allProviders[this.selectedIndex];
|
||||||
if (selectedProvider?.available) {
|
if (selectedProvider) {
|
||||||
this.onSelectCallback(selectedProvider.id);
|
this.onSelectCallback(selectedProvider.id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3529,8 +3529,7 @@ export class InteractiveMode {
|
||||||
const providerName = providerInfo?.name || providerId;
|
const providerName = providerInfo?.name || providerId;
|
||||||
|
|
||||||
// Providers that use callback servers (can paste redirect URL)
|
// Providers that use callback servers (can paste redirect URL)
|
||||||
const usesCallbackServer =
|
const usesCallbackServer = providerInfo?.usesCallbackServer ?? false;
|
||||||
providerId === "openai-codex" || providerId === "google-gemini-cli" || providerId === "google-antigravity";
|
|
||||||
|
|
||||||
// Create login dialog component
|
// Create login dialog component
|
||||||
const dialog = new LoginDialogComponent(this.ui, providerId, (_success, _message) => {
|
const dialog = new LoginDialogComponent(this.ui, providerId, (_success, _message) => {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue