summaryrefslogtreecommitdiff
path: root/src/generic.ts
diff options
context:
space:
mode:
authorpolwex <polwex@sortug.com>2025-07-23 02:37:15 +0700
committerpolwex <polwex@sortug.com>2025-07-23 02:37:15 +0700
commit42dd99bfac9777a4ecc6700b87edf26a5c984de6 (patch)
tree031e45d187f45def4b58ad7590d39dec3924600d /src/generic.ts
parent4c6913644b362b28f15b125c2fbe48165f1e048c (diff)
checkpoint
Diffstat (limited to 'src/generic.ts')
-rw-r--r--src/generic.ts128
1 files changed, 128 insertions, 0 deletions
diff --git a/src/generic.ts b/src/generic.ts
new file mode 100644
index 0000000..50c4435
--- /dev/null
+++ b/src/generic.ts
@@ -0,0 +1,128 @@
+import OpenAI from "openai";
+import { MAX_TOKENS, RESPONSE_LENGTH } from "./logic/constants";
+import type { AIModelAPI, ChatMessage, OChoice } from "./types";
+import type { AsyncRes } from "sortug";
+
+type Message = OpenAI.Chat.Completions.ChatCompletionMessageParam;
+
+type Props = {
+ baseURL: string;
+ apiKey: string;
+ model?: string;
+ maxTokens?: number;
+ tokenizer?: (text: string) => number;
+};
+export default class OpenAIAPI implements AIModelAPI {
+ private apiKey;
+ private baseURL;
+ private api;
+ maxTokens: number = MAX_TOKENS;
+ tokenizer: (text: string) => number = (text) => text.length / 3;
+ model;
+
+ constructor(props: Props) {
+ this.apiKey = props.apiKey;
+ this.baseURL = props.baseURL;
+ this.api = new OpenAI({ baseURL: this.baseURL, apiKey: this.apiKey });
+ this.model = props.model || "";
+ if (props.maxTokens) this.maxTokens = props.maxTokens;
+ if (props.tokenizer) this.tokenizer = props.tokenizer;
+ }
+ public setModel(model: string) {
+ this.model = model;
+ }
+ private mapMessages(input: ChatMessage[]): Message[] {
+ return input.map((m) => {
+ return { role: m.author as any, content: m.text, name: m.author };
+ });
+ }
+
+ public async send(sys: string, input: ChatMessage[]): AsyncRes<string[]> {
+ const messages = this.mapMessages(input);
+ const sysMsg: Message = { role: "system", content: sys };
+ const allMessages = [sysMsg, ...messages];
+ console.log("before truncation", allMessages);
+ const truncated = this.truncateHistory(allMessages);
+ const res = await this.apiCall(truncated);
+ if ("error" in res) return res;
+ else {
+ try {
+ // TODO type this properly
+ const choices: OChoice[] = res.ok;
+ return { ok: choices.map((c) => c.message.content!) };
+ } catch (e) {
+ return { error: `${e}` };
+ }
+ }
+ }
+
+ public async stream(
+ sys: string,
+ input: ChatMessage[],
+ handle: (c: string) => void,
+ ) {
+ const messages = this.mapMessages(input);
+ const sysMsg: Message = { role: "system", content: sys };
+ const allMessages = [sysMsg, ...messages];
+ const truncated = this.truncateHistory(allMessages);
+ await this.apiCallStream(truncated, handle);
+ }
+
+ private truncateHistory(messages: Message[]): Message[] {
+ const totalTokens = messages.reduce((total, message) => {
+ return total + this.tokenizer(message.content as string);
+ }, 0);
+ while (totalTokens > this.maxTokens && messages.length > 1) {
+ // Always keep the system message if it exists
+ const startIndex = messages[0].role === "system" ? 1 : 0;
+ messages.splice(startIndex, 1);
+ }
+ return messages;
+ }
+
+ // TODO custom temperature?
+ private async apiCall(messages: Message[]): AsyncRes<OChoice[]> {
+ console.log({ messages }, "at the very end");
+ try {
+ const completion = await this.api.chat.completions.create({
+ temperature: 1.3,
+ model: this.model,
+ messages,
+ max_tokens: RESPONSE_LENGTH,
+ });
+ if (!completion) return { error: "null response from openai" };
+ return { ok: completion.choices };
+ } catch (e) {
+ console.log(e, "error in openai api");
+ return { error: `${e}` };
+ }
+ }
+
+ private async apiCallStream(
+ messages: Message[],
+ handle: (c: string) => void,
+ ): Promise<void> {
+ try {
+ const stream = await this.api.chat.completions.create({
+ temperature: 1.3,
+ model: this.model,
+ messages,
+ max_tokens: RESPONSE_LENGTH,
+ stream: true,
+ });
+
+ for await (const chunk of stream) {
+ for (const choice of chunk.choices) {
+ console.log({ choice });
+ if (!choice.delta) continue;
+ const cont = choice.delta.content;
+ if (!cont) continue;
+ handle(cont);
+ }
+ }
+ } catch (e) {
+ console.log(e, "error in openai api");
+ handle(`Error streaming OpenAI, ${e}`);
+ }
+ }
+}