-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathkendra_chat_flan_xl_nb.py
154 lines (131 loc) · 5.09 KB
/
kendra_chat_flan_xl_nb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# pylint: disable=invalid-name,line-too-long
"""
Adapted from
https://github.com/aws-samples/amazon-kendra-langchain-extensions/blob/main/kendra_retriever_samples/kendra_chat_flan_xl.py
"""
import json
import os
from langchain import SagemakerEndpoint
from langchain.chains import ConversationalRetrievalChain
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.prompts import PromptTemplate
from langchain.retrievers import AmazonKendraRetriever
class bcolors: #pylint: disable=too-few-public-methods
"""
ANSI escape sequences
https://stackoverflow.com/questions/287871/how-do-i-print-colored-text-to-the-terminal
"""
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKCYAN = '\033[96m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
MAX_HISTORY_LENGTH = 5
def build_chain():
"""
Builds the LangChain chain
"""
region = os.environ["AWS_REGION"]
kendra_index_id = os.environ["KENDRA_INDEX_ID"]
endpoint_name = os.environ["FLAN_XL_ENDPOINT"]
class ContentHandler(LLMContentHandler):
"""
Handler class to transform input and ouput
into a format that the SageMaker Endpoint can understand
"""
content_type = "application/json"
accepts = "application/json"
def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
return input_str.encode('utf-8')
def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
return response_json["generated_texts"][0]
content_handler = ContentHandler()
# Initialize LLM hosted on a SageMaker endpoint
# https://python.langchain.com/en/latest/modules/models/llms/integrations/sagemaker.html
llm=SagemakerEndpoint(
endpoint_name=endpoint_name,
region_name="us-east-1",
model_kwargs={"temperature":1e-10, "max_length": 500},
content_handler=content_handler
)
# Initialize Kendra index retriever
retriever = AmazonKendraRetriever(
index_id=kendra_index_id,
region_name=region
)
# Define prompt template
# https://python.langchain.com/en/latest/modules/prompts/prompt_templates.html
prompt_template = """
The following is a friendly conversation between a human and an AI.
The AI is talkative and provides lots of specific details from its context.
If the AI does not know the answer to a question, it truthfully says it
does not know.
{context}
Instruction: Based on the above documents, provide a detailed answer for,
{question} Answer "don't know" if not present in the document. Solution:
"""
qa_prompt = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
condense_qa_template = """
Given the following conversation and a follow up question, rephrase the follow up question
to be a standalone question.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
standalone_question_prompt = PromptTemplate.from_template(condense_qa_template)
# Initialize QA chain with chat history
# https://python.langchain.com/en/latest/modules/chains/index_examples/chat_vector_db.html
qa = ConversationalRetrievalChain.from_llm( #
llm=llm,
retriever=retriever,
condense_question_prompt=standalone_question_prompt,
return_source_documents=True,
combine_docs_chain_kwargs={"prompt": qa_prompt}
)
return qa
def run_chain(chain, prompt: str, history=None):
"""
Runs the Q&A chain given a user prompt and chat history
"""
if history is None:
history = []
return chain({"question": prompt, "chat_history": history})
def prompt_user():
"""
Helper function to get user input
"""
print(f"{bcolors.OKBLUE}Hello! How can I help you?{bcolors.ENDC}")
print(f"{bcolors.OKCYAN}Ask a question, start a New search: or Stop cell execution to exit.{bcolors.ENDC}")
return input(">")
if __name__ == "__main__":
# Initialize chat history
chat_history = []
# Initialize Q&A chain
qa_chain = build_chain()
try:
while query := prompt_user():
# Process user input in case of a new search
if query.strip().lower().startswith("new search:"):
query = query.strip().lower().replace("new search:", "")
chat_history = []
if len(chat_history) == MAX_HISTORY_LENGTH:
chat_history.pop(0)
# Show answer and keep a record
result = run_chain(qa_chain, query, chat_history)
chat_history.append((query, result["answer"]))
print(f"{bcolors.OKGREEN}{result['answer']}{bcolors.ENDC}")
# Show sources
if 'source_documents' in result:
print(bcolors.OKGREEN + 'Sources:')
for doc in result['source_documents']:
print(f"+ {doc.metadata['source']}")
except KeyboardInterrupt:
pass