From a91775ec4e60c5e13b21377d85f2ee2fa5b57221 Mon Sep 17 00:00:00 2001 From: Simon Knott Date: Thu, 13 Feb 2025 12:11:03 +0100 Subject: [PATCH] fix aborting --- .../trace-viewer/src/ui/aiConversation.tsx | 33 +++---------- packages/trace-viewer/src/ui/errorsTab.tsx | 25 +++++----- packages/trace-viewer/src/ui/llm.tsx | 46 +++++++++++++------ 3 files changed, 51 insertions(+), 53 deletions(-) diff --git a/packages/trace-viewer/src/ui/aiConversation.tsx b/packages/trace-viewer/src/ui/aiConversation.tsx index 9fd0bc63c7..f9bc550fe0 100644 --- a/packages/trace-viewer/src/ui/aiConversation.tsx +++ b/packages/trace-viewer/src/ui/aiConversation.tsx @@ -1,38 +1,18 @@ -import { useCallback, useEffect, useState } from 'react'; +import { useCallback, useState } from 'react'; import Markdown from 'react-markdown' import './aiConversation.css'; import { clsx } from '@web/uiUtils'; import type { Conversation, LLMMessage } from './llm'; -export function AIConversation({ history, conversation, firstPrompt }: { history: LLMMessage[], conversation: Conversation, firstPrompt?: LLMMessage }) { +export function AIConversation({ history, conversation }: { history: LLMMessage[], conversation: Conversation }) { const [input, setInput] = useState(''); - const [abort, setAbort] = useState(); - - const send = useCallback(async (prompt: string, visiblePrompt?: string) => { - const controller = new AbortController(); - setAbort(controller); - try { - await conversation.send(prompt, visiblePrompt, controller.signal); - } finally { - setAbort(undefined); - } - }, [conversation]); const onSubmit = useCallback(async (event: React.FormEvent) => { event.preventDefault(); setInput(''); const content = new FormData(event.target as any).get('content') as string; - await send(content); - }, [send]); - - useEffect(() => { - if (!conversation.isEmpty()) - return; - if (!firstPrompt) - return; - - send(firstPrompt.content, firstPrompt.displayContent); - }, [conversation, firstPrompt, send]); + await conversation.send(content); + }, [conversation]); return (
@@ -63,10 +43,11 @@ export function AIConversation({ history, conversation, firstPrompt }: { history placeholder="Ask a question..." className="message-input" /> - {abort ? ( + {conversation.isSending() ? ( diff --git a/packages/trace-viewer/src/ui/errorsTab.tsx b/packages/trace-viewer/src/ui/errorsTab.tsx index 8befbcc224..3884896a83 100644 --- a/packages/trace-viewer/src/ui/errorsTab.tsx +++ b/packages/trace-viewer/src/ui/errorsTab.tsx @@ -27,7 +27,7 @@ import { fixTestPrompt } from '@web/components/prompts'; import type { GitCommitInfo } from '@testIsomorphic/types'; import { AIConversation } from './aiConversation'; import { ToolbarButton } from '@web/components/toolbarButton'; -import { LLMMessage, useLLMChat, useLLMConversation } from './llm'; +import { useLLMChat, useLLMConversation } from './llm'; import { useAsyncMemo } from '@web/uiUtils'; const GitCommitInfoContext = React.createContext(undefined); @@ -167,25 +167,22 @@ export function AIErrorConversation({ conversationId, error, pageSnapshot, diff ].join('\n') ); - const firstPrompt = React.useMemo(() => { - const message: LLMMessage = { - role: 'user', - content: `Here's the error: ${error}`, - displayContent: `Help me with the error above.` - } + React.useEffect(() => { + let content = `Here's the error: ${error}`; + let displayContent = `Help me with the error above.`; if (diff) - message.content += `\n\nCode diff:\n${diff}`; + content += `\n\nCode diff:\n${diff}`; if (pageSnapshot) - message.content += `\n\nPage snapshot:\n${pageSnapshot}`; + content += `\n\nPage snapshot:\n${pageSnapshot}`; if (diff) - message.displayContent += ` Take the code diff${pageSnapshot ? ' and page snapshot' : ''} into account.`; + displayContent += ` Take the code diff${pageSnapshot ? ' and page snapshot' : ''} into account.`; else if (pageSnapshot) - message.displayContent += ` Take the page snapshot into account.`; + displayContent += ` Take the page snapshot into account.`; - return message; - }, [diff, pageSnapshot, error]); + conversation.send(content, displayContent); + }, []); - return ; + return ; } diff --git a/packages/trace-viewer/src/ui/llm.tsx b/packages/trace-viewer/src/ui/llm.tsx index 02bb7b7d04..776b4b7dfb 100644 --- a/packages/trace-viewer/src/ui/llm.tsx +++ b/packages/trace-viewer/src/ui/llm.tsx @@ -98,7 +98,7 @@ class OpenAI implements LLM { constructor(private apiKey: string, private baseURL = 'https://api.openai.com') {} - async *chatCompletion(messages: LLMMessage[]) { + async *chatCompletion(messages: LLMMessage[], signal: AbortSignal) { const url = new URL('./v1/chat/completions', this.baseURL); const response = await fetch(url, { method: 'POST', @@ -112,6 +112,7 @@ class OpenAI implements LLM { messages: messages.map(({ role, content }) => ({ role, content })), stream: true, }), + signal, }); if (response.status !== 200 || !response.body) @@ -130,7 +131,7 @@ class OpenAI implements LLM { class Anthropic implements LLM { constructor(private apiKey: string, private baseURL = 'https://api.anthropic.com') {} - async *chatCompletion(messages: LLMMessage[]): AsyncGenerator { + async *chatCompletion(messages: LLMMessage[], signal: AbortSignal): AsyncGenerator { const response = await fetch(new URL('./v1/messages', this.baseURL), { method: 'POST', headers: { @@ -145,7 +146,8 @@ class Anthropic implements LLM { system: messages.find(({ role }) => role === 'developer')?.content, max_tokens: 1024, stream: true, - }) + }), + signal, }); if (response.status !== 200 || !response.body) @@ -176,21 +178,39 @@ class LLMChat { export class Conversation { history: LLMMessage[]; onChange = new EventEmitter(); + private _abortController: AbortController | undefined; constructor(private chat: LLMChat, systemPrompt: string) { this.history = [{ role: 'developer', content: systemPrompt }]; } - async send(content: string, displayContent: string | undefined, signal: AbortSignal) { + async send(content: string, displayContent?: string) { + if (this.isSending()) + throw new Error('Already sending'); + const response: LLMMessage = { role: 'assistant', content: '' }; this.history.push({ role: 'user', content, displayContent }, response); this.onChange.fire(); - for await (const chunk of this.chat.api.chatCompletion(this.history, signal)) { - response.content += chunk; - this.onChange.fire(); + this._abortController = new AbortController(); + try { + for await (const chunk of this.chat.api.chatCompletion(this.history, this._abortController.signal)) { + response.content += chunk; + this.onChange.fire(); + } + } finally { + this._abortController = undefined; } } + isSending(): boolean { + return this._abortController !== undefined; + } + + abortSending() { + this._abortController!.abort(); + this.onChange.fire(); + } + isEmpty() { return this.history.length < 2; } @@ -223,12 +243,12 @@ export function useLLMConversation(id: string, systemPrompt: string) { const conversation = React.useMemo(() => chat.getConversation(id, systemPrompt), [chat, id]); const [history, setHistory] = React.useState(conversation.history); React.useEffect(() => { - function update() { - setHistory([...conversation.history]); - } - update(); - const subscription = conversation.onChange.event(update); - return subscription.dispose; + function update() { + setHistory([...conversation.history]); + } + update(); + const subscription = conversation.onChange.event(update); + return subscription.dispose; }, [conversation]); return [history, conversation] as const;