feat: Deep Research section
This commit is contained in:
@ -8,20 +8,21 @@ import { o3MiniModel, trimPrompt } from './ai/providers';
|
||||
import { systemPrompt } from './prompt';
|
||||
import zodToJsonSchema from 'zod-to-json-schema';
|
||||
import { tavily, type TavilySearchResponse } from '@tavily/core';
|
||||
// import 'dotenv/config';
|
||||
|
||||
// Used for streaming response
|
||||
type PartialSerpQueries = DeepPartial<z.infer<typeof serpQueriesTypeSchema>['queries']>;
|
||||
type PartialSearchResult = DeepPartial<z.infer<typeof serpResultTypeSchema>>;
|
||||
export type SearchQuery = z.infer<typeof searchQueriesTypeSchema>['queries'][0];
|
||||
export type PartialSearchQuery = DeepPartial<SearchQuery>;
|
||||
export type SearchResult = z.infer<typeof searchResultTypeSchema>;
|
||||
export type PartialSearchResult = DeepPartial<SearchResult>;
|
||||
|
||||
export type ResearchStep =
|
||||
| { type: 'start'; message: string; depth: number; breadth: number }
|
||||
| { type: 'generating_queries'; result: PartialSerpQueries; depth: number; breadth: number }
|
||||
| { type: 'query_generated'; query: string; researchGoal: string; depth: number; breadth: number; queryIndex: number }
|
||||
| { type: 'searching'; query: string; depth: number; breadth: number; queryIndex: number }
|
||||
| { type: 'search_complete'; query: string; urls: string[]; depth: number; breadth: number; queryIndex: number }
|
||||
| { type: 'processing_serach_result'; query: string; result: PartialSearchResult; depth: number; breadth: number; queryIndex: number }
|
||||
| { type: 'error'; message: string }
|
||||
| { type: 'generating_query'; result: PartialSearchQuery; depth: number; breadth: number; nodeIndex: number; nodeId: string }
|
||||
| { type: 'generated_query'; query: string; result: PartialSearchQuery; depth: number; breadth: number; nodeIndex: number; nodeId: string }
|
||||
| { type: 'searching'; query: string; depth: number; breadth: number; nodeIndex: number; nodeId: string }
|
||||
| { type: 'search_complete'; query: string; urls: string[]; depth: number; breadth: number; nodeIndex: number; nodeId: string }
|
||||
| { type: 'processing_serach_result'; query: string; result: PartialSearchResult; depth: number; breadth: number; nodeIndex: number; nodeId: string }
|
||||
| { type: 'processed_search_result'; query: string; result: SearchResult; depth: number; breadth: number; nodeIndex: number; nodeId: string }
|
||||
| { type: 'error'; message: string; depth: number; nodeId: string }
|
||||
| { type: 'complete' };
|
||||
|
||||
// increase this if you have higher API rate limits
|
||||
@ -38,9 +39,9 @@ const tvly = tavily({
|
||||
})
|
||||
|
||||
/**
|
||||
* Schema for {@link generateSerpQueries} without dynamic descriptions
|
||||
* Schema for {@link generateSearchQueries} without dynamic descriptions
|
||||
*/
|
||||
export const serpQueriesTypeSchema = z.object({
|
||||
export const searchQueriesTypeSchema = z.object({
|
||||
queries: z.array(
|
||||
z.object({
|
||||
query: z.string(),
|
||||
@ -50,7 +51,7 @@ export const serpQueriesTypeSchema = z.object({
|
||||
});
|
||||
|
||||
// take en user query, return a list of SERP queries
|
||||
export function generateSerpQueries({
|
||||
export function generateSearchQueries({
|
||||
query,
|
||||
numQueries = 3,
|
||||
learnings,
|
||||
@ -92,11 +93,11 @@ export function generateSerpQueries({
|
||||
});
|
||||
}
|
||||
|
||||
export const serpResultTypeSchema = z.object({
|
||||
export const searchResultTypeSchema = z.object({
|
||||
learnings: z.array(z.string()),
|
||||
followUpQuestions: z.array(z.string()),
|
||||
});
|
||||
function processSerpResult({
|
||||
function processSearchResult({
|
||||
query,
|
||||
result,
|
||||
numLearnings = 3,
|
||||
@ -170,112 +171,159 @@ export async function writeFinalReport({
|
||||
return res.object.reportMarkdown + urlsSection;
|
||||
}
|
||||
|
||||
function childNodeId(parentNodeId: string, currentIndex: number) {
|
||||
return `${parentNodeId}-${currentIndex}`;
|
||||
}
|
||||
|
||||
export async function deepResearch({
|
||||
query,
|
||||
breadth,
|
||||
depth,
|
||||
maxDepth,
|
||||
learnings = [],
|
||||
visitedUrls = [],
|
||||
onProgress,
|
||||
currentDepth = 1,
|
||||
nodeId = '0'
|
||||
}: {
|
||||
query: string;
|
||||
breadth: number;
|
||||
depth: number;
|
||||
maxDepth: number;
|
||||
learnings?: string[];
|
||||
visitedUrls?: string[];
|
||||
onProgress: (step: ResearchStep) => void;
|
||||
currentDepth?: number;
|
||||
nodeId?: string
|
||||
}): Promise<void> {
|
||||
onProgress({ type: 'start', message: `开始深度研究,深度:${depth},广度:${breadth}`, depth, breadth });
|
||||
|
||||
try {
|
||||
const serpQueriesResult = generateSerpQueries({
|
||||
const searchQueriesResult = generateSearchQueries({
|
||||
query,
|
||||
learnings,
|
||||
numQueries: breadth,
|
||||
});
|
||||
const limit = pLimit(ConcurrencyLimit);
|
||||
|
||||
let serpQueries: PartialSerpQueries = [];
|
||||
let searchQueries: PartialSearchQuery[] = [];
|
||||
|
||||
for await (const parsedQueries of parseStreamingJson(
|
||||
serpQueriesResult.textStream,
|
||||
serpQueriesTypeSchema,
|
||||
searchQueriesResult.textStream,
|
||||
searchQueriesTypeSchema,
|
||||
(value) => !!value.queries?.length && !!value.queries[0]?.query
|
||||
)) {
|
||||
if (parsedQueries.queries) {
|
||||
serpQueries = parsedQueries.queries;
|
||||
onProgress({
|
||||
type: 'generating_queries',
|
||||
result: serpQueries,
|
||||
depth,
|
||||
breadth
|
||||
});
|
||||
for (let i = 0; i < searchQueries.length; i++) {
|
||||
onProgress({
|
||||
type: 'generating_query',
|
||||
result: searchQueries[i],
|
||||
depth: currentDepth,
|
||||
breadth,
|
||||
nodeIndex: i,
|
||||
nodeId: childNodeId(nodeId, i)
|
||||
});
|
||||
}
|
||||
searchQueries = parsedQueries.queries;
|
||||
}
|
||||
}
|
||||
|
||||
for (let i = 0; i < searchQueries.length; i++) {
|
||||
onProgress({
|
||||
type: 'generated_query',
|
||||
query,
|
||||
result: searchQueries[i],
|
||||
depth: currentDepth,
|
||||
breadth,
|
||||
nodeIndex: i,
|
||||
nodeId: childNodeId(nodeId, i)
|
||||
});
|
||||
}
|
||||
|
||||
await Promise.all(
|
||||
serpQueries.map(serpQuery =>
|
||||
searchQueries.map((searchQuery, nodeIndex) =>
|
||||
limit(async () => {
|
||||
if (!serpQuery?.query) return
|
||||
if (!searchQuery?.query) return
|
||||
onProgress({
|
||||
type: 'searching',
|
||||
query: searchQuery.query,
|
||||
depth: currentDepth,
|
||||
breadth,
|
||||
nodeIndex,
|
||||
nodeId: childNodeId(nodeId, nodeIndex)
|
||||
})
|
||||
try {
|
||||
// const result = await firecrawl.search(serpQuery.query, {
|
||||
// const result = await firecrawl.search(searchQuery.query, {
|
||||
// timeout: 15000,
|
||||
// limit: 5,
|
||||
// scrapeOptions: { formats: ['markdown'] },
|
||||
// });
|
||||
const result = await tvly.search(serpQuery.query, {
|
||||
const result = await tvly.search(searchQuery.query, {
|
||||
maxResults: 5,
|
||||
})
|
||||
console.log(`Ran ${serpQuery.query}, found ${result.results.length} contents`);
|
||||
console.log(`Ran ${searchQuery.query}, found ${result.results.length} contents`);
|
||||
|
||||
// Collect URLs from this search
|
||||
const newUrls = compact(result.results.map(item => item.url));
|
||||
const newBreadth = Math.ceil(breadth / 2);
|
||||
const newDepth = depth - 1;
|
||||
// Breadth for the next search is half of the current breadth
|
||||
const nextBreadth = Math.ceil(breadth / 2);
|
||||
|
||||
const serpResultGenerator = processSerpResult({
|
||||
query: serpQuery.query,
|
||||
const searchResultGenerator = processSearchResult({
|
||||
query: searchQuery.query,
|
||||
result,
|
||||
numFollowUpQuestions: newBreadth,
|
||||
numFollowUpQuestions: nextBreadth,
|
||||
});
|
||||
let serpResult: PartialSearchResult = {};
|
||||
let searchResult: PartialSearchResult = {};
|
||||
|
||||
for await (const parsedLearnings of parseStreamingJson(
|
||||
serpResultGenerator.textStream,
|
||||
serpResultTypeSchema,
|
||||
searchResultGenerator.textStream,
|
||||
searchResultTypeSchema,
|
||||
(value) => !!value.learnings?.length
|
||||
)) {
|
||||
serpResult = parsedLearnings;
|
||||
searchResult = parsedLearnings;
|
||||
onProgress({
|
||||
type: 'processing_serach_result',
|
||||
result: parsedLearnings,
|
||||
depth,
|
||||
breadth,
|
||||
query: serpQuery.query,
|
||||
queryIndex: serpQueries.indexOf(serpQuery),
|
||||
depth: currentDepth,
|
||||
breadth: breadth,
|
||||
query: searchQuery.query,
|
||||
nodeIndex: nodeIndex,
|
||||
nodeId: childNodeId(nodeId, nodeIndex)
|
||||
});
|
||||
}
|
||||
console.log(`Processed serp result for ${serpQuery.query}`, serpResult);
|
||||
const allLearnings = [...learnings, ...(serpResult.learnings ?? [])];
|
||||
console.log(`Processed search result for ${searchQuery.query}`, searchResult);
|
||||
const allLearnings = [...learnings, ...(searchResult.learnings ?? [])];
|
||||
const allUrls = [...visitedUrls, ...newUrls];
|
||||
const nextDepth = currentDepth + 1;
|
||||
|
||||
if (newDepth > 0 && serpResult.followUpQuestions?.length) {
|
||||
console.log(
|
||||
`Researching deeper, breadth: ${newBreadth}, depth: ${newDepth}`,
|
||||
onProgress({
|
||||
type: 'processed_search_result',
|
||||
result: {
|
||||
learnings: allLearnings,
|
||||
followUpQuestions: searchResult.followUpQuestions ?? [],
|
||||
},
|
||||
depth: currentDepth,
|
||||
breadth,
|
||||
query: searchQuery.query,
|
||||
nodeIndex: nodeIndex,
|
||||
nodeId: childNodeId(nodeId, nodeIndex)
|
||||
})
|
||||
|
||||
if (nextDepth < maxDepth && searchResult.followUpQuestions?.length) {
|
||||
console.warn(
|
||||
`Researching deeper, breadth: ${nextBreadth}, depth: ${nextDepth}`,
|
||||
);
|
||||
|
||||
const nextQuery = `
|
||||
Previous research goal: ${serpQuery.researchGoal}
|
||||
Follow-up research directions: ${serpResult.followUpQuestions.map(q => `\n${q}`).join('')}
|
||||
Previous research goal: ${searchQuery.researchGoal}
|
||||
Follow-up research directions: ${searchResult.followUpQuestions.map(q => `\n${q}`).join('')}
|
||||
`.trim();
|
||||
|
||||
return deepResearch({
|
||||
query: nextQuery,
|
||||
breadth: newBreadth,
|
||||
depth: newDepth,
|
||||
breadth: nextBreadth,
|
||||
maxDepth,
|
||||
learnings: allLearnings,
|
||||
visitedUrls: allUrls,
|
||||
onProgress,
|
||||
currentDepth: nextDepth,
|
||||
nodeId: childNodeId(nodeId, nodeIndex),
|
||||
});
|
||||
} else {
|
||||
return {
|
||||
@ -284,7 +332,7 @@ export async function deepResearch({
|
||||
};
|
||||
}
|
||||
} catch (e: any) {
|
||||
throw new Error(`Error searching for ${serpQuery.query}, depth ${depth}\nMessage: ${e.message}`)
|
||||
throw new Error(`Error searching for ${searchQuery.query}, depth ${currentDepth}\nMessage: ${e.message}`)
|
||||
}
|
||||
}),
|
||||
),
|
||||
@ -294,6 +342,8 @@ export async function deepResearch({
|
||||
onProgress({
|
||||
type: 'error',
|
||||
message: error?.message ?? 'Something went wrong',
|
||||
depth: currentDepth,
|
||||
nodeId,
|
||||
})
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user