refactor: move general utils to shared directory

This commit is contained in:
AnotiaWang
2025-03-23 14:55:19 +08:00
parent ff1e79603d
commit 60f6f227c3
11 changed files with 124 additions and 63 deletions

View File

@ -272,6 +272,7 @@
retryNode, retryNode,
currentDepth, currentDepth,
breadth, breadth,
aiConfig: config.ai,
maxDepth: form.value.depth, maxDepth: form.value.depth,
languageCode: locale.value, languageCode: locale.value,
searchLanguageCode: config.webSearch.searchLanguage, searchLanguageCode: config.webSearch.searchLanguage,

View File

@ -1,9 +1,6 @@
<script setup lang="ts"> <script setup lang="ts">
const { locale: globalLocale, availableLocales, t, setLocale } = useI18n() const { locale: globalLocale, availableLocales, t, setLocale } = useI18n()
export type Locale = (typeof globalLocale)['value']
export type AvailableLocales = Locale[]
const props = defineProps<{ const props = defineProps<{
/** Override display locale */ /** Override display locale */
value?: Locale value?: Locale

View File

@ -19,7 +19,7 @@
}>() }>()
const { t, locale } = useI18n() const { t, locale } = useI18n()
const { showConfigManager, isConfigValid } = storeToRefs(useConfigStore()) const { showConfigManager, isConfigValid, config } = storeToRefs(useConfigStore())
const toast = useToast() const toast = useToast()
const reasoningContent = ref('') const reasoningContent = ref('')
@ -57,6 +57,7 @@
query: form.value.query, query: form.value.query,
numQuestions: form.value.numQuestions, numQuestions: form.value.numQuestions,
language: t('language', {}, { locale: locale.value }), language: t('language', {}, { locale: locale.value }),
aiConfig: config.value.ai,
})) { })) {
if (f.type === 'reasoning') { if (f.type === 'reasoning') {
reasoningContent.value += f.delta reasoningContent.value += f.delta

View File

@ -8,6 +8,7 @@
} from '~/constants/injection-keys' } from '~/constants/injection-keys'
const { t, locale } = useI18n() const { t, locale } = useI18n()
const { config } = storeToRefs(useConfigStore())
const toast = useToast() const toast = useToast()
const error = ref('') const error = ref('')
@ -156,6 +157,7 @@
prompt: getCombinedQuery(form.value, feedback.value), prompt: getCombinedQuery(form.value, feedback.value),
language: t('language', {}, { locale: locale.value }), language: t('language', {}, { locale: locale.value }),
learnings, learnings,
aiConfig: config.value.ai,
}) })
for await (const chunk of fullStream) { for await (const chunk of fullStream) {
if (chunk.type === 'reasoning') { if (chunk.type === 'reasoning') {

View File

@ -1,5 +1,5 @@
import { skipHydrate } from 'pinia' import { skipHydrate } from 'pinia'
import type { Locale } from '@/components/LangSwitcher.vue' import { getApiBase } from '~~/shared/utils/ai-model'
export type ConfigAiProvider = export type ConfigAiProvider =
| 'openai-compatible' | 'openai-compatible'
@ -71,25 +71,7 @@ export const useConfigStore = defineStore('config', () => {
) )
const isConfigValid = computed(() => validateConfig(config.value)) const isConfigValid = computed(() => validateConfig(config.value))
const aiApiBase = computed(() => { const aiApiBase = computed(() => getApiBase(config.value.ai))
const { ai } = config.value
if (ai.provider === 'openrouter') {
return ai.apiBase || 'https://openrouter.ai/api/v1'
}
if (ai.provider === 'deepseek') {
return ai.apiBase || 'https://api.deepseek.com/v1'
}
if (ai.provider === 'ollama') {
return ai.apiBase || 'http://localhost:11434/v1'
}
if (ai.provider === 'siliconflow') {
return ai.apiBase || 'https://api.siliconflow.cn/v1'
}
if (ai.provider === 'infiniai') {
return ai.apiBase || 'https://cloud.infini-ai.com/maas/v1'
}
return ai.apiBase || 'https://api.openai.com/v1'
})
const webSearchApiBase = computed(() => { const webSearchApiBase = computed(() => {
const { webSearch } = config.value const { webSearch } = config.value
if (webSearch.provider === 'tavily') { if (webSearch.provider === 'tavily') {

View File

@ -1,14 +1,12 @@
import { streamText } from 'ai' import { streamText } from 'ai'
import { z } from 'zod' import { z } from 'zod'
import { parseStreamingJson, type DeepPartial } from '~~/utils/json' import { parseStreamingJson, type DeepPartial } from '~~/shared/utils/json'
import { trimPrompt } from './ai/providers' import { trimPrompt } from './ai/providers'
import { languagePrompt, systemPrompt } from './prompt' import { languagePrompt, systemPrompt } from './prompt'
import zodToJsonSchema from 'zod-to-json-schema' import zodToJsonSchema from 'zod-to-json-schema'
import { useAiModel } from '@/composables/useAiProvider'
import type { Locale } from '@/components/LangSwitcher.vue'
import type { DeepResearchNode } from '@/components/DeepResearch/DeepResearch.vue' import type { DeepResearchNode } from '@/components/DeepResearch/DeepResearch.vue'
import { throwAiError } from '~~/utils/errors' import { throwAiError } from '~~/shared/utils/errors'
export type ResearchResult = { export type ResearchResult = {
learnings: ProcessedSearchResult['learnings'] learnings: ProcessedSearchResult['learnings']
@ -18,6 +16,7 @@ export interface WriteFinalReportParams {
prompt: string prompt: string
learnings: ProcessedSearchResult['learnings'] learnings: ProcessedSearchResult['learnings']
language: string language: string
aiConfig: ConfigAi
} }
// Used for streaming response // Used for streaming response
@ -80,6 +79,7 @@ export function generateSearchQueries({
learnings, learnings,
language, language,
searchLanguage, searchLanguage,
aiConfig,
}: { }: {
query: string query: string
language: string language: string
@ -88,6 +88,7 @@ export function generateSearchQueries({
learnings?: string[] learnings?: string[]
/** Force the LLM to generate serp queries in a certain language */ /** Force the LLM to generate serp queries in a certain language */
searchLanguage?: string searchLanguage?: string
aiConfig: ConfigAi
}) { }) {
const schema = z.object({ const schema = z.object({
queries: z queries: z
@ -122,7 +123,7 @@ export function generateSearchQueries({
lp, lp,
].join('\n\n') ].join('\n\n')
return streamText({ return streamText({
model: useAiModel(), model: getLanguageModel(aiConfig),
system: systemPrompt(), system: systemPrompt(),
prompt, prompt,
onError({ error }) { onError({ error }) {
@ -149,12 +150,14 @@ function processSearchResult({
numLearnings = 5, numLearnings = 5,
numFollowUpQuestions = 3, numFollowUpQuestions = 3,
language, language,
aiConfig,
}: { }: {
query: string query: string
results: WebSearchResult[] results: WebSearchResult[]
language: string language: string
numLearnings?: number numLearnings?: number
numFollowUpQuestions?: number numFollowUpQuestions?: number
aiConfig: ConfigAi
}) { }) {
const schema = z.object({ const schema = z.object({
learnings: z learnings: z
@ -194,7 +197,7 @@ function processSearchResult({
].join('\n\n') ].join('\n\n')
return streamText({ return streamText({
model: useAiModel(), model: getLanguageModel(aiConfig),
system: systemPrompt(), system: systemPrompt(),
prompt, prompt,
onError({ error }) { onError({ error }) {
@ -207,6 +210,7 @@ export function writeFinalReport({
prompt, prompt,
learnings, learnings,
language, language,
aiConfig,
}: WriteFinalReportParams) { }: WriteFinalReportParams) {
const learningsString = trimPrompt( const learningsString = trimPrompt(
learnings learnings
@ -229,7 +233,7 @@ ${learning.learning}
].join('\n\n') ].join('\n\n')
return streamText({ return streamText({
model: useAiModel(), model: getLanguageModel(aiConfig),
system: systemPrompt(), system: systemPrompt(),
prompt: _prompt, prompt: _prompt,
onError({ error }) { onError({ error }) {
@ -247,6 +251,7 @@ export async function deepResearch({
breadth, breadth,
maxDepth, maxDepth,
languageCode, languageCode,
aiConfig,
searchLanguageCode, searchLanguageCode,
learnings, learnings,
onProgress, onProgress,
@ -259,6 +264,8 @@ export async function deepResearch({
maxDepth: number maxDepth: number
/** The language of generated response */ /** The language of generated response */
languageCode: Locale languageCode: Locale
/** The AI model configuration */
aiConfig: ConfigAi
/** The language of SERP query */ /** The language of SERP query */
searchLanguageCode?: Locale searchLanguageCode?: Locale
/** Accumulated learnings from all nodes visited so far */ /** Accumulated learnings from all nodes visited so far */
@ -299,6 +306,7 @@ export async function deepResearch({
numQueries: breadth, numQueries: breadth,
language, language,
searchLanguage, searchLanguage,
aiConfig,
}) })
for await (const chunk of parseStreamingJson( for await (const chunk of parseStreamingJson(
@ -370,7 +378,6 @@ export async function deepResearch({
if (!searchQuery?.query) { if (!searchQuery?.query) {
return { return {
learnings: [], learnings: [],
visitedUrls: [],
} }
} }
onProgress({ onProgress({
@ -404,6 +411,7 @@ export async function deepResearch({
results, results,
numFollowUpQuestions: nextBreadth, numFollowUpQuestions: nextBreadth,
language, language,
aiConfig,
}) })
let searchResult: PartialProcessedSearchResult = {} let searchResult: PartialProcessedSearchResult = {}
@ -495,6 +503,7 @@ export async function deepResearch({
currentDepth: nextDepth, currentDepth: nextDepth,
nodeId: searchQuery.nodeId, nodeId: searchQuery.nodeId,
languageCode, languageCode,
aiConfig,
}) })
return r return r
} catch (error) { } catch (error) {

View File

@ -3,9 +3,9 @@ import { z } from 'zod'
import { zodToJsonSchema } from 'zod-to-json-schema' import { zodToJsonSchema } from 'zod-to-json-schema'
import { languagePrompt, systemPrompt } from './prompt' import { languagePrompt, systemPrompt } from './prompt'
import { useAiModel } from '~/composables/useAiProvider' import { parseStreamingJson, type DeepPartial } from '~~/shared/utils/json'
import { parseStreamingJson, type DeepPartial } from '~~/utils/json' import { throwAiError } from '~~/shared/utils/errors'
import { throwAiError } from '~~/utils/errors' import { getLanguageModel } from '~~/shared/utils/ai-model'
type PartialFeedback = DeepPartial<z.infer<typeof feedbackTypeSchema>> type PartialFeedback = DeepPartial<z.infer<typeof feedbackTypeSchema>>
@ -17,9 +17,11 @@ export function generateFeedback({
query, query,
language, language,
numQuestions = 3, numQuestions = 3,
aiConfig,
}: { }: {
query: string query: string
language: string language: string
aiConfig: ConfigAi
numQuestions?: number numQuestions?: number
}) { }) {
const schema = z.object({ const schema = z.object({
@ -36,7 +38,7 @@ export function generateFeedback({
].join('\n\n') ].join('\n\n')
const stream = streamText({ const stream = streamText({
model: useAiModel(), model: getLanguageModel(aiConfig),
system: systemPrompt(), system: systemPrompt(),
prompt, prompt,
onError({ error }) { onError({ error }) {

4
shared/types/types.d.ts vendored Normal file
View File

@ -0,0 +1,4 @@
import type { NuxtApp } from "#app";
export type AvailableLocales = NuxtApp['$i18n']['availableLocales']
export type Locale = AvailableLocales[number]

63
shared/utils/ai-model.ts Normal file
View File

@ -0,0 +1,63 @@
import { createDeepSeek } from "@ai-sdk/deepseek"
import { createOpenAI } from "@ai-sdk/openai"
import { createOpenRouter } from "@openrouter/ai-sdk-provider"
import { wrapLanguageModel, extractReasoningMiddleware } from "ai"
import type { LanguageModelV1 } from "ai"
export function getLanguageModel(config: ConfigAi) {
const apiBase = getApiBase(config)
let model: LanguageModelV1
if (config.provider === 'openrouter') {
const openRouter = createOpenRouter({
apiKey: config.apiKey,
baseURL: apiBase,
})
model = openRouter(config.model, {
includeReasoning: true,
})
} else if (
config.provider === 'deepseek' ||
config.provider === 'siliconflow' ||
config.provider === 'infiniai' ||
// Special case if model name includes 'deepseek'
// This ensures compatibilty with providers like Siliconflow
config.model?.toLowerCase().includes('deepseek')
) {
const deepSeek = createDeepSeek({
apiKey: config.apiKey,
baseURL: apiBase,
})
model = deepSeek(config.model)
} else {
const openai = createOpenAI({
apiKey: config.apiKey,
baseURL: apiBase,
})
model = openai(config.model)
}
return wrapLanguageModel({
model,
middleware: extractReasoningMiddleware({ tagName: 'think' }),
})
}
export function getApiBase(config: ConfigAi) {
if (config.provider === 'openrouter') {
return config.apiBase || 'https://openrouter.ai/api/v1'
}
if (config.provider === 'deepseek') {
return config.apiBase || 'https://api.deepseek.com/v1'
}
if (config.provider === 'ollama') {
return config.apiBase || 'http://localhost:11434/v1'
}
if (config.provider === 'siliconflow') {
return config.apiBase || 'https://api.siliconflow.cn/v1'
}
if (config.provider === 'infiniai') {
return config.apiBase || 'https://cloud.infini-ai.com/maas/v1'
}
return config.apiBase || 'https://api.openai.com/v1'
}