From 42dd99bfac9777a4ecc6700b87edf26a5c984de6 Mon Sep 17 00:00:00 2001 From: polwex Date: Wed, 23 Jul 2025 02:37:15 +0700 Subject: checkpoint --- src/generic.ts | 128 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 128 insertions(+) create mode 100644 src/generic.ts (limited to 'src/generic.ts') diff --git a/src/generic.ts b/src/generic.ts new file mode 100644 index 0000000..50c4435 --- /dev/null +++ b/src/generic.ts @@ -0,0 +1,128 @@ +import OpenAI from "openai"; +import { MAX_TOKENS, RESPONSE_LENGTH } from "./logic/constants"; +import type { AIModelAPI, ChatMessage, OChoice } from "./types"; +import type { AsyncRes } from "sortug"; + +type Message = OpenAI.Chat.Completions.ChatCompletionMessageParam; + +type Props = { + baseURL: string; + apiKey: string; + model?: string; + maxTokens?: number; + tokenizer?: (text: string) => number; +}; +export default class OpenAIAPI implements AIModelAPI { + private apiKey; + private baseURL; + private api; + maxTokens: number = MAX_TOKENS; + tokenizer: (text: string) => number = (text) => text.length / 3; + model; + + constructor(props: Props) { + this.apiKey = props.apiKey; + this.baseURL = props.baseURL; + this.api = new OpenAI({ baseURL: this.baseURL, apiKey: this.apiKey }); + this.model = props.model || ""; + if (props.maxTokens) this.maxTokens = props.maxTokens; + if (props.tokenizer) this.tokenizer = props.tokenizer; + } + public setModel(model: string) { + this.model = model; + } + private mapMessages(input: ChatMessage[]): Message[] { + return input.map((m) => { + return { role: m.author as any, content: m.text, name: m.author }; + }); + } + + public async send(sys: string, input: ChatMessage[]): AsyncRes { + const messages = this.mapMessages(input); + const sysMsg: Message = { role: "system", content: sys }; + const allMessages = [sysMsg, ...messages]; + console.log("before truncation", allMessages); + const truncated = this.truncateHistory(allMessages); + const res = await this.apiCall(truncated); + if ("error" in res) return res; + else { + try { + // TODO type this properly + const choices: OChoice[] = res.ok; + return { ok: choices.map((c) => c.message.content!) }; + } catch (e) { + return { error: `${e}` }; + } + } + } + + public async stream( + sys: string, + input: ChatMessage[], + handle: (c: string) => void, + ) { + const messages = this.mapMessages(input); + const sysMsg: Message = { role: "system", content: sys }; + const allMessages = [sysMsg, ...messages]; + const truncated = this.truncateHistory(allMessages); + await this.apiCallStream(truncated, handle); + } + + private truncateHistory(messages: Message[]): Message[] { + const totalTokens = messages.reduce((total, message) => { + return total + this.tokenizer(message.content as string); + }, 0); + while (totalTokens > this.maxTokens && messages.length > 1) { + // Always keep the system message if it exists + const startIndex = messages[0].role === "system" ? 1 : 0; + messages.splice(startIndex, 1); + } + return messages; + } + + // TODO custom temperature? + private async apiCall(messages: Message[]): AsyncRes { + console.log({ messages }, "at the very end"); + try { + const completion = await this.api.chat.completions.create({ + temperature: 1.3, + model: this.model, + messages, + max_tokens: RESPONSE_LENGTH, + }); + if (!completion) return { error: "null response from openai" }; + return { ok: completion.choices }; + } catch (e) { + console.log(e, "error in openai api"); + return { error: `${e}` }; + } + } + + private async apiCallStream( + messages: Message[], + handle: (c: string) => void, + ): Promise { + try { + const stream = await this.api.chat.completions.create({ + temperature: 1.3, + model: this.model, + messages, + max_tokens: RESPONSE_LENGTH, + stream: true, + }); + + for await (const chunk of stream) { + for (const choice of chunk.choices) { + console.log({ choice }); + if (!choice.delta) continue; + const cont = choice.delta.content; + if (!cont) continue; + handle(cont); + } + } + } catch (e) { + console.log(e, "error in openai api"); + handle(`Error streaming OpenAI, ${e}`); + } + } +} -- cgit v1.2.3