import OpenAI from "openai"; import { MAX_TOKENS, RESPONSE_LENGTH } from "./logic/constants"; import type { AResult, ChatMessage, OChoice } from "./types"; type Message = OpenAI.Chat.Completions.ChatCompletionMessageParam; type Props = { baseURL: string; apiKey: string; model: string; maxTokens?: number; tokenizer?: (text: string) => number; }; export default class Conversation { private apiKey; private baseURL; private maxTokens: number = MAX_TOKENS; private tokenizer: (text: string) => number = (text) => text.length / 3; private api; private model; constructor(props: Props) { this.apiKey = props.apiKey; this.baseURL = props.baseURL; this.api = new OpenAI({ baseURL: this.baseURL, apiKey: this.apiKey }); this.model = props.model; if (props.maxTokens) this.maxTokens = props.maxTokens; if (props.tokenizer) this.tokenizer = props.tokenizer; } public setModel(model: string) { this.model = model; } private mapMessages(input: ChatMessage[]): Message[] { return input.map((m) => { return { role: m.author as any, content: m.text, name: m.author }; }); } public async send(sys: string, input: ChatMessage[]): AResult { const messages = this.mapMessages(input); const sysMsg: Message = { role: "system", content: sys }; const allMessages = [sysMsg, ...messages]; console.log("before truncation", allMessages); const truncated = this.truncateHistory(allMessages); const res = await this.apiCall(truncated); if ("error" in res) return res; else try { return { ok: res.ok.map((c) => c.message.content!) }; } catch (e) { return { error: `${e}` }; } } public async stream( sys: string, input: ChatMessage[], handle: (c: string) => void, ) { const messages = this.mapMessages(input); const sysMsg: Message = { role: "system", content: sys }; const allMessages = [sysMsg, ...messages]; const truncated = this.truncateHistory(allMessages); await this.apiCallStream(truncated, handle); } private truncateHistory(messages: Message[]): Message[] { const totalTokens = messages.reduce((total, message) => { return total + this.tokenizer(message.content as string); }, 0); while (totalTokens > this.maxTokens && messages.length > 1) { // Always keep the system message if it exists const startIndex = messages[0].role === "system" ? 1 : 0; messages.splice(startIndex, 1); } return messages; } // TODO custom temperature? private async apiCall(messages: Message[]): AResult { console.log({ messages }, "at the very end"); try { const completion = await this.api.chat.completions.create({ temperature: 1.3, model: this.model, messages, max_tokens: RESPONSE_LENGTH, }); if (!completion) return { error: "null response from openai" }; return { ok: completion.choices }; } catch (e) { console.log(e, "error in openai api"); return { error: `${e}` }; } } private async apiCallStream( messages: Message[], handle: (c: string) => void, ): Promise { try { const stream = await this.api.chat.completions.create({ temperature: 1.3, model: this.model, messages, max_tokens: RESPONSE_LENGTH, stream: true, }); for await (const chunk of stream) { for (const choice of chunk.choices) { console.log({ choice }); if (!choice.delta) continue; const cont = choice.delta.content; if (!cont) continue; handle(cont); } } } catch (e) { console.log(e, "error in openai api"); handle(`Error streaming OpenAI, ${e}`); } } }