-
Notifications
You must be signed in to change notification settings - Fork 2.5k
/
Copy pathchain.test.ts
146 lines (136 loc) Β· 3.83 KB
/
chain.test.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import { test, expect, jest } from "@jest/globals";
import { LLM } from "@langchain/core/language_models/llms";
import { PromptTemplate } from "@langchain/core/prompts";
import { FakeEmbeddings } from "@langchain/core/utils/testing";
import { ChainTool } from "../chain.js";
import { LLMChain } from "../../chains/llm_chain.js";
import { VectorDBQAChain } from "../../chains/vector_db_qa.js";
import { MemoryVectorStore } from "../../vectorstores/memory.js";
class FakeLLM extends LLM {
_llmType() {
return "fake";
}
async _call(prompt: string): Promise<string> {
return prompt;
}
}
test("chain tool with llm chain and local callback", async () => {
const calls: string[] = [];
const handleToolStart = jest.fn(() => {
calls.push("tool start");
});
const handleToolEnd = jest.fn(() => {
calls.push("tool end");
});
const handleLLMStart = jest.fn(() => {
calls.push("llm start");
});
const handleLLMEnd = jest.fn(() => {
calls.push("llm end");
});
const handleChainStart = jest.fn(() => {
calls.push("chain start");
});
const handleChainEnd = jest.fn(() => {
calls.push("chain end");
});
const chain = new LLMChain({
llm: new FakeLLM({}),
prompt: PromptTemplate.fromTemplate("hello world"),
});
const tool = new ChainTool({ chain, name: "fake", description: "fake" });
const result = await tool.invoke("hi", {
callbacks: [
{
awaitHandlers: true,
handleToolStart,
handleToolEnd,
handleLLMStart,
handleLLMEnd,
handleChainStart,
handleChainEnd,
},
],
});
expect(result).toMatchInlineSnapshot(`"hello world"`);
expect(handleToolStart).toBeCalledTimes(1);
expect(handleToolEnd).toBeCalledTimes(1);
expect(handleLLMStart).toBeCalledTimes(1);
expect(handleLLMEnd).toBeCalledTimes(1);
expect(handleChainStart).toBeCalledTimes(1);
expect(handleChainEnd).toBeCalledTimes(1);
expect(calls).toMatchInlineSnapshot(`
[
"tool start",
"chain start",
"llm start",
"llm end",
"chain end",
"tool end",
]
`);
});
test("chain tool with vectordbqa chain", async () => {
const calls: string[] = [];
const handleToolStart = jest.fn(() => {
calls.push("tool start");
});
const handleToolEnd = jest.fn(() => {
calls.push("tool end");
});
const handleLLMStart = jest.fn(() => {
calls.push("llm start");
});
const handleLLMEnd = jest.fn(() => {
calls.push("llm end");
});
const handleChainStart = jest.fn(() => {
calls.push("chain start");
});
const handleChainEnd = jest.fn(() => {
calls.push("chain end");
});
const chain = VectorDBQAChain.fromLLM(
new FakeLLM({}),
await MemoryVectorStore.fromExistingIndex(new FakeEmbeddings())
);
const tool = new ChainTool({ chain, name: "fake", description: "fake" });
const result = await tool.invoke("hi", {
callbacks: [
{
awaitHandlers: true,
handleToolStart,
handleToolEnd,
handleLLMStart,
handleLLMEnd,
handleChainStart,
handleChainEnd,
},
],
});
expect(result).toMatchInlineSnapshot(`
"Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
Question: hi
Helpful Answer:"
`);
expect(handleToolStart).toBeCalledTimes(1);
expect(handleToolEnd).toBeCalledTimes(1);
expect(handleLLMStart).toBeCalledTimes(1);
expect(handleLLMEnd).toBeCalledTimes(1);
expect(handleChainStart).toBeCalledTimes(3);
expect(handleChainEnd).toBeCalledTimes(3);
expect(calls).toMatchInlineSnapshot(`
[
"tool start",
"chain start",
"chain start",
"chain start",
"llm start",
"llm end",
"chain end",
"chain end",
"chain end",
"tool end",
]
`);
});