summaryrefslogtreecommitdiff
path: root/packages/ai/src/generic.ts
diff options
context:
space:
mode:
Diffstat (limited to 'packages/ai/src/generic.ts')
-rw-r--r--packages/ai/src/generic.ts204
1 files changed, 204 insertions, 0 deletions
diff --git a/packages/ai/src/generic.ts b/packages/ai/src/generic.ts
new file mode 100644
index 0000000..8c41f19
--- /dev/null
+++ b/packages/ai/src/generic.ts
@@ -0,0 +1,204 @@
+import OpenAI from "openai";
+import { MAX_TOKENS, RESPONSE_LENGTH } from "./logic/constants";
+import type { AIModelAPI, ChatMessage, InputToken } from "./types";
+import type { AsyncRes } from "@sortug/lib";
+import type { ChatCompletionContentPart } from "openai/resources";
+import { memoize } from "./cache";
+import type { ChatCompletionCreateParamsNonStreaming } from "groq-sdk/src/resources/chat/completions.js";
+
+type OChoice = OpenAI.Chat.Completions.ChatCompletion.Choice;
+type Message = OpenAI.Chat.Completions.ChatCompletionUserMessageParam;
+type Params = OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming;
+type OMessage = OpenAI.Chat.Completions.ChatCompletionMessageParam;
+
+type Props = {
+ baseURL: string;
+ apiKey: string;
+ model?: string;
+ maxTokens?: number;
+ tokenizer?: (text: string) => number;
+ allowBrowser?: boolean;
+};
+export default class OpenAIAPI implements AIModelAPI {
+ private cachedCreate!: (
+ args: Params,
+ ) => Promise<OpenAI.Chat.Completions.ChatCompletion>;
+
+ private apiKey;
+ private baseURL;
+ private api;
+ maxTokens: number = MAX_TOKENS;
+ tokenizer: (text: string) => number = (text) => text.length / 3;
+ model;
+
+ constructor(props: Props) {
+ this.apiKey = props.apiKey;
+ this.baseURL = props.baseURL;
+ this.api = new OpenAI({
+ baseURL: this.baseURL,
+ apiKey: this.apiKey,
+ dangerouslyAllowBrowser: props.allowBrowser || false,
+ });
+ this.model = props.model || "";
+ if (props.maxTokens) this.maxTokens = props.maxTokens;
+ if (props.tokenizer) this.tokenizer = props.tokenizer;
+
+ const boundCreate = this.api.chat.completions.create.bind(
+ this.api.chat.completions,
+ );
+
+ this.cachedCreate = memoize(boundCreate, {
+ ttlMs: 2 * 60 * 60 * 1000, // 2h
+ maxEntries: 5000,
+ persistDir: "./cache/memo",
+ // stable key for the call
+ keyFn: (args) => {
+ // args is the single object param to .create(...)
+ const {
+ model,
+ messages,
+ max_tokens,
+ temperature,
+ top_p,
+ frequency_penalty,
+ presence_penalty,
+ stop,
+ } = args as Params;
+ // stringify messages deterministically (role+content only)
+ const msg = (messages as any[])
+ .map((m) => ({ role: m.role, content: m.content }))
+ .slice(0, 200); // guard size if you want
+ return JSON.stringify({
+ model,
+ msg,
+ max_tokens,
+ temperature,
+ top_p,
+ frequency_penalty,
+ presence_penalty,
+ stop,
+ });
+ },
+ });
+ }
+ public setModel(model: string) {
+ this.model = model;
+ }
+ private mapMessages(input: ChatMessage[]): Message[] {
+ return input.map((m) => {
+ return { role: m.author as any, content: m.text, name: m.author };
+ });
+ }
+ private buildInput(tokens: InputToken[]): Message[] {
+ const content: ChatCompletionContentPart[] = tokens.map((t) => {
+ if ("text" in t) return { type: "text", text: t.text };
+ if ("img" in t) return { type: "image_url", image_url: { url: t.img } };
+ else return { type: "text", text: "oy vey" };
+ });
+ return [{ role: "user", content }];
+ }
+
+ public async send(
+ input: string | InputToken[],
+ sys?: string,
+ ): AsyncRes<string> {
+ const messages: Message[] =
+ typeof input === "string"
+ ? [{ role: "user" as const, content: input }]
+ : this.buildInput(input);
+ // const messages = this.mapMessages(input);
+ const allMessages: OMessage[] = sys
+ ? [{ role: "system", content: sys }, ...messages]
+ : messages;
+ const truncated = this.truncateHistory(allMessages);
+ const res = await this.apiCall(truncated);
+ if ("error" in res) return res;
+ else {
+ try {
+ // TODO type this properly
+ const choices: OChoice[] = res.ok;
+ const resText = choices.reduce((acc, item) => {
+ return `${acc}\n${item.message.content || ""}`;
+ }, "");
+ return { ok: resText };
+ } catch (e) {
+ return { error: `${e}` };
+ }
+ }
+ }
+
+ public async stream(
+ input: string | InputToken[],
+ handle: (c: string) => void,
+ sys?: string,
+ ) {
+ const messages: Message[] =
+ typeof input === "string"
+ ? [{ role: "user" as const, content: input }]
+ : this.buildInput(input);
+ // const messages = this.mapMessages(input);
+ const allMessages: OMessage[] = sys
+ ? [{ role: "system", content: sys }, ...messages]
+ : messages;
+ const truncated = this.truncateHistory(allMessages);
+ await this.apiCallStream(truncated, handle);
+ }
+
+ private truncateHistory(messages: OMessage[]): OMessage[] {
+ const totalTokens = messages.reduce((total, message) => {
+ return total + this.tokenizer(message.content as string);
+ }, 0);
+ while (totalTokens > this.maxTokens && messages.length > 1) {
+ // Always keep the system message if it exists
+ const startIndex = messages[0].role === "system" ? 1 : 0;
+ messages.splice(startIndex, 1);
+ }
+ return messages;
+ }
+
+ // TODO custom temperature?
+ private async apiCall(messages: OMessage[]): AsyncRes<OChoice[]> {
+ // console.log({ messages }, "at the very end");
+ try {
+ const completion = await this.cachedCreate({
+ // temperature: 1.3,
+ model: this.model,
+ messages,
+ max_tokens: RESPONSE_LENGTH,
+ });
+ if (!completion) return { error: "null response from openai" };
+ return { ok: completion.choices };
+ } catch (e) {
+ console.log(e, "error in openai api");
+ return { error: `${e}` };
+ }
+ }
+
+ private async apiCallStream(
+ messages: OMessage[],
+ handle: (c: string) => void,
+ ): Promise<void> {
+ try {
+ const stream = await this.api.chat.completions.create({
+ temperature: 1.3,
+ model: this.model,
+ messages,
+ max_tokens: RESPONSE_LENGTH,
+ stream: true,
+ });
+
+ for await (const chunk of stream) {
+ for (const choice of chunk.choices) {
+ console.log({ choice });
+ if (!choice.delta) continue;
+ const cont = choice.delta.content;
+ if (!cont) continue;
+ handle(cont);
+ }
+ }
+ } catch (e) {
+ console.log(e, "error in openai api");
+ handle(`Error streaming OpenAI, ${e}`);
+ }
+ }
+}