Skip to content

perf: Use flyweight for node fields #1654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions bigframes/core/bigframe_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@
import functools
import itertools
import typing
from typing import Callable, Dict, Generator, Iterable, Mapping, Set, Tuple
from typing import Callable, Dict, Generator, Iterable, Mapping, Sequence, Set, Tuple

from bigframes.core import identifiers
import bigframes.core.guid
import bigframes.core.schema as schemata
import bigframes.dtypes

Expand Down Expand Up @@ -163,7 +162,7 @@ def roots(self) -> typing.Set[BigFrameNode]:
# TODO: Store some local data lazily for select, aggregate nodes.
@property
@abc.abstractmethod
def fields(self) -> Iterable[Field]:
def fields(self) -> Sequence[Field]:
...

@property
Expand Down
80 changes: 47 additions & 33 deletions bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

import google.cloud.bigquery as bq

from bigframes.core import identifiers, local_data
from bigframes.core import identifiers, local_data, sequences
from bigframes.core.bigframe_node import BigFrameNode, COLUMN_SET, Field
import bigframes.core.expression as ex
from bigframes.core.ordering import OrderingExpression, RowOrdering
Expand Down Expand Up @@ -87,7 +87,7 @@ def child_nodes(self) -> typing.Sequence[BigFrameNode]:
return (self.child,)

@property
def fields(self) -> Iterable[Field]:
def fields(self) -> Sequence[Field]:
return self.child.fields

@property
Expand Down Expand Up @@ -226,8 +226,8 @@ def added_fields(self) -> Tuple[Field, ...]:
return (Field(self.indicator_col, bigframes.dtypes.BOOL_DTYPE, nullable=False),)

@property
def fields(self) -> Iterable[Field]:
return itertools.chain(
def fields(self) -> Sequence[Field]:
return sequences.ChainedSequence(
self.left_child.fields,
self.added_fields,
)
Expand Down Expand Up @@ -321,15 +321,15 @@ def order_ambiguous(self) -> bool:
def explicitly_ordered(self) -> bool:
return self.propogate_order

@property
def fields(self) -> Iterable[Field]:
left_fields = self.left_child.fields
@functools.cached_property
def fields(self) -> Sequence[Field]:
left_fields: Iterable[Field] = self.left_child.fields
if self.type in ("right", "outer"):
left_fields = map(lambda x: x.with_nullable(), left_fields)
right_fields = self.right_child.fields
right_fields: Iterable[Field] = self.right_child.fields
if self.type in ("left", "outer"):
right_fields = map(lambda x: x.with_nullable(), right_fields)
return itertools.chain(left_fields, right_fields)
return (*left_fields, *right_fields)

@property
def joins_nulls(self) -> bool:
Expand Down Expand Up @@ -430,10 +430,10 @@ def explicitly_ordered(self) -> bool:
return True

@property
def fields(self) -> Iterable[Field]:
def fields(self) -> Sequence[Field]:
# TODO: Output names should probably be aligned beforehand or be part of concat definition
# TODO: Handle nullability
return (
return tuple(
Field(id, field.dtype)
for id, field in zip(self.output_ids, self.children[0].fields)
)
Expand Down Expand Up @@ -505,7 +505,7 @@ def explicitly_ordered(self) -> bool:
return True

@functools.cached_property
def fields(self) -> Iterable[Field]:
def fields(self) -> Sequence[Field]:
return (
Field(self.output_id, next(iter(self.start.fields)).dtype, nullable=False),
)
Expand Down Expand Up @@ -626,12 +626,20 @@ class ReadLocalNode(LeafNode):
session: typing.Optional[bigframes.session.Session] = None

@property
def fields(self) -> Iterable[Field]:
fields = (Field(col_id, dtype) for col_id, dtype, _ in self.scan_list.items)
def fields(self) -> Sequence[Field]:
fields = tuple(
Field(col_id, dtype) for col_id, dtype, _ in self.scan_list.items
)
if self.offsets_col is not None:
return itertools.chain(
fields,
(Field(self.offsets_col, bigframes.dtypes.INT_DTYPE, nullable=False),),
return tuple(
itertools.chain(
fields,
(
Field(
self.offsets_col, bigframes.dtypes.INT_DTYPE, nullable=False
),
),
)
)
return fields

Expand Down Expand Up @@ -767,8 +775,8 @@ def session(self):
return self.table_session

@property
def fields(self) -> Iterable[Field]:
return (
def fields(self) -> Sequence[Field]:
return tuple(
Field(col_id, dtype, self.source.table.schema_by_id[source_id].is_nullable)
for col_id, dtype, source_id in self.scan_list.items
)
Expand Down Expand Up @@ -881,8 +889,8 @@ def non_local(self) -> bool:
return True

@property
def fields(self) -> Iterable[Field]:
return itertools.chain(self.child.fields, self.added_fields)
def fields(self) -> Sequence[Field]:
return sequences.ChainedSequence(self.child.fields, self.added_fields)

@property
def relation_ops_created(self) -> int:
Expand Down Expand Up @@ -1097,7 +1105,7 @@ def _validate(self):
raise ValueError(f"Reference to column not in child: {ref.id}")

@functools.cached_property
def fields(self) -> Iterable[Field]:
def fields(self) -> Sequence[Field]:
input_fields_by_id = {field.id: field for field in self.child.fields}
return tuple(
Field(
Expand Down Expand Up @@ -1192,8 +1200,8 @@ def added_fields(self) -> Tuple[Field, ...]:
return tuple(fields)

@property
def fields(self) -> Iterable[Field]:
return itertools.chain(self.child.fields, self.added_fields)
def fields(self) -> Sequence[Field]:
return sequences.ChainedSequence(self.child.fields, self.added_fields)

@property
def variables_introduced(self) -> int:
Expand Down Expand Up @@ -1263,7 +1271,7 @@ def non_local(self) -> bool:
return True

@property
def fields(self) -> Iterable[Field]:
def fields(self) -> Sequence[Field]:
return (Field(self.col_id, bigframes.dtypes.INT_DTYPE, nullable=False),)

@property
Expand Down Expand Up @@ -1313,7 +1321,7 @@ def non_local(self) -> bool:
return True

@functools.cached_property
def fields(self) -> Iterable[Field]:
def fields(self) -> Sequence[Field]:
# TODO: Use child nullability to infer grouping key nullability
by_fields = (self.child.field_by_id[ref.id] for ref in self.by_column_ids)
if self.dropna:
Expand Down Expand Up @@ -1411,8 +1419,8 @@ def non_local(self) -> bool:
return True

@property
def fields(self) -> Iterable[Field]:
return itertools.chain(self.child.fields, [self.added_field])
def fields(self) -> Sequence[Field]:
return sequences.ChainedSequence(self.child.fields, (self.added_field,))

@property
def variables_introduced(self) -> int:
Expand Down Expand Up @@ -1547,7 +1555,7 @@ def row_preserving(self) -> bool:
return False

@property
def fields(self) -> Iterable[Field]:
def fields(self) -> Sequence[Field]:
fields = (
Field(
field.id,
Expand All @@ -1561,11 +1569,17 @@ def fields(self) -> Iterable[Field]:
for field in self.child.fields
)
if self.offsets_col is not None:
return itertools.chain(
fields,
(Field(self.offsets_col, bigframes.dtypes.INT_DTYPE, nullable=False),),
return tuple(
itertools.chain(
fields,
(
Field(
self.offsets_col, bigframes.dtypes.INT_DTYPE, nullable=False
),
),
)
)
return fields
return tuple(fields)

@property
def relation_ops_created(self) -> int:
Expand Down
7 changes: 5 additions & 2 deletions bigframes/core/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from dataclasses import dataclass
import functools
import typing
from typing import Sequence

import google.cloud.bigquery
import pyarrow

import bigframes.core.guid
import bigframes.dtypes

ColumnIdentifierType = str
Expand All @@ -35,7 +35,10 @@ class SchemaItem:

@dataclass(frozen=True)
class ArraySchema:
items: typing.Tuple[SchemaItem, ...]
items: Sequence[SchemaItem]

def __iter__(self):
yield from self.items

@classmethod
def from_bq_table(
Expand Down
105 changes: 105 additions & 0 deletions bigframes/core/sequences.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import collections.abc
import functools
import itertools
from typing import Iterable, Iterator, Sequence, TypeVar

ColumnIdentifierType = str


T = TypeVar("T")

# Further optimizations possible:
# * Support mapping operators
# * Support insertions and deletions


class ChainedSequence(collections.abc.Sequence[T]):
"""
Memory-optimized sequence from composing chain of existing sequences.

Will use the provided parts as underlying storage - so do not mutate provided parts.
May merge small underlying parts for better access performance.
"""

def __init__(self, *parts: Sequence[T]):
# Could build an index that makes random access faster?
self._parts: tuple[Sequence[T], ...] = tuple(
_defrag_parts(_flatten_parts(parts))
)

def __getitem__(self, index):
if isinstance(index, slice):
return tuple(self)[index]
if index < 0:
index = len(self) + index
if index < 0:
raise IndexError("Index out of bounds")

offset = 0
for part in self._parts:
if (index - offset) < len(part):
return part[index - offset]
offset += len(part)
raise IndexError("Index out of bounds")

@functools.cache
def __len__(self):
return sum(map(len, self._parts))

def __iter__(self):
for part in self._parts:
yield from part


def _flatten_parts(parts: Iterable[Sequence[T]]) -> Iterator[Sequence[T]]:
for part in parts:
if isinstance(part, ChainedSequence):
yield from part._parts
else:
yield part


# Should be a cache-friendly chunk size?
_TARGET_SIZE = 128
_MAX_MERGABLE = 32


def _defrag_parts(parts: Iterable[Sequence[T]]) -> Iterator[Sequence[T]]:
"""
Merge small chunks into larger chunks for better performance.
"""
parts_queue: list[Sequence[T]] = []
queued_items = 0
for part in parts:
# too big, just yield from the buffer
if len(part) > _MAX_MERGABLE:
yield from parts_queue
parts_queue = []
queued_items = 0
yield part
else: # can be merged, so lets add to the queue
parts_queue.append(part)
queued_items += len(part)
# if queue has reached target size, merge, dump and reset queue
if queued_items >= _TARGET_SIZE:
yield tuple(itertools.chain(*parts_queue))
parts_queue = []
queued_items = 0

yield from parts_queue
Loading