style: prettier format

This commit is contained in:
AnotiaWang
2025-02-11 22:57:49 +08:00
parent 84f63abb3d
commit d027965013
23 changed files with 485 additions and 331 deletions

View File

@ -1,25 +1,25 @@
import { createOpenAI } from '@ai-sdk/openai';
import { getEncoding } from 'js-tiktoken';
import { createOpenAI } from '@ai-sdk/openai'
import { getEncoding } from 'js-tiktoken'
import { RecursiveCharacterTextSplitter } from './text-splitter';
import { RecursiveCharacterTextSplitter } from './text-splitter'
// Providers
const openai = createOpenAI({
apiKey: import.meta.env.VITE_OPENAI_API_KEY!,
baseURL: import.meta.env.VITE_OPENAI_ENDPOINT || 'https://api.openai.com/v1',
});
})
const customModel = import.meta.env.VITE_OPENAI_MODEL || 'o3-mini';
const customModel = import.meta.env.VITE_OPENAI_MODEL || 'o3-mini'
// Models
export const o3MiniModel = openai(customModel, {
// reasoningEffort: customModel.startsWith('o') ? 'medium' : undefined,
structuredOutputs: true,
});
})
const MinChunkSize = 140;
const encoder = getEncoding('o200k_base');
const MinChunkSize = 140
const encoder = getEncoding('o200k_base')
// trim prompt to maximum context size
export function trimPrompt(
@ -27,32 +27,32 @@ export function trimPrompt(
contextSize = Number(import.meta.env.VITE_CONTEXT_SIZE) || 128_000,
) {
if (!prompt) {
return '';
return ''
}
const length = encoder.encode(prompt).length;
const length = encoder.encode(prompt).length
if (length <= contextSize) {
return prompt;
return prompt
}
const overflowTokens = length - contextSize;
const overflowTokens = length - contextSize
// on average it's 3 characters per token, so multiply by 3 to get a rough estimate of the number of characters
const chunkSize = prompt.length - overflowTokens * 3;
const chunkSize = prompt.length - overflowTokens * 3
if (chunkSize < MinChunkSize) {
return prompt.slice(0, MinChunkSize);
return prompt.slice(0, MinChunkSize)
}
const splitter = new RecursiveCharacterTextSplitter({
chunkSize,
chunkOverlap: 0,
});
const trimmedPrompt = splitter.splitText(prompt)[0] ?? '';
})
const trimmedPrompt = splitter.splitText(prompt)[0] ?? ''
// last catch, there's a chance that the trimmed prompt is same length as the original prompt, due to how tokens are split & innerworkings of the splitter, handle this case by just doing a hard cut
if (trimmedPrompt.length === prompt.length) {
return trimPrompt(prompt.slice(0, chunkSize), contextSize);
return trimPrompt(prompt.slice(0, chunkSize), contextSize)
}
// recursively trim until the prompt is within the context size
return trimPrompt(trimmedPrompt, contextSize);
return trimPrompt(trimmedPrompt, contextSize)
}

View File

@ -1,77 +1,80 @@
import assert from 'node:assert';
import { describe, it, beforeEach } from 'node:test';
import { RecursiveCharacterTextSplitter } from './text-splitter';
import assert from 'node:assert'
import { describe, it, beforeEach } from 'node:test'
import { RecursiveCharacterTextSplitter } from './text-splitter'
describe('RecursiveCharacterTextSplitter', () => {
let splitter: RecursiveCharacterTextSplitter;
let splitter: RecursiveCharacterTextSplitter
beforeEach(() => {
splitter = new RecursiveCharacterTextSplitter({
chunkSize: 50,
chunkOverlap: 10,
});
});
})
})
it('Should correctly split text by separators', () => {
const text = 'Hello world, this is a test of the recursive text splitter.';
const text = 'Hello world, this is a test of the recursive text splitter.'
// Test with initial chunkSize
assert.deepEqual(
splitter.splitText(text),
['Hello world', 'this is a test of the recursive text splitter']
);
assert.deepEqual(splitter.splitText(text), [
'Hello world',
'this is a test of the recursive text splitter',
])
// Test with updated chunkSize
splitter.chunkSize = 100;
splitter.chunkSize = 100
assert.deepEqual(
splitter.splitText(
'Hello world, this is a test of the recursive text splitter. If I have a period, it should split along the period.'
'Hello world, this is a test of the recursive text splitter. If I have a period, it should split along the period.',
),
[
'Hello world, this is a test of the recursive text splitter',
'If I have a period, it should split along the period.',
]
);
],
)
// Test with another updated chunkSize
splitter.chunkSize = 110;
splitter.chunkSize = 110
assert.deepEqual(
splitter.splitText(
'Hello world, this is a test of the recursive text splitter. If I have a period, it should split along the period.\nOr, if there is a new line, it should prioritize splitting on new lines instead.'
'Hello world, this is a test of the recursive text splitter. If I have a period, it should split along the period.\nOr, if there is a new line, it should prioritize splitting on new lines instead.',
),
[
'Hello world, this is a test of the recursive text splitter',
'If I have a period, it should split along the period.',
'Or, if there is a new line, it should prioritize splitting on new lines instead.',
]
);
});
],
)
})
it('Should handle empty string', () => {
assert.deepEqual(splitter.splitText(''), []);
});
assert.deepEqual(splitter.splitText(''), [])
})
it('Should handle special characters and large texts', () => {
const largeText = 'A'.repeat(1000);
splitter.chunkSize = 200;
const largeText = 'A'.repeat(1000)
splitter.chunkSize = 200
assert.deepEqual(
splitter.splitText(largeText),
Array(5).fill('A'.repeat(200))
);
Array(5).fill('A'.repeat(200)),
)
const specialCharText = 'Hello!@# world$%^ &*( this) is+ a-test';
assert.deepEqual(
splitter.splitText(specialCharText),
['Hello!@#', 'world$%^', '&*( this)', 'is+', 'a-test']
);
});
const specialCharText = 'Hello!@# world$%^ &*( this) is+ a-test'
assert.deepEqual(splitter.splitText(specialCharText), [
'Hello!@#',
'world$%^',
'&*( this)',
'is+',
'a-test',
])
})
it('Should handle chunkSize equal to chunkOverlap', () => {
splitter.chunkSize = 50;
splitter.chunkOverlap = 50;
splitter.chunkSize = 50
splitter.chunkOverlap = 50
assert.throws(
() => splitter.splitText('Invalid configuration'),
new Error('Cannot have chunkOverlap >= chunkSize')
);
});
});
new Error('Cannot have chunkOverlap >= chunkSize'),
)
})
})

View File

@ -1,60 +1,60 @@
interface TextSplitterParams {
chunkSize: number;
chunkSize: number
chunkOverlap: number;
chunkOverlap: number
}
abstract class TextSplitter implements TextSplitterParams {
chunkSize = 1000;
chunkOverlap = 200;
chunkSize = 1000
chunkOverlap = 200
constructor(fields?: Partial<TextSplitterParams>) {
this.chunkSize = fields?.chunkSize ?? this.chunkSize;
this.chunkOverlap = fields?.chunkOverlap ?? this.chunkOverlap;
this.chunkSize = fields?.chunkSize ?? this.chunkSize
this.chunkOverlap = fields?.chunkOverlap ?? this.chunkOverlap
if (this.chunkOverlap >= this.chunkSize) {
throw new Error('Cannot have chunkOverlap >= chunkSize');
throw new Error('Cannot have chunkOverlap >= chunkSize')
}
}
abstract splitText(text: string): string[];
abstract splitText(text: string): string[]
createDocuments(texts: string[]): string[] {
const documents: string[] = [];
const documents: string[] = []
for (let i = 0; i < texts.length; i += 1) {
const text = texts[i];
const text = texts[i]
for (const chunk of this.splitText(text!)) {
documents.push(chunk);
documents.push(chunk)
}
}
return documents;
return documents
}
splitDocuments(documents: string[]): string[] {
return this.createDocuments(documents);
return this.createDocuments(documents)
}
private joinDocs(docs: string[], separator: string): string | null {
const text = docs.join(separator).trim();
return text === '' ? null : text;
const text = docs.join(separator).trim()
return text === '' ? null : text
}
mergeSplits(splits: string[], separator: string): string[] {
const docs: string[] = [];
const currentDoc: string[] = [];
let total = 0;
const docs: string[] = []
const currentDoc: string[] = []
let total = 0
for (const d of splits) {
const _len = d.length;
const _len = d.length
if (total + _len >= this.chunkSize) {
if (total > this.chunkSize) {
console.warn(
`Created a chunk of size ${total}, +
which is longer than the specified ${this.chunkSize}`,
);
)
}
if (currentDoc.length > 0) {
const doc = this.joinDocs(currentDoc, separator);
const doc = this.joinDocs(currentDoc, separator)
if (doc !== null) {
docs.push(doc);
docs.push(doc)
}
// Keep on popping if:
// - we have a larger chunk than in the chunk overlap
@ -63,81 +63,81 @@ which is longer than the specified ${this.chunkSize}`,
total > this.chunkOverlap ||
(total + _len > this.chunkSize && total > 0)
) {
total -= currentDoc[0]!.length;
currentDoc.shift();
total -= currentDoc[0]!.length
currentDoc.shift()
}
}
}
currentDoc.push(d);
total += _len;
currentDoc.push(d)
total += _len
}
const doc = this.joinDocs(currentDoc, separator);
const doc = this.joinDocs(currentDoc, separator)
if (doc !== null) {
docs.push(doc);
docs.push(doc)
}
return docs;
return docs
}
}
export interface RecursiveCharacterTextSplitterParams
extends TextSplitterParams {
separators: string[];
separators: string[]
}
export class RecursiveCharacterTextSplitter
extends TextSplitter
implements RecursiveCharacterTextSplitterParams
{
separators: string[] = ['\n\n', '\n', '.', ',', '>', '<', ' ', ''];
separators: string[] = ['\n\n', '\n', '.', ',', '>', '<', ' ', '']
constructor(fields?: Partial<RecursiveCharacterTextSplitterParams>) {
super(fields);
this.separators = fields?.separators ?? this.separators;
super(fields)
this.separators = fields?.separators ?? this.separators
}
splitText(text: string): string[] {
const finalChunks: string[] = [];
const finalChunks: string[] = []
// Get appropriate separator to use
let separator: string = this.separators[this.separators.length - 1]!;
let separator: string = this.separators[this.separators.length - 1]!
for (const s of this.separators) {
if (s === '') {
separator = s;
break;
separator = s
break
}
if (text.includes(s)) {
separator = s;
break;
separator = s
break
}
}
// Now that we have the separator, split the text
let splits: string[];
let splits: string[]
if (separator) {
splits = text.split(separator);
splits = text.split(separator)
} else {
splits = text.split('');
splits = text.split('')
}
// Now go merging things, recursively splitting longer texts.
let goodSplits: string[] = [];
let goodSplits: string[] = []
for (const s of splits) {
if (s.length < this.chunkSize) {
goodSplits.push(s);
goodSplits.push(s)
} else {
if (goodSplits.length) {
const mergedText = this.mergeSplits(goodSplits, separator);
finalChunks.push(...mergedText);
goodSplits = [];
const mergedText = this.mergeSplits(goodSplits, separator)
finalChunks.push(...mergedText)
goodSplits = []
}
const otherInfo = this.splitText(s);
finalChunks.push(...otherInfo);
const otherInfo = this.splitText(s)
finalChunks.push(...otherInfo)
}
}
if (goodSplits.length) {
const mergedText = this.mergeSplits(goodSplits, separator);
finalChunks.push(...mergedText);
const mergedText = this.mergeSplits(goodSplits, separator)
finalChunks.push(...mergedText)
}
return finalChunks;
return finalChunks
}
}

View File

@ -1,43 +1,57 @@
import { generateObject, streamText } from 'ai';
import { compact } from 'lodash-es';
import pLimit from 'p-limit';
import { z } from 'zod';
import { parseStreamingJson, type DeepPartial } from '~/utils/json';
import { generateObject, streamText } from 'ai'
import { compact } from 'lodash-es'
import pLimit from 'p-limit'
import { z } from 'zod'
import { parseStreamingJson, type DeepPartial } from '~/utils/json'
import { o3MiniModel, trimPrompt } from './ai/providers';
import { systemPrompt } from './prompt';
import zodToJsonSchema from 'zod-to-json-schema';
import { tavily, type TavilySearchResponse } from '@tavily/core';
import { o3MiniModel, trimPrompt } from './ai/providers'
import { systemPrompt } from './prompt'
import zodToJsonSchema from 'zod-to-json-schema'
import { tavily, type TavilySearchResponse } from '@tavily/core'
export type ResearchResult = {
learnings: string[];
visitedUrls: string[];
};
learnings: string[]
visitedUrls: string[]
}
export interface WriteFinalReportParams {
prompt: string;
learnings: string[];
prompt: string
learnings: string[]
}
// useRuntimeConfig()
// Used for streaming response
export type SearchQuery = z.infer<typeof searchQueriesTypeSchema>['queries'][0];
export type PartialSearchQuery = DeepPartial<SearchQuery>;
export type SearchResult = z.infer<typeof searchResultTypeSchema>;
export type PartialSearchResult = DeepPartial<SearchResult>;
export type SearchQuery = z.infer<typeof searchQueriesTypeSchema>['queries'][0]
export type PartialSearchQuery = DeepPartial<SearchQuery>
export type SearchResult = z.infer<typeof searchResultTypeSchema>
export type PartialSearchResult = DeepPartial<SearchResult>
export type ResearchStep =
| { type: 'generating_query'; result: PartialSearchQuery; nodeId: string }
| { type: 'generated_query'; query: string; result: PartialSearchQuery; nodeId: string }
| {
type: 'generated_query'
query: string
result: PartialSearchQuery
nodeId: string
}
| { type: 'searching'; query: string; nodeId: string }
| { type: 'search_complete'; urls: string[]; nodeId: string }
| { type: 'processing_serach_result'; query: string; result: PartialSearchResult; nodeId: string }
| { type: 'processed_search_result'; query: string; result: SearchResult; nodeId: string }
| {
type: 'processing_serach_result'
query: string
result: PartialSearchResult
nodeId: string
}
| {
type: 'processed_search_result'
query: string
result: SearchResult
nodeId: string
}
| { type: 'error'; message: string; nodeId: string }
| { type: 'complete'; learnings: string[], visitedUrls: string[] };
| { type: 'complete'; learnings: string[]; visitedUrls: string[] }
// increase this if you have higher API rate limits
const ConcurrencyLimit = 2;
const ConcurrencyLimit = 2
// Initialize Firecrawl with optional API key and optional base url
@ -59,7 +73,7 @@ export const searchQueriesTypeSchema = z.object({
researchGoal: z.string(),
}),
),
});
})
// take en user query, return a list of SERP queries
export function generateSearchQueries({
@ -67,10 +81,10 @@ export function generateSearchQueries({
numQueries = 3,
learnings,
}: {
query: string;
numQueries?: number;
query: string
numQueries?: number
// optional, if provided, the research will continue from the last learning
learnings?: string[];
learnings?: string[]
}) {
const schema = z.object({
queries: z
@ -84,40 +98,38 @@ export function generateSearchQueries({
),
}),
)
.describe(`List of SERP queries, max of ${numQueries}`)
.describe(`List of SERP queries, max of ${numQueries}`),
})
const jsonSchema = JSON.stringify(zodToJsonSchema(schema));
const jsonSchema = JSON.stringify(zodToJsonSchema(schema))
const prompt = [
`Given the following prompt from the user, generate a list of SERP queries to research the topic. Return a maximum of ${numQueries} queries, but feel free to return less if the original prompt is clear. Make sure each query is unique and not similar to each other: <prompt>${query}</prompt>\n\n`,
learnings
? `Here are some learnings from previous research, use them to generate more specific queries: ${learnings.join(
'\n',
)}`
? `Here are some learnings from previous research, use them to generate more specific queries: ${learnings.join('\n')}`
: '',
`You MUST respond in JSON with the following schema: ${jsonSchema}`,
].join('\n\n');
].join('\n\n')
return streamText({
model: o3MiniModel,
system: systemPrompt(),
prompt,
});
})
}
export const searchResultTypeSchema = z.object({
learnings: z.array(z.string()),
followUpQuestions: z.array(z.string()),
});
})
function processSearchResult({
query,
result,
numLearnings = 3,
numFollowUpQuestions = 3,
}: {
query: string;
query: string
result: TavilySearchResponse
numLearnings?: number;
numFollowUpQuestions?: number;
numLearnings?: number
numFollowUpQuestions?: number
}) {
const schema = z.object({
learnings: z
@ -128,25 +140,23 @@ function processSearchResult({
.describe(
`List of follow-up questions to research the topic further, max of ${numFollowUpQuestions}`,
),
});
const jsonSchema = JSON.stringify(zodToJsonSchema(schema));
const contents = compact(result.results.map(item => item.content)).map(
content => trimPrompt(content, 25_000),
);
})
const jsonSchema = JSON.stringify(zodToJsonSchema(schema))
const contents = compact(result.results.map((item) => item.content)).map(
(content) => trimPrompt(content, 25_000),
)
const prompt = [
`Given the following contents from a SERP search for the query <query>${query}</query>, generate a list of learnings from the contents. Return a maximum of ${numLearnings} learnings, but feel free to return less if the contents are clear. Make sure each learning is unique and not similar to each other. The learnings should be concise and to the point, as detailed and information dense as possible. Make sure to include any entities like people, places, companies, products, things, etc in the learnings, as well as any exact metrics, numbers, or dates. The learnings will be used to research the topic further.`,
`<contents>${contents
.map(content => `<content>\n${content}\n</content>`)
.join('\n')}</contents>`,
`<contents>${contents.map((content) => `<content>\n${content}\n</content>`).join('\n')}</contents>`,
`You MUST respond in JSON with the following schema: ${jsonSchema}`,
].join('\n\n');
].join('\n\n')
return streamText({
model: o3MiniModel,
abortSignal: AbortSignal.timeout(60_000),
system: systemPrompt(),
prompt,
});
})
}
export function writeFinalReport({
@ -155,28 +165,28 @@ export function writeFinalReport({
}: WriteFinalReportParams) {
const learningsString = trimPrompt(
learnings
.map(learning => `<learning>\n${learning}\n</learning>`)
.map((learning) => `<learning>\n${learning}\n</learning>`)
.join('\n'),
150_000,
);
)
const _prompt = [
`Given the following prompt from the user, write a final report on the topic using the learnings from research. Make it as as detailed as possible, aim for 3 or more pages, include ALL the learnings from research:`,
`<prompt>${prompt}</prompt>`,
`Here are all the learnings from previous research:`,
`<learnings>\n${learningsString}\n</learnings>`,
`Write the report in Markdown.`,
`## Deep Research Report`
].join('\n\n');
`## Deep Research Report`,
].join('\n\n')
return streamText({
model: o3MiniModel,
system: systemPrompt(),
prompt: _prompt,
});
})
}
function childNodeId(parentNodeId: string, currentIndex: number) {
return `${parentNodeId}-${currentIndex}`;
return `${parentNodeId}-${currentIndex}`
}
export async function deepResearch({
@ -187,15 +197,15 @@ export async function deepResearch({
visitedUrls = [],
onProgress,
currentDepth = 1,
nodeId = '0'
nodeId = '0',
}: {
query: string;
breadth: number;
maxDepth: number;
learnings?: string[];
visitedUrls?: string[];
onProgress: (step: ResearchStep) => void;
currentDepth?: number;
query: string
breadth: number
maxDepth: number
learnings?: string[]
visitedUrls?: string[]
onProgress: (step: ResearchStep) => void
currentDepth?: number
nodeId?: string
}): Promise<ResearchResult> {
try {
@ -203,25 +213,25 @@ export async function deepResearch({
query,
learnings,
numQueries: breadth,
});
const limit = pLimit(ConcurrencyLimit);
})
const limit = pLimit(ConcurrencyLimit)
let searchQueries: PartialSearchQuery[] = [];
let searchQueries: PartialSearchQuery[] = []
for await (const parsedQueries of parseStreamingJson(
searchQueriesResult.textStream,
searchQueriesTypeSchema,
(value) => !!value.queries?.length && !!value.queries[0]?.query
(value) => !!value.queries?.length && !!value.queries[0]?.query,
)) {
if (parsedQueries.queries) {
for (let i = 0; i < searchQueries.length; i++) {
onProgress({
type: 'generating_query',
result: searchQueries[i],
nodeId: childNodeId(nodeId, i)
});
nodeId: childNodeId(nodeId, i),
})
}
searchQueries = parsedQueries.queries;
searchQueries = parsedQueries.queries
}
}
@ -230,21 +240,22 @@ export async function deepResearch({
type: 'generated_query',
query,
result: searchQueries[i],
nodeId: childNodeId(nodeId, i)
});
nodeId: childNodeId(nodeId, i),
})
}
const results = await Promise.all(
searchQueries.map((searchQuery, i) =>
limit(async () => {
if (!searchQuery?.query) return {
learnings: [],
visitedUrls: [],
}
if (!searchQuery?.query)
return {
learnings: [],
visitedUrls: [],
}
onProgress({
type: 'searching',
query: searchQuery.query,
nodeId: childNodeId(nodeId, i)
nodeId: childNodeId(nodeId, i),
})
try {
// const result = await firecrawl.search(searchQuery.query, {
@ -255,42 +266,50 @@ export async function deepResearch({
const result = await tvly.search(searchQuery.query, {
maxResults: 5,
})
console.log(`Ran ${searchQuery.query}, found ${result.results.length} contents`);
console.log(
`Ran ${searchQuery.query}, found ${result.results.length} contents`,
)
// Collect URLs from this search
const newUrls = compact(result.results.map(item => item.url));
const newUrls = compact(result.results.map((item) => item.url))
onProgress({
type: 'search_complete',
urls: newUrls,
nodeId: childNodeId(nodeId, i),
})
// Breadth for the next search is half of the current breadth
const nextBreadth = Math.ceil(breadth / 2);
const nextBreadth = Math.ceil(breadth / 2)
const searchResultGenerator = processSearchResult({
query: searchQuery.query,
result,
numFollowUpQuestions: nextBreadth,
});
let searchResult: PartialSearchResult = {};
})
let searchResult: PartialSearchResult = {}
for await (const parsedLearnings of parseStreamingJson(
searchResultGenerator.textStream,
searchResultTypeSchema,
(value) => !!value.learnings?.length
(value) => !!value.learnings?.length,
)) {
searchResult = parsedLearnings;
searchResult = parsedLearnings
onProgress({
type: 'processing_serach_result',
result: parsedLearnings,
query: searchQuery.query,
nodeId: childNodeId(nodeId, i)
});
nodeId: childNodeId(nodeId, i),
})
}
console.log(`Processed search result for ${searchQuery.query}`, searchResult);
const allLearnings = [...learnings, ...(searchResult.learnings ?? [])];
const allUrls = [...visitedUrls, ...newUrls];
const nextDepth = currentDepth + 1;
console.log(
`Processed search result for ${searchQuery.query}`,
searchResult,
)
const allLearnings = [
...learnings,
...(searchResult.learnings ?? []),
]
const allUrls = [...visitedUrls, ...newUrls]
const nextDepth = currentDepth + 1
onProgress({
type: 'processed_search_result',
@ -299,18 +318,21 @@ export async function deepResearch({
followUpQuestions: searchResult.followUpQuestions ?? [],
},
query: searchQuery.query,
nodeId: childNodeId(nodeId, i)
nodeId: childNodeId(nodeId, i),
})
if (nextDepth < maxDepth && searchResult.followUpQuestions?.length) {
if (
nextDepth < maxDepth &&
searchResult.followUpQuestions?.length
) {
console.warn(
`Researching deeper, breadth: ${nextBreadth}, depth: ${nextDepth}`,
);
)
const nextQuery = `
Previous research goal: ${searchQuery.researchGoal}
Follow-up research directions: ${searchResult.followUpQuestions.map(q => `\n${q}`).join('')}
`.trim();
Follow-up research directions: ${searchResult.followUpQuestions.map((q) => `\n${q}`).join('')}
`.trim()
return deepResearch({
query: nextQuery,
@ -321,36 +343,38 @@ export async function deepResearch({
onProgress,
currentDepth: nextDepth,
nodeId: childNodeId(nodeId, i),
});
})
} else {
return {
learnings: allLearnings,
visitedUrls: allUrls,
};
}
}
} catch (e: any) {
throw new Error(`Error searching for ${searchQuery.query}, depth ${currentDepth}\nMessage: ${e.message}`)
throw new Error(
`Error searching for ${searchQuery.query}, depth ${currentDepth}\nMessage: ${e.message}`,
)
}
}),
),
);
)
// Conclude results
const _learnings = [...new Set(results.flatMap(r => r.learnings))]
const _visitedUrls = [...new Set(results.flatMap(r => r.visitedUrls))]
const _learnings = [...new Set(results.flatMap((r) => r.learnings))]
const _visitedUrls = [...new Set(results.flatMap((r) => r.visitedUrls))]
// Complete should only be called once
if (nodeId === '0') {
onProgress({
type: 'complete',
learnings: _learnings,
visitedUrls: _visitedUrls,
});
})
}
return {
learnings: _learnings,
visitedUrls: _visitedUrls,
}
} catch (error: any) {
console.error(error);
console.error(error)
onProgress({
type: 'error',
message: error?.message ?? 'Something went wrong',
@ -361,4 +385,4 @@ export async function deepResearch({
visitedUrls: [],
}
}
}
}

View File

@ -1,22 +1,22 @@
import { streamText } from 'ai';
import { z } from 'zod';
import { streamText } from 'ai'
import { z } from 'zod'
import { zodToJsonSchema } from 'zod-to-json-schema'
import { o3MiniModel } from './ai/providers';
import { systemPrompt } from './prompt';
import { o3MiniModel } from './ai/providers'
import { systemPrompt } from './prompt'
type PartialFeedback = DeepPartial<z.infer<typeof feedbackTypeSchema>>
export const feedbackTypeSchema = z.object({
questions: z.array(z.string())
questions: z.array(z.string()),
})
export function generateFeedback({
query,
numQuestions = 3,
}: {
query: string;
numQuestions?: number;
query: string
numQuestions?: number
}) {
const schema = z.object({
questions: z
@ -24,22 +24,22 @@ export function generateFeedback({
.describe(
`Follow up questions to clarify the research direction, max of ${numQuestions}`,
),
});
const jsonSchema = JSON.stringify(zodToJsonSchema(schema));
})
const jsonSchema = JSON.stringify(zodToJsonSchema(schema))
const prompt = [
`Given the following query from the user, ask some follow up questions to clarify the research direction. Return a maximum of ${numQuestions} questions, but feel free to return less if the original query is clear: <query>${query}</query>`,
`You MUST respond in JSON with the following schema: ${jsonSchema}`,
].join('\n\n');
].join('\n\n')
const stream = streamText({
model: o3MiniModel,
system: systemPrompt(),
prompt,
});
})
return parseStreamingJson(
stream.textStream,
feedbackTypeSchema,
(value: PartialFeedback) => !!value.questions && value.questions.length > 0
(value: PartialFeedback) => !!value.questions && value.questions.length > 0,
)
}

View File

@ -1,5 +1,5 @@
export const systemPrompt = () => {
const now = new Date().toISOString();
const now = new Date().toISOString()
return `You are an expert researcher. Today is ${now}. Follow these instructions when responding:
- You may be asked to research subjects that is after your knowledge cutoff, assume the user is right when presented with news.
- The user is a highly experienced analyst, no need to simplify it, be as detailed as possible and make sure your response is correct.
@ -11,5 +11,5 @@ export const systemPrompt = () => {
- Provide detailed explanations, I'm comfortable with lots of detail.
- Value good arguments over authorities, the source is irrelevant.
- Consider new technologies and contrarian ideas, not just the conventional wisdom.
- You may use high levels of speculation or prediction, just flag it for me.`;
};
- You may use high levels of speculation or prediction, just flag it for me.`
}

View File

@ -1,27 +1,27 @@
import * as fs from 'fs/promises';
import * as readline from 'readline';
import * as fs from 'fs/promises'
import * as readline from 'readline'
import { deepResearch, writeFinalReport } from './deep-research';
import { generateFeedback } from './feedback';
import { deepResearch, writeFinalReport } from './deep-research'
import { generateFeedback } from './feedback'
const rl = readline.createInterface({
input: process.stdin,
output: process.stdout,
});
})
// Helper function to get user input
function askQuestion(query: string): Promise<string> {
return new Promise(resolve => {
rl.question(query, answer => {
resolve(answer);
});
});
return new Promise((resolve) => {
rl.question(query, (answer) => {
resolve(answer)
})
})
}
// run the agent
async function run() {
// Get initial query
const initialQuery = await askQuestion('What would you like to research? ');
const initialQuery = await askQuestion('What would you like to research? ')
// Get breath and depth parameters
const breadth =
@ -30,29 +30,29 @@ async function run() {
'Enter research breadth (recommended 2-10, default 4): ',
),
10,
) || 4;
) || 4
const depth =
parseInt(
await askQuestion('Enter research depth (recommended 1-5, default 2): '),
10,
) || 2;
) || 2
console.log(`Creating research plan...`);
console.log(`Creating research plan...`)
// Generate follow-up questions
const followUpQuestions = await generateFeedback({
query: initialQuery,
});
})
console.log(
'\nTo better understand your research needs, please answer these follow-up questions:',
);
)
// Collect answers to follow-up questions
const answers: string[] = [];
const answers: string[] = []
for (const question of followUpQuestions) {
const answer = await askQuestion(`\n${question}\nYour answer: `);
answers.push(answer);
const answer = await askQuestion(`\n${question}\nYour answer: `)
answers.push(answer)
}
// Combine all information for deep research
@ -60,34 +60,34 @@ async function run() {
Initial Query: ${initialQuery}
Follow-up Questions and Answers:
${followUpQuestions.map((q, i) => `Q: ${q}\nA: ${answers[i]}`).join('\n')}
`;
`
console.log('\nResearching your topic...');
console.log('\nResearching your topic...')
const { learnings, visitedUrls } = await deepResearch({
query: combinedQuery,
breadth,
depth,
});
})
console.log(`\n\nLearnings:\n\n${learnings.join('\n')}`);
console.log(`\n\nLearnings:\n\n${learnings.join('\n')}`)
console.log(
`\n\nVisited URLs (${visitedUrls.length}):\n\n${visitedUrls.join('\n')}`,
);
console.log('Writing final report...');
)
console.log('Writing final report...')
const report = await writeFinalReport({
prompt: combinedQuery,
learnings,
visitedUrls,
});
})
// Save report to file
await fs.writeFile('output.md', report, 'utf-8');
await fs.writeFile('output.md', report, 'utf-8')
console.log(`\n\nFinal Report:\n\n${report}`);
console.log('\nReport has been saved to output.md');
rl.close();
console.log(`\n\nFinal Report:\n\n${report}`)
console.log('\nReport has been saved to output.md')
rl.close()
}
run().catch(console.error);
run().catch(console.error)