feat: support reasoning models like DeepSeek R1

This commit is contained in:
AnotiaWang
2025-02-14 15:20:02 +08:00
parent 93527597b7
commit e7296df78f
17 changed files with 549 additions and 171 deletions

View File

@ -1,4 +1,5 @@
import { parsePartialJson } from '@ai-sdk/ui-utils'
import type { TextStreamPart } from 'ai'
import { z } from 'zod'
export type DeepPartial<T> = T extends object
@ -7,6 +8,13 @@ export type DeepPartial<T> = T extends object
: { [P in keyof T]?: DeepPartial<T[P]> }
: T
export type ParseStreamingJsonEvent<T> =
| { type: 'object'; value: DeepPartial<T> }
| { type: 'reasoning'; delta: string }
| { type: 'error'; message: string }
/** The call finished with invalid content that can't be parsed as JSON */
| { type: 'bad-end'; rawText: string }
export function removeJsonMarkdown(text: string) {
text = text.trim()
if (text.startsWith('```json')) {
@ -23,32 +31,56 @@ export function removeJsonMarkdown(text: string) {
}
/**
* 解析流式的 JSON 数据
* @param textStream 字符串流
* @param _schema zod schema 用于类型验证
* @param isValid 自定义验证函数,用于判断解析出的 JSON 是否有效
* @returns 异步生成器yield 解析后的数据
* Parse streaming JSON text
* @param fullStream Returned by AI SDK
* @param _schema zod schema for type definition
* @param isValid Custom validation function to check if the parsed JSON is valid
*/
export async function* parseStreamingJson<T extends z.ZodType>(
textStream: AsyncIterable<string>,
fullStream: AsyncIterable<TextStreamPart<any>>,
_schema: T,
isValid: (value: DeepPartial<z.infer<T>>) => boolean,
): AsyncGenerator<DeepPartial<z.infer<T>>> {
): AsyncGenerator<ParseStreamingJsonEvent<z.infer<T>>> {
let rawText = ''
let isParseSuccessful = false
for await (const chunk of textStream) {
rawText += chunk
const parsed = parsePartialJson(removeJsonMarkdown(rawText))
for await (const chunk of fullStream) {
if (chunk.type === 'reasoning') {
yield { type: 'reasoning', delta: chunk.textDelta }
continue
}
if (chunk.type === 'error') {
yield {
type: 'error',
message:
chunk.error instanceof Error
? chunk.error.message
: String(chunk.error),
}
continue
}
if (chunk.type === 'text-delta') {
rawText += chunk.textDelta
const parsed = parsePartialJson(removeJsonMarkdown(rawText))
isParseSuccessful =
parsed.state === 'repaired-parse' || parsed.state === 'successful-parse'
if (isParseSuccessful && isValid(parsed.value as any)) {
yield parsed.value as DeepPartial<z.infer<T>>
} else {
console.debug(`Failed to parse JSON:`, rawText)
isParseSuccessful =
parsed.state === 'repaired-parse' || parsed.state === 'successful-parse'
if (isParseSuccessful && isValid(parsed.value as any)) {
yield {
type: 'object',
value: parsed.value as DeepPartial<z.infer<T>>,
}
} else {
console.debug(`Failed to parse JSON: ${removeJsonMarkdown(rawText)}`)
}
}
}
return { isSuccessful: isParseSuccessful }
// If the last chunk parses failed, return an error
if (!isParseSuccessful) {
yield {
type: 'bad-end',
rawText,
}
}
}