summaryrefslogtreecommitdiff
path: root/src/generic.ts
diff options
context:
space:
mode:
Diffstat (limited to 'src/generic.ts')
-rw-r--r--src/generic.ts165
1 files changed, 101 insertions, 64 deletions
diff --git a/src/generic.ts b/src/generic.ts
index 50c4435..ac6b55b 100644
--- a/src/generic.ts
+++ b/src/generic.ts
@@ -1,9 +1,13 @@
import OpenAI from "openai";
import { MAX_TOKENS, RESPONSE_LENGTH } from "./logic/constants";
-import type { AIModelAPI, ChatMessage, OChoice } from "./types";
+import type { AIModelAPI, ChatMessage, InputToken, OChoice } from "./types";
import type { AsyncRes } from "sortug";
-
-type Message = OpenAI.Chat.Completions.ChatCompletionMessageParam;
+import type {
+ ResponseCreateParamsBase,
+ ResponseCreateParamsNonStreaming,
+ ResponseCreateParamsStreaming,
+ ResponseInput,
+} from "openai/resources/responses/responses.mjs";
type Props = {
baseURL: string;
@@ -31,25 +35,35 @@ export default class OpenAIAPI implements AIModelAPI {
public setModel(model: string) {
this.model = model;
}
- private mapMessages(input: ChatMessage[]): Message[] {
- return input.map((m) => {
- return { role: m.author as any, content: m.text, name: m.author };
- });
+
+ public buildInput(tokens: InputToken[]): ResponseInput {
+ return [
+ {
+ role: "user",
+ content: tokens.map((t) =>
+ "text" in t
+ ? { type: "input_text", text: t.text }
+ : "img" in t
+ ? { type: "input_image", image_url: t.img, detail: "auto" }
+ : { type: "input_text", text: "oy vey" },
+ ),
+ },
+ ];
}
- public async send(sys: string, input: ChatMessage[]): AsyncRes<string[]> {
- const messages = this.mapMessages(input);
- const sysMsg: Message = { role: "system", content: sys };
- const allMessages = [sysMsg, ...messages];
- console.log("before truncation", allMessages);
- const truncated = this.truncateHistory(allMessages);
- const res = await this.apiCall(truncated);
+ // OpenAI SDK has three kinds ReponseInputContent: text image and file
+ // images can be URLs or base64 dataurl thingies
+ //
+ public async send(
+ input: string | ResponseInput,
+ sys?: string,
+ ): AsyncRes<string> {
+ const params = sys ? { instructions: sys, input } : { input };
+ const res = await this.apiCall(params);
if ("error" in res) return res;
else {
try {
- // TODO type this properly
- const choices: OChoice[] = res.ok;
- return { ok: choices.map((c) => c.message.content!) };
+ return { ok: res.ok.output_text };
} catch (e) {
return { error: `${e}` };
}
@@ -57,41 +71,29 @@ export default class OpenAIAPI implements AIModelAPI {
}
public async stream(
- sys: string,
- input: ChatMessage[],
+ input: string,
handle: (c: string) => void,
+ sys?: string,
) {
- const messages = this.mapMessages(input);
- const sysMsg: Message = { role: "system", content: sys };
- const allMessages = [sysMsg, ...messages];
- const truncated = this.truncateHistory(allMessages);
- await this.apiCallStream(truncated, handle);
- }
-
- private truncateHistory(messages: Message[]): Message[] {
- const totalTokens = messages.reduce((total, message) => {
- return total + this.tokenizer(message.content as string);
- }, 0);
- while (totalTokens > this.maxTokens && messages.length > 1) {
- // Always keep the system message if it exists
- const startIndex = messages[0].role === "system" ? 1 : 0;
- messages.splice(startIndex, 1);
- }
- return messages;
+ const params = sys ? { instructions: sys, input } : { input };
+ await this.apiCallStream(params, handle);
}
// TODO custom temperature?
- private async apiCall(messages: Message[]): AsyncRes<OChoice[]> {
- console.log({ messages }, "at the very end");
+ private async apiCall(
+ params: ResponseCreateParamsNonStreaming,
+ ): AsyncRes<OpenAI.Responses.Response> {
try {
- const completion = await this.api.chat.completions.create({
- temperature: 1.3,
- model: this.model,
- messages,
- max_tokens: RESPONSE_LENGTH,
+ const res = await this.api.responses.create({
+ ...params,
+ // temperature: 1.3,
+ model: params.model || this.model,
+ input: params.input,
+ max_output_tokens: params.max_output_tokens || RESPONSE_LENGTH,
+ stream: false,
});
- if (!completion) return { error: "null response from openai" };
- return { ok: completion.choices };
+ // TODO damn there's a lot of stuff here
+ return { ok: res };
} catch (e) {
console.log(e, "error in openai api");
return { error: `${e}` };
@@ -99,30 +101,65 @@ export default class OpenAIAPI implements AIModelAPI {
}
private async apiCallStream(
- messages: Message[],
- handle: (c: string) => void,
- ): Promise<void> {
+ params: ResponseCreateParamsBase,
+ handler: (c: string) => void,
+ ) {
+ // temperature: 1.3,
+ const pms: ResponseCreateParamsStreaming = {
+ ...params,
+ stream: true,
+ model: params.model || this.model,
+ input: params.input,
+ max_output_tokens: params.max_output_tokens || RESPONSE_LENGTH,
+ };
try {
- const stream = await this.api.chat.completions.create({
- temperature: 1.3,
- model: this.model,
- messages,
- max_tokens: RESPONSE_LENGTH,
- stream: true,
- });
-
- for await (const chunk of stream) {
- for (const choice of chunk.choices) {
- console.log({ choice });
- if (!choice.delta) continue;
- const cont = choice.delta.content;
- if (!cont) continue;
- handle(cont);
+ const stream = await this.api.responses.create(pms);
+ for await (const event of stream) {
+ console.log(event);
+ switch (event.type) {
+ // TODO deal with audio and whatever
+ case "response.output_text.delta":
+ handler(event.delta);
+ break;
+ case "response.completed":
+ break;
+ default:
+ break;
}
+ // if (event.type === "response.completed")
+ // wtf how do we use this
}
} catch (e) {
console.log(e, "error in openai api");
- handle(`Error streaming OpenAI, ${e}`);
+ return { error: `${e}` };
}
}
+
+ // private async apiCallStream(
+ // messages: Message[],
+ // handle: (c: string) => void,
+ // ): Promise<void> {
+ // try {
+ // const stream = await this.api.chat.completions.create({
+ // temperature: 1.3,
+ // model: this.model,
+ // messages,
+ // max_tokens: RESPONSE_LENGTH,
+ // stream: true,
+ // });
+
+ // for await (const chunk of stream) {
+ // for (const choice of chunk.choices) {
+ // console.log({ choice });
+ // if (!choice.delta) continue;
+ // const cont = choice.delta.content;
+ // if (!cont) continue;
+ // handle(cont);
+ // }
+ // }
+ // } catch (e) {
+ // console.log(e, "error in openai api");
+ // handle(`Error streaming OpenAI, ${e}`);
+ // }
+ // }
}