Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/langchain_google_cloud_sql_mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from langchain_google_cloud_sql_mysql.mysql_chat_message_history import (
MySQLChatMessageHistory,
)
from langchain_google_cloud_sql_mysql.mysql_engine import MySQLEngine
from langchain_google_cloud_sql_mysql.mysql_loader import MySQLLoader

__all__ = ["MySQLEngine", "MySQLLoader"]
__all__ = ["MySQLChatMessageHistory", "MySQLEngine", "MySQLLoader"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import List

import sqlalchemy
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, messages_from_dict

from langchain_google_cloud_sql_mysql.mysql_engine import MySQLEngine


class MySQLChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in a Cloud SQL MySQL database."""

def __init__(
self,
engine: MySQLEngine,
session_id: str,
table_name: str = "message_store",
) -> None:
self.engine = engine
self.session_id = session_id
self.table_name = table_name
self._create_table_if_not_exists()

def _create_table_if_not_exists(self) -> None:
create_table_query = f"""CREATE TABLE IF NOT EXISTS {self.table_name} (
id INT AUTO_INCREMENT PRIMARY KEY,
session_id TEXT NOT NULL,
data JSON NOT NULL,
type TEXT NOT NULL
);"""

with self.engine.connect() as conn:
conn.execute(sqlalchemy.text(create_table_query))
conn.commit()

@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve the messages from Cloud SQL"""
query = f"SELECT data, type FROM {self.table_name} WHERE session_id = '{self.session_id}' ORDER BY id;"
with self.engine.connect() as conn:
results = conn.execute(sqlalchemy.text(query)).fetchall()
# load SQLAlchemy row objects into dicts
items = [
{"data": json.loads(result[0]), "type": result[1]} for result in results
]
messages = messages_from_dict(items)
return messages

def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in Cloud SQL"""
query = f"INSERT INTO {self.table_name} (session_id, data, type) VALUES (:session_id, :data, :type);"
with self.engine.connect() as conn:
conn.execute(
sqlalchemy.text(query),
{
"session_id": self.session_id,
"data": json.dumps(message.dict()),
"type": message.type,
},
)
conn.commit()

def clear(self) -> None:
"""Clear session memory from Cloud SQL"""
query = f"DELETE FROM {self.table_name} WHERE session_id = :session_id;"
with self.engine.connect() as conn:
conn.execute(sqlalchemy.text(query), {"session_id": self.session_id})
conn.commit()
58 changes: 58 additions & 0 deletions tests/integration/test_mysql_chat_message_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Generator

import pytest
import sqlalchemy
from langchain_core.messages.ai import AIMessage
from langchain_core.messages.human import HumanMessage

from langchain_google_cloud_sql_mysql import MySQLChatMessageHistory, MySQLEngine

project_id = os.environ["PROJECT_ID"]
region = os.environ["REGION"]
instance_id = os.environ["INSTANCE_ID"]
db_name = os.environ["DB_NAME"]


@pytest.fixture(name="memory_engine")
def setup() -> Generator:
engine = MySQLEngine.from_instance(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documentation should have notes on how to set up authentication

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean "documentation" as in a docstring explaining setup or are you talking about wider documentation around how to run integrations tests and the setup pre-reqs? Will be adding a testing section to CONTRIBUTING.md in a future PR

project_id=project_id, region=region, instance=instance_id, database=db_name
)

yield engine
# use default table for MySQLChatMessageHistory
table_name = "message_store"
with engine.connect() as conn:
conn.execute(sqlalchemy.text(f"DROP TABLE IF EXISTS `{table_name}`"))
conn.commit()


def test_chat_message_history(memory_engine: MySQLEngine) -> None:
history = MySQLChatMessageHistory(engine=memory_engine, session_id="test")
history.add_user_message("hi!")
history.add_ai_message("whats up?")
messages = history.messages

# verify messages are correct
assert messages[0].content == "hi!"
assert type(messages[0]) is HumanMessage
assert messages[1].content == "whats up?"
assert type(messages[1]) is AIMessage

# verify clear() clears message history
history.clear()
assert len(history.messages) == 0