diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..6d30603 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,32 @@ +# Python 3.10 slim 이미지 기반 +FROM python:3.12-slim + +# 시스템 라이브러리 설치 +RUN apt-get update && apt-get install -y \ + build-essential \ + curl \ + software-properties-common \ + git \ + libpq-dev \ + && rm -rf /var/lib/apt/lists/* + +# 작업 디렉토리 설정 +WORKDIR /app + +# 의존성 파일 복사 및 설치 +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# 전체 서비스 코드 복사 +COPY . . + +# Python 환경 설정 +ENV PYTHONPATH=/app +ENV PYTHONUNBUFFERED=1 + +# Streamlit 포트 노출 +EXPOSE 8501 + +# Streamlit 실행 명령 +CMD ["python", "-c", "from llm_utils.tools import set_gms_server; import os; set_gms_server(os.getenv('DATAHUB_SERVER', 'http://localhost:8080'))"] +CMD ["streamlit", "run", "./interface/streamlit_app.py", "--server.port=8501"] \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..097e87f --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,28 @@ +services: + streamlit: + build: . + ports: + - "8501:8501" + volumes: + - .:/app + env_file: + - .env + environment: + - DATABASE_URL=postgresql://postgres:password@db:5432/streamlit_db + depends_on: + - db + + db: + image: pgvector/pgvector:pg17 + container_name: pgvector-db + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: password + POSTGRES_DB: streamlit_db + ports: + - "5432:5432" + volumes: + - pgdata:/var/lib/postgresql/data + - ./postgres/schema.sql:/docker-entrypoint-initdb.d/schema.sql +volumes: + pgdata: \ No newline at end of file diff --git a/evaluation/gen_persona.py b/evaluation/gen_persona.py index 4884d63..febca07 100644 --- a/evaluation/gen_persona.py +++ b/evaluation/gen_persona.py @@ -1,5 +1,3 @@ -import os - from utils import save_persona_json, pretty_print_persona from persona_class import PersonaList diff --git a/llm_utils/agent.py b/llm_utils/agent.py new file mode 100644 index 0000000..895e534 --- /dev/null +++ b/llm_utils/agent.py @@ -0,0 +1,66 @@ +from langchain_core.messages import AIMessage +from langchain_core.output_parsers import JsonOutputParser +from langchain_core.messages import SystemMessage +from .state import QueryMakerState +from .llm_factory import get_llm +from prompt.template_loader import get_prompt_template + +llm = get_llm() + +# JSON 스키마 정의 +main_agent_schema = { + "type": "object", + "properties": { + "intent": {"type": "string", "description": "유저의 의도 (search, end 등)"}, + "user_input": {"type": "string", "description": "유저의 입력"}, + "intent_reason": {"type": "string", "description": "유저의 의도 파악 이유"}, + }, + "required": ["intent", "user_input"], +} +main_agent_parser = JsonOutputParser(schema=main_agent_schema) + + +def manager_agent(state: QueryMakerState) -> dict: + """ + 가장 처음 시작하는 agent로 질문의 유무를 판단해서 적절한 Agent를 호출합니다. + 추후, 가드레일 등에 detecting될 경우에도 해당 노드를 통해 대응이 가능합니다 + """ + manager_agent_prompt = get_prompt_template("manager_agent_prompt") + messages = [SystemMessage(content=manager_agent_prompt), state["messages"][-1]] + response = llm.invoke(messages) + + try: + parsed_output = main_agent_parser.parse(response.content) + state.update( + { + "messages": state["messages"] + [response], # 기록용 + "intent": parsed_output.get("intent", "end"), # 분기용 + "user_input": parsed_output.get( + "user_input", state["messages"][-1].content + ), # SQL 쿼리 변환 대상 질문 + "intent_reason": parsed_output.get("intent_reason", ""), # 분기 이유 + } + ) + return state + + except Exception as e: + print(f"<>") + state.update( + { + "messages": state["messages"] + [AIMessage(content=response.content)], + "intent": "end", + "intent_reason": response.content, + } + ) + return state + + +def manager_agent_edge(state: QueryMakerState) -> str: + """ + Condition for main_agent + """ + print("=== In condition: main_edge ===") + if state.get("intent") == "make_query": + return "make_query" + else: + return "end" # end 시 최종 출력 값 반환 diff --git a/llm_utils/graph.py b/llm_utils/graph.py index 69a10b9..1f14d24 100644 --- a/llm_utils/graph.py +++ b/llm_utils/graph.py @@ -1,20 +1,17 @@ -import os import json -from typing_extensions import TypedDict, Annotated from langgraph.graph import END, StateGraph -from langgraph.graph.message import add_messages from langchain.chains.sql_database.prompt import SQL_PROMPTS from pydantic import BaseModel, Field +from .agent import manager_agent, manager_agent_edge from .llm_factory import get_llm - from llm_utils.chains import ( query_refiner_chain, query_maker_chain, ) - -from llm_utils.tools import get_info_from_db from llm_utils.retrieval import search_tables +from langchain.schema import AIMessage +from .state import QueryMakerState # 노드 식별자 정의 QUERY_REFINER = "query_refiner" @@ -22,23 +19,26 @@ TOOL = "tool" TABLE_FILTER = "table_filter" QUERY_MAKER = "query_maker" +MANAGER_AGENT = "manager_agent" +EXCEPTION_END_NODE = "exception_end_node" -# 상태 타입 정의 (추가 상태 정보와 메시지들을 포함) -class QueryMakerState(TypedDict): - messages: Annotated[list, add_messages] - user_database_env: str - searched_tables: dict[str, dict[str, str]] - best_practice_query: str - refined_input: str - generated_query: str - retriever_name: str - top_n: int - device: str +def exception_end_node(state: QueryMakerState): + intent_reason = state.get("intent_reason", "SQL 쿼리 생성을 위한 질문을 해주세요") + end_message_prompt = f""" +다음과 같은 이유로 답변을 할 수 없습니다! +``` +{intent_reason} +``` +""" + return { + "messages": state["messages"] + [AIMessage(content=end_message_prompt)], + } # 노드 함수: QUERY_REFINER 노드 def query_refiner_node(state: QueryMakerState): + # refined_node의 결과값으로 바로 AIMessages 반환 res = query_refiner_chain.invoke( input={ "user_input": [state["messages"][0].content], @@ -55,7 +55,7 @@ def query_refiner_node(state: QueryMakerState): def get_table_info_node(state: QueryMakerState): # retriever_name과 top_n을 이용하여 검색 수행 documents_dict = search_tables( - query=state["messages"][0].content, + query=state["user_input"], retriever_name=state["retriever_name"], top_n=state["top_n"], device=state["device"], @@ -67,9 +67,10 @@ def get_table_info_node(state: QueryMakerState): # 노드 함수: QUERY_MAKER 노드 def query_maker_node(state: QueryMakerState): + # sturctured output 사용 res = query_maker_chain.invoke( input={ - "user_input": [state["messages"][0].content], + "user_input": [state["user_input"]], "refined_input": [state["refined_input"]], "searched_tables": [json.dumps(state["searched_tables"])], "user_database_env": [state["user_database_env"]], @@ -105,19 +106,33 @@ def query_maker_node_with_db_guide(state: QueryMakerState): # StateGraph 생성 및 구성 builder = StateGraph(QueryMakerState) -builder.set_entry_point(GET_TABLE_INFO) - # 노드 추가 +builder.add_node(MANAGER_AGENT, manager_agent) builder.add_node(GET_TABLE_INFO, get_table_info_node) builder.add_node(QUERY_REFINER, query_refiner_node) builder.add_node(QUERY_MAKER, query_maker_node) # query_maker_node_with_db_guide +builder.add_node(EXCEPTION_END_NODE, exception_end_node) # builder.add_node( # QUERY_MAKER, query_maker_node_with_db_guide # ) # query_maker_node_with_db_guide # 기본 엣지 설정 +builder.set_entry_point(MANAGER_AGENT) builder.add_edge(GET_TABLE_INFO, QUERY_REFINER) builder.add_edge(QUERY_REFINER, QUERY_MAKER) +# 조건부 엣지 +builder.add_conditional_edges( + MANAGER_AGENT, + manager_agent_edge, + { + "end": EXCEPTION_END_NODE, + "make_query": GET_TABLE_INFO, + }, +) + # QUERY_MAKER 노드 후 종료 builder.add_edge(QUERY_MAKER, END) + +# EXCEPTION_END_NODE 노드 후 종료 +builder.add_edge(EXCEPTION_END_NODE, END) diff --git a/llm_utils/llm_factory.py b/llm_utils/llm_factory.py index bdb4d64..4d1967f 100644 --- a/llm_utils/llm_factory.py +++ b/llm_utils/llm_factory.py @@ -18,7 +18,6 @@ AzureChatOpenAI, OpenAIEmbeddings, ) -from langchain_community.llms.bedrock import Bedrock # .env 파일 로딩 load_dotenv() diff --git a/llm_utils/retrieval.py b/llm_utils/retrieval.py index 728141f..27e8131 100644 --- a/llm_utils/retrieval.py +++ b/llm_utils/retrieval.py @@ -18,7 +18,7 @@ def get_vector_db(): embeddings, allow_dangerous_deserialization=True, ) - except: + except Exception: documents = get_info_from_db() db = FAISS.from_documents(documents, embeddings) db.save_local(os.getcwd() + "/table_info_db") diff --git a/llm_utils/state.py b/llm_utils/state.py new file mode 100644 index 0000000..3d30e19 --- /dev/null +++ b/llm_utils/state.py @@ -0,0 +1,18 @@ +from typing_extensions import TypedDict, Annotated +from langgraph.graph.message import add_messages + + +# 상태 타입 정의 (추가 상태 정보와 메시지들을 포함) +class QueryMakerState(TypedDict): + messages: Annotated[list, add_messages] + user_database_env: str + searched_tables: dict[str, dict[str, str]] + best_practice_query: str + refined_input: str + generated_query: str + retriever_name: str + top_n: int + device: str + intent: str + intent_reason: str + user_input: str diff --git a/llm_utils/tools.py b/llm_utils/tools.py index 31c2a09..a5b10a7 100644 --- a/llm_utils/tools.py +++ b/llm_utils/tools.py @@ -1,5 +1,5 @@ import os -from typing import List, Dict, Optional, TypeVar, Callable, Iterable, Any +from typing import List, Dict, Optional, TypeVar, Callable, Iterable from langchain.schema import Document @@ -40,7 +40,7 @@ def parallel_process[T, R]( def set_gms_server(gms_server: str): try: os.environ["DATAHUB_SERVER"] = gms_server - fetcher = DatahubMetadataFetcher(gms_server=gms_server) + DatahubMetadataFetcher(gms_server=gms_server) except ValueError as e: raise ValueError(f"GMS 서버 설정 실패: {str(e)}") diff --git a/prompt/manager_agent_prompt.md b/prompt/manager_agent_prompt.md new file mode 100644 index 0000000..a7e1e55 --- /dev/null +++ b/prompt/manager_agent_prompt.md @@ -0,0 +1,20 @@ +# Role + +사용자의 질문을 기반으로, 주어진 테이블과 컬럼 정보를 활용하여 적절한 SQL 쿼리를 생성하는 데 도움을 주는 Agent입니다. +사용자의 입력을 분석하여 다음 중에서 의도(intent)를 파악하고, 만약 SQL 쿼리나 주어진 DB에 관련된 질문이 아닐 경우, end node를 반환합니다. + +1. end: SQL 관련 요청이 아닐 경우 (예: "오늘 날씨가 어때?", "Postgres에 대해 알려줘") +2. make_query: DB검색 관련 요청 + +agent는 key로 관리되며 다음과 같은 키가 있습니다. +intent가 make_query일 경우, 아래와 같은 형태로 결과를 반환해주세요. +응답 형식: + { + "intent": "<의도>", + "user_input": "<사용자의 입력>", + "intent_reason": "<의도 파악 이유>" + } +# 주의사항 +- 서술형 대답이 필요한 경우, Intent에 대한 reason만 입력해주세요. +- Intent Reason은 최대한 쉽고 간단하게 입력하되, 존댓말을 쓰세요. +- Json 형식으로만 응답을 해야하며, 추가적인 정보나 문구를 입력하지 마세요.