From 179f60d54ee60b54d923dba05c61f9f6f2ad1fe8 Mon Sep 17 00:00:00 2001 From: Shuowei Li Date: Wed, 20 Aug 2025 04:40:41 +0000 Subject: [PATCH 1/2] support ai.generate --- bigframes/ml/core.py | 15 ++++ bigframes/ml/llm.py | 4 +- bigframes/ml/sql.py | 10 +++ bigframes/ml/utils.py | 10 ++- bigframes/testing/utils.py | 5 ++ tests/system/small/ml/test_llm.py | 124 +++++++++++++++++++++++++++++- tests/unit/ml/test_sql.py | 17 ++++ 7 files changed, 182 insertions(+), 3 deletions(-) diff --git a/bigframes/ml/core.py b/bigframes/ml/core.py index 73b8ba8dbc..d8eaec76ec 100644 --- a/bigframes/ml/core.py +++ b/bigframes/ml/core.py @@ -217,6 +217,21 @@ def generate_table( generate_table_tvf = TvfDef(generate_table, "status") + def ai_generate( + self, + input_data: bpd.DataFrame, + options: dict[str, Union[int, float, bool, Mapping]], + ) -> bpd.DataFrame: + return self._apply_ml_tvf( + input_data, + lambda source_sql: self._sql_generator.ai_generate( + source_sql=source_sql, + struct_options=options, + ), + ) + + ai_generate_tvf = TvfDef(ai_generate, "status") + def detect_anomalies( self, input_data: bpd.DataFrame, options: Mapping[str, int | float] ) -> bpd.DataFrame: diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 11861c786e..bc9648c6c4 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -734,7 +734,9 @@ def predict( output_schema = { k: utils.standardize_type(v) for k, v in output_schema.items() } - options["output_schema"] = output_schema + options["output_schema"] = { + k: utils.standardize_type(v) for k, v in output_schema.items() + } return self._predict_and_retry( core.BqmlModel.generate_table_tvf, X, diff --git a/bigframes/ml/sql.py b/bigframes/ml/sql.py index 2937368c92..01bac17446 100644 --- a/bigframes/ml/sql.py +++ b/bigframes/ml/sql.py @@ -435,3 +435,13 @@ def ai_generate_table( struct_options_sql = self.struct_options(**struct_options) return f"""SELECT * FROM AI.GENERATE_TABLE(MODEL {self._model_ref_sql()}, ({source_sql}), {struct_options_sql})""" + + def ai_generate( + self, + source_sql: str, + struct_options: Mapping[str, Union[int, float, bool, Mapping]], + ) -> str: + """Encode AI.GENERATE for BQML""" + struct_options_sql = self.struct_options(**struct_options) + return f"""SELECT * FROM AI.GENERATE(MODEL {self._model_ref_sql()}, + ({source_sql}), {struct_options_sql})""" diff --git a/bigframes/ml/utils.py b/bigframes/ml/utils.py index 5c02789576..1a51250200 100644 --- a/bigframes/ml/utils.py +++ b/bigframes/ml/utils.py @@ -191,8 +191,16 @@ def combine_training_and_evaluation_data( def standardize_type(v: str, supported_dtypes: Optional[Iterable[str]] = None): + """Standardize type string to BQML supported type string.""" t = v.lower() - t = t.replace("boolean", "bool") + if t == "boolean": + t = "bool" + elif t == "integer": + t = "int64" + elif t == "str": + t = "string" + elif t == "float": + t = "float64" if supported_dtypes: if t not in supported_dtypes: diff --git a/bigframes/testing/utils.py b/bigframes/testing/utils.py index 5da24c5b9b..fd09945ee1 100644 --- a/bigframes/testing/utils.py +++ b/bigframes/testing/utils.py @@ -49,6 +49,11 @@ "ml_generate_text_status", "prompt", ] +AI_GENERATE_OUTPUT = [ + "result", + "full_response", + "status", +] ML_GENERATE_EMBEDDING_OUTPUT = [ "ml_generate_embedding_result", "ml_generate_embedding_statistics", diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index 245fead028..5362974c32 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -15,6 +15,7 @@ from typing import Callable from unittest import mock +from google.api_core import exceptions as api_core_exceptions import pandas as pd import pyarrow as pa import pytest @@ -216,7 +217,9 @@ def test_gemini_text_generator_predict_output_schema_success( llm_text_df: bpd.DataFrame, model_name, session, bq_connection ): gemini_text_generator_model = llm.GeminiTextGenerator( - model_name=model_name, connection_name=bq_connection, session=session + model_name="gemini-2.0-flash-001", + connection_name=bq_connection, + session=session, ) output_schema = { "bool_output": "bool", @@ -807,3 +810,122 @@ def test_text_embedding_generator_no_default_model_warning(model_class): message = "Since upgrading the default model can cause unintended breakages, the\ndefault model will be removed in BigFrames 3.0. Please supply an\nexplicit model to avoid this message." with pytest.warns(FutureWarning, match=message): model_class(model_name=None) + + +@pytest.mark.flaky(retries=2) +def test_gemini_text_generator_predict_struct_schema_succeeds( + llm_text_df: bpd.DataFrame, session, bq_connection +): + gemini_text_generator_model = llm.GeminiTextGenerator( + model_name="gemini-2.0-flash-001", + connection_name=bq_connection, + session=session, + ) + output_schema = { + "struct_output": "struct", + } + df = gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema) + assert set(field.name for field in df["struct_output"].dtype.pyarrow_dtype) == { + "name", + "age", + } + + pd_df = df.to_pandas() + utils.check_pandas_df_schema_and_index( + pd_df, + columns=list(output_schema.keys()) + ["prompt", "full_response", "status"], + index=3, + col_exact=False, + ) + + +@pytest.mark.flaky(retries=2) +def test_gemini_text_generator_predict_struct_schema_flat_succeeds( + llm_text_df: bpd.DataFrame, session, bq_connection +): + gemini_text_generator_model = llm.GeminiTextGenerator( + model_name="gemini-2.0-flash-001", + connection_name=bq_connection, + session=session, + ) + output_schema = { + "name": "string", + "age": "int64", + } + df = gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema) + assert df["name"].dtype == pd.StringDtype(storage="pyarrow") + assert df["age"].dtype == pd.Int64Dtype() + + pd_df = df.to_pandas() + utils.check_pandas_df_schema_and_index( + pd_df, + columns=list(output_schema.keys()) + ["prompt", "full_response", "status"], + index=3, + col_exact=False, + ) + + +@pytest.mark.flaky(retries=2) +def test_gemini_text_generator_predict_array_schema_succeeds( + llm_text_df: bpd.DataFrame, session, bq_connection +): + gemini_text_generator_model = llm.GeminiTextGenerator( + model_name="gemini-2.0-flash-001", + connection_name=bq_connection, + session=session, + ) + output_schema = { + "array_output": "array", + } + df = gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema) + assert df["array_output"].dtype == pd.ArrowDtype(pa.list_(pa.string())) + + pd_df = df.to_pandas() + utils.check_pandas_df_schema_and_index( + pd_df, + columns=list(output_schema.keys()) + ["prompt", "full_response", "status"], + index=3, + col_exact=False, + ) + + +@pytest.mark.flaky(retries=2) +def test_gemini_text_generator_predict_array_struct_schema_succeeds( + llm_text_df: bpd.DataFrame, session, bq_connection +): + gemini_text_generator_model = llm.GeminiTextGenerator( + model_name="gemini-2.0-flash-001", + connection_name=bq_connection, + session=session, + ) + output_schema = { + "array_output": "array>", + } + df = gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema) + assert set( + field.name for field in df["array_output"].dtype.pyarrow_dtype.value_type + ) == {"name", "age"} + + pd_df = df.to_pandas() + utils.check_pandas_df_schema_and_index( + pd_df, + columns=list(output_schema.keys()) + ["prompt", "full_response", "status"], + index=3, + col_exact=False, + ) + + +@pytest.mark.flaky(retries=2) +def test_gemini_text_generator_predict_invalid_schema_fails( + llm_text_df: bpd.DataFrame, session, bq_connection +): + gemini_text_generator_model = llm.GeminiTextGenerator( + model_name="gemini-2.0-flash-001", + connection_name=bq_connection, + session=session, + ) + output_schema = { + "invalid_output": "invalid_type", + } + with pytest.raises(api_core_exceptions.BadRequest): + gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema) diff --git a/tests/unit/ml/test_sql.py b/tests/unit/ml/test_sql.py index d605b571f3..e36d7d8acb 100644 --- a/tests/unit/ml/test_sql.py +++ b/tests/unit/ml/test_sql.py @@ -529,3 +529,20 @@ def test_ml_principal_component_info_correct( sql == """SELECT * FROM ML.PRINCIPAL_COMPONENT_INFO(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`)""" ) + + +def test_ai_generate_correct( + model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator, + mock_df: bpd.DataFrame, +): + sql = model_manipulation_sql_generator.ai_generate( + source_sql=mock_df.sql, + struct_options={"option_key1": 1, "option_key2": 2.2}, + ) + assert ( + sql + == """SELECT * FROM AI.GENERATE(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`, + (input_X_y_sql), STRUCT( + 1 AS `option_key1`, + 2.2 AS `option_key2`))""" + ) From 4dcc5c7f974f462efc396fb24e85b66430665fe9 Mon Sep 17 00:00:00 2001 From: Shuowei Li Date: Wed, 20 Aug 2025 04:47:08 +0000 Subject: [PATCH 2/2] minor update --- bigframes/ml/llm.py | 4 +- tests/system/small/ml/test_llm.py | 62 +++++++++++++++++++++++-------- 2 files changed, 47 insertions(+), 19 deletions(-) diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index bc9648c6c4..11861c786e 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -734,9 +734,7 @@ def predict( output_schema = { k: utils.standardize_type(v) for k, v in output_schema.items() } - options["output_schema"] = { - k: utils.standardize_type(v) for k, v in output_schema.items() - } + options["output_schema"] = output_schema return self._predict_and_retry( core.BqmlModel.generate_table_tvf, X, diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index 5362974c32..7644cec816 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -217,7 +217,7 @@ def test_gemini_text_generator_predict_output_schema_success( llm_text_df: bpd.DataFrame, model_name, session, bq_connection ): gemini_text_generator_model = llm.GeminiTextGenerator( - model_name="gemini-2.0-flash-001", + model_name=model_name, connection_name=bq_connection, session=session, ) @@ -812,12 +812,18 @@ def test_text_embedding_generator_no_default_model_warning(model_class): model_class(model_name=None) -@pytest.mark.flaky(retries=2) +@pytest.mark.parametrize( + "model_name", + ( + "gemini-2.0-flash-001", + "gemini-2.0-flash-lite-001", + ), +) def test_gemini_text_generator_predict_struct_schema_succeeds( - llm_text_df: bpd.DataFrame, session, bq_connection + llm_text_df: bpd.DataFrame, session, bq_connection, model_name ): gemini_text_generator_model = llm.GeminiTextGenerator( - model_name="gemini-2.0-flash-001", + model_name=model_name, connection_name=bq_connection, session=session, ) @@ -839,12 +845,18 @@ def test_gemini_text_generator_predict_struct_schema_succeeds( ) -@pytest.mark.flaky(retries=2) +@pytest.mark.parametrize( + "model_name", + ( + "gemini-2.0-flash-001", + "gemini-2.0-flash-lite-001", + ), +) def test_gemini_text_generator_predict_struct_schema_flat_succeeds( - llm_text_df: bpd.DataFrame, session, bq_connection + llm_text_df: bpd.DataFrame, session, bq_connection, model_name ): gemini_text_generator_model = llm.GeminiTextGenerator( - model_name="gemini-2.0-flash-001", + model_name=model_name, connection_name=bq_connection, session=session, ) @@ -865,12 +877,18 @@ def test_gemini_text_generator_predict_struct_schema_flat_succeeds( ) -@pytest.mark.flaky(retries=2) +@pytest.mark.parametrize( + "model_name", + ( + "gemini-2.0-flash-001", + "gemini-2.0-flash-lite-001", + ), +) def test_gemini_text_generator_predict_array_schema_succeeds( - llm_text_df: bpd.DataFrame, session, bq_connection + llm_text_df: bpd.DataFrame, session, bq_connection, model_name ): gemini_text_generator_model = llm.GeminiTextGenerator( - model_name="gemini-2.0-flash-001", + model_name=model_name, connection_name=bq_connection, session=session, ) @@ -889,12 +907,18 @@ def test_gemini_text_generator_predict_array_schema_succeeds( ) -@pytest.mark.flaky(retries=2) +@pytest.mark.parametrize( + "model_name", + ( + "gemini-2.0-flash-001", + "gemini-2.0-flash-lite-001", + ), +) def test_gemini_text_generator_predict_array_struct_schema_succeeds( - llm_text_df: bpd.DataFrame, session, bq_connection + llm_text_df: bpd.DataFrame, session, bq_connection, model_name ): gemini_text_generator_model = llm.GeminiTextGenerator( - model_name="gemini-2.0-flash-001", + model_name=model_name, connection_name=bq_connection, session=session, ) @@ -915,12 +939,18 @@ def test_gemini_text_generator_predict_array_struct_schema_succeeds( ) -@pytest.mark.flaky(retries=2) +@pytest.mark.parametrize( + "model_name", + ( + "gemini-2.0-flash-001", + "gemini-2.0-flash-lite-001", + ), +) def test_gemini_text_generator_predict_invalid_schema_fails( - llm_text_df: bpd.DataFrame, session, bq_connection + llm_text_df: bpd.DataFrame, session, bq_connection, model_name ): gemini_text_generator_model = llm.GeminiTextGenerator( - model_name="gemini-2.0-flash-001", + model_name=model_name, connection_name=bq_connection, session=session, )