summaryrefslogtreecommitdiff
path: root/src/model.ts
blob: 39b42dc222ad40d263873bf71aeff13bfa0178af (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import OpenAI from "openai";
import { MAX_TOKENS, RESPONSE_LENGTH } from "./logic/constants";
import type { AResult, ChatMessage, OChoice } from "./types";

type Message = OpenAI.Chat.Completions.ChatCompletionMessageParam;

type Props = {
  baseURL: string;
  apiKey: string;
  model: string;
  maxTokens?: number;
  tokenizer?: (text: string) => number;
};
export default class Conversation {
  private apiKey;
  private baseURL;
  private maxTokens: number = MAX_TOKENS;
  private tokenizer: (text: string) => number = (text) => text.length / 3;
  private api;
  private model;

  constructor(props: Props) {
    this.apiKey = props.apiKey;
    this.baseURL = props.baseURL;
    this.api = new OpenAI({ baseURL: this.baseURL, apiKey: this.apiKey });
    this.model = props.model;
    if (props.maxTokens) this.maxTokens = props.maxTokens;
    if (props.tokenizer) this.tokenizer = props.tokenizer;
  }
  public setModel(model: string) {
    this.model = model;
  }
  private mapMessages(input: ChatMessage[]): Message[] {
    return input.map((m) => {
      return { role: m.author as any, content: m.text, name: m.author };
    });
  }

  public async send(sys: string, input: ChatMessage[]): AResult<string[]> {
    const messages = this.mapMessages(input);
    const sysMsg: Message = { role: "system", content: sys };
    const allMessages = [sysMsg, ...messages];
    console.log("before truncation", allMessages);
    const truncated = this.truncateHistory(allMessages);
    const res = await this.apiCall(truncated);
    if ("error" in res) return res;
    else
      try {
        return { ok: res.ok.map((c) => c.message.content!) };
      } catch (e) {
        return { error: `${e}` };
      }
  }

  public async stream(
    sys: string,
    input: ChatMessage[],
    handle: (c: string) => void,
  ) {
    const messages = this.mapMessages(input);
    const sysMsg: Message = { role: "system", content: sys };
    const allMessages = [sysMsg, ...messages];
    const truncated = this.truncateHistory(allMessages);
    await this.apiCallStream(truncated, handle);
  }

  private truncateHistory(messages: Message[]): Message[] {
    const totalTokens = messages.reduce((total, message) => {
      return total + this.tokenizer(message.content as string);
    }, 0);
    while (totalTokens > this.maxTokens && messages.length > 1) {
      // Always keep the system message if it exists
      const startIndex = messages[0].role === "system" ? 1 : 0;
      messages.splice(startIndex, 1);
    }
    return messages;
  }

  // TODO custom temperature?
  private async apiCall(messages: Message[]): AResult<OChoice[]> {
    console.log({ messages }, "at the very end");
    try {
      const completion = await this.api.chat.completions.create({
        temperature: 1.3,
        model: this.model,
        messages,
        max_tokens: RESPONSE_LENGTH,
      });
      if (!completion) return { error: "null response from openai" };
      return { ok: completion.choices };
    } catch (e) {
      console.log(e, "error in openai api");
      return { error: `${e}` };
    }
  }

  private async apiCallStream(
    messages: Message[],
    handle: (c: string) => void,
  ): Promise<void> {
    try {
      const stream = await this.api.chat.completions.create({
        temperature: 1.3,
        model: this.model,
        messages,
        max_tokens: RESPONSE_LENGTH,
        stream: true,
      });

      for await (const chunk of stream) {
        for (const choice of chunk.choices) {
          console.log({ choice });
          if (!choice.delta) continue;
          const cont = choice.delta.content;
          if (!cont) continue;
          handle(cont);
        }
      }
    } catch (e) {
      console.log(e, "error in openai api");
      handle(`Error streaming OpenAI, ${e}`);
    }
  }
}