diff --git a/orchestrator/src/server/repositories/ghostwriter.ts b/orchestrator/src/server/repositories/ghostwriter.ts index 47ab607..fec863d 100644 --- a/orchestrator/src/server/repositories/ghostwriter.ts +++ b/orchestrator/src/server/repositories/ghostwriter.ts @@ -341,3 +341,28 @@ export async function completeRun( return getRunById(runId); } + +export async function completeRunIfRunning( + runId: string, + input: { + status: Exclude; + errorCode?: string | null; + errorMessage?: string | null; + }, +): Promise { + const nowEpoch = Date.now(); + const nowIso = new Date(nowEpoch).toISOString(); + + await db + .update(jobChatRuns) + .set({ + status: input.status, + completedAt: nowEpoch, + errorCode: input.errorCode ?? null, + errorMessage: input.errorMessage ?? null, + updatedAt: nowIso, + }) + .where(and(eq(jobChatRuns.id, runId), eq(jobChatRuns.status, "running"))); + + return getRunById(runId); +} diff --git a/orchestrator/src/server/services/ghostwriter.test.ts b/orchestrator/src/server/services/ghostwriter.test.ts index 069881f..ccf9cd1 100644 --- a/orchestrator/src/server/services/ghostwriter.test.ts +++ b/orchestrator/src/server/services/ghostwriter.test.ts @@ -14,6 +14,7 @@ const mocks = vi.hoisted(() => ({ createRun: vi.fn(), updateMessage: vi.fn(), completeRun: vi.fn(), + completeRunIfRunning: vi.fn(), getMessageById: vi.fn(), getLatestAssistantMessage: vi.fn(), getRunById: vi.fn(), @@ -52,6 +53,7 @@ vi.mock("../repositories/ghostwriter", () => ({ createRun: mocks.repo.createRun, updateMessage: mocks.repo.updateMessage, completeRun: mocks.repo.completeRun, + completeRunIfRunning: mocks.repo.completeRunIfRunning, getMessageById: mocks.repo.getMessageById, getLatestAssistantMessage: mocks.repo.getLatestAssistantMessage, getRunById: mocks.repo.getRunById, @@ -148,6 +150,21 @@ describe("ghostwriter service", () => { updatedAt: new Date().toISOString(), }); mocks.repo.completeRun.mockResolvedValue(null); + mocks.repo.completeRunIfRunning.mockResolvedValue({ + id: "run-1", + threadId: "thread-1", + jobId: "job-1", + status: "cancelled", + model: "model-a", + provider: "openrouter", + errorCode: "REQUEST_TIMEOUT", + errorMessage: "Generation cancelled by user", + startedAt: Date.now(), + completedAt: Date.now(), + requestId: "req-123", + createdAt: new Date().toISOString(), + updatedAt: new Date().toISOString(), + }); mocks.repo.updateMessage.mockResolvedValue(baseAssistantMessage); mocks.repo.getMessageById.mockResolvedValue(baseAssistantMessage); mocks.repo.listMessagesForThread.mockResolvedValue([ @@ -355,6 +372,48 @@ describe("ghostwriter service", () => { expect(result).toEqual({ cancelled: false, alreadyFinished: true }); expect(mocks.repo.completeRun).not.toHaveBeenCalled(); + expect(mocks.repo.completeRunIfRunning).not.toHaveBeenCalled(); + }); + + it("returns alreadyFinished when run completes before cancel write", async () => { + mocks.repo.getRunById.mockResolvedValue({ + id: "run-race", + threadId: "thread-1", + jobId: "job-1", + status: "running", + model: "model-a", + provider: "openrouter", + errorCode: null, + errorMessage: null, + startedAt: Date.now(), + completedAt: null, + requestId: "req-123", + createdAt: new Date().toISOString(), + updatedAt: new Date().toISOString(), + }); + mocks.repo.completeRunIfRunning.mockResolvedValue({ + id: "run-race", + threadId: "thread-1", + jobId: "job-1", + status: "completed", + model: "model-a", + provider: "openrouter", + errorCode: null, + errorMessage: null, + startedAt: Date.now(), + completedAt: Date.now(), + requestId: "req-123", + createdAt: new Date().toISOString(), + updatedAt: new Date().toISOString(), + }); + + const result = await cancelRun({ + jobId: "job-1", + threadId: "thread-1", + runId: "run-race", + }); + + expect(result).toEqual({ cancelled: false, alreadyFinished: true }); }); it("maps createRun unique constraint races to conflict", async () => { diff --git a/orchestrator/src/server/services/ghostwriter.ts b/orchestrator/src/server/services/ghostwriter.ts index 908b0c6..13e24b0 100644 --- a/orchestrator/src/server/services/ghostwriter.ts +++ b/orchestrator/src/server/services/ghostwriter.ts @@ -562,12 +562,19 @@ export async function cancelRun(input: { controller.abort(); } - await jobChatRepo.completeRun(input.runId, { + const runAfterCancel = await jobChatRepo.completeRunIfRunning(input.runId, { status: "cancelled", errorCode: "REQUEST_TIMEOUT", errorMessage: "Generation cancelled by user", }); + if (!runAfterCancel || runAfterCancel.status !== "cancelled") { + return { + cancelled: false, + alreadyFinished: true, + }; + } + return { cancelled: true, alreadyFinished: false,