import { GenerativeModel, GoogleGenerativeAI, type Content, type GenerateContentResult, } from "@google/generative-ai"; import { RESPONSE_LENGTH } from "./logic/constants"; import type { AIModelAPI, ChatMessage, OChoice, OChunk, OMessage, } from "./types"; import type { AsyncRes } from "sortug"; export default class GeminiAPI implements AIModelAPI { tokenizer: (text: string) => number; maxTokens: number; private model: GenerativeModel; constructor( maxTokens = 200_000, tokenizer: (text: string) => number = (text) => text.length / 3, model?: string, ) { this.maxTokens = maxTokens; this.tokenizer = tokenizer; const gem = new GoogleGenerativeAI(Bun.env["GEMINI_API_KEY"]!); this.model = gem.getGenerativeModel({ // model: model || "gemini-2.0-flash-exp", model: model || "gemini-2.5-pro-preview-05-06 ", generationConfig: { maxOutputTokens: RESPONSE_LENGTH }, }); } public setModel(model: string) { const gem = new GoogleGenerativeAI(Bun.env["GEMINI_API_KEY"]!); this.model = gem.getGenerativeModel({ model, generationConfig: { maxOutputTokens: RESPONSE_LENGTH }, }); } private mapMessages(input: ChatMessage[]): Content[] { return input.map((m) => ({ role: m.author === "gemini" ? "model" : "user", parts: [{ text: m.text }], })); } private mapMessagesR1(input: ChatMessage[]): Content[] { return input.reduce((acc: Content[], m, i) => { const prev = acc[i - 1]; const role = m.author === "gemini" ? "model" : "user"; const msg = { role, parts: [{ text: m.text }] }; if (prev?.role === role) acc[i - 1] = msg; else acc = [...acc, msg]; return acc; }, []); } private async apiCall( messages: Content[], isR1: boolean = false, ): Promise> { try { const chat = this.model.startChat({ history: messages }); const res = await chat.sendMessage(""); return { ok: [res.response.text()] }; } catch (e) { console.log(e, "error in gemini api"); return { error: `${e}` }; } } private async apiCallStream( messages: Content[], handle: (c: any) => void, isR1: boolean = false, ): Promise { try { const chat = this.model.startChat({ history: messages }); const res = await chat.sendMessage(""); // for await (const chunk of res.stream()) { // handle(chunk.text()); // } } catch (e) { console.log(e, "error in gemini api"); handle(`Error streaming Gemini, ${e}`); } } public async send(sys: string, input: ChatMessage[]) { console.log({ sys, input }); this.model.systemInstruction = { role: "system", parts: [{ text: sys }] }; const messages = this.mapMessages(input); const truncated = this.truncateHistory(messages); const res = await this.apiCall(truncated); return res; } public async sendR1(input: ChatMessage[]) { const messages = this.mapMessagesR1(input); const truncated = this.truncateHistory(messages); const res = await this.apiCall(truncated, true); return res; } public async stream( sys: string, input: ChatMessage[], handle: (c: any) => void, ) { this.model.systemInstruction = { role: "system", parts: [{ text: sys }] }; const messages = this.mapMessages(input); const truncated = this.truncateHistory(messages); await this.apiCallStream(truncated, handle); } public async streamR1(input: ChatMessage[], handle: (c: any) => void) { const messages = this.mapMessagesR1(input); const truncated = this.truncateHistory(messages); await this.apiCallStream(truncated, handle, true); } public async sendDoc(data: ArrayBuffer, mimeType: string, prompt: string) { const res = await this.model.generateContent([ { inlineData: { data: Buffer.from(data).toString("base64"), mimeType, }, }, prompt, ]); return res; } private truncateHistory(messages: Content[]): Content[] { const totalTokens = messages.reduce((total, message) => { return total + this.tokenizer(message.parts[0].text || ""); }, 0); while (totalTokens > this.maxTokens && messages.length > 1) { messages.splice(0, 1); } return messages; } }