-
Notifications
You must be signed in to change notification settings - Fork 2.5k
/
Copy pathsql.test.ts
159 lines (142 loc) Β· 4.57 KB
/
sql.test.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
/* eslint-disable no-process-env */
import { test, expect, beforeEach, afterEach } from "@jest/globals";
import { DataSource } from "typeorm";
import {
InfoSqlTool,
QuerySqlTool,
ListTablesSqlTool,
QueryCheckerTool,
} from "../../tools/sql.js";
import { SqlDatabase } from "../../sql_db.js";
const previousEnv = process.env;
let db: SqlDatabase;
beforeEach(async () => {
const datasource = new DataSource({
type: "sqlite",
database: ":memory:",
synchronize: true,
});
await datasource.initialize();
await datasource.query(`
CREATE TABLE products (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT, price INTEGER);
`);
await datasource.query(`
INSERT INTO products (name, price) VALUES ('Apple', 100);
`);
await datasource.query(`
INSERT INTO products (name, price) VALUES ('Banana', 200);
`);
await datasource.query(`
INSERT INTO products (name, price) VALUES ('Orange', 300);
`);
await datasource.query(`
CREATE TABLE users (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT, age INTEGER);
`);
await datasource.query(`
INSERT INTO users (name, age) VALUES ('Alice', 20);
`);
await datasource.query(`
INSERT INTO users (name, age) VALUES ('Bob', 21);
`);
await datasource.query(`
INSERT INTO users (name, age) VALUES ('Charlie', 22);
`);
db = await SqlDatabase.fromDataSourceParams({
appDataSource: datasource,
});
process.env = { ...previousEnv, OPENAI_API_KEY: "test" };
});
afterEach(async () => {
process.env = previousEnv;
await db.appDataSource.destroy();
});
test.skip("QuerySqlTool", async () => {
const querySqlTool = new QuerySqlTool(db);
const result = await querySqlTool.invoke("SELECT * FROM users");
expect(result).toBe(
`[{"id":1,"name":"Alice","age":20},{"id":2,"name":"Bob","age":21},{"id":3,"name":"Charlie","age":22}]`
);
});
test.skip("QuerySqlTool with error", async () => {
const querySqlTool = new QuerySqlTool(db);
const result = await querySqlTool.invoke("SELECT * FROM userss");
expect(result).toBe(`QueryFailedError: SQLITE_ERROR: no such table: userss`);
});
test.skip("InfoSqlTool", async () => {
const infoSqlTool = new InfoSqlTool(db);
const result = await infoSqlTool.invoke("users, products");
const expectStr = `
CREATE TABLE products (
id INTEGER , name TEXT , price INTEGER )
SELECT * FROM "products" LIMIT 3;
id name price
1 Apple 100
2 Banana 200
3 Orange 300
CREATE TABLE users (
id INTEGER , name TEXT , age INTEGER )
SELECT * FROM "users" LIMIT 3;
id name age
1 Alice 20
2 Bob 21
3 Charlie 22`;
expect(result.trim()).toBe(expectStr.trim());
});
test.skip("InfoSqlTool with customDescription", async () => {
db.customDescription = {
products: "Custom Description for Products Table",
users: "Custom Description for Users Table",
userss: "Should not appear",
};
const infoSqlTool = new InfoSqlTool(db);
const result = await infoSqlTool.invoke("users, products");
const expectStr = `
Custom Description for Products Table
CREATE TABLE products (
id INTEGER , name TEXT , price INTEGER )
SELECT * FROM "products" LIMIT 3;
id name price
1 Apple 100
2 Banana 200
3 Orange 300
Custom Description for Users Table
CREATE TABLE users (
id INTEGER , name TEXT , age INTEGER )
SELECT * FROM "users" LIMIT 3;
id name age
1 Alice 20
2 Bob 21
3 Charlie 22`;
expect(result.trim()).toBe(expectStr.trim());
});
test.skip("InfoSqlTool with error", async () => {
const infoSqlTool = new InfoSqlTool(db);
const result = await infoSqlTool.invoke("userss, products");
expect(result).toBe(
`Error: Wrong target table name: the table userss was not found in the database`
);
});
test.skip("ListTablesSqlTool", async () => {
const listSqlTool = new ListTablesSqlTool(db);
const result = await listSqlTool.invoke("");
expect(result).toBe(`products, users`);
});
test.skip("QueryCheckerTool", async () => {
const queryCheckerTool = new QueryCheckerTool();
expect(queryCheckerTool.llmChain).not.toBeNull();
expect(queryCheckerTool.llmChain.inputKeys).toEqual(["query"]);
});
test.skip("ListTablesSqlTool with include tables", async () => {
const includesTables = ["users"];
db.includesTables = includesTables;
const listSqlTool = new ListTablesSqlTool(db);
const result = await listSqlTool.invoke("");
expect(result).toBe("users");
});
test.skip("ListTablesSqlTool with ignore tables", async () => {
const ignoreTables = ["products"];
db.ignoreTables = ignoreTables;
const listSqlTool = new ListTablesSqlTool(db);
const result = await listSqlTool.invoke("");
expect(result).toBe("users");
});