summaryrefslogtreecommitdiff
path: root/src/gemini.ts
diff options
context:
space:
mode:
Diffstat (limited to 'src/gemini.ts')
-rw-r--r--src/gemini.ts137
1 files changed, 137 insertions, 0 deletions
diff --git a/src/gemini.ts b/src/gemini.ts
new file mode 100644
index 0000000..2f685a2
--- /dev/null
+++ b/src/gemini.ts
@@ -0,0 +1,137 @@
+import {
+ GenerativeModel,
+ GoogleGenerativeAI,
+ type Content,
+ type GenerateContentResult,
+} from "@google/generative-ai";
+import { RESPONSE_LENGTH } from "./logic/constants";
+import type { AResult, ChatMessage, OChoice, OChunk, OMessage } from "./types";
+
+export default class Conversation {
+ private tokenizer: (text: string) => number;
+ private maxTokens: number;
+ private model: GenerativeModel;
+
+ constructor(
+ maxTokens = 200_000,
+ tokenizer: (text: string) => number = (text) => text.length / 3,
+ ) {
+ this.maxTokens = maxTokens;
+ this.tokenizer = tokenizer;
+
+ const gem = new GoogleGenerativeAI(Bun.env["GEMINI_API_KEY"]!);
+ this.model = gem.getGenerativeModel({
+ model: "gemini-2.0-flash-exp",
+ generationConfig: { maxOutputTokens: RESPONSE_LENGTH },
+ });
+ }
+
+ public setModel(model: string) {
+ const gem = new GoogleGenerativeAI(Bun.env["GEMINI_API_KEY"]!);
+ this.model = gem.getGenerativeModel({
+ model,
+ generationConfig: { maxOutputTokens: RESPONSE_LENGTH },
+ });
+ }
+ private mapMessages(input: ChatMessage[]): Content[] {
+ return input.map((m) => ({
+ role: m.author === "gemini" ? "model" : "user",
+ parts: [{ text: m.text }],
+ }));
+ }
+
+ private mapMessagesR1(input: ChatMessage[]): Content[] {
+ return input.reduce((acc: Content[], m, i) => {
+ const prev = acc[i - 1];
+ const role = m.author === "gemini" ? "model" : "user";
+ const msg = { role, parts: [{ text: m.text }] };
+ if (prev?.role === role) acc[i - 1] = msg;
+ else acc = [...acc, msg];
+ return acc;
+ }, []);
+ }
+
+ private async apiCall(
+ messages: Content[],
+ isR1: boolean = false,
+ ): Promise<AResult<string[]>> {
+ try {
+ const chat = this.model.startChat({ history: messages });
+ const res = await chat.sendMessage("");
+ return { ok: [res.response.text()] };
+ } catch (e) {
+ console.log(e, "error in gemini api");
+ return { error: `${e}` };
+ }
+ }
+
+ private async apiCallStream(
+ messages: Content[],
+ handle: (c: any) => void,
+ isR1: boolean = false,
+ ): Promise<void> {
+ try {
+ const chat = this.model.startChat({ history: messages });
+ const res = await chat.sendMessage("");
+ // for await (const chunk of res.stream()) {
+ // handle(chunk.text());
+ // }
+ } catch (e) {
+ console.log(e, "error in gemini api");
+ handle(`Error streaming Gemini, ${e}`);
+ }
+ }
+
+ public async send(sys: string, input: ChatMessage[]) {
+ const messages = this.mapMessages(input);
+ const truncated = this.truncateHistory(messages);
+ const res = await this.apiCall(truncated);
+ return res;
+ }
+
+ public async sendR1(input: ChatMessage[]) {
+ const messages = this.mapMessagesR1(input);
+ const truncated = this.truncateHistory(messages);
+ const res = await this.apiCall(truncated, true);
+ return res;
+ }
+
+ public async stream(
+ sys: string,
+ input: ChatMessage[],
+ handle: (c: any) => void,
+ ) {
+ const messages = this.mapMessages(input);
+ const truncated = this.truncateHistory(messages);
+ await this.apiCallStream(truncated, handle);
+ }
+
+ public async streamR1(input: ChatMessage[], handle: (c: any) => void) {
+ const messages = this.mapMessagesR1(input);
+ const truncated = this.truncateHistory(messages);
+ await this.apiCallStream(truncated, handle, true);
+ }
+
+ public async sendDoc(data: ArrayBuffer, mimeType: string, prompt: string) {
+ const res = await this.model.generateContent([
+ {
+ inlineData: {
+ data: Buffer.from(data).toString("base64"),
+ mimeType,
+ },
+ },
+ prompt,
+ ]);
+ return res;
+ }
+
+ private truncateHistory(messages: Content[]): Content[] {
+ const totalTokens = messages.reduce((total, message) => {
+ return total + this.tokenizer(message.parts[0].text || "");
+ }, 0);
+ while (totalTokens > this.maxTokens && messages.length > 1) {
+ messages.splice(0, 1);
+ }
+ return messages;
+ }
+}