summaryrefslogtreecommitdiff
path: root/src/openai_tools.ts
blob: feb2e4af20b0cfe6b835b3b91050bf2386d70ebd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import type OpenAI from "openai";
import type { Result } from "./types";
type ToolCall = OpenAI.Chat.Completions.ChatCompletionMessageToolCall;

type Tool = OpenAI.Chat.Completions.ChatCompletionTool;
type ToolMsg = OpenAI.Chat.Completions.ChatCompletionToolMessageParam;

type Message = OpenAI.Chat.Completions.ChatCompletionMessage;

export default class OpenAIToolUse {
  api;
  model;
  socket;
  tools;
  message;
  calls;
  res: ToolMsg | null = null;
  constructor(
    api: OpenAI,
    model: string,
    tools: Tool[],
    message: Message,
    calls: ToolCall[],
  ) {
    this.api = api;
    this.model = model;
    this.socket = new WebSocket("http://localhost:8900");
    this.tools = tools;
    this.message = message;
    this.calls = calls;
    for (let c of calls) {
      console.log({ c });
    }
    this.wsHandlers();
  }
  wsHandlers() {
    this.socket.addEventListener("open", (_data) => {
      this.handleToolCalls();
    });
    this.socket.addEventListener("message", (ev) => {
      const j = JSON.parse(ev.data);
      if ("functionRes" in j) this.handleRes(j.functionRes);
    });
  }
  handleToolCalls() {
    for (let c of this.calls) this.socket.send(JSON.stringify({ call: c }));
  }
  async handleRes(res: Result<ToolMsg>) {
    if ("error" in res) {
      console.log("TODO");
      return;
    }
    this.res = res.ok;
    const messages = [this.message, res.ok];
    console.log({ messages }, "almost there");
    const completion = await this.api.chat.completions.create({
      model: this.model,
      messages,
      tools: this.tools,
    });
    console.log({ completion });
    for (let choice of completion.choices) {
      console.log({ choice });
    }
  }
}