init
This commit is contained in:
58
lib/ai/providers.ts
Normal file
58
lib/ai/providers.ts
Normal file
@ -0,0 +1,58 @@
|
||||
import { createOpenAI } from '@ai-sdk/openai';
|
||||
import { getEncoding } from 'js-tiktoken';
|
||||
|
||||
import { RecursiveCharacterTextSplitter } from './text-splitter';
|
||||
|
||||
// Providers
|
||||
const openai = createOpenAI({
|
||||
apiKey: import.meta.env.VITE_OPENAI_API_KEY!,
|
||||
baseURL: import.meta.env.VITE_OPENAI_ENDPOINT || 'https://api.openai.com/v1',
|
||||
});
|
||||
|
||||
const customModel = import.meta.env.VITE_OPENAI_MODEL || 'o3-mini';
|
||||
|
||||
// Models
|
||||
|
||||
export const o3MiniModel = openai(customModel, {
|
||||
// reasoningEffort: customModel.startsWith('o') ? 'medium' : undefined,
|
||||
structuredOutputs: true,
|
||||
});
|
||||
|
||||
const MinChunkSize = 140;
|
||||
const encoder = getEncoding('o200k_base');
|
||||
|
||||
// trim prompt to maximum context size
|
||||
export function trimPrompt(
|
||||
prompt: string,
|
||||
contextSize = Number(import.meta.env.VITE_CONTEXT_SIZE) || 128_000,
|
||||
) {
|
||||
if (!prompt) {
|
||||
return '';
|
||||
}
|
||||
|
||||
const length = encoder.encode(prompt).length;
|
||||
if (length <= contextSize) {
|
||||
return prompt;
|
||||
}
|
||||
|
||||
const overflowTokens = length - contextSize;
|
||||
// on average it's 3 characters per token, so multiply by 3 to get a rough estimate of the number of characters
|
||||
const chunkSize = prompt.length - overflowTokens * 3;
|
||||
if (chunkSize < MinChunkSize) {
|
||||
return prompt.slice(0, MinChunkSize);
|
||||
}
|
||||
|
||||
const splitter = new RecursiveCharacterTextSplitter({
|
||||
chunkSize,
|
||||
chunkOverlap: 0,
|
||||
});
|
||||
const trimmedPrompt = splitter.splitText(prompt)[0] ?? '';
|
||||
|
||||
// last catch, there's a chance that the trimmed prompt is same length as the original prompt, due to how tokens are split & innerworkings of the splitter, handle this case by just doing a hard cut
|
||||
if (trimmedPrompt.length === prompt.length) {
|
||||
return trimPrompt(prompt.slice(0, chunkSize), contextSize);
|
||||
}
|
||||
|
||||
// recursively trim until the prompt is within the context size
|
||||
return trimPrompt(trimmedPrompt, contextSize);
|
||||
}
|
77
lib/ai/text-splitter.test.ts
Normal file
77
lib/ai/text-splitter.test.ts
Normal file
@ -0,0 +1,77 @@
|
||||
import assert from 'node:assert';
|
||||
import { describe, it, beforeEach } from 'node:test';
|
||||
import { RecursiveCharacterTextSplitter } from './text-splitter';
|
||||
|
||||
describe('RecursiveCharacterTextSplitter', () => {
|
||||
let splitter: RecursiveCharacterTextSplitter;
|
||||
|
||||
beforeEach(() => {
|
||||
splitter = new RecursiveCharacterTextSplitter({
|
||||
chunkSize: 50,
|
||||
chunkOverlap: 10,
|
||||
});
|
||||
});
|
||||
|
||||
it('Should correctly split text by separators', () => {
|
||||
const text = 'Hello world, this is a test of the recursive text splitter.';
|
||||
|
||||
// Test with initial chunkSize
|
||||
assert.deepEqual(
|
||||
splitter.splitText(text),
|
||||
['Hello world', 'this is a test of the recursive text splitter']
|
||||
);
|
||||
|
||||
// Test with updated chunkSize
|
||||
splitter.chunkSize = 100;
|
||||
assert.deepEqual(
|
||||
splitter.splitText(
|
||||
'Hello world, this is a test of the recursive text splitter. If I have a period, it should split along the period.'
|
||||
),
|
||||
[
|
||||
'Hello world, this is a test of the recursive text splitter',
|
||||
'If I have a period, it should split along the period.',
|
||||
]
|
||||
);
|
||||
|
||||
// Test with another updated chunkSize
|
||||
splitter.chunkSize = 110;
|
||||
assert.deepEqual(
|
||||
splitter.splitText(
|
||||
'Hello world, this is a test of the recursive text splitter. If I have a period, it should split along the period.\nOr, if there is a new line, it should prioritize splitting on new lines instead.'
|
||||
),
|
||||
[
|
||||
'Hello world, this is a test of the recursive text splitter',
|
||||
'If I have a period, it should split along the period.',
|
||||
'Or, if there is a new line, it should prioritize splitting on new lines instead.',
|
||||
]
|
||||
);
|
||||
});
|
||||
|
||||
it('Should handle empty string', () => {
|
||||
assert.deepEqual(splitter.splitText(''), []);
|
||||
});
|
||||
|
||||
it('Should handle special characters and large texts', () => {
|
||||
const largeText = 'A'.repeat(1000);
|
||||
splitter.chunkSize = 200;
|
||||
assert.deepEqual(
|
||||
splitter.splitText(largeText),
|
||||
Array(5).fill('A'.repeat(200))
|
||||
);
|
||||
|
||||
const specialCharText = 'Hello!@# world$%^ &*( this) is+ a-test';
|
||||
assert.deepEqual(
|
||||
splitter.splitText(specialCharText),
|
||||
['Hello!@#', 'world$%^', '&*( this)', 'is+', 'a-test']
|
||||
);
|
||||
});
|
||||
|
||||
it('Should handle chunkSize equal to chunkOverlap', () => {
|
||||
splitter.chunkSize = 50;
|
||||
splitter.chunkOverlap = 50;
|
||||
assert.throws(
|
||||
() => splitter.splitText('Invalid configuration'),
|
||||
new Error('Cannot have chunkOverlap >= chunkSize')
|
||||
);
|
||||
});
|
||||
});
|
143
lib/ai/text-splitter.ts
Normal file
143
lib/ai/text-splitter.ts
Normal file
@ -0,0 +1,143 @@
|
||||
interface TextSplitterParams {
|
||||
chunkSize: number;
|
||||
|
||||
chunkOverlap: number;
|
||||
}
|
||||
|
||||
abstract class TextSplitter implements TextSplitterParams {
|
||||
chunkSize = 1000;
|
||||
chunkOverlap = 200;
|
||||
|
||||
constructor(fields?: Partial<TextSplitterParams>) {
|
||||
this.chunkSize = fields?.chunkSize ?? this.chunkSize;
|
||||
this.chunkOverlap = fields?.chunkOverlap ?? this.chunkOverlap;
|
||||
if (this.chunkOverlap >= this.chunkSize) {
|
||||
throw new Error('Cannot have chunkOverlap >= chunkSize');
|
||||
}
|
||||
}
|
||||
|
||||
abstract splitText(text: string): string[];
|
||||
|
||||
createDocuments(texts: string[]): string[] {
|
||||
const documents: string[] = [];
|
||||
for (let i = 0; i < texts.length; i += 1) {
|
||||
const text = texts[i];
|
||||
for (const chunk of this.splitText(text!)) {
|
||||
documents.push(chunk);
|
||||
}
|
||||
}
|
||||
return documents;
|
||||
}
|
||||
|
||||
splitDocuments(documents: string[]): string[] {
|
||||
return this.createDocuments(documents);
|
||||
}
|
||||
|
||||
private joinDocs(docs: string[], separator: string): string | null {
|
||||
const text = docs.join(separator).trim();
|
||||
return text === '' ? null : text;
|
||||
}
|
||||
|
||||
mergeSplits(splits: string[], separator: string): string[] {
|
||||
const docs: string[] = [];
|
||||
const currentDoc: string[] = [];
|
||||
let total = 0;
|
||||
for (const d of splits) {
|
||||
const _len = d.length;
|
||||
if (total + _len >= this.chunkSize) {
|
||||
if (total > this.chunkSize) {
|
||||
console.warn(
|
||||
`Created a chunk of size ${total}, +
|
||||
which is longer than the specified ${this.chunkSize}`,
|
||||
);
|
||||
}
|
||||
if (currentDoc.length > 0) {
|
||||
const doc = this.joinDocs(currentDoc, separator);
|
||||
if (doc !== null) {
|
||||
docs.push(doc);
|
||||
}
|
||||
// Keep on popping if:
|
||||
// - we have a larger chunk than in the chunk overlap
|
||||
// - or if we still have any chunks and the length is long
|
||||
while (
|
||||
total > this.chunkOverlap ||
|
||||
(total + _len > this.chunkSize && total > 0)
|
||||
) {
|
||||
total -= currentDoc[0]!.length;
|
||||
currentDoc.shift();
|
||||
}
|
||||
}
|
||||
}
|
||||
currentDoc.push(d);
|
||||
total += _len;
|
||||
}
|
||||
const doc = this.joinDocs(currentDoc, separator);
|
||||
if (doc !== null) {
|
||||
docs.push(doc);
|
||||
}
|
||||
return docs;
|
||||
}
|
||||
}
|
||||
|
||||
export interface RecursiveCharacterTextSplitterParams
|
||||
extends TextSplitterParams {
|
||||
separators: string[];
|
||||
}
|
||||
|
||||
export class RecursiveCharacterTextSplitter
|
||||
extends TextSplitter
|
||||
implements RecursiveCharacterTextSplitterParams
|
||||
{
|
||||
separators: string[] = ['\n\n', '\n', '.', ',', '>', '<', ' ', ''];
|
||||
|
||||
constructor(fields?: Partial<RecursiveCharacterTextSplitterParams>) {
|
||||
super(fields);
|
||||
this.separators = fields?.separators ?? this.separators;
|
||||
}
|
||||
|
||||
splitText(text: string): string[] {
|
||||
const finalChunks: string[] = [];
|
||||
|
||||
// Get appropriate separator to use
|
||||
let separator: string = this.separators[this.separators.length - 1]!;
|
||||
for (const s of this.separators) {
|
||||
if (s === '') {
|
||||
separator = s;
|
||||
break;
|
||||
}
|
||||
if (text.includes(s)) {
|
||||
separator = s;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Now that we have the separator, split the text
|
||||
let splits: string[];
|
||||
if (separator) {
|
||||
splits = text.split(separator);
|
||||
} else {
|
||||
splits = text.split('');
|
||||
}
|
||||
|
||||
// Now go merging things, recursively splitting longer texts.
|
||||
let goodSplits: string[] = [];
|
||||
for (const s of splits) {
|
||||
if (s.length < this.chunkSize) {
|
||||
goodSplits.push(s);
|
||||
} else {
|
||||
if (goodSplits.length) {
|
||||
const mergedText = this.mergeSplits(goodSplits, separator);
|
||||
finalChunks.push(...mergedText);
|
||||
goodSplits = [];
|
||||
}
|
||||
const otherInfo = this.splitText(s);
|
||||
finalChunks.push(...otherInfo);
|
||||
}
|
||||
}
|
||||
if (goodSplits.length) {
|
||||
const mergedText = this.mergeSplits(goodSplits, separator);
|
||||
finalChunks.push(...mergedText);
|
||||
}
|
||||
return finalChunks;
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user