Skip to content

feat: add llm.TextEmbeddingGenerator to support new embedding models #905

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 163 additions & 3 deletions bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,18 @@

_EMBEDDING_GENERATOR_GECKO_ENDPOINT = "textembedding-gecko"
_EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT = "textembedding-gecko-multilingual"
_EMBEDDING_GENERATOR_ENDPOINTS = (
_PALM2_EMBEDDING_GENERATOR_ENDPOINTS = (
_EMBEDDING_GENERATOR_GECKO_ENDPOINT,
_EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT,
)

_TEXT_EMBEDDING_004_ENDPOINT = "text-embedding-004"
_TEXT_MULTILINGUAL_EMBEDDING_002_ENDPOINT = "text-multilingual-embedding-002"
_TEXT_EMBEDDING_ENDPOINTS = (
_TEXT_EMBEDDING_004_ENDPOINT,
_TEXT_MULTILINGUAL_EMBEDDING_002_ENDPOINT,
)

_GEMINI_PRO_ENDPOINT = "gemini-pro"
_GEMINI_1P5_PRO_PREVIEW_ENDPOINT = "gemini-1.5-pro-preview-0514"
_GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT = "gemini-1.5-flash-preview-0514"
Expand All @@ -57,6 +64,7 @@

_ML_GENERATE_TEXT_STATUS = "ml_generate_text_status"
_ML_EMBED_TEXT_STATUS = "ml_embed_text_status"
_ML_GENERATE_EMBEDDING_STATUS = "ml_generate_embedding_status"


@log_adapter.class_logger
Expand Down Expand Up @@ -387,6 +395,10 @@ def to_gbq(self, model_name: str, replace: bool = False) -> PaLM2TextGenerator:
class PaLM2TextEmbeddingGenerator(base.BaseEstimator):
"""PaLM2 text embedding generator LLM model.

.. note::
Models in this class are outdated and going to be deprecated. To use the most updated text embedding models, go to the TextEmbeddingGenerator class.


Args:
model_name (str, Default to "textembedding-gecko"):
The model for text embedding. “textembedding-gecko” returns model embeddings for text inputs.
Expand Down Expand Up @@ -447,9 +459,9 @@ def _create_bqml_model(self):
iam_role="aiplatform.user",
)

if self.model_name not in _EMBEDDING_GENERATOR_ENDPOINTS:
if self.model_name not in _PALM2_EMBEDDING_GENERATOR_ENDPOINTS:
raise ValueError(
f"Model name {self.model_name} is not supported. We only support {', '.join(_EMBEDDING_GENERATOR_ENDPOINTS)}."
f"Model name {self.model_name} is not supported. We only support {', '.join(_PALM2_EMBEDDING_GENERATOR_ENDPOINTS)}."
)

endpoint = (
Expand Down Expand Up @@ -551,6 +563,154 @@ def to_gbq(
return new_model.session.read_gbq_model(model_name)


@log_adapter.class_logger
class TextEmbeddingGenerator(base.BaseEstimator):
"""Text embedding generator LLM model.

Args:
model_name (str, Default to "text-embedding-004"):
The model for text embedding. Possible values are "text-embedding-004" or "text-multilingual-embedding-002".
text-embedding models returns model embeddings for text inputs.
text-multilingual-embedding models returns model embeddings for text inputs which support over 100 languages.
Default to "text-embedding-004".
session (bigframes.Session or None):
BQ session to create the model. If None, use the global default session.
connection_name (str or None):
Connection to connect with remote service. str of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
If None, use default connection in session context.
"""

def __init__(
self,
*,
model_name: Literal[
"text-embedding-004", "text-multilingual-embedding-002"
] = "text-embedding-004",
session: Optional[bigframes.Session] = None,
connection_name: Optional[str] = None,
):
self.model_name = model_name
self.session = session or bpd.get_global_session()
self._bq_connection_manager = self.session.bqconnectionmanager

connection_name = connection_name or self.session._bq_connection
self.connection_name = clients.resolve_full_bq_connection_name(
connection_name,
default_project=self.session._project,
default_location=self.session._location,
)

self._bqml_model_factory = globals.bqml_model_factory()
self._bqml_model: core.BqmlModel = self._create_bqml_model()

def _create_bqml_model(self):
# Parse and create connection if needed.
if not self.connection_name:
raise ValueError(
"Must provide connection_name, either in constructor or through session options."
)

if self._bq_connection_manager:
connection_name_parts = self.connection_name.split(".")
if len(connection_name_parts) != 3:
raise ValueError(
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
)
self._bq_connection_manager.create_bq_connection(
project_id=connection_name_parts[0],
location=connection_name_parts[1],
connection_id=connection_name_parts[2],
iam_role="aiplatform.user",
)

if self.model_name not in _TEXT_EMBEDDING_ENDPOINTS:
raise ValueError(
f"Model name {self.model_name} is not supported. We only support {', '.join(_TEXT_EMBEDDING_ENDPOINTS)}."
)

options = {
"endpoint": self.model_name,
}
return self._bqml_model_factory.create_remote_model(
session=self.session, connection_name=self.connection_name, options=options
)

@classmethod
def _from_bq(
cls, session: bigframes.Session, bq_model: bigquery.Model
) -> TextEmbeddingGenerator:
assert bq_model.model_type == "MODEL_TYPE_UNSPECIFIED"
assert "remoteModelInfo" in bq_model._properties
assert "endpoint" in bq_model._properties["remoteModelInfo"]
assert "connection" in bq_model._properties["remoteModelInfo"]

# Parse the remote model endpoint
bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"]
model_connection = bq_model._properties["remoteModelInfo"]["connection"]
model_endpoint = bqml_endpoint.split("/")[-1]

model = cls(
session=session,
model_name=model_endpoint, # type: ignore
connection_name=model_connection,
)

model._bqml_model = core.BqmlModel(session, bq_model)
return model

def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
"""Predict the result from input DataFrame.

Args:
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
Input DataFrame, which needs to contain a column with name "content". Only the column will be used as input. Content can include preamble, questions, suggestions, instructions, or examples.

Returns:
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
"""

# Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models
(X,) = utils.convert_to_dataframe(X)

if len(X.columns) != 1:
raise ValueError(
f"Only support one column as input. {constants.FEEDBACK_LINK}"
)

# BQML identified the column by name
col_label = cast(blocks.Label, X.columns[0])
X = X.rename(columns={col_label: "content"})

options = {
"flatten_json_output": True,
}

df = self._bqml_model.generate_embedding(X, options)

if (df[_ML_GENERATE_EMBEDDING_STATUS] != "").any():
warnings.warn(
f"Some predictions failed. Check column {_ML_GENERATE_EMBEDDING_STATUS} for detailed status. You may want to filter the failed rows and retry.",
RuntimeWarning,
)

return df

def to_gbq(self, model_name: str, replace: bool = False) -> TextEmbeddingGenerator:
"""Save the model to BigQuery.

Args:
model_name (str):
The name of the model.
replace (bool, default False):
Determine whether to replace if the model already exists. Default to False.

Returns:
TextEmbeddingGenerator: Saved model."""

new_model = self._bqml_model.copy(model_name, replace)
return new_model.session.read_gbq_model(model_name)


@log_adapter.class_logger
class GeminiTextGenerator(base.BaseEstimator):
"""Gemini text generator LLM model.
Expand Down
3 changes: 3 additions & 0 deletions bigframes/ml/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
llm._GEMINI_PRO_ENDPOINT: llm.GeminiTextGenerator,
llm._GEMINI_1P5_PRO_PREVIEW_ENDPOINT: llm.GeminiTextGenerator,
llm._GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT: llm.GeminiTextGenerator,
llm._TEXT_EMBEDDING_004_ENDPOINT: llm.TextEmbeddingGenerator,
llm._TEXT_MULTILINGUAL_EMBEDDING_002_ENDPOINT: llm.TextEmbeddingGenerator,
}
)

Expand All @@ -84,6 +86,7 @@ def from_bq(
imported.XGBoostModel,
llm.PaLM2TextGenerator,
llm.PaLM2TextEmbeddingGenerator,
llm.TextEmbeddingGenerator,
pipeline.Pipeline,
compose.ColumnTransformer,
preprocessing.PreprocessingType,
Expand Down
41 changes: 41 additions & 0 deletions tests/system/small/ml/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,47 @@ def test_embedding_generator_predict_series_success(
assert len(value) == 768


@pytest.mark.parametrize(
"model_name",
("text-embedding-004", "text-multilingual-embedding-002"),
)
def test_create_load_text_embedding_generator_model(
dataset_id, model_name, session, bq_connection
):
text_embedding_model = llm.TextEmbeddingGenerator(
model_name=model_name, connection_name=bq_connection, session=session
)
assert text_embedding_model is not None
assert text_embedding_model._bqml_model is not None

# save, load to ensure configuration was kept
reloaded_model = text_embedding_model.to_gbq(
f"{dataset_id}.temp_text_model", replace=True
)
assert f"{dataset_id}.temp_text_model" == reloaded_model._bqml_model.model_name
assert reloaded_model.connection_name == bq_connection
assert reloaded_model.model_name == model_name


@pytest.mark.parametrize(
"model_name",
("text-embedding-004", "text-multilingual-embedding-002"),
)
@pytest.mark.flaky(retries=2)
def test_gemini_text_embedding_generator_predict_default_params_success(
llm_text_df, model_name, session, bq_connection
):
text_embedding_model = llm.TextEmbeddingGenerator(
model_name=model_name, connection_name=bq_connection, session=session
)
df = text_embedding_model.predict(llm_text_df).to_pandas()
assert df.shape == (3, 4)
assert "ml_generate_embedding_result" in df.columns
series = df["ml_generate_embedding_result"]
value = series[0]
assert len(value) == 768


@pytest.mark.parametrize(
"model_name",
("gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514"),
Expand Down