From 42dd99bfac9777a4ecc6700b87edf26a5c984de6 Mon Sep 17 00:00:00 2001 From: polwex Date: Wed, 23 Jul 2025 02:37:15 +0700 Subject: checkpoint --- src/gemini2.ts | 149 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 src/gemini2.ts (limited to 'src/gemini2.ts') diff --git a/src/gemini2.ts b/src/gemini2.ts new file mode 100644 index 0000000..291553f --- /dev/null +++ b/src/gemini2.ts @@ -0,0 +1,149 @@ +import { + GenerativeModel, + GoogleGenerativeAI, + type Content, + type GenerateContentResult, +} from "@google/generative-ai"; +import { RESPONSE_LENGTH } from "./logic/constants"; +import type { + AIModelAPI, + ChatMessage, + OChoice, + OChunk, + OMessage, +} from "./types"; +import type { AsyncRes } from "sortug"; + +export default class GeminiAPI implements AIModelAPI { + tokenizer: (text: string) => number; + maxTokens: number; + private model: GenerativeModel; + + constructor( + maxTokens = 200_000, + tokenizer: (text: string) => number = (text) => text.length / 3, + model?: string, + ) { + this.maxTokens = maxTokens; + this.tokenizer = tokenizer; + + const gem = new GoogleGenerativeAI(Bun.env["GEMINI_API_KEY"]!); + this.model = gem.getGenerativeModel({ + // model: model || "gemini-2.0-flash-exp", + model: model || "gemini-2.5-pro-preview-05-06 ", + generationConfig: { maxOutputTokens: RESPONSE_LENGTH }, + }); + } + + public setModel(model: string) { + const gem = new GoogleGenerativeAI(Bun.env["GEMINI_API_KEY"]!); + this.model = gem.getGenerativeModel({ + model, + generationConfig: { maxOutputTokens: RESPONSE_LENGTH }, + }); + } + private mapMessages(input: ChatMessage[]): Content[] { + return input.map((m) => ({ + role: m.author === "gemini" ? "model" : "user", + parts: [{ text: m.text }], + })); + } + + private mapMessagesR1(input: ChatMessage[]): Content[] { + return input.reduce((acc: Content[], m, i) => { + const prev = acc[i - 1]; + const role = m.author === "gemini" ? "model" : "user"; + const msg = { role, parts: [{ text: m.text }] }; + if (prev?.role === role) acc[i - 1] = msg; + else acc = [...acc, msg]; + return acc; + }, []); + } + + private async apiCall( + messages: Content[], + isR1: boolean = false, + ): Promise> { + try { + const chat = this.model.startChat({ history: messages }); + const res = await chat.sendMessage(""); + return { ok: [res.response.text()] }; + } catch (e) { + console.log(e, "error in gemini api"); + return { error: `${e}` }; + } + } + + private async apiCallStream( + messages: Content[], + handle: (c: any) => void, + isR1: boolean = false, + ): Promise { + try { + const chat = this.model.startChat({ history: messages }); + const res = await chat.sendMessage(""); + // for await (const chunk of res.stream()) { + // handle(chunk.text()); + // } + } catch (e) { + console.log(e, "error in gemini api"); + handle(`Error streaming Gemini, ${e}`); + } + } + + public async send(sys: string, input: ChatMessage[]) { + console.log({ sys, input }); + this.model.systemInstruction = { role: "system", parts: [{ text: sys }] }; + const messages = this.mapMessages(input); + const truncated = this.truncateHistory(messages); + const res = await this.apiCall(truncated); + return res; + } + + public async sendR1(input: ChatMessage[]) { + const messages = this.mapMessagesR1(input); + const truncated = this.truncateHistory(messages); + const res = await this.apiCall(truncated, true); + return res; + } + + public async stream( + sys: string, + input: ChatMessage[], + handle: (c: any) => void, + ) { + this.model.systemInstruction = { role: "system", parts: [{ text: sys }] }; + const messages = this.mapMessages(input); + const truncated = this.truncateHistory(messages); + await this.apiCallStream(truncated, handle); + } + + public async streamR1(input: ChatMessage[], handle: (c: any) => void) { + const messages = this.mapMessagesR1(input); + const truncated = this.truncateHistory(messages); + await this.apiCallStream(truncated, handle, true); + } + + public async sendDoc(data: ArrayBuffer, mimeType: string, prompt: string) { + const res = await this.model.generateContent([ + { + inlineData: { + data: Buffer.from(data).toString("base64"), + mimeType, + }, + }, + prompt, + ]); + return res; + } + + private truncateHistory(messages: Content[]): Content[] { + const totalTokens = messages.reduce((total, message) => { + return total + this.tokenizer(message.parts[0].text || ""); + }, 0); + while (totalTokens > this.maxTokens && messages.length > 1) { + messages.splice(0, 1); + } + return messages; + } +} -- cgit v1.2.3