Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/langchain_google_cloud_sql_mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from langchain_google_cloud_sql_mssql.mssql_chat_message_history import (
MSSQLChatMessageHistory,
)
from langchain_google_cloud_sql_mssql.mssql_engine import MSSQLEngine
from langchain_google_cloud_sql_mssql.mssql_loader import MSSQLLoader

__all__ = ["MSSQLEngine", "MSSQLLoader"]
__all__ = ["MSSQLChatMessageHistory", "MSSQLEngine", "MSSQLLoader"]
112 changes: 112 additions & 0 deletions src/langchain_google_cloud_sql_mssql/mssql_chat_message_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import List

import sqlalchemy
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, messages_from_dict

from langchain_google_cloud_sql_mssql.mssql_engine import MSSQLEngine


class MSSQLChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in a Cloud SQL MSSQL database.

Args:
engine (MSSQLEngine): SQLAlchemy connection pool engine for managing
connections to Cloud SQL for SQL Server.
session_id (str): Arbitrary key that is used to store the messages
of a single chat session.
table_name (str): The name of the table to use for storing/retrieving
the chat message history.
"""

def __init__(
self,
engine: MSSQLEngine,
session_id: str,
table_name: str,
) -> None:
self.engine = engine
self.session_id = session_id
self.table_name = table_name
self._verify_schema()

def _verify_schema(self) -> None:
"""Verify table exists with required schema for MSSQLChatMessageHistory class.

Use helper method MSSQLEngine.create_chat_history_table(...) to create
table with valid schema.
"""
insp = sqlalchemy.inspect(self.engine.engine)
# check table exists
if insp.has_table(self.table_name):
# check that all required columns are present
required_columns = ["id", "session_id", "data", "type"]
column_names = [
c["name"] for c in insp.get_columns(table_name=self.table_name)
]
if not (all(x in column_names for x in required_columns)):
raise IndexError(
f"Table '{self.table_name}' has incorrect schema. Got "
f"column names '{column_names}' but required column names "
f"'{required_columns}'.\nPlease create table with following schema:"
f"\nCREATE TABLE {self.table_name} ("
"\n id INT IDENTITY(1,1) PRIMARY KEY,"
"\n session_id NVARCHAR(MAX) NOT NULL,"
"\n data NVARCHAR(MAX) NOT NULL,"
"\n type NVARCHAR(MAX) NOT NULL"
"\n);"
)
else:
raise AttributeError(
f"Table '{self.table_name}' does not exist. Please create "
"it before initializing MSSQLChatMessageHistory. See "
"MSSQLEngine.create_chat_history_table() for a helper method."
)

@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve the messages from Cloud SQL"""
query = f'SELECT data, type FROM "{self.table_name}" WHERE session_id = :session_id ORDER BY id;'
with self.engine.connect() as conn:
results = conn.execute(
sqlalchemy.text(query), {"session_id": self.session_id}
).fetchall()
# load SQLAlchemy row objects into dicts
items = [{"data": json.loads(r[0]), "type": r[1]} for r in results]
messages = messages_from_dict(items)
return messages

def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in Cloud SQL"""
query = f'INSERT INTO "{self.table_name}" (session_id, data, type) VALUES (:session_id, :data, :type);'
with self.engine.connect() as conn:
conn.execute(
sqlalchemy.text(query),
{
"session_id": self.session_id,
"data": json.dumps(message.dict()),
"type": message.type,
},
)
conn.commit()

def clear(self) -> None:
"""Clear session memory from Cloud SQL"""
query = f'DELETE FROM "{self.table_name}" WHERE session_id = :session_id;'
with self.engine.connect() as conn:
conn.execute(sqlalchemy.text(query), {"session_id": self.session_id})
conn.commit()
30 changes: 30 additions & 0 deletions src/langchain_google_cloud_sql_mssql/mssql_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,33 @@ def connect(self) -> sqlalchemy.engine.Connection:
out from the connection pool.
"""
return self.engine.connect()

def create_chat_history_table(self, table_name: str) -> None:
"""Create table with schema required for MSSQLChatMessageHistory class.

Required schema is as follows:

CREATE TABLE {table_name} (
id INT IDENTITY(1,1) PRIMARY KEY,
session_id NVARCHAR(MAX) NOT NULL,
data NVARCHAR(MAX) NOT NULL,
type NVARCHAR(MAX) NOT NULL
)

Args:
table_name (str): Name of database table to create for storing chat
message history.
"""
create_table_query = f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_NAME = '{table_name}')
BEGIN
CREATE TABLE {table_name} (
id INT IDENTITY(1,1) PRIMARY KEY,
session_id NVARCHAR(MAX) NOT NULL,
data NVARCHAR(MAX) NOT NULL,
type NVARCHAR(MAX) NOT NULL
)
END;"""
with self.engine.connect() as conn:
conn.execute(sqlalchemy.text(create_table_query))
conn.commit()
101 changes: 101 additions & 0 deletions tests/integration/test_mssql_chat_message_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Generator

import pytest
import sqlalchemy
from langchain_core.messages.ai import AIMessage
from langchain_core.messages.human import HumanMessage

from langchain_google_cloud_sql_mssql import MSSQLChatMessageHistory, MSSQLEngine

project_id = os.environ["PROJECT_ID"]
region = os.environ["REGION"]
instance_id = os.environ["INSTANCE_ID"]
db_name = os.environ["DB_NAME"]
db_user = os.environ["DB_USER"]
db_password = os.environ["DB_PASSWORD"]
table_name = "message_store"


@pytest.fixture(name="memory_engine")
def setup() -> Generator:
engine = MSSQLEngine.from_instance(
project_id=project_id,
region=region,
instance=instance_id,
database=db_name,
user=db_user,
password=db_password,
)

# create table with malformed schema (missing 'type')
query = """CREATE TABLE malformed_table (
id INT IDENTITY(1,1) PRIMARY KEY,
session_id NVARCHAR(MAX) NOT NULL,
data NVARCHAR(MAX) NOT NULL,
);"""
with engine.connect() as conn:
conn.execute(sqlalchemy.text(query))
conn.commit()
yield engine
# cleanup tables
with engine.connect() as conn:
conn.execute(sqlalchemy.text(f"DROP TABLE IF EXISTS {table_name}"))
conn.execute(sqlalchemy.text(f"DROP TABLE IF EXISTS malformed_table"))
conn.commit()


def test_chat_message_history(memory_engine: MSSQLEngine) -> None:
memory_engine.create_chat_history_table(table_name)
history = MSSQLChatMessageHistory(
engine=memory_engine, session_id="test", table_name=table_name
)
history.add_user_message("hi!")
history.add_ai_message("whats up?")
messages = history.messages

# verify messages are correct
assert messages[0].content == "hi!"
assert type(messages[0]) is HumanMessage
assert messages[1].content == "whats up?"
assert type(messages[1]) is AIMessage

# verify clear() clears message history
history.clear()
assert len(history.messages) == 0


def test_chat_message_history_table_does_not_exist(memory_engine: MSSQLEngine) -> None:
"""Test that MSSQLChatMessageHistory fails if table does not exist."""
with pytest.raises(AttributeError) as exc_info:
MSSQLChatMessageHistory(
engine=memory_engine, session_id="test", table_name="missing_table"
)
# assert custom error message for missing table
assert (
exc_info.value.args[0]
== f"Table 'missing_table' does not exist. Please create it before initializing MSSQLChatMessageHistory. See MSSQLEngine.create_chat_history_table() for a helper method."
)


def test_chat_message_history_table_malformed_schema(
memory_engine: MSSQLEngine,
) -> None:
"""Test that MSSQLChatMessageHistory fails if schema is malformed."""
with pytest.raises(IndexError):
MSSQLChatMessageHistory(
engine=memory_engine, session_id="test", table_name="malformed_table"
)