Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ jobs:
isort --check .

- name: Run type-check
run: mypy .
run: mypy --install-types --non-interactive .
13 changes: 13 additions & 0 deletions integration.cloudbuild.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,16 @@ steps:
name: python:3.11
entrypoint: python
args: ["-m", "pytest"]
env:
- 'PROJECT_ID=$PROJECT_ID'
- 'INSTANCE_ID=$_INSTANCE_ID'
- 'DB_NAME=$_DB_NAME'
- 'TABLE_NAME=test-$BUILD_ID'
- 'REGION=$_REGION'
- 'DB_USER=>$_DB_USER'
- 'DB_PASSWORD=>$_DB_PASSWORD'

substitutions:
_INSTANCE_ID: test-mssql-instance
_REGION: us-central1
_DB_NAME: test
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ requires-python = ">=3.8"
dependencies = [
"langchain==0.1.1",
"SQLAlchemy==2.0.7",
"sqlalchemy-pytds==0.3.5",
"cloud-sql-python-connector[pytds]==1.5.0"
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from langchain_google_cloud_sql_mssql.mssql_engine import MSSQLEngine
from langchain_google_cloud_sql_mssql.mssql_loader import MSSQLLoader

__all__ = ["MSSQLEngine", "MSSQLLoader"]
120 changes: 120 additions & 0 deletions src/langchain_google_cloud_sql_mssql/mssql_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# TODO: Remove below import when minimum supported Python version is 3.10
from __future__ import annotations

from typing import Optional

import sqlalchemy
from google.cloud.sql.connector import Connector


class MSSQLEngine:
"""A class for managing connections to a Cloud SQL for MSSQL database."""

_connector: Optional[Connector] = None

def __init__(
self,
engine: sqlalchemy.engine.Engine,
) -> None:
self.engine = engine

@classmethod
def from_instance(
cls,
project_id: str,
region: str,
instance: str,
database: str,
user: str,
password: str,
) -> MSSQLEngine:
"""Create an instance of MSSQLEngine from Cloud SQL instance
details.

This method uses the Cloud SQL Python Connector to connect to Cloud SQL
MSSQL instance using the given database credentials.

More details can be found at
https://github.com/GoogleCloudPlatform/cloud-sql-python-connector#credentials

Args:
project_id (str): Project ID of the Google Cloud Project where
the Cloud SQL instance is located.
region (str): Region where the Cloud SQL instance is located.
instance (str): The name of the Cloud SQL instance.
database (str): The name of the database to connect to on the
Cloud SQL instance.
db_user (str): The username to use for authentication.
db_password (str): The password to use for authentication.

Returns:
(MSSQLEngine): The engine configured to connect to a
Cloud SQL instance database.
"""
engine = cls._create_connector_engine(
instance_connection_name=f"{project_id}:{region}:{instance}",
database=database,
user=user,
password=password,
)
return cls(engine=engine)

@classmethod
def _create_connector_engine(
cls, instance_connection_name: str, database: str, user: str, password: str
) -> sqlalchemy.engine.Engine:
"""Create a SQLAlchemy engine using the Cloud SQL Python Connector.

Args:
instance_connection_name (str): The instance connection
name of the Cloud SQL instance to establish a connection to.
(ex. "project-id:instance-region:instance-name")
database (str): The name of the database to connect to on the
Cloud SQL instance.
user (str): The username to use for authentication.
password (str): The password to use for authentication.
Returns:
(sqlalchemy.engine.Engine): Engine configured using the Cloud SQL
Python Connector.
"""
if cls._connector is None:
cls._connector = Connector()

# anonymous function to be used for SQLAlchemy 'creator' argument
def getconn():
conn = cls._connector.connect( # type: ignore
instance_connection_name,
"pytds",
user=user,
password=password,
db=database,
)
return conn

return sqlalchemy.create_engine(
"mssql+pytds://",
creator=getconn,
)

def connect(self) -> sqlalchemy.engine.Connection:
"""Create a connection from SQLAlchemy connection pool.

Returns:
(sqlalchemy.engine.Connection): a single DBAPI connection checked
out from the connection pool.
"""
return self.engine.connect()
107 changes: 107 additions & 0 deletions src/langchain_google_cloud_sql_mssql/mssql_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from collections.abc import Iterable
from typing import Any, List, Optional, Sequence

import sqlalchemy
from langchain_community.document_loaders.base import BaseLoader
from langchain_core.documents import Document

from langchain_google_cloud_sql_mssql.mssql_engine import MSSQLEngine

DEFAULT_METADATA_COL = "langchain_metadata"


def _parse_doc_from_table(
content_columns: Iterable[str],
metadata_columns: Iterable[str],
column_names: Iterable[str],
rows: Sequence[Any],
) -> List[Document]:
docs = []
for row in rows:
page_content = " ".join(
str(getattr(row, column))
for column in content_columns
if column in column_names
)
metadata = {
column: getattr(row, column)
for column in metadata_columns
if column in column_names
}
if DEFAULT_METADATA_COL in metadata:
extra_metadata = json.loads(metadata[DEFAULT_METADATA_COL])
del metadata[DEFAULT_METADATA_COL]
metadata |= extra_metadata
doc = Document(page_content=page_content, metadata=metadata)
docs.append(doc)
return docs


class MSSQLLoader(BaseLoader):
"""A class for loading langchain documents from a Cloud SQL MSSQL database."""

def __init__(
self,
engine: MSSQLEngine,
query: str,
content_columns: Optional[List[str]] = None,
metadata_columns: Optional[List[str]] = None,
):
"""
Args:
engine (MSSQLEngine): MSSQLEngine object to connect to the MSSQL database.
query (str): The query to execute in MSSQL format.
content_columns (List[str]): The columns to write into the `page_content`
of the document. Optional.
metadata_columns (List[str]): The columns to write into the `metadata` of the document.
Optional.
"""
self.engine = engine
self.query = query
self.content_columns = content_columns
self.metadata_columns = metadata_columns

def load(self) -> List[Document]:
"""
Load langchain documents from a Cloud SQL MSSQL database.

Document page content defaults to the first columns present in the query or table and
metadata defaults to all other columns. Use with content_columns to overwrite the column
used for page content. Use metadata_columns to select specific metadata columns rather
than using all remaining columns.

If multiple content columns are specified, page_content’s string format will default to
space-separated string concatenation.

Returns:
(List[langchain_core.documents.Document]): a list of Documents with metadata from
specific columns.
"""
with self.engine.connect() as connection:
result_proxy = connection.execute(sqlalchemy.text(self.query))
column_names = list(result_proxy.keys())
results = result_proxy.fetchall()
content_columns = self.content_columns or [column_names[0]]
metadata_columns = self.metadata_columns or [
col for col in column_names if col not in content_columns
]
return _parse_doc_from_table(
content_columns,
metadata_columns,
column_names,
results,
)
Loading