diff options
Diffstat (limited to 'src/generic.ts')
| -rw-r--r-- | src/generic.ts | 165 |
1 files changed, 101 insertions, 64 deletions
diff --git a/src/generic.ts b/src/generic.ts index 50c4435..ac6b55b 100644 --- a/src/generic.ts +++ b/src/generic.ts @@ -1,9 +1,13 @@ import OpenAI from "openai"; import { MAX_TOKENS, RESPONSE_LENGTH } from "./logic/constants"; -import type { AIModelAPI, ChatMessage, OChoice } from "./types"; +import type { AIModelAPI, ChatMessage, InputToken, OChoice } from "./types"; import type { AsyncRes } from "sortug"; - -type Message = OpenAI.Chat.Completions.ChatCompletionMessageParam; +import type { + ResponseCreateParamsBase, + ResponseCreateParamsNonStreaming, + ResponseCreateParamsStreaming, + ResponseInput, +} from "openai/resources/responses/responses.mjs"; type Props = { baseURL: string; @@ -31,25 +35,35 @@ export default class OpenAIAPI implements AIModelAPI { public setModel(model: string) { this.model = model; } - private mapMessages(input: ChatMessage[]): Message[] { - return input.map((m) => { - return { role: m.author as any, content: m.text, name: m.author }; - }); + + public buildInput(tokens: InputToken[]): ResponseInput { + return [ + { + role: "user", + content: tokens.map((t) => + "text" in t + ? { type: "input_text", text: t.text } + : "img" in t + ? { type: "input_image", image_url: t.img, detail: "auto" } + : { type: "input_text", text: "oy vey" }, + ), + }, + ]; } - public async send(sys: string, input: ChatMessage[]): AsyncRes<string[]> { - const messages = this.mapMessages(input); - const sysMsg: Message = { role: "system", content: sys }; - const allMessages = [sysMsg, ...messages]; - console.log("before truncation", allMessages); - const truncated = this.truncateHistory(allMessages); - const res = await this.apiCall(truncated); + // OpenAI SDK has three kinds ReponseInputContent: text image and file + // images can be URLs or base64 dataurl thingies + // + public async send( + input: string | ResponseInput, + sys?: string, + ): AsyncRes<string> { + const params = sys ? { instructions: sys, input } : { input }; + const res = await this.apiCall(params); if ("error" in res) return res; else { try { - // TODO type this properly - const choices: OChoice[] = res.ok; - return { ok: choices.map((c) => c.message.content!) }; + return { ok: res.ok.output_text }; } catch (e) { return { error: `${e}` }; } @@ -57,41 +71,29 @@ export default class OpenAIAPI implements AIModelAPI { } public async stream( - sys: string, - input: ChatMessage[], + input: string, handle: (c: string) => void, + sys?: string, ) { - const messages = this.mapMessages(input); - const sysMsg: Message = { role: "system", content: sys }; - const allMessages = [sysMsg, ...messages]; - const truncated = this.truncateHistory(allMessages); - await this.apiCallStream(truncated, handle); - } - - private truncateHistory(messages: Message[]): Message[] { - const totalTokens = messages.reduce((total, message) => { - return total + this.tokenizer(message.content as string); - }, 0); - while (totalTokens > this.maxTokens && messages.length > 1) { - // Always keep the system message if it exists - const startIndex = messages[0].role === "system" ? 1 : 0; - messages.splice(startIndex, 1); - } - return messages; + const params = sys ? { instructions: sys, input } : { input }; + await this.apiCallStream(params, handle); } // TODO custom temperature? - private async apiCall(messages: Message[]): AsyncRes<OChoice[]> { - console.log({ messages }, "at the very end"); + private async apiCall( + params: ResponseCreateParamsNonStreaming, + ): AsyncRes<OpenAI.Responses.Response> { try { - const completion = await this.api.chat.completions.create({ - temperature: 1.3, - model: this.model, - messages, - max_tokens: RESPONSE_LENGTH, + const res = await this.api.responses.create({ + ...params, + // temperature: 1.3, + model: params.model || this.model, + input: params.input, + max_output_tokens: params.max_output_tokens || RESPONSE_LENGTH, + stream: false, }); - if (!completion) return { error: "null response from openai" }; - return { ok: completion.choices }; + // TODO damn there's a lot of stuff here + return { ok: res }; } catch (e) { console.log(e, "error in openai api"); return { error: `${e}` }; @@ -99,30 +101,65 @@ export default class OpenAIAPI implements AIModelAPI { } private async apiCallStream( - messages: Message[], - handle: (c: string) => void, - ): Promise<void> { + params: ResponseCreateParamsBase, + handler: (c: string) => void, + ) { + // temperature: 1.3, + const pms: ResponseCreateParamsStreaming = { + ...params, + stream: true, + model: params.model || this.model, + input: params.input, + max_output_tokens: params.max_output_tokens || RESPONSE_LENGTH, + }; try { - const stream = await this.api.chat.completions.create({ - temperature: 1.3, - model: this.model, - messages, - max_tokens: RESPONSE_LENGTH, - stream: true, - }); - - for await (const chunk of stream) { - for (const choice of chunk.choices) { - console.log({ choice }); - if (!choice.delta) continue; - const cont = choice.delta.content; - if (!cont) continue; - handle(cont); + const stream = await this.api.responses.create(pms); + for await (const event of stream) { + console.log(event); + switch (event.type) { + // TODO deal with audio and whatever + case "response.output_text.delta": + handler(event.delta); + break; + case "response.completed": + break; + default: + break; } + // if (event.type === "response.completed") + // wtf how do we use this } } catch (e) { console.log(e, "error in openai api"); - handle(`Error streaming OpenAI, ${e}`); + return { error: `${e}` }; } } + + // private async apiCallStream( + // messages: Message[], + // handle: (c: string) => void, + // ): Promise<void> { + // try { + // const stream = await this.api.chat.completions.create({ + // temperature: 1.3, + // model: this.model, + // messages, + // max_tokens: RESPONSE_LENGTH, + // stream: true, + // }); + + // for await (const chunk of stream) { + // for (const choice of chunk.choices) { + // console.log({ choice }); + // if (!choice.delta) continue; + // const cont = choice.delta.content; + // if (!cont) continue; + // handle(cont); + // } + // } + // } catch (e) { + // console.log(e, "error in openai api"); + // handle(`Error streaming OpenAI, ${e}`); + // } + // } } |
