Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,60 @@


class MySQLChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in a Cloud SQL MySQL database."""
"""Chat message history stored in a Cloud SQL MySQL database.

Args:
engine (MySQLEngine): SQLAlchemy connection pool engine for managing
connections to Cloud SQL for MySQL.
session_id (str): Arbitrary key that is used to store the messages
of a single chat session.
table_name (str): The name of the table to use for storing/retrieving
the chat message history.
"""

def __init__(
self,
engine: MySQLEngine,
session_id: str,
table_name: str = "message_store",
table_name: str,
) -> None:
self.engine = engine
self.session_id = session_id
self.table_name = table_name
self._create_table_if_not_exists()
self._verify_schema()

def _create_table_if_not_exists(self) -> None:
create_table_query = f"""CREATE TABLE IF NOT EXISTS `{self.table_name}` (
id INT AUTO_INCREMENT PRIMARY KEY,
session_id TEXT NOT NULL,
data JSON NOT NULL,
type TEXT NOT NULL
);"""
def _verify_schema(self) -> None:
"""Verify table exists with required schema for MySQLChatMessageHistory class.

with self.engine.connect() as conn:
conn.execute(sqlalchemy.text(create_table_query))
conn.commit()
Use helper method MySQLEngine.init_chat_history_table(...) to create
table with valid schema.
"""
insp = sqlalchemy.inspect(self.engine.engine)
# check table exists
if insp.has_table(self.table_name):
# check that all required columns are present
required_columns = ["id", "session_id", "data", "type"]
column_names = [
c["name"] for c in insp.get_columns(table_name=self.table_name)
]
if not (all(x in column_names for x in required_columns)):
raise IndexError(
f"Table '{self.table_name}' has incorrect schema. Got "
f"column names '{column_names}' but required column names "
f"'{required_columns}'.\nPlease create table with following schema:"
f"\nCREATE TABLE {self.table_name} ("
"\n id INT AUTO_INCREMENT PRIMARY KEY,"
"\n session_id TEXT NOT NULL,"
"\n data JSON NOT NULL,"
"\n type TEXT NOT NULL"
"\n);"
)
else:
raise AttributeError(
f"Table '{self.table_name}' does not exist. Please create "
"it before initializing MySQLChatMessageHistory. See "
"MySQLEngine.init_chat_history_table() for a helper method."
)

@property
def messages(self) -> List[BaseMessage]: # type: ignore
Expand Down
27 changes: 27 additions & 0 deletions src/langchain_google_cloud_sql_mysql/mysql_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,33 @@ def connect(self) -> sqlalchemy.engine.Connection:
"""
return self.engine.connect()

def init_chat_history_table(self, table_name: str) -> None:
"""Create table with schema required for MySQLChatMessageHistory class.

Required schema is as follows:

CREATE TABLE {table_name} (
id INT AUTO_INCREMENT PRIMARY KEY,
session_id TEXT NOT NULL,
data JSON NOT NULL,
type TEXT NOT NULL
)

Args:
table_name (str): Name of database table to create for storing chat
message history.
"""
create_table_query = f"""CREATE TABLE IF NOT EXISTS `{table_name}` (
id INT AUTO_INCREMENT PRIMARY KEY,
session_id TEXT NOT NULL,
data JSON NOT NULL,
type TEXT NOT NULL
);"""

with self.engine.connect() as conn:
conn.execute(sqlalchemy.text(create_table_query))
conn.commit()

def init_document_table(
self,
table_name: str,
Expand Down
52 changes: 34 additions & 18 deletions tests/integration/test_mysql_chat_message_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
region = os.environ["REGION"]
instance_id = os.environ["INSTANCE_ID"]
db_name = os.environ["DB_NAME"]
table_name = "message_store"


@pytest.fixture(name="memory_engine")
Expand All @@ -33,16 +34,28 @@ def setup() -> Generator:
project_id=project_id, region=region, instance=instance_id, database=db_name
)

# create table with malformed schema (missing 'type')
query = """CREATE TABLE malformed_table (
id INT AUTO_INCREMENT PRIMARY KEY,
session_id TEXT NOT NULL,
data JSON NOT NULL
);"""
with engine.connect() as conn:
conn.execute(sqlalchemy.text(query))
conn.commit()
yield engine
# use default table for MySQLChatMessageHistory
table_name = "message_store"
with engine.connect() as conn:
conn.execute(sqlalchemy.text(f"DROP TABLE IF EXISTS `{table_name}`"))
conn.execute(sqlalchemy.text(f"DROP TABLE IF EXISTS malformed_table"))
conn.commit()


def test_chat_message_history(memory_engine: MySQLEngine) -> None:
history = MySQLChatMessageHistory(engine=memory_engine, session_id="test")
memory_engine.init_chat_history_table(table_name)
history = MySQLChatMessageHistory(
engine=memory_engine, session_id="test", table_name=table_name
)
history.add_user_message("hi!")
history.add_ai_message("whats up?")
messages = history.messages
Expand All @@ -58,21 +71,24 @@ def test_chat_message_history(memory_engine: MySQLEngine) -> None:
assert len(history.messages) == 0


def test_chat_message_history_custom_table_name(memory_engine: MySQLEngine) -> None:
"""Test MySQLChatMessageHistory with custom table name"""
history = MySQLChatMessageHistory(
engine=memory_engine, session_id="test", table_name="message-store"
)
history.add_user_message("hi!")
history.add_ai_message("whats up?")
messages = history.messages
def test_chat_message_history_table_does_not_exist(memory_engine: MySQLEngine) -> None:
"""Test that MySQLChatMessageHistory fails if table does not exist."""
with pytest.raises(AttributeError) as exc_info:
MySQLChatMessageHistory(
engine=memory_engine, session_id="test", table_name="missing_table"
)
# assert custom error message for missing table
assert (
exc_info.value.args[0]
== f"Table 'missing_table' does not exist. Please create it before initializing MySQLChatMessageHistory. See MySQLEngine.init_chat_history_table() for a helper method."
)

# verify messages are correct
assert messages[0].content == "hi!"
assert type(messages[0]) is HumanMessage
assert messages[1].content == "whats up?"
assert type(messages[1]) is AIMessage

# verify clear() clears message history
history.clear()
assert len(history.messages) == 0
def test_chat_message_history_table_malformed_schema(
memory_engine: MySQLEngine,
) -> None:
"""Test that MySQLChatMessageHistory fails if schema is malformed."""
with pytest.raises(IndexError):
MySQLChatMessageHistory(
engine=memory_engine, session_id="test", table_name="malformed_table"
)