summaryrefslogtreecommitdiff
path: root/src/claude.ts
diff options
context:
space:
mode:
Diffstat (limited to 'src/claude.ts')
-rw-r--r--src/claude.ts157
1 files changed, 157 insertions, 0 deletions
diff --git a/src/claude.ts b/src/claude.ts
new file mode 100644
index 0000000..377316e
--- /dev/null
+++ b/src/claude.ts
@@ -0,0 +1,157 @@
+import Claude from "@anthropic-ai/sdk";
+import { RESPONSE_LENGTH } from "./logic/constants";
+import type { AResult, ChatMessage, OChoice, OChunk, OMessage } from "./types";
+import { BOOKWORM_SYS } from "./prompts";
+
+type Message = Claude.Messages.MessageParam;
+
+export default class Conversation {
+ private tokenizer: (text: string) => number;
+ private maxTokens: number;
+ model: string = "claude-3-5-sonnet-20241022";
+ constructor(
+ maxTokens = 200_000,
+ tokenizer: (text: string) => number = (text) => text.length / 3,
+ ) {
+ this.maxTokens = maxTokens;
+ this.tokenizer = tokenizer;
+ }
+ public setModel(model: string) {
+ this.model = model;
+ }
+ private mapMessages(input: ChatMessage[]): Message[] {
+ return input.map((m) => {
+ const role = m.author === "claude" ? "assistant" : "user";
+ return { role, content: m.text };
+ });
+ }
+
+ private mapMessagesR1(input: ChatMessage[]): Message[] {
+ return input.reduce((acc: Message[], m, i) => {
+ const prev = acc[i - 1];
+ const role: any = m.author === "claude" ? "assistant" : "user";
+ const msg = { role, content: m.text };
+ if (prev?.role === role) acc[i - 1] = msg;
+ else acc = [...acc, msg];
+ return acc;
+ }, []);
+ }
+
+ public async send(sys: string, input: ChatMessage[]) {
+ const messages = this.mapMessages(input);
+ const truncated = this.truncateHistory(messages);
+ const res = await this.apiCall(sys, truncated);
+ return res;
+ }
+
+ public async sendR1(input: ChatMessage[]) {
+ const messages = this.mapMessagesR1(input);
+ const truncated = this.truncateHistory(messages);
+ const res = await this.apiCall("", truncated, true);
+ return res;
+ }
+ public async sendDoc(data: string) {
+ const sys = BOOKWORM_SYS;
+ const msg: Message = {
+ role: "user",
+ content: [
+ {
+ type: "document",
+ source: { type: "base64", data, media_type: "application/pdf" },
+ },
+ {
+ type: "text",
+ text: "Please analyze this according to your system prompt. Be thorough.",
+ },
+ ],
+ };
+ const res = await this.apiCall(sys, [msg]);
+ return res;
+ }
+
+ public async stream(
+ sys: string,
+ input: ChatMessage[],
+ handle: (c: any) => void,
+ ) {
+ const messages = this.mapMessages(input);
+ const truncated = this.truncateHistory(messages);
+ await this.apiCallStream(sys, truncated, handle);
+ }
+
+ public async streamR1(input: ChatMessage[], handle: (c: any) => void) {
+ const messages = this.mapMessagesR1(input);
+ const truncated = this.truncateHistory(messages);
+ await this.apiCallStream("", truncated, handle, true);
+ }
+
+ private truncateHistory(messages: Message[]): Message[] {
+ const totalTokens = messages.reduce((total, message) => {
+ return total + this.tokenizer(message.content as string);
+ }, 0);
+ while (totalTokens > this.maxTokens && messages.length > 1) {
+ messages.splice(0, 1);
+ }
+ return messages;
+ }
+
+ // TODO
+ // https://docs.anthropic.com/en/api/messages-examples#putting-words-in-claudes-mouth
+ private async apiCall(
+ system: string,
+ messages: Message[],
+ isR1: boolean = false,
+ ): Promise<AResult<string[]>> {
+ try {
+ const claud = new Claude();
+ // const list = await claud.models.list();
+ // console.log(list.data);
+ const res = await claud.messages.create({
+ model: this.model,
+ max_tokens: RESPONSE_LENGTH,
+ system,
+ messages,
+ });
+ return {
+ ok: res.content.reduce((acc: string[], item) => {
+ if (item.type === "tool_use") return acc;
+ else return [...acc, item.text];
+ }, []),
+ };
+ } catch (e) {
+ console.log(e, "error in claude api");
+ return { error: `${e}` };
+ }
+ }
+
+ private async apiCallStream(
+ system: string,
+ messages: Message[],
+ handle: (c: any) => void,
+ isR1: boolean = false,
+ ): Promise<void> {
+ try {
+ const claud = new Claude();
+ const stream = await claud.messages.create({
+ model: this.model,
+ max_tokens: RESPONSE_LENGTH,
+ system,
+ messages,
+ stream: true,
+ });
+
+ for await (const part of stream) {
+ if (part.type === "message_start") continue;
+ if (part.type === "content_block_start") continue;
+ if (part.type === "content_block_delta") {
+ console.log("delta", part.delta);
+ const delta: any = part.delta;
+ handle(delta.text);
+ }
+ }
+ } catch (e) {
+ console.log(e, "error in claude api");
+ handle(`Error streaming Claude, ${e}`);
+ }
+ }
+}