Complete Foundry refactor checklist

This commit is contained in:
Nathan Flurry 2026-03-15 13:38:51 -07:00 committed by Nathan Flurry
parent 40bed3b0a1
commit 13fc9cb318
91 changed files with 5091 additions and 4108 deletions

View file

@ -1,10 +1,11 @@
// @ts-nocheck
import { and, desc, eq } from "drizzle-orm";
import { actor, queue } from "rivetkit";
import { Loop, workflow } from "rivetkit/workflow";
import { workflow } from "rivetkit/workflow";
import type { AuditLogEvent } from "@sandbox-agent/foundry-shared";
import { auditLogDb } from "./db/db.js";
import { events } from "./db/schema.js";
import { AUDIT_LOG_QUEUE_NAMES, runAuditLogWorkflow } from "./workflow.js";
export interface AuditLogInput {
organizationId: string;
@ -24,46 +25,9 @@ export interface ListAuditLogParams {
limit?: number;
}
export const AUDIT_LOG_QUEUE_NAMES = ["auditLog.command.append"] as const;
async function appendAuditLogRow(loopCtx: any, body: AppendAuditLogCommand): Promise<void> {
const now = Date.now();
await loopCtx.db
.insert(events)
.values({
taskId: body.taskId ?? null,
branchName: body.branchName ?? null,
kind: body.kind,
payloadJson: JSON.stringify(body.payload),
createdAt: now,
})
.run();
}
async function runAuditLogWorkflow(ctx: any): Promise<void> {
await ctx.loop("audit-log-command-loop", async (loopCtx: any) => {
const msg = await loopCtx.queue.next("next-audit-log-command", {
names: [...AUDIT_LOG_QUEUE_NAMES],
completable: true,
});
if (!msg) {
return Loop.continue(undefined);
}
if (msg.name === "auditLog.command.append") {
await loopCtx.step("append-audit-log-row", async () => appendAuditLogRow(loopCtx, msg.body as AppendAuditLogCommand));
await msg.complete({ ok: true });
}
return Loop.continue(undefined);
});
}
export const auditLog = actor({
db: auditLogDb,
queues: {
"auditLog.command.append": queue(),
},
queues: Object.fromEntries(AUDIT_LOG_QUEUE_NAMES.map((name) => [name, queue()])),
options: {
name: "Audit Log",
icon: "database",

View file

@ -0,0 +1,39 @@
// @ts-nocheck
import { Loop } from "rivetkit/workflow";
import { events } from "./db/schema.js";
import type { AppendAuditLogCommand } from "./index.js";
export const AUDIT_LOG_QUEUE_NAMES = ["auditLog.command.append"] as const;
async function appendAuditLogRow(loopCtx: any, body: AppendAuditLogCommand): Promise<void> {
const now = Date.now();
await loopCtx.db
.insert(events)
.values({
taskId: body.taskId ?? null,
branchName: body.branchName ?? null,
kind: body.kind,
payloadJson: JSON.stringify(body.payload),
createdAt: now,
})
.run();
}
export async function runAuditLogWorkflow(ctx: any): Promise<void> {
await ctx.loop("audit-log-command-loop", async (loopCtx: any) => {
const msg = await loopCtx.queue.next("next-audit-log-command", {
names: [...AUDIT_LOG_QUEUE_NAMES],
completable: true,
});
if (!msg) {
return Loop.continue(undefined);
}
if (msg.name === "auditLog.command.append") {
await loopCtx.step("append-audit-log-row", async () => appendAuditLogRow(loopCtx, msg.body as AppendAuditLogCommand));
await msg.complete({ ok: true });
}
return Loop.continue(undefined);
});
}

View file

@ -1,104 +0,0 @@
import type { TaskStatus, SandboxProviderId } from "@sandbox-agent/foundry-shared";
export interface TaskCreatedEvent {
organizationId: string;
repoId: string;
taskId: string;
sandboxProviderId: SandboxProviderId;
branchName: string;
title: string;
}
export interface TaskStatusEvent {
organizationId: string;
repoId: string;
taskId: string;
status: TaskStatus;
message: string;
}
export interface RepositorySnapshotEvent {
organizationId: string;
repoId: string;
updatedAt: number;
}
export interface AgentStartedEvent {
organizationId: string;
repoId: string;
taskId: string;
sessionId: string;
}
export interface AgentIdleEvent {
organizationId: string;
repoId: string;
taskId: string;
sessionId: string;
}
export interface AgentErrorEvent {
organizationId: string;
repoId: string;
taskId: string;
message: string;
}
export interface PrCreatedEvent {
organizationId: string;
repoId: string;
taskId: string;
prNumber: number;
url: string;
}
export interface PrClosedEvent {
organizationId: string;
repoId: string;
taskId: string;
prNumber: number;
merged: boolean;
}
export interface PrReviewEvent {
organizationId: string;
repoId: string;
taskId: string;
prNumber: number;
reviewer: string;
status: string;
}
export interface CiStatusChangedEvent {
organizationId: string;
repoId: string;
taskId: string;
prNumber: number;
status: string;
}
export type TaskStepName = "auto_commit" | "push" | "pr_submit";
export type TaskStepStatus = "started" | "completed" | "skipped" | "failed";
export interface TaskStepEvent {
organizationId: string;
repoId: string;
taskId: string;
step: TaskStepName;
status: TaskStepStatus;
message: string;
}
export interface BranchSwitchedEvent {
organizationId: string;
repoId: string;
taskId: string;
branchName: string;
}
export interface SessionAttachedEvent {
organizationId: string;
repoId: string;
taskId: string;
sessionId: string;
}

View file

@ -18,6 +18,12 @@ const journal = {
tag: "0002_github_branches",
breakpoints: true,
},
{
idx: 3,
when: 1773907200000,
tag: "0003_sync_progress",
breakpoints: true,
},
],
} as const;
@ -79,6 +85,22 @@ CREATE TABLE \`github_pull_requests\` (
\`commit_sha\` text NOT NULL,
\`updated_at\` integer NOT NULL
);
`,
m0003: `ALTER TABLE \`github_meta\` ADD \`sync_generation\` integer NOT NULL DEFAULT 0;
--> statement-breakpoint
ALTER TABLE \`github_meta\` ADD \`sync_phase\` text;
--> statement-breakpoint
ALTER TABLE \`github_meta\` ADD \`processed_repository_count\` integer NOT NULL DEFAULT 0;
--> statement-breakpoint
ALTER TABLE \`github_meta\` ADD \`total_repository_count\` integer NOT NULL DEFAULT 0;
--> statement-breakpoint
ALTER TABLE \`github_repositories\` ADD \`sync_generation\` integer NOT NULL DEFAULT 0;
--> statement-breakpoint
ALTER TABLE \`github_members\` ADD \`sync_generation\` integer NOT NULL DEFAULT 0;
--> statement-breakpoint
ALTER TABLE \`github_pull_requests\` ADD \`sync_generation\` integer NOT NULL DEFAULT 0;
--> statement-breakpoint
ALTER TABLE \`github_branches\` ADD \`sync_generation\` integer NOT NULL DEFAULT 0;
`,
} as const,
};

View file

@ -11,6 +11,10 @@ export const githubMeta = sqliteTable(
installationId: integer("installation_id"),
lastSyncLabel: text("last_sync_label").notNull(),
lastSyncAt: integer("last_sync_at"),
syncGeneration: integer("sync_generation").notNull(),
syncPhase: text("sync_phase"),
processedRepositoryCount: integer("processed_repository_count").notNull(),
totalRepositoryCount: integer("total_repository_count").notNull(),
updatedAt: integer("updated_at").notNull(),
},
(table) => [check("github_meta_singleton_id_check", sql`${table.id} = 1`)],
@ -22,6 +26,7 @@ export const githubRepositories = sqliteTable("github_repositories", {
cloneUrl: text("clone_url").notNull(),
private: integer("private").notNull(),
defaultBranch: text("default_branch").notNull(),
syncGeneration: integer("sync_generation").notNull(),
updatedAt: integer("updated_at").notNull(),
});
@ -30,6 +35,7 @@ export const githubBranches = sqliteTable("github_branches", {
repoId: text("repo_id").notNull(),
branchName: text("branch_name").notNull(),
commitSha: text("commit_sha").notNull(),
syncGeneration: integer("sync_generation").notNull(),
updatedAt: integer("updated_at").notNull(),
});
@ -40,6 +46,7 @@ export const githubMembers = sqliteTable("github_members", {
email: text("email"),
role: text("role"),
state: text("state").notNull(),
syncGeneration: integer("sync_generation").notNull(),
updatedAt: integer("updated_at").notNull(),
});
@ -56,5 +63,6 @@ export const githubPullRequests = sqliteTable("github_pull_requests", {
baseRefName: text("base_ref_name").notNull(),
authorLogin: text("author_login"),
isDraft: integer("is_draft").notNull(),
syncGeneration: integer("sync_generation").notNull(),
updatedAt: integer("updated_at").notNull(),
});

View file

@ -1,16 +1,29 @@
// @ts-nocheck
import { eq } from "drizzle-orm";
import { actor, queue } from "rivetkit";
import { workflow, Loop } from "rivetkit/workflow";
import { workflow } from "rivetkit/workflow";
import type { FoundryOrganization } from "@sandbox-agent/foundry-shared";
import { getActorRuntimeContext } from "../context.js";
import { getOrCreateOrganization, getOrCreateRepository, getTask } from "../handles.js";
import { repoIdFromRemote } from "../../services/repo.js";
import { resolveOrganizationGithubAuth } from "../../services/github-auth.js";
import { expectQueueResponse } from "../../services/queue.js";
import { organizationWorkflowQueueName } from "../organization/queues.js";
import { repositoryWorkflowQueueName } from "../repository/workflow.js";
import { taskWorkflowQueueName } from "../task/workflow/index.js";
import { githubDataDb } from "./db/db.js";
import { githubBranches, githubMembers, githubMeta, githubPullRequests, githubRepositories } from "./db/schema.js";
import { GITHUB_DATA_QUEUE_NAMES, runGithubDataWorkflow } from "./workflow.js";
const META_ROW_ID = 1;
const SYNC_REPOSITORY_BATCH_SIZE = 10;
type GithubSyncPhase =
| "discovering_repositories"
| "syncing_repositories"
| "syncing_branches"
| "syncing_members"
| "syncing_pull_requests";
interface GithubDataInput {
organizationId: string;
@ -70,6 +83,12 @@ interface ClearStateInput {
label: string;
}
async function sendOrganizationCommand(organization: any, name: Parameters<typeof organizationWorkflowQueueName>[0], body: unknown): Promise<void> {
await expectQueueResponse<{ ok: true }>(
await organization.send(organizationWorkflowQueueName(name), body, { wait: true, timeout: 60_000 }),
);
}
interface PullRequestWebhookInput {
connectedAccount: string;
installationStatus: FoundryOrganization["github"]["installationStatus"];
@ -93,6 +112,19 @@ interface PullRequestWebhookInput {
};
}
interface GithubMetaState {
connectedAccount: string;
installationStatus: FoundryOrganization["github"]["installationStatus"];
syncStatus: FoundryOrganization["github"]["syncStatus"];
installationId: number | null;
lastSyncLabel: string;
lastSyncAt: number | null;
syncGeneration: number;
syncPhase: GithubSyncPhase | null;
processedRepositoryCount: number;
totalRepositoryCount: number;
}
function normalizePrStatus(input: { state: string; isDraft?: boolean; merged?: boolean }): "OPEN" | "DRAFT" | "CLOSED" | "MERGED" {
const state = input.state.trim().toUpperCase();
if (input.merged || state === "MERGED") return "MERGED";
@ -117,7 +149,18 @@ function pullRequestSummaryFromRow(row: any) {
};
}
async function readMeta(c: any) {
function chunkItems<T>(items: T[], size: number): T[][] {
if (items.length === 0) {
return [];
}
const chunks: T[][] = [];
for (let index = 0; index < items.length; index += size) {
chunks.push(items.slice(index, index + size));
}
return chunks;
}
export async function readMeta(c: any): Promise<GithubMetaState> {
const row = await c.db.select().from(githubMeta).where(eq(githubMeta.id, META_ROW_ID)).get();
return {
connectedAccount: row?.connectedAccount ?? "",
@ -126,10 +169,14 @@ async function readMeta(c: any) {
installationId: row?.installationId ?? null,
lastSyncLabel: row?.lastSyncLabel ?? "Waiting for first import",
lastSyncAt: row?.lastSyncAt ?? null,
syncGeneration: row?.syncGeneration ?? 0,
syncPhase: (row?.syncPhase ?? null) as GithubSyncPhase | null,
processedRepositoryCount: row?.processedRepositoryCount ?? 0,
totalRepositoryCount: row?.totalRepositoryCount ?? 0,
};
}
async function writeMeta(c: any, patch: Partial<Awaited<ReturnType<typeof readMeta>>>) {
async function writeMeta(c: any, patch: Partial<GithubMetaState>) {
const current = await readMeta(c);
const next = {
...current,
@ -145,6 +192,10 @@ async function writeMeta(c: any, patch: Partial<Awaited<ReturnType<typeof readMe
installationId: next.installationId,
lastSyncLabel: next.lastSyncLabel,
lastSyncAt: next.lastSyncAt,
syncGeneration: next.syncGeneration,
syncPhase: next.syncPhase,
processedRepositoryCount: next.processedRepositoryCount,
totalRepositoryCount: next.totalRepositoryCount,
updatedAt: Date.now(),
})
.onConflictDoUpdate({
@ -156,6 +207,10 @@ async function writeMeta(c: any, patch: Partial<Awaited<ReturnType<typeof readMe
installationId: next.installationId,
lastSyncLabel: next.lastSyncLabel,
lastSyncAt: next.lastSyncAt,
syncGeneration: next.syncGeneration,
syncPhase: next.syncPhase,
processedRepositoryCount: next.processedRepositoryCount,
totalRepositoryCount: next.totalRepositoryCount,
updatedAt: Date.now(),
},
})
@ -163,6 +218,35 @@ async function writeMeta(c: any, patch: Partial<Awaited<ReturnType<typeof readMe
return next;
}
async function publishSyncProgress(c: any, patch: Partial<GithubMetaState>): Promise<GithubMetaState> {
const meta = await writeMeta(c, patch);
const organization = await getOrCreateOrganization(c, c.state.organizationId);
await sendOrganizationCommand(organization, "organization.command.github.sync_progress.apply", {
connectedAccount: meta.connectedAccount,
installationStatus: meta.installationStatus,
installationId: meta.installationId,
syncStatus: meta.syncStatus,
lastSyncLabel: meta.lastSyncLabel,
lastSyncAt: meta.lastSyncAt,
syncGeneration: meta.syncGeneration,
syncPhase: meta.syncPhase,
processedRepositoryCount: meta.processedRepositoryCount,
totalRepositoryCount: meta.totalRepositoryCount,
});
return meta;
}
async function runSyncStep<T>(c: any, name: string, run: () => Promise<T>): Promise<T> {
if (typeof c.step !== "function") {
return await run();
}
return await c.step({
name,
timeout: 90_000,
run,
});
}
async function getOrganizationContext(c: any, overrides?: FullSyncInput) {
const organizationHandle = await getOrCreateOrganization(c, c.state.organizationId);
const organizationState = await organizationHandle.getOrganizationShellStateIfInitialized({});
@ -183,8 +267,7 @@ async function getOrganizationContext(c: any, overrides?: FullSyncInput) {
};
}
async function replaceRepositories(c: any, repositories: GithubRepositoryRecord[], updatedAt: number) {
await c.db.delete(githubRepositories).run();
async function upsertRepositories(c: any, repositories: GithubRepositoryRecord[], updatedAt: number, syncGeneration: number) {
for (const repository of repositories) {
await c.db
.insert(githubRepositories)
@ -194,14 +277,35 @@ async function replaceRepositories(c: any, repositories: GithubRepositoryRecord[
cloneUrl: repository.cloneUrl,
private: repository.private ? 1 : 0,
defaultBranch: repository.defaultBranch,
syncGeneration,
updatedAt,
})
.onConflictDoUpdate({
target: githubRepositories.repoId,
set: {
fullName: repository.fullName,
cloneUrl: repository.cloneUrl,
private: repository.private ? 1 : 0,
defaultBranch: repository.defaultBranch,
syncGeneration,
updatedAt,
},
})
.run();
}
}
async function replaceBranches(c: any, branches: GithubBranchRecord[], updatedAt: number) {
await c.db.delete(githubBranches).run();
async function sweepRepositories(c: any, syncGeneration: number) {
const rows = await c.db.select({ repoId: githubRepositories.repoId, syncGeneration: githubRepositories.syncGeneration }).from(githubRepositories).all();
for (const row of rows) {
if (row.syncGeneration === syncGeneration) {
continue;
}
await c.db.delete(githubRepositories).where(eq(githubRepositories.repoId, row.repoId)).run();
}
}
async function upsertBranches(c: any, branches: GithubBranchRecord[], updatedAt: number, syncGeneration: number) {
for (const branch of branches) {
await c.db
.insert(githubBranches)
@ -210,14 +314,34 @@ async function replaceBranches(c: any, branches: GithubBranchRecord[], updatedAt
repoId: branch.repoId,
branchName: branch.branchName,
commitSha: branch.commitSha,
syncGeneration,
updatedAt,
})
.onConflictDoUpdate({
target: githubBranches.branchId,
set: {
repoId: branch.repoId,
branchName: branch.branchName,
commitSha: branch.commitSha,
syncGeneration,
updatedAt,
},
})
.run();
}
}
async function replaceMembers(c: any, members: GithubMemberRecord[], updatedAt: number) {
await c.db.delete(githubMembers).run();
async function sweepBranches(c: any, syncGeneration: number) {
const rows = await c.db.select({ branchId: githubBranches.branchId, syncGeneration: githubBranches.syncGeneration }).from(githubBranches).all();
for (const row of rows) {
if (row.syncGeneration === syncGeneration) {
continue;
}
await c.db.delete(githubBranches).where(eq(githubBranches.branchId, row.branchId)).run();
}
}
async function upsertMembers(c: any, members: GithubMemberRecord[], updatedAt: number, syncGeneration: number) {
for (const member of members) {
await c.db
.insert(githubMembers)
@ -228,14 +352,36 @@ async function replaceMembers(c: any, members: GithubMemberRecord[], updatedAt:
email: member.email ?? null,
role: member.role ?? null,
state: member.state ?? "active",
syncGeneration,
updatedAt,
})
.onConflictDoUpdate({
target: githubMembers.memberId,
set: {
login: member.login,
displayName: member.name || member.login,
email: member.email ?? null,
role: member.role ?? null,
state: member.state ?? "active",
syncGeneration,
updatedAt,
},
})
.run();
}
}
async function replacePullRequests(c: any, pullRequests: GithubPullRequestRecord[]) {
await c.db.delete(githubPullRequests).run();
async function sweepMembers(c: any, syncGeneration: number) {
const rows = await c.db.select({ memberId: githubMembers.memberId, syncGeneration: githubMembers.syncGeneration }).from(githubMembers).all();
for (const row of rows) {
if (row.syncGeneration === syncGeneration) {
continue;
}
await c.db.delete(githubMembers).where(eq(githubMembers.memberId, row.memberId)).run();
}
}
async function upsertPullRequests(c: any, pullRequests: GithubPullRequestRecord[], syncGeneration: number) {
for (const pullRequest of pullRequests) {
await c.db
.insert(githubPullRequests)
@ -252,19 +398,54 @@ async function replacePullRequests(c: any, pullRequests: GithubPullRequestRecord
baseRefName: pullRequest.baseRefName,
authorLogin: pullRequest.authorLogin ?? null,
isDraft: pullRequest.isDraft ? 1 : 0,
syncGeneration,
updatedAt: pullRequest.updatedAt,
})
.onConflictDoUpdate({
target: githubPullRequests.prId,
set: {
repoId: pullRequest.repoId,
repoFullName: pullRequest.repoFullName,
number: pullRequest.number,
title: pullRequest.title,
body: pullRequest.body ?? null,
state: pullRequest.state,
url: pullRequest.url,
headRefName: pullRequest.headRefName,
baseRefName: pullRequest.baseRefName,
authorLogin: pullRequest.authorLogin ?? null,
isDraft: pullRequest.isDraft ? 1 : 0,
syncGeneration,
updatedAt: pullRequest.updatedAt,
},
})
.run();
}
}
async function refreshTaskSummaryForBranch(c: any, repoId: string, branchName: string) {
async function sweepPullRequests(c: any, syncGeneration: number) {
const rows = await c.db.select({ prId: githubPullRequests.prId, syncGeneration: githubPullRequests.syncGeneration }).from(githubPullRequests).all();
for (const row of rows) {
if (row.syncGeneration === syncGeneration) {
continue;
}
await c.db.delete(githubPullRequests).where(eq(githubPullRequests.prId, row.prId)).run();
}
}
async function refreshTaskSummaryForBranch(c: any, repoId: string, branchName: string, pullRequest: ReturnType<typeof pullRequestSummaryFromRow> | null) {
const repositoryRecord = await c.db.select().from(githubRepositories).where(eq(githubRepositories.repoId, repoId)).get();
if (!repositoryRecord) {
return;
}
const repository = await getOrCreateRepository(c, c.state.organizationId, repoId, repositoryRecord.cloneUrl);
await repository.refreshTaskSummaryForBranch({ branchName });
const repository = await getOrCreateRepository(c, c.state.organizationId, repoId);
await expectQueueResponse<{ ok: true }>(
await repository.send(
repositoryWorkflowQueueName("repository.command.refreshTaskSummaryForBranch"),
{ branchName, pullRequest },
{ wait: true, timeout: 10_000 },
),
);
}
async function emitPullRequestChangeEvents(c: any, beforeRows: any[], afterRows: any[]) {
@ -286,14 +467,14 @@ async function emitPullRequestChangeEvents(c: any, beforeRows: any[], afterRows:
if (!changed) {
continue;
}
await refreshTaskSummaryForBranch(c, row.repoId, row.headRefName);
await refreshTaskSummaryForBranch(c, row.repoId, row.headRefName, pullRequestSummaryFromRow(row));
}
for (const [prId, row] of beforeById) {
if (afterById.has(prId)) {
continue;
}
await refreshTaskSummaryForBranch(c, row.repoId, row.headRefName);
await refreshTaskSummaryForBranch(c, row.repoId, row.headRefName, null);
}
}
@ -302,7 +483,7 @@ async function autoArchiveTaskForClosedPullRequest(c: any, row: any) {
if (!repositoryRecord) {
return;
}
const repository = await getOrCreateRepository(c, c.state.organizationId, row.repoId, repositoryRecord.cloneUrl);
const repository = await getOrCreateRepository(c, c.state.organizationId, row.repoId);
const match = await repository.findTaskForBranch({
branchName: row.headRefName,
});
@ -311,7 +492,7 @@ async function autoArchiveTaskForClosedPullRequest(c: any, row: any) {
}
try {
const task = getTask(c, c.state.organizationId, row.repoId, match.taskId);
await task.archive({ reason: `PR ${String(row.state).toLowerCase()}` });
await task.send(taskWorkflowQueueName("task.command.archive"), { reason: `PR ${String(row.state).toLowerCase()}` }, { wait: false });
} catch {
// Best-effort only. Task summary refresh will still clear the PR state.
}
@ -363,8 +544,7 @@ async function resolveMembers(c: any, context: Awaited<ReturnType<typeof getOrga
return await appShell.github.listOrganizationMembers(context.accessToken, context.githubLogin);
}
async function resolvePullRequests(
c: any,
async function listPullRequestsForRepositories(
context: Awaited<ReturnType<typeof getOrganizationContext>>,
repositories: GithubRepositoryRecord[],
): Promise<GithubPullRequestRecord[]> {
@ -448,11 +628,51 @@ async function listRepositoryBranchesForContext(
}
async function resolveBranches(
_c: any,
c: any,
context: Awaited<ReturnType<typeof getOrganizationContext>>,
repositories: GithubRepositoryRecord[],
): Promise<GithubBranchRecord[]> {
return (await Promise.all(repositories.map((repository) => listRepositoryBranchesForContext(context, repository)))).flat();
onBatch?: (branches: GithubBranchRecord[]) => Promise<void>,
onProgress?: (processedRepositoryCount: number, totalRepositoryCount: number) => Promise<void>,
): Promise<void> {
const batches = chunkItems(repositories, SYNC_REPOSITORY_BATCH_SIZE);
let processedRepositoryCount = 0;
for (const batch of batches) {
const batchBranches = await runSyncStep(c, `github-sync-branches-${processedRepositoryCount / SYNC_REPOSITORY_BATCH_SIZE + 1}`, async () =>
(await Promise.all(batch.map((repository) => listRepositoryBranchesForContext(context, repository)))).flat(),
);
if (onBatch) {
await onBatch(batchBranches);
}
processedRepositoryCount += batch.length;
if (onProgress) {
await onProgress(processedRepositoryCount, repositories.length);
}
}
}
async function resolvePullRequests(
c: any,
context: Awaited<ReturnType<typeof getOrganizationContext>>,
repositories: GithubRepositoryRecord[],
onBatch?: (pullRequests: GithubPullRequestRecord[]) => Promise<void>,
onProgress?: (processedRepositoryCount: number, totalRepositoryCount: number) => Promise<void>,
): Promise<void> {
const batches = chunkItems(repositories, SYNC_REPOSITORY_BATCH_SIZE);
let processedRepositoryCount = 0;
for (const batch of batches) {
const batchPullRequests = await runSyncStep(c, `github-sync-pull-requests-${processedRepositoryCount / SYNC_REPOSITORY_BATCH_SIZE + 1}`, async () =>
listPullRequestsForRepositories(context, batch),
);
if (onBatch) {
await onBatch(batchPullRequests);
}
processedRepositoryCount += batch.length;
if (onProgress) {
await onProgress(processedRepositoryCount, repositories.length);
}
}
}
async function refreshRepositoryBranches(
@ -461,6 +681,7 @@ async function refreshRepositoryBranches(
repository: GithubRepositoryRecord,
updatedAt: number,
): Promise<void> {
const currentMeta = await readMeta(c);
const nextBranches = await listRepositoryBranchesForContext(context, repository);
await c.db
.delete(githubBranches)
@ -475,6 +696,7 @@ async function refreshRepositoryBranches(
repoId: branch.repoId,
branchName: branch.branchName,
commitSha: branch.commitSha,
syncGeneration: currentMeta.syncGeneration,
updatedAt,
})
.run();
@ -485,118 +707,176 @@ async function readAllPullRequestRows(c: any) {
return await c.db.select().from(githubPullRequests).all();
}
async function runFullSync(c: any, input: FullSyncInput = {}) {
export async function runFullSync(c: any, input: FullSyncInput = {}) {
const startedAt = Date.now();
const beforeRows = await readAllPullRequestRows(c);
const context = await getOrganizationContext(c, input);
const currentMeta = await readMeta(c);
let context: Awaited<ReturnType<typeof getOrganizationContext>> | null = null;
let syncGeneration = currentMeta.syncGeneration + 1;
await writeMeta(c, {
connectedAccount: context.connectedAccount,
installationStatus: context.installationStatus,
installationId: context.installationId,
syncStatus: "syncing",
lastSyncLabel: input.label?.trim() || "Syncing GitHub data...",
});
try {
context = await getOrganizationContext(c, input);
syncGeneration = currentMeta.syncGeneration + 1;
const repositories = await resolveRepositories(c, context);
const branches = await resolveBranches(c, context, repositories);
const members = await resolveMembers(c, context);
const pullRequests = await resolvePullRequests(c, context, repositories);
await replaceRepositories(c, repositories, startedAt);
await replaceBranches(c, branches, startedAt);
await replaceMembers(c, members, startedAt);
await replacePullRequests(c, pullRequests);
const organization = await getOrCreateOrganization(c, c.state.organizationId);
await organization.applyGithubDataProjection({
connectedAccount: context.connectedAccount,
installationStatus: context.installationStatus,
installationId: context.installationId,
syncStatus: "synced",
lastSyncLabel: repositories.length > 0 ? `Synced ${repositories.length} repositories` : "No repositories available",
lastSyncAt: startedAt,
repositories,
});
const meta = await writeMeta(c, {
connectedAccount: context.connectedAccount,
installationStatus: context.installationStatus,
installationId: context.installationId,
syncStatus: "synced",
lastSyncLabel: repositories.length > 0 ? `Synced ${repositories.length} repositories` : "No repositories available",
lastSyncAt: startedAt,
});
const afterRows = await readAllPullRequestRows(c);
await emitPullRequestChangeEvents(c, beforeRows, afterRows);
return {
...meta,
repositoryCount: repositories.length,
memberCount: members.length,
pullRequestCount: afterRows.length,
};
}
const GITHUB_DATA_QUEUE_NAMES = ["githubData.command.syncRepos"] as const;
async function runGithubDataWorkflow(ctx: any): Promise<void> {
// Initial sync: if this actor was just created and has never synced,
// kick off the first full sync automatically.
await ctx.step({
name: "github-data-initial-sync",
timeout: 5 * 60_000,
run: async () => {
const meta = await readMeta(ctx);
if (meta.syncStatus !== "pending") {
return; // Already synced or syncing — skip initial sync
}
try {
await runFullSync(ctx, { label: "Importing repository catalog..." });
} catch (error) {
// Best-effort initial sync. Write the error to meta so the client
// sees the failure and can trigger a manual retry.
const currentMeta = await readMeta(ctx);
const organization = await getOrCreateOrganization(ctx, ctx.state.organizationId);
await organization.markOrganizationSyncFailed({
message: error instanceof Error ? error.message : "GitHub import failed",
installationStatus: currentMeta.installationStatus,
});
}
},
});
// Command loop for explicit sync requests (reload, re-import, etc.)
await ctx.loop("github-data-command-loop", async (loopCtx: any) => {
const msg = await loopCtx.queue.next("next-github-data-command", {
names: [...GITHUB_DATA_QUEUE_NAMES],
completable: true,
await publishSyncProgress(c, {
connectedAccount: context.connectedAccount,
installationStatus: context.installationStatus,
installationId: context.installationId,
syncStatus: "syncing",
lastSyncLabel: input.label?.trim() || "Syncing GitHub data...",
syncGeneration,
syncPhase: "discovering_repositories",
processedRepositoryCount: 0,
totalRepositoryCount: 0,
});
if (!msg) {
return Loop.continue(undefined);
}
try {
if (msg.name === "githubData.command.syncRepos") {
await loopCtx.step({
name: "github-data-sync-repos",
timeout: 5 * 60_000,
run: async () => {
const body = msg.body as FullSyncInput;
await runFullSync(loopCtx, body);
},
const repositories = await runSyncStep(c, "github-sync-repositories", async () => resolveRepositories(c, context));
const totalRepositoryCount = repositories.length;
await publishSyncProgress(c, {
connectedAccount: context.connectedAccount,
installationStatus: context.installationStatus,
installationId: context.installationId,
syncStatus: "syncing",
lastSyncLabel: totalRepositoryCount > 0 ? `Importing ${totalRepositoryCount} repositories...` : "No repositories available",
syncGeneration,
syncPhase: "syncing_repositories",
processedRepositoryCount: totalRepositoryCount,
totalRepositoryCount,
});
await upsertRepositories(c, repositories, startedAt, syncGeneration);
const organization = await getOrCreateOrganization(c, c.state.organizationId);
await sendOrganizationCommand(organization, "organization.command.github.data_projection.apply", {
connectedAccount: context.connectedAccount,
installationStatus: context.installationStatus,
installationId: context.installationId,
syncStatus: "syncing",
lastSyncLabel: totalRepositoryCount > 0 ? `Imported ${totalRepositoryCount} repositories` : "No repositories available",
lastSyncAt: currentMeta.lastSyncAt,
syncGeneration,
syncPhase: totalRepositoryCount > 0 ? "syncing_branches" : null,
processedRepositoryCount: 0,
totalRepositoryCount,
repositories,
});
await resolveBranches(
c,
context,
repositories,
async (batchBranches) => {
await upsertBranches(c, batchBranches, startedAt, syncGeneration);
},
async (processedRepositoryCount, repositoryCount) => {
await publishSyncProgress(c, {
connectedAccount: context.connectedAccount,
installationStatus: context.installationStatus,
installationId: context.installationId,
syncStatus: "syncing",
lastSyncLabel: `Synced branches for ${processedRepositoryCount} of ${repositoryCount} repositories`,
syncGeneration,
syncPhase: "syncing_branches",
processedRepositoryCount,
totalRepositoryCount: repositoryCount,
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
} catch (error) {
const message = error instanceof Error ? error.message : String(error);
await msg.complete({ error: message }).catch(() => {});
}
},
);
return Loop.continue(undefined);
});
await publishSyncProgress(c, {
connectedAccount: context.connectedAccount,
installationStatus: context.installationStatus,
installationId: context.installationId,
syncStatus: "syncing",
lastSyncLabel: "Syncing GitHub members...",
syncGeneration,
syncPhase: "syncing_members",
processedRepositoryCount: totalRepositoryCount,
totalRepositoryCount,
});
const members = await runSyncStep(c, "github-sync-members", async () => resolveMembers(c, context));
await upsertMembers(c, members, startedAt, syncGeneration);
await sweepMembers(c, syncGeneration);
await resolvePullRequests(
c,
context,
repositories,
async (batchPullRequests) => {
await upsertPullRequests(c, batchPullRequests, syncGeneration);
},
async (processedRepositoryCount, repositoryCount) => {
await publishSyncProgress(c, {
connectedAccount: context.connectedAccount,
installationStatus: context.installationStatus,
installationId: context.installationId,
syncStatus: "syncing",
lastSyncLabel: `Synced pull requests for ${processedRepositoryCount} of ${repositoryCount} repositories`,
syncGeneration,
syncPhase: "syncing_pull_requests",
processedRepositoryCount,
totalRepositoryCount: repositoryCount,
});
},
);
await sweepBranches(c, syncGeneration);
await sweepPullRequests(c, syncGeneration);
await sweepRepositories(c, syncGeneration);
await sendOrganizationCommand(organization, "organization.command.github.data_projection.apply", {
connectedAccount: context.connectedAccount,
installationStatus: context.installationStatus,
installationId: context.installationId,
syncStatus: "synced",
lastSyncLabel: totalRepositoryCount > 0 ? `Synced ${totalRepositoryCount} repositories` : "No repositories available",
lastSyncAt: startedAt,
syncGeneration,
syncPhase: null,
processedRepositoryCount: totalRepositoryCount,
totalRepositoryCount,
repositories,
});
const meta = await writeMeta(c, {
connectedAccount: context.connectedAccount,
installationStatus: context.installationStatus,
installationId: context.installationId,
syncStatus: "synced",
lastSyncLabel: totalRepositoryCount > 0 ? `Synced ${totalRepositoryCount} repositories` : "No repositories available",
lastSyncAt: startedAt,
syncGeneration,
syncPhase: null,
processedRepositoryCount: totalRepositoryCount,
totalRepositoryCount,
});
const afterRows = await readAllPullRequestRows(c);
await emitPullRequestChangeEvents(c, beforeRows, afterRows);
return {
...meta,
repositoryCount: repositories.length,
memberCount: members.length,
pullRequestCount: afterRows.length,
};
} catch (error) {
const message = error instanceof Error ? error.message : "GitHub import failed";
await publishSyncProgress(c, {
connectedAccount: context?.connectedAccount ?? currentMeta.connectedAccount,
installationStatus: context?.installationStatus ?? currentMeta.installationStatus,
installationId: context?.installationId ?? currentMeta.installationId,
syncStatus: "error",
lastSyncLabel: message,
syncGeneration,
syncPhase: null,
processedRepositoryCount: 0,
totalRepositoryCount: 0,
});
throw error;
}
}
export const githubData = actor({
@ -651,11 +931,6 @@ export const githubData = actor({
};
},
async listPullRequestsForRepository(c, input: { repoId: string }) {
const rows = await c.db.select().from(githubPullRequests).where(eq(githubPullRequests.repoId, input.repoId)).all();
return rows.map(pullRequestSummaryFromRow);
},
async listBranchesForRepository(c, input: { repoId: string }) {
const rows = await c.db.select().from(githubBranches).where(eq(githubBranches.repoId, input.repoId)).all();
return rows
@ -666,36 +941,10 @@ export const githubData = actor({
.sort((left, right) => left.branchName.localeCompare(right.branchName));
},
async listOpenPullRequests(c) {
const rows = await c.db.select().from(githubPullRequests).all();
return rows.map(pullRequestSummaryFromRow).sort((left, right) => right.updatedAtMs - left.updatedAtMs);
},
},
});
async getPullRequestForBranch(c, input: { repoId: string; branchName: string }) {
const rows = await c.db.select().from(githubPullRequests).where(eq(githubPullRequests.repoId, input.repoId)).all();
const match = rows.find((candidate) => candidate.headRefName === input.branchName) ?? null;
if (!match) {
return null;
}
return {
number: match.number,
status: match.isDraft ? ("draft" as const) : ("ready" as const),
};
},
async adminFullSync(c, input: FullSyncInput = {}) {
return await runFullSync(c, input);
},
async adminReloadOrganization(c) {
return await runFullSync(c, { label: "Reloading GitHub organization..." });
},
async adminReloadAllPullRequests(c) {
return await runFullSync(c, { label: "Reloading GitHub pull requests..." });
},
async reloadRepository(c, input: { repoId: string }) {
export async function reloadRepositoryMutation(c: any, input: { repoId: string }) {
const context = await getOrganizationContext(c);
const current = await c.db.select().from(githubRepositories).where(eq(githubRepositories.repoId, input.repoId)).get();
if (!current) {
@ -713,6 +962,7 @@ export const githubData = actor({
}
const updatedAt = Date.now();
const currentMeta = await readMeta(c);
await c.db
.insert(githubRepositories)
.values({
@ -721,6 +971,7 @@ export const githubData = actor({
cloneUrl: repository.cloneUrl,
private: repository.private ? 1 : 0,
defaultBranch: repository.defaultBranch,
syncGeneration: currentMeta.syncGeneration,
updatedAt,
})
.onConflictDoUpdate({
@ -730,6 +981,7 @@ export const githubData = actor({
cloneUrl: repository.cloneUrl,
private: repository.private ? 1 : 0,
defaultBranch: repository.defaultBranch,
syncGeneration: currentMeta.syncGeneration,
updatedAt,
},
})
@ -747,7 +999,7 @@ export const githubData = actor({
);
const organization = await getOrCreateOrganization(c, c.state.organizationId);
await organization.applyGithubRepositoryProjection({
await sendOrganizationCommand(organization, "organization.command.github.repository_projection.apply", {
repoId: input.repoId,
remoteUrl: repository.cloneUrl,
});
@ -758,98 +1010,11 @@ export const githubData = actor({
private: repository.private,
defaultBranch: repository.defaultBranch,
};
},
async reloadPullRequest(c, input: { repoId: string; prNumber: number }) {
const repository = await c.db.select().from(githubRepositories).where(eq(githubRepositories.repoId, input.repoId)).get();
if (!repository) {
throw new Error(`Unknown GitHub repository: ${input.repoId}`);
}
const context = await getOrganizationContext(c);
const { appShell } = getActorRuntimeContext();
const pullRequest =
context.installationId != null
? await appShell.github.getInstallationPullRequest(context.installationId, repository.fullName, input.prNumber)
: context.accessToken
? await appShell.github.getUserPullRequest(context.accessToken, repository.fullName, input.prNumber)
: null;
if (!pullRequest) {
throw new Error(`Unable to reload pull request #${input.prNumber} for ${repository.fullName}`);
}
}
export async function clearStateMutation(c: any, input: ClearStateInput) {
const beforeRows = await readAllPullRequestRows(c);
const updatedAt = Date.now();
const nextState = normalizePrStatus(pullRequest);
const prId = `${input.repoId}#${input.prNumber}`;
if (nextState === "CLOSED" || nextState === "MERGED") {
await c.db.delete(githubPullRequests).where(eq(githubPullRequests.prId, prId)).run();
} else {
await c.db
.insert(githubPullRequests)
.values({
prId,
repoId: input.repoId,
repoFullName: repository.fullName,
number: pullRequest.number,
title: pullRequest.title,
body: pullRequest.body ?? null,
state: nextState,
url: pullRequest.url,
headRefName: pullRequest.headRefName,
baseRefName: pullRequest.baseRefName,
authorLogin: pullRequest.authorLogin ?? null,
isDraft: pullRequest.isDraft ? 1 : 0,
updatedAt,
})
.onConflictDoUpdate({
target: githubPullRequests.prId,
set: {
title: pullRequest.title,
body: pullRequest.body ?? null,
state: nextState,
url: pullRequest.url,
headRefName: pullRequest.headRefName,
baseRefName: pullRequest.baseRefName,
authorLogin: pullRequest.authorLogin ?? null,
isDraft: pullRequest.isDraft ? 1 : 0,
updatedAt,
},
})
.run();
}
const afterRows = await readAllPullRequestRows(c);
await emitPullRequestChangeEvents(c, beforeRows, afterRows);
const closed = afterRows.find((row) => row.prId === prId);
if (!closed && (nextState === "CLOSED" || nextState === "MERGED")) {
const previous = beforeRows.find((row) => row.prId === prId);
if (previous) {
await autoArchiveTaskForClosedPullRequest(c, {
...previous,
state: nextState,
});
}
}
return pullRequestSummaryFromRow(
afterRows.find((row) => row.prId === prId) ?? {
prId,
repoId: input.repoId,
repoFullName: repository.fullName,
number: input.prNumber,
title: pullRequest.title,
state: nextState,
url: pullRequest.url,
headRefName: pullRequest.headRefName,
baseRefName: pullRequest.baseRefName,
authorLogin: pullRequest.authorLogin ?? null,
isDraft: pullRequest.isDraft ? 1 : 0,
updatedAt,
},
);
},
async adminClearState(c, input: ClearStateInput) {
const beforeRows = await readAllPullRequestRows(c);
const currentMeta = await readMeta(c);
await c.db.delete(githubPullRequests).run();
await c.db.delete(githubBranches).run();
await c.db.delete(githubRepositories).run();
@ -861,26 +1026,35 @@ export const githubData = actor({
syncStatus: "pending",
lastSyncLabel: input.label,
lastSyncAt: null,
syncGeneration: currentMeta.syncGeneration,
syncPhase: null,
processedRepositoryCount: 0,
totalRepositoryCount: 0,
});
const organization = await getOrCreateOrganization(c, c.state.organizationId);
await organization.applyGithubDataProjection({
await sendOrganizationCommand(organization, "organization.command.github.data_projection.apply", {
connectedAccount: input.connectedAccount,
installationStatus: input.installationStatus,
installationId: input.installationId,
syncStatus: "pending",
lastSyncLabel: input.label,
lastSyncAt: null,
syncGeneration: currentMeta.syncGeneration,
syncPhase: null,
processedRepositoryCount: 0,
totalRepositoryCount: 0,
repositories: [],
});
await emitPullRequestChangeEvents(c, beforeRows, []);
},
}
async handlePullRequestWebhook(c, input: PullRequestWebhookInput) {
export async function handlePullRequestWebhookMutation(c: any, input: PullRequestWebhookInput) {
const beforeRows = await readAllPullRequestRows(c);
const repoId = repoIdFromRemote(input.repository.cloneUrl);
const currentRepository = await c.db.select().from(githubRepositories).where(eq(githubRepositories.repoId, repoId)).get();
const updatedAt = Date.now();
const currentMeta = await readMeta(c);
const state = normalizePrStatus(input.pullRequest);
const prId = `${repoId}#${input.pullRequest.number}`;
@ -892,6 +1066,7 @@ export const githubData = actor({
cloneUrl: input.repository.cloneUrl,
private: input.repository.private ? 1 : 0,
defaultBranch: currentRepository?.defaultBranch ?? input.pullRequest.baseRefName ?? "main",
syncGeneration: currentMeta.syncGeneration,
updatedAt,
})
.onConflictDoUpdate({
@ -901,6 +1076,7 @@ export const githubData = actor({
cloneUrl: input.repository.cloneUrl,
private: input.repository.private ? 1 : 0,
defaultBranch: currentRepository?.defaultBranch ?? input.pullRequest.baseRefName ?? "main",
syncGeneration: currentMeta.syncGeneration,
updatedAt,
},
})
@ -924,6 +1100,7 @@ export const githubData = actor({
baseRefName: input.pullRequest.baseRefName,
authorLogin: input.pullRequest.authorLogin ?? null,
isDraft: input.pullRequest.isDraft ? 1 : 0,
syncGeneration: currentMeta.syncGeneration,
updatedAt,
})
.onConflictDoUpdate({
@ -937,23 +1114,27 @@ export const githubData = actor({
baseRefName: input.pullRequest.baseRefName,
authorLogin: input.pullRequest.authorLogin ?? null,
isDraft: input.pullRequest.isDraft ? 1 : 0,
syncGeneration: currentMeta.syncGeneration,
updatedAt,
},
})
.run();
}
await writeMeta(c, {
await publishSyncProgress(c, {
connectedAccount: input.connectedAccount,
installationStatus: input.installationStatus,
installationId: input.installationId,
syncStatus: "synced",
lastSyncLabel: "GitHub webhook received",
lastSyncAt: updatedAt,
syncPhase: null,
processedRepositoryCount: 0,
totalRepositoryCount: 0,
});
const organization = await getOrCreateOrganization(c, c.state.organizationId);
await organization.applyGithubRepositoryProjection({
await sendOrganizationCommand(organization, "organization.command.github.repository_projection.apply", {
repoId,
remoteUrl: input.repository.cloneUrl,
});
@ -969,6 +1150,4 @@ export const githubData = actor({
});
}
}
},
},
});
}

View file

@ -0,0 +1,76 @@
// @ts-nocheck
import { Loop } from "rivetkit/workflow";
import { clearStateMutation, handlePullRequestWebhookMutation, reloadRepositoryMutation, runFullSync } from "./index.js";
export const GITHUB_DATA_QUEUE_NAMES = [
"githubData.command.syncRepos",
"githubData.command.reloadRepository",
"githubData.command.clearState",
"githubData.command.handlePullRequestWebhook",
] as const;
export type GithubDataQueueName = (typeof GITHUB_DATA_QUEUE_NAMES)[number];
export function githubDataWorkflowQueueName(name: GithubDataQueueName): GithubDataQueueName {
return name;
}
export async function runGithubDataWorkflow(ctx: any): Promise<void> {
const meta = await ctx.step({
name: "github-data-read-meta",
timeout: 30_000,
run: async () => {
const { readMeta } = await import("./index.js");
return await readMeta(ctx);
},
});
if (meta.syncStatus === "pending") {
try {
await runFullSync(ctx, { label: "Importing repository catalog..." });
} catch {
// Best-effort initial sync. runFullSync persists the failure state.
}
}
await ctx.loop("github-data-command-loop", async (loopCtx: any) => {
const msg = await loopCtx.queue.next("next-github-data-command", {
names: [...GITHUB_DATA_QUEUE_NAMES],
completable: true,
});
if (!msg) {
return Loop.continue(undefined);
}
try {
if (msg.name === "githubData.command.syncRepos") {
await runFullSync(loopCtx, msg.body);
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
if (msg.name === "githubData.command.reloadRepository") {
const result = await reloadRepositoryMutation(loopCtx, msg.body);
await msg.complete(result);
return Loop.continue(undefined);
}
if (msg.name === "githubData.command.clearState") {
await clearStateMutation(loopCtx, msg.body);
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
if (msg.name === "githubData.command.handlePullRequestWebhook") {
await handlePullRequestWebhookMutation(loopCtx, msg.body);
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
} catch (error) {
const message = error instanceof Error ? error.message : String(error);
await msg.complete({ error: message }).catch(() => {});
}
return Loop.continue(undefined);
});
}

View file

@ -20,12 +20,11 @@ export function getUser(c: any, userId: string) {
return actorClient(c).user.get(userKey(userId));
}
export async function getOrCreateRepository(c: any, organizationId: string, repoId: string, remoteUrl: string) {
export async function getOrCreateRepository(c: any, organizationId: string, repoId: string) {
return await actorClient(c).repository.getOrCreate(repositoryKey(organizationId, repoId), {
createWithInput: {
organizationId,
repoId,
remoteUrl,
},
});
}

View file

@ -32,7 +32,6 @@ export const registry = setup({
});
export * from "./context.js";
export * from "./events.js";
export * from "./audit-log/index.js";
export * from "./user/index.js";
export * from "./github-data/index.js";

View file

@ -1,76 +1,31 @@
// @ts-nocheck
import { desc, eq } from "drizzle-orm";
import { Loop } from "rivetkit/workflow";
import type {
CreateTaskInput,
AuditLogEvent,
HistoryQueryInput,
ListTasksInput,
SandboxProviderId,
RepoOverview,
RepoRecord,
StarSandboxAgentRepoInput,
StarSandboxAgentRepoResult,
SwitchResult,
TaskRecord,
TaskSummary,
TaskWorkspaceChangeModelInput,
TaskWorkspaceCreateTaskInput,
TaskWorkspaceDiffInput,
TaskWorkspaceRenameInput,
TaskWorkspaceRenameSessionInput,
TaskWorkspaceSelectInput,
TaskWorkspaceSetSessionUnreadInput,
TaskWorkspaceSendMessageInput,
TaskWorkspaceSessionInput,
TaskWorkspaceUpdateDraftInput,
WorkspaceRepositorySummary,
WorkspaceTaskSummary,
OrganizationEvent,
OrganizationGithubSummary,
OrganizationSummarySnapshot,
OrganizationUseInput,
} from "@sandbox-agent/foundry-shared";
import { getActorRuntimeContext } from "../context.js";
import { getOrCreateAuditLog, getOrCreateGithubData, getTask as getTaskHandle, getOrCreateRepository, selfOrganization } from "../handles.js";
import { getOrCreateRepository } from "../handles.js";
import { logActorWarning, resolveErrorMessage } from "../logging.js";
import { defaultSandboxProviderId } from "../../sandbox-config.js";
import { repoIdFromRemote } from "../../services/repo.js";
import { resolveOrganizationGithubAuth } from "../../services/github-auth.js";
import { organizationProfile, repos } from "./db/schema.js";
import { agentTypeForModel } from "../task/workspace.js";
import { expectQueueResponse } from "../../services/queue.js";
import { organizationAppActions } from "./app-shell.js";
import { organizationAppActions } from "./actions/app.js";
import { organizationBetterAuthActions } from "./actions/better-auth.js";
import { organizationOnboardingActions } from "./actions/onboarding.js";
import { organizationGithubActions } from "./actions/github.js";
import { organizationShellActions } from "./actions/organization.js";
import { organizationTaskActions } from "./actions/tasks.js";
export { createTaskMutation } from "./actions/tasks.js";
interface OrganizationState {
organizationId: string;
}
interface GetTaskInput {
organizationId: string;
repoId?: string;
taskId: string;
}
interface TaskProxyActionInput extends GetTaskInput {
reason?: string;
}
interface RepoOverviewInput {
organizationId: string;
repoId: string;
}
const ORGANIZATION_QUEUE_NAMES = ["organization.command.createTask", "organization.command.syncGithubSession"] as const;
const SANDBOX_AGENT_REPO = "rivet-dev/sandbox-agent";
type OrganizationQueueName = (typeof ORGANIZATION_QUEUE_NAMES)[number];
export { ORGANIZATION_QUEUE_NAMES };
export function organizationWorkflowQueueName(name: OrganizationQueueName): OrganizationQueueName {
return name;
}
const ORGANIZATION_PROFILE_ROW_ID = 1;
function assertOrganization(c: { state: OrganizationState }, organizationId: string): void {
@ -79,28 +34,6 @@ function assertOrganization(c: { state: OrganizationState }, organizationId: str
}
}
async function collectAllTaskSummaries(c: any): Promise<TaskSummary[]> {
const repoRows = await c.db.select({ repoId: repos.repoId, remoteUrl: repos.remoteUrl }).from(repos).orderBy(desc(repos.updatedAt)).all();
const all: TaskSummary[] = [];
for (const row of repoRows) {
try {
const repository = await getOrCreateRepository(c, c.state.organizationId, row.repoId, row.remoteUrl);
const snapshot = await repository.listTaskSummaries({ includeArchived: true });
all.push(...snapshot);
} catch (error) {
logActorWarning("organization", "failed collecting tasks for repo", {
organizationId: c.state.organizationId,
repoId: row.repoId,
error: resolveErrorMessage(error),
});
}
}
all.sort((a, b) => b.updatedAt - a.updatedAt);
return all;
}
function repoLabelFromRemote(remoteUrl: string): string {
try {
const url = new URL(remoteUrl.startsWith("http") ? remoteUrl : `https://${remoteUrl}`);
@ -127,67 +60,30 @@ function buildRepoSummary(repoRow: { repoId: string; remoteUrl: string; updatedA
};
}
async function resolveRepositoryForTask(c: any, taskId: string, repoId?: string | null) {
if (repoId) {
const repoRow = await c.db.select({ remoteUrl: repos.remoteUrl }).from(repos).where(eq(repos.repoId, repoId)).get();
if (!repoRow) {
throw new Error(`Unknown repo: ${repoId}`);
}
const repository = await getOrCreateRepository(c, c.state.organizationId, repoId, repoRow.remoteUrl);
return { repoId, repository };
}
const repoRows = await c.db.select({ repoId: repos.repoId, remoteUrl: repos.remoteUrl }).from(repos).orderBy(desc(repos.updatedAt)).all();
for (const row of repoRows) {
const repository = await getOrCreateRepository(c, c.state.organizationId, row.repoId, row.remoteUrl);
const summaries = await repository.listTaskSummaries({ includeArchived: true });
if (summaries.some((summary: TaskSummary) => summary.taskId === taskId)) {
return { repoId: row.repoId, repository };
}
}
throw new Error(`Unknown task: ${taskId}`);
}
async function reconcileWorkspaceProjection(c: any): Promise<OrganizationSummarySnapshot> {
const repoRows = await c.db
.select({ repoId: repos.repoId, remoteUrl: repos.remoteUrl, updatedAt: repos.updatedAt })
.from(repos)
.orderBy(desc(repos.updatedAt))
.all();
const taskRows: WorkspaceTaskSummary[] = [];
for (const row of repoRows) {
try {
const repository = await getOrCreateRepository(c, c.state.organizationId, row.repoId, row.remoteUrl);
taskRows.push(...(await repository.listWorkspaceTaskSummaries({})));
} catch (error) {
logActorWarning("organization", "failed collecting repo during workspace reconciliation", {
organizationId: c.state.organizationId,
repoId: row.repoId,
error: resolveErrorMessage(error),
});
}
}
taskRows.sort((left, right) => right.updatedAtMs - left.updatedAtMs);
function buildGithubSummary(profile: any, importedRepoCount: number): OrganizationGithubSummary {
return {
organizationId: c.state.organizationId,
repos: repoRows.map((row) => buildRepoSummary(row, taskRows)).sort((left, right) => right.latestActivityMs - left.latestActivityMs),
taskSummaries: taskRows,
connectedAccount: profile?.githubConnectedAccount ?? "",
installationStatus: profile?.githubInstallationStatus ?? "install_required",
syncStatus: profile?.githubSyncStatus ?? "pending",
importedRepoCount,
lastSyncLabel: profile?.githubLastSyncLabel ?? "Waiting for first import",
lastSyncAt: profile?.githubLastSyncAt ?? null,
lastWebhookAt: profile?.githubLastWebhookAt ?? null,
lastWebhookEvent: profile?.githubLastWebhookEvent ?? "",
syncGeneration: profile?.githubSyncGeneration ?? 0,
syncPhase: profile?.githubSyncPhase ?? null,
processedRepositoryCount: profile?.githubProcessedRepositoryCount ?? 0,
totalRepositoryCount: profile?.githubTotalRepositoryCount ?? 0,
};
}
async function requireWorkspaceTask(c: any, repoId: string, taskId: string) {
return getTaskHandle(c, c.state.organizationId, repoId, taskId);
}
/**
* Reads the organization sidebar snapshot by fanning out one level to the
* repository coordinators. Task summaries are repository-owned; organization
* only aggregates them.
*/
async function getOrganizationSummarySnapshot(c: any): Promise<OrganizationSummarySnapshot> {
const profile = await c.db.select().from(organizationProfile).where(eq(organizationProfile.id, ORGANIZATION_PROFILE_ROW_ID)).get();
const repoRows = await c.db
.select({
repoId: repos.repoId,
@ -200,7 +96,7 @@ async function getOrganizationSummarySnapshot(c: any): Promise<OrganizationSumma
const summaries: WorkspaceTaskSummary[] = [];
for (const row of repoRows) {
try {
const repository = await getOrCreateRepository(c, c.state.organizationId, row.repoId, row.remoteUrl);
const repository = await getOrCreateRepository(c, c.state.organizationId, row.repoId);
summaries.push(...(await repository.listWorkspaceTaskSummaries({})));
} catch (error) {
logActorWarning("organization", "failed reading repository task projection", {
@ -214,98 +110,26 @@ async function getOrganizationSummarySnapshot(c: any): Promise<OrganizationSumma
return {
organizationId: c.state.organizationId,
github: buildGithubSummary(profile, repoRows.length),
repos: repoRows.map((row) => buildRepoSummary(row, summaries)).sort((left, right) => right.latestActivityMs - left.latestActivityMs),
taskSummaries: summaries,
};
}
async function broadcastOrganizationSnapshot(c: any): Promise<void> {
export async function refreshOrganizationSnapshotMutation(c: any): Promise<void> {
c.broadcast("organizationUpdated", {
type: "organizationUpdated",
snapshot: await getOrganizationSummarySnapshot(c),
} satisfies OrganizationEvent);
}
async function createTaskMutation(c: any, input: CreateTaskInput): Promise<TaskRecord> {
assertOrganization(c, input.organizationId);
const { config } = getActorRuntimeContext();
const sandboxProviderId = input.sandboxProviderId ?? defaultSandboxProviderId(config);
const repoId = input.repoId;
const repoRow = await c.db.select({ remoteUrl: repos.remoteUrl }).from(repos).where(eq(repos.repoId, repoId)).get();
if (!repoRow) {
throw new Error(`Unknown repo: ${repoId}`);
}
const remoteUrl = repoRow.remoteUrl;
const repository = await getOrCreateRepository(c, c.state.organizationId, repoId, remoteUrl);
const created = await repository.createTask({
task: input.task,
sandboxProviderId,
agentType: input.agentType ?? null,
explicitTitle: input.explicitTitle ?? null,
explicitBranchName: input.explicitBranchName ?? null,
onBranch: input.onBranch ?? null,
});
return created;
}
export async function runOrganizationWorkflow(ctx: any): Promise<void> {
await ctx.loop("organization-command-loop", async (loopCtx: any) => {
const msg = await loopCtx.queue.next("next-organization-command", {
names: [...ORGANIZATION_QUEUE_NAMES],
completable: true,
});
if (!msg) {
return Loop.continue(undefined);
}
try {
if (msg.name === "organization.command.createTask") {
const result = await loopCtx.step({
name: "organization-create-task",
timeout: 5 * 60_000,
run: async () => createTaskMutation(loopCtx, msg.body as CreateTaskInput),
});
await msg.complete(result);
return Loop.continue(undefined);
}
if (msg.name === "organization.command.syncGithubSession") {
await loopCtx.step({
name: "organization-sync-github-session",
timeout: 60_000,
run: async () => {
const { syncGithubOrganizations } = await import("./app-shell.js");
await syncGithubOrganizations(loopCtx, msg.body as { sessionId: string; accessToken: string });
},
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
} catch (error) {
const message = resolveErrorMessage(error);
logActorWarning("organization", "organization workflow command failed", {
queueName: msg.name,
error: message,
});
await msg.complete({ error: message }).catch((completeError: unknown) => {
logActorWarning("organization", "organization workflow failed completing error response", {
queueName: msg.name,
error: resolveErrorMessage(completeError),
});
});
}
return Loop.continue(undefined);
});
}
export const organizationActions = {
...organizationBetterAuthActions,
...organizationGithubActions,
...organizationOnboardingActions,
...organizationShellActions,
...organizationAppActions,
...organizationTaskActions,
async useOrganization(c: any, input: OrganizationUseInput): Promise<{ organizationId: string }> {
assertOrganization(c, input.organizationId);
return { organizationId: c.state.organizationId };
@ -334,381 +158,180 @@ export const organizationActions = {
}));
},
async createTask(c: any, input: CreateTaskInput): Promise<TaskRecord> {
const self = selfOrganization(c);
return expectQueueResponse<TaskRecord>(
await self.send(organizationWorkflowQueueName("organization.command.createTask"), input, {
wait: true,
timeout: 10_000,
}),
);
},
async starSandboxAgentRepo(c: any, input: StarSandboxAgentRepoInput): Promise<StarSandboxAgentRepoResult> {
async getOrganizationSummary(c: any, input: OrganizationUseInput): Promise<OrganizationSummarySnapshot> {
assertOrganization(c, input.organizationId);
const { driver } = getActorRuntimeContext();
const auth = await resolveOrganizationGithubAuth(c, c.state.organizationId);
await driver.github.starRepository(SANDBOX_AGENT_REPO, {
githubToken: auth?.githubToken ?? null,
});
return {
repo: SANDBOX_AGENT_REPO,
starredAt: Date.now(),
};
return await getOrganizationSummarySnapshot(c);
},
};
async refreshOrganizationSnapshot(c: any): Promise<void> {
await broadcastOrganizationSnapshot(c);
export async function applyGithubRepositoryProjectionMutation(c: any, input: { repoId: string; remoteUrl: string }): Promise<void> {
const now = Date.now();
await c.db
.insert(repos)
.values({
repoId: input.repoId,
remoteUrl: input.remoteUrl,
createdAt: now,
updatedAt: now,
})
.onConflictDoUpdate({
target: repos.repoId,
set: {
remoteUrl: input.remoteUrl,
updatedAt: now,
},
})
.run();
await refreshOrganizationSnapshotMutation(c);
}
export async function applyGithubDataProjectionMutation(
c: any,
input: {
connectedAccount: string;
installationStatus: string;
installationId: number | null;
syncStatus: string;
lastSyncLabel: string;
lastSyncAt: number | null;
syncGeneration: number;
syncPhase: string | null;
processedRepositoryCount: number;
totalRepositoryCount: number;
repositories: Array<{ fullName: string; cloneUrl: string; private: boolean }>;
},
): Promise<void> {
const existingRepos = await c.db.select({ repoId: repos.repoId }).from(repos).all();
const nextRepoIds = new Set<string>();
const now = Date.now();
async applyGithubRepositoryProjection(c: any, input: { repoId: string; remoteUrl: string }): Promise<void> {
const now = Date.now();
const existing = await c.db.select({ repoId: repos.repoId }).from(repos).where(eq(repos.repoId, input.repoId)).get();
const profile = await c.db
.select({ id: organizationProfile.id })
.from(organizationProfile)
.where(eq(organizationProfile.id, ORGANIZATION_PROFILE_ROW_ID))
.get();
if (profile) {
await c.db
.update(organizationProfile)
.set({
githubConnectedAccount: input.connectedAccount,
githubInstallationStatus: input.installationStatus,
githubSyncStatus: input.syncStatus,
githubInstallationId: input.installationId,
githubLastSyncLabel: input.lastSyncLabel,
githubLastSyncAt: input.lastSyncAt,
githubSyncGeneration: input.syncGeneration,
githubSyncPhase: input.syncPhase,
githubProcessedRepositoryCount: input.processedRepositoryCount,
githubTotalRepositoryCount: input.totalRepositoryCount,
updatedAt: now,
})
.where(eq(organizationProfile.id, ORGANIZATION_PROFILE_ROW_ID))
.run();
}
for (const repository of input.repositories) {
const repoId = repoIdFromRemote(repository.cloneUrl);
nextRepoIds.add(repoId);
await c.db
.insert(repos)
.values({
repoId: input.repoId,
remoteUrl: input.remoteUrl,
repoId,
remoteUrl: repository.cloneUrl,
createdAt: now,
updatedAt: now,
})
.onConflictDoUpdate({
target: repos.repoId,
set: {
remoteUrl: input.remoteUrl,
remoteUrl: repository.cloneUrl,
updatedAt: now,
},
})
.run();
await broadcastOrganizationSnapshot(c);
},
}
async applyGithubDataProjection(
c: any,
input: {
connectedAccount: string;
installationStatus: string;
installationId: number | null;
syncStatus: string;
lastSyncLabel: string;
lastSyncAt: number | null;
repositories: Array<{ fullName: string; cloneUrl: string; private: boolean }>;
},
): Promise<void> {
const existingRepos = await c.db.select({ repoId: repos.repoId, remoteUrl: repos.remoteUrl, updatedAt: repos.updatedAt }).from(repos).all();
const existingById = new Map(existingRepos.map((repo) => [repo.repoId, repo]));
const nextRepoIds = new Set<string>();
const now = Date.now();
for (const repository of input.repositories) {
const repoId = repoIdFromRemote(repository.cloneUrl);
nextRepoIds.add(repoId);
await c.db
.insert(repos)
.values({
repoId,
remoteUrl: repository.cloneUrl,
createdAt: now,
updatedAt: now,
})
.onConflictDoUpdate({
target: repos.repoId,
set: {
remoteUrl: repository.cloneUrl,
updatedAt: now,
},
})
.run();
await broadcastOrganizationSnapshot(c);
for (const repo of existingRepos) {
if (nextRepoIds.has(repo.repoId)) {
continue;
}
await c.db.delete(repos).where(eq(repos.repoId, repo.repoId)).run();
}
for (const repo of existingRepos) {
if (nextRepoIds.has(repo.repoId)) {
continue;
}
await c.db.delete(repos).where(eq(repos.repoId, repo.repoId)).run();
await broadcastOrganizationSnapshot(c);
}
await refreshOrganizationSnapshotMutation(c);
}
const profile = await c.db
.select({ id: organizationProfile.id })
.from(organizationProfile)
.where(eq(organizationProfile.id, ORGANIZATION_PROFILE_ROW_ID))
.get();
if (profile) {
await c.db
.update(organizationProfile)
.set({
githubConnectedAccount: input.connectedAccount,
githubInstallationStatus: input.installationStatus,
githubSyncStatus: input.syncStatus,
githubInstallationId: input.installationId,
githubLastSyncLabel: input.lastSyncLabel,
githubLastSyncAt: input.lastSyncAt,
updatedAt: now,
})
.where(eq(organizationProfile.id, ORGANIZATION_PROFILE_ROW_ID))
.run();
}
export async function applyGithubSyncProgressMutation(
c: any,
input: {
connectedAccount: string;
installationStatus: string;
installationId: number | null;
syncStatus: string;
lastSyncLabel: string;
lastSyncAt: number | null;
syncGeneration: number;
syncPhase: string | null;
processedRepositoryCount: number;
totalRepositoryCount: number;
},
): Promise<void> {
const profile = await c.db
.select({ id: organizationProfile.id })
.from(organizationProfile)
.where(eq(organizationProfile.id, ORGANIZATION_PROFILE_ROW_ID))
.get();
if (!profile) {
return;
}
async recordGithubWebhookReceipt(
c: any,
input: {
organizationId: string;
event: string;
action?: string | null;
receivedAt?: number;
},
): Promise<void> {
assertOrganization(c, input.organizationId);
await c.db
.update(organizationProfile)
.set({
githubConnectedAccount: input.connectedAccount,
githubInstallationStatus: input.installationStatus,
githubSyncStatus: input.syncStatus,
githubInstallationId: input.installationId,
githubLastSyncLabel: input.lastSyncLabel,
githubLastSyncAt: input.lastSyncAt,
githubSyncGeneration: input.syncGeneration,
githubSyncPhase: input.syncPhase,
githubProcessedRepositoryCount: input.processedRepositoryCount,
githubTotalRepositoryCount: input.totalRepositoryCount,
updatedAt: Date.now(),
})
.where(eq(organizationProfile.id, ORGANIZATION_PROFILE_ROW_ID))
.run();
const profile = await c.db
.select({ id: organizationProfile.id })
.from(organizationProfile)
.where(eq(organizationProfile.id, ORGANIZATION_PROFILE_ROW_ID))
.get();
if (!profile) {
return;
}
await refreshOrganizationSnapshotMutation(c);
}
await c.db
.update(organizationProfile)
.set({
githubLastWebhookAt: input.receivedAt ?? Date.now(),
githubLastWebhookEvent: input.action ? `${input.event}.${input.action}` : input.event,
})
.where(eq(organizationProfile.id, ORGANIZATION_PROFILE_ROW_ID))
.run();
export async function recordGithubWebhookReceiptMutation(
c: any,
input: {
organizationId: string;
event: string;
action?: string | null;
receivedAt?: number;
},
): Promise<void> {
assertOrganization(c, input.organizationId);
async getOrganizationSummary(c: any, input: OrganizationUseInput): Promise<OrganizationSummarySnapshot> {
assertOrganization(c, input.organizationId);
return await getOrganizationSummarySnapshot(c);
},
const profile = await c.db
.select({ id: organizationProfile.id })
.from(organizationProfile)
.where(eq(organizationProfile.id, ORGANIZATION_PROFILE_ROW_ID))
.get();
if (!profile) {
return;
}
async adminReconcileWorkspaceState(c: any, input: OrganizationUseInput): Promise<OrganizationSummarySnapshot> {
assertOrganization(c, input.organizationId);
return await reconcileWorkspaceProjection(c);
},
async createWorkspaceTask(c: any, input: TaskWorkspaceCreateTaskInput): Promise<{ taskId: string; sessionId?: string }> {
// Step 1: Create the task record (wait: true — local state mutations only).
const created = await organizationActions.createTask(c, {
organizationId: c.state.organizationId,
repoId: input.repoId,
task: input.task,
...(input.title ? { explicitTitle: input.title } : {}),
...(input.onBranch ? { onBranch: input.onBranch } : input.branch ? { explicitBranchName: input.branch } : {}),
...(input.model ? { agentType: agentTypeForModel(input.model) } : {}),
});
// Step 2: Enqueue session creation + initial message (wait: false).
// The task workflow creates the session record and sends the message in
// the background. The client observes progress via push events on the
// task subscription topic.
const task = await requireWorkspaceTask(c, input.repoId, created.taskId);
await task.createWorkspaceSessionAndSend({
model: input.model,
text: input.task,
});
return { taskId: created.taskId };
},
async markWorkspaceUnread(c: any, input: TaskWorkspaceSelectInput): Promise<void> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
await task.markWorkspaceUnread({});
},
async renameWorkspaceTask(c: any, input: TaskWorkspaceRenameInput): Promise<void> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
await task.renameWorkspaceTask(input);
},
async createWorkspaceSession(c: any, input: TaskWorkspaceSelectInput & { model?: string }): Promise<{ sessionId: string }> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
return await task.createWorkspaceSession({ ...(input.model ? { model: input.model } : {}) });
},
async renameWorkspaceSession(c: any, input: TaskWorkspaceRenameSessionInput): Promise<void> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
await task.renameWorkspaceSession(input);
},
async setWorkspaceSessionUnread(c: any, input: TaskWorkspaceSetSessionUnreadInput): Promise<void> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
await task.setWorkspaceSessionUnread(input);
},
async updateWorkspaceDraft(c: any, input: TaskWorkspaceUpdateDraftInput): Promise<void> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
await task.updateWorkspaceDraft(input);
},
async changeWorkspaceModel(c: any, input: TaskWorkspaceChangeModelInput): Promise<void> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
await task.changeWorkspaceModel(input);
},
async sendWorkspaceMessage(c: any, input: TaskWorkspaceSendMessageInput): Promise<void> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
await task.sendWorkspaceMessage(input);
},
async stopWorkspaceSession(c: any, input: TaskWorkspaceSessionInput): Promise<void> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
await task.stopWorkspaceSession(input);
},
async closeWorkspaceSession(c: any, input: TaskWorkspaceSessionInput): Promise<void> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
await task.closeWorkspaceSession(input);
},
async publishWorkspacePr(c: any, input: TaskWorkspaceSelectInput): Promise<void> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
await task.publishWorkspacePr({});
},
async revertWorkspaceFile(c: any, input: TaskWorkspaceDiffInput): Promise<void> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
await task.revertWorkspaceFile(input);
},
async adminReloadGithubOrganization(c: any): Promise<void> {
await getOrCreateGithubData(c, c.state.organizationId).adminReloadOrganization({});
},
async adminReloadGithubPullRequests(c: any): Promise<void> {
await getOrCreateGithubData(c, c.state.organizationId).adminReloadAllPullRequests({});
},
async adminReloadGithubRepository(c: any, input: { repoId: string }): Promise<void> {
await getOrCreateGithubData(c, c.state.organizationId).reloadRepository(input);
},
async adminReloadGithubPullRequest(c: any, input: { repoId: string; prNumber: number }): Promise<void> {
await getOrCreateGithubData(c, c.state.organizationId).reloadPullRequest(input);
},
async listTasks(c: any, input: ListTasksInput): Promise<TaskSummary[]> {
assertOrganization(c, input.organizationId);
if (input.repoId) {
const repoRow = await c.db.select({ remoteUrl: repos.remoteUrl }).from(repos).where(eq(repos.repoId, input.repoId)).get();
if (!repoRow) {
throw new Error(`Unknown repo: ${input.repoId}`);
}
const repository = await getOrCreateRepository(c, c.state.organizationId, input.repoId, repoRow.remoteUrl);
return await repository.listTaskSummaries({ includeArchived: true });
}
return await collectAllTaskSummaries(c);
},
async getRepoOverview(c: any, input: RepoOverviewInput): Promise<RepoOverview> {
assertOrganization(c, input.organizationId);
const repoRow = await c.db.select({ remoteUrl: repos.remoteUrl }).from(repos).where(eq(repos.repoId, input.repoId)).get();
if (!repoRow) {
throw new Error(`Unknown repo: ${input.repoId}`);
}
const repository = await getOrCreateRepository(c, c.state.organizationId, input.repoId, repoRow.remoteUrl);
return await repository.getRepoOverview({});
},
async switchTask(c: any, input: { repoId?: string; taskId: string }): Promise<SwitchResult> {
const { repoId } = await resolveRepositoryForTask(c, input.taskId, input.repoId);
const h = getTaskHandle(c, c.state.organizationId, repoId, input.taskId);
const record = await h.get();
const switched = await h.switch();
return {
organizationId: c.state.organizationId,
taskId: input.taskId,
sandboxProviderId: record.sandboxProviderId,
switchTarget: switched.switchTarget,
};
},
async auditLog(c: any, input: HistoryQueryInput): Promise<AuditLogEvent[]> {
assertOrganization(c, input.organizationId);
const limit = input.limit ?? 20;
const repoRows = await c.db.select({ repoId: repos.repoId }).from(repos).all();
const allEvents: AuditLogEvent[] = [];
for (const row of repoRows) {
try {
const auditLog = await getOrCreateAuditLog(c, c.state.organizationId, row.repoId);
const items = await auditLog.list({
branch: input.branch,
taskId: input.taskId,
limit,
});
allEvents.push(...items);
} catch (error) {
logActorWarning("organization", "audit log lookup failed for repo", {
organizationId: c.state.organizationId,
repoId: row.repoId,
error: resolveErrorMessage(error),
});
}
}
allEvents.sort((a, b) => b.createdAt - a.createdAt);
return allEvents.slice(0, limit);
},
async getTask(c: any, input: GetTaskInput): Promise<TaskRecord> {
assertOrganization(c, input.organizationId);
const { repoId } = await resolveRepositoryForTask(c, input.taskId, input.repoId);
return await getTaskHandle(c, c.state.organizationId, repoId, input.taskId).get();
},
async attachTask(c: any, input: TaskProxyActionInput): Promise<{ target: string; sessionId: string | null }> {
assertOrganization(c, input.organizationId);
const { repoId } = await resolveRepositoryForTask(c, input.taskId, input.repoId);
const h = getTaskHandle(c, c.state.organizationId, repoId, input.taskId);
return await h.attach({ reason: input.reason });
},
async pushTask(c: any, input: TaskProxyActionInput): Promise<void> {
assertOrganization(c, input.organizationId);
const { repoId } = await resolveRepositoryForTask(c, input.taskId, input.repoId);
const h = getTaskHandle(c, c.state.organizationId, repoId, input.taskId);
await h.push({ reason: input.reason });
},
async syncTask(c: any, input: TaskProxyActionInput): Promise<void> {
assertOrganization(c, input.organizationId);
const { repoId } = await resolveRepositoryForTask(c, input.taskId, input.repoId);
const h = getTaskHandle(c, c.state.organizationId, repoId, input.taskId);
await h.sync({ reason: input.reason });
},
async mergeTask(c: any, input: TaskProxyActionInput): Promise<void> {
assertOrganization(c, input.organizationId);
const { repoId } = await resolveRepositoryForTask(c, input.taskId, input.repoId);
const h = getTaskHandle(c, c.state.organizationId, repoId, input.taskId);
await h.merge({ reason: input.reason });
},
async archiveTask(c: any, input: TaskProxyActionInput): Promise<void> {
assertOrganization(c, input.organizationId);
const { repoId } = await resolveRepositoryForTask(c, input.taskId, input.repoId);
const h = getTaskHandle(c, c.state.organizationId, repoId, input.taskId);
await h.archive({ reason: input.reason });
},
async killTask(c: any, input: TaskProxyActionInput): Promise<void> {
assertOrganization(c, input.organizationId);
const { repoId } = await resolveRepositoryForTask(c, input.taskId, input.repoId);
const h = getTaskHandle(c, c.state.organizationId, repoId, input.taskId);
await h.kill({ reason: input.reason });
},
};
await c.db
.update(organizationProfile)
.set({
githubLastWebhookAt: input.receivedAt ?? Date.now(),
githubLastWebhookEvent: input.action ? `${input.event}.${input.action}` : input.event,
})
.where(eq(organizationProfile.id, ORGANIZATION_PROFILE_ROW_ID))
.run();
}

View file

@ -0,0 +1 @@
export { organizationAppActions } from "../app-shell.js";

View file

@ -0,0 +1,323 @@
import {
and,
asc,
count as sqlCount,
desc,
eq,
gt,
gte,
inArray,
isNotNull,
isNull,
like,
lt,
lte,
ne,
notInArray,
or,
} from "drizzle-orm";
import { authAccountIndex, authEmailIndex, authSessionIndex, authVerification } from "../db/schema.js";
import { APP_SHELL_ORGANIZATION_ID } from "../constants.js";
function assertAppOrganization(c: any): void {
if (c.state.organizationId !== APP_SHELL_ORGANIZATION_ID) {
throw new Error(`App shell action requires organization ${APP_SHELL_ORGANIZATION_ID}, got ${c.state.organizationId}`);
}
}
function organizationAuthColumn(table: any, field: string): any {
const column = table[field];
if (!column) {
throw new Error(`Unknown auth table field: ${field}`);
}
return column;
}
function normalizeAuthValue(value: unknown): unknown {
if (value instanceof Date) {
return value.getTime();
}
if (Array.isArray(value)) {
return value.map((entry) => normalizeAuthValue(entry));
}
return value;
}
function organizationAuthClause(table: any, clause: { field: string; value: unknown; operator?: string }): any {
const column = organizationAuthColumn(table, clause.field);
const value = normalizeAuthValue(clause.value);
switch (clause.operator) {
case "ne":
return value === null ? isNotNull(column) : ne(column, value as any);
case "lt":
return lt(column, value as any);
case "lte":
return lte(column, value as any);
case "gt":
return gt(column, value as any);
case "gte":
return gte(column, value as any);
case "in":
return inArray(column, Array.isArray(value) ? (value as any[]) : [value as any]);
case "not_in":
return notInArray(column, Array.isArray(value) ? (value as any[]) : [value as any]);
case "contains":
return like(column, `%${String(value ?? "")}%`);
case "starts_with":
return like(column, `${String(value ?? "")}%`);
case "ends_with":
return like(column, `%${String(value ?? "")}`);
case "eq":
default:
return value === null ? isNull(column) : eq(column, value as any);
}
}
function organizationBetterAuthWhere(table: any, clauses: any[] | undefined): any {
if (!clauses || clauses.length === 0) {
return undefined;
}
let expr = organizationAuthClause(table, clauses[0]);
for (const clause of clauses.slice(1)) {
const next = organizationAuthClause(table, clause);
expr = clause.connector === "OR" ? or(expr, next) : and(expr, next);
}
return expr;
}
export async function betterAuthUpsertSessionIndexMutation(c: any, input: { sessionId: string; sessionToken: string; userId: string }) {
assertAppOrganization(c);
const now = Date.now();
await c.db
.insert(authSessionIndex)
.values({
sessionId: input.sessionId,
sessionToken: input.sessionToken,
userId: input.userId,
createdAt: now,
updatedAt: now,
})
.onConflictDoUpdate({
target: authSessionIndex.sessionId,
set: {
sessionToken: input.sessionToken,
userId: input.userId,
updatedAt: now,
},
})
.run();
return await c.db.select().from(authSessionIndex).where(eq(authSessionIndex.sessionId, input.sessionId)).get();
}
export async function betterAuthDeleteSessionIndexMutation(c: any, input: { sessionId?: string; sessionToken?: string }) {
assertAppOrganization(c);
const clauses = [
...(input.sessionId ? [{ field: "sessionId", value: input.sessionId }] : []),
...(input.sessionToken ? [{ field: "sessionToken", value: input.sessionToken }] : []),
];
if (clauses.length === 0) {
return;
}
const predicate = organizationBetterAuthWhere(authSessionIndex, clauses);
await c.db.delete(authSessionIndex).where(predicate!).run();
}
export async function betterAuthUpsertEmailIndexMutation(c: any, input: { email: string; userId: string }) {
assertAppOrganization(c);
const now = Date.now();
await c.db
.insert(authEmailIndex)
.values({
email: input.email,
userId: input.userId,
updatedAt: now,
})
.onConflictDoUpdate({
target: authEmailIndex.email,
set: {
userId: input.userId,
updatedAt: now,
},
})
.run();
return await c.db.select().from(authEmailIndex).where(eq(authEmailIndex.email, input.email)).get();
}
export async function betterAuthDeleteEmailIndexMutation(c: any, input: { email: string }) {
assertAppOrganization(c);
await c.db.delete(authEmailIndex).where(eq(authEmailIndex.email, input.email)).run();
}
export async function betterAuthUpsertAccountIndexMutation(
c: any,
input: { id: string; providerId: string; accountId: string; userId: string },
) {
assertAppOrganization(c);
const now = Date.now();
await c.db
.insert(authAccountIndex)
.values({
id: input.id,
providerId: input.providerId,
accountId: input.accountId,
userId: input.userId,
updatedAt: now,
})
.onConflictDoUpdate({
target: authAccountIndex.id,
set: {
providerId: input.providerId,
accountId: input.accountId,
userId: input.userId,
updatedAt: now,
},
})
.run();
return await c.db.select().from(authAccountIndex).where(eq(authAccountIndex.id, input.id)).get();
}
export async function betterAuthDeleteAccountIndexMutation(c: any, input: { id?: string; providerId?: string; accountId?: string }) {
assertAppOrganization(c);
if (input.id) {
await c.db.delete(authAccountIndex).where(eq(authAccountIndex.id, input.id)).run();
return;
}
if (input.providerId && input.accountId) {
await c.db
.delete(authAccountIndex)
.where(and(eq(authAccountIndex.providerId, input.providerId), eq(authAccountIndex.accountId, input.accountId)))
.run();
}
}
export async function betterAuthCreateVerificationMutation(c: any, input: { data: Record<string, unknown> }) {
assertAppOrganization(c);
await c.db.insert(authVerification).values(input.data as any).run();
return await c.db.select().from(authVerification).where(eq(authVerification.id, input.data.id as string)).get();
}
export async function betterAuthUpdateVerificationMutation(c: any, input: { where: any[]; update: Record<string, unknown> }) {
assertAppOrganization(c);
const predicate = organizationBetterAuthWhere(authVerification, input.where);
if (!predicate) {
return null;
}
await c.db.update(authVerification).set(input.update as any).where(predicate).run();
return await c.db.select().from(authVerification).where(predicate).get();
}
export async function betterAuthUpdateManyVerificationMutation(c: any, input: { where: any[]; update: Record<string, unknown> }) {
assertAppOrganization(c);
const predicate = organizationBetterAuthWhere(authVerification, input.where);
if (!predicate) {
return 0;
}
await c.db.update(authVerification).set(input.update as any).where(predicate).run();
const row = await c.db.select({ value: sqlCount() }).from(authVerification).where(predicate).get();
return row?.value ?? 0;
}
export async function betterAuthDeleteVerificationMutation(c: any, input: { where: any[] }) {
assertAppOrganization(c);
const predicate = organizationBetterAuthWhere(authVerification, input.where);
if (!predicate) {
return;
}
await c.db.delete(authVerification).where(predicate).run();
}
export async function betterAuthDeleteManyVerificationMutation(c: any, input: { where: any[] }) {
assertAppOrganization(c);
const predicate = organizationBetterAuthWhere(authVerification, input.where);
if (!predicate) {
return 0;
}
const rows = await c.db.select().from(authVerification).where(predicate).all();
await c.db.delete(authVerification).where(predicate).run();
return rows.length;
}
export const organizationBetterAuthActions = {
async betterAuthFindSessionIndex(c: any, input: { sessionId?: string; sessionToken?: string }) {
assertAppOrganization(c);
const clauses = [
...(input.sessionId ? [{ field: "sessionId", value: input.sessionId }] : []),
...(input.sessionToken ? [{ field: "sessionToken", value: input.sessionToken }] : []),
];
if (clauses.length === 0) {
return null;
}
const predicate = organizationBetterAuthWhere(authSessionIndex, clauses);
return await c.db.select().from(authSessionIndex).where(predicate!).get();
},
async betterAuthFindEmailIndex(c: any, input: { email: string }) {
assertAppOrganization(c);
return await c.db.select().from(authEmailIndex).where(eq(authEmailIndex.email, input.email)).get();
},
async betterAuthFindAccountIndex(c: any, input: { id?: string; providerId?: string; accountId?: string }) {
assertAppOrganization(c);
if (input.id) {
return await c.db.select().from(authAccountIndex).where(eq(authAccountIndex.id, input.id)).get();
}
if (!input.providerId || !input.accountId) {
return null;
}
return await c.db
.select()
.from(authAccountIndex)
.where(and(eq(authAccountIndex.providerId, input.providerId), eq(authAccountIndex.accountId, input.accountId)))
.get();
},
async betterAuthFindOneVerification(c: any, input: { where: any[] }) {
assertAppOrganization(c);
const predicate = organizationBetterAuthWhere(authVerification, input.where);
return predicate ? await c.db.select().from(authVerification).where(predicate).get() : null;
},
async betterAuthFindManyVerification(c: any, input: { where?: any[]; limit?: number; sortBy?: any; offset?: number }) {
assertAppOrganization(c);
const predicate = organizationBetterAuthWhere(authVerification, input.where);
let query = c.db.select().from(authVerification);
if (predicate) {
query = query.where(predicate);
}
if (input.sortBy?.field) {
const column = organizationAuthColumn(authVerification, input.sortBy.field);
query = query.orderBy(input.sortBy.direction === "asc" ? asc(column) : desc(column));
}
if (typeof input.limit === "number") {
query = query.limit(input.limit);
}
if (typeof input.offset === "number") {
query = query.offset(input.offset);
}
return await query.all();
},
async betterAuthCountVerification(c: any, input: { where?: any[] }) {
assertAppOrganization(c);
const predicate = organizationBetterAuthWhere(authVerification, input.where);
const row = predicate
? await c.db.select({ value: sqlCount() }).from(authVerification).where(predicate).get()
: await c.db.select({ value: sqlCount() }).from(authVerification).get();
return row?.value ?? 0;
},
};

View file

@ -0,0 +1,91 @@
import { desc } from "drizzle-orm";
import type { FoundryAppSnapshot } from "@sandbox-agent/foundry-shared";
import { getOrCreateGithubData, getOrCreateOrganization } from "../../handles.js";
import { authSessionIndex } from "../db/schema.js";
import { githubDataWorkflowQueueName } from "../../github-data/workflow.js";
import {
assertAppOrganization,
buildAppSnapshot,
requireEligibleOrganization,
requireSignedInSession,
} from "../app-shell.js";
import { getBetterAuthService } from "../../../services/better-auth.js";
import { expectQueueResponse } from "../../../services/queue.js";
import { organizationWorkflowQueueName } from "../queues.js";
export const organizationGithubActions = {
async resolveAppGithubToken(
c: any,
input: { organizationId: string; requireRepoScope?: boolean },
): Promise<{ accessToken: string; scopes: string[] } | null> {
assertAppOrganization(c);
const auth = getBetterAuthService();
const rows = await c.db.select().from(authSessionIndex).orderBy(desc(authSessionIndex.updatedAt)).all();
for (const row of rows) {
const authState = await auth.getAuthState(row.sessionId);
if (authState?.sessionState?.activeOrganizationId !== input.organizationId) {
continue;
}
const token = await auth.getAccessTokenForSession(row.sessionId);
if (!token?.accessToken) {
continue;
}
const scopes = token.scopes;
if (input.requireRepoScope !== false && scopes.length > 0 && !scopes.some((scope) => scope === "repo" || scope.startsWith("repo:"))) {
continue;
}
return {
accessToken: token.accessToken,
scopes,
};
}
return null;
},
async triggerAppRepoImport(c: any, input: { sessionId: string; organizationId: string }): Promise<FoundryAppSnapshot> {
assertAppOrganization(c);
const session = await requireSignedInSession(c, input.sessionId);
requireEligibleOrganization(session, input.organizationId);
const githubData = await getOrCreateGithubData(c, input.organizationId);
const summary = await githubData.getSummary({});
if (summary.syncStatus === "syncing") {
return await buildAppSnapshot(c, input.sessionId);
}
const organizationHandle = await getOrCreateOrganization(c, input.organizationId);
await expectQueueResponse<{ ok: true }>(
await organizationHandle.send(
organizationWorkflowQueueName("organization.command.shell.sync_started.mark"),
{ label: "Importing repository catalog..." },
{ wait: true, timeout: 10_000 },
),
);
await expectQueueResponse<{ ok: true }>(
await organizationHandle.send(organizationWorkflowQueueName("organization.command.snapshot.broadcast"), {}, { wait: true, timeout: 10_000 }),
);
await githubData.send("githubData.command.syncRepos", { label: "Importing repository catalog..." }, { wait: false });
return await buildAppSnapshot(c, input.sessionId);
},
async adminReloadGithubOrganization(c: any): Promise<void> {
const githubData = await getOrCreateGithubData(c, c.state.organizationId);
await expectQueueResponse<{ ok: true }>(
await githubData.send(githubDataWorkflowQueueName("githubData.command.syncRepos"), { label: "Reloading GitHub organization..." }, { wait: true, timeout: 10_000 }),
);
},
async adminReloadGithubRepository(c: any, input: { repoId: string }): Promise<void> {
const githubData = await getOrCreateGithubData(c, c.state.organizationId);
await expectQueueResponse<unknown>(
await githubData.send(githubDataWorkflowQueueName("githubData.command.reloadRepository"), input, { wait: true, timeout: 10_000 }),
);
},
};

View file

@ -0,0 +1,82 @@
import { randomUUID } from "node:crypto";
import type { FoundryAppSnapshot, StarSandboxAgentRepoInput, StarSandboxAgentRepoResult } from "@sandbox-agent/foundry-shared";
import { getOrCreateGithubData, getOrCreateOrganization } from "../../handles.js";
import {
assertAppOrganization,
buildAppSnapshot,
getOrganizationState,
requireEligibleOrganization,
requireSignedInSession,
} from "../app-shell.js";
import { getBetterAuthService } from "../../../services/better-auth.js";
import { getActorRuntimeContext } from "../../context.js";
import { resolveOrganizationGithubAuth } from "../../../services/github-auth.js";
const SANDBOX_AGENT_REPO = "rivet-dev/sandbox-agent";
export const organizationOnboardingActions = {
async skipAppStarterRepo(c: any, input: { sessionId: string }): Promise<FoundryAppSnapshot> {
assertAppOrganization(c);
const session = await requireSignedInSession(c, input.sessionId);
await getBetterAuthService().upsertUserProfile(session.authUserId, {
starterRepoStatus: "skipped",
starterRepoSkippedAt: Date.now(),
starterRepoStarredAt: null,
});
return await buildAppSnapshot(c, input.sessionId);
},
async starAppStarterRepo(c: any, input: { sessionId: string; organizationId: string }): Promise<FoundryAppSnapshot> {
assertAppOrganization(c);
const session = await requireSignedInSession(c, input.sessionId);
requireEligibleOrganization(session, input.organizationId);
const organization = await getOrCreateOrganization(c, input.organizationId);
await organization.starSandboxAgentRepo({
organizationId: input.organizationId,
});
await getBetterAuthService().upsertUserProfile(session.authUserId, {
starterRepoStatus: "starred",
starterRepoStarredAt: Date.now(),
starterRepoSkippedAt: null,
});
return await buildAppSnapshot(c, input.sessionId);
},
async selectAppOrganization(c: any, input: { sessionId: string; organizationId: string }): Promise<FoundryAppSnapshot> {
assertAppOrganization(c);
const session = await requireSignedInSession(c, input.sessionId);
requireEligibleOrganization(session, input.organizationId);
await getBetterAuthService().setActiveOrganization(input.sessionId, input.organizationId);
await getOrCreateGithubData(c, input.organizationId);
return await buildAppSnapshot(c, input.sessionId);
},
async beginAppGithubInstall(c: any, input: { sessionId: string; organizationId: string }): Promise<{ url: string }> {
assertAppOrganization(c);
const session = await requireSignedInSession(c, input.sessionId);
requireEligibleOrganization(session, input.organizationId);
const { appShell } = getActorRuntimeContext();
const organizationHandle = await getOrCreateOrganization(c, input.organizationId);
const organizationState = await getOrganizationState(organizationHandle);
if (organizationState.snapshot.kind !== "organization") {
return {
url: `${appShell.appUrl}/organizations/${input.organizationId}`,
};
}
return {
url: await appShell.github.buildInstallationUrl(organizationState.githubLogin, randomUUID()),
};
},
async starSandboxAgentRepo(c: any, input: StarSandboxAgentRepoInput): Promise<StarSandboxAgentRepoResult> {
const { driver } = getActorRuntimeContext();
const auth = await resolveOrganizationGithubAuth(c, c.state.organizationId);
await driver.github.starRepository(SANDBOX_AGENT_REPO, {
githubToken: auth?.githubToken ?? null,
});
return {
repo: SANDBOX_AGENT_REPO,
starredAt: Date.now(),
};
},
};

View file

@ -0,0 +1,61 @@
import type { FoundryAppSnapshot, UpdateFoundryOrganizationProfileInput, WorkspaceModelId } from "@sandbox-agent/foundry-shared";
import { getBetterAuthService } from "../../../services/better-auth.js";
import { getOrCreateOrganization } from "../../handles.js";
import { expectQueueResponse } from "../../../services/queue.js";
import {
assertAppOrganization,
assertOrganizationShell,
buildAppSnapshot,
buildOrganizationState,
buildOrganizationStateIfInitialized,
requireEligibleOrganization,
requireSignedInSession,
} from "../app-shell.js";
import { organizationWorkflowQueueName } from "../queues.js";
export const organizationShellActions = {
async getAppSnapshot(c: any, input: { sessionId: string }): Promise<FoundryAppSnapshot> {
return await buildAppSnapshot(c, input.sessionId);
},
async setAppDefaultModel(c: any, input: { sessionId: string; defaultModel: WorkspaceModelId }): Promise<FoundryAppSnapshot> {
assertAppOrganization(c);
const session = await requireSignedInSession(c, input.sessionId);
await getBetterAuthService().upsertUserProfile(session.authUserId, {
defaultModel: input.defaultModel,
});
return await buildAppSnapshot(c, input.sessionId);
},
async updateAppOrganizationProfile(
c: any,
input: { sessionId: string; organizationId: string } & UpdateFoundryOrganizationProfileInput,
): Promise<FoundryAppSnapshot> {
assertAppOrganization(c);
const session = await requireSignedInSession(c, input.sessionId);
requireEligibleOrganization(session, input.organizationId);
const organization = await getOrCreateOrganization(c, input.organizationId);
await expectQueueResponse<{ ok: true }>(
await organization.send(
organizationWorkflowQueueName("organization.command.shell.profile.update"),
{
displayName: input.displayName,
slug: input.slug,
primaryDomain: input.primaryDomain,
},
{ wait: true, timeout: 10_000 },
),
);
return await buildAppSnapshot(c, input.sessionId);
},
async getOrganizationShellState(c: any): Promise<any> {
assertOrganizationShell(c);
return await buildOrganizationState(c);
},
async getOrganizationShellStateIfInitialized(c: any): Promise<any | null> {
assertOrganizationShell(c);
return await buildOrganizationStateIfInitialized(c);
},
};

View file

@ -0,0 +1,387 @@
// @ts-nocheck
import { desc, eq } from "drizzle-orm";
import type {
AuditLogEvent,
CreateTaskInput,
HistoryQueryInput,
ListTasksInput,
RepoOverview,
SwitchResult,
TaskRecord,
TaskSummary,
TaskWorkspaceChangeModelInput,
TaskWorkspaceCreateTaskInput,
TaskWorkspaceDiffInput,
TaskWorkspaceRenameInput,
TaskWorkspaceRenameSessionInput,
TaskWorkspaceSelectInput,
TaskWorkspaceSetSessionUnreadInput,
TaskWorkspaceSendMessageInput,
TaskWorkspaceSessionInput,
TaskWorkspaceUpdateDraftInput,
} from "@sandbox-agent/foundry-shared";
import { getActorRuntimeContext } from "../../context.js";
import { getOrCreateAuditLog, getOrCreateRepository, getTask as getTaskHandle, selfOrganization } from "../../handles.js";
import { defaultSandboxProviderId } from "../../../sandbox-config.js";
import { expectQueueResponse } from "../../../services/queue.js";
import { logActorWarning, resolveErrorMessage } from "../../logging.js";
import { repositoryWorkflowQueueName } from "../../repository/workflow.js";
import { taskWorkflowQueueName } from "../../task/workflow/index.js";
import { repos } from "../db/schema.js";
import { organizationWorkflowQueueName } from "../queues.js";
function assertOrganization(c: { state: { organizationId: string } }, organizationId: string): void {
if (organizationId !== c.state.organizationId) {
throw new Error(`Organization actor mismatch: actor=${c.state.organizationId} command=${organizationId}`);
}
}
async function requireRepositoryForTask(c: any, repoId: string) {
const repoRow = await c.db.select({ repoId: repos.repoId }).from(repos).where(eq(repos.repoId, repoId)).get();
if (!repoRow) {
throw new Error(`Unknown repo: ${repoId}`);
}
return await getOrCreateRepository(c, c.state.organizationId, repoId);
}
async function requireWorkspaceTask(c: any, repoId: string, taskId: string) {
return getTaskHandle(c, c.state.organizationId, repoId, taskId);
}
async function collectAllTaskSummaries(c: any): Promise<TaskSummary[]> {
const repoRows = await c.db.select({ repoId: repos.repoId, remoteUrl: repos.remoteUrl }).from(repos).orderBy(desc(repos.updatedAt)).all();
const all: TaskSummary[] = [];
for (const row of repoRows) {
try {
const repository = await getOrCreateRepository(c, c.state.organizationId, row.repoId);
const snapshot = await repository.listTaskSummaries({ includeArchived: true });
all.push(...snapshot);
} catch (error) {
logActorWarning("organization", "failed collecting tasks for repo", {
organizationId: c.state.organizationId,
repoId: row.repoId,
error: resolveErrorMessage(error),
});
}
}
all.sort((a, b) => b.updatedAt - a.updatedAt);
return all;
}
interface GetTaskInput {
organizationId: string;
repoId: string;
taskId: string;
}
interface TaskProxyActionInput extends GetTaskInput {
reason?: string;
}
interface RepoOverviewInput {
organizationId: string;
repoId: string;
}
export async function createTaskMutation(c: any, input: CreateTaskInput): Promise<TaskRecord> {
assertOrganization(c, input.organizationId);
const { config } = getActorRuntimeContext();
const sandboxProviderId = input.sandboxProviderId ?? defaultSandboxProviderId(config);
await requireRepositoryForTask(c, input.repoId);
const repository = await getOrCreateRepository(c, c.state.organizationId, input.repoId);
return expectQueueResponse<TaskRecord>(
await repository.send(
repositoryWorkflowQueueName("repository.command.createTask"),
{
task: input.task,
sandboxProviderId,
explicitTitle: input.explicitTitle ?? null,
explicitBranchName: input.explicitBranchName ?? null,
onBranch: input.onBranch ?? null,
},
{
wait: true,
timeout: 10_000,
},
),
);
}
export const organizationTaskActions = {
async createTask(c: any, input: CreateTaskInput): Promise<TaskRecord> {
const self = selfOrganization(c);
return expectQueueResponse<TaskRecord>(
await self.send(organizationWorkflowQueueName("organization.command.createTask"), input, {
wait: true,
timeout: 10_000,
}),
);
},
async createWorkspaceTask(c: any, input: TaskWorkspaceCreateTaskInput): Promise<{ taskId: string; sessionId?: string }> {
const created = await organizationTaskActions.createTask(c, {
organizationId: c.state.organizationId,
repoId: input.repoId,
task: input.task,
...(input.title ? { explicitTitle: input.title } : {}),
...(input.onBranch ? { onBranch: input.onBranch } : input.branch ? { explicitBranchName: input.branch } : {}),
});
const task = await requireWorkspaceTask(c, input.repoId, created.taskId);
await task.send(
taskWorkflowQueueName("task.command.workspace.create_session_and_send"),
{
model: input.model,
text: input.task,
authSessionId: input.authSessionId,
},
{ wait: false },
);
return { taskId: created.taskId };
},
async markWorkspaceUnread(c: any, input: TaskWorkspaceSelectInput): Promise<void> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
await expectQueueResponse<{ ok: true }>(
await task.send(taskWorkflowQueueName("task.command.workspace.mark_unread"), { authSessionId: input.authSessionId }, { wait: true, timeout: 10_000 }),
);
},
async renameWorkspaceTask(c: any, input: TaskWorkspaceRenameInput): Promise<void> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
await expectQueueResponse<{ ok: true }>(
await task.send(taskWorkflowQueueName("task.command.workspace.rename_task"), { value: input.value }, { wait: true, timeout: 20_000 }),
);
},
async createWorkspaceSession(c: any, input: TaskWorkspaceSelectInput & { model?: string }): Promise<{ sessionId: string }> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
return await expectQueueResponse<{ sessionId: string }>(
await task.send(
taskWorkflowQueueName("task.command.workspace.create_session"),
{
...(input.model ? { model: input.model } : {}),
...(input.authSessionId ? { authSessionId: input.authSessionId } : {}),
},
{ wait: true, timeout: 10_000 },
),
);
},
async renameWorkspaceSession(c: any, input: TaskWorkspaceRenameSessionInput): Promise<void> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
await expectQueueResponse<{ ok: true }>(
await task.send(
taskWorkflowQueueName("task.command.workspace.rename_session"),
{ sessionId: input.sessionId, title: input.title, authSessionId: input.authSessionId },
{ wait: true, timeout: 10_000 },
),
);
},
async selectWorkspaceSession(c: any, input: TaskWorkspaceSessionInput): Promise<void> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
await expectQueueResponse<{ ok: true }>(
await task.send(
taskWorkflowQueueName("task.command.workspace.select_session"),
{ sessionId: input.sessionId, authSessionId: input.authSessionId },
{ wait: true, timeout: 10_000 },
),
);
},
async setWorkspaceSessionUnread(c: any, input: TaskWorkspaceSetSessionUnreadInput): Promise<void> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
await expectQueueResponse<{ ok: true }>(
await task.send(
taskWorkflowQueueName("task.command.workspace.set_session_unread"),
{ sessionId: input.sessionId, unread: input.unread, authSessionId: input.authSessionId },
{ wait: true, timeout: 10_000 },
),
);
},
async updateWorkspaceDraft(c: any, input: TaskWorkspaceUpdateDraftInput): Promise<void> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
await task.send(
taskWorkflowQueueName("task.command.workspace.update_draft"),
{
sessionId: input.sessionId,
text: input.text,
attachments: input.attachments,
authSessionId: input.authSessionId,
},
{ wait: false },
);
},
async changeWorkspaceModel(c: any, input: TaskWorkspaceChangeModelInput): Promise<void> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
await expectQueueResponse<{ ok: true }>(
await task.send(
taskWorkflowQueueName("task.command.workspace.change_model"),
{ sessionId: input.sessionId, model: input.model, authSessionId: input.authSessionId },
{ wait: true, timeout: 10_000 },
),
);
},
async sendWorkspaceMessage(c: any, input: TaskWorkspaceSendMessageInput): Promise<void> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
await task.send(
taskWorkflowQueueName("task.command.workspace.send_message"),
{
sessionId: input.sessionId,
text: input.text,
attachments: input.attachments,
authSessionId: input.authSessionId,
},
{ wait: false },
);
},
async stopWorkspaceSession(c: any, input: TaskWorkspaceSessionInput): Promise<void> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
await task.send(
taskWorkflowQueueName("task.command.workspace.stop_session"),
{ sessionId: input.sessionId, authSessionId: input.authSessionId },
{ wait: false },
);
},
async closeWorkspaceSession(c: any, input: TaskWorkspaceSessionInput): Promise<void> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
await task.send(
taskWorkflowQueueName("task.command.workspace.close_session"),
{ sessionId: input.sessionId, authSessionId: input.authSessionId },
{ wait: false },
);
},
async publishWorkspacePr(c: any, input: TaskWorkspaceSelectInput): Promise<void> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
await task.send(taskWorkflowQueueName("task.command.workspace.publish_pr"), {}, { wait: false });
},
async revertWorkspaceFile(c: any, input: TaskWorkspaceDiffInput): Promise<void> {
const task = await requireWorkspaceTask(c, input.repoId, input.taskId);
await task.send(taskWorkflowQueueName("task.command.workspace.revert_file"), input, { wait: false });
},
async getRepoOverview(c: any, input: RepoOverviewInput): Promise<RepoOverview> {
assertOrganization(c, input.organizationId);
const repository = await requireRepositoryForTask(c, input.repoId);
return await repository.getRepoOverview({});
},
async listTasks(c: any, input: ListTasksInput): Promise<TaskSummary[]> {
assertOrganization(c, input.organizationId);
if (input.repoId) {
const repository = await requireRepositoryForTask(c, input.repoId);
return await repository.listTaskSummaries({ includeArchived: true });
}
return await collectAllTaskSummaries(c);
},
async switchTask(c: any, input: { repoId: string; taskId: string }): Promise<SwitchResult> {
await requireRepositoryForTask(c, input.repoId);
const h = getTaskHandle(c, c.state.organizationId, input.repoId, input.taskId);
const record = await h.get();
const switched = await expectQueueResponse<{ switchTarget: string }>(
await h.send(taskWorkflowQueueName("task.command.switch"), {}, { wait: true, timeout: 10_000 }),
);
return {
organizationId: c.state.organizationId,
taskId: input.taskId,
sandboxProviderId: record.sandboxProviderId,
switchTarget: switched.switchTarget,
};
},
async auditLog(c: any, input: HistoryQueryInput): Promise<AuditLogEvent[]> {
assertOrganization(c, input.organizationId);
const limit = input.limit ?? 20;
const repoRows = await c.db.select({ repoId: repos.repoId }).from(repos).orderBy(desc(repos.updatedAt)).all();
const allEvents: AuditLogEvent[] = [];
for (const row of repoRows) {
try {
const auditLog = await getOrCreateAuditLog(c, c.state.organizationId, row.repoId);
const items = await auditLog.list({
branch: input.branch,
taskId: input.taskId,
limit,
});
allEvents.push(...items);
} catch (error) {
logActorWarning("organization", "audit log lookup failed for repo", {
organizationId: c.state.organizationId,
repoId: row.repoId,
error: resolveErrorMessage(error),
});
}
}
allEvents.sort((a, b) => b.createdAt - a.createdAt);
return allEvents.slice(0, limit);
},
async getTask(c: any, input: GetTaskInput): Promise<TaskRecord> {
assertOrganization(c, input.organizationId);
await requireRepositoryForTask(c, input.repoId);
return await getTaskHandle(c, c.state.organizationId, input.repoId, input.taskId).get();
},
async attachTask(c: any, input: TaskProxyActionInput): Promise<{ target: string; sessionId: string | null }> {
assertOrganization(c, input.organizationId);
await requireRepositoryForTask(c, input.repoId);
const h = getTaskHandle(c, c.state.organizationId, input.repoId, input.taskId);
return await expectQueueResponse<{ target: string; sessionId: string | null }>(
await h.send(taskWorkflowQueueName("task.command.attach"), { reason: input.reason }, { wait: true, timeout: 10_000 }),
);
},
async pushTask(c: any, input: TaskProxyActionInput): Promise<void> {
assertOrganization(c, input.organizationId);
await requireRepositoryForTask(c, input.repoId);
const h = getTaskHandle(c, c.state.organizationId, input.repoId, input.taskId);
await h.send(taskWorkflowQueueName("task.command.push"), { reason: input.reason }, { wait: false });
},
async syncTask(c: any, input: TaskProxyActionInput): Promise<void> {
assertOrganization(c, input.organizationId);
await requireRepositoryForTask(c, input.repoId);
const h = getTaskHandle(c, c.state.organizationId, input.repoId, input.taskId);
await h.send(taskWorkflowQueueName("task.command.sync"), { reason: input.reason }, { wait: false });
},
async mergeTask(c: any, input: TaskProxyActionInput): Promise<void> {
assertOrganization(c, input.organizationId);
await requireRepositoryForTask(c, input.repoId);
const h = getTaskHandle(c, c.state.organizationId, input.repoId, input.taskId);
await h.send(taskWorkflowQueueName("task.command.merge"), { reason: input.reason }, { wait: false });
},
async archiveTask(c: any, input: TaskProxyActionInput): Promise<void> {
assertOrganization(c, input.organizationId);
await requireRepositoryForTask(c, input.repoId);
const h = getTaskHandle(c, c.state.organizationId, input.repoId, input.taskId);
await h.send(taskWorkflowQueueName("task.command.archive"), { reason: input.reason }, { wait: false });
},
async killTask(c: any, input: TaskProxyActionInput): Promise<void> {
assertOrganization(c, input.organizationId);
await requireRepositoryForTask(c, input.repoId);
const h = getTaskHandle(c, c.state.organizationId, input.repoId, input.taskId);
await h.send(taskWorkflowQueueName("task.command.kill"), { reason: input.reason }, { wait: false });
},
};

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1 @@
export const APP_SHELL_ORGANIZATION_ID = "app";

View file

@ -56,6 +56,10 @@ CREATE TABLE `organization_profile` (
`github_last_sync_at` integer,
`github_last_webhook_at` integer,
`github_last_webhook_event` text,
`github_sync_generation` integer NOT NULL,
`github_sync_phase` text,
`github_processed_repository_count` integer NOT NULL,
`github_total_repository_count` integer NOT NULL,
`stripe_customer_id` text,
`stripe_subscription_id` text,
`stripe_price_id` text,
@ -86,8 +90,3 @@ CREATE TABLE `stripe_lookup` (
`organization_id` text NOT NULL,
`updated_at` integer NOT NULL
);
--> statement-breakpoint
CREATE TABLE `task_lookup` (
`task_id` text PRIMARY KEY NOT NULL,
`repo_id` text NOT NULL
);

View file

@ -373,6 +373,34 @@
"notNull": false,
"autoincrement": false
},
"github_sync_generation": {
"name": "github_sync_generation",
"type": "integer",
"primaryKey": false,
"notNull": true,
"autoincrement": false
},
"github_sync_phase": {
"name": "github_sync_phase",
"type": "text",
"primaryKey": false,
"notNull": false,
"autoincrement": false
},
"github_processed_repository_count": {
"name": "github_processed_repository_count",
"type": "integer",
"primaryKey": false,
"notNull": true,
"autoincrement": false
},
"github_total_repository_count": {
"name": "github_total_repository_count",
"type": "integer",
"primaryKey": false,
"notNull": true,
"autoincrement": false
},
"stripe_customer_id": {
"name": "stripe_customer_id",
"type": "text",
@ -549,30 +577,6 @@
"compositePrimaryKeys": {},
"uniqueConstraints": {},
"checkConstraints": {}
},
"task_lookup": {
"name": "task_lookup",
"columns": {
"task_id": {
"name": "task_id",
"type": "text",
"primaryKey": true,
"notNull": true,
"autoincrement": false
},
"repo_id": {
"name": "repo_id",
"type": "text",
"primaryKey": false,
"notNull": true,
"autoincrement": false
}
},
"indexes": {},
"foreignKeys": {},
"compositePrimaryKeys": {},
"uniqueConstraints": {},
"checkConstraints": {}
}
},
"views": {},

View file

@ -10,6 +10,12 @@ const journal = {
tag: "0000_melted_viper",
breakpoints: true,
},
{
idx: 1,
when: 1773907201000,
tag: "0001_github_sync_progress",
breakpoints: true,
},
],
} as const;
@ -104,6 +110,14 @@ CREATE TABLE \`stripe_lookup\` (
\`organization_id\` text NOT NULL,
\`updated_at\` integer NOT NULL
);
`,
m0001: `ALTER TABLE \`organization_profile\` ADD \`github_sync_generation\` integer NOT NULL DEFAULT 0;
--> statement-breakpoint
ALTER TABLE \`organization_profile\` ADD \`github_sync_phase\` text;
--> statement-breakpoint
ALTER TABLE \`organization_profile\` ADD \`github_processed_repository_count\` integer NOT NULL DEFAULT 0;
--> statement-breakpoint
ALTER TABLE \`organization_profile\` ADD \`github_total_repository_count\` integer NOT NULL DEFAULT 0;
`,
} as const,
};

View file

@ -36,6 +36,10 @@ export const organizationProfile = sqliteTable(
githubLastSyncAt: integer("github_last_sync_at"),
githubLastWebhookAt: integer("github_last_webhook_at"),
githubLastWebhookEvent: text("github_last_webhook_event"),
githubSyncGeneration: integer("github_sync_generation").notNull(),
githubSyncPhase: text("github_sync_phase"),
githubProcessedRepositoryCount: integer("github_processed_repository_count").notNull(),
githubTotalRepositoryCount: integer("github_total_repository_count").notNull(),
stripeCustomerId: text("stripe_customer_id"),
stripeSubscriptionId: text("stripe_subscription_id"),
stripePriceId: text("stripe_price_id"),

View file

@ -1,7 +1,9 @@
import { actor, queue } from "rivetkit";
import { workflow } from "rivetkit/workflow";
import { organizationDb } from "./db/db.js";
import { runOrganizationWorkflow, ORGANIZATION_QUEUE_NAMES, organizationActions } from "./actions.js";
import { organizationActions } from "./actions.js";
import { ORGANIZATION_QUEUE_NAMES } from "./queues.js";
import { runOrganizationWorkflow } from "./workflow.js";
export const organization = actor({
db: organizationDb,

View file

@ -0,0 +1,36 @@
export const ORGANIZATION_QUEUE_NAMES = [
"organization.command.createTask",
"organization.command.snapshot.broadcast",
"organization.command.syncGithubSession",
"organization.command.better_auth.session_index.upsert",
"organization.command.better_auth.session_index.delete",
"organization.command.better_auth.email_index.upsert",
"organization.command.better_auth.email_index.delete",
"organization.command.better_auth.account_index.upsert",
"organization.command.better_auth.account_index.delete",
"organization.command.better_auth.verification.create",
"organization.command.better_auth.verification.update",
"organization.command.better_auth.verification.update_many",
"organization.command.better_auth.verification.delete",
"organization.command.better_auth.verification.delete_many",
"organization.command.github.repository_projection.apply",
"organization.command.github.data_projection.apply",
"organization.command.github.sync_progress.apply",
"organization.command.github.webhook_receipt.record",
"organization.command.github.organization_shell.sync_from_github",
"organization.command.shell.profile.update",
"organization.command.shell.sync_started.mark",
"organization.command.billing.stripe_customer.apply",
"organization.command.billing.stripe_subscription.apply",
"organization.command.billing.free_plan.apply",
"organization.command.billing.payment_method.set",
"organization.command.billing.status.set",
"organization.command.billing.invoice.upsert",
"organization.command.billing.seat_usage.record",
] as const;
export type OrganizationQueueName = (typeof ORGANIZATION_QUEUE_NAMES)[number];
export function organizationWorkflowQueueName(name: OrganizationQueueName): OrganizationQueueName {
return name;
}

View file

@ -0,0 +1,349 @@
// @ts-nocheck
import { Loop } from "rivetkit/workflow";
import { logActorWarning, resolveErrorMessage } from "../logging.js";
import type { CreateTaskInput } from "@sandbox-agent/foundry-shared";
import {
applyGithubDataProjectionMutation,
applyGithubRepositoryProjectionMutation,
applyGithubSyncProgressMutation,
createTaskMutation,
recordGithubWebhookReceiptMutation,
refreshOrganizationSnapshotMutation,
} from "./actions.js";
import {
betterAuthCreateVerificationMutation,
betterAuthDeleteAccountIndexMutation,
betterAuthDeleteEmailIndexMutation,
betterAuthDeleteManyVerificationMutation,
betterAuthDeleteSessionIndexMutation,
betterAuthDeleteVerificationMutation,
betterAuthUpdateManyVerificationMutation,
betterAuthUpdateVerificationMutation,
betterAuthUpsertAccountIndexMutation,
betterAuthUpsertEmailIndexMutation,
betterAuthUpsertSessionIndexMutation,
} from "./actions/better-auth.js";
import {
applyOrganizationFreePlanMutation,
applyOrganizationStripeCustomerMutation,
applyOrganizationStripeSubscriptionMutation,
markOrganizationSyncStartedMutation,
recordOrganizationSeatUsageMutation,
setOrganizationBillingPaymentMethodMutation,
setOrganizationBillingStatusMutation,
syncOrganizationShellFromGithubMutation,
updateOrganizationShellProfileMutation,
upsertOrganizationInvoiceMutation,
} from "./app-shell.js";
import { ORGANIZATION_QUEUE_NAMES } from "./queues.js";
export async function runOrganizationWorkflow(ctx: any): Promise<void> {
await ctx.loop("organization-command-loop", async (loopCtx: any) => {
const msg = await loopCtx.queue.next("next-organization-command", {
names: [...ORGANIZATION_QUEUE_NAMES],
completable: true,
});
if (!msg) {
return Loop.continue(undefined);
}
try {
if (msg.name === "organization.command.createTask") {
const result = await loopCtx.step({
name: "organization-create-task",
timeout: 5 * 60_000,
run: async () => createTaskMutation(loopCtx, msg.body as CreateTaskInput),
});
await msg.complete(result);
return Loop.continue(undefined);
}
if (msg.name === "organization.command.snapshot.broadcast") {
await loopCtx.step({
name: "organization-snapshot-broadcast",
timeout: 60_000,
run: async () => refreshOrganizationSnapshotMutation(loopCtx),
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
if (msg.name === "organization.command.syncGithubSession") {
await loopCtx.step({
name: "organization-sync-github-session",
timeout: 60_000,
run: async () => {
const { syncGithubOrganizations } = await import("./app-shell.js");
await syncGithubOrganizations(loopCtx, msg.body as { sessionId: string; accessToken: string });
},
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
if (msg.name === "organization.command.better_auth.session_index.upsert") {
const result = await loopCtx.step({
name: "organization-better-auth-session-index-upsert",
timeout: 60_000,
run: async () => betterAuthUpsertSessionIndexMutation(loopCtx, msg.body),
});
await msg.complete(result);
return Loop.continue(undefined);
}
if (msg.name === "organization.command.better_auth.session_index.delete") {
await loopCtx.step({
name: "organization-better-auth-session-index-delete",
timeout: 60_000,
run: async () => betterAuthDeleteSessionIndexMutation(loopCtx, msg.body),
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
if (msg.name === "organization.command.better_auth.email_index.upsert") {
const result = await loopCtx.step({
name: "organization-better-auth-email-index-upsert",
timeout: 60_000,
run: async () => betterAuthUpsertEmailIndexMutation(loopCtx, msg.body),
});
await msg.complete(result);
return Loop.continue(undefined);
}
if (msg.name === "organization.command.better_auth.email_index.delete") {
await loopCtx.step({
name: "organization-better-auth-email-index-delete",
timeout: 60_000,
run: async () => betterAuthDeleteEmailIndexMutation(loopCtx, msg.body),
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
if (msg.name === "organization.command.better_auth.account_index.upsert") {
const result = await loopCtx.step({
name: "organization-better-auth-account-index-upsert",
timeout: 60_000,
run: async () => betterAuthUpsertAccountIndexMutation(loopCtx, msg.body),
});
await msg.complete(result);
return Loop.continue(undefined);
}
if (msg.name === "organization.command.better_auth.account_index.delete") {
await loopCtx.step({
name: "organization-better-auth-account-index-delete",
timeout: 60_000,
run: async () => betterAuthDeleteAccountIndexMutation(loopCtx, msg.body),
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
if (msg.name === "organization.command.better_auth.verification.create") {
const result = await loopCtx.step({
name: "organization-better-auth-verification-create",
timeout: 60_000,
run: async () => betterAuthCreateVerificationMutation(loopCtx, msg.body),
});
await msg.complete(result);
return Loop.continue(undefined);
}
if (msg.name === "organization.command.better_auth.verification.update") {
const result = await loopCtx.step({
name: "organization-better-auth-verification-update",
timeout: 60_000,
run: async () => betterAuthUpdateVerificationMutation(loopCtx, msg.body),
});
await msg.complete(result);
return Loop.continue(undefined);
}
if (msg.name === "organization.command.better_auth.verification.update_many") {
const result = await loopCtx.step({
name: "organization-better-auth-verification-update-many",
timeout: 60_000,
run: async () => betterAuthUpdateManyVerificationMutation(loopCtx, msg.body),
});
await msg.complete(result);
return Loop.continue(undefined);
}
if (msg.name === "organization.command.better_auth.verification.delete") {
await loopCtx.step({
name: "organization-better-auth-verification-delete",
timeout: 60_000,
run: async () => betterAuthDeleteVerificationMutation(loopCtx, msg.body),
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
if (msg.name === "organization.command.better_auth.verification.delete_many") {
const result = await loopCtx.step({
name: "organization-better-auth-verification-delete-many",
timeout: 60_000,
run: async () => betterAuthDeleteManyVerificationMutation(loopCtx, msg.body),
});
await msg.complete(result);
return Loop.continue(undefined);
}
if (msg.name === "organization.command.github.repository_projection.apply") {
await loopCtx.step({
name: "organization-github-repository-projection-apply",
timeout: 60_000,
run: async () => applyGithubRepositoryProjectionMutation(loopCtx, msg.body),
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
if (msg.name === "organization.command.github.data_projection.apply") {
await loopCtx.step({
name: "organization-github-data-projection-apply",
timeout: 60_000,
run: async () => applyGithubDataProjectionMutation(loopCtx, msg.body),
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
if (msg.name === "organization.command.github.sync_progress.apply") {
await loopCtx.step({
name: "organization-github-sync-progress-apply",
timeout: 60_000,
run: async () => applyGithubSyncProgressMutation(loopCtx, msg.body),
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
if (msg.name === "organization.command.github.webhook_receipt.record") {
await loopCtx.step({
name: "organization-github-webhook-receipt-record",
timeout: 60_000,
run: async () => recordGithubWebhookReceiptMutation(loopCtx, msg.body),
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
if (msg.name === "organization.command.github.organization_shell.sync_from_github") {
const result = await loopCtx.step({
name: "organization-github-organization-shell-sync-from-github",
timeout: 60_000,
run: async () => syncOrganizationShellFromGithubMutation(loopCtx, msg.body),
});
await msg.complete(result);
return Loop.continue(undefined);
}
if (msg.name === "organization.command.shell.profile.update") {
await loopCtx.step({
name: "organization-shell-profile-update",
timeout: 60_000,
run: async () => updateOrganizationShellProfileMutation(loopCtx, msg.body),
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
if (msg.name === "organization.command.shell.sync_started.mark") {
await loopCtx.step({
name: "organization-shell-sync-started-mark",
timeout: 60_000,
run: async () => markOrganizationSyncStartedMutation(loopCtx, msg.body),
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
if (msg.name === "organization.command.billing.stripe_customer.apply") {
await loopCtx.step({
name: "organization-billing-stripe-customer-apply",
timeout: 60_000,
run: async () => applyOrganizationStripeCustomerMutation(loopCtx, msg.body),
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
if (msg.name === "organization.command.billing.stripe_subscription.apply") {
await loopCtx.step({
name: "organization-billing-stripe-subscription-apply",
timeout: 60_000,
run: async () => applyOrganizationStripeSubscriptionMutation(loopCtx, msg.body),
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
if (msg.name === "organization.command.billing.free_plan.apply") {
await loopCtx.step({
name: "organization-billing-free-plan-apply",
timeout: 60_000,
run: async () => applyOrganizationFreePlanMutation(loopCtx, msg.body),
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
if (msg.name === "organization.command.billing.payment_method.set") {
await loopCtx.step({
name: "organization-billing-payment-method-set",
timeout: 60_000,
run: async () => setOrganizationBillingPaymentMethodMutation(loopCtx, msg.body),
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
if (msg.name === "organization.command.billing.status.set") {
await loopCtx.step({
name: "organization-billing-status-set",
timeout: 60_000,
run: async () => setOrganizationBillingStatusMutation(loopCtx, msg.body),
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
if (msg.name === "organization.command.billing.invoice.upsert") {
await loopCtx.step({
name: "organization-billing-invoice-upsert",
timeout: 60_000,
run: async () => upsertOrganizationInvoiceMutation(loopCtx, msg.body),
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
if (msg.name === "organization.command.billing.seat_usage.record") {
await loopCtx.step({
name: "organization-billing-seat-usage-record",
timeout: 60_000,
run: async () => recordOrganizationSeatUsageMutation(loopCtx, msg.body),
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
} catch (error) {
const message = resolveErrorMessage(error);
logActorWarning("organization", "organization workflow command failed", {
queueName: msg.name,
error: message,
});
await msg.complete({ error: message }).catch((completeError: unknown) => {
logActorWarning("organization", "organization workflow failed completing error response", {
queueName: msg.name,
error: resolveErrorMessage(completeError),
});
});
}
return Loop.continue(undefined);
});
}

View file

@ -1,9 +1,7 @@
// @ts-nocheck
import { randomUUID } from "node:crypto";
import { and, desc, eq, isNotNull, ne } from "drizzle-orm";
import { Loop } from "rivetkit/workflow";
import type {
AgentType,
RepoOverview,
SandboxProviderId,
TaskRecord,
@ -12,19 +10,21 @@ import type {
WorkspaceSessionSummary,
WorkspaceTaskSummary,
} from "@sandbox-agent/foundry-shared";
import { getGithubData, getOrCreateAuditLog, getOrCreateOrganization, getOrCreateTask, getTask, selfRepository } from "../handles.js";
import { getActorRuntimeContext } from "../context.js";
import { getOrCreateAuditLog, getOrCreateOrganization, getOrCreateTask, getTask } from "../handles.js";
import { organizationWorkflowQueueName } from "../organization/queues.js";
import { taskWorkflowQueueName } from "../task/workflow/index.js";
import { deriveFallbackTitle, resolveCreateFlowDecision } from "../../services/create-flow.js";
import { expectQueueResponse } from "../../services/queue.js";
import { isActorNotFoundError, logActorWarning, resolveErrorMessage } from "../logging.js";
import { defaultSandboxProviderId } from "../../sandbox-config.js";
import { repoMeta, taskIndex, tasks } from "./db/schema.js";
interface CreateTaskCommand {
task: string;
sandboxProviderId: SandboxProviderId;
agentType: AgentType | null;
explicitTitle: string | null;
explicitBranchName: string | null;
initialPrompt: string | null;
onBranch: string | null;
}
@ -38,18 +38,8 @@ interface ListTaskSummariesCommand {
includeArchived?: boolean;
}
interface GetPullRequestForBranchCommand {
branchName: string;
}
const REPOSITORY_QUEUE_NAMES = ["repository.command.createTask", "repository.command.registerTaskBranch"] as const;
type RepositoryQueueName = (typeof REPOSITORY_QUEUE_NAMES)[number];
export { REPOSITORY_QUEUE_NAMES };
export function repositoryWorkflowQueueName(name: RepositoryQueueName): RepositoryQueueName {
return name;
interface GetProjectedTaskSummaryCommand {
taskId: string;
}
function isStaleTaskReferenceError(error: unknown): boolean {
@ -109,26 +99,14 @@ async function upsertTaskSummary(c: any, taskSummary: WorkspaceTaskSummary): Pro
async function notifyOrganizationSnapshotChanged(c: any): Promise<void> {
const organization = await getOrCreateOrganization(c, c.state.organizationId);
await organization.refreshOrganizationSnapshot({});
await expectQueueResponse<{ ok: true }>(
await organization.send(organizationWorkflowQueueName("organization.command.snapshot.broadcast"), {}, { wait: true, timeout: 10_000 }),
);
}
async function persistRemoteUrl(c: any, remoteUrl: string): Promise<void> {
c.state.remoteUrl = remoteUrl;
await c.db
.insert(repoMeta)
.values({
id: 1,
remoteUrl,
updatedAt: Date.now(),
})
.onConflictDoUpdate({
target: repoMeta.id,
set: {
remoteUrl,
updatedAt: Date.now(),
},
})
.run();
async function readStoredRemoteUrl(c: any): Promise<string | null> {
const row = await c.db.select({ remoteUrl: repoMeta.remoteUrl }).from(repoMeta).where(eq(repoMeta.id, 1)).get();
return row?.remoteUrl ?? null;
}
async function deleteStaleTaskIndexRow(c: any, taskId: string): Promise<void> {
@ -164,31 +142,6 @@ async function listKnownTaskBranches(c: any): Promise<string[]> {
return rows.map((row) => row.branchName).filter((value): value is string => typeof value === "string" && value.trim().length > 0);
}
function parseJsonValue<T>(value: string | null | undefined, fallback: T): T {
if (!value) {
return fallback;
}
try {
return JSON.parse(value) as T;
} catch {
return fallback;
}
}
function taskSummaryRowFromSummary(taskSummary: WorkspaceTaskSummary) {
return {
taskId: taskSummary.id,
title: taskSummary.title,
status: taskSummary.status,
repoName: taskSummary.repoName,
updatedAtMs: taskSummary.updatedAtMs,
branch: taskSummary.branch,
pullRequestJson: JSON.stringify(taskSummary.pullRequest),
sessionsSummaryJson: JSON.stringify(taskSummary.sessionsSummary),
};
}
async function resolveGitHubRepository(c: any) {
const githubData = getGithubData(c, c.state.organizationId);
return await githubData.getRepository({ repoId: c.state.repoId }).catch(() => null);
@ -199,17 +152,29 @@ async function listGitHubBranches(c: any): Promise<Array<{ branchName: string; c
return await githubData.listBranchesForRepository({ repoId: c.state.repoId }).catch(() => []);
}
async function createTaskMutation(c: any, cmd: CreateTaskCommand): Promise<TaskRecord> {
async function resolveRepositoryRemoteUrl(c: any): Promise<string> {
const storedRemoteUrl = await readStoredRemoteUrl(c);
if (storedRemoteUrl) {
return storedRemoteUrl;
}
const repository = await resolveGitHubRepository(c);
const remoteUrl = repository?.cloneUrl?.trim();
if (!remoteUrl) {
throw new Error(`Missing remote URL for repo ${c.state.repoId}`);
}
return remoteUrl;
}
export async function createTaskMutation(c: any, cmd: CreateTaskCommand): Promise<TaskRecord> {
const organizationId = c.state.organizationId;
const repoId = c.state.repoId;
const repoRemote = c.state.remoteUrl;
await resolveRepositoryRemoteUrl(c);
const onBranch = cmd.onBranch?.trim() || null;
const taskId = randomUUID();
let initialBranchName: string | null = null;
let initialTitle: string | null = null;
await persistRemoteUrl(c, repoRemote);
if (onBranch) {
initialBranchName = onBranch;
initialTitle = deriveFallbackTitle(cmd.task, cmd.explicitTitle ?? undefined);
@ -251,15 +216,6 @@ async function createTaskMutation(c: any, cmd: CreateTaskCommand): Promise<TaskR
organizationId,
repoId,
taskId,
repoRemote,
branchName: initialBranchName,
title: initialTitle,
task: cmd.task,
sandboxProviderId: cmd.sandboxProviderId,
agentType: cmd.agentType,
explicitTitle: null,
explicitBranchName: null,
initialPrompt: cmd.initialPrompt,
});
} catch (error) {
if (initialBranchName) {
@ -268,7 +224,21 @@ async function createTaskMutation(c: any, cmd: CreateTaskCommand): Promise<TaskR
throw error;
}
const created = await taskHandle.initialize({ sandboxProviderId: cmd.sandboxProviderId });
const created = await expectQueueResponse<TaskRecord>(
await taskHandle.send(
taskWorkflowQueueName("task.command.initialize"),
{
sandboxProviderId: cmd.sandboxProviderId,
branchName: initialBranchName,
title: initialTitle,
task: cmd.task,
},
{
wait: true,
timeout: 10_000,
},
),
);
try {
await upsertTaskSummary(c, await taskHandle.getTaskSummary({}));
@ -313,25 +283,12 @@ async function createTaskMutation(c: any, cmd: CreateTaskCommand): Promise<TaskR
return created;
}
async function upsertTaskSummary(c: any, taskSummary: WorkspaceTaskSummary): Promise<void> {
await c.db
.insert(tasks)
.values(taskSummaryRowFromSummary(taskSummary))
.onConflictDoUpdate({
target: tasks.taskId,
set: taskSummaryRowFromSummary(taskSummary),
})
.run();
}
async function registerTaskBranchMutation(c: any, cmd: RegisterTaskBranchCommand): Promise<{ branchName: string; headSha: string }> {
export async function registerTaskBranchMutation(c: any, cmd: RegisterTaskBranchCommand): Promise<{ branchName: string; headSha: string }> {
const branchName = cmd.branchName.trim();
if (!branchName) {
throw new Error("branchName is required");
}
await persistRemoteUrl(c, c.state.remoteUrl);
const existingOwner = await c.db
.select({ taskId: taskIndex.taskId })
.from(taskIndex)
@ -397,6 +354,7 @@ async function listTaskSummaries(c: any, includeArchived = false): Promise<TaskS
title: row.title,
status: row.status,
updatedAt: row.updatedAtMs,
pullRequest: parseJsonValue<WorkspacePullRequestSummary | null>(row.pullRequestJson, null),
}))
.filter((row) => includeArchived || row.status !== "archived");
}
@ -413,12 +371,8 @@ function sortOverviewBranches(
taskId: string | null;
taskTitle: string | null;
taskStatus: TaskRecord["status"] | null;
prNumber: number | null;
prState: string | null;
prUrl: string | null;
pullRequest: WorkspacePullRequestSummary | null;
ciStatus: string | null;
reviewStatus: string | null;
reviewer: string | null;
updatedAt: number;
}>,
defaultBranch: string | null,
@ -438,60 +392,59 @@ function sortOverviewBranches(
});
}
export async function runRepositoryWorkflow(ctx: any): Promise<void> {
await ctx.loop("repository-command-loop", async (loopCtx: any) => {
const msg = await loopCtx.queue.next("next-repository-command", {
names: [...REPOSITORY_QUEUE_NAMES],
completable: true,
export async function applyTaskSummaryUpdateMutation(c: any, input: { taskSummary: WorkspaceTaskSummary }): Promise<void> {
await upsertTaskSummary(c, input.taskSummary);
await notifyOrganizationSnapshotChanged(c);
}
export async function removeTaskSummaryMutation(c: any, input: { taskId: string }): Promise<void> {
await c.db.delete(tasks).where(eq(tasks.taskId, input.taskId)).run();
await notifyOrganizationSnapshotChanged(c);
}
export async function refreshTaskSummaryForBranchMutation(
c: any,
input: { branchName: string; pullRequest?: WorkspacePullRequestSummary | null },
): Promise<void> {
const pullRequest = input.pullRequest ?? null;
let rows = await c.db.select({ taskId: tasks.taskId }).from(tasks).where(eq(tasks.branch, input.branchName)).all();
if (rows.length === 0 && pullRequest) {
const { config } = getActorRuntimeContext();
const created = await createTaskMutation(c, {
task: pullRequest.title?.trim() || `Review ${input.branchName}`,
sandboxProviderId: defaultSandboxProviderId(config),
explicitTitle: pullRequest.title?.trim() || input.branchName,
explicitBranchName: null,
onBranch: input.branchName,
});
if (!msg) {
return Loop.continue(undefined);
}
rows = [{ taskId: created.taskId }];
}
for (const row of rows) {
try {
if (msg.name === "repository.command.createTask") {
const result = await loopCtx.step({
name: "repository-create-task",
timeout: 5 * 60_000,
run: async () => createTaskMutation(loopCtx, msg.body as CreateTaskCommand),
});
await msg.complete(result);
return Loop.continue(undefined);
}
if (msg.name === "repository.command.registerTaskBranch") {
const result = await loopCtx.step({
name: "repository-register-task-branch",
timeout: 60_000,
run: async () => registerTaskBranchMutation(loopCtx, msg.body as RegisterTaskBranchCommand),
});
await msg.complete(result);
return Loop.continue(undefined);
}
const task = getTask(c, c.state.organizationId, c.state.repoId, row.taskId);
await expectQueueResponse<{ ok: true }>(
await task.send(
taskWorkflowQueueName("task.command.pull_request.sync"),
{ pullRequest },
{ wait: true, timeout: 10_000 },
),
);
} catch (error) {
const message = resolveErrorMessage(error);
logActorWarning("repository", "repository workflow command failed", {
queueName: msg.name,
error: message,
logActorWarning("repository", "failed refreshing task summary for branch", {
organizationId: c.state.organizationId,
repoId: c.state.repoId,
branchName: input.branchName,
taskId: row.taskId,
error: resolveErrorMessage(error),
});
await msg.complete({ error: message }).catch(() => {});
}
}
return Loop.continue(undefined);
});
}
export const repositoryActions = {
async createTask(c: any, cmd: CreateTaskCommand): Promise<TaskRecord> {
const self = selfRepository(c);
return expectQueueResponse<TaskRecord>(
await self.send(repositoryWorkflowQueueName("repository.command.createTask"), cmd, {
wait: true,
timeout: 10_000,
}),
);
},
async listReservedBranches(c: any): Promise<string[]> {
return await listKnownTaskBranches(c);
},
@ -506,23 +459,19 @@ export const repositoryActions = {
async getRepositoryMetadata(c: any): Promise<{ defaultBranch: string | null; fullName: string | null; remoteUrl: string }> {
const repository = await resolveGitHubRepository(c);
const remoteUrl = await resolveRepositoryRemoteUrl(c);
return {
defaultBranch: repository?.defaultBranch ?? null,
fullName: repository?.fullName ?? null,
remoteUrl: c.state.remoteUrl,
remoteUrl,
};
},
async getRepoOverview(c: any): Promise<RepoOverview> {
await persistRemoteUrl(c, c.state.remoteUrl);
const now = Date.now();
const repository = await resolveGitHubRepository(c);
const remoteUrl = await resolveRepositoryRemoteUrl(c);
const githubBranches = await listGitHubBranches(c).catch(() => []);
const githubData = getGithubData(c, c.state.organizationId);
const prRows = await githubData.listPullRequestsForRepository({ repoId: c.state.repoId }).catch(() => []);
const prByBranch = new Map(prRows.map((row) => [row.headRefName, row]));
const taskRows = await c.db.select().from(tasks).all();
const taskMetaByBranch = new Map<
@ -558,19 +507,15 @@ export const repositoryActions = {
const branches = sortOverviewBranches(
[...branchMap.values()].map((branch) => {
const taskMeta = taskMetaByBranch.get(branch.branchName);
const pr = taskMeta?.pullRequest ?? prByBranch.get(branch.branchName) ?? null;
const pr = taskMeta?.pullRequest ?? null;
return {
branchName: branch.branchName,
commitSha: branch.commitSha,
taskId: taskMeta?.taskId ?? null,
taskTitle: taskMeta?.title ?? null,
taskStatus: taskMeta?.status ?? null,
prNumber: pr?.number ?? null,
prState: "state" in (pr ?? {}) ? pr.state : null,
prUrl: "url" in (pr ?? {}) ? pr.url : null,
pullRequest: pr,
ciStatus: null,
reviewStatus: pr && "isDraft" in pr ? (pr.isDraft ? "draft" : "ready") : null,
reviewer: pr?.authorLogin ?? null,
updatedAt: Math.max(taskMeta?.updatedAt ?? 0, pr?.updatedAtMs ?? 0, now),
};
}),
@ -580,58 +525,24 @@ export const repositoryActions = {
return {
organizationId: c.state.organizationId,
repoId: c.state.repoId,
remoteUrl: c.state.remoteUrl,
remoteUrl,
baseRef: repository?.defaultBranch ?? null,
fetchedAt: now,
branches,
};
},
async applyTaskSummaryUpdate(c: any, input: { taskSummary: WorkspaceTaskSummary }): Promise<void> {
await upsertTaskSummary(c, input.taskSummary);
await notifyOrganizationSnapshotChanged(c);
},
async removeTaskSummary(c: any, input: { taskId: string }): Promise<void> {
await c.db.delete(tasks).where(eq(tasks.taskId, input.taskId)).run();
await notifyOrganizationSnapshotChanged(c);
},
async findTaskForBranch(c: any, input: { branchName: string }): Promise<{ taskId: string | null }> {
const row = await c.db.select({ taskId: tasks.taskId }).from(tasks).where(eq(tasks.branch, input.branchName)).get();
return { taskId: row?.taskId ?? null };
},
async refreshTaskSummaryForBranch(c: any, input: { branchName: string }): Promise<void> {
const rows = await c.db.select({ taskId: tasks.taskId }).from(tasks).where(eq(tasks.branch, input.branchName)).all();
for (const row of rows) {
try {
const task = getTask(c, c.state.organizationId, c.state.repoId, row.taskId);
await upsertTaskSummary(c, await task.getTaskSummary({}));
} catch (error) {
logActorWarning("repository", "failed refreshing task summary for branch", {
organizationId: c.state.organizationId,
repoId: c.state.repoId,
branchName: input.branchName,
taskId: row.taskId,
error: resolveErrorMessage(error),
});
}
}
await notifyOrganizationSnapshotChanged(c);
},
async getPullRequestForBranch(c: any, cmd: GetPullRequestForBranchCommand): Promise<WorkspacePullRequestSummary | null> {
const branchName = cmd.branchName?.trim();
if (!branchName) {
async getProjectedTaskSummary(c: any, input: GetProjectedTaskSummaryCommand): Promise<WorkspaceTaskSummary | null> {
const taskId = input.taskId?.trim();
if (!taskId) {
return null;
}
const githubData = getGithubData(c, c.state.organizationId);
const rows = await githubData.listPullRequestsForRepository({
repoId: c.state.repoId,
});
return rows.find((candidate: WorkspacePullRequestSummary) => candidate.headRefName === branchName) ?? null;
const row = await c.db.select().from(tasks).where(eq(tasks.taskId, taskId)).get();
return row ? taskSummaryFromRow(c, row) : null;
},
};

View file

@ -1,12 +1,12 @@
import { actor, queue } from "rivetkit";
import { workflow } from "rivetkit/workflow";
import { repositoryDb } from "./db/db.js";
import { REPOSITORY_QUEUE_NAMES, repositoryActions, runRepositoryWorkflow } from "./actions.js";
import { repositoryActions } from "./actions.js";
import { REPOSITORY_QUEUE_NAMES, runRepositoryWorkflow } from "./workflow.js";
export interface RepositoryInput {
organizationId: string;
repoId: string;
remoteUrl: string;
}
export const repository = actor({
@ -20,7 +20,6 @@ export const repository = actor({
createState: (_c, input: RepositoryInput) => ({
organizationId: input.organizationId,
repoId: input.repoId,
remoteUrl: input.remoteUrl,
}),
actions: repositoryActions,
run: workflow(runRepositoryWorkflow),

View file

@ -0,0 +1,97 @@
// @ts-nocheck
import { Loop } from "rivetkit/workflow";
import { logActorWarning, resolveErrorMessage } from "../logging.js";
import {
applyTaskSummaryUpdateMutation,
createTaskMutation,
refreshTaskSummaryForBranchMutation,
registerTaskBranchMutation,
removeTaskSummaryMutation,
} from "./actions.js";
export const REPOSITORY_QUEUE_NAMES = [
"repository.command.createTask",
"repository.command.registerTaskBranch",
"repository.command.applyTaskSummaryUpdate",
"repository.command.removeTaskSummary",
"repository.command.refreshTaskSummaryForBranch",
] as const;
export type RepositoryQueueName = (typeof REPOSITORY_QUEUE_NAMES)[number];
export function repositoryWorkflowQueueName(name: RepositoryQueueName): RepositoryQueueName {
return name;
}
export async function runRepositoryWorkflow(ctx: any): Promise<void> {
await ctx.loop("repository-command-loop", async (loopCtx: any) => {
const msg = await loopCtx.queue.next("next-repository-command", {
names: [...REPOSITORY_QUEUE_NAMES],
completable: true,
});
if (!msg) {
return Loop.continue(undefined);
}
try {
if (msg.name === "repository.command.createTask") {
const result = await loopCtx.step({
name: "repository-create-task",
timeout: 5 * 60_000,
run: async () => createTaskMutation(loopCtx, msg.body),
});
await msg.complete(result);
return Loop.continue(undefined);
}
if (msg.name === "repository.command.registerTaskBranch") {
const result = await loopCtx.step({
name: "repository-register-task-branch",
timeout: 60_000,
run: async () => registerTaskBranchMutation(loopCtx, msg.body),
});
await msg.complete(result);
return Loop.continue(undefined);
}
if (msg.name === "repository.command.applyTaskSummaryUpdate") {
await loopCtx.step({
name: "repository-apply-task-summary-update",
timeout: 30_000,
run: async () => applyTaskSummaryUpdateMutation(loopCtx, msg.body),
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
if (msg.name === "repository.command.removeTaskSummary") {
await loopCtx.step({
name: "repository-remove-task-summary",
timeout: 30_000,
run: async () => removeTaskSummaryMutation(loopCtx, msg.body),
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
if (msg.name === "repository.command.refreshTaskSummaryForBranch") {
await loopCtx.step({
name: "repository-refresh-task-summary-for-branch",
timeout: 60_000,
run: async () => refreshTaskSummaryForBranchMutation(loopCtx, msg.body),
});
await msg.complete({ ok: true });
return Loop.continue(undefined);
}
} catch (error) {
const message = resolveErrorMessage(error);
logActorWarning("repository", "repository workflow command failed", {
queueName: msg.name,
error: message,
});
await msg.complete({ error: message }).catch(() => {});
}
return Loop.continue(undefined);
});
}

View file

@ -2,6 +2,7 @@ import { actor } from "rivetkit";
import { e2b, sandboxActor } from "rivetkit/sandbox";
import { existsSync } from "node:fs";
import Dockerode from "dockerode";
import { DEFAULT_WORKSPACE_MODEL_GROUPS, workspaceModelGroupsFromSandboxAgents, type WorkspaceModelGroup } from "@sandbox-agent/foundry-shared";
import { SandboxAgent } from "sandbox-agent";
import { getActorRuntimeContext } from "../context.js";
import { organizationKey } from "../keys.js";
@ -258,6 +259,26 @@ async function providerForConnection(c: any): Promise<any | null> {
return provider;
}
async function listWorkspaceModelGroupsForSandbox(c: any): Promise<WorkspaceModelGroup[]> {
const provider = await providerForConnection(c);
if (!provider || !c.state.sandboxId || typeof provider.connectAgent !== "function") {
return DEFAULT_WORKSPACE_MODEL_GROUPS;
}
try {
const client = await provider.connectAgent(c.state.sandboxId, {
waitForHealth: {
timeoutMs: 15_000,
},
});
const listed = await client.listAgents({ config: true });
const groups = workspaceModelGroupsFromSandboxAgents(Array.isArray(listed?.agents) ? listed.agents : []);
return groups.length > 0 ? groups : DEFAULT_WORKSPACE_MODEL_GROUPS;
} catch {
return DEFAULT_WORKSPACE_MODEL_GROUPS;
}
}
const baseActions = baseTaskSandbox.config.actions as Record<string, (c: any, ...args: any[]) => Promise<any>>;
export const taskSandbox = actor({
@ -360,6 +381,10 @@ export const taskSandbox = actor({
}
},
async listWorkspaceModelGroups(c: any): Promise<WorkspaceModelGroup[]> {
return await listWorkspaceModelGroupsForSandbox(c);
},
async providerState(c: any): Promise<{ sandboxProviderId: "e2b" | "local"; sandboxId: string; state: string; at: number }> {
const { config } = getActorRuntimeContext();
const { taskId } = parseTaskSandboxKey(c.key);

View file

@ -5,8 +5,7 @@ CREATE TABLE `task` (
`task` text NOT NULL,
`sandbox_provider_id` text NOT NULL,
`status` text NOT NULL,
`agent_type` text DEFAULT 'claude',
`pr_submitted` integer DEFAULT 0,
`pull_request_json` text,
`created_at` integer NOT NULL,
`updated_at` integer NOT NULL,
CONSTRAINT "task_singleton_id_check" CHECK("task"."id" = 1)
@ -15,14 +14,10 @@ CREATE TABLE `task` (
CREATE TABLE `task_runtime` (
`id` integer PRIMARY KEY NOT NULL,
`active_sandbox_id` text,
`active_session_id` text,
`active_switch_target` text,
`active_cwd` text,
`status_message` text,
`git_state_json` text,
`git_state_updated_at` integer,
`provision_stage` text,
`provision_stage_updated_at` integer,
`updated_at` integer NOT NULL,
CONSTRAINT "task_runtime_singleton_id_check" CHECK("task_runtime"."id" = 1)
);
@ -33,7 +28,6 @@ CREATE TABLE `task_sandboxes` (
`sandbox_actor_id` text,
`switch_target` text NOT NULL,
`cwd` text,
`status_message` text,
`created_at` integer NOT NULL,
`updated_at` integer NOT NULL
);
@ -47,10 +41,6 @@ CREATE TABLE `task_workspace_sessions` (
`error_message` text,
`transcript_json` text DEFAULT '[]' NOT NULL,
`transcript_updated_at` integer,
`unread` integer DEFAULT 0 NOT NULL,
`draft_text` text DEFAULT '' NOT NULL,
`draft_attachments_json` text DEFAULT '[]' NOT NULL,
`draft_updated_at` integer,
`created` integer DEFAULT 1 NOT NULL,
`closed` integer DEFAULT 0 NOT NULL,
`thinking_since_ms` integer,

View file

@ -35,8 +35,8 @@
"notNull": true,
"autoincrement": false
},
"provider_id": {
"name": "provider_id",
"sandbox_provider_id": {
"name": "sandbox_provider_id",
"type": "text",
"primaryKey": false,
"notNull": true,
@ -49,21 +49,12 @@
"notNull": true,
"autoincrement": false
},
"agent_type": {
"name": "agent_type",
"pull_request_json": {
"name": "pull_request_json",
"type": "text",
"primaryKey": false,
"notNull": false,
"autoincrement": false,
"default": "'claude'"
},
"pr_submitted": {
"name": "pr_submitted",
"type": "integer",
"primaryKey": false,
"notNull": false,
"autoincrement": false,
"default": 0
"autoincrement": false
},
"created_at": {
"name": "created_at",
@ -108,13 +99,6 @@
"notNull": false,
"autoincrement": false
},
"active_session_id": {
"name": "active_session_id",
"type": "text",
"primaryKey": false,
"notNull": false,
"autoincrement": false
},
"active_switch_target": {
"name": "active_switch_target",
"type": "text",
@ -129,13 +113,20 @@
"notNull": false,
"autoincrement": false
},
"status_message": {
"name": "status_message",
"git_state_json": {
"name": "git_state_json",
"type": "text",
"primaryKey": false,
"notNull": false,
"autoincrement": false
},
"git_state_updated_at": {
"name": "git_state_updated_at",
"type": "integer",
"primaryKey": false,
"notNull": false,
"autoincrement": false
},
"updated_at": {
"name": "updated_at",
"type": "integer",
@ -165,8 +156,8 @@
"notNull": true,
"autoincrement": false
},
"provider_id": {
"name": "provider_id",
"sandbox_provider_id": {
"name": "sandbox_provider_id",
"type": "text",
"primaryKey": false,
"notNull": true,
@ -193,13 +184,6 @@
"notNull": false,
"autoincrement": false
},
"status_message": {
"name": "status_message",
"type": "text",
"primaryKey": false,
"notNull": false,
"autoincrement": false
},
"created_at": {
"name": "created_at",
"type": "integer",
@ -231,6 +215,13 @@
"notNull": true,
"autoincrement": false
},
"sandbox_session_id": {
"name": "sandbox_session_id",
"type": "text",
"primaryKey": false,
"notNull": false,
"autoincrement": false
},
"session_name": {
"name": "session_name",
"type": "text",
@ -245,32 +236,31 @@
"notNull": true,
"autoincrement": false
},
"unread": {
"name": "unread",
"type": "integer",
"primaryKey": false,
"notNull": true,
"autoincrement": false,
"default": 0
},
"draft_text": {
"name": "draft_text",
"status": {
"name": "status",
"type": "text",
"primaryKey": false,
"notNull": true,
"autoincrement": false,
"default": "''"
"default": "'ready'"
},
"draft_attachments_json": {
"name": "draft_attachments_json",
"error_message": {
"name": "error_message",
"type": "text",
"primaryKey": false,
"notNull": false,
"autoincrement": false
},
"transcript_json": {
"name": "transcript_json",
"type": "text",
"primaryKey": false,
"notNull": true,
"autoincrement": false,
"default": "'[]'"
},
"draft_updated_at": {
"name": "draft_updated_at",
"transcript_updated_at": {
"name": "transcript_updated_at",
"type": "integer",
"primaryKey": false,
"notNull": false,

View file

@ -11,6 +11,7 @@ export const task = sqliteTable(
task: text("task").notNull(),
sandboxProviderId: text("sandbox_provider_id").notNull(),
status: text("status").notNull(),
pullRequestJson: text("pull_request_json"),
createdAt: integer("created_at").notNull(),
updatedAt: integer("updated_at").notNull(),
},
@ -42,7 +43,6 @@ export const taskSandboxes = sqliteTable("task_sandboxes", {
sandboxActorId: text("sandbox_actor_id"),
switchTarget: text("switch_target").notNull(),
cwd: text("cwd"),
statusMessage: text("status_message"),
createdAt: integer("created_at").notNull(),
updatedAt: integer("updated_at").notNull(),
});

View file

@ -1,122 +1,15 @@
import { actor, queue } from "rivetkit";
import { workflow } from "rivetkit/workflow";
import type {
TaskRecord,
TaskWorkspaceChangeModelInput,
TaskWorkspaceRenameInput,
TaskWorkspaceRenameSessionInput,
TaskWorkspaceSetSessionUnreadInput,
TaskWorkspaceSendMessageInput,
TaskWorkspaceUpdateDraftInput,
SandboxProviderId,
} from "@sandbox-agent/foundry-shared";
import { expectQueueResponse } from "../../services/queue.js";
import { selfTask } from "../handles.js";
import type { TaskRecord } from "@sandbox-agent/foundry-shared";
import { taskDb } from "./db/db.js";
import { getCurrentRecord } from "./workflow/common.js";
import {
changeWorkspaceModel,
closeWorkspaceSession,
createWorkspaceSession,
getSessionDetail,
getTaskDetail,
getTaskSummary,
markWorkspaceUnread,
publishWorkspacePr,
renameWorkspaceTask,
renameWorkspaceSession,
revertWorkspaceFile,
sendWorkspaceMessage,
syncWorkspaceSessionStatus,
setWorkspaceSessionUnread,
stopWorkspaceSession,
updateWorkspaceDraft,
} from "./workspace.js";
import { TASK_QUEUE_NAMES, taskWorkflowQueueName, runTaskWorkflow } from "./workflow/index.js";
import { getSessionDetail, getTaskDetail, getTaskSummary } from "./workspace.js";
import { TASK_QUEUE_NAMES, runTaskWorkflow } from "./workflow/index.js";
export interface TaskInput {
organizationId: string;
repoId: string;
taskId: string;
repoRemote: string;
branchName: string | null;
title: string | null;
task: string;
sandboxProviderId: SandboxProviderId;
explicitTitle: string | null;
explicitBranchName: string | null;
}
interface InitializeCommand {
sandboxProviderId?: SandboxProviderId;
}
interface TaskActionCommand {
reason?: string;
}
interface TaskSessionCommand {
sessionId: string;
authSessionId?: string;
}
interface TaskStatusSyncCommand {
sessionId: string;
status: "running" | "idle" | "error";
at: number;
}
interface TaskWorkspaceValueCommand {
value: string;
authSessionId?: string;
}
interface TaskWorkspaceSessionTitleCommand {
sessionId: string;
title: string;
authSessionId?: string;
}
interface TaskWorkspaceSessionUnreadCommand {
sessionId: string;
unread: boolean;
authSessionId?: string;
}
interface TaskWorkspaceUpdateDraftCommand {
sessionId: string;
text: string;
attachments: Array<any>;
authSessionId?: string;
}
interface TaskWorkspaceChangeModelCommand {
sessionId: string;
model: string;
authSessionId?: string;
}
interface TaskWorkspaceSendMessageCommand {
sessionId: string;
text: string;
attachments: Array<any>;
authSessionId?: string;
}
interface TaskWorkspaceCreateSessionCommand {
model?: string;
authSessionId?: string;
}
interface TaskWorkspaceCreateSessionAndSendCommand {
model?: string;
text: string;
authSessionId?: string;
}
interface TaskWorkspaceSessionCommand {
sessionId: string;
authSessionId?: string;
}
export const task = actor({
@ -131,85 +24,10 @@ export const task = actor({
organizationId: input.organizationId,
repoId: input.repoId,
taskId: input.taskId,
repoRemote: input.repoRemote,
}),
actions: {
async initialize(c, cmd: InitializeCommand): Promise<TaskRecord> {
const self = selfTask(c);
const result = await self.send(taskWorkflowQueueName("task.command.initialize"), cmd ?? {}, {
wait: true,
timeout: 10_000,
});
return expectQueueResponse<TaskRecord>(result);
},
async provision(c, cmd: InitializeCommand): Promise<{ ok: true }> {
const self = selfTask(c);
await self.send(taskWorkflowQueueName("task.command.provision"), cmd ?? {}, {
wait: false,
});
return { ok: true };
},
async attach(c, cmd?: TaskActionCommand): Promise<{ target: string; sessionId: string | null }> {
const self = selfTask(c);
const result = await self.send(taskWorkflowQueueName("task.command.attach"), cmd ?? {}, {
wait: true,
timeout: 10_000,
});
return expectQueueResponse<{ target: string; sessionId: string | null }>(result);
},
async switch(c): Promise<{ switchTarget: string }> {
const self = selfTask(c);
const result = await self.send(
taskWorkflowQueueName("task.command.switch"),
{},
{
wait: true,
timeout: 10_000,
},
);
return expectQueueResponse<{ switchTarget: string }>(result);
},
async push(c, cmd?: TaskActionCommand): Promise<void> {
const self = selfTask(c);
await self.send(taskWorkflowQueueName("task.command.push"), cmd ?? {}, {
wait: false,
});
},
async sync(c, cmd?: TaskActionCommand): Promise<void> {
const self = selfTask(c);
await self.send(taskWorkflowQueueName("task.command.sync"), cmd ?? {}, {
wait: false,
});
},
async merge(c, cmd?: TaskActionCommand): Promise<void> {
const self = selfTask(c);
await self.send(taskWorkflowQueueName("task.command.merge"), cmd ?? {}, {
wait: false,
});
},
async archive(c, cmd?: TaskActionCommand): Promise<void> {
const self = selfTask(c);
await self.send(taskWorkflowQueueName("task.command.archive"), cmd ?? {}, {
wait: false,
});
},
async kill(c, cmd?: TaskActionCommand): Promise<void> {
const self = selfTask(c);
await self.send(taskWorkflowQueueName("task.command.kill"), cmd ?? {}, {
wait: false,
});
},
async get(c): Promise<TaskRecord> {
return await getCurrentRecord({ db: c.db, state: c.state });
return await getCurrentRecord(c);
},
async getTaskSummary(c) {
@ -223,175 +41,6 @@ export const task = actor({
async getSessionDetail(c, input: { sessionId: string; authSessionId?: string }) {
return await getSessionDetail(c, input.sessionId, input.authSessionId);
},
async markWorkspaceUnread(c, input?: { authSessionId?: string }): Promise<void> {
const self = selfTask(c);
await self.send(
taskWorkflowQueueName("task.command.workspace.mark_unread"),
{ authSessionId: input?.authSessionId },
{
wait: true,
timeout: 10_000,
},
);
},
async renameWorkspaceTask(c, input: TaskWorkspaceRenameInput): Promise<void> {
const self = selfTask(c);
await self.send(
taskWorkflowQueueName("task.command.workspace.rename_task"),
{ value: input.value, authSessionId: input.authSessionId } satisfies TaskWorkspaceValueCommand,
{
wait: true,
timeout: 20_000,
},
);
},
async createWorkspaceSession(c, input?: { model?: string; authSessionId?: string }): Promise<{ sessionId: string }> {
const self = selfTask(c);
const result = await self.send(
taskWorkflowQueueName("task.command.workspace.create_session"),
{
...(input?.model ? { model: input.model } : {}),
...(input?.authSessionId ? { authSessionId: input.authSessionId } : {}),
} satisfies TaskWorkspaceCreateSessionCommand,
{
wait: true,
timeout: 10_000,
},
);
return expectQueueResponse<{ sessionId: string }>(result);
},
/**
* Fire-and-forget: creates a session and sends the initial message.
* Used by createWorkspaceTask so the caller doesn't block on session creation.
*/
async createWorkspaceSessionAndSend(c, input: { model?: string; text: string; authSessionId?: string }): Promise<void> {
const self = selfTask(c);
await self.send(
taskWorkflowQueueName("task.command.workspace.create_session_and_send"),
{ model: input.model, text: input.text, authSessionId: input.authSessionId } satisfies TaskWorkspaceCreateSessionAndSendCommand,
{ wait: false },
);
},
async renameWorkspaceSession(c, input: TaskWorkspaceRenameSessionInput): Promise<void> {
const self = selfTask(c);
await self.send(
taskWorkflowQueueName("task.command.workspace.rename_session"),
{ sessionId: input.sessionId, title: input.title, authSessionId: input.authSessionId } satisfies TaskWorkspaceSessionTitleCommand,
{
wait: true,
timeout: 10_000,
},
);
},
async setWorkspaceSessionUnread(c, input: TaskWorkspaceSetSessionUnreadInput): Promise<void> {
const self = selfTask(c);
await self.send(
taskWorkflowQueueName("task.command.workspace.set_session_unread"),
{ sessionId: input.sessionId, unread: input.unread, authSessionId: input.authSessionId } satisfies TaskWorkspaceSessionUnreadCommand,
{
wait: true,
timeout: 10_000,
},
);
},
async updateWorkspaceDraft(c, input: TaskWorkspaceUpdateDraftInput): Promise<void> {
const self = selfTask(c);
await self.send(
taskWorkflowQueueName("task.command.workspace.update_draft"),
{
sessionId: input.sessionId,
text: input.text,
attachments: input.attachments,
authSessionId: input.authSessionId,
} satisfies TaskWorkspaceUpdateDraftCommand,
{
wait: false,
},
);
},
async changeWorkspaceModel(c, input: TaskWorkspaceChangeModelInput): Promise<void> {
const self = selfTask(c);
await self.send(
taskWorkflowQueueName("task.command.workspace.change_model"),
{ sessionId: input.sessionId, model: input.model, authSessionId: input.authSessionId } satisfies TaskWorkspaceChangeModelCommand,
{
wait: true,
timeout: 10_000,
},
);
},
async sendWorkspaceMessage(c, input: TaskWorkspaceSendMessageInput): Promise<void> {
const self = selfTask(c);
await self.send(
taskWorkflowQueueName("task.command.workspace.send_message"),
{
sessionId: input.sessionId,
text: input.text,
attachments: input.attachments,
authSessionId: input.authSessionId,
} satisfies TaskWorkspaceSendMessageCommand,
{
wait: false,
},
);
},
async stopWorkspaceSession(c, input: TaskSessionCommand): Promise<void> {
const self = selfTask(c);
await self.send(
taskWorkflowQueueName("task.command.workspace.stop_session"),
{ sessionId: input.sessionId, authSessionId: input.authSessionId } satisfies TaskWorkspaceSessionCommand,
{
wait: false,
},
);
},
async syncWorkspaceSessionStatus(c, input: TaskStatusSyncCommand): Promise<void> {
const self = selfTask(c);
await self.send(taskWorkflowQueueName("task.command.workspace.sync_session_status"), input, {
wait: true,
timeout: 20_000,
});
},
async closeWorkspaceSession(c, input: TaskSessionCommand): Promise<void> {
const self = selfTask(c);
await self.send(
taskWorkflowQueueName("task.command.workspace.close_session"),
{ sessionId: input.sessionId, authSessionId: input.authSessionId } satisfies TaskWorkspaceSessionCommand,
{
wait: false,
},
);
},
async publishWorkspacePr(c): Promise<void> {
const self = selfTask(c);
await self.send(
taskWorkflowQueueName("task.command.workspace.publish_pr"),
{},
{
wait: false,
},
);
},
async revertWorkspaceFile(c, input: { path: string }): Promise<void> {
const self = selfTask(c);
await self.send(taskWorkflowQueueName("task.command.workspace.revert_file"), input, {
wait: false,
});
},
},
run: workflow(runTaskWorkflow),
});

View file

@ -65,7 +65,7 @@ export async function handlePushActivity(loopCtx: any, msg: any): Promise<void>
await msg.complete({ ok: true });
}
export async function handleSimpleCommandActivity(loopCtx: any, msg: any, _statusMessage: string, historyKind: string): Promise<void> {
export async function handleSimpleCommandActivity(loopCtx: any, msg: any, historyKind: string): Promise<void> {
await appendAuditLog(loopCtx, historyKind, { reason: msg.body?.reason ?? null });
await msg.complete({ ok: true });
}

View file

@ -2,7 +2,7 @@
import { eq } from "drizzle-orm";
import type { TaskRecord, TaskStatus } from "@sandbox-agent/foundry-shared";
import { task as taskTable, taskRuntime, taskSandboxes } from "../db/schema.js";
import { getOrCreateAuditLog } from "../../handles.js";
import { getOrCreateAuditLog, getOrCreateRepository } from "../../handles.js";
import { broadcastTaskUpdate } from "../workspace.js";
export const TASK_ROW_ID = 1;
@ -66,6 +66,7 @@ export async function setTaskState(ctx: any, status: TaskStatus): Promise<void>
export async function getCurrentRecord(ctx: any): Promise<TaskRecord> {
const db = ctx.db;
const repository = await getOrCreateRepository(ctx, ctx.state.organizationId, ctx.state.repoId);
const row = await db
.select({
branchName: taskTable.branchName,
@ -73,6 +74,7 @@ export async function getCurrentRecord(ctx: any): Promise<TaskRecord> {
task: taskTable.task,
sandboxProviderId: taskTable.sandboxProviderId,
status: taskTable.status,
pullRequestJson: taskTable.pullRequestJson,
activeSandboxId: taskRuntime.activeSandboxId,
createdAt: taskTable.createdAt,
updatedAt: taskTable.updatedAt,
@ -86,6 +88,16 @@ export async function getCurrentRecord(ctx: any): Promise<TaskRecord> {
throw new Error(`Task not found: ${ctx.state.taskId}`);
}
const repositoryMetadata = await repository.getRepositoryMetadata({});
let pullRequest = null;
if (row.pullRequestJson) {
try {
pullRequest = JSON.parse(row.pullRequestJson);
} catch {
pullRequest = null;
}
}
const sandboxes = await db
.select({
sandboxId: taskSandboxes.sandboxId,
@ -102,7 +114,7 @@ export async function getCurrentRecord(ctx: any): Promise<TaskRecord> {
return {
organizationId: ctx.state.organizationId,
repoId: ctx.state.repoId,
repoRemote: ctx.state.repoRemote,
repoRemote: repositoryMetadata.remoteUrl,
taskId: ctx.state.taskId,
branchName: row.branchName,
title: row.title,
@ -110,6 +122,7 @@ export async function getCurrentRecord(ctx: any): Promise<TaskRecord> {
sandboxProviderId: row.sandboxProviderId,
status: row.status,
activeSandboxId: row.activeSandboxId ?? null,
pullRequest,
sandboxes: sandboxes.map((sb) => ({
sandboxId: sb.sandboxId,
sandboxProviderId: sb.sandboxProviderId,
@ -119,25 +132,20 @@ export async function getCurrentRecord(ctx: any): Promise<TaskRecord> {
createdAt: sb.createdAt,
updatedAt: sb.updatedAt,
})),
diffStat: null,
prUrl: null,
prAuthor: null,
ciStatus: null,
reviewStatus: null,
reviewer: null,
createdAt: row.createdAt,
updatedAt: row.updatedAt,
} as TaskRecord;
}
export async function appendAuditLog(ctx: any, kind: string, payload: Record<string, unknown>): Promise<void> {
const row = await ctx.db.select({ branchName: taskTable.branchName }).from(taskTable).where(eq(taskTable.id, TASK_ROW_ID)).get();
const auditLog = await getOrCreateAuditLog(ctx, ctx.state.organizationId, ctx.state.repoId);
await auditLog.send(
"auditLog.command.append",
{
kind,
taskId: ctx.state.taskId,
branchName: ctx.state.branchName,
branchName: row?.branchName ?? null,
payload,
},
{

View file

@ -24,10 +24,12 @@ import {
publishWorkspacePr,
renameWorkspaceTask,
renameWorkspaceSession,
selectWorkspaceSession,
revertWorkspaceFile,
sendWorkspaceMessage,
setWorkspaceSessionUnread,
stopWorkspaceSession,
syncTaskPullRequest,
syncWorkspaceSessionStatus,
updateWorkspaceDraft,
} from "../workspace.js";
@ -71,7 +73,7 @@ const commandHandlers: Record<TaskQueueName, WorkflowHandler> = {
await loopCtx.step("init-complete", async () => initCompleteActivity(loopCtx, msg.body));
await msg.complete({ ok: true });
} catch (error) {
await loopCtx.step("init-failed-v3", async () => initFailedActivity(loopCtx, error));
await loopCtx.step("init-failed-v3", async () => initFailedActivity(loopCtx, error, msg.body));
await msg.complete({
ok: false,
error: resolveErrorMessage(error),
@ -92,11 +94,11 @@ const commandHandlers: Record<TaskQueueName, WorkflowHandler> = {
},
"task.command.sync": async (loopCtx, msg) => {
await loopCtx.step("handle-sync", async () => handleSimpleCommandActivity(loopCtx, msg, "sync requested", "task.sync"));
await loopCtx.step("handle-sync", async () => handleSimpleCommandActivity(loopCtx, msg, "task.sync"));
},
"task.command.merge": async (loopCtx, msg) => {
await loopCtx.step("handle-merge", async () => handleSimpleCommandActivity(loopCtx, msg, "merge requested", "task.merge"));
await loopCtx.step("handle-merge", async () => handleSimpleCommandActivity(loopCtx, msg, "task.merge"));
},
"task.command.archive": async (loopCtx, msg) => {
@ -112,6 +114,11 @@ const commandHandlers: Record<TaskQueueName, WorkflowHandler> = {
await loopCtx.step("handle-get", async () => handleGetActivity(loopCtx, msg));
},
"task.command.pull_request.sync": async (loopCtx, msg) => {
await loopCtx.step("task-pull-request-sync", async () => syncTaskPullRequest(loopCtx, msg.body?.pullRequest ?? null));
await msg.complete({ ok: true });
},
"task.command.workspace.mark_unread": async (loopCtx, msg) => {
await loopCtx.step("workspace-mark-unread", async () => markWorkspaceUnread(loopCtx, msg.body?.authSessionId));
await msg.complete({ ok: true });
@ -169,22 +176,23 @@ const commandHandlers: Record<TaskQueueName, WorkflowHandler> = {
await msg.complete({ ok: true });
},
"task.command.workspace.select_session": async (loopCtx, msg) => {
await loopCtx.step("workspace-select-session", async () => selectWorkspaceSession(loopCtx, msg.body.sessionId, msg.body?.authSessionId));
await msg.complete({ ok: true });
},
"task.command.workspace.set_session_unread": async (loopCtx, msg) => {
await loopCtx.step("workspace-set-session-unread", async () =>
setWorkspaceSessionUnread(loopCtx, msg.body.sessionId, msg.body.unread, msg.body?.authSessionId),
);
await loopCtx.step("workspace-set-session-unread", async () => setWorkspaceSessionUnread(loopCtx, msg.body.sessionId, msg.body.unread, msg.body?.authSessionId));
await msg.complete({ ok: true });
},
"task.command.workspace.update_draft": async (loopCtx, msg) => {
await loopCtx.step("workspace-update-draft", async () =>
updateWorkspaceDraft(loopCtx, msg.body.sessionId, msg.body.text, msg.body.attachments, msg.body?.authSessionId),
);
await loopCtx.step("workspace-update-draft", async () => updateWorkspaceDraft(loopCtx, msg.body.sessionId, msg.body.text, msg.body.attachments, msg.body?.authSessionId));
await msg.complete({ ok: true });
},
"task.command.workspace.change_model": async (loopCtx, msg) => {
await loopCtx.step("workspace-change-model", async () => changeWorkspaceModel(loopCtx, msg.body.sessionId, msg.body.model));
await loopCtx.step("workspace-change-model", async () => changeWorkspaceModel(loopCtx, msg.body.sessionId, msg.body.model, msg.body?.authSessionId));
await msg.complete({ ok: true });
},

View file

@ -11,28 +11,34 @@ import { taskWorkflowQueueName } from "./queue.js";
export async function initBootstrapDbActivity(loopCtx: any, body: any): Promise<void> {
const { config } = getActorRuntimeContext();
const sandboxProviderId = body?.sandboxProviderId ?? defaultSandboxProviderId(config);
const task = body?.task;
if (typeof task !== "string" || task.trim().length === 0) {
throw new Error("task initialize requires the task prompt");
}
const now = Date.now();
await loopCtx.db
.insert(taskTable)
.values({
id: TASK_ROW_ID,
branchName: loopCtx.state.branchName,
title: loopCtx.state.title,
task: loopCtx.state.task,
branchName: body?.branchName ?? null,
title: body?.title ?? null,
task,
sandboxProviderId,
status: "init_bootstrap_db",
pullRequestJson: null,
createdAt: now,
updatedAt: now,
})
.onConflictDoUpdate({
target: taskTable.id,
set: {
branchName: loopCtx.state.branchName,
title: loopCtx.state.title,
task: loopCtx.state.task,
branchName: body?.branchName ?? null,
title: body?.title ?? null,
task,
sandboxProviderId,
status: "init_bootstrap_db",
pullRequestJson: null,
updatedAt: now,
},
})
@ -99,33 +105,36 @@ export async function initCompleteActivity(loopCtx: any, body: any): Promise<voi
});
}
export async function initFailedActivity(loopCtx: any, error: unknown): Promise<void> {
export async function initFailedActivity(loopCtx: any, error: unknown, body?: any): Promise<void> {
const now = Date.now();
const detail = resolveErrorDetail(error);
const messages = collectErrorMessages(error);
const { config } = getActorRuntimeContext();
const sandboxProviderId = defaultSandboxProviderId(config);
const task = typeof body?.task === "string" ? body.task : null;
await loopCtx.db
.insert(taskTable)
.values({
id: TASK_ROW_ID,
branchName: loopCtx.state.branchName ?? null,
title: loopCtx.state.title ?? null,
task: loopCtx.state.task,
branchName: body?.branchName ?? null,
title: body?.title ?? null,
task: task ?? detail,
sandboxProviderId,
status: "error",
pullRequestJson: null,
createdAt: now,
updatedAt: now,
})
.onConflictDoUpdate({
target: taskTable.id,
set: {
branchName: loopCtx.state.branchName ?? null,
title: loopCtx.state.title ?? null,
task: loopCtx.state.task,
branchName: body?.branchName ?? null,
title: body?.title ?? null,
task: task ?? detail,
sandboxProviderId,
status: "error",
pullRequestJson: null,
updatedAt: now,
},
})

View file

@ -1,8 +1,6 @@
// @ts-nocheck
import { eq } from "drizzle-orm";
import { getTaskSandbox } from "../../handles.js";
import { resolveOrganizationGithubAuth } from "../../../services/github-auth.js";
import { taskSandboxes } from "../db/schema.js";
import { appendAuditLog, getCurrentRecord } from "./common.js";
export interface PushActiveBranchOptions {
@ -13,7 +11,7 @@ export interface PushActiveBranchOptions {
export async function pushActiveBranchActivity(loopCtx: any, options: PushActiveBranchOptions = {}): Promise<void> {
const record = await getCurrentRecord(loopCtx);
const activeSandboxId = record.activeSandboxId;
const branchName = loopCtx.state.branchName ?? record.branchName;
const branchName = record.branchName;
if (!activeSandboxId) {
throw new Error("cannot push: no active sandbox");
@ -28,13 +26,6 @@ export async function pushActiveBranchActivity(loopCtx: any, options: PushActive
throw new Error("cannot push: active sandbox cwd is not set");
}
const now = Date.now();
await loopCtx.db
.update(taskSandboxes)
.set({ statusMessage: `pushing branch ${branchName}`, updatedAt: now })
.where(eq(taskSandboxes.sandboxId, activeSandboxId))
.run();
const script = [
"set -euo pipefail",
`cd ${JSON.stringify(cwd)}`,
@ -62,13 +53,6 @@ export async function pushActiveBranchActivity(loopCtx: any, options: PushActive
throw new Error(`git push failed (${result.exitCode ?? 1}): ${[result.stdout, result.stderr].filter(Boolean).join("")}`);
}
const updatedAt = Date.now();
await loopCtx.db
.update(taskSandboxes)
.set({ statusMessage: `push complete for ${branchName}`, updatedAt })
.where(eq(taskSandboxes.sandboxId, activeSandboxId))
.run();
await appendAuditLog(loopCtx, options.historyKind ?? "task.push", {
reason: options.reason ?? null,
branchName,

View file

@ -9,12 +9,14 @@ export const TASK_QUEUE_NAMES = [
"task.command.archive",
"task.command.kill",
"task.command.get",
"task.command.pull_request.sync",
"task.command.workspace.mark_unread",
"task.command.workspace.rename_task",
"task.command.workspace.create_session",
"task.command.workspace.create_session_and_send",
"task.command.workspace.ensure_session",
"task.command.workspace.rename_session",
"task.command.workspace.select_session",
"task.command.workspace.set_session_unread",
"task.command.workspace.update_draft",
"task.command.workspace.change_model",

View file

@ -2,13 +2,17 @@
import { randomUUID } from "node:crypto";
import { basename, dirname } from "node:path";
import { asc, eq } from "drizzle-orm";
import { DEFAULT_WORKSPACE_MODEL_GROUPS, DEFAULT_WORKSPACE_MODEL_ID, workspaceAgentForModel, workspaceSandboxAgentIdForModel } from "@sandbox-agent/foundry-shared";
import { getActorRuntimeContext } from "../context.js";
import { getOrCreateRepository, getOrCreateTaskSandbox, getOrCreateUser, getTaskSandbox, selfTask } from "../handles.js";
import { SANDBOX_REPO_CWD } from "../sandbox/index.js";
import { resolveSandboxProviderId } from "../../sandbox-config.js";
import { getBetterAuthService } from "../../services/better-auth.js";
import { expectQueueResponse } from "../../services/queue.js";
import { resolveOrganizationGithubAuth } from "../../services/github-auth.js";
import { githubRepoFullNameFromRemote } from "../../services/repo.js";
import { repositoryWorkflowQueueName } from "../repository/workflow.js";
import { userWorkflowQueueName } from "../user/workflow.js";
import { task as taskTable, taskRuntime, taskSandboxes, taskWorkspaceSessions } from "./db/schema.js";
import { getCurrentRecord } from "./workflow/common.js";
@ -21,24 +25,29 @@ function emptyGitState() {
};
}
const FALLBACK_MODEL = "claude-sonnet-4";
function isCodexModel(model: string) {
return model.startsWith("gpt-") || model.startsWith("o");
}
const FALLBACK_MODEL = DEFAULT_WORKSPACE_MODEL_ID;
function agentKindForModel(model: string) {
if (isCodexModel(model)) {
return "Codex";
}
return "Claude";
return workspaceAgentForModel(model);
}
export function agentTypeForModel(model: string) {
if (isCodexModel(model)) {
return "codex";
export function sandboxAgentIdForModel(model: string) {
return workspaceSandboxAgentIdForModel(model);
}
async function resolveWorkspaceModelGroups(c: any): Promise<any[]> {
try {
const sandbox = await getOrCreateTaskSandbox(c, c.state.organizationId, stableSandboxId(c));
const groups = await sandbox.listWorkspaceModelGroups();
return Array.isArray(groups) && groups.length > 0 ? groups : DEFAULT_WORKSPACE_MODEL_GROUPS;
} catch {
return DEFAULT_WORKSPACE_MODEL_GROUPS;
}
return "claude";
}
async function resolveSandboxAgentForModel(c: any, model: string): Promise<string> {
const groups = await resolveWorkspaceModelGroups(c);
return workspaceSandboxAgentIdForModel(model, groups);
}
function repoLabelFromRemote(remoteUrl: string): string {
@ -56,6 +65,11 @@ function repoLabelFromRemote(remoteUrl: string): string {
return basename(trimmed.replace(/\.git$/, ""));
}
async function getRepositoryMetadata(c: any): Promise<{ defaultBranch: string | null; fullName: string | null; remoteUrl: string }> {
const repository = await getOrCreateRepository(c, c.state.organizationId, c.state.repoId);
return await repository.getRepositoryMetadata({});
}
function parseDraftAttachments(value: string | null | undefined): Array<any> {
if (!value) {
return [];
@ -220,11 +234,17 @@ async function upsertUserTaskState(c: any, authSessionId: string | null | undefi
}
const user = await getOrCreateUser(c, userId);
await user.upsertTaskState({
taskId: c.state.taskId,
sessionId,
patch,
});
expectQueueResponse(
await user.send(
userWorkflowQueueName("user.command.task_state.upsert"),
{
taskId: c.state.taskId,
sessionId,
patch,
},
{ wait: true, timeout: 60_000 },
),
);
}
async function deleteUserTaskState(c: any, authSessionId: string | null | undefined, sessionId: string): Promise<void> {
@ -239,10 +259,16 @@ async function deleteUserTaskState(c: any, authSessionId: string | null | undefi
}
const user = await getOrCreateUser(c, userId);
await user.deleteTaskState({
taskId: c.state.taskId,
sessionId,
});
expectQueueResponse(
await user.send(
userWorkflowQueueName("user.command.task_state.delete"),
{
taskId: c.state.taskId,
sessionId,
},
{ wait: true, timeout: 60_000 },
),
);
}
async function resolveDefaultModel(c: any, authSessionId?: string | null): Promise<string> {
@ -367,7 +393,7 @@ async function getTaskSandboxRuntime(
}> {
const { config } = getActorRuntimeContext();
const sandboxId = stableSandboxId(c);
const sandboxProviderId = resolveSandboxProviderId(config, record.sandboxProviderId ?? c.state.sandboxProviderId ?? null);
const sandboxProviderId = resolveSandboxProviderId(config, record.sandboxProviderId ?? null);
const sandbox = await getOrCreateTaskSandbox(c, c.state.organizationId, sandboxId, {});
const actorId = typeof sandbox.resolve === "function" ? await sandbox.resolve().catch(() => null) : null;
const switchTarget = sandboxProviderId === "local" ? `sandbox://local/${sandboxId}` : `sandbox://e2b/${sandboxId}`;
@ -381,7 +407,6 @@ async function getTaskSandboxRuntime(
sandboxActorId: typeof actorId === "string" ? actorId : null,
switchTarget,
cwd: SANDBOX_REPO_CWD,
statusMessage: "sandbox ready",
createdAt: now,
updatedAt: now,
})
@ -436,8 +461,7 @@ async function ensureSandboxRepo(c: any, sandbox: any, record: any, opts?: { ski
}
const auth = await resolveOrganizationGithubAuth(c, c.state.organizationId);
const repository = await getOrCreateRepository(c, c.state.organizationId, c.state.repoId, c.state.repoRemote);
const metadata = await repository.getRepositoryMetadata({});
const metadata = await getRepositoryMetadata(c);
const baseRef = metadata.defaultBranch ?? "main";
const sandboxRepoRoot = dirname(SANDBOX_REPO_CWD);
const script = [
@ -445,7 +469,7 @@ async function ensureSandboxRepo(c: any, sandbox: any, record: any, opts?: { ski
`mkdir -p ${JSON.stringify(sandboxRepoRoot)}`,
"git config --global credential.helper '!f() { echo username=x-access-token; echo password=${GH_TOKEN:-$GITHUB_TOKEN}; }; f'",
`if [ ! -d ${JSON.stringify(`${SANDBOX_REPO_CWD}/.git`)} ]; then rm -rf ${JSON.stringify(SANDBOX_REPO_CWD)} && git clone ${JSON.stringify(
c.state.repoRemote,
metadata.remoteUrl,
)} ${JSON.stringify(SANDBOX_REPO_CWD)}; fi`,
`cd ${JSON.stringify(SANDBOX_REPO_CWD)}`,
"git fetch origin --prune",
@ -774,21 +798,8 @@ function computeWorkspaceTaskStatus(record: any, sessions: Array<any>) {
return "idle";
}
async function readPullRequestSummary(c: any, branchName: string | null) {
if (!branchName) {
return null;
}
try {
const repository = await getOrCreateRepository(c, c.state.organizationId, c.state.repoId, c.state.repoRemote);
return await repository.getPullRequestForBranch({ branchName });
} catch {
return null;
}
}
export async function ensureWorkspaceSeeded(c: any): Promise<any> {
return await getCurrentRecord({ db: c.db, state: c.state });
return await getCurrentRecord(c);
}
function buildSessionSummary(meta: any, userState?: any): any {
@ -853,20 +864,24 @@ function buildSessionDetailFromMeta(meta: any, userState?: any): any {
*/
export async function buildTaskSummary(c: any, authSessionId?: string | null): Promise<any> {
const record = await ensureWorkspaceSeeded(c);
const repositoryMetadata = await getRepositoryMetadata(c);
const sessions = await listSessionMetaRows(c);
await maybeScheduleWorkspaceRefreshes(c, record, sessions);
const userTaskState = await getUserTaskState(c, authSessionId);
const taskStatus = computeWorkspaceTaskStatus(record, sessions);
const activeSessionId =
userTaskState.activeSessionId && sessions.some((meta) => meta.sessionId === userTaskState.activeSessionId) ? userTaskState.activeSessionId : null;
return {
id: c.state.taskId,
repoId: c.state.repoId,
title: record.title ?? "New Task",
status: taskStatus ?? "new",
repoName: repoLabelFromRemote(c.state.repoRemote),
status: taskStatus,
repoName: repoLabelFromRemote(repositoryMetadata.remoteUrl),
updatedAtMs: record.updatedAt,
branch: record.branchName,
pullRequest: await readPullRequestSummary(c, record.branchName),
pullRequest: record.pullRequest ?? null,
activeSessionId,
sessionsSummary: sessions.map((meta) => buildSessionSummary(meta, userTaskState.bySessionId.get(meta.sessionId))),
};
}
@ -885,10 +900,6 @@ export async function buildTaskDetail(c: any, authSessionId?: string | null): Pr
return {
...summary,
task: record.task,
runtimeStatus: summary.status,
diffStat: record.diffStat ?? null,
prUrl: record.prUrl ?? null,
reviewStatus: record.reviewStatus ?? null,
fileChanges: gitState.fileChanges,
diffs: gitState.diffs,
fileTree: gitState.fileTree,
@ -959,8 +970,14 @@ export async function getSessionDetail(c: any, sessionId: string, authSessionId?
* - Broadcast full detail/session payloads down to direct task subscribers.
*/
export async function broadcastTaskUpdate(c: any, options?: { sessionId?: string }): Promise<void> {
const repository = await getOrCreateRepository(c, c.state.organizationId, c.state.repoId, c.state.repoRemote);
await repository.applyTaskSummaryUpdate({ taskSummary: await buildTaskSummary(c) });
const repository = await getOrCreateRepository(c, c.state.organizationId, c.state.repoId);
await expectQueueResponse<{ ok: true }>(
await repository.send(
repositoryWorkflowQueueName("repository.command.applyTaskSummaryUpdate"),
{ taskSummary: await buildTaskSummary(c) },
{ wait: true, timeout: 10_000 },
),
);
c.broadcast("taskUpdated", {
type: "taskUpdated",
detail: await buildTaskDetail(c),
@ -1010,6 +1027,19 @@ export async function renameWorkspaceTask(c: any, value: string): Promise<void>
await broadcastTaskUpdate(c);
}
export async function syncTaskPullRequest(c: any, pullRequest: any): Promise<void> {
const now = pullRequest?.updatedAtMs ?? Date.now();
await c.db
.update(taskTable)
.set({
pullRequestJson: pullRequest ? JSON.stringify(pullRequest) : null,
updatedAt: now,
})
.where(eq(taskTable.id, 1))
.run();
await broadcastTaskUpdate(c);
}
export async function createWorkspaceSession(c: any, model?: string, authSessionId?: string): Promise<{ sessionId: string }> {
const sessionId = `session-${randomUUID()}`;
const record = await ensureWorkspaceSeeded(c);
@ -1055,9 +1085,10 @@ export async function ensureWorkspaceSession(c: any, sessionId: string, model?:
const runtime = await getTaskSandboxRuntime(c, record);
await ensureSandboxRepo(c, runtime.sandbox, record);
const resolvedModel = model ?? meta.model ?? (await resolveDefaultModel(c, authSessionId));
const resolvedAgent = await resolveSandboxAgentForModel(c, resolvedModel);
await runtime.sandbox.createSession({
id: meta.sandboxSessionId ?? sessionId,
agent: agentTypeForModel(resolvedModel),
agent: resolvedAgent,
model: resolvedModel,
sessionInit: {
cwd: runtime.cwd,
@ -1113,6 +1144,17 @@ export async function renameWorkspaceSession(c: any, sessionId: string, title: s
await broadcastTaskUpdate(c, { sessionId });
}
export async function selectWorkspaceSession(c: any, sessionId: string, authSessionId?: string): Promise<void> {
const meta = await readSessionMeta(c, sessionId);
if (!meta || meta.closed) {
return;
}
await upsertUserTaskState(c, authSessionId, sessionId, {
activeSessionId: sessionId,
});
await broadcastTaskUpdate(c, { sessionId });
}
export async function setWorkspaceSessionUnread(c: any, sessionId: string, unread: boolean, authSessionId?: string): Promise<void> {
await upsertUserTaskState(c, authSessionId, sessionId, {
unread,
@ -1129,7 +1171,7 @@ export async function updateWorkspaceDraft(c: any, sessionId: string, text: stri
await broadcastTaskUpdate(c, { sessionId });
}
export async function changeWorkspaceModel(c: any, sessionId: string, model: string): Promise<void> {
export async function changeWorkspaceModel(c: any, sessionId: string, model: string, _authSessionId?: string): Promise<void> {
const meta = await readSessionMeta(c, sessionId);
if (!meta || meta.closed) {
return;
@ -1295,6 +1337,13 @@ export async function closeWorkspaceSession(c: any, sessionId: string, authSessi
closed: 1,
thinkingSinceMs: null,
});
const remainingSessions = sessions.filter((candidate) => candidate.sessionId !== sessionId && candidate.closed !== true);
const userTaskState = await getUserTaskState(c, authSessionId);
if (userTaskState.activeSessionId === sessionId && remainingSessions[0]) {
await upsertUserTaskState(c, authSessionId, remainingSessions[0].sessionId, {
activeSessionId: remainingSessions[0].sessionId,
});
}
await deleteUserTaskState(c, authSessionId, sessionId);
await broadcastTaskUpdate(c);
}
@ -1316,19 +1365,30 @@ export async function publishWorkspacePr(c: any): Promise<void> {
if (!record.branchName) {
throw new Error("cannot publish PR without a branch");
}
const repository = await getOrCreateRepository(c, c.state.organizationId, c.state.repoId, c.state.repoRemote);
const metadata = await repository.getRepositoryMetadata({});
const repoFullName = metadata.fullName ?? githubRepoFullNameFromRemote(c.state.repoRemote);
const metadata = await getRepositoryMetadata(c);
const repoFullName = metadata.fullName ?? githubRepoFullNameFromRemote(metadata.remoteUrl);
if (!repoFullName) {
throw new Error(`Unable to resolve GitHub repository for ${c.state.repoRemote}`);
throw new Error(`Unable to resolve GitHub repository for ${metadata.remoteUrl}`);
}
const { driver } = getActorRuntimeContext();
const auth = await resolveOrganizationGithubAuth(c, c.state.organizationId);
await driver.github.createPr(repoFullName, record.branchName, record.title ?? c.state.task, undefined, {
const created = await driver.github.createPr(repoFullName, record.branchName, record.title ?? record.task, undefined, {
githubToken: auth?.githubToken ?? null,
baseBranch: metadata.defaultBranch ?? undefined,
});
await broadcastTaskUpdate(c);
await syncTaskPullRequest(c, {
number: created.number,
title: record.title ?? record.task,
body: null,
state: "open",
url: created.url,
headRefName: record.branchName,
baseRefName: metadata.defaultBranch ?? "main",
authorLogin: null,
isDraft: false,
merged: false,
updatedAtMs: Date.now(),
});
}
export async function revertWorkspaceFile(c: any, path: string): Promise<void> {

View file

@ -0,0 +1,47 @@
import { asc, count as sqlCount, desc } from "drizzle-orm";
import { applyJoinToRow, applyJoinToRows, buildWhere, columnFor, tableFor } from "../query-helpers.js";
export const betterAuthActions = {
// Better Auth adapter action — called by the Better Auth adapter in better-auth.ts.
// Schema and behavior are constrained by Better Auth.
async betterAuthFindOneRecord(c, input: { model: string; where: any[]; join?: any }) {
const table = tableFor(input.model);
const predicate = buildWhere(table, input.where);
const row = predicate ? await c.db.select().from(table).where(predicate).get() : await c.db.select().from(table).get();
return await applyJoinToRow(c, input.model, row ?? null, input.join);
},
// Better Auth adapter action — called by the Better Auth adapter in better-auth.ts.
// Schema and behavior are constrained by Better Auth.
async betterAuthFindManyRecords(c, input: { model: string; where?: any[]; limit?: number; offset?: number; sortBy?: any; join?: any }) {
const table = tableFor(input.model);
const predicate = buildWhere(table, input.where);
let query: any = c.db.select().from(table);
if (predicate) {
query = query.where(predicate);
}
if (input.sortBy?.field) {
const column = columnFor(input.model, table, input.sortBy.field);
query = query.orderBy(input.sortBy.direction === "asc" ? asc(column) : desc(column));
}
if (typeof input.limit === "number") {
query = query.limit(input.limit);
}
if (typeof input.offset === "number") {
query = query.offset(input.offset);
}
const rows = await query.all();
return await applyJoinToRows(c, input.model, rows, input.join);
},
// Better Auth adapter action — called by the Better Auth adapter in better-auth.ts.
// Schema and behavior are constrained by Better Auth.
async betterAuthCountRecords(c, input: { model: string; where?: any[] }) {
const table = tableFor(input.model);
const predicate = buildWhere(table, input.where);
const row = predicate
? await c.db.select({ value: sqlCount() }).from(table).where(predicate).get()
: await c.db.select({ value: sqlCount() }).from(table).get();
return row?.value ?? 0;
},
};

View file

@ -0,0 +1,44 @@
import { eq } from "drizzle-orm";
import { authAccounts, authSessions, authUsers, sessionState, userProfiles, userTaskState } from "../db/schema.js";
import { materializeRow } from "../query-helpers.js";
export const userActions = {
// Custom Foundry action — not part of Better Auth.
async getAppAuthState(c, input: { sessionId: string }) {
const session = await c.db.select().from(authSessions).where(eq(authSessions.id, input.sessionId)).get();
if (!session) {
return null;
}
const [user, profile, currentSessionState, accounts] = await Promise.all([
c.db.select().from(authUsers).where(eq(authUsers.authUserId, session.userId)).get(),
c.db.select().from(userProfiles).where(eq(userProfiles.userId, session.userId)).get(),
c.db.select().from(sessionState).where(eq(sessionState.sessionId, input.sessionId)).get(),
c.db.select().from(authAccounts).where(eq(authAccounts.userId, session.userId)).all(),
]);
return {
session,
user: materializeRow("user", user),
profile: profile ?? null,
sessionState: currentSessionState ?? null,
accounts,
};
},
// Custom Foundry action — not part of Better Auth.
async getTaskState(c, input: { taskId: string }) {
const rows = await c.db.select().from(userTaskState).where(eq(userTaskState.taskId, input.taskId)).all();
const activeSessionId = rows.find((row) => typeof row.activeSessionId === "string" && row.activeSessionId.length > 0)?.activeSessionId ?? null;
return {
taskId: input.taskId,
activeSessionId,
sessions: rows.map((row) => ({
sessionId: row.sessionId,
unread: row.unread === 1,
draftText: row.draftText,
draftAttachmentsJson: row.draftAttachmentsJson,
draftUpdatedAt: row.draftUpdatedAt ?? null,
updatedAt: row.updatedAt,
})),
};
},
};

View file

@ -23,15 +23,19 @@ export default {
journal,
migrations: {
m0000: `CREATE TABLE \`user\` (
\`id\` text PRIMARY KEY NOT NULL,
\`id\` integer PRIMARY KEY NOT NULL,
\`auth_user_id\` text NOT NULL,
\`name\` text NOT NULL,
\`email\` text NOT NULL,
\`email_verified\` integer NOT NULL,
\`image\` text,
\`created_at\` integer NOT NULL,
\`updated_at\` integer NOT NULL
\`updated_at\` integer NOT NULL,
CONSTRAINT \`user_singleton_id_check\` CHECK(\`id\` = 1)
);
--> statement-breakpoint
CREATE UNIQUE INDEX \`user_auth_user_id_idx\` ON \`user\` (\`auth_user_id\`);
--> statement-breakpoint
CREATE TABLE \`session\` (
\`id\` text PRIMARY KEY NOT NULL,
\`token\` text NOT NULL,
@ -69,7 +73,7 @@ CREATE TABLE \`user_profiles\` (
\`github_account_id\` text,
\`github_login\` text,
\`role_label\` text NOT NULL,
\`default_model\` text DEFAULT 'claude-sonnet-4' NOT NULL,
\`default_model\` text DEFAULT 'gpt-5.3-codex' NOT NULL,
\`eligible_organization_ids_json\` text NOT NULL,
\`starter_repo_status\` text NOT NULL,
\`starter_repo_starred_at\` integer,

View file

@ -1,16 +1,25 @@
import { check, integer, primaryKey, sqliteTable, text, uniqueIndex } from "drizzle-orm/sqlite-core";
import { sql } from "drizzle-orm";
import { DEFAULT_WORKSPACE_MODEL_ID } from "@sandbox-agent/foundry-shared";
/** Better Auth core model — schema defined at https://better-auth.com/docs/concepts/database */
export const authUsers = sqliteTable("user", {
id: text("id").notNull().primaryKey(),
name: text("name").notNull(),
email: text("email").notNull(),
emailVerified: integer("email_verified").notNull(),
image: text("image"),
createdAt: integer("created_at").notNull(),
updatedAt: integer("updated_at").notNull(),
});
export const authUsers = sqliteTable(
"user",
{
id: integer("id").primaryKey(),
authUserId: text("auth_user_id").notNull(),
name: text("name").notNull(),
email: text("email").notNull(),
emailVerified: integer("email_verified").notNull(),
image: text("image"),
createdAt: integer("created_at").notNull(),
updatedAt: integer("updated_at").notNull(),
},
(table) => ({
authUserIdIdx: uniqueIndex("user_auth_user_id_idx").on(table.authUserId),
singletonCheck: check("user_singleton_id_check", sql`${table.id} = 1`),
}),
);
/** Better Auth core model — schema defined at https://better-auth.com/docs/concepts/database */
export const authSessions = sqliteTable(
@ -62,7 +71,7 @@ export const userProfiles = sqliteTable(
githubAccountId: text("github_account_id"),
githubLogin: text("github_login"),
roleLabel: text("role_label").notNull(),
defaultModel: text("default_model").notNull().default("claude-sonnet-4"),
defaultModel: text("default_model").notNull().default(DEFAULT_WORKSPACE_MODEL_ID),
eligibleOrganizationIdsJson: text("eligible_organization_ids_json").notNull(),
starterRepoStatus: text("starter_repo_status").notNull(),
starterRepoStarredAt: integer("starter_repo_starred_at"),

View file

@ -1,158 +1,13 @@
import { and, asc, count as sqlCount, desc, eq, gt, gte, inArray, isNotNull, isNull, like, lt, lte, ne, notInArray, or } from "drizzle-orm";
import { actor } from "rivetkit";
import { actor, queue } from "rivetkit";
import { workflow } from "rivetkit/workflow";
import { userDb } from "./db/db.js";
import { authAccounts, authSessions, authUsers, sessionState, userProfiles, userTaskState } from "./db/schema.js";
const tables = {
user: authUsers,
session: authSessions,
account: authAccounts,
userProfiles,
sessionState,
userTaskState,
} as const;
function tableFor(model: string) {
const table = tables[model as keyof typeof tables];
if (!table) {
throw new Error(`Unsupported user model: ${model}`);
}
return table as any;
}
function columnFor(table: any, field: string) {
const column = table[field];
if (!column) {
throw new Error(`Unsupported user field: ${field}`);
}
return column;
}
function normalizeValue(value: unknown): unknown {
if (value instanceof Date) {
return value.getTime();
}
if (Array.isArray(value)) {
return value.map((entry) => normalizeValue(entry));
}
return value;
}
function clauseToExpr(table: any, clause: any) {
const column = columnFor(table, clause.field);
const value = normalizeValue(clause.value);
switch (clause.operator) {
case "ne":
return value === null ? isNotNull(column) : ne(column, value as any);
case "lt":
return lt(column, value as any);
case "lte":
return lte(column, value as any);
case "gt":
return gt(column, value as any);
case "gte":
return gte(column, value as any);
case "in":
return inArray(column, Array.isArray(value) ? (value as any[]) : [value as any]);
case "not_in":
return notInArray(column, Array.isArray(value) ? (value as any[]) : [value as any]);
case "contains":
return like(column, `%${String(value ?? "")}%`);
case "starts_with":
return like(column, `${String(value ?? "")}%`);
case "ends_with":
return like(column, `%${String(value ?? "")}`);
case "eq":
default:
return value === null ? isNull(column) : eq(column, value as any);
}
}
function buildWhere(table: any, where: any[] | undefined) {
if (!where || where.length === 0) {
return undefined;
}
let expr = clauseToExpr(table, where[0]);
for (const clause of where.slice(1)) {
const next = clauseToExpr(table, clause);
expr = clause.connector === "OR" ? or(expr, next) : and(expr, next);
}
return expr;
}
function applyJoinToRow(c: any, model: string, row: any, join: any) {
if (!row || !join) {
return row;
}
if (model === "session" && join.user) {
return c.db
.select()
.from(authUsers)
.where(eq(authUsers.id, row.userId))
.get()
.then((user: any) => ({ ...row, user: user ?? null }));
}
if (model === "account" && join.user) {
return c.db
.select()
.from(authUsers)
.where(eq(authUsers.id, row.userId))
.get()
.then((user: any) => ({ ...row, user: user ?? null }));
}
if (model === "user" && join.account) {
return c.db
.select()
.from(authAccounts)
.where(eq(authAccounts.userId, row.id))
.all()
.then((accounts: any[]) => ({ ...row, account: accounts }));
}
return Promise.resolve(row);
}
async function applyJoinToRows(c: any, model: string, rows: any[], join: any) {
if (!join || rows.length === 0) {
return rows;
}
if (model === "session" && join.user) {
const userIds = [...new Set(rows.map((row) => row.userId).filter(Boolean))];
const users = userIds.length > 0 ? await c.db.select().from(authUsers).where(inArray(authUsers.id, userIds)).all() : [];
const userMap = new Map(users.map((user: any) => [user.id, user]));
return rows.map((row) => ({ ...row, user: userMap.get(row.userId) ?? null }));
}
if (model === "account" && join.user) {
const userIds = [...new Set(rows.map((row) => row.userId).filter(Boolean))];
const users = userIds.length > 0 ? await c.db.select().from(authUsers).where(inArray(authUsers.id, userIds)).all() : [];
const userMap = new Map(users.map((user: any) => [user.id, user]));
return rows.map((row) => ({ ...row, user: userMap.get(row.userId) ?? null }));
}
if (model === "user" && join.account) {
const userIds = rows.map((row) => row.id);
const accounts = userIds.length > 0 ? await c.db.select().from(authAccounts).where(inArray(authAccounts.userId, userIds)).all() : [];
const accountsByUserId = new Map<string, any[]>();
for (const account of accounts) {
const entries = accountsByUserId.get(account.userId) ?? [];
entries.push(account);
accountsByUserId.set(account.userId, entries);
}
return rows.map((row) => ({ ...row, account: accountsByUserId.get(row.id) ?? [] }));
}
return rows;
}
import { betterAuthActions } from "./actions/better-auth.js";
import { userActions } from "./actions/user.js";
import { USER_QUEUE_NAMES, runUserWorkflow } from "./workflow.js";
export const user = actor({
db: userDb,
queues: Object.fromEntries(USER_QUEUE_NAMES.map((name) => [name, queue()])),
options: {
name: "User",
icon: "shield",
@ -162,312 +17,8 @@ export const user = actor({
userId: input.userId,
}),
actions: {
// Better Auth adapter action — called by the Better Auth adapter in better-auth.ts.
// Schema and behavior are constrained by Better Auth.
async createAuthRecord(c, input: { model: string; data: Record<string, unknown> }) {
const table = tableFor(input.model);
await c.db
.insert(table)
.values(input.data as any)
.run();
return await c.db
.select()
.from(table)
.where(eq(columnFor(table, "id"), input.data.id as any))
.get();
},
// Better Auth adapter action — called by the Better Auth adapter in better-auth.ts.
// Schema and behavior are constrained by Better Auth.
async findOneAuthRecord(c, input: { model: string; where: any[]; join?: any }) {
const table = tableFor(input.model);
const predicate = buildWhere(table, input.where);
const row = predicate ? await c.db.select().from(table).where(predicate).get() : await c.db.select().from(table).get();
return await applyJoinToRow(c, input.model, row ?? null, input.join);
},
// Better Auth adapter action — called by the Better Auth adapter in better-auth.ts.
// Schema and behavior are constrained by Better Auth.
async findManyAuthRecords(c, input: { model: string; where?: any[]; limit?: number; offset?: number; sortBy?: any; join?: any }) {
const table = tableFor(input.model);
const predicate = buildWhere(table, input.where);
let query: any = c.db.select().from(table);
if (predicate) {
query = query.where(predicate);
}
if (input.sortBy?.field) {
const column = columnFor(table, input.sortBy.field);
query = query.orderBy(input.sortBy.direction === "asc" ? asc(column) : desc(column));
}
if (typeof input.limit === "number") {
query = query.limit(input.limit);
}
if (typeof input.offset === "number") {
query = query.offset(input.offset);
}
const rows = await query.all();
return await applyJoinToRows(c, input.model, rows, input.join);
},
// Better Auth adapter action — called by the Better Auth adapter in better-auth.ts.
// Schema and behavior are constrained by Better Auth.
async updateAuthRecord(c, input: { model: string; where: any[]; update: Record<string, unknown> }) {
const table = tableFor(input.model);
const predicate = buildWhere(table, input.where);
if (!predicate) {
throw new Error("updateAuthRecord requires a where clause");
}
await c.db
.update(table)
.set(input.update as any)
.where(predicate)
.run();
return await c.db.select().from(table).where(predicate).get();
},
// Better Auth adapter action — called by the Better Auth adapter in better-auth.ts.
// Schema and behavior are constrained by Better Auth.
async updateManyAuthRecords(c, input: { model: string; where: any[]; update: Record<string, unknown> }) {
const table = tableFor(input.model);
const predicate = buildWhere(table, input.where);
if (!predicate) {
throw new Error("updateManyAuthRecords requires a where clause");
}
await c.db
.update(table)
.set(input.update as any)
.where(predicate)
.run();
const row = await c.db.select({ value: sqlCount() }).from(table).where(predicate).get();
return row?.value ?? 0;
},
// Better Auth adapter action — called by the Better Auth adapter in better-auth.ts.
// Schema and behavior are constrained by Better Auth.
async deleteAuthRecord(c, input: { model: string; where: any[] }) {
const table = tableFor(input.model);
const predicate = buildWhere(table, input.where);
if (!predicate) {
throw new Error("deleteAuthRecord requires a where clause");
}
await c.db.delete(table).where(predicate).run();
},
// Better Auth adapter action — called by the Better Auth adapter in better-auth.ts.
// Schema and behavior are constrained by Better Auth.
async deleteManyAuthRecords(c, input: { model: string; where: any[] }) {
const table = tableFor(input.model);
const predicate = buildWhere(table, input.where);
if (!predicate) {
throw new Error("deleteManyAuthRecords requires a where clause");
}
const rows = await c.db.select().from(table).where(predicate).all();
await c.db.delete(table).where(predicate).run();
return rows.length;
},
// Better Auth adapter action — called by the Better Auth adapter in better-auth.ts.
// Schema and behavior are constrained by Better Auth.
async countAuthRecords(c, input: { model: string; where?: any[] }) {
const table = tableFor(input.model);
const predicate = buildWhere(table, input.where);
const row = predicate
? await c.db.select({ value: sqlCount() }).from(table).where(predicate).get()
: await c.db.select({ value: sqlCount() }).from(table).get();
return row?.value ?? 0;
},
// Custom Foundry action — not part of Better Auth.
async getAppAuthState(c, input: { sessionId: string }) {
const session = await c.db.select().from(authSessions).where(eq(authSessions.id, input.sessionId)).get();
if (!session) {
return null;
}
const [user, profile, currentSessionState, accounts] = await Promise.all([
c.db.select().from(authUsers).where(eq(authUsers.id, session.userId)).get(),
c.db.select().from(userProfiles).where(eq(userProfiles.userId, session.userId)).get(),
c.db.select().from(sessionState).where(eq(sessionState.sessionId, input.sessionId)).get(),
c.db.select().from(authAccounts).where(eq(authAccounts.userId, session.userId)).all(),
]);
return {
session,
user,
profile: profile ?? null,
sessionState: currentSessionState ?? null,
accounts,
};
},
// Custom Foundry action — not part of Better Auth.
async upsertUserProfile(
c,
input: {
userId: string;
patch: {
githubAccountId?: string | null;
githubLogin?: string | null;
roleLabel?: string;
defaultModel?: string;
eligibleOrganizationIdsJson?: string;
starterRepoStatus?: string;
starterRepoStarredAt?: number | null;
starterRepoSkippedAt?: number | null;
};
},
) {
const now = Date.now();
await c.db
.insert(userProfiles)
.values({
id: 1,
userId: input.userId,
githubAccountId: input.patch.githubAccountId ?? null,
githubLogin: input.patch.githubLogin ?? null,
roleLabel: input.patch.roleLabel ?? "GitHub user",
defaultModel: input.patch.defaultModel ?? "claude-sonnet-4",
eligibleOrganizationIdsJson: input.patch.eligibleOrganizationIdsJson ?? "[]",
starterRepoStatus: input.patch.starterRepoStatus ?? "pending",
starterRepoStarredAt: input.patch.starterRepoStarredAt ?? null,
starterRepoSkippedAt: input.patch.starterRepoSkippedAt ?? null,
createdAt: now,
updatedAt: now,
})
.onConflictDoUpdate({
target: userProfiles.userId,
set: {
...(input.patch.githubAccountId !== undefined ? { githubAccountId: input.patch.githubAccountId } : {}),
...(input.patch.githubLogin !== undefined ? { githubLogin: input.patch.githubLogin } : {}),
...(input.patch.roleLabel !== undefined ? { roleLabel: input.patch.roleLabel } : {}),
...(input.patch.defaultModel !== undefined ? { defaultModel: input.patch.defaultModel } : {}),
...(input.patch.eligibleOrganizationIdsJson !== undefined ? { eligibleOrganizationIdsJson: input.patch.eligibleOrganizationIdsJson } : {}),
...(input.patch.starterRepoStatus !== undefined ? { starterRepoStatus: input.patch.starterRepoStatus } : {}),
...(input.patch.starterRepoStarredAt !== undefined ? { starterRepoStarredAt: input.patch.starterRepoStarredAt } : {}),
...(input.patch.starterRepoSkippedAt !== undefined ? { starterRepoSkippedAt: input.patch.starterRepoSkippedAt } : {}),
updatedAt: now,
},
})
.run();
return await c.db.select().from(userProfiles).where(eq(userProfiles.userId, input.userId)).get();
},
// Custom Foundry action — not part of Better Auth.
async upsertSessionState(c, input: { sessionId: string; activeOrganizationId: string | null }) {
const now = Date.now();
await c.db
.insert(sessionState)
.values({
sessionId: input.sessionId,
activeOrganizationId: input.activeOrganizationId,
createdAt: now,
updatedAt: now,
})
.onConflictDoUpdate({
target: sessionState.sessionId,
set: {
activeOrganizationId: input.activeOrganizationId,
updatedAt: now,
},
})
.run();
return await c.db.select().from(sessionState).where(eq(sessionState.sessionId, input.sessionId)).get();
},
// Custom Foundry action — not part of Better Auth.
async getTaskState(c, input: { taskId: string }) {
const rows = await c.db.select().from(userTaskState).where(eq(userTaskState.taskId, input.taskId)).all();
const activeSessionId = rows.find((row) => typeof row.activeSessionId === "string" && row.activeSessionId.length > 0)?.activeSessionId ?? null;
return {
taskId: input.taskId,
activeSessionId,
sessions: rows.map((row) => ({
sessionId: row.sessionId,
unread: row.unread === 1,
draftText: row.draftText,
draftAttachmentsJson: row.draftAttachmentsJson,
draftUpdatedAt: row.draftUpdatedAt ?? null,
updatedAt: row.updatedAt,
})),
};
},
// Custom Foundry action — not part of Better Auth.
async upsertTaskState(
c,
input: {
taskId: string;
sessionId: string;
patch: {
activeSessionId?: string | null;
unread?: boolean;
draftText?: string;
draftAttachmentsJson?: string;
draftUpdatedAt?: number | null;
};
},
) {
const now = Date.now();
const existing = await c.db
.select()
.from(userTaskState)
.where(and(eq(userTaskState.taskId, input.taskId), eq(userTaskState.sessionId, input.sessionId)))
.get();
if (input.patch.activeSessionId !== undefined) {
await c.db
.update(userTaskState)
.set({
activeSessionId: input.patch.activeSessionId,
updatedAt: now,
})
.where(eq(userTaskState.taskId, input.taskId))
.run();
}
await c.db
.insert(userTaskState)
.values({
taskId: input.taskId,
sessionId: input.sessionId,
activeSessionId: input.patch.activeSessionId ?? existing?.activeSessionId ?? null,
unread: input.patch.unread !== undefined ? (input.patch.unread ? 1 : 0) : (existing?.unread ?? 0),
draftText: input.patch.draftText ?? existing?.draftText ?? "",
draftAttachmentsJson: input.patch.draftAttachmentsJson ?? existing?.draftAttachmentsJson ?? "[]",
draftUpdatedAt: input.patch.draftUpdatedAt === undefined ? (existing?.draftUpdatedAt ?? null) : input.patch.draftUpdatedAt,
updatedAt: now,
})
.onConflictDoUpdate({
target: [userTaskState.taskId, userTaskState.sessionId],
set: {
...(input.patch.activeSessionId !== undefined ? { activeSessionId: input.patch.activeSessionId } : {}),
...(input.patch.unread !== undefined ? { unread: input.patch.unread ? 1 : 0 } : {}),
...(input.patch.draftText !== undefined ? { draftText: input.patch.draftText } : {}),
...(input.patch.draftAttachmentsJson !== undefined ? { draftAttachmentsJson: input.patch.draftAttachmentsJson } : {}),
...(input.patch.draftUpdatedAt !== undefined ? { draftUpdatedAt: input.patch.draftUpdatedAt } : {}),
updatedAt: now,
},
})
.run();
return await c.db
.select()
.from(userTaskState)
.where(and(eq(userTaskState.taskId, input.taskId), eq(userTaskState.sessionId, input.sessionId)))
.get();
},
// Custom Foundry action — not part of Better Auth.
async deleteTaskState(c, input: { taskId: string; sessionId?: string }) {
if (input.sessionId) {
await c.db
.delete(userTaskState)
.where(and(eq(userTaskState.taskId, input.taskId), eq(userTaskState.sessionId, input.sessionId)))
.run();
return;
}
await c.db.delete(userTaskState).where(eq(userTaskState.taskId, input.taskId)).run();
},
...betterAuthActions,
...userActions,
},
run: workflow(runUserWorkflow),
});

View file

@ -0,0 +1,197 @@
import { and, eq, inArray, isNotNull, isNull, like, lt, lte, gt, gte, ne, notInArray, or } from "drizzle-orm";
import { authAccounts, authSessions, authUsers, sessionState, userProfiles, userTaskState } from "./db/schema.js";
export const userTables = {
user: authUsers,
session: authSessions,
account: authAccounts,
userProfiles,
sessionState,
userTaskState,
} as const;
export function tableFor(model: string) {
const table = userTables[model as keyof typeof userTables];
if (!table) {
throw new Error(`Unsupported user model: ${model}`);
}
return table as any;
}
function dbFieldFor(model: string, field: string): string {
if (model === "user" && field === "id") {
return "authUserId";
}
return field;
}
export function materializeRow(model: string, row: any) {
if (!row || model !== "user") {
return row;
}
const { id: _singletonId, authUserId, ...rest } = row;
return {
id: authUserId,
...rest,
};
}
export function persistInput(model: string, data: Record<string, unknown>) {
if (model !== "user") {
return data;
}
const { id, ...rest } = data;
return {
id: 1,
authUserId: id,
...rest,
};
}
export function persistPatch(model: string, data: Record<string, unknown>) {
if (model !== "user") {
return data;
}
const { id, ...rest } = data;
return {
...(id !== undefined ? { authUserId: id } : {}),
...rest,
};
}
export function columnFor(model: string, table: any, field: string) {
const column = table[dbFieldFor(model, field)];
if (!column) {
throw new Error(`Unsupported user field: ${model}.${field}`);
}
return column;
}
export function normalizeValue(value: unknown): unknown {
if (value instanceof Date) {
return value.getTime();
}
if (Array.isArray(value)) {
return value.map((entry) => normalizeValue(entry));
}
return value;
}
export function clauseToExpr(table: any, clause: any) {
const model = table === authUsers ? "user" : table === authSessions ? "session" : table === authAccounts ? "account" : "";
const column = columnFor(model, table, clause.field);
const value = normalizeValue(clause.value);
switch (clause.operator) {
case "ne":
return value === null ? isNotNull(column) : ne(column, value as any);
case "lt":
return lt(column, value as any);
case "lte":
return lte(column, value as any);
case "gt":
return gt(column, value as any);
case "gte":
return gte(column, value as any);
case "in":
return inArray(column, Array.isArray(value) ? (value as any[]) : [value as any]);
case "not_in":
return notInArray(column, Array.isArray(value) ? (value as any[]) : [value as any]);
case "contains":
return like(column, `%${String(value ?? "")}%`);
case "starts_with":
return like(column, `${String(value ?? "")}%`);
case "ends_with":
return like(column, `%${String(value ?? "")}`);
case "eq":
default:
return value === null ? isNull(column) : eq(column, value as any);
}
}
export function buildWhere(table: any, where: any[] | undefined) {
if (!where || where.length === 0) {
return undefined;
}
let expr = clauseToExpr(table, where[0]);
for (const clause of where.slice(1)) {
const next = clauseToExpr(table, clause);
expr = clause.connector === "OR" ? or(expr, next) : and(expr, next);
}
return expr;
}
export function applyJoinToRow(c: any, model: string, row: any, join: any) {
const materialized = materializeRow(model, row);
if (!materialized || !join) {
return materialized;
}
if (model === "session" && join.user) {
return c.db
.select()
.from(authUsers)
.where(eq(authUsers.authUserId, materialized.userId))
.get()
.then((user: any) => ({ ...materialized, user: materializeRow("user", user) ?? null }));
}
if (model === "account" && join.user) {
return c.db
.select()
.from(authUsers)
.where(eq(authUsers.authUserId, materialized.userId))
.get()
.then((user: any) => ({ ...materialized, user: materializeRow("user", user) ?? null }));
}
if (model === "user" && join.account) {
return c.db
.select()
.from(authAccounts)
.where(eq(authAccounts.userId, materialized.id))
.all()
.then((accounts: any[]) => ({ ...materialized, account: accounts }));
}
return Promise.resolve(materialized);
}
export async function applyJoinToRows(c: any, model: string, rows: any[], join: any) {
if (!join || rows.length === 0) {
return rows.map((row) => materializeRow(model, row));
}
if (model === "session" && join.user) {
const userIds = [...new Set(rows.map((row) => row.userId).filter(Boolean))];
const users = userIds.length > 0 ? await c.db.select().from(authUsers).where(inArray(authUsers.authUserId, userIds)).all() : [];
const userMap = new Map(users.map((user: any) => [user.authUserId, materializeRow("user", user)]));
return rows.map((row) => ({ ...row, user: userMap.get(row.userId) ?? null }));
}
if (model === "account" && join.user) {
const userIds = [...new Set(rows.map((row) => row.userId).filter(Boolean))];
const users = userIds.length > 0 ? await c.db.select().from(authUsers).where(inArray(authUsers.authUserId, userIds)).all() : [];
const userMap = new Map(users.map((user: any) => [user.authUserId, materializeRow("user", user)]));
return rows.map((row) => ({ ...row, user: userMap.get(row.userId) ?? null }));
}
if (model === "user" && join.account) {
const materializedRows = rows.map((row) => materializeRow("user", row));
const userIds = materializedRows.map((row) => row.id);
const accounts = userIds.length > 0 ? await c.db.select().from(authAccounts).where(inArray(authAccounts.userId, userIds)).all() : [];
const accountsByUserId = new Map<string, any[]>();
for (const account of accounts) {
const entries = accountsByUserId.get(account.userId) ?? [];
entries.push(account);
accountsByUserId.set(account.userId, entries);
}
return materializedRows.map((row) => ({ ...row, account: accountsByUserId.get(row.id) ?? [] }));
}
return rows.map((row) => materializeRow(model, row));
}

View file

@ -0,0 +1,281 @@
import { eq, count as sqlCount, and } from "drizzle-orm";
import { Loop } from "rivetkit/workflow";
import { DEFAULT_WORKSPACE_MODEL_ID } from "@sandbox-agent/foundry-shared";
import { logActorWarning, resolveErrorMessage } from "../logging.js";
import { authUsers, sessionState, userProfiles, userTaskState } from "./db/schema.js";
import { buildWhere, columnFor, materializeRow, persistInput, persistPatch, tableFor } from "./query-helpers.js";
export const USER_QUEUE_NAMES = [
"user.command.auth.create",
"user.command.auth.update",
"user.command.auth.update_many",
"user.command.auth.delete",
"user.command.auth.delete_many",
"user.command.profile.upsert",
"user.command.session_state.upsert",
"user.command.task_state.upsert",
"user.command.task_state.delete",
] as const;
export type UserQueueName = (typeof USER_QUEUE_NAMES)[number];
export function userWorkflowQueueName(name: UserQueueName): UserQueueName {
return name;
}
async function createAuthRecordMutation(c: any, input: { model: string; data: Record<string, unknown> }) {
const table = tableFor(input.model);
const persisted = persistInput(input.model, input.data);
await c.db.insert(table).values(persisted as any).run();
const row = await c.db.select().from(table).where(eq(columnFor(input.model, table, "id"), input.data.id as any)).get();
return materializeRow(input.model, row);
}
async function updateAuthRecordMutation(c: any, input: { model: string; where: any[]; update: Record<string, unknown> }) {
const table = tableFor(input.model);
const predicate = buildWhere(table, input.where);
if (!predicate) {
throw new Error("updateAuthRecord requires a where clause");
}
await c.db.update(table).set(persistPatch(input.model, input.update) as any).where(predicate).run();
return materializeRow(input.model, await c.db.select().from(table).where(predicate).get());
}
async function updateManyAuthRecordsMutation(c: any, input: { model: string; where: any[]; update: Record<string, unknown> }) {
const table = tableFor(input.model);
const predicate = buildWhere(table, input.where);
if (!predicate) {
throw new Error("updateManyAuthRecords requires a where clause");
}
await c.db.update(table).set(persistPatch(input.model, input.update) as any).where(predicate).run();
const row = await c.db.select({ value: sqlCount() }).from(table).where(predicate).get();
return row?.value ?? 0;
}
async function deleteAuthRecordMutation(c: any, input: { model: string; where: any[] }) {
const table = tableFor(input.model);
const predicate = buildWhere(table, input.where);
if (!predicate) {
throw new Error("deleteAuthRecord requires a where clause");
}
await c.db.delete(table).where(predicate).run();
}
async function deleteManyAuthRecordsMutation(c: any, input: { model: string; where: any[] }) {
const table = tableFor(input.model);
const predicate = buildWhere(table, input.where);
if (!predicate) {
throw new Error("deleteManyAuthRecords requires a where clause");
}
const rows = await c.db.select().from(table).where(predicate).all();
await c.db.delete(table).where(predicate).run();
return rows.length;
}
async function upsertUserProfileMutation(
c: any,
input: {
userId: string;
patch: {
githubAccountId?: string | null;
githubLogin?: string | null;
roleLabel?: string;
defaultModel?: string;
eligibleOrganizationIdsJson?: string;
starterRepoStatus?: string;
starterRepoStarredAt?: number | null;
starterRepoSkippedAt?: number | null;
};
},
) {
const now = Date.now();
await c.db
.insert(userProfiles)
.values({
id: 1,
userId: input.userId,
githubAccountId: input.patch.githubAccountId ?? null,
githubLogin: input.patch.githubLogin ?? null,
roleLabel: input.patch.roleLabel ?? "GitHub user",
defaultModel: input.patch.defaultModel ?? DEFAULT_WORKSPACE_MODEL_ID,
eligibleOrganizationIdsJson: input.patch.eligibleOrganizationIdsJson ?? "[]",
starterRepoStatus: input.patch.starterRepoStatus ?? "pending",
starterRepoStarredAt: input.patch.starterRepoStarredAt ?? null,
starterRepoSkippedAt: input.patch.starterRepoSkippedAt ?? null,
createdAt: now,
updatedAt: now,
})
.onConflictDoUpdate({
target: userProfiles.userId,
set: {
...(input.patch.githubAccountId !== undefined ? { githubAccountId: input.patch.githubAccountId } : {}),
...(input.patch.githubLogin !== undefined ? { githubLogin: input.patch.githubLogin } : {}),
...(input.patch.roleLabel !== undefined ? { roleLabel: input.patch.roleLabel } : {}),
...(input.patch.defaultModel !== undefined ? { defaultModel: input.patch.defaultModel } : {}),
...(input.patch.eligibleOrganizationIdsJson !== undefined ? { eligibleOrganizationIdsJson: input.patch.eligibleOrganizationIdsJson } : {}),
...(input.patch.starterRepoStatus !== undefined ? { starterRepoStatus: input.patch.starterRepoStatus } : {}),
...(input.patch.starterRepoStarredAt !== undefined ? { starterRepoStarredAt: input.patch.starterRepoStarredAt } : {}),
...(input.patch.starterRepoSkippedAt !== undefined ? { starterRepoSkippedAt: input.patch.starterRepoSkippedAt } : {}),
updatedAt: now,
},
})
.run();
return await c.db.select().from(userProfiles).where(eq(userProfiles.userId, input.userId)).get();
}
async function upsertSessionStateMutation(c: any, input: { sessionId: string; activeOrganizationId: string | null }) {
const now = Date.now();
await c.db
.insert(sessionState)
.values({
sessionId: input.sessionId,
activeOrganizationId: input.activeOrganizationId,
createdAt: now,
updatedAt: now,
})
.onConflictDoUpdate({
target: sessionState.sessionId,
set: {
activeOrganizationId: input.activeOrganizationId,
updatedAt: now,
},
})
.run();
return await c.db.select().from(sessionState).where(eq(sessionState.sessionId, input.sessionId)).get();
}
async function upsertTaskStateMutation(
c: any,
input: {
taskId: string;
sessionId: string;
patch: {
activeSessionId?: string | null;
unread?: boolean;
draftText?: string;
draftAttachmentsJson?: string;
draftUpdatedAt?: number | null;
};
},
) {
const now = Date.now();
const existing = await c.db
.select()
.from(userTaskState)
.where(and(eq(userTaskState.taskId, input.taskId), eq(userTaskState.sessionId, input.sessionId)))
.get();
if (input.patch.activeSessionId !== undefined) {
await c.db
.update(userTaskState)
.set({
activeSessionId: input.patch.activeSessionId,
updatedAt: now,
})
.where(eq(userTaskState.taskId, input.taskId))
.run();
}
await c.db
.insert(userTaskState)
.values({
taskId: input.taskId,
sessionId: input.sessionId,
activeSessionId: input.patch.activeSessionId ?? existing?.activeSessionId ?? null,
unread: input.patch.unread !== undefined ? (input.patch.unread ? 1 : 0) : (existing?.unread ?? 0),
draftText: input.patch.draftText ?? existing?.draftText ?? "",
draftAttachmentsJson: input.patch.draftAttachmentsJson ?? existing?.draftAttachmentsJson ?? "[]",
draftUpdatedAt: input.patch.draftUpdatedAt === undefined ? (existing?.draftUpdatedAt ?? null) : input.patch.draftUpdatedAt,
updatedAt: now,
})
.onConflictDoUpdate({
target: [userTaskState.taskId, userTaskState.sessionId],
set: {
...(input.patch.activeSessionId !== undefined ? { activeSessionId: input.patch.activeSessionId } : {}),
...(input.patch.unread !== undefined ? { unread: input.patch.unread ? 1 : 0 } : {}),
...(input.patch.draftText !== undefined ? { draftText: input.patch.draftText } : {}),
...(input.patch.draftAttachmentsJson !== undefined ? { draftAttachmentsJson: input.patch.draftAttachmentsJson } : {}),
...(input.patch.draftUpdatedAt !== undefined ? { draftUpdatedAt: input.patch.draftUpdatedAt } : {}),
updatedAt: now,
},
})
.run();
return await c.db
.select()
.from(userTaskState)
.where(and(eq(userTaskState.taskId, input.taskId), eq(userTaskState.sessionId, input.sessionId)))
.get();
}
async function deleteTaskStateMutation(c: any, input: { taskId: string; sessionId?: string }) {
if (input.sessionId) {
await c.db
.delete(userTaskState)
.where(and(eq(userTaskState.taskId, input.taskId), eq(userTaskState.sessionId, input.sessionId)))
.run();
return;
}
await c.db.delete(userTaskState).where(eq(userTaskState.taskId, input.taskId)).run();
}
export async function runUserWorkflow(ctx: any): Promise<void> {
await ctx.loop("user-command-loop", async (loopCtx: any) => {
const msg = await loopCtx.queue.next("next-user-command", {
names: [...USER_QUEUE_NAMES],
completable: true,
});
if (!msg) {
return Loop.continue(undefined);
}
try {
let result: unknown;
switch (msg.name) {
case "user.command.auth.create":
result = await loopCtx.step({ name: "user-auth-create", timeout: 60_000, run: async () => createAuthRecordMutation(loopCtx, msg.body) });
break;
case "user.command.auth.update":
result = await loopCtx.step({ name: "user-auth-update", timeout: 60_000, run: async () => updateAuthRecordMutation(loopCtx, msg.body) });
break;
case "user.command.auth.update_many":
result = await loopCtx.step({ name: "user-auth-update-many", timeout: 60_000, run: async () => updateManyAuthRecordsMutation(loopCtx, msg.body) });
break;
case "user.command.auth.delete":
result = await loopCtx.step({ name: "user-auth-delete", timeout: 60_000, run: async () => deleteAuthRecordMutation(loopCtx, msg.body) });
break;
case "user.command.auth.delete_many":
result = await loopCtx.step({ name: "user-auth-delete-many", timeout: 60_000, run: async () => deleteManyAuthRecordsMutation(loopCtx, msg.body) });
break;
case "user.command.profile.upsert":
result = await loopCtx.step({ name: "user-profile-upsert", timeout: 60_000, run: async () => upsertUserProfileMutation(loopCtx, msg.body) });
break;
case "user.command.session_state.upsert":
result = await loopCtx.step({ name: "user-session-state-upsert", timeout: 60_000, run: async () => upsertSessionStateMutation(loopCtx, msg.body) });
break;
case "user.command.task_state.upsert":
result = await loopCtx.step({ name: "user-task-state-upsert", timeout: 60_000, run: async () => upsertTaskStateMutation(loopCtx, msg.body) });
break;
case "user.command.task_state.delete":
result = await loopCtx.step({ name: "user-task-state-delete", timeout: 60_000, run: async () => deleteTaskStateMutation(loopCtx, msg.body) });
break;
default:
return Loop.continue(undefined);
}
await msg.complete(result);
} catch (error) {
const message = resolveErrorMessage(error);
logActorWarning("user", "user workflow command failed", {
queueName: msg.name,
error: message,
});
await msg.complete({ error: message }).catch(() => {});
}
return Loop.continue(undefined);
});
}

View file

@ -10,7 +10,7 @@ import { createDefaultDriver } from "./driver.js";
import { createClient } from "rivetkit/client";
import { initBetterAuthService } from "./services/better-auth.js";
import { createDefaultAppShellServices } from "./services/app-shell-runtime.js";
import { APP_SHELL_ORGANIZATION_ID } from "./actors/organization/app-shell.js";
import { APP_SHELL_ORGANIZATION_ID } from "./actors/organization/constants.js";
import { logger } from "./logging.js";
export interface BackendStartOptions {

View file

@ -1,8 +1,11 @@
import { betterAuth } from "better-auth";
import { createAdapterFactory } from "better-auth/adapters";
import { APP_SHELL_ORGANIZATION_ID } from "../actors/organization/app-shell.js";
import { APP_SHELL_ORGANIZATION_ID } from "../actors/organization/constants.js";
import { organizationWorkflowQueueName } from "../actors/organization/queues.js";
import { userWorkflowQueueName } from "../actors/user/workflow.js";
import { organizationKey, userKey } from "../actors/keys.js";
import { logger } from "../logging.js";
import { expectQueueResponse } from "./queue.js";
const AUTH_BASE_PATH = "/v1/auth";
const SESSION_COOKIE = "better-auth.session_token";
@ -59,6 +62,12 @@ function resolveRouteUserId(organization: any, resolved: any): string | null {
return null;
}
async function sendOrganizationCommand<TResponse>(organization: any, name: Parameters<typeof organizationWorkflowQueueName>[0], body: unknown): Promise<TResponse> {
return expectQueueResponse<TResponse>(
await organization.send(organizationWorkflowQueueName(name), body, { wait: true, timeout: 60_000 }),
);
}
export interface BetterAuthService {
auth: any;
resolveSession(headers: Headers): Promise<{ session: any; user: any } | null>;
@ -110,7 +119,7 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
const email = direct("email");
if (typeof email === "string" && email.length > 0) {
const organization = await appOrganization();
const resolved = await organization.authFindEmailIndex({ email: email.toLowerCase() });
const resolved = await organization.betterAuthFindEmailIndex({ email: email.toLowerCase() });
return resolveRouteUserId(organization, resolved);
}
return null;
@ -125,7 +134,7 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
const sessionToken = direct("token") ?? data?.token;
if (typeof sessionId === "string" || typeof sessionToken === "string") {
const organization = await appOrganization();
const resolved = await organization.authFindSessionIndex({
const resolved = await organization.betterAuthFindSessionIndex({
...(typeof sessionId === "string" ? { sessionId } : {}),
...(typeof sessionToken === "string" ? { sessionToken } : {}),
});
@ -144,11 +153,11 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
const accountId = direct("accountId") ?? data?.accountId;
const organization = await appOrganization();
if (typeof accountRecordId === "string" && accountRecordId.length > 0) {
const resolved = await organization.authFindAccountIndex({ id: accountRecordId });
const resolved = await organization.betterAuthFindAccountIndex({ id: accountRecordId });
return resolveRouteUserId(organization, resolved);
}
if (typeof providerId === "string" && providerId.length > 0 && typeof accountId === "string" && accountId.length > 0) {
const resolved = await organization.authFindAccountIndex({ providerId, accountId });
const resolved = await organization.betterAuthFindAccountIndex({ providerId, accountId });
return resolveRouteUserId(organization, resolved);
}
return null;
@ -157,9 +166,9 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
return null;
};
const ensureOrganizationVerification = async (method: string, payload: Record<string, unknown>) => {
const ensureOrganizationVerification = async <TResponse>(method: Parameters<typeof organizationWorkflowQueueName>[0], payload: Record<string, unknown>) => {
const organization = await appOrganization();
return await organization[method](payload);
return await sendOrganizationCommand<TResponse>(organization, method, payload);
};
return {
@ -170,7 +179,7 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
create: async ({ model, data }) => {
const transformed = await transformInput(data, model, "create", true);
if (model === "verification") {
return await ensureOrganizationVerification("authCreateVerification", { data: transformed });
return await ensureOrganizationVerification<any>("organization.command.better_auth.verification.create", { data: transformed });
}
const userId = await resolveUserIdForQuery(model, undefined, transformed);
@ -179,18 +188,20 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
}
const userActor = await getUser(userId);
const created = await userActor.createAuthRecord({ model, data: transformed });
const created = expectQueueResponse<any>(
await userActor.send(userWorkflowQueueName("user.command.auth.create"), { model, data: transformed }, { wait: true, timeout: 60_000 }),
);
const organization = await appOrganization();
if (model === "user" && typeof transformed.email === "string" && transformed.email.length > 0) {
await organization.authUpsertEmailIndex({
await sendOrganizationCommand(organization, "organization.command.better_auth.email_index.upsert", {
email: transformed.email.toLowerCase(),
userId,
});
}
if (model === "session") {
await organization.authUpsertSessionIndex({
await sendOrganizationCommand(organization, "organization.command.better_auth.session_index.upsert", {
sessionId: String(created.id),
sessionToken: String(created.token),
userId,
@ -198,7 +209,7 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
}
if (model === "account") {
await organization.authUpsertAccountIndex({
await sendOrganizationCommand(organization, "organization.command.better_auth.account_index.upsert", {
id: String(created.id),
providerId: String(created.providerId),
accountId: String(created.accountId),
@ -212,7 +223,8 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
findOne: async ({ model, where, join }) => {
const transformedWhere = transformWhereClause({ model, where, action: "findOne" });
if (model === "verification") {
return await ensureOrganizationVerification("authFindOneVerification", { where: transformedWhere, join });
const organization = await appOrganization();
return await organization.betterAuthFindOneVerification({ where: transformedWhere, join });
}
const userId = await resolveUserIdForQuery(model, transformedWhere);
@ -221,14 +233,15 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
}
const userActor = await getUser(userId);
const found = await userActor.findOneAuthRecord({ model, where: transformedWhere, join });
const found = await userActor.betterAuthFindOneRecord({ model, where: transformedWhere, join });
return found ? ((await transformOutput(found, model, undefined, join)) as any) : null;
},
findMany: async ({ model, where, limit, sortBy, offset, join }) => {
const transformedWhere = transformWhereClause({ model, where, action: "findMany" });
if (model === "verification") {
return await ensureOrganizationVerification("authFindManyVerification", {
const organization = await appOrganization();
return await organization.betterAuthFindManyVerification({
where: transformedWhere,
limit,
sortBy,
@ -244,7 +257,7 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
const resolved = await Promise.all(
(tokenClause.value as string[]).map(async (sessionToken: string) => ({
sessionToken,
route: await organization.authFindSessionIndex({ sessionToken }),
route: await organization.betterAuthFindSessionIndex({ sessionToken }),
})),
);
const byUser = new Map<string, string[]>();
@ -263,7 +276,7 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
const scopedWhere = transformedWhere.map((entry: any) =>
entry.field === "token" && entry.operator === "in" ? { ...entry, value: tokens } : entry,
);
const found = await userActor.findManyAuthRecords({ model, where: scopedWhere, limit, sortBy, offset, join });
const found = await userActor.betterAuthFindManyRecords({ model, where: scopedWhere, limit, sortBy, offset, join });
rows.push(...found);
}
return await Promise.all(rows.map(async (row: any) => await transformOutput(row, model, undefined, join)));
@ -276,7 +289,7 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
}
const userActor = await getUser(userId);
const found = await userActor.findManyAuthRecords({ model, where: transformedWhere, limit, sortBy, offset, join });
const found = await userActor.betterAuthFindManyRecords({ model, where: transformedWhere, limit, sortBy, offset, join });
return await Promise.all(found.map(async (row: any) => await transformOutput(row, model, undefined, join)));
},
@ -284,7 +297,10 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
const transformedWhere = transformWhereClause({ model, where, action: "update" });
const transformedUpdate = (await transformInput(update as Record<string, unknown>, model, "update", true)) as Record<string, unknown>;
if (model === "verification") {
return await ensureOrganizationVerification("authUpdateVerification", { where: transformedWhere, update: transformedUpdate });
return await ensureOrganizationVerification<any>("organization.command.better_auth.verification.update", {
where: transformedWhere,
update: transformedUpdate,
});
}
const userId = await resolveUserIdForQuery(model, transformedWhere, transformedUpdate);
@ -295,26 +311,37 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
const userActor = await getUser(userId);
const before =
model === "user"
? await userActor.findOneAuthRecord({ model, where: transformedWhere })
? await userActor.betterAuthFindOneRecord({ model, where: transformedWhere })
: model === "account"
? await userActor.findOneAuthRecord({ model, where: transformedWhere })
? await userActor.betterAuthFindOneRecord({ model, where: transformedWhere })
: model === "session"
? await userActor.findOneAuthRecord({ model, where: transformedWhere })
? await userActor.betterAuthFindOneRecord({ model, where: transformedWhere })
: null;
const updated = await userActor.updateAuthRecord({ model, where: transformedWhere, update: transformedUpdate });
const updated = expectQueueResponse<any>(
await userActor.send(
userWorkflowQueueName("user.command.auth.update"),
{ model, where: transformedWhere, update: transformedUpdate },
{ wait: true, timeout: 60_000 },
),
);
const organization = await appOrganization();
if (model === "user" && updated) {
if (before?.email && before.email !== updated.email) {
await organization.authDeleteEmailIndex({ email: before.email.toLowerCase() });
await sendOrganizationCommand(organization, "organization.command.better_auth.email_index.delete", {
email: before.email.toLowerCase(),
});
}
if (updated.email) {
await organization.authUpsertEmailIndex({ email: updated.email.toLowerCase(), userId });
await sendOrganizationCommand(organization, "organization.command.better_auth.email_index.upsert", {
email: updated.email.toLowerCase(),
userId,
});
}
}
if (model === "session" && updated) {
await organization.authUpsertSessionIndex({
await sendOrganizationCommand(organization, "organization.command.better_auth.session_index.upsert", {
sessionId: String(updated.id),
sessionToken: String(updated.token),
userId,
@ -322,7 +349,7 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
}
if (model === "account" && updated) {
await organization.authUpsertAccountIndex({
await sendOrganizationCommand(organization, "organization.command.better_auth.account_index.upsert", {
id: String(updated.id),
providerId: String(updated.providerId),
accountId: String(updated.accountId),
@ -337,7 +364,10 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
const transformedWhere = transformWhereClause({ model, where, action: "updateMany" });
const transformedUpdate = (await transformInput(update as Record<string, unknown>, model, "update", true)) as Record<string, unknown>;
if (model === "verification") {
return await ensureOrganizationVerification("authUpdateManyVerification", { where: transformedWhere, update: transformedUpdate });
return await ensureOrganizationVerification<number>("organization.command.better_auth.verification.update_many", {
where: transformedWhere,
update: transformedUpdate,
});
}
const userId = await resolveUserIdForQuery(model, transformedWhere, transformedUpdate);
@ -346,13 +376,20 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
}
const userActor = await getUser(userId);
return await userActor.updateManyAuthRecords({ model, where: transformedWhere, update: transformedUpdate });
return expectQueueResponse<number>(
await userActor.send(
userWorkflowQueueName("user.command.auth.update_many"),
{ model, where: transformedWhere, update: transformedUpdate },
{ wait: true, timeout: 60_000 },
),
);
},
delete: async ({ model, where }) => {
const transformedWhere = transformWhereClause({ model, where, action: "delete" });
if (model === "verification") {
await ensureOrganizationVerification("authDeleteVerification", { where: transformedWhere });
const organization = await appOrganization();
await sendOrganizationCommand(organization, "organization.command.better_auth.verification.delete", { where: transformedWhere });
return;
}
@ -363,18 +400,20 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
const userActor = await getUser(userId);
const organization = await appOrganization();
const before = await userActor.findOneAuthRecord({ model, where: transformedWhere });
await userActor.deleteAuthRecord({ model, where: transformedWhere });
const before = await userActor.betterAuthFindOneRecord({ model, where: transformedWhere });
expectQueueResponse<void>(
await userActor.send(userWorkflowQueueName("user.command.auth.delete"), { model, where: transformedWhere }, { wait: true, timeout: 60_000 }),
);
if (model === "session" && before) {
await organization.authDeleteSessionIndex({
await sendOrganizationCommand(organization, "organization.command.better_auth.session_index.delete", {
sessionId: before.id,
sessionToken: before.token,
});
}
if (model === "account" && before) {
await organization.authDeleteAccountIndex({
await sendOrganizationCommand(organization, "organization.command.better_auth.account_index.delete", {
id: before.id,
providerId: before.providerId,
accountId: before.accountId,
@ -382,14 +421,16 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
}
if (model === "user" && before?.email) {
await organization.authDeleteEmailIndex({ email: before.email.toLowerCase() });
await sendOrganizationCommand(organization, "organization.command.better_auth.email_index.delete", {
email: before.email.toLowerCase(),
});
}
},
deleteMany: async ({ model, where }) => {
const transformedWhere = transformWhereClause({ model, where, action: "deleteMany" });
if (model === "verification") {
return await ensureOrganizationVerification("authDeleteManyVerification", { where: transformedWhere });
return await ensureOrganizationVerification<number>("organization.command.better_auth.verification.delete_many", { where: transformedWhere });
}
if (model === "session") {
@ -399,10 +440,12 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
}
const userActor = await getUser(userId);
const organization = await appOrganization();
const sessions = await userActor.findManyAuthRecords({ model, where: transformedWhere, limit: 5000 });
const deleted = await userActor.deleteManyAuthRecords({ model, where: transformedWhere });
const sessions = await userActor.betterAuthFindManyRecords({ model, where: transformedWhere, limit: 5000 });
const deleted = expectQueueResponse<number>(
await userActor.send(userWorkflowQueueName("user.command.auth.delete_many"), { model, where: transformedWhere }, { wait: true, timeout: 60_000 }),
);
for (const session of sessions) {
await organization.authDeleteSessionIndex({
await sendOrganizationCommand(organization, "organization.command.better_auth.session_index.delete", {
sessionId: session.id,
sessionToken: session.token,
});
@ -416,14 +459,17 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
}
const userActor = await getUser(userId);
const deleted = await userActor.deleteManyAuthRecords({ model, where: transformedWhere });
const deleted = expectQueueResponse<number>(
await userActor.send(userWorkflowQueueName("user.command.auth.delete_many"), { model, where: transformedWhere }, { wait: true, timeout: 60_000 }),
);
return deleted;
},
count: async ({ model, where }) => {
const transformedWhere = transformWhereClause({ model, where, action: "count" });
if (model === "verification") {
return await ensureOrganizationVerification("authCountVerification", { where: transformedWhere });
const organization = await appOrganization();
return await organization.betterAuthCountVerification({ where: transformedWhere });
}
const userId = await resolveUserIdForQuery(model, transformedWhere);
@ -432,7 +478,7 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
}
const userActor = await getUser(userId);
return await userActor.countAuthRecords({ model, where: transformedWhere });
return await userActor.betterAuthCountRecords({ model, where: transformedWhere });
},
};
},
@ -477,7 +523,7 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
async getAuthState(sessionId: string) {
const organization = await appOrganization();
const route = await organization.authFindSessionIndex({ sessionId });
const route = await organization.betterAuthFindSessionIndex({ sessionId });
if (!route?.userId) {
return null;
}
@ -487,7 +533,9 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
async upsertUserProfile(userId: string, patch: Record<string, unknown>) {
const userActor = await getUser(userId);
return await userActor.upsertUserProfile({ userId, patch });
return expectQueueResponse(
await userActor.send(userWorkflowQueueName("user.command.profile.upsert"), { userId, patch }, { wait: true, timeout: 60_000 }),
);
},
async setActiveOrganization(sessionId: string, activeOrganizationId: string | null) {
@ -496,7 +544,13 @@ export function initBetterAuthService(actorClient: any, options: { apiUrl: strin
throw new Error(`Unknown auth session ${sessionId}`);
}
const userActor = await getUser(authState.user.id);
return await userActor.upsertSessionState({ sessionId, activeOrganizationId });
return expectQueueResponse(
await userActor.send(
userWorkflowQueueName("user.command.session_state.upsert"),
{ sessionId, activeOrganizationId },
{ wait: true, timeout: 60_000 },
),
);
},
async getAccessTokenForSession(sessionId: string) {

View file

@ -1,5 +1,5 @@
import { getOrCreateOrganization } from "../actors/handles.js";
import { APP_SHELL_ORGANIZATION_ID } from "../actors/organization/app-shell.js";
import { APP_SHELL_ORGANIZATION_ID } from "../actors/organization/constants.js";
export interface ResolvedGithubAuth {
githubToken: string;