Skip to content

feat: support context manager for bigframes session #1107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,23 @@ def __init__(
)

def __del__(self):
"""Automatic cleanup of internal resources"""
"""Automatic cleanup of internal resources."""
self.close()

def __enter__(self):
"""Enter the runtime context of the Session object.

See [With Statement Context Managers](https://docs.python.org/3/reference/datamodel.html#with-statement-context-managers)
for more details.
"""
return self

def __exit__(self, *_):
"""Exit the runtime context of the Session object.

See [With Statement Context Managers](https://docs.python.org/3/reference/datamodel.html#with-statement-context-managers)
for more details.
"""
self.close()

@property
Expand Down
152 changes: 152 additions & 0 deletions tests/system/large/test_remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2245,3 +2245,155 @@ def test_remote_function_ingress_settings_unsupported(session):
@session.remote_function(reuse=False, cloud_function_ingress_settings="unknown")
def square(x: int) -> int:
return x * x


@pytest.mark.parametrize(
("session_creator"),
[
pytest.param(bigframes.Session, id="session-constructor"),
pytest.param(bigframes.connect, id="connect-method"),
],
)
@pytest.mark.flaky(retries=2, delay=120)
def test_remote_function_w_context_manager_unnamed(
scalars_dfs, dataset_id, bq_cf_connection, session_creator
):
def add_one(x: int) -> int:
return x + 1

scalars_df, scalars_pandas_df = scalars_dfs
pd_result = scalars_pandas_df["int64_too"].apply(add_one)

temporary_bigquery_remote_function = None
temporary_cloud_run_function = None

try:
with session_creator() as session:
# create a temporary remote function
add_one_remote_temp = session.remote_function(
dataset=dataset_id,
bigquery_connection=bq_cf_connection,
reuse=False,
)(add_one)

temporary_bigquery_remote_function = (
add_one_remote_temp.bigframes_remote_function
)
assert temporary_bigquery_remote_function is not None
assert (
session.bqclient.get_routine(temporary_bigquery_remote_function)
is not None
)

temporary_cloud_run_function = add_one_remote_temp.bigframes_cloud_function
assert temporary_cloud_run_function is not None
assert (
session.cloudfunctionsclient.get_function(
name=temporary_cloud_run_function
)
is not None
)

bf_result = scalars_df["int64_too"].apply(add_one_remote_temp).to_pandas()
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)

# outside the with statement context manager the temporary BQ remote
# function and the underlying cloud run function should have been
# cleaned up
assert temporary_bigquery_remote_function is not None
with pytest.raises(google.api_core.exceptions.NotFound):
session.bqclient.get_routine(temporary_bigquery_remote_function)
# the deletion of cloud function happens in a non-blocking way, ensure that
# it either exists in a being-deleted state, or is already deleted
assert temporary_cloud_run_function is not None
try:
gcf = session.cloudfunctionsclient.get_function(
name=temporary_cloud_run_function
)
assert gcf.state is functions_v2.Function.State.DELETING
except google.cloud.exceptions.NotFound:
pass
finally:
# clean up the gcp assets created for the temporary remote function,
# just in case it was not explicitly cleaned up in the try clause due
# to assertion failure or exception earlier than that
cleanup_remote_function_assets(
session.bqclient, session.cloudfunctionsclient, add_one_remote_temp
)


@pytest.mark.parametrize(
("session_creator"),
[
pytest.param(bigframes.Session, id="session-constructor"),
pytest.param(bigframes.connect, id="connect-method"),
],
)
@pytest.mark.flaky(retries=2, delay=120)
def test_remote_function_w_context_manager_named(
scalars_dfs, dataset_id, bq_cf_connection, session_creator
):
def add_one(x: int) -> int:
return x + 1

scalars_df, scalars_pandas_df = scalars_dfs
pd_result = scalars_pandas_df["int64_too"].apply(add_one)

persistent_bigquery_remote_function = None
persistent_cloud_run_function = None

try:
with session_creator() as session:
# create a persistent remote function
name = test_utils.prefixer.Prefixer("bigframes", "").create_prefix()
add_one_remote_persist = session.remote_function(
dataset=dataset_id,
bigquery_connection=bq_cf_connection,
reuse=False,
name=name,
)(add_one)

persistent_bigquery_remote_function = (
add_one_remote_persist.bigframes_remote_function
)
assert persistent_bigquery_remote_function is not None
assert (
session.bqclient.get_routine(persistent_bigquery_remote_function)
is not None
)

persistent_cloud_run_function = (
add_one_remote_persist.bigframes_cloud_function
)
assert persistent_cloud_run_function is not None
assert (
session.cloudfunctionsclient.get_function(
name=persistent_cloud_run_function
)
is not None
)

bf_result = (
scalars_df["int64_too"].apply(add_one_remote_persist).to_pandas()
)
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)

# outside the with statement context manager the persistent BQ remote
# function and the underlying cloud run function should still exist
assert persistent_bigquery_remote_function is not None
assert (
session.bqclient.get_routine(persistent_bigquery_remote_function)
is not None
)
assert persistent_cloud_run_function is not None
assert (
session.cloudfunctionsclient.get_function(
name=persistent_cloud_run_function
)
is not None
)
finally:
# clean up the gcp assets created for the persistent remote function
cleanup_remote_function_assets(
session.bqclient, session.cloudfunctionsclient, add_one_remote_persist
)
35 changes: 35 additions & 0 deletions tests/system/large/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,38 @@ def test_clean_up_by_session_id():
assert not any(
[(session.session_id in table.full_table_id) for table in tables_after]
)


@pytest.mark.parametrize(
("session_creator"),
[
pytest.param(bigframes.Session, id="session-constructor"),
pytest.param(bigframes.connect, id="connect-method"),
],
)
def test_clean_up_via_context_manager(session_creator):
# we will create two tables and confirm that they are deleted
# when the session is closed
with session_creator() as session:
bqclient = session.bqclient

expiration = (
datetime.datetime.now(datetime.timezone.utc)
+ bigframes.constants.DEFAULT_EXPIRATION
)
full_id_1 = bigframes.session._io.bigquery.create_temp_table(
session.bqclient, session._temp_storage_manager._random_table(), expiration
)
full_id_2 = bigframes.session._io.bigquery.create_temp_table(
session.bqclient, session._temp_storage_manager._random_table(), expiration
)

# check that the tables were actually created
assert bqclient.get_table(full_id_1).created is not None
assert bqclient.get_table(full_id_2).created is not None

# check that the tables are already deleted
with pytest.raises(google.cloud.exceptions.NotFound):
bqclient.delete_table(full_id_1)
with pytest.raises(google.cloud.exceptions.NotFound):
bqclient.delete_table(full_id_2)
8 changes: 4 additions & 4 deletions tests/unit/_config/test_bigquery_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,15 @@ def test_setter_if_session_started_but_setting_the_same_value(attribute):
)
def test_location_set_to_valid_no_warning(valid_location):
# test setting location through constructor
def set_location_in_ctor():
def set_location_in_constructor():
bigquery_options.BigQueryOptions(location=valid_location)

# test setting location property
def set_location_property():
options = bigquery_options.BigQueryOptions()
options.location = valid_location

for op in [set_location_in_ctor, set_location_property]:
for op in [set_location_in_constructor, set_location_property]:
# Ensure that no warnings are emitted.
# https://docs.pytest.org/en/7.0.x/how-to/capture-warnings.html#additional-use-cases-of-warnings-in-tests
with warnings.catch_warnings():
Expand Down Expand Up @@ -136,15 +136,15 @@ def set_location_property():
)
def test_location_set_to_invalid_warning(invalid_location, possibility):
# test setting location through constructor
def set_location_in_ctor():
def set_location_in_constructor():
bigquery_options.BigQueryOptions(location=invalid_location)

# test setting location property
def set_location_property():
options = bigquery_options.BigQueryOptions()
options.location = invalid_location

for op in [set_location_in_ctor, set_location_property]:
for op in [set_location_in_constructor, set_location_property]:
with pytest.warns(
bigframes.exceptions.UnknownLocationWarning,
match=re.escape(
Expand Down