Skip to content

feat: support primary key(s) in read_gbq by using as the index_col by default #625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,13 +708,15 @@ def _get_snapshot_sql_and_primary_key(
f"Current session is in {self._location} but dataset '{table.project}.{table.dataset_id}' is located in {table.location}"
)

# TODO(b/305264153): Use public properties to fetch primary keys once
# added to google-cloud-bigquery.
primary_keys = (
table._properties.get("tableConstraints", {})
.get("primaryKey", {})
.get("columns")
)
primary_keys = None
if (
(table_constraints := getattr(table, "table_constraints", None)) is not None
and (primary_key := table_constraints.primary_key) is not None
# This will be False for either None or empty list.
# We want primary_keys = None if no primary keys are set.
and (columns := primary_key.columns)
):
primary_keys = columns

job_config = bigquery.QueryJobConfig()
job_config.labels["bigframes-api"] = api_name
Expand Down Expand Up @@ -777,12 +779,13 @@ def _read_gbq_table(
query, default_project=self.bqclient.project
)

(
table_expression,
total_ordering_cols,
) = self._get_snapshot_sql_and_primary_key(
(table_expression, primary_keys,) = self._get_snapshot_sql_and_primary_key(
table_ref, api_name=api_name, use_cache=use_cache
)
total_ordering_cols = primary_keys

if not index_col and primary_keys is not None:
index_col = primary_keys

for key in columns:
if key not in table_expression.columns:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"gcsfs >=2023.3.0",
"geopandas >=0.12.2",
"google-auth >=2.15.0,<3.0dev",
"google-cloud-bigquery[bqstorage,pandas] >=3.10.0",
"google-cloud-bigquery[bqstorage,pandas] >=3.16.0",
"google-cloud-functions >=1.12.0",
"google-cloud-bigquery-connection >=1.12.0",
"google-cloud-iam >=2.12.1",
Expand Down
2 changes: 1 addition & 1 deletion testing/constraints-3.9.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ fsspec==2023.3.0
gcsfs==2023.3.0
geopandas==0.12.2
google-auth==2.15.0
google-cloud-bigquery==3.10.0
google-cloud-bigquery==3.16.0
google-cloud-functions==1.12.0
google-cloud-bigquery-connection==1.12.0
google-cloud-iam==2.12.1
Expand Down
13 changes: 6 additions & 7 deletions tests/system/small/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,13 @@ def test_read_gbq_w_anonymous_query_results_table(session: bigframes.Session):
def test_read_gbq_w_primary_keys_table(
session: bigframes.Session, usa_names_grouped_table: bigquery.Table
):
# Validate that the table we're querying has a primary key.
table = usa_names_grouped_table
# TODO(b/305264153): Use public properties to fetch primary keys once
# added to google-cloud-bigquery.
primary_keys = (
table._properties.get("tableConstraints", {})
.get("primaryKey", {})
.get("columns")
)
table_constraints = table.table_constraints
assert table_constraints is not None
primary_key = table_constraints.primary_key
assert primary_key is not None
primary_keys = primary_key.columns
assert len(primary_keys) != 0

df = session.read_gbq(f"{table.project}.{table.dataset_id}.{table.table_id}")
Expand Down
7 changes: 4 additions & 3 deletions tests/unit/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import datetime
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Sequence
import unittest.mock as mock

import google.auth.credentials
Expand All @@ -37,6 +37,7 @@
def create_bigquery_session(
bqclient: Optional[mock.Mock] = None,
session_id: str = "abcxyz",
table_schema: Sequence[google.cloud.bigquery.SchemaField] = TEST_SCHEMA,
anonymous_dataset: Optional[google.cloud.bigquery.DatasetReference] = None,
) -> bigframes.Session:
credentials = mock.create_autospec(
Expand All @@ -51,7 +52,7 @@ def create_bigquery_session(
table = mock.create_autospec(google.cloud.bigquery.Table, instance=True)
table._properties = {}
type(table).location = mock.PropertyMock(return_value="test-region")
type(table).schema = mock.PropertyMock(return_value=TEST_SCHEMA)
type(table).schema = mock.PropertyMock(return_value=table_schema)
bqclient.get_table.return_value = table

if anonymous_dataset is None:
Expand All @@ -72,7 +73,7 @@ def query_mock(query, *args, **kwargs):
if query.startswith("SELECT CURRENT_TIMESTAMP()"):
query_job.result = mock.MagicMock(return_value=[[datetime.datetime.now()]])
else:
type(query_job).schema = mock.PropertyMock(return_value=TEST_SCHEMA)
type(query_job).schema = mock.PropertyMock(return_value=table_schema)

return query_job

Expand Down
39 changes: 39 additions & 0 deletions tests/unit/session/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@

import google.api_core.exceptions
import google.cloud.bigquery
import google.cloud.bigquery.table
import pytest

import bigframes
import bigframes.exceptions

from .. import resources

Expand Down Expand Up @@ -50,6 +52,43 @@ def test_read_gbq_cached_table():
assert "1999-01-02T03:04:05.678901" in df.sql


def test_read_gbq_clustered_table_ok_default_index_with_primary_key():
"""If a primary key is set on the table, we use that as the index column
by default, no error should be raised in this case.

See internal issue 335727141.
"""
table = google.cloud.bigquery.Table("my-project.my_dataset.my_table")
table.clustering_fields = ["col1", "col2"]
table.schema = (
google.cloud.bigquery.SchemaField("pk_1", "INT64"),
google.cloud.bigquery.SchemaField("pk_2", "INT64"),
google.cloud.bigquery.SchemaField("col_1", "INT64"),
google.cloud.bigquery.SchemaField("col_2", "INT64"),
)

# TODO(b/305264153): use setter for table_constraints in client library
# when available.
table._properties["tableConstraints"] = {
"primaryKey": {
"columns": ["pk_1", "pk_2"],
},
}
bqclient = mock.create_autospec(google.cloud.bigquery.Client, instance=True)
bqclient.project = "test-project"
bqclient.get_table.return_value = table
session = resources.create_bigquery_session(
bqclient=bqclient, table_schema=table.schema
)
table._properties["location"] = session._location

df = session.read_gbq("my-project.my_dataset.my_table")

# There should be no analytic operators to prevent row filtering pushdown.
assert "OVER" not in df.sql
assert tuple(df.index.names) == ("pk_1", "pk_2")


@pytest.mark.parametrize(
"not_found_table_id",
[("unknown.dataset.table"), ("project.unknown.table"), ("project.dataset.unknown")],
Expand Down
3 changes: 3 additions & 0 deletions third_party/bigframes_vendored/pandas/io/gbq.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def read_gbq(
In tha case, will read all the matched table as one DataFrame.
index_col (Iterable[str] or str):
Name of result column(s) to use for index in results DataFrame.

**New in bigframes version 1.3.0**: If ``index_cols`` is not
set, the primary key(s) of the table are used as the index.
columns (Iterable[str]):
List of BigQuery column names in the desired order for results
DataFrame.
Expand Down