Skip to content

feat: Support write api as loading option #1617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 66 additions & 12 deletions bigframes/core/local_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,27 +97,46 @@ def from_pyarrow(self, table: pa.Table) -> ManagedArrowTable:
mat.validate()
return mat

def to_pyarrow_table(
def to_arrow(
self,
*,
offsets_col: Optional[str] = None,
geo_format: Literal["wkb", "wkt"] = "wkt",
duration_type: Literal["int", "duration"] = "duration",
json_type: Literal["string"] = "string",
) -> pa.Table:
pa_table = self.data
if offsets_col is not None:
pa_table = pa_table.append_column(
offsets_col, pa.array(range(pa_table.num_rows), type=pa.int64())
)
) -> tuple[pa.Schema, Iterable[pa.RecordBatch]]:
if geo_format != "wkt":
raise NotImplementedError(f"geo format {geo_format} not yet implemented")
if duration_type != "duration":
raise NotImplementedError(
f"duration as {duration_type} not yet implemented"
)
assert json_type == "string"
return pa_table

batches = self.data.to_batches()
schema = self.data.schema
if duration_type == "int":
schema = _schema_durations_to_ints(schema)
batches = map(functools.partial(_cast_pa_batch, schema=schema), batches)

if offsets_col is not None:
return schema.append(pa.field(offsets_col, pa.int64())), _append_offsets(
batches, offsets_col
)
else:
return schema, batches

def to_pyarrow_table(
self,
*,
offsets_col: Optional[str] = None,
geo_format: Literal["wkb", "wkt"] = "wkt",
duration_type: Literal["int", "duration"] = "duration",
json_type: Literal["string"] = "string",
) -> pa.Table:
schema, batches = self.to_arrow(
offsets_col=offsets_col,
geo_format=geo_format,
duration_type=duration_type,
json_type=json_type,
)
return pa.Table.from_batches(batches, schema)

def to_parquet(
self,
Expand Down Expand Up @@ -391,6 +410,41 @@ def _physical_type_replacements(dtype: pa.DataType) -> pa.DataType:
return dtype


def _append_offsets(
batches: Iterable[pa.RecordBatch], offsets_col_name: str
) -> Iterable[pa.RecordBatch]:
offset = 0
for batch in batches:
offsets = pa.array(range(offset, offset + batch.num_rows), type=pa.int64())
batch_w_offsets = pa.record_batch(
[*batch.columns, offsets],
schema=batch.schema.append(pa.field(offsets_col_name, pa.int64())),
)
offset += batch.num_rows
yield batch_w_offsets


@_recursive_map_types
def _durations_to_ints(type: pa.DataType) -> pa.DataType:
if pa.types.is_duration(type):
return pa.int64()
return type


def _schema_durations_to_ints(schema: pa.Schema) -> pa.Schema:
return pa.schema(
pa.field(field.name, _durations_to_ints(field.type)) for field in schema
)


# TODO: Use RecordBatch.cast once min pyarrow>=16.0
def _cast_pa_batch(batch: pa.RecordBatch, schema: pa.Schema) -> pa.RecordBatch:
return pa.record_batch(
[arr.cast(type) for arr, type in zip(batch.columns, schema.types)],
schema=schema,
)


def _pairwise(iterable):
do_yield = False
a = None
Expand Down
3 changes: 3 additions & 0 deletions bigframes/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def label_to_identifier(label: typing.Hashable, strict: bool = False) -> str:
identifier = re.sub(r"[^a-zA-Z0-9_]", "", identifier)
if not identifier:
identifier = "id"
elif identifier[0].isdigit():
# first character must be letter or underscore
identifier = "_" + identifier
return identifier


Expand Down
9 changes: 8 additions & 1 deletion bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def __init__(
session=self,
bqclient=self._clients_provider.bqclient,
storage_manager=self._temp_storage_manager,
write_client=self._clients_provider.bqstoragewriteclient,
default_index_type=self._default_index_type,
scan_index_uniqueness=self._strictly_ordered,
force_total_order=self._strictly_ordered,
Expand Down Expand Up @@ -731,7 +732,9 @@ def read_pandas(
workload is such that you exhaust the BigQuery load job
quota and your data cannot be embedded in SQL due to size or
data type limitations.

* "bigquery_write":
[Preview] Use the BigQuery Storage Write API. This feature
is in public preview.
Returns:
An equivalent bigframes.pandas.(DataFrame/Series/Index) object

Expand Down Expand Up @@ -805,6 +808,10 @@ def _read_pandas(
return self._loader.read_pandas(
pandas_dataframe, method="stream", api_name=api_name
)
elif write_engine == "bigquery_write":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add this to the various docstrings? Let's mark the "bigquery_write" option as [Preview] in the docs, too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added to the only docstring that enumerates the engines

return self._loader.read_pandas(
pandas_dataframe, method="write", api_name=api_name
)
else:
raise ValueError(f"Got unexpected write_engine '{write_engine}'")

Expand Down
31 changes: 31 additions & 0 deletions bigframes/session/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def __init__(
self._bqstoragereadclient: Optional[
google.cloud.bigquery_storage_v1.BigQueryReadClient
] = None
self._bqstoragewriteclient: Optional[
google.cloud.bigquery_storage_v1.BigQueryWriteClient
] = None
self._cloudfunctionsclient: Optional[
google.cloud.functions_v2.FunctionServiceClient
] = None
Expand Down Expand Up @@ -238,6 +241,34 @@ def bqstoragereadclient(self):

return self._bqstoragereadclient

@property
def bqstoragewriteclient(self):
if not self._bqstoragewriteclient:
bqstorage_options = None
if "bqstoragewriteclient" in self._client_endpoints_override:
bqstorage_options = google.api_core.client_options.ClientOptions(
api_endpoint=self._client_endpoints_override["bqstoragewriteclient"]
)
elif self._use_regional_endpoints:
bqstorage_options = google.api_core.client_options.ClientOptions(
api_endpoint=_BIGQUERYSTORAGE_REGIONAL_ENDPOINT.format(
location=self._location
)
)

bqstorage_info = google.api_core.gapic_v1.client_info.ClientInfo(
user_agent=self._application_name
)
self._bqstoragewriteclient = (
google.cloud.bigquery_storage_v1.BigQueryWriteClient(
client_info=bqstorage_info,
client_options=bqstorage_options,
credentials=self._credentials,
)
)

return self._bqstoragewriteclient

@property
def cloudfunctionsclient(self):
if not self._cloudfunctionsclient:
Expand Down
71 changes: 66 additions & 5 deletions bigframes/session/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import typing
from typing import (
Dict,
Generator,
Hashable,
IO,
Iterable,
Expand All @@ -36,12 +37,13 @@
import bigframes_vendored.constants as constants
import bigframes_vendored.pandas.io.gbq as third_party_pandas_gbq
import google.api_core.exceptions
from google.cloud import bigquery_storage_v1
import google.cloud.bigquery as bigquery
import google.cloud.bigquery.table
from google.cloud.bigquery_storage_v1 import types as bq_storage_types
import pandas
import pyarrow as pa

from bigframes.core import local_data, utils
from bigframes.core import guid, local_data, utils
import bigframes.core as core
import bigframes.core.blocks as blocks
import bigframes.core.schema as schemata
Expand Down Expand Up @@ -142,13 +144,15 @@ def __init__(
self,
session: bigframes.session.Session,
bqclient: bigquery.Client,
write_client: bigquery_storage_v1.BigQueryWriteClient,
storage_manager: bigframes.session.temporary_storage.TemporaryStorageManager,
default_index_type: bigframes.enums.DefaultIndexKind,
scan_index_uniqueness: bool,
force_total_order: bool,
metrics: Optional[bigframes.session.metrics.ExecutionMetrics] = None,
):
self._bqclient = bqclient
self._write_client = write_client
self._storage_manager = storage_manager
self._default_index_type = default_index_type
self._scan_index_uniqueness = scan_index_uniqueness
Expand All @@ -165,7 +169,7 @@ def __init__(
def read_pandas(
self,
pandas_dataframe: pandas.DataFrame,
method: Literal["load", "stream"],
method: Literal["load", "stream", "write"],
api_name: str,
) -> dataframe.DataFrame:
# TODO: Push this into from_pandas, along with index flag
Expand All @@ -183,6 +187,8 @@ def read_pandas(
array_value = self.load_data(managed_data, api_name=api_name)
elif method == "stream":
array_value = self.stream_data(managed_data)
elif method == "write":
array_value = self.write_data(managed_data)
else:
raise ValueError(f"Unsupported read method {method}")

Expand All @@ -198,7 +204,7 @@ def load_data(
self, data: local_data.ManagedArrowTable, api_name: Optional[str] = None
) -> core.ArrayValue:
"""Load managed data into bigquery"""
ordering_col = "bf_load_job_offsets"
ordering_col = guid.generate_guid("load_offsets_")

# JSON support incomplete
for item in data.schema.items:
Expand Down Expand Up @@ -244,7 +250,7 @@ def load_data(

def stream_data(self, data: local_data.ManagedArrowTable) -> core.ArrayValue:
"""Load managed data into bigquery"""
ordering_col = "bf_stream_job_offsets"
ordering_col = guid.generate_guid("stream_offsets_")
schema_w_offsets = data.schema.append(
schemata.SchemaItem(ordering_col, bigframes.dtypes.INT_DTYPE)
)
Expand Down Expand Up @@ -277,6 +283,61 @@ def stream_data(self, data: local_data.ManagedArrowTable) -> core.ArrayValue:
n_rows=data.data.num_rows,
).drop_columns([ordering_col])

def write_data(self, data: local_data.ManagedArrowTable) -> core.ArrayValue:
"""Load managed data into bigquery"""
ordering_col = guid.generate_guid("stream_offsets_")
schema_w_offsets = data.schema.append(
schemata.SchemaItem(ordering_col, bigframes.dtypes.INT_DTYPE)
)
bq_schema = schema_w_offsets.to_bigquery(_STREAM_JOB_TYPE_OVERRIDES)
bq_table_ref = self._storage_manager.create_temp_table(
bq_schema, [ordering_col]
)

requested_stream = bq_storage_types.stream.WriteStream()
requested_stream.type_ = bq_storage_types.stream.WriteStream.Type.COMMITTED # type: ignore

stream_request = bq_storage_types.CreateWriteStreamRequest(
parent=bq_table_ref.to_bqstorage(), write_stream=requested_stream
)
stream = self._write_client.create_write_stream(request=stream_request)

def request_gen() -> Generator[bq_storage_types.AppendRowsRequest, None, None]:
schema, batches = data.to_arrow(
offsets_col=ordering_col, duration_type="int"
)
offset = 0
for batch in batches:
request = bq_storage_types.AppendRowsRequest(
write_stream=stream.name, offset=offset
)
request.arrow_rows.writer_schema.serialized_schema = (
schema.serialize().to_pybytes()
)
request.arrow_rows.rows.serialized_record_batch = (
batch.serialize().to_pybytes()
)
offset += batch.num_rows
yield request

for response in self._write_client.append_rows(requests=request_gen()):
if response.row_errors:
raise ValueError(
f"Problem loading at least one row from DataFrame: {response.row_errors}. {constants.FEEDBACK_LINK}"
)
# This step isn't strictly necessary in COMMITTED mode, but avoids max active stream limits
response = self._write_client.finalize_write_stream(name=stream.name)
assert response.row_count == data.data.num_rows

destination_table = self._bqclient.get_table(bq_table_ref)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do something here to finalize the stream? https://cloud.google.com/bigquery/docs/write-api-streaming

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not strictly necessary, but can avoid limits, per docs: "This step is optional in committed type, but helps to prevent exceeding the limit on active streams"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added finalize in new iteration

return core.ArrayValue.from_table(
table=destination_table,
schema=schema_w_offsets,
session=self._session,
offsets_col=ordering_col,
n_rows=data.data.num_rows,
).drop_columns([ordering_col])

def _start_generic_job(self, job: formatting_helpers.GenericJob):
if bigframes.options.display.progress_bar is not None:
formatting_helpers.wait_for_job(
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
"google-cloud-bigtable >=2.24.0",
"google-cloud-pubsub >=2.21.4",
"google-cloud-bigquery[bqstorage,pandas] >=3.31.0",
# 2.30 needed for arrow support.
"google-cloud-bigquery-storage >= 2.30.0, < 3.0.0",
"google-cloud-functions >=1.12.0",
"google-cloud-bigquery-connection >=1.12.0",
"google-cloud-iam >=2.12.1",
Expand Down
3 changes: 2 additions & 1 deletion tests/system/small/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def test_df_construct_pandas_default(scalars_dfs):
("bigquery_inline"),
("bigquery_load"),
("bigquery_streaming"),
("bigquery_write"),
],
)
def test_read_pandas_all_nice_types(
Expand Down Expand Up @@ -1772,7 +1773,7 @@ def test_len(scalars_dfs):
)
@pytest.mark.parametrize(
"write_engine",
["bigquery_load", "bigquery_streaming"],
["bigquery_load", "bigquery_streaming", "bigquery_write"],
)
def test_df_len_local(session, n_rows, write_engine):
assert (
Expand Down
Loading