diff options
Diffstat (limited to 'packages/ai/src/generic.ts')
| -rw-r--r-- | packages/ai/src/generic.ts | 204 |
1 files changed, 204 insertions, 0 deletions
diff --git a/packages/ai/src/generic.ts b/packages/ai/src/generic.ts new file mode 100644 index 0000000..8c41f19 --- /dev/null +++ b/packages/ai/src/generic.ts @@ -0,0 +1,204 @@ +import OpenAI from "openai"; +import { MAX_TOKENS, RESPONSE_LENGTH } from "./logic/constants"; +import type { AIModelAPI, ChatMessage, InputToken } from "./types"; +import type { AsyncRes } from "@sortug/lib"; +import type { ChatCompletionContentPart } from "openai/resources"; +import { memoize } from "./cache"; +import type { ChatCompletionCreateParamsNonStreaming } from "groq-sdk/src/resources/chat/completions.js"; + +type OChoice = OpenAI.Chat.Completions.ChatCompletion.Choice; +type Message = OpenAI.Chat.Completions.ChatCompletionUserMessageParam; +type Params = OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming; +type OMessage = OpenAI.Chat.Completions.ChatCompletionMessageParam; + +type Props = { + baseURL: string; + apiKey: string; + model?: string; + maxTokens?: number; + tokenizer?: (text: string) => number; + allowBrowser?: boolean; +}; +export default class OpenAIAPI implements AIModelAPI { + private cachedCreate!: ( + args: Params, + ) => Promise<OpenAI.Chat.Completions.ChatCompletion>; + + private apiKey; + private baseURL; + private api; + maxTokens: number = MAX_TOKENS; + tokenizer: (text: string) => number = (text) => text.length / 3; + model; + + constructor(props: Props) { + this.apiKey = props.apiKey; + this.baseURL = props.baseURL; + this.api = new OpenAI({ + baseURL: this.baseURL, + apiKey: this.apiKey, + dangerouslyAllowBrowser: props.allowBrowser || false, + }); + this.model = props.model || ""; + if (props.maxTokens) this.maxTokens = props.maxTokens; + if (props.tokenizer) this.tokenizer = props.tokenizer; + + const boundCreate = this.api.chat.completions.create.bind( + this.api.chat.completions, + ); + + this.cachedCreate = memoize(boundCreate, { + ttlMs: 2 * 60 * 60 * 1000, // 2h + maxEntries: 5000, + persistDir: "./cache/memo", + // stable key for the call + keyFn: (args) => { + // args is the single object param to .create(...) + const { + model, + messages, + max_tokens, + temperature, + top_p, + frequency_penalty, + presence_penalty, + stop, + } = args as Params; + // stringify messages deterministically (role+content only) + const msg = (messages as any[]) + .map((m) => ({ role: m.role, content: m.content })) + .slice(0, 200); // guard size if you want + return JSON.stringify({ + model, + msg, + max_tokens, + temperature, + top_p, + frequency_penalty, + presence_penalty, + stop, + }); + }, + }); + } + public setModel(model: string) { + this.model = model; + } + private mapMessages(input: ChatMessage[]): Message[] { + return input.map((m) => { + return { role: m.author as any, content: m.text, name: m.author }; + }); + } + private buildInput(tokens: InputToken[]): Message[] { + const content: ChatCompletionContentPart[] = tokens.map((t) => { + if ("text" in t) return { type: "text", text: t.text }; + if ("img" in t) return { type: "image_url", image_url: { url: t.img } }; + else return { type: "text", text: "oy vey" }; + }); + return [{ role: "user", content }]; + } + + public async send( + input: string | InputToken[], + sys?: string, + ): AsyncRes<string> { + const messages: Message[] = + typeof input === "string" + ? [{ role: "user" as const, content: input }] + : this.buildInput(input); + // const messages = this.mapMessages(input); + const allMessages: OMessage[] = sys + ? [{ role: "system", content: sys }, ...messages] + : messages; + const truncated = this.truncateHistory(allMessages); + const res = await this.apiCall(truncated); + if ("error" in res) return res; + else { + try { + // TODO type this properly + const choices: OChoice[] = res.ok; + const resText = choices.reduce((acc, item) => { + return `${acc}\n${item.message.content || ""}`; + }, ""); + return { ok: resText }; + } catch (e) { + return { error: `${e}` }; + } + } + } + + public async stream( + input: string | InputToken[], + handle: (c: string) => void, + sys?: string, + ) { + const messages: Message[] = + typeof input === "string" + ? [{ role: "user" as const, content: input }] + : this.buildInput(input); + // const messages = this.mapMessages(input); + const allMessages: OMessage[] = sys + ? [{ role: "system", content: sys }, ...messages] + : messages; + const truncated = this.truncateHistory(allMessages); + await this.apiCallStream(truncated, handle); + } + + private truncateHistory(messages: OMessage[]): OMessage[] { + const totalTokens = messages.reduce((total, message) => { + return total + this.tokenizer(message.content as string); + }, 0); + while (totalTokens > this.maxTokens && messages.length > 1) { + // Always keep the system message if it exists + const startIndex = messages[0].role === "system" ? 1 : 0; + messages.splice(startIndex, 1); + } + return messages; + } + + // TODO custom temperature? + private async apiCall(messages: OMessage[]): AsyncRes<OChoice[]> { + // console.log({ messages }, "at the very end"); + try { + const completion = await this.cachedCreate({ + // temperature: 1.3, + model: this.model, + messages, + max_tokens: RESPONSE_LENGTH, + }); + if (!completion) return { error: "null response from openai" }; + return { ok: completion.choices }; + } catch (e) { + console.log(e, "error in openai api"); + return { error: `${e}` }; + } + } + + private async apiCallStream( + messages: OMessage[], + handle: (c: string) => void, + ): Promise<void> { + try { + const stream = await this.api.chat.completions.create({ + temperature: 1.3, + model: this.model, + messages, + max_tokens: RESPONSE_LENGTH, + stream: true, + }); + + for await (const chunk of stream) { + for (const choice of chunk.choices) { + console.log({ choice }); + if (!choice.delta) continue; + const cont = choice.delta.content; + if (!cont) continue; + handle(cont); + } + } + } catch (e) { + console.log(e, "error in openai api"); + handle(`Error streaming OpenAI, ${e}`); + } + } +} |
