import os import time from langchain_groq import ChatGroq from langchain_core.prompts import ChatPromptTemplate from langchain_core.pydantic_v1 import BaseModel, Field from langgraph.graph.message import AnyMessage, add_messages from langgraph.graph import END, StateGraph from langgraph.checkpoint.sqlite import SqliteSaver from typing import Annotated, Dict, TypedDict, List from operator import itemgetter from langchain_core.runnables import RunnablePassthrough from langchain_core.prompts import PromptTemplate from IPython.display import Image, display import uuid # Set environment variables os.environ['TOKENIZERS_PARALLELISM'] = 'true' # mistral_api_key = os.getenv("MISTRAL_API_KEY") # Ensure this is set # Set up the LLM llm = ChatGroq(temperature=0, groq_api_key="groq_api", model_name="llama3-8b-8192") # Define the prompt template code_gen_prompt_claude = ChatPromptTemplate.from_messages( [ ( "system", """You are a coding assistant. Ensure any code you provide can be executed with all required imports and variables defined. Structure your answer: 1) a prefix describing the code solution, 2) the imports, 3) the functioning code block. \n Here is the user question:""", ), ("placeholder", "{messages}"), ] ) # Define the data model class code(BaseModel): """Code output""" prefix: str = Field(description="Description of the problem and approach") imports: str = Field(description="Code block import statements") code: str = Field(description="Code block not including import statements") description = "Schema for code solutions to questions about LCEL." # Set up the structured output code_gen_chain = llm.with_structured_output(code, include_raw=False) # Define the graph state class GraphState(TypedDict): """ Represents the state of our graph. Attributes: error : Binary flag for control flow to indicate whether test error was tripped messages : With user question, error messages, reasoning generation : Code solution iterations : Number of tries """ error: str messages: Annotated[list[AnyMessage], add_messages] generation: str iterations: int # Define the nodes def generate(state: GraphState): """ Generate a code solution Args: state (dict): The current graph state Returns: state (dict): New key added to state, generation """ print("---GENERATING CODE SOLUTION---") # State messages = state["messages"] iterations = state["iterations"] error = state["error"] # Solution code_solution = code_gen_chain.invoke(messages) messages += [ ( "assistant", f"Here is my attempt to solve the problem: {code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}", ) ] # Increment iterations = iterations + 1 # Add delay to reduce API requests time.sleep(1) # Wait for 1 second return {"generation": code_solution, "messages": messages, "iterations": iterations} def code_check(state: GraphState): """ Check code Args: state (dict): The current graph state Returns: state (dict): New key added to state, error """ print("---CHECKING CODE---") # State messages = state["messages"] code_solution = state["generation"] iterations = state["iterations"] # Get solution components prefix = code_solution.prefix imports = code_solution.imports code = code_solution.code # Check imports try: exec(imports) except Exception as e: print("---CODE IMPORT CHECK: FAILED---") error_message = [("user", f"Your solution failed the import test. Here is the error: {e}. Reflect on this error and your prior attempt to solve the problem. (1) State what you think went wrong with the prior solution and (2) try to solve this problem again. Return the FULL SOLUTION. Use the code tool to structure the output with a prefix, imports, and code block:")] messages += error_message return { "generation": code_solution, "messages": messages, "iterations": iterations, "error": "yes", } # Check execution try: combined_code = f"{imports}\n{code}" # Use a shared scope for exec global_scope = {} exec(combined_code, global_scope) except Exception as e: print("---CODE BLOCK CHECK: FAILED---") error_message = [("user", f"Your solution failed the code execution test: {e}) Reflect on this error and your prior attempt to solve the problem. (1) State what you think went wrong with the prior solution and (2) try to solve this problem again. Return the FULL SOLUTION. Use the code tool to structure the output with a prefix, imports, and code block:")] messages += error_message return { "generation": code_solution, "messages": messages, "iterations": iterations, "error": "yes", } # No errors print("---NO CODE TEST FAILURES---") return { "generation": code_solution, "messages": messages, "iterations": iterations, "error": "no", } def decide_to_finish(state: GraphState): """ Determines whether to finish. Args: state (dict): The current graph state Returns: str: Next node to call """ error = state["error"] iterations = state["iterations"] if error == "no" or iterations == max_iterations: print("---DECISION: FINISH---") return "end" else: print("---DECISION: RE-TRY SOLUTION---") return "generate" # Define the graph builder = StateGraph(GraphState) # Add nodes builder.add_node("generate", generate) # generation solution builder.add_node("check_code", code_check) # check code # Build graph builder.set_entry_point("generate") builder.add_edge("generate", "check_code") builder.add_conditional_edges( "check_code", decide_to_finish, { "end": END, "generate": "generate", }, ) # Compile the graph memory = SqliteSaver.from_conn_string(":memory:") graph = builder.compile(checkpointer=memory) # Display the graph try: display(Image(graph.get_graph(xray=True).draw_mermaid_png())) except: pass # Run the graph _printed = set() thread_id = str(uuid.uuid4()) config = { "configurable": { # Checkpoints are accessed by thread_id "thread_id": thread_id, } } # Ask user for input question = input("Enter your question or search query: ") # Run the graph max_iterations = 5 # Define the maximum number of iterations events = graph.stream( {"messages": [("user", question)], "iterations": 0}, config, stream_mode="values" ) def _print_event(event, _printed): if str(event) not in _printed: print(event) _printed.add(str(event)) for event in events: _print_event(event, _printed) # Output the final result print("Final Result:") print(event['generation'])