fix aborting

This commit is contained in:
Simon Knott 2025-02-13 12:11:03 +01:00
parent e6b81d85d8
commit a91775ec4e
No known key found for this signature in database
GPG key ID: 8CEDC00028084AEC
3 changed files with 51 additions and 53 deletions

View file

@ -1,38 +1,18 @@
import { useCallback, useEffect, useState } from 'react'; import { useCallback, useState } from 'react';
import Markdown from 'react-markdown' import Markdown from 'react-markdown'
import './aiConversation.css'; import './aiConversation.css';
import { clsx } from '@web/uiUtils'; import { clsx } from '@web/uiUtils';
import type { Conversation, LLMMessage } from './llm'; import type { Conversation, LLMMessage } from './llm';
export function AIConversation({ history, conversation, firstPrompt }: { history: LLMMessage[], conversation: Conversation, firstPrompt?: LLMMessage }) { export function AIConversation({ history, conversation }: { history: LLMMessage[], conversation: Conversation }) {
const [input, setInput] = useState(''); const [input, setInput] = useState('');
const [abort, setAbort] = useState<AbortController>();
const send = useCallback(async (prompt: string, visiblePrompt?: string) => {
const controller = new AbortController();
setAbort(controller);
try {
await conversation.send(prompt, visiblePrompt, controller.signal);
} finally {
setAbort(undefined);
}
}, [conversation]);
const onSubmit = useCallback(async (event: React.FormEvent<HTMLFormElement>) => { const onSubmit = useCallback(async (event: React.FormEvent<HTMLFormElement>) => {
event.preventDefault(); event.preventDefault();
setInput(''); setInput('');
const content = new FormData(event.target as any).get('content') as string; const content = new FormData(event.target as any).get('content') as string;
await send(content); await conversation.send(content);
}, [send]); }, [conversation]);
useEffect(() => {
if (!conversation.isEmpty())
return;
if (!firstPrompt)
return;
send(firstPrompt.content, firstPrompt.displayContent);
}, [conversation, firstPrompt, send]);
return ( return (
<div className="chat-container"> <div className="chat-container">
@ -63,10 +43,11 @@ export function AIConversation({ history, conversation, firstPrompt }: { history
placeholder="Ask a question..." placeholder="Ask a question..."
className="message-input" className="message-input"
/> />
{abort ? ( {conversation.isSending() ? (
<button type="button" className="send-button" onClick={(evt) => { <button type="button" className="send-button" onClick={(evt) => {
evt.preventDefault() evt.preventDefault()
abort.abort() console.log("aborting")
conversation.abortSending();
}}> }}>
Cancel Cancel
</button> </button>

View file

@ -27,7 +27,7 @@ import { fixTestPrompt } from '@web/components/prompts';
import type { GitCommitInfo } from '@testIsomorphic/types'; import type { GitCommitInfo } from '@testIsomorphic/types';
import { AIConversation } from './aiConversation'; import { AIConversation } from './aiConversation';
import { ToolbarButton } from '@web/components/toolbarButton'; import { ToolbarButton } from '@web/components/toolbarButton';
import { LLMMessage, useLLMChat, useLLMConversation } from './llm'; import { useLLMChat, useLLMConversation } from './llm';
import { useAsyncMemo } from '@web/uiUtils'; import { useAsyncMemo } from '@web/uiUtils';
const GitCommitInfoContext = React.createContext<GitCommitInfo | undefined>(undefined); const GitCommitInfoContext = React.createContext<GitCommitInfo | undefined>(undefined);
@ -167,25 +167,22 @@ export function AIErrorConversation({ conversationId, error, pageSnapshot, diff
].join('\n') ].join('\n')
); );
const firstPrompt = React.useMemo<LLMMessage>(() => { React.useEffect(() => {
const message: LLMMessage = { let content = `Here's the error: ${error}`;
role: 'user', let displayContent = `Help me with the error above.`;
content: `Here's the error: ${error}`,
displayContent: `Help me with the error above.`
}
if (diff) if (diff)
message.content += `\n\nCode diff:\n${diff}`; content += `\n\nCode diff:\n${diff}`;
if (pageSnapshot) if (pageSnapshot)
message.content += `\n\nPage snapshot:\n${pageSnapshot}`; content += `\n\nPage snapshot:\n${pageSnapshot}`;
if (diff) if (diff)
message.displayContent += ` Take the code diff${pageSnapshot ? ' and page snapshot' : ''} into account.`; displayContent += ` Take the code diff${pageSnapshot ? ' and page snapshot' : ''} into account.`;
else if (pageSnapshot) else if (pageSnapshot)
message.displayContent += ` Take the page snapshot into account.`; displayContent += ` Take the page snapshot into account.`;
return message; conversation.send(content, displayContent);
}, [diff, pageSnapshot, error]); }, []);
return <AIConversation history={history} conversation={conversation} firstPrompt={firstPrompt} />; return <AIConversation history={history} conversation={conversation} />;
} }

View file

@ -98,7 +98,7 @@ class OpenAI implements LLM {
constructor(private apiKey: string, private baseURL = 'https://api.openai.com') {} constructor(private apiKey: string, private baseURL = 'https://api.openai.com') {}
async *chatCompletion(messages: LLMMessage[]) { async *chatCompletion(messages: LLMMessage[], signal: AbortSignal) {
const url = new URL('./v1/chat/completions', this.baseURL); const url = new URL('./v1/chat/completions', this.baseURL);
const response = await fetch(url, { const response = await fetch(url, {
method: 'POST', method: 'POST',
@ -112,6 +112,7 @@ class OpenAI implements LLM {
messages: messages.map(({ role, content }) => ({ role, content })), messages: messages.map(({ role, content }) => ({ role, content })),
stream: true, stream: true,
}), }),
signal,
}); });
if (response.status !== 200 || !response.body) if (response.status !== 200 || !response.body)
@ -130,7 +131,7 @@ class OpenAI implements LLM {
class Anthropic implements LLM { class Anthropic implements LLM {
constructor(private apiKey: string, private baseURL = 'https://api.anthropic.com') {} constructor(private apiKey: string, private baseURL = 'https://api.anthropic.com') {}
async *chatCompletion(messages: LLMMessage[]): AsyncGenerator<string> { async *chatCompletion(messages: LLMMessage[], signal: AbortSignal): AsyncGenerator<string> {
const response = await fetch(new URL('./v1/messages', this.baseURL), { const response = await fetch(new URL('./v1/messages', this.baseURL), {
method: 'POST', method: 'POST',
headers: { headers: {
@ -145,7 +146,8 @@ class Anthropic implements LLM {
system: messages.find(({ role }) => role === 'developer')?.content, system: messages.find(({ role }) => role === 'developer')?.content,
max_tokens: 1024, max_tokens: 1024,
stream: true, stream: true,
}) }),
signal,
}); });
if (response.status !== 200 || !response.body) if (response.status !== 200 || !response.body)
@ -176,21 +178,39 @@ class LLMChat {
export class Conversation { export class Conversation {
history: LLMMessage[]; history: LLMMessage[];
onChange = new EventEmitter<void>(); onChange = new EventEmitter<void>();
private _abortController: AbortController | undefined;
constructor(private chat: LLMChat, systemPrompt: string) { constructor(private chat: LLMChat, systemPrompt: string) {
this.history = [{ role: 'developer', content: systemPrompt }]; this.history = [{ role: 'developer', content: systemPrompt }];
} }
async send(content: string, displayContent: string | undefined, signal: AbortSignal) { async send(content: string, displayContent?: string) {
if (this.isSending())
throw new Error('Already sending');
const response: LLMMessage = { role: 'assistant', content: '' }; const response: LLMMessage = { role: 'assistant', content: '' };
this.history.push({ role: 'user', content, displayContent }, response); this.history.push({ role: 'user', content, displayContent }, response);
this.onChange.fire(); this.onChange.fire();
for await (const chunk of this.chat.api.chatCompletion(this.history, signal)) { this._abortController = new AbortController();
response.content += chunk; try {
this.onChange.fire(); for await (const chunk of this.chat.api.chatCompletion(this.history, this._abortController.signal)) {
response.content += chunk;
this.onChange.fire();
}
} finally {
this._abortController = undefined;
} }
} }
isSending(): boolean {
return this._abortController !== undefined;
}
abortSending() {
this._abortController!.abort();
this.onChange.fire();
}
isEmpty() { isEmpty() {
return this.history.length < 2; return this.history.length < 2;
} }
@ -223,12 +243,12 @@ export function useLLMConversation(id: string, systemPrompt: string) {
const conversation = React.useMemo(() => chat.getConversation(id, systemPrompt), [chat, id]); const conversation = React.useMemo(() => chat.getConversation(id, systemPrompt), [chat, id]);
const [history, setHistory] = React.useState(conversation.history); const [history, setHistory] = React.useState(conversation.history);
React.useEffect(() => { React.useEffect(() => {
function update() { function update() {
setHistory([...conversation.history]); setHistory([...conversation.history]);
} }
update(); update();
const subscription = conversation.onChange.event(update); const subscription = conversation.onChange.event(update);
return subscription.dispose; return subscription.dispose;
}, [conversation]); }, [conversation]);
return [history, conversation] as const; return [history, conversation] as const;