Skip to content

Initial tensorflow stubs #8974

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Jan 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions stubs/tensorflow/@tests/stubtest_allowlist.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Some methods are dynamically patched onto to instances as they
# may depend on whether code is executed in graph/eager/v1/v2/etc.
# Tensorflow supports multiple modes of execution which changes some
# of the attributes/methods/even class hierachies.
tensorflow.Tensor.__int__
tensorflow.Tensor.numpy
tensorflow.Tensor.__index__
# Incomplete
tensorflow.sparse.SparseTensor.__getattr__
tensorflow.SparseTensor.__getattr__
tensorflow.TensorShape.__getattr__
tensorflow.dtypes.DType.__getattr__
tensorflow.RaggedTensor.__getattr__
tensorflow.DType.__getattr__
tensorflow.Graph.__getattr__
tensorflow.Operation.__getattr__
tensorflow.Variable.__getattr__
# Internal undocumented API
tensorflow.RaggedTensor.__init__
# Has an undocumented extra argument that tf.Variable which acts like subclass
# (by dynamically patching tf.Tensor methods) does not preserve.
tensorflow.Tensor.__getitem__
3 changes: 3 additions & 0 deletions stubs/tensorflow/METADATA.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
version = "2.10.*"
# requires a version of numpy with a `py.typed` file
requires = ["numpy>=1.20"]
195 changes: 195 additions & 0 deletions stubs/tensorflow/tensorflow/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from _typeshed import Incomplete, Self, Unused
from abc import ABCMeta
from builtins import bool as _bool
from collections.abc import Callable, Iterable, Iterator, Sequence
from contextlib import contextmanager
from enum import Enum
from typing import Any, NoReturn, overload
from typing_extensions import TypeAlias

import numpy
from tensorflow.dtypes import *

# Most tf.math functions are exported as tf, but sadly not all are.
from tensorflow.math import abs as abs
from tensorflow.sparse import SparseTensor

# Tensors ideally should be a generic type, but properly typing data type/shape
# will be a lot of work. Until we have good non-generic tensorflow stubs,
# we will skip making Tensor generic. Also good type hints for shapes will
# run quickly into many places where type system is not strong enough today.
# So shape typing is probably not worth doing anytime soon.
_Slice: TypeAlias = int | slice | None

_FloatDataSequence: TypeAlias = Sequence[float] | Sequence[_FloatDataSequence]
_StrDataSequence: TypeAlias = Sequence[str] | Sequence[_StrDataSequence]
_ScalarTensorCompatible: TypeAlias = Tensor | str | float | numpy.ndarray[Any, Any] | numpy.number[Any]
_TensorCompatible: TypeAlias = _ScalarTensorCompatible | Sequence[_TensorCompatible]
_ShapeLike: TypeAlias = TensorShape | Iterable[_ScalarTensorCompatible | None] | int | Tensor
_DTypeLike: TypeAlias = DType | str | numpy.dtype[Any]

class Tensor:
def __init__(self, op: Operation, value_index: int, dtype: DType) -> None: ...
def consumers(self) -> list[Incomplete]: ...
@property
def shape(self) -> TensorShape: ...
def get_shape(self) -> TensorShape: ...
@property
def dtype(self) -> DType: ...
@property
def graph(self) -> Graph: ...
@property
def name(self) -> str: ...
@property
def op(self) -> Operation: ...
def numpy(self) -> numpy.ndarray[Any, Any]: ...
def __int__(self) -> int: ...
def __abs__(self, name: str | None = None) -> Tensor: ...
def __add__(self, other: _TensorCompatible) -> Tensor: ...
def __radd__(self, other: _TensorCompatible) -> Tensor: ...
def __sub__(self, other: _TensorCompatible) -> Tensor: ...
def __rsub__(self, other: _TensorCompatible) -> Tensor: ...
def __mul__(self, other: _TensorCompatible) -> Tensor: ...
def __rmul__(self, other: _TensorCompatible) -> Tensor: ...
def __pow__(self, other: _TensorCompatible) -> Tensor: ...
def __matmul__(self, other: _TensorCompatible) -> Tensor: ...
def __rmatmul__(self, other: _TensorCompatible) -> Tensor: ...
def __floordiv__(self, other: _TensorCompatible) -> Tensor: ...
def __rfloordiv__(self, other: _TensorCompatible) -> Tensor: ...
def __truediv__(self, other: _TensorCompatible) -> Tensor: ...
def __rtruediv__(self, other: _TensorCompatible) -> Tensor: ...
def __neg__(self, name: str | None = None) -> Tensor: ...
def __and__(self, other: _TensorCompatible) -> Tensor: ...
def __rand__(self, other: _TensorCompatible) -> Tensor: ...
def __or__(self, other: _TensorCompatible) -> Tensor: ...
def __ror__(self, other: _TensorCompatible) -> Tensor: ...
def __eq__(self, other: _TensorCompatible) -> Tensor: ... # type: ignore[override]
def __ne__(self, other: _TensorCompatible) -> Tensor: ... # type: ignore[override]
def __ge__(self, other: _TensorCompatible, name: str | None = None) -> Tensor: ...
def __gt__(self, other: _TensorCompatible, name: str | None = None) -> Tensor: ...
def __le__(self, other: _TensorCompatible, name: str | None = None) -> Tensor: ...
def __lt__(self, other: _TensorCompatible, name: str | None = None) -> Tensor: ...
def __bool__(self) -> NoReturn: ...
def __getitem__(self, slice_spec: _Slice | tuple[_Slice, ...]) -> Tensor: ...
def __len__(self) -> int: ...
# This only works for rank 0 tensors.
def __index__(self) -> int: ...
def __getattr__(self, name: str) -> Incomplete: ...

class VariableSynchronization(Enum):
AUTO = 0
NONE = 1
ON_WRITE = 2
ON_READ = 3

class VariableAggregation(Enum):
AUTO = 0
NONE = 1
ON_WRITE = 2
ON_READ = 3

class _VariableMetaclass(type): ...

# Variable class in intent/documentation is a Tensor. In implementation there's
# TODO comment to make it Tensor. It is not actually Tensor type wise, but even
# dynamically patches on most methods of tf.Tensor
# https://github.com/tensorflow/tensorflow/blob/9524a636cae9ae3f0554203c1ba7ee29c85fcf12/tensorflow/python/ops/variables.py#L1086.
class Variable(Tensor, metaclass=_VariableMetaclass):
def __init__(
self,
initial_value: Tensor | Callable[[], Tensor] | None = None,
trainable: _bool | None = None,
validate_shape: _bool = True,
# Valid non-None values are deprecated.
caching_device: None = None,
name: str | None = None,
# Real type is VariableDef protobuf type. Can be added after adding script
# to generate tensorflow protobuf stubs with mypy-protobuf.
variable_def: Incomplete | None = None,
dtype: _DTypeLike | None = None,
import_scope: str | None = None,
constraint: Callable[[Tensor], Tensor] | None = None,
synchronization: VariableSynchronization = VariableSynchronization.AUTO,
aggregation: VariableAggregation = VariableAggregation.NONE,
shape: _ShapeLike | None = None,
) -> None: ...
def __getattr__(self, name: str) -> Incomplete: ...

class RaggedTensor(metaclass=ABCMeta):
def bounding_shape(
self, axis: _TensorCompatible | None = None, name: str | None = None, out_type: _DTypeLike | None = None
) -> Tensor: ...
@classmethod
def from_sparse(
cls, st_input: SparseTensor, name: str | None = None, row_splits_dtype: _DTypeLike = int64
) -> RaggedTensor: ...
def to_sparse(self, name: str | None = None) -> SparseTensor: ...
def to_tensor(
self, default_value: float | str | None = None, name: str | None = None, shape: _ShapeLike | None = None
) -> Tensor: ...
def __add__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
def __radd__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
def __sub__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
def __mul__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
def __rmul__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
def __floordiv__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
def __truediv__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
def __getitem__(self, slice_spec: _Slice | tuple[_Slice, ...]) -> RaggedTensor: ...
def __getattr__(self, name: str) -> Incomplete: ...

class Operation:
def __init__(
self,
node_def: Incomplete,
g: Graph,
# isinstance is used so can not be Sequence/Iterable.
inputs: list[Tensor] | None = None,
output_types: Unused = None,
control_inputs: Iterable[Tensor | Operation] | None = None,
input_types: Iterable[DType] | None = None,
original_op: Operation | None = None,
op_def: Incomplete = None,
) -> None: ...
@property
def inputs(self) -> list[Tensor]: ...
@property
def outputs(self) -> list[Tensor]: ...
@property
def device(self) -> str: ...
@property
def name(self) -> str: ...
@property
def type(self) -> str: ...
def __getattr__(self, name: str) -> Incomplete: ...

class TensorShape(metaclass=ABCMeta):
def __init__(self, dims: _ShapeLike) -> None: ...
@property
def rank(self) -> int: ...
def as_list(self) -> list[int | None]: ...
def assert_has_rank(self, rank: int) -> None: ...
def assert_is_compatible_with(self, other: Iterable[int | None]) -> None: ...
def __bool__(self) -> _bool: ...
@overload
def __getitem__(self, key: int) -> int | None: ...
@overload
def __getitem__(self, key: slice) -> TensorShape: ...
def __iter__(self) -> Iterator[int | None]: ...
def __len__(self) -> int: ...
def __add__(self, other: Iterable[int | None]) -> TensorShape: ...
def __radd__(self, other: Iterable[int | None]) -> TensorShape: ...
def __getattr__(self, name: str) -> Incomplete: ...

class Graph:
def add_to_collection(self, name: str, value: object) -> None: ...
def add_to_collections(self, names: Iterable[str] | str, value: object) -> None: ...
@contextmanager
def as_default(self: Self) -> Iterator[Self]: ...
def finalize(self) -> None: ...
def get_tensor_by_name(self, name: str) -> Tensor: ...
def get_operation_by_name(self, name: str) -> Operation: ...
def get_operations(self) -> list[Operation]: ...
def get_name_scope(self) -> str: ...
def __getattr__(self, name: str) -> Incomplete: ...

def __getattr__(name: str) -> Incomplete: ...
Empty file.
55 changes: 55 additions & 0 deletions stubs/tensorflow/tensorflow/dtypes.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from _typeshed import Incomplete
from abc import ABCMeta
from builtins import bool as _bool
from typing import Any

import numpy as np
from tensorflow import _DTypeLike

class _DTypeMeta(ABCMeta): ...

class DType(metaclass=_DTypeMeta):
@property
def name(self) -> str: ...
@property
def as_numpy_dtype(self) -> type[np.number[Any]]: ...
@property
def is_numpy_compatible(self) -> _bool: ...
@property
def is_bool(self) -> _bool: ...
@property
def is_floating(self) -> _bool: ...
@property
def is_integer(self) -> _bool: ...
@property
def is_quantized(self) -> _bool: ...
@property
def is_unsigned(self) -> _bool: ...
def __getattr__(self, name: str) -> Incomplete: ...

bool: DType
complex128: DType
complex64: DType
bfloat16: DType
float16: DType
half: DType
float32: DType
float64: DType
double: DType
int8: DType
int16: DType
int32: DType
int64: DType
uint8: DType
uint16: DType
uint32: DType
uint64: DType
qint8: DType
qint16: DType
qint32: DType
quint8: DType
quint16: DType
string: DType

def as_dtype(type_value: _DTypeLike) -> DType: ...
def __getattr__(name: str) -> Incomplete: ...
13 changes: 13 additions & 0 deletions stubs/tensorflow/tensorflow/math.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from _typeshed import Incomplete
from typing import overload

from tensorflow import RaggedTensor, Tensor, _TensorCompatible
from tensorflow.sparse import SparseTensor

@overload
def abs(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def abs(x: SparseTensor, name: str | None = None) -> SparseTensor: ...
@overload
def abs(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
def __getattr__(name: str) -> Incomplete: ...
30 changes: 30 additions & 0 deletions stubs/tensorflow/tensorflow/sparse.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from _typeshed import Incomplete
from abc import ABCMeta
from typing_extensions import TypeAlias

from tensorflow import Tensor, TensorShape, _TensorCompatible
from tensorflow.dtypes import DType

_SparseTensorCompatible: TypeAlias = _TensorCompatible | SparseTensor

class SparseTensor(metaclass=ABCMeta):
@property
def indices(self) -> Tensor: ...
@property
def values(self) -> Tensor: ...
@property
def dense_shape(self) -> Tensor: ...
@property
def shape(self) -> TensorShape: ...
@property
def dtype(self) -> DType: ...
name: str
def __init__(self, indices: _TensorCompatible, values: _TensorCompatible, dense_shape: _TensorCompatible) -> None: ...
def get_shape(self) -> TensorShape: ...
# Many arithmetic operations are not directly supported. Some have alternatives like tf.sparse.add instead of +.
def __div__(self, y: _SparseTensorCompatible) -> SparseTensor: ...
def __truediv__(self, y: _SparseTensorCompatible) -> SparseTensor: ...
def __mul__(self, y: _SparseTensorCompatible) -> SparseTensor: ...
def __getattr__(self, name: str) -> Incomplete: ...

def __getattr__(name: str) -> Incomplete: ...