diff options
Diffstat (limited to 'src/gemini.ts')
| -rw-r--r-- | src/gemini.ts | 207 |
1 files changed, 101 insertions, 106 deletions
diff --git a/src/gemini.ts b/src/gemini.ts index 2f685a2..3e636c2 100644 --- a/src/gemini.ts +++ b/src/gemini.ts @@ -1,137 +1,132 @@ import { - GenerativeModel, - GoogleGenerativeAI, + Chat, + GoogleGenAI, type Content, - type GenerateContentResult, -} from "@google/generative-ai"; + type GeneratedImage, + type GeneratedVideo, +} from "@google/genai"; import { RESPONSE_LENGTH } from "./logic/constants"; -import type { AResult, ChatMessage, OChoice, OChunk, OMessage } from "./types"; +import type { + AIModelAPI, + ChatMessage, + OChoice, + OChunk, + OMessage, +} from "./types"; +import type { AsyncRes } from "sortug"; -export default class Conversation { - private tokenizer: (text: string) => number; - private maxTokens: number; - private model: GenerativeModel; +export default class GeminiAPI { + tokenizer: (text: string) => number; + maxTokens: number; + private model: string; + api: GoogleGenAI; + chats: Map<string, Chat> = new Map<string, Chat>(); constructor( maxTokens = 200_000, tokenizer: (text: string) => number = (text) => text.length / 3, + model?: string, ) { this.maxTokens = maxTokens; this.tokenizer = tokenizer; - const gem = new GoogleGenerativeAI(Bun.env["GEMINI_API_KEY"]!); - this.model = gem.getGenerativeModel({ - model: "gemini-2.0-flash-exp", - generationConfig: { maxOutputTokens: RESPONSE_LENGTH }, - }); + const gem = new GoogleGenAI({ apiKey: Bun.env["GEMINI_API_KEY"]! }); + this.api = gem; + this.model = model || "gemini-2.5-pro-preview-05-06 "; } - public setModel(model: string) { - const gem = new GoogleGenerativeAI(Bun.env["GEMINI_API_KEY"]!); - this.model = gem.getGenerativeModel({ - model, - generationConfig: { maxOutputTokens: RESPONSE_LENGTH }, - }); + createChat({ name, history }: { name?: string; history?: Content[] }) { + const chat = this.api.chats.create({ model: this.model, history }); + this.chats.set(name ? name : Date.now().toString(), chat); } - private mapMessages(input: ChatMessage[]): Content[] { - return input.map((m) => ({ - role: m.author === "gemini" ? "model" : "user", - parts: [{ text: m.text }], - })); + async followChat(name: string, message: string): AsyncRes<string> { + const chat = this.chats.get(name); + if (!chat) return { error: "no chat with that name" }; + else { + const response = await chat.sendMessage({ message }); + const text = response.text; + return { ok: text || "" }; + } } - - private mapMessagesR1(input: ChatMessage[]): Content[] { - return input.reduce((acc: Content[], m, i) => { - const prev = acc[i - 1]; - const role = m.author === "gemini" ? "model" : "user"; - const msg = { role, parts: [{ text: m.text }] }; - if (prev?.role === role) acc[i - 1] = msg; - else acc = [...acc, msg]; - return acc; - }, []); + async followChatStream( + name: string, + message: string, + handler: (data: string) => void, + ) { + const chat = this.chats.get(name); + if (!chat) throw new Error("no chat!"); + else { + const response = await chat.sendMessageStream({ message }); + for await (const chunk of response) { + const text = chunk.text; + handler(text || ""); + } + } } - private async apiCall( - messages: Content[], - isR1: boolean = false, - ): Promise<AResult<string[]>> { + async send(message: string, systemPrompt?: string): AsyncRes<string> { try { - const chat = this.model.startChat({ history: messages }); - const res = await chat.sendMessage(""); - return { ok: [res.response.text()] }; + const opts = { + model: this.model, + contents: message, + }; + const fopts = systemPrompt + ? { ...opts, config: { systemInstruction: systemPrompt } } + : opts; + const response = await this.api.models.generateContent(fopts); + return { ok: response.text || "" }; } catch (e) { - console.log(e, "error in gemini api"); return { error: `${e}` }; } } + async sendStream( + handler: (s: string) => void, + message: string, + systemPrompt?: string, + ) { + const opts = { + model: this.model, + contents: message, + }; + const fopts = systemPrompt + ? { ...opts, config: { systemInstruction: systemPrompt } } + : opts; + const response = await this.api.models.generateContentStream(fopts); + for await (const chunk of response) { + handler(chunk.text || ""); + } + } - private async apiCallStream( - messages: Content[], - handle: (c: any) => void, - isR1: boolean = false, - ): Promise<void> { + async makeImage(prompt: string): AsyncRes<GeneratedImage[]> { try { - const chat = this.model.startChat({ history: messages }); - const res = await chat.sendMessage(""); - // for await (const chunk of res.stream()) { - // handle(chunk.text()); - // } + const response = await this.api.models.generateImages({ + model: this.model, + prompt, + }); + // TODO if empty or undefined return error + return { ok: response.generatedImages || [] }; } catch (e) { - console.log(e, "error in gemini api"); - handle(`Error streaming Gemini, ${e}`); + return { error: `${e}` }; } } - - public async send(sys: string, input: ChatMessage[]) { - const messages = this.mapMessages(input); - const truncated = this.truncateHistory(messages); - const res = await this.apiCall(truncated); - return res; - } - - public async sendR1(input: ChatMessage[]) { - const messages = this.mapMessagesR1(input); - const truncated = this.truncateHistory(messages); - const res = await this.apiCall(truncated, true); - return res; - } - - public async stream( - sys: string, - input: ChatMessage[], - handle: (c: any) => void, - ) { - const messages = this.mapMessages(input); - const truncated = this.truncateHistory(messages); - await this.apiCallStream(truncated, handle); - } - - public async streamR1(input: ChatMessage[], handle: (c: any) => void) { - const messages = this.mapMessagesR1(input); - const truncated = this.truncateHistory(messages); - await this.apiCallStream(truncated, handle, true); - } - - public async sendDoc(data: ArrayBuffer, mimeType: string, prompt: string) { - const res = await this.model.generateContent([ - { - inlineData: { - data: Buffer.from(data).toString("base64"), - mimeType, - }, - }, - prompt, - ]); - return res; - } - - private truncateHistory(messages: Content[]): Content[] { - const totalTokens = messages.reduce((total, message) => { - return total + this.tokenizer(message.parts[0].text || ""); - }, 0); - while (totalTokens > this.maxTokens && messages.length > 1) { - messages.splice(0, 1); + async makeVideo({ + prompt, + image, + }: { + prompt?: string; + image?: string; + }): AsyncRes<GeneratedVideo[]> { + try { + const response = await this.api.models.generateVideos({ + model: this.model, + prompt, + }); + // TODO if empty or undefined return error + return { ok: response.response?.generatedVideos || [] }; + } catch (e) { + return { error: `${e}` }; } - return messages; } } +// TODO how to use caches +// https://ai.google.dev/api/caching |
