1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
|
import OpenAI from "openai";
import { MAX_TOKENS, RESPONSE_LENGTH } from "./logic/constants";
import type { AIModelAPI, ChatMessage, OChoice } from "./types";
import type { AsyncRes } from "sortug";
type Message = OpenAI.Chat.Completions.ChatCompletionMessageParam;
type Props = {
baseURL: string;
apiKey: string;
model?: string;
maxTokens?: number;
tokenizer?: (text: string) => number;
};
export default class OpenAIAPI implements AIModelAPI {
private apiKey;
private baseURL;
private api;
maxTokens: number = MAX_TOKENS;
tokenizer: (text: string) => number = (text) => text.length / 3;
model;
constructor(props: Props) {
this.apiKey = props.apiKey;
this.baseURL = props.baseURL;
this.api = new OpenAI({ baseURL: this.baseURL, apiKey: this.apiKey });
this.model = props.model || "";
if (props.maxTokens) this.maxTokens = props.maxTokens;
if (props.tokenizer) this.tokenizer = props.tokenizer;
}
public setModel(model: string) {
this.model = model;
}
private mapMessages(input: ChatMessage[]): Message[] {
return input.map((m) => {
return { role: m.author as any, content: m.text, name: m.author };
});
}
public async send(sys: string, input: ChatMessage[]): AsyncRes<string[]> {
const messages = this.mapMessages(input);
const sysMsg: Message = { role: "system", content: sys };
const allMessages = [sysMsg, ...messages];
console.log("before truncation", allMessages);
const truncated = this.truncateHistory(allMessages);
const res = await this.apiCall(truncated);
if ("error" in res) return res;
else {
try {
// TODO type this properly
const choices: OChoice[] = res.ok;
return { ok: choices.map((c) => c.message.content!) };
} catch (e) {
return { error: `${e}` };
}
}
}
public async stream(
sys: string,
input: ChatMessage[],
handle: (c: string) => void,
) {
const messages = this.mapMessages(input);
const sysMsg: Message = { role: "system", content: sys };
const allMessages = [sysMsg, ...messages];
const truncated = this.truncateHistory(allMessages);
await this.apiCallStream(truncated, handle);
}
private truncateHistory(messages: Message[]): Message[] {
const totalTokens = messages.reduce((total, message) => {
return total + this.tokenizer(message.content as string);
}, 0);
while (totalTokens > this.maxTokens && messages.length > 1) {
// Always keep the system message if it exists
const startIndex = messages[0].role === "system" ? 1 : 0;
messages.splice(startIndex, 1);
}
return messages;
}
// TODO custom temperature?
private async apiCall(messages: Message[]): AsyncRes<OChoice[]> {
console.log({ messages }, "at the very end");
try {
const completion = await this.api.chat.completions.create({
temperature: 1.3,
model: this.model,
messages,
max_tokens: RESPONSE_LENGTH,
});
if (!completion) return { error: "null response from openai" };
return { ok: completion.choices };
} catch (e) {
console.log(e, "error in openai api");
return { error: `${e}` };
}
}
private async apiCallStream(
messages: Message[],
handle: (c: string) => void,
): Promise<void> {
try {
const stream = await this.api.chat.completions.create({
temperature: 1.3,
model: this.model,
messages,
max_tokens: RESPONSE_LENGTH,
stream: true,
});
for await (const chunk of stream) {
for (const choice of chunk.choices) {
console.log({ choice });
if (!choice.delta) continue;
const cont = choice.delta.content;
if (!cont) continue;
handle(cont);
}
}
} catch (e) {
console.log(e, "error in openai api");
handle(`Error streaming OpenAI, ${e}`);
}
}
}
|