feat: support reasoning models like DeepSeek R1
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
import { parsePartialJson } from '@ai-sdk/ui-utils'
|
||||
import type { TextStreamPart } from 'ai'
|
||||
import { z } from 'zod'
|
||||
|
||||
export type DeepPartial<T> = T extends object
|
||||
@ -7,6 +8,13 @@ export type DeepPartial<T> = T extends object
|
||||
: { [P in keyof T]?: DeepPartial<T[P]> }
|
||||
: T
|
||||
|
||||
export type ParseStreamingJsonEvent<T> =
|
||||
| { type: 'object'; value: DeepPartial<T> }
|
||||
| { type: 'reasoning'; delta: string }
|
||||
| { type: 'error'; message: string }
|
||||
/** The call finished with invalid content that can't be parsed as JSON */
|
||||
| { type: 'bad-end'; rawText: string }
|
||||
|
||||
export function removeJsonMarkdown(text: string) {
|
||||
text = text.trim()
|
||||
if (text.startsWith('```json')) {
|
||||
@ -23,32 +31,56 @@ export function removeJsonMarkdown(text: string) {
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析流式的 JSON 数据
|
||||
* @param textStream 字符串流
|
||||
* @param _schema zod schema 用于类型验证
|
||||
* @param isValid 自定义验证函数,用于判断解析出的 JSON 是否有效
|
||||
* @returns 异步生成器,yield 解析后的数据
|
||||
* Parse streaming JSON text
|
||||
* @param fullStream Returned by AI SDK
|
||||
* @param _schema zod schema for type definition
|
||||
* @param isValid Custom validation function to check if the parsed JSON is valid
|
||||
*/
|
||||
export async function* parseStreamingJson<T extends z.ZodType>(
|
||||
textStream: AsyncIterable<string>,
|
||||
fullStream: AsyncIterable<TextStreamPart<any>>,
|
||||
_schema: T,
|
||||
isValid: (value: DeepPartial<z.infer<T>>) => boolean,
|
||||
): AsyncGenerator<DeepPartial<z.infer<T>>> {
|
||||
): AsyncGenerator<ParseStreamingJsonEvent<z.infer<T>>> {
|
||||
let rawText = ''
|
||||
let isParseSuccessful = false
|
||||
|
||||
for await (const chunk of textStream) {
|
||||
rawText += chunk
|
||||
const parsed = parsePartialJson(removeJsonMarkdown(rawText))
|
||||
for await (const chunk of fullStream) {
|
||||
if (chunk.type === 'reasoning') {
|
||||
yield { type: 'reasoning', delta: chunk.textDelta }
|
||||
continue
|
||||
}
|
||||
if (chunk.type === 'error') {
|
||||
yield {
|
||||
type: 'error',
|
||||
message:
|
||||
chunk.error instanceof Error
|
||||
? chunk.error.message
|
||||
: String(chunk.error),
|
||||
}
|
||||
continue
|
||||
}
|
||||
if (chunk.type === 'text-delta') {
|
||||
rawText += chunk.textDelta
|
||||
const parsed = parsePartialJson(removeJsonMarkdown(rawText))
|
||||
|
||||
isParseSuccessful =
|
||||
parsed.state === 'repaired-parse' || parsed.state === 'successful-parse'
|
||||
if (isParseSuccessful && isValid(parsed.value as any)) {
|
||||
yield parsed.value as DeepPartial<z.infer<T>>
|
||||
} else {
|
||||
console.debug(`Failed to parse JSON:`, rawText)
|
||||
isParseSuccessful =
|
||||
parsed.state === 'repaired-parse' || parsed.state === 'successful-parse'
|
||||
if (isParseSuccessful && isValid(parsed.value as any)) {
|
||||
yield {
|
||||
type: 'object',
|
||||
value: parsed.value as DeepPartial<z.infer<T>>,
|
||||
}
|
||||
} else {
|
||||
console.debug(`Failed to parse JSON: ${removeJsonMarkdown(rawText)}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { isSuccessful: isParseSuccessful }
|
||||
// If the last chunk parses failed, return an error
|
||||
if (!isParseSuccessful) {
|
||||
yield {
|
||||
type: 'bad-end',
|
||||
rawText,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user