Skip to content

Commit 8dce07d

Browse files
authored
Added summarizing examples (#955)
1 parent b88ef63 commit 8dce07d

File tree

5 files changed

+141
-2
lines changed

5 files changed

+141
-2
lines changed

pgml-sdks/rust/pgml/javascript/examples/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@ In this example, we will use `hknlp/instructor-base` model to build text embeddi
1111

1212
### [Extractive Question Answering](./extractive_question_answering.js)
1313
In this example, we will show how to use `vector_recall` result as a `context` to a HuggingFace question answering model. We will use `Builtins.transform()` to run the model on the database.
14+
15+
### [Summarizing Question Answering](./summarizing_question_answering.js)
16+
This is an example to find documents relevant to a question from the collection of documents and then summarize those documents.
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
const pgml = require("pgml");
2+
require("dotenv").config();
3+
4+
pgml.js_init_logger();
5+
6+
const main = async () => {
7+
// Initialize the collection
8+
const collection = pgml.newCollection("my_javascript_sqa_collection");
9+
10+
// Add a pipeline
11+
const model = pgml.newModel();
12+
const splitter = pgml.newSplitter();
13+
const pipeline = pgml.newPipeline(
14+
"my_javascript_sqa_pipeline",
15+
model,
16+
splitter,
17+
);
18+
await collection.add_pipeline(pipeline);
19+
20+
// Upsert documents, these documents are automatically split into chunks and embedded by our pipeline
21+
const documents = [
22+
{
23+
id: "Document One",
24+
text: "PostgresML is the best tool for machine learning applications!",
25+
},
26+
{
27+
id: "Document Two",
28+
text: "PostgresML is open source and available to everyone!",
29+
},
30+
];
31+
await collection.upsert_documents(documents);
32+
33+
const query = "What is the best tool for machine learning?";
34+
35+
// Perform vector search
36+
const queryResults = await collection
37+
.query()
38+
.vector_recall(query, pipeline)
39+
.limit(1)
40+
.fetch_all();
41+
42+
// Construct context from results
43+
const context = queryResults
44+
.map((result) => {
45+
return result[1];
46+
})
47+
.join("\n");
48+
49+
// Query for summarization
50+
const builtins = pgml.newBuiltins();
51+
const answer = await builtins.transform(
52+
{ task: "summarization", model: "sshleifer/distilbart-cnn-12-6" },
53+
[context],
54+
);
55+
56+
// Archive the collection
57+
await collection.archive();
58+
return answer;
59+
};
60+
61+
main().then((results) => {
62+
console.log("Question summary: \n", results);
63+
});

pgml-sdks/rust/pgml/python/examples/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,6 @@ In this example, we will show how to use `vector_recall` result as a `context` t
1515
### [Table Question Answering](./table_question_answering.py)
1616
In this example, we will use [Open Table-and-Text Question Answering (OTT-QA)
1717
](https://github.com/wenhuchen/OTT-QA) dataset to run queries on tables. We will use `deepset/all-mpnet-base-v2-table` model that is trained for embedding tabular data for retrieval tasks.
18+
19+
### [Summarizing Question Answering](./summarizing_question_answering.py)
20+
This is an example to find documents relevant to a question from the collection of documents and then summarize those documents.

pgml-sdks/rust/pgml/python/examples/extractive_question_answering.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ async def main():
5656
"question-answering", [json.dumps({"question": query, "context": context})]
5757
)
5858
end = time()
59-
console.print("Results for query '%s'" % query, style="bold")
60-
console.print(answer)
59+
console.print("Answer '%s'" % answer, style="bold")
6160
console.print("Query time = %0.3f" % (end - start))
6261

6362
# Archive collection
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from pgml import Collection, Model, Splitter, Pipeline, Builtins, py_init_logger
2+
import json
3+
from datasets import load_dataset
4+
from time import time
5+
from dotenv import load_dotenv
6+
from rich.console import Console
7+
import asyncio
8+
9+
10+
py_init_logger()
11+
12+
13+
async def main():
14+
load_dotenv()
15+
console = Console()
16+
17+
# Initialize collection
18+
collection = Collection("squad_collection")
19+
20+
# Create a pipeline using the default model and splitter
21+
model = Model()
22+
splitter = Splitter()
23+
pipeline = Pipeline("squadv1", model, splitter)
24+
await collection.add_pipeline(pipeline)
25+
26+
# Prep documents for upserting
27+
data = load_dataset("squad", split="train")
28+
data = data.to_pandas()
29+
data = data.drop_duplicates(subset=["context"])
30+
documents = [
31+
{"id": r["id"], "text": r["context"], "title": r["title"]}
32+
for r in data.to_dict(orient="records")
33+
]
34+
35+
# Upsert documents
36+
await collection.upsert_documents(documents[:200])
37+
38+
# Query for context
39+
query = "Who won more than 20 grammy awards?"
40+
console.print("Querying for context ...")
41+
start = time()
42+
results = (
43+
await collection.query().vector_recall(query, pipeline).limit(5).fetch_all()
44+
)
45+
end = time()
46+
console.print("\n Results for '%s' " % (query), style="bold")
47+
console.print(results)
48+
console.print("Query time = %0.3f" % (end - start))
49+
50+
# Construct context from results
51+
context = " ".join(results[0][1].strip().split())
52+
context = context.replace('"', '\\"').replace("'", "''")
53+
54+
# Query for summary
55+
builtins = Builtins()
56+
console.print("Querying for summary ...")
57+
start = time()
58+
summary = await builtins.transform(
59+
{"task": "summarization", "model": "sshleifer/distilbart-cnn-12-6"},
60+
[context],
61+
)
62+
end = time()
63+
console.print("Summary '%s'" % summary, style="bold")
64+
console.print("Query time = %0.3f" % (end - start))
65+
66+
# Archive collection
67+
await collection.archive()
68+
69+
70+
if __name__ == "__main__":
71+
asyncio.run(main())

0 commit comments

Comments
 (0)