diff --git a/src/langchain_google_cloud_sql_mysql/mysql_chat_message_history.py b/src/langchain_google_cloud_sql_mysql/mysql_chat_message_history.py index 655d8d3..bfc7cb2 100644 --- a/src/langchain_google_cloud_sql_mysql/mysql_chat_message_history.py +++ b/src/langchain_google_cloud_sql_mysql/mysql_chat_message_history.py @@ -22,30 +22,60 @@ class MySQLChatMessageHistory(BaseChatMessageHistory): - """Chat message history stored in a Cloud SQL MySQL database.""" + """Chat message history stored in a Cloud SQL MySQL database. + + Args: + engine (MySQLEngine): SQLAlchemy connection pool engine for managing + connections to Cloud SQL for MySQL. + session_id (str): Arbitrary key that is used to store the messages + of a single chat session. + table_name (str): The name of the table to use for storing/retrieving + the chat message history. + """ def __init__( self, engine: MySQLEngine, session_id: str, - table_name: str = "message_store", + table_name: str, ) -> None: self.engine = engine self.session_id = session_id self.table_name = table_name - self._create_table_if_not_exists() + self._verify_schema() - def _create_table_if_not_exists(self) -> None: - create_table_query = f"""CREATE TABLE IF NOT EXISTS `{self.table_name}` ( - id INT AUTO_INCREMENT PRIMARY KEY, - session_id TEXT NOT NULL, - data JSON NOT NULL, - type TEXT NOT NULL - );""" + def _verify_schema(self) -> None: + """Verify table exists with required schema for MySQLChatMessageHistory class. - with self.engine.connect() as conn: - conn.execute(sqlalchemy.text(create_table_query)) - conn.commit() + Use helper method MySQLEngine.init_chat_history_table(...) to create + table with valid schema. + """ + insp = sqlalchemy.inspect(self.engine.engine) + # check table exists + if insp.has_table(self.table_name): + # check that all required columns are present + required_columns = ["id", "session_id", "data", "type"] + column_names = [ + c["name"] for c in insp.get_columns(table_name=self.table_name) + ] + if not (all(x in column_names for x in required_columns)): + raise IndexError( + f"Table '{self.table_name}' has incorrect schema. Got " + f"column names '{column_names}' but required column names " + f"'{required_columns}'.\nPlease create table with following schema:" + f"\nCREATE TABLE {self.table_name} (" + "\n id INT AUTO_INCREMENT PRIMARY KEY," + "\n session_id TEXT NOT NULL," + "\n data JSON NOT NULL," + "\n type TEXT NOT NULL" + "\n);" + ) + else: + raise AttributeError( + f"Table '{self.table_name}' does not exist. Please create " + "it before initializing MySQLChatMessageHistory. See " + "MySQLEngine.init_chat_history_table() for a helper method." + ) @property def messages(self) -> List[BaseMessage]: # type: ignore diff --git a/src/langchain_google_cloud_sql_mysql/mysql_engine.py b/src/langchain_google_cloud_sql_mysql/mysql_engine.py index 474d60d..8b5f3be 100644 --- a/src/langchain_google_cloud_sql_mysql/mysql_engine.py +++ b/src/langchain_google_cloud_sql_mysql/mysql_engine.py @@ -202,6 +202,33 @@ def connect(self) -> sqlalchemy.engine.Connection: """ return self.engine.connect() + def init_chat_history_table(self, table_name: str) -> None: + """Create table with schema required for MySQLChatMessageHistory class. + + Required schema is as follows: + + CREATE TABLE {table_name} ( + id INT AUTO_INCREMENT PRIMARY KEY, + session_id TEXT NOT NULL, + data JSON NOT NULL, + type TEXT NOT NULL + ) + + Args: + table_name (str): Name of database table to create for storing chat + message history. + """ + create_table_query = f"""CREATE TABLE IF NOT EXISTS `{table_name}` ( + id INT AUTO_INCREMENT PRIMARY KEY, + session_id TEXT NOT NULL, + data JSON NOT NULL, + type TEXT NOT NULL + );""" + + with self.engine.connect() as conn: + conn.execute(sqlalchemy.text(create_table_query)) + conn.commit() + def init_document_table( self, table_name: str, diff --git a/tests/integration/test_mysql_chat_message_history.py b/tests/integration/test_mysql_chat_message_history.py index 44c98c9..c437392 100644 --- a/tests/integration/test_mysql_chat_message_history.py +++ b/tests/integration/test_mysql_chat_message_history.py @@ -25,6 +25,7 @@ region = os.environ["REGION"] instance_id = os.environ["INSTANCE_ID"] db_name = os.environ["DB_NAME"] +table_name = "message_store" @pytest.fixture(name="memory_engine") @@ -33,16 +34,28 @@ def setup() -> Generator: project_id=project_id, region=region, instance=instance_id, database=db_name ) + # create table with malformed schema (missing 'type') + query = """CREATE TABLE malformed_table ( + id INT AUTO_INCREMENT PRIMARY KEY, + session_id TEXT NOT NULL, + data JSON NOT NULL + );""" + with engine.connect() as conn: + conn.execute(sqlalchemy.text(query)) + conn.commit() yield engine # use default table for MySQLChatMessageHistory - table_name = "message_store" with engine.connect() as conn: conn.execute(sqlalchemy.text(f"DROP TABLE IF EXISTS `{table_name}`")) + conn.execute(sqlalchemy.text(f"DROP TABLE IF EXISTS malformed_table")) conn.commit() def test_chat_message_history(memory_engine: MySQLEngine) -> None: - history = MySQLChatMessageHistory(engine=memory_engine, session_id="test") + memory_engine.init_chat_history_table(table_name) + history = MySQLChatMessageHistory( + engine=memory_engine, session_id="test", table_name=table_name + ) history.add_user_message("hi!") history.add_ai_message("whats up?") messages = history.messages @@ -58,21 +71,24 @@ def test_chat_message_history(memory_engine: MySQLEngine) -> None: assert len(history.messages) == 0 -def test_chat_message_history_custom_table_name(memory_engine: MySQLEngine) -> None: - """Test MySQLChatMessageHistory with custom table name""" - history = MySQLChatMessageHistory( - engine=memory_engine, session_id="test", table_name="message-store" - ) - history.add_user_message("hi!") - history.add_ai_message("whats up?") - messages = history.messages +def test_chat_message_history_table_does_not_exist(memory_engine: MySQLEngine) -> None: + """Test that MySQLChatMessageHistory fails if table does not exist.""" + with pytest.raises(AttributeError) as exc_info: + MySQLChatMessageHistory( + engine=memory_engine, session_id="test", table_name="missing_table" + ) + # assert custom error message for missing table + assert ( + exc_info.value.args[0] + == f"Table 'missing_table' does not exist. Please create it before initializing MySQLChatMessageHistory. See MySQLEngine.init_chat_history_table() for a helper method." + ) - # verify messages are correct - assert messages[0].content == "hi!" - assert type(messages[0]) is HumanMessage - assert messages[1].content == "whats up?" - assert type(messages[1]) is AIMessage - # verify clear() clears message history - history.clear() - assert len(history.messages) == 0 +def test_chat_message_history_table_malformed_schema( + memory_engine: MySQLEngine, +) -> None: + """Test that MySQLChatMessageHistory fails if schema is malformed.""" + with pytest.raises(IndexError): + MySQLChatMessageHistory( + engine=memory_engine, session_id="test", table_name="malformed_table" + )