import OpenAI from "openai"; import { MAX_TOKENS, RESPONSE_LENGTH } from "./logic/constants"; import type { AIModelAPI, ChatMessage, InputToken } from "./types"; import type { AsyncRes } from "sortug"; import type { ChatCompletionContentPart } from "openai/resources"; type OChoice = OpenAI.Chat.Completions.ChatCompletion.Choice; type Message = OpenAI.Chat.Completions.ChatCompletionUserMessageParam; type OMessage = OpenAI.Chat.Completions.ChatCompletionMessageParam; type Props = { baseURL: string; apiKey: string; model?: string; maxTokens?: number; tokenizer?: (text: string) => number; allowBrowser?: boolean; }; export default class OpenAIAPI implements AIModelAPI { private apiKey; private baseURL; private api; maxTokens: number = MAX_TOKENS; tokenizer: (text: string) => number = (text) => text.length / 3; model; constructor(props: Props) { this.apiKey = props.apiKey; this.baseURL = props.baseURL; this.api = new OpenAI({ baseURL: this.baseURL, apiKey: this.apiKey, dangerouslyAllowBrowser: props.allowBrowser || false, }); this.model = props.model || ""; if (props.maxTokens) this.maxTokens = props.maxTokens; if (props.tokenizer) this.tokenizer = props.tokenizer; } public setModel(model: string) { this.model = model; } private mapMessages(input: ChatMessage[]): Message[] { return input.map((m) => { return { role: m.author as any, content: m.text, name: m.author }; }); } private buildInput(tokens: InputToken[]): Message[] { const content: ChatCompletionContentPart[] = tokens.map((t) => { if ("text" in t) return { type: "text", text: t.text }; if ("img" in t) return { type: "image_url", image_url: { url: t.img } }; else return { type: "text", text: "oy vey" }; }); return [{ role: "user", content }]; } public async send( input: string | InputToken[], sys?: string, ): AsyncRes { const messages: Message[] = typeof input === "string" ? [{ role: "user" as const, content: input }] : this.buildInput(input); // const messages = this.mapMessages(input); const allMessages: OMessage[] = sys ? [{ role: "system", content: sys }, ...messages] : messages; const truncated = this.truncateHistory(allMessages); const res = await this.apiCall(truncated); if ("error" in res) return res; else { try { // TODO type this properly const choices: OChoice[] = res.ok; const resText = choices.reduce((acc, item) => { return `${acc}\n${item.message.content || ""}`; }, ""); return { ok: resText }; } catch (e) { return { error: `${e}` }; } } } public async stream( input: string | InputToken[], handle: (c: string) => void, sys?: string, ) { const messages: Message[] = typeof input === "string" ? [{ role: "user" as const, content: input }] : this.buildInput(input); // const messages = this.mapMessages(input); const allMessages: OMessage[] = sys ? [{ role: "system", content: sys }, ...messages] : messages; const truncated = this.truncateHistory(allMessages); await this.apiCallStream(truncated, handle); } private truncateHistory(messages: OMessage[]): OMessage[] { const totalTokens = messages.reduce((total, message) => { return total + this.tokenizer(message.content as string); }, 0); while (totalTokens > this.maxTokens && messages.length > 1) { // Always keep the system message if it exists const startIndex = messages[0].role === "system" ? 1 : 0; messages.splice(startIndex, 1); } return messages; } // TODO custom temperature? private async apiCall(messages: OMessage[]): AsyncRes { // console.log({ messages }, "at the very end"); try { const completion = await this.api.chat.completions.create({ // temperature: 1.3, model: this.model, messages, max_tokens: RESPONSE_LENGTH, }); if (!completion) return { error: "null response from openai" }; return { ok: completion.choices }; } catch (e) { console.log(e, "error in openai api"); return { error: `${e}` }; } } private async apiCallStream( messages: OMessage[], handle: (c: string) => void, ): Promise { try { const stream = await this.api.chat.completions.create({ temperature: 1.3, model: this.model, messages, max_tokens: RESPONSE_LENGTH, stream: true, }); for await (const chunk of stream) { for (const choice of chunk.choices) { console.log({ choice }); if (!choice.delta) continue; const cont = choice.delta.content; if (!cont) continue; handle(cont); } } } catch (e) { console.log(e, "error in openai api"); handle(`Error streaming OpenAI, ${e}`); } } }