style: prettier format

This commit is contained in:
AnotiaWang
2025-02-11 22:57:49 +08:00
parent 84f63abb3d
commit d027965013
23 changed files with 485 additions and 331 deletions

View File

@ -1,25 +1,25 @@
import { createOpenAI } from '@ai-sdk/openai';
import { getEncoding } from 'js-tiktoken';
import { createOpenAI } from '@ai-sdk/openai'
import { getEncoding } from 'js-tiktoken'
import { RecursiveCharacterTextSplitter } from './text-splitter';
import { RecursiveCharacterTextSplitter } from './text-splitter'
// Providers
const openai = createOpenAI({
apiKey: import.meta.env.VITE_OPENAI_API_KEY!,
baseURL: import.meta.env.VITE_OPENAI_ENDPOINT || 'https://api.openai.com/v1',
});
})
const customModel = import.meta.env.VITE_OPENAI_MODEL || 'o3-mini';
const customModel = import.meta.env.VITE_OPENAI_MODEL || 'o3-mini'
// Models
export const o3MiniModel = openai(customModel, {
// reasoningEffort: customModel.startsWith('o') ? 'medium' : undefined,
structuredOutputs: true,
});
})
const MinChunkSize = 140;
const encoder = getEncoding('o200k_base');
const MinChunkSize = 140
const encoder = getEncoding('o200k_base')
// trim prompt to maximum context size
export function trimPrompt(
@ -27,32 +27,32 @@ export function trimPrompt(
contextSize = Number(import.meta.env.VITE_CONTEXT_SIZE) || 128_000,
) {
if (!prompt) {
return '';
return ''
}
const length = encoder.encode(prompt).length;
const length = encoder.encode(prompt).length
if (length <= contextSize) {
return prompt;
return prompt
}
const overflowTokens = length - contextSize;
const overflowTokens = length - contextSize
// on average it's 3 characters per token, so multiply by 3 to get a rough estimate of the number of characters
const chunkSize = prompt.length - overflowTokens * 3;
const chunkSize = prompt.length - overflowTokens * 3
if (chunkSize < MinChunkSize) {
return prompt.slice(0, MinChunkSize);
return prompt.slice(0, MinChunkSize)
}
const splitter = new RecursiveCharacterTextSplitter({
chunkSize,
chunkOverlap: 0,
});
const trimmedPrompt = splitter.splitText(prompt)[0] ?? '';
})
const trimmedPrompt = splitter.splitText(prompt)[0] ?? ''
// last catch, there's a chance that the trimmed prompt is same length as the original prompt, due to how tokens are split & innerworkings of the splitter, handle this case by just doing a hard cut
if (trimmedPrompt.length === prompt.length) {
return trimPrompt(prompt.slice(0, chunkSize), contextSize);
return trimPrompt(prompt.slice(0, chunkSize), contextSize)
}
// recursively trim until the prompt is within the context size
return trimPrompt(trimmedPrompt, contextSize);
return trimPrompt(trimmedPrompt, contextSize)
}

View File

@ -1,77 +1,80 @@
import assert from 'node:assert';
import { describe, it, beforeEach } from 'node:test';
import { RecursiveCharacterTextSplitter } from './text-splitter';
import assert from 'node:assert'
import { describe, it, beforeEach } from 'node:test'
import { RecursiveCharacterTextSplitter } from './text-splitter'
describe('RecursiveCharacterTextSplitter', () => {
let splitter: RecursiveCharacterTextSplitter;
let splitter: RecursiveCharacterTextSplitter
beforeEach(() => {
splitter = new RecursiveCharacterTextSplitter({
chunkSize: 50,
chunkOverlap: 10,
});
});
})
})
it('Should correctly split text by separators', () => {
const text = 'Hello world, this is a test of the recursive text splitter.';
const text = 'Hello world, this is a test of the recursive text splitter.'
// Test with initial chunkSize
assert.deepEqual(
splitter.splitText(text),
['Hello world', 'this is a test of the recursive text splitter']
);
assert.deepEqual(splitter.splitText(text), [
'Hello world',
'this is a test of the recursive text splitter',
])
// Test with updated chunkSize
splitter.chunkSize = 100;
splitter.chunkSize = 100
assert.deepEqual(
splitter.splitText(
'Hello world, this is a test of the recursive text splitter. If I have a period, it should split along the period.'
'Hello world, this is a test of the recursive text splitter. If I have a period, it should split along the period.',
),
[
'Hello world, this is a test of the recursive text splitter',
'If I have a period, it should split along the period.',
]
);
],
)
// Test with another updated chunkSize
splitter.chunkSize = 110;
splitter.chunkSize = 110
assert.deepEqual(
splitter.splitText(
'Hello world, this is a test of the recursive text splitter. If I have a period, it should split along the period.\nOr, if there is a new line, it should prioritize splitting on new lines instead.'
'Hello world, this is a test of the recursive text splitter. If I have a period, it should split along the period.\nOr, if there is a new line, it should prioritize splitting on new lines instead.',
),
[
'Hello world, this is a test of the recursive text splitter',
'If I have a period, it should split along the period.',
'Or, if there is a new line, it should prioritize splitting on new lines instead.',
]
);
});
],
)
})
it('Should handle empty string', () => {
assert.deepEqual(splitter.splitText(''), []);
});
assert.deepEqual(splitter.splitText(''), [])
})
it('Should handle special characters and large texts', () => {
const largeText = 'A'.repeat(1000);
splitter.chunkSize = 200;
const largeText = 'A'.repeat(1000)
splitter.chunkSize = 200
assert.deepEqual(
splitter.splitText(largeText),
Array(5).fill('A'.repeat(200))
);
Array(5).fill('A'.repeat(200)),
)
const specialCharText = 'Hello!@# world$%^ &*( this) is+ a-test';
assert.deepEqual(
splitter.splitText(specialCharText),
['Hello!@#', 'world$%^', '&*( this)', 'is+', 'a-test']
);
});
const specialCharText = 'Hello!@# world$%^ &*( this) is+ a-test'
assert.deepEqual(splitter.splitText(specialCharText), [
'Hello!@#',
'world$%^',
'&*( this)',
'is+',
'a-test',
])
})
it('Should handle chunkSize equal to chunkOverlap', () => {
splitter.chunkSize = 50;
splitter.chunkOverlap = 50;
splitter.chunkSize = 50
splitter.chunkOverlap = 50
assert.throws(
() => splitter.splitText('Invalid configuration'),
new Error('Cannot have chunkOverlap >= chunkSize')
);
});
});
new Error('Cannot have chunkOverlap >= chunkSize'),
)
})
})

View File

@ -1,60 +1,60 @@
interface TextSplitterParams {
chunkSize: number;
chunkSize: number
chunkOverlap: number;
chunkOverlap: number
}
abstract class TextSplitter implements TextSplitterParams {
chunkSize = 1000;
chunkOverlap = 200;
chunkSize = 1000
chunkOverlap = 200
constructor(fields?: Partial<TextSplitterParams>) {
this.chunkSize = fields?.chunkSize ?? this.chunkSize;
this.chunkOverlap = fields?.chunkOverlap ?? this.chunkOverlap;
this.chunkSize = fields?.chunkSize ?? this.chunkSize
this.chunkOverlap = fields?.chunkOverlap ?? this.chunkOverlap
if (this.chunkOverlap >= this.chunkSize) {
throw new Error('Cannot have chunkOverlap >= chunkSize');
throw new Error('Cannot have chunkOverlap >= chunkSize')
}
}
abstract splitText(text: string): string[];
abstract splitText(text: string): string[]
createDocuments(texts: string[]): string[] {
const documents: string[] = [];
const documents: string[] = []
for (let i = 0; i < texts.length; i += 1) {
const text = texts[i];
const text = texts[i]
for (const chunk of this.splitText(text!)) {
documents.push(chunk);
documents.push(chunk)
}
}
return documents;
return documents
}
splitDocuments(documents: string[]): string[] {
return this.createDocuments(documents);
return this.createDocuments(documents)
}
private joinDocs(docs: string[], separator: string): string | null {
const text = docs.join(separator).trim();
return text === '' ? null : text;
const text = docs.join(separator).trim()
return text === '' ? null : text
}
mergeSplits(splits: string[], separator: string): string[] {
const docs: string[] = [];
const currentDoc: string[] = [];
let total = 0;
const docs: string[] = []
const currentDoc: string[] = []
let total = 0
for (const d of splits) {
const _len = d.length;
const _len = d.length
if (total + _len >= this.chunkSize) {
if (total > this.chunkSize) {
console.warn(
`Created a chunk of size ${total}, +
which is longer than the specified ${this.chunkSize}`,
);
)
}
if (currentDoc.length > 0) {
const doc = this.joinDocs(currentDoc, separator);
const doc = this.joinDocs(currentDoc, separator)
if (doc !== null) {
docs.push(doc);
docs.push(doc)
}
// Keep on popping if:
// - we have a larger chunk than in the chunk overlap
@ -63,81 +63,81 @@ which is longer than the specified ${this.chunkSize}`,
total > this.chunkOverlap ||
(total + _len > this.chunkSize && total > 0)
) {
total -= currentDoc[0]!.length;
currentDoc.shift();
total -= currentDoc[0]!.length
currentDoc.shift()
}
}
}
currentDoc.push(d);
total += _len;
currentDoc.push(d)
total += _len
}
const doc = this.joinDocs(currentDoc, separator);
const doc = this.joinDocs(currentDoc, separator)
if (doc !== null) {
docs.push(doc);
docs.push(doc)
}
return docs;
return docs
}
}
export interface RecursiveCharacterTextSplitterParams
extends TextSplitterParams {
separators: string[];
separators: string[]
}
export class RecursiveCharacterTextSplitter
extends TextSplitter
implements RecursiveCharacterTextSplitterParams
{
separators: string[] = ['\n\n', '\n', '.', ',', '>', '<', ' ', ''];
separators: string[] = ['\n\n', '\n', '.', ',', '>', '<', ' ', '']
constructor(fields?: Partial<RecursiveCharacterTextSplitterParams>) {
super(fields);
this.separators = fields?.separators ?? this.separators;
super(fields)
this.separators = fields?.separators ?? this.separators
}
splitText(text: string): string[] {
const finalChunks: string[] = [];
const finalChunks: string[] = []
// Get appropriate separator to use
let separator: string = this.separators[this.separators.length - 1]!;
let separator: string = this.separators[this.separators.length - 1]!
for (const s of this.separators) {
if (s === '') {
separator = s;
break;
separator = s
break
}
if (text.includes(s)) {
separator = s;
break;
separator = s
break
}
}
// Now that we have the separator, split the text
let splits: string[];
let splits: string[]
if (separator) {
splits = text.split(separator);
splits = text.split(separator)
} else {
splits = text.split('');
splits = text.split('')
}
// Now go merging things, recursively splitting longer texts.
let goodSplits: string[] = [];
let goodSplits: string[] = []
for (const s of splits) {
if (s.length < this.chunkSize) {
goodSplits.push(s);
goodSplits.push(s)
} else {
if (goodSplits.length) {
const mergedText = this.mergeSplits(goodSplits, separator);
finalChunks.push(...mergedText);
goodSplits = [];
const mergedText = this.mergeSplits(goodSplits, separator)
finalChunks.push(...mergedText)
goodSplits = []
}
const otherInfo = this.splitText(s);
finalChunks.push(...otherInfo);
const otherInfo = this.splitText(s)
finalChunks.push(...otherInfo)
}
}
if (goodSplits.length) {
const mergedText = this.mergeSplits(goodSplits, separator);
finalChunks.push(...mergedText);
const mergedText = this.mergeSplits(goodSplits, separator)
finalChunks.push(...mergedText)
}
return finalChunks;
return finalChunks
}
}