Skip to content

Commit 69c1b24

Browse files
committed
fix: initial chat redirect bugs
1 parent e31b4da commit 69c1b24

File tree

2 files changed

+74
-38
lines changed

2 files changed

+74
-38
lines changed

apps/postgres-new/app/page.tsx

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,51 @@ export default function Page() {
6868
<Workspace
6969
databaseId={nextDatabaseId}
7070
visibility="local"
71-
onStart={async () => {
71+
onMessage={async () => {
7272
// Make the DB no longer hidden
7373
await updateDatabase({ id: nextDatabaseId, name: null, isHidden: false })
7474

75-
// Navigate to this DB's path
76-
router.push(`/db/${nextDatabaseId}`)
77-
7875
// Pre-load the next DB
7976
const nextId = uniqueId()
8077
localStorage.setItem('next-db-id', JSON.stringify(nextId))
8178
preloadDb(nextId)
8279
}}
80+
onReply={async (message, append) => {
81+
if (!dbManager) {
82+
throw new Error('dbManager is not available')
83+
}
84+
85+
const messages = await dbManager.getMessages(nextDatabaseId)
86+
const isFirstReplyComplete =
87+
!messages.some((message) => message.role === 'assistant' && !message.toolInvocations) &&
88+
message.role === 'assistant' &&
89+
!message.toolInvocations
90+
91+
// The model might run multiple tool calls before ending with a message, so
92+
// we only want to redirect after all of these back-to-back calls finish
93+
if (isFirstReplyComplete) {
94+
router.push(`/db/${nextDatabaseId}`)
95+
96+
append({
97+
role: 'user',
98+
content: 'Name this conversation. No need to reply.',
99+
data: {
100+
automated: true,
101+
},
102+
})
103+
}
104+
}}
105+
onCancelReply={(append) => {
106+
router.push(`/db/${nextDatabaseId}`)
107+
108+
append({
109+
role: 'user',
110+
content: 'Name this conversation. No need to reply.',
111+
data: {
112+
automated: true,
113+
},
114+
})
115+
}}
83116
/>
84117
)
85118
}

apps/postgres-new/components/workspace.tsx

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
'use client'
22

3-
import { CreateMessage, Message, useChat } from 'ai/react'
3+
import { CreateMessage, Message, useChat, UseChatHelpers } from 'ai/react'
44
import { createContext, useCallback, useContext, useMemo } from 'react'
55
import { useMessageCreateMutation } from '~/data/messages/message-create-mutation'
66
import { useMessagesQuery } from '~/data/messages/messages-query'
77
import { useTablesQuery } from '~/data/tables/tables-query'
88
import { useOnToolCall } from '~/lib/hooks'
99
import { useBreakpoint } from '~/lib/use-breakpoint'
1010
import { ensureMessageId, ensureToolResult } from '~/lib/util'
11-
import { useApp } from './app-provider'
1211
import Chat, { getInitialMessages } from './chat'
1312
import IDE from './ide'
1413

@@ -27,13 +26,31 @@ export type WorkspaceProps = {
2726
visibility: Visibility
2827

2928
/**
30-
* Callback called when the conversation has started.
29+
* Callback called after the user sends a message.
3130
*/
32-
onStart?: () => void | Promise<void>
31+
onMessage?: (
32+
message: Message | CreateMessage,
33+
append: UseChatHelpers['append']
34+
) => void | Promise<void>
35+
36+
/**
37+
* Callback called after the LLM finishes a reply.
38+
*/
39+
onReply?: (message: Message, append: UseChatHelpers['append']) => void | Promise<void>
40+
41+
/**
42+
* Callback called when the user cancels the reply.
43+
*/
44+
onCancelReply?: (append: UseChatHelpers['append']) => void | Promise<void>
3345
}
3446

35-
export default function Workspace({ databaseId, visibility, onStart }: WorkspaceProps) {
36-
const { dbManager } = useApp()
47+
export default function Workspace({
48+
databaseId,
49+
visibility,
50+
onMessage,
51+
onReply,
52+
onCancelReply,
53+
}: WorkspaceProps) {
3754
const isSmallBreakpoint = useBreakpoint('lg')
3855
const onToolCall = useOnToolCall(databaseId)
3956
const { mutateAsync: saveMessage } = useMessageCreateMutation(databaseId)
@@ -46,40 +63,18 @@ export default function Workspace({ databaseId, visibility, onStart }: Workspace
4663

4764
const initialMessages = useMemo(() => (tables ? getInitialMessages(tables) : undefined), [tables])
4865

49-
const {
50-
messages,
51-
setMessages,
52-
append,
53-
stop: stopReply,
54-
} = useChat({
66+
const { messages, setMessages, append, stop } = useChat({
5567
id: databaseId,
5668
api: '/api/chat',
5769
maxToolRoundtrips: 10,
70+
keepLastMessageOnError: true,
5871
onToolCall: onToolCall as any, // our `OnToolCall` type is more specific than `ai` SDK's
5972
initialMessages:
6073
existingMessages && existingMessages.length > 0 ? existingMessages : initialMessages,
6174
async onFinish(message) {
62-
if (!dbManager) {
63-
throw new Error('dbManager is not available')
64-
}
65-
75+
// Order is important here
76+
await onReply?.(message, append)
6677
await saveMessage({ message })
67-
68-
const database = await dbManager.getDatabase(databaseId)
69-
const isStartOfConversation = database.isHidden && !message.toolInvocations
70-
71-
if (isStartOfConversation) {
72-
await onStart?.()
73-
74-
// Intentionally using `append` vs `appendMessage` so that this message isn't persisted in the meta DB
75-
await append({
76-
role: 'user',
77-
content: 'Name this conversation. No need to reply.',
78-
data: {
79-
automated: true,
80-
},
81-
})
82-
}
8378
},
8479
})
8580

@@ -90,12 +85,20 @@ export default function Workspace({ databaseId, visibility, onStart }: Workspace
9085
return isModified ? [...messages] : messages
9186
})
9287
ensureMessageId(message)
88+
89+
// Intentionally don't await so that framer animations aren't affected
9390
append(message)
9491
saveMessage({ message })
92+
onMessage?.(message, append)
9593
},
96-
[setMessages, saveMessage, append]
94+
[onMessage, setMessages, saveMessage, append]
9795
)
9896

97+
const stopReply = useCallback(async () => {
98+
stop()
99+
onCancelReply?.(append)
100+
}, [onCancelReply, stop, append])
101+
99102
const isConversationStarted =
100103
initialMessages !== undefined && messages.length > initialMessages.length
101104

@@ -134,7 +137,7 @@ export type WorkspaceContextValues = {
134137
messages: Message[]
135138
visibility: Visibility
136139
appendMessage(message: Message | CreateMessage): Promise<void>
137-
stopReply(): void
140+
stopReply(): Promise<void>
138141
}
139142

140143
export const WorkspaceContext = createContext<WorkspaceContextValues | undefined>(undefined)

0 commit comments

Comments
 (0)