Custom provider WIP

This commit is contained in:
Mario Zechner 2025-11-10 21:47:21 +01:00
parent 389c80d7a8
commit 1f9a3a00cc
17 changed files with 1185 additions and 107 deletions

View file

@ -1,14 +1,15 @@
import { Badge, Button, DialogBase, DialogHeader, html, icon, type TemplateResult } from "@mariozechner/mini-lit";
import type { Model } from "@mariozechner/pi-ai";
import { MODELS } from "@mariozechner/pi-ai/dist/models.generated.js";
import { getModels, getProviders, type Model } from "@mariozechner/pi-ai";
import type { PropertyValues } from "lit";
import { customElement, state } from "lit/decorators.js";
import { createRef, ref } from "lit/directives/ref.js";
import { Brain, Image as ImageIcon } from "lucide";
import { Ollama } from "ollama/dist/browser.mjs";
import { Input } from "../components/Input.js";
import { getAppStorage } from "../storage/app-storage.js";
import type { AutoDiscoveryProviderType } from "../storage/stores/custom-providers-store.js";
import { formatModelCost } from "../utils/format.js";
import { i18n } from "../utils/i18n.js";
import { discoverModels } from "../utils/model-discovery.js";
@customElement("agent-model-selector")
export class ModelSelector extends DialogBase {
@ -16,10 +17,10 @@ export class ModelSelector extends DialogBase {
@state() searchQuery = "";
@state() filterThinking = false;
@state() filterVision = false;
@state() ollamaModels: Model<any>[] = [];
@state() ollamaError: string | null = null;
@state() customProvidersLoading = false;
@state() selectedIndex = 0;
@state() private navigationMode: "mouse" | "keyboard" = "mouse";
@state() private customProviderModels: Model<any>[] = [];
private onSelectCallback?: (model: Model<any>) => void;
private scrollContainerRef = createRef<HTMLDivElement>();
@ -33,7 +34,7 @@ export class ModelSelector extends DialogBase {
selector.currentModel = currentModel;
selector.onSelectCallback = onSelect;
selector.open();
selector.fetchOllamaModels();
selector.loadCustomProviders();
}
override async firstUpdated(changedProperties: PropertyValues): Promise<void> {
@ -91,67 +92,50 @@ export class ModelSelector extends DialogBase {
});
}
private async fetchOllamaModels() {
private async loadCustomProviders() {
this.customProvidersLoading = true;
const allCustomModels: Model<any>[] = [];
try {
// Create Ollama client
const ollama = new Ollama({ host: "http://localhost:11434" });
const storage = getAppStorage();
const customProviders = await storage.customProviders.getAll();
// Get list of available models
const { models } = await ollama.list();
// Load models from custom providers
for (const provider of customProviders) {
const isAutoDiscovery: boolean =
provider.type === "ollama" ||
provider.type === "llama.cpp" ||
provider.type === "vllm" ||
provider.type === "lmstudio";
// Fetch details for each model and convert to Model format
const ollamaModelPromises: Promise<Model<any> | null>[] = models
.map(async (model: any) => {
if (isAutoDiscovery) {
try {
// Get model details
const details = await ollama.show({
model: model.name,
});
const models = await discoverModels(
provider.type as AutoDiscoveryProviderType,
provider.baseUrl,
provider.apiKey,
);
// Some Ollama servers don't report capabilities; don't filter on them
const modelsWithProvider = models.map((model) => ({
...model,
provider: provider.name,
}));
// Extract model info
const modelInfo: any = details.model_info || {};
// Get context window size - look for architecture-specific keys
const architecture = modelInfo["general.architecture"] || "";
const contextKey = `${architecture}.context_length`;
const contextWindow = parseInt(modelInfo[contextKey] || "8192", 10);
const maxTokens = 4096; // Default max output tokens
// Create Model object manually since ollama models aren't in MODELS constant
const ollamaModel: Model<any> = {
id: model.name,
name: model.name,
api: "openai-completions" as any,
provider: "ollama",
baseUrl: "http://localhost:11434/v1",
reasoning: false,
input: ["text"],
cost: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
},
contextWindow: contextWindow,
maxTokens: maxTokens,
};
return ollamaModel;
} catch (err) {
console.error(`Failed to fetch details for model ${model.name}:`, err);
return null;
allCustomModels.push(...modelsWithProvider);
} catch (error) {
console.debug(`Failed to load models from ${provider.name}:`, error);
}
})
.filter((m: any) => m !== null);
const results = await Promise.all(ollamaModelPromises);
this.ollamaModels = results.filter((m): m is Model<any> => m !== null);
} catch (err) {
// Ollama not available or other error - silently ignore
console.debug("Ollama not available:", err);
this.ollamaError = err instanceof Error ? err.message : String(err);
} else if (provider.models) {
// Manual provider - models already defined
allCustomModels.push(...provider.models);
}
}
} catch (error) {
console.error("Failed to load custom providers:", error);
} finally {
this.customProviderModels = allCustomModels;
this.customProvidersLoading = false;
this.requestUpdate();
}
}
@ -169,21 +153,20 @@ export class ModelSelector extends DialogBase {
}
private getFilteredModels(): Array<{ provider: string; id: string; model: any }> {
// Collect all models from all providers
// Collect all models from known providers
const allModels: Array<{ provider: string; id: string; model: any }> = [];
for (const [provider, providerData] of Object.entries(MODELS)) {
for (const [modelId, model] of Object.entries(providerData)) {
allModels.push({ provider, id: modelId, model });
const knownProviders = getProviders();
for (const provider of knownProviders) {
const models = getModels(provider as any);
for (const model of models) {
allModels.push({ provider, id: model.id, model });
}
}
// Add Ollama models
for (const ollamaModel of this.ollamaModels) {
allModels.push({
id: ollamaModel.id,
provider: "ollama",
model: ollamaModel,
});
// Add custom provider models
for (const model of this.customProviderModels) {
allModels.push({ provider: model.provider, id: model.id, model });
}
// Filter models based on search and capability filters
@ -283,8 +266,7 @@ export class ModelSelector extends DialogBase {
<!-- Scrollable model list -->
<div class="flex-1 overflow-y-auto" ${ref(this.scrollContainerRef)}>
${filteredModels.map(({ provider, id, model }, index) => {
// Check if this is the current model by comparing IDs
const isCurrent = this.currentModel?.id === model.id;
const isCurrent = this.currentModel?.id === model.id && this.currentModel?.provider === model.provider;
const isSelected = index === this.selectedIndex;
return html`
<div