diff --git a/README.rst b/README.rst
index f6243dc6c..ff32c8881 100644
--- a/README.rst
+++ b/README.rst
@@ -129,6 +129,12 @@ For plotting
- `matplotlib `__:
required for plotting.
+Miscellaneous
+~~~~~~~~~~~~~
+
+- `pydantic `__:
+ required to use `CheckedSession`.
+
.. _start-documentation:
Documentation
diff --git a/doc/source/api.rst b/doc/source/api.rst
index 029535af8..47289b1bf 100644
--- a/doc/source/api.rst
+++ b/doc/source/api.rst
@@ -795,7 +795,6 @@ Modifying
Session.add
Session.update
- Session.get
Session.apply
Session.transpose
@@ -821,6 +820,30 @@ Load/Save
Session.to_hdf
Session.to_pickle
+CheckedArray
+============
+
+.. autosummary::
+ :toctree: _generated/
+
+ CheckedArray
+
+CheckedSession
+==============
+
+.. autosummary::
+ :toctree: _generated/
+
+ CheckedSession
+
+CheckedParameters
+=================
+
+.. autosummary::
+ :toctree: _generated/
+
+ CheckedParameters
+
.. _api-editor:
Editor
diff --git a/doc/source/changes/version_0_33.rst.inc b/doc/source/changes/version_0_33.rst.inc
index c288ba1f1..03c38d856 100644
--- a/doc/source/changes/version_0_33.rst.inc
+++ b/doc/source/changes/version_0_33.rst.inc
@@ -20,30 +20,21 @@ New features
* added official support for Python 3.9 (0.32.3 already supports it even though it was not mentioned).
-* added a feature (see the :ref:`miscellaneous section ` for details). It works on :ref:`api-axis` and
- :ref:`api-group` objects.
+* added :py:obj:`CheckedSession`, :py:obj:`CheckedParameters` and :py:obj:`CheckedArray` objects.
- Here is an example of the new feature:
+ `CheckedSession` is intended to be inherited by user defined classes in which the variables of a model
+ are declared. By declaring variables, users will speed up the development of their models using the auto-completion
+ (the feature in which development tools like PyCharm try to predict the variable or function a user intends
+ to enter after only a few characters have been typed). All user defined classes inheriting from `CheckedSession`
+ will have access to the same methods as `Session` objects.
- >>> arr = ndtest((2, 3))
- >>> arr
- a\b b0 b1 b2
- a0 0 1 2
- a1 3 4 5
+ `CheckedParameters` is the same as `CheckedSession` but the declared variables cannot be
+ modified after initialization.
- And it can also be used like this:
+ The special :py:funct:`CheckedArray` type represents an Array object with fixed axes and/or dtype.
+ It is intended to be only used along with :py:class:`CheckedSession`.
- >>> arr = ndtest("a=a0..a2")
- >>> arr
- a a0 a1 a2
- 0 1 2
-
-* added another feature in the editor (closes :editor_issue:`1`).
-
- .. note::
-
- - It works for foo bar !
- - It does not work for foo baz !
+ Closes :issue:`832`.
.. _misc:
diff --git a/environment.yml b/environment.yml
index 8b42a3e4d..40bbf5b8c 100644
--- a/environment.yml
+++ b/environment.yml
@@ -12,4 +12,5 @@ dependencies:
- pytest>=3.5
- flake8
- pip:
- - pytest-flake8
\ No newline at end of file
+ - pytest-flake8
+ - pydantic==1.5
\ No newline at end of file
diff --git a/larray/__init__.py b/larray/__init__.py
index bc6402e3f..8cd0eaf94 100644
--- a/larray/__init__.py
+++ b/larray/__init__.py
@@ -8,6 +8,7 @@
eye, all, any, sum, prod, cumsum, cumprod, min, max, mean, ptp, var,
std, median, percentile, stack, zip_array_values, zip_array_items)
from larray.core.session import Session, local_arrays, global_arrays, arrays
+from larray.core.checked import CheckedArray, CheckedSession, CheckedParameters
from larray.core.constants import nan, inf, pi, e, euler_gamma
from larray.core.metadata import Metadata
from larray.core.ufuncs import wrap_elementwise_array_func, maximum, minimum, where
@@ -55,6 +56,8 @@
'median', 'percentile', 'stack', 'zip_array_values', 'zip_array_items',
# session
'Session', 'local_arrays', 'global_arrays', 'arrays',
+ # constrained
+ 'CheckedArray', 'CheckedSession', 'CheckedParameters',
# constants
'nan', 'inf', 'pi', 'e', 'euler_gamma',
# metadata
diff --git a/larray/core/axis.py b/larray/core/axis.py
index e498b35d4..5ce1008c3 100644
--- a/larray/core/axis.py
+++ b/larray/core/axis.py
@@ -839,7 +839,7 @@ def __getitem__(self, key):
-----
key is label-based (slice and fancy indexing are supported)
"""
- # if isinstance(key, basestring):
+ # if isinstance(key, str):
# key = to_keys(key)
def isscalar(k):
@@ -862,7 +862,7 @@ def isscalar(k):
and key.name in self
):
return LGroup(key.name, None, self)
- # elif isinstance(key, basestring) and key in self:
+ # elif isinstance(key, str) and key in self:
# TODO: this is an awful workaround to avoid the "processing" of string keys which exist as is in the axis
# (probably because the string was used in an aggregate function to create the label)
# res = LGroup(slice(None), None, self)
diff --git a/larray/core/checked.py b/larray/core/checked.py
new file mode 100644
index 000000000..71dd6ccb6
--- /dev/null
+++ b/larray/core/checked.py
@@ -0,0 +1,537 @@
+from abc import ABCMeta
+from copy import deepcopy
+import warnings
+
+import numpy as np
+
+from typing import TYPE_CHECKING, Type, Any, Dict, Set, List, no_type_check
+
+from larray.core.metadata import Metadata
+from larray.core.axis import AxisCollection
+from larray.core.group import Group
+from larray.core.array import Array, full
+from larray.core.session import Session
+
+
+class NotLoaded:
+ pass
+
+
+try:
+ import pydantic
+except ImportError:
+ pydantic = None
+
+# moved the not implemented versions of Checked* classes in the beginning of the module
+# otherwise PyCharm do not provide auto-completion for methods of CheckedSession
+# (imported from Session)
+if not pydantic:
+ def CheckedArray(axes: AxisCollection, dtype: np.dtype = float) -> Type[Array]:
+ raise NotImplementedError("CheckedArray cannot be used because pydantic is not installed")
+
+ class CheckedSession:
+ def __init__(self, *args, **kwargs):
+ raise NotImplementedError("CheckedSession class cannot be instantiated "
+ "because pydantic is not installed")
+
+ class CheckedParameters:
+ def __init__(self, *args, **kwargs):
+ raise NotImplementedError("CheckedParameters class cannot be instantiated "
+ "because pydantic is not installed")
+else:
+ from pydantic.fields import ModelField
+ from pydantic.class_validators import Validator
+ from pydantic.main import BaseConfig
+
+ # the implementation of the class below is inspired by the 'ConstrainedBytes' class
+ # from the types.py module of the 'pydantic' library
+ class CheckedArrayImpl(Array):
+ expected_axes: AxisCollection
+ dtype: np.dtype = np.dtype(float)
+
+ # see https://pydantic-docs.helpmanual.io/usage/types/#classes-with-__get_validators__
+ @classmethod
+ def __get_validators__(cls):
+ # one or more validators may be yielded which will be called in the
+ # order to validate the input, each validator will receive as an input
+ # the value returned from the previous validator
+ yield cls.validate
+
+ @classmethod
+ def validate(cls, value, field: ModelField):
+ if not (isinstance(value, Array) or np.isscalar(value)):
+ raise TypeError(f"Expected object of type '{Array.__name__}' or a scalar for "
+ f"the variable '{field.name}' but got object of type '{type(value).__name__}'")
+
+ # check axes
+ if isinstance(value, Array):
+ error_msg = f"Array '{field.name}' was declared with axes {cls.expected_axes} but got array " \
+ f"with axes {value.axes}"
+ # check for extra axes
+ extra_axes = value.axes - cls.expected_axes
+ if extra_axes:
+ raise ValueError(f"{error_msg} (unexpected {extra_axes} "
+ f"{'axes' if len(extra_axes) > 1 else 'axis'})")
+ # check compatible axes
+ try:
+ cls.expected_axes.check_compatible(value.axes)
+ except ValueError as error:
+ error_msg = str(error).replace("incompatible axes", f"Incompatible axis for array '{field.name}'")
+ raise ValueError(error_msg)
+ # broadcast + transpose if needed
+ value = value.expand(cls.expected_axes)
+ # check dtype
+ if value.dtype != cls.dtype:
+ value = value.astype(cls.dtype)
+ return value
+ else:
+ return full(axes=cls.expected_axes, fill_value=value, dtype=cls.dtype)
+
+ # the implementation of the function below is inspired by the 'conbytes' function
+ # from the types.py module of the 'pydantic' library
+
+ def CheckedArray(axes: AxisCollection, dtype: np.dtype = float) -> Type[Array]:
+ # XXX: for a very weird reason I don't know, I have to put the fake import below
+ # to get autocompletion from PyCharm
+ from larray.core.checked import CheckedArrayImpl
+ """
+ Represents a constrained array. It is intended to only be used along with :py:class:`CheckedSession`.
+
+ Its axes are assumed to be "frozen", meaning they are constant all along the execution of the program.
+ A constraint on the dtype of the data can be also specified.
+
+ Parameters
+ ----------
+ axes: AxisCollection
+ Axes of the checked array.
+ dtype: data-type, optional
+ Data-type for the checked array. Defaults to float.
+
+ Returns
+ -------
+ Array
+ Constrained array.
+ """
+ if axes is not None and not isinstance(axes, AxisCollection):
+ axes = AxisCollection(axes)
+ _dtype = np.dtype(dtype)
+
+ class ArrayDefValue(CheckedArrayImpl):
+ expected_axes = axes
+ dtype = _dtype
+
+ return ArrayDefValue
+
+ class AbstractCheckedSession:
+ pass
+
+ # Simplified version of the ModelMetaclass class from pydantic:
+ # https://github.com/samuelcolvin/pydantic/blob/master/pydantic/main.py#L195
+
+ class ModelMetaclass(ABCMeta):
+ @no_type_check # noqa C901
+ def __new__(mcs, name, bases, namespace, **kwargs):
+ from pydantic.fields import Undefined
+ from pydantic.class_validators import extract_validators, inherit_validators
+ from pydantic.types import PyObject
+ from pydantic.typing import is_classvar, resolve_annotations
+ from pydantic.utils import lenient_issubclass, validate_field_name
+ from pydantic.main import inherit_config, prepare_config, UNTOUCHED_TYPES
+
+ fields: Dict[str, ModelField] = {}
+ config = BaseConfig
+ validators: Dict[str, List[Validator]] = {}
+
+ for base in reversed(bases):
+ if issubclass(base, AbstractCheckedSession) and base != AbstractCheckedSession:
+ config = inherit_config(base.__config__, config)
+ fields.update(deepcopy(base.__fields__))
+ validators = inherit_validators(base.__validators__, validators)
+
+ config = inherit_config(namespace.get('Config'), config)
+ validators = inherit_validators(extract_validators(namespace), validators)
+
+ # update fields inherited from base classes
+ for field in fields.values():
+ field.set_config(config)
+ extra_validators = validators.get(field.name, [])
+ if extra_validators:
+ field.class_validators.update(extra_validators)
+ # re-run prepare to add extra validators
+ field.populate_validators()
+
+ prepare_config(config, name)
+
+ # extract and build fields
+ class_vars = set()
+ if (namespace.get('__module__'), namespace.get('__qualname__')) != \
+ ('larray.core.checked', 'CheckedSession'):
+ untouched_types = UNTOUCHED_TYPES + config.keep_untouched
+
+ # annotation only fields need to come first in fields
+ annotations = resolve_annotations(namespace.get('__annotations__', {}),
+ namespace.get('__module__', None))
+ for ann_name, ann_type in annotations.items():
+ if is_classvar(ann_type):
+ class_vars.add(ann_name)
+ elif not ann_name.startswith('_'):
+ validate_field_name(bases, ann_name)
+ value = namespace.get(ann_name, Undefined)
+ if (isinstance(value, untouched_types) and ann_type != PyObject
+ and not lenient_issubclass(getattr(ann_type, '__origin__', None), Type)):
+ continue
+ fields[ann_name] = ModelField.infer(name=ann_name, value=value, annotation=ann_type,
+ class_validators=validators.get(ann_name, []),
+ config=config)
+
+ for var_name, value in namespace.items():
+ # 'var_name not in annotations' because namespace.items() contains annotated fields
+ # with default values
+ # 'var_name not in class_vars' to avoid to update a field if it was redeclared (by mistake)
+ if (var_name not in annotations and not var_name.startswith('_')
+ and not isinstance(value, untouched_types) and var_name not in class_vars):
+ validate_field_name(bases, var_name)
+ # the method ModelField.infer() fails to infer the type of Group objects
+ # (which are interpreted as ndarray objects)
+ annotation = type(value) if isinstance(value, Group) else annotations.get(var_name)
+ inferred = ModelField.infer(name=var_name, value=value, annotation=annotation,
+ class_validators=validators.get(var_name, []), config=config)
+ if var_name in fields and inferred.type_ != fields[var_name].type_:
+ raise TypeError(f'The type of {name}.{var_name} differs from the new default value; '
+ f'if you wish to change the type of this field, please use a type '
+ f'annotation')
+ fields[var_name] = inferred
+
+ new_namespace = {
+ '__config__': config,
+ '__fields__': fields,
+ '__field_defaults__': {n: f.default for n, f in fields.items() if not f.required},
+ '__validators__': validators,
+ **{n: v for n, v in namespace.items() if n not in fields},
+ }
+ return super().__new__(mcs, name, bases, new_namespace, **kwargs)
+
+ class CheckedSession(Session, AbstractCheckedSession, metaclass=ModelMetaclass):
+ """
+ This class is intended to be inherited by user defined classes in which the variables of a model are declared.
+ Each declared variable is constrained by a type defined explicitly or deduced from the given default value
+ (see examples below).
+ All classes inheriting from `CheckedSession` will have access to all methods of the :py:class:`Session` class.
+
+ The special :py:funct:`ConsArray` type represents an Array object with fixed axes and/or dtype.
+ This prevents users from modifying the dimensions (and labels) and/or the dtype of an array by mistake
+ and make sure that the definition of an array remains always valid in the model.
+
+ By declaring variables, users will speed up the development of their models using the auto-completion
+ (the feature in which development tools like PyCharm try to predict the variable or function a user intends
+ to enter after only a few characters have been typed).
+
+ As for normal Session objects, it is still possible to add undeclared variables to instances of
+ classes inheriting from `CheckedSession` but this must be done with caution.
+
+ Parameters
+ ----------
+ *args : str or dict of {str: object} or iterable of tuples (str, object)
+ Path to the file containing the session to load or
+ list/tuple/dictionary containing couples (name, object).
+ **kwargs : dict of {str: object}
+
+ * Objects to add written as name=object
+ * meta : list of pairs or dict or OrderedDict or Metadata, optional
+ Metadata (title, description, author, creation_date, ...) associated with the array.
+ Keys must be strings. Values must be of type string, int, float, date, time or datetime.
+
+ Warnings
+ --------
+ The :py:method:`CheckedSession.filter`, :py:method:`CheckedSession.compact`
+ and :py:method:`CheckedSession.apply` methods return a simple Session object.
+ The type of the declared variables (and the value for the declared constants) will
+ no longer be checked.
+
+ See Also
+ --------
+ Session, CheckedParameters
+
+ Examples
+ --------
+
+ Content of file 'parameters.py'
+
+ >>> from larray import *
+ >>> FIRST_YEAR = 2020
+ >>> LAST_YEAR = 2030
+ >>> AGE = Axis('age=0..10')
+ >>> GENDER = Axis('gender=male,female')
+ >>> TIME = Axis(f'time={FIRST_YEAR}..{LAST_YEAR}')
+
+ Content of file 'model.py'
+
+ >>> class ModelVariables(CheckedSession):
+ ... # --- declare variables with defined types ---
+ ... # Their values will be defined at runtime but must match the specified type.
+ ... birth_rate: Array
+ ... births: Array
+ ... # --- declare variables with a default value ---
+ ... # The default value will be used to set the variable if no value is passed at instantiation (see below).
+ ... # Their type is deduced from their default value and cannot be changed at runtime.
+ ... target_age = AGE[:2] >> '0-2'
+ ... population = zeros((AGE, GENDER, TIME), dtype=int)
+ ... # --- declare checked arrays ---
+ ... # The checked arrays have axes assumed to be "frozen", meaning they are
+ ... # constant all along the execution of the program.
+ ... mortality_rate: CheckedArray((AGE, GENDER))
+ ... # For checked arrays, the default value can be given as a scalar.
+ ... # Optionally, a dtype can be also specified (defaults to float).
+ ... deaths: CheckedArray((AGE, GENDER, TIME), dtype=int) = 0
+
+ >>> variant_name = "baseline"
+ >>> # Instantiation --> create an instance of the ModelVariables class.
+ >>> # Warning: All variables declared without a default value must be set.
+ >>> m = ModelVariables(birth_rate = zeros((AGE, GENDER)),
+ ... births = zeros((AGE, GENDER, TIME), dtype=int),
+ ... mortality_rate = 0)
+
+ >>> # ==== model ====
+ >>> # In the definition of ModelVariables, the 'birth_rate' variable, has been declared as an Array object.
+ >>> # This means that the 'birth_rate' variable will always remain of type Array.
+ >>> # Any attempt to assign a non-Array value to 'birth_rate' will make the program to crash.
+ >>> m.birth_rate = Array([0.045, 0.055], GENDER) # OK
+ >>> m.birth_rate = [0.045, 0.055] # Fails
+ Traceback (most recent call last):
+ ...
+ pydantic.errors.ArbitraryTypeError: instance of Array expected
+ >>> # However, the arrays 'birth_rate', 'births' and 'population' have not been declared as 'CheckedArray'.
+ >>> # Thus, axes and dtype of these arrays are not protected, leading to potentially unexpected behavior
+ >>> # of the model.
+ >>> # example 1: Let's say we want to calculate the new births for the year 2025 and we assume that
+ >>> # the birth rate only differ by gender.
+ >>> # In the line below, we add an additional TIME axis to 'birth_rate' while it was initialized
+ >>> # with the AGE and GENDER axes only
+ >>> m.birth_rate = full((AGE, GENDER, TIME), fill_value=Array([0.045, 0.055], GENDER))
+ >>> # here 'new_births' have the AGE, GENDER and TIME axes instead of the AGE and GENDER axes only
+ >>> new_births = m.population['female', 2025] * m.birth_rate
+ >>> print(new_births.info)
+ 11 x 2 x 11
+ age [11]: 0 1 2 ... 8 9 10
+ gender [2]: 'male' 'female'
+ time [11]: 2020 2021 2022 ... 2028 2029 2030
+ dtype: float64
+ memory used: 1.89 Kb
+ >>> # and the line below will crash
+ >>> m.births[2025] = new_births # doctest: +NORMALIZE_WHITESPACE
+ Traceback (most recent call last):
+ ...
+ ValueError: Value {time} axis is not present in target subset {age, gender}.
+ A value can only have the same axes or fewer axes than the subset being targeted
+ >>> # now let's try to do the same for deaths and making the same mistake as for 'birth_rate'.
+ >>> # The program will crash now at the first step instead of letting you go further
+ >>> m.mortality_rate = full((AGE, GENDER, TIME), fill_value=sequence(AGE, inc=0.02)) \
+ # doctest: +NORMALIZE_WHITESPACE
+ Traceback (most recent call last):
+ ...
+ ValueError: Array 'mortality_rate' was declared with axes {age, gender} but got array with axes
+ {age, gender, time} (unexpected {time} axis)
+
+ >>> # example 2: let's say we want to calculate the new births for all years.
+ >>> m.birth_rate = full((AGE, GENDER, TIME), fill_value=Array([0.045, 0.055], GENDER))
+ >>> new_births = m.population['female'] * m.birth_rate
+ >>> # here 'new_births' has the same axes as 'births' but is a float array instead of
+ >>> # an integer array as 'births'.
+ >>> # The line below will make the 'births' array become a float array while
+ >>> # it was initialized as an integer array
+ >>> m.births = new_births
+ >>> print(m.births.info)
+ 11 x 11 x 2
+ age [11]: 0 1 2 ... 8 9 10
+ time [11]: 2020 2021 2022 ... 2028 2029 2030
+ gender [2]: 'male' 'female'
+ dtype: float64
+ memory used: 1.89 Kb
+ >>> # now let's try to do the same for deaths.
+ >>> m.mortality_rate = full((AGE, GENDER), fill_value=sequence(AGE, inc=0.02))
+ >>> # here the result of the multiplication of the 'population' array by the 'mortality_rate' array
+ >>> # is automatically converted to an integer array
+ >>> m.deaths = m.population * m.mortality_rate
+ >>> print(m.deaths.info) # doctest: +SKIP
+ 11 x 2 x 11
+ age [11]: 0 1 2 ... 8 9 10
+ gender [2]: 'male' 'female'
+ time [11]: 2020 2021 2022 ... 2028 2029 2030
+ dtype: int32
+ memory used: 968 bytes
+
+ >>> # note that it still possible to add undeclared variables to a checked session
+ >>> # but this must be done with caution.
+ >>> m.undeclared_var = 'undeclared_var'
+
+ >>> # ==== output ====
+ >>> # save all variables in an HDF5 file
+ >>> m.save(f'{variant_name}.h5', display=True)
+ dumping birth_rate ... done
+ dumping births ... done
+ dumping mortality_rate ... done
+ dumping deaths ... done
+ dumping target_age ... done
+ dumping population ... done
+ dumping undeclared_var ... done
+ """
+ if TYPE_CHECKING:
+ # populated by the metaclass, defined here to help IDEs only
+ __fields__: Dict[str, ModelField] = {}
+ __field_defaults__: Dict[str, Any] = {}
+ __validators__: Dict[str, List[Validator]] = {}
+ __config__: Type[BaseConfig] = BaseConfig
+
+ class Config:
+ # whether to allow arbitrary user types for fields (they are validated simply by checking
+ # if the value is an instance of the type). If False, RuntimeError will be raised on model declaration.
+ # (default: False)
+ arbitrary_types_allowed = True
+ # whether to validate field defaults
+ validate_all = True
+ # whether to ignore, allow, or forbid extra attributes during model initialization (and after).
+ # Accepts the string values of 'ignore', 'allow', or 'forbid', or values of the Extra enum
+ # (default: Extra.ignore)
+ extra = 'allow'
+ # whether to perform validation on assignment to attributes
+ validate_assignment = True
+ # whether or not models are faux-immutable, i.e. whether __setattr__ is allowed.
+ # (default: True)
+ allow_mutation = True
+
+ # Warning: order of fields is not preserved.
+ # As of v1.0 of pydantic all fields with annotations (whether annotation-only or with a default value)
+ # will precede all fields without an annotation. Within their respective groups, fields remain in the
+ # order they were defined.
+ # See https://pydantic-docs.helpmanual.io/usage/models/#field-ordering
+ def __init__(self, *args, **kwargs):
+ meta = kwargs.pop('meta', Metadata())
+ Session.__init__(self, meta=meta)
+
+ # create an intermediate Session object to not call the __setattr__
+ # and __setitem__ overridden in the present class and in case a filepath
+ # is given as only argument
+ # todo: refactor Session.load() to use a private function which returns the handler directly
+ # so that we can get the items out of it and avoid this
+ input_data = dict(Session(*args, **kwargs))
+
+ # --- declared variables
+ for name, field in self.__fields__.items():
+ value = input_data.pop(field.name, NotLoaded())
+
+ if isinstance(value, NotLoaded):
+ if field.default is None:
+ warnings.warn(f"No value passed for the declared variable '{field.name}'", stacklevel=2)
+ self.__setattr__(name, value, skip_allow_mutation=True, skip_validation=True)
+ else:
+ self.__setattr__(name, field.default, skip_allow_mutation=True)
+ else:
+ self.__setattr__(name, value, skip_allow_mutation=True)
+
+ # --- undeclared variables
+ for name, value in input_data.items():
+ self.__setattr__(name, value, skip_allow_mutation=True)
+
+ # code of the method below has been partly borrowed from pydantic.BaseModel.__setattr__()
+ def _check_key_value(self, name: str, value: Any, skip_allow_mutation: bool, skip_validation: bool) -> Any:
+ config = self.__config__
+ if not config.extra and name not in self.__fields__:
+ raise ValueError(f"Variable '{name}' is not declared in '{self.__class__.__name__}'. "
+ f"Adding undeclared variables is forbidden. "
+ f"List of declared variables is: {list(self.__fields__.keys())}.")
+ if not skip_allow_mutation and not config.allow_mutation:
+ raise TypeError(f"Cannot change the value of the variable '{name}' since '{self.__class__.__name__}' "
+ f"is immutable and does not support item assignment")
+ known_field = self.__fields__.get(name, None)
+ if known_field:
+ if not skip_validation:
+ value, error_ = known_field.validate(value, self.dict(exclude={name}), loc=name, cls=self.__class__)
+ if error_:
+ raise error_.exc
+ else:
+ warnings.warn(f"'{name}' is not declared in '{self.__class__.__name__}'", stacklevel=3)
+ return value
+
+ def __setitem__(self, key, value, skip_allow_mutation=False, skip_validation=False):
+ if key != 'meta':
+ value = self._check_key_value(key, value, skip_allow_mutation, skip_validation)
+ # we need to keep the attribute in sync
+ object.__setattr__(self, key, value)
+ self._objects[key] = value
+
+ def __setattr__(self, key, value, skip_allow_mutation=False, skip_validation=False):
+ if key != 'meta':
+ value = self._check_key_value(key, value, skip_allow_mutation, skip_validation)
+ # we need to keep the attribute in sync
+ object.__setattr__(self, key, value)
+ Session.__setattr__(self, key, value)
+
+ def __getstate__(self) -> Dict[str, Any]:
+ return {'__dict__': self.__dict__}
+
+ def __setstate__(self, state: Dict[str, Any]) -> None:
+ object.__setattr__(self, '__dict__', state['__dict__'])
+
+ def dict(self, exclude: Set[str] = None):
+ d = dict(self.items())
+ for name in exclude:
+ if name in d:
+ del d[name]
+ return d
+
+ class CheckedParameters(CheckedSession):
+ """
+ Same as py:class:`CheckedSession` but declared variables cannot be modified after initialization.
+
+ Parameters
+ ----------
+ *args : str or dict of {str: object} or iterable of tuples (str, object)
+ Path to the file containing the session to load or
+ list/tuple/dictionary containing couples (name, object).
+ **kwargs : dict of {str: object}
+
+ * Objects to add written as name=object
+ * meta : list of pairs or dict or OrderedDict or Metadata, optional
+ Metadata (title, description, author, creation_date, ...) associated with the array.
+ Keys must be strings. Values must be of type string, int, float, date, time or datetime.
+
+ See Also
+ --------
+ CheckedSession
+
+ Examples
+ --------
+
+ Content of file 'parameters.py'
+
+ >>> from larray import *
+ >>> class Parameters(CheckedParameters):
+ ... # --- declare variables with fixed values ---
+ ... # The given values can never be changed
+ ... FIRST_YEAR = 2020
+ ... LAST_YEAR = 2030
+ ... AGE = Axis('age=0..10')
+ ... GENDER = Axis('gender=male,female')
+ ... TIME = Axis(f'time={FIRST_YEAR}..{LAST_YEAR}')
+ ... # --- declare variables with defined types ---
+ ... # Their values must be defined at initialized and will be frozen after.
+ ... variant_name: str
+
+ Content of file 'model.py'
+
+ >>> # instantiation --> create an instance of the ModelVariables class
+ >>> # all variables declared without value must be set
+ >>> P = Parameters(variant_name='variant_1')
+ >>> # once an instance is created, its variables can be accessed but not modified
+ >>> P.variant_name
+ 'variant_1'
+ >>> P.variant_name = 'new_variant' # doctest: +NORMALIZE_WHITESPACE
+ Traceback (most recent call last):
+ ...
+ TypeError: Cannot change the value of the variable 'variant_name' since 'Parameters'
+ is immutable and does not support item assignment
+ """
+ class Config:
+ # whether or not models are faux-immutable, i.e. whether __setattr__ is allowed.
+ # (default: True)
+ allow_mutation = False
diff --git a/larray/core/group.py b/larray/core/group.py
index 00ddbacfa..f885a2653 100644
--- a/larray/core/group.py
+++ b/larray/core/group.py
@@ -452,7 +452,7 @@ def _seq_str_to_seq(s, stack_depth=1, parse_single_int=False):
Parameters
----------
- s : basestring
+ s : str
string to parse
Returns
@@ -496,7 +496,7 @@ def _to_key(v, stack_depth=1, parse_single_int=False):
Parameters
----------
- v : int or basestring or tuple or list or slice or Array or Group
+ v : int or str or tuple or list or slice or Array or Group
value to convert into a key usable for indexing
Returns
@@ -598,7 +598,7 @@ def _to_keys(value, stack_depth=1):
Parameters
----------
- value : int or basestring or tuple or list or slice or Array or Group
+ value : int or str or tuple or list or slice or Array or Group
(collection of) value(s) to convert into key(s) usable for indexing
Returns
diff --git a/larray/core/session.py b/larray/core/session.py
index a3d34d817..d753469c0 100644
--- a/larray/core/session.py
+++ b/larray/core/session.py
@@ -85,6 +85,7 @@ def __init__(self, *args, **kwargs):
self.meta = meta
if len(args) == 1:
+ assert len(kwargs) == 0
a0 = args[0]
if isinstance(a0, str):
# assume a0 is a filename
@@ -915,7 +916,7 @@ def copy(self):
r"""Returns a copy of the session.
"""
# this actually *does* a copy of the internal mapping (the mapping is not reused-as is)
- return Session(self._objects)
+ return self.__class__(self._objects)
def keys(self):
r"""
@@ -1042,7 +1043,12 @@ def opmethod(self, other):
except Exception:
res_item = nan
res.append((name, res_item))
- return Session(res)
+ try:
+ # XXX: print a warning?
+ ses = self.__class__(res)
+ except Exception:
+ ses = Session(res)
+ return ses
opmethod.__name__ = opfullname
return opmethod
@@ -1072,7 +1078,12 @@ def opmethod(self):
except Exception:
res_array = nan
res.append((k, res_array))
- return Session(res)
+ try:
+ # XXX: print a warning?
+ ses = self.__class__(res)
+ except Exception:
+ ses = Session(res)
+ return ses
opmethod.__name__ = opfullname
return opmethod
diff --git a/larray/tests/common.py b/larray/tests/common.py
index b9b11ce43..0b0131b1f 100644
--- a/larray/tests/common.py
+++ b/larray/tests/common.py
@@ -174,3 +174,17 @@ def must_warn(warn_cls=None, msg=None, match=None, check_file=True, check_num=Tr
warning_path = caught_warnings[0].filename
assert warning_path == caller_path, \
f"{warning_path} != {caller_path}"
+
+
+@contextmanager
+def must_raise(warn_cls=None, msg=None, match=None):
+ if msg is not None and match is not None:
+ raise ValueError("bad test: can't use both msg and match arguments")
+ elif msg is not None:
+ match = re.escape(msg)
+
+ try:
+ with pytest.raises(warn_cls, match=match) as error:
+ yield error
+ finally:
+ pass
diff --git a/larray/tests/data/test_session.h5 b/larray/tests/data/test_session.h5
index d17d199b6..f27b6bf3a 100644
Binary files a/larray/tests/data/test_session.h5 and b/larray/tests/data/test_session.h5 differ
diff --git a/larray/tests/test_checked_session.py b/larray/tests/test_checked_session.py
new file mode 100644
index 000000000..9c9f07ddd
--- /dev/null
+++ b/larray/tests/test_checked_session.py
@@ -0,0 +1,670 @@
+import pytest
+
+try:
+ import pydantic # noqa: F401
+except ImportError:
+ pytestmark = pytest.mark.skip("pydantic is required for testing Checked* classes")
+
+import pickle
+import numpy as np
+
+from larray import (CheckedSession, CheckedArray, Axis, AxisCollection, Group, Array,
+ ndtest, full, full_like, zeros_like, ones, ones_like, isnan)
+from larray.tests.common import (inputpath, tmp_path, assert_array_nan_equal, meta, # noqa: F401
+ needs_pytables, needs_openpyxl, needs_xlwings,
+ must_warn, must_raise)
+from larray.tests.test_session import (a, a2, a3, anonymous, a01, ano01, b, b2, b024, # noqa: F401
+ c, d, e, f, g, h,
+ assert_seq_equal, session, test_getitem, test_getattr,
+ test_add, test_element_equals, test_eq, test_ne)
+from larray.core.checked import NotLoaded
+
+
+# avoid flake8 errors
+meta = meta
+
+
+class TestCheckedSession(CheckedSession):
+ b = b
+ b024 = b024
+ a: Axis
+ a2: Axis
+ anonymous = anonymous
+ a01: Group
+ ano01 = ano01
+ c: str = c
+ d = d
+ e: Array
+ g: Array
+ f: CheckedArray((Axis(3), Axis(2)))
+ h: CheckedArray((a3, b2), dtype=int)
+
+
+@pytest.fixture()
+def checkedsession():
+ return TestCheckedSession(a=a, a2=a2, a01=a01, e=e, g=g, f=f, h=h)
+
+
+def test_create_checkedsession_instance(meta):
+ # As of v1.0 of pydantic all fields with annotations (whether annotation-only or with a default value)
+ # will precede all fields without an annotation. Within their respective groups, fields remain in the
+ # order they were defined.
+ # See https://pydantic-docs.helpmanual.io/usage/models/#field-ordering
+ declared_variable_keys = ['a', 'a2', 'a01', 'c', 'e', 'g', 'f', 'h', 'b', 'b024', 'anonymous', 'ano01', 'd']
+
+ # setting variables without default values
+ cs = TestCheckedSession(a, a01, a2=a2, e=e, f=f, g=g, h=h)
+ assert list(cs.keys()) == declared_variable_keys
+ assert cs.b.equals(b)
+ assert cs.b024.equals(b024)
+ assert cs.a.equals(a)
+ assert cs.a2.equals(a2)
+ assert cs.anonymous.equals(anonymous)
+ assert cs.a01.equals(a01)
+ assert cs.ano01.equals(ano01)
+ assert cs.c == c
+ assert cs.d == d
+ assert cs.e.equals(e)
+ assert cs.g.equals(g)
+ assert cs.f.equals(f)
+ assert cs.h.equals(h)
+
+ # metadata
+ cs = TestCheckedSession(a, a01, a2=a2, e=e, f=f, g=g, h=h, meta=meta)
+ assert cs.meta == meta
+
+ # override default value
+ b_alt = Axis('b=b0..b4')
+ cs = TestCheckedSession(a, a01, b=b_alt, a2=a2, e=e, f=f, g=g, h=h)
+ assert cs.b is b_alt
+
+ # test for "NOT_LOADED" variables
+ with must_warn(UserWarning, msg="No value passed for the declared variable 'a'", check_file=False):
+ TestCheckedSession(a01=a01, a2=a2, e=e, f=f, g=g, h=h)
+ cs = TestCheckedSession()
+ assert list(cs.keys()) == declared_variable_keys
+ # --- variables with default values ---
+ assert cs.b.equals(b)
+ assert cs.b024.equals(b024)
+ assert cs.anonymous.equals(anonymous)
+ assert cs.ano01.equals(ano01)
+ assert cs.c == c
+ assert cs.d == d
+ # --- variables without default values ---
+ assert isinstance(cs.a, NotLoaded)
+ assert isinstance(cs.a2, NotLoaded)
+ assert isinstance(cs.a01, NotLoaded)
+ assert isinstance(cs.e, NotLoaded)
+ assert isinstance(cs.g, NotLoaded)
+ assert isinstance(cs.f, NotLoaded)
+ assert isinstance(cs.h, NotLoaded)
+
+ # passing a scalar to set all elements a CheckedArray
+ cs = TestCheckedSession(a, a01, a2=a2, e=e, f=f, g=g, h=5)
+ assert cs.h.axes == AxisCollection((a3, b2))
+ assert cs.h.equals(full(axes=(a3, b2), fill_value=5))
+
+ # add the undeclared variable 'i'
+ with must_warn(UserWarning, f"'i' is not declared in '{cs.__class__.__name__}'", check_file=False):
+ cs = TestCheckedSession(a, a01, a2=a2, i=5, e=e, f=f, g=g, h=h)
+ assert list(cs.keys()) == declared_variable_keys + ['i']
+
+ # test inheritance between checked sessions
+ class TestInheritance(TestCheckedSession):
+ # override variables
+ b = b2
+ c: int = 5
+ f: CheckedArray((a3, b2), dtype=int)
+ h: CheckedArray((Axis(3), Axis(2)))
+ # new variables
+ n0 = 'first new var'
+ n1: str
+
+ declared_variable_keys += ['n1', 'n0']
+ cs = TestInheritance(a, a01, a2=a2, e=e, f=h, g=g, h=f, n1='second new var')
+ assert list(cs.keys()) == declared_variable_keys
+ # --- overriden variables ---
+ assert cs.b.equals(b2)
+ assert cs.c == 5
+ assert cs.f.equals(h)
+ assert cs.h.equals(f)
+ # --- new variables ---
+ assert cs.n0 == 'first new var'
+ assert cs.n1 == 'second new var'
+ # --- variables declared in the base class ---
+ assert cs.b024.equals(b024)
+ assert cs.a.equals(a)
+ assert cs.a2.equals(a2)
+ assert cs.anonymous.equals(anonymous)
+ assert cs.a01.equals(a01)
+ assert cs.ano01.equals(ano01)
+ assert cs.d == d
+ assert cs.e.equals(e)
+ assert cs.g.equals(g)
+
+
+@needs_pytables
+def test_init_checkedsession_hdf():
+ cs = TestCheckedSession(inputpath('test_session.h5'))
+ assert set(cs.keys()) == {'b', 'b024', 'a', 'a2', 'anonymous', 'a01', 'ano01', 'c', 'd', 'e', 'g', 'f', 'h'}
+
+
+def test_getitem_cs(checkedsession):
+ test_getitem(checkedsession)
+
+
+def test_setitem_cs(checkedsession):
+ cs = checkedsession
+
+ # only change values of an array -> OK
+ cs['h'] = zeros_like(h)
+
+ # trying to add an undeclared variable -> prints a warning message
+ with must_warn(UserWarning, msg=f"'i' is not declared in '{cs.__class__.__name__}'"):
+ cs['i'] = ndtest((3, 3))
+
+ # trying to set a variable with an object of different type -> should fail
+ # a) type given explicitly
+ # -> Axis
+ with must_raise(TypeError, msg="instance of Axis expected"):
+ cs['a'] = 0
+ # -> CheckedArray
+ with must_raise(TypeError, msg="Expected object of type 'Array' or a scalar for the variable 'h' but got "
+ "object of type 'ndarray'"):
+ cs['h'] = h.data
+ # b) type deduced from the given default value
+ with must_raise(TypeError, msg="instance of Axis expected"):
+ cs['b'] = ndtest((3, 3))
+
+ # trying to set a CheckedArray variable using a scalar -> OK
+ cs['h'] = 5
+
+ # trying to set a CheckedArray variable using an array with axes in different order -> OK
+ cs['h'] = h.transpose()
+ assert cs.h.axes.names == h.axes.names
+
+ # broadcasting (missing axis) is allowed
+ cs['h'] = ndtest(a3)
+ assert_array_nan_equal(cs['h']['b0'], cs['h']['b1'])
+
+ # trying to set a CheckedArray variable using an array with wrong axes -> should fail
+ # a) extra axis
+ with must_raise(ValueError, msg="Array 'h' was declared with axes {a, b} but got array with axes {a, b, c} "
+ "(unexpected {c} axis)"):
+ cs['h'] = ndtest((a3, b2, 'c=c0..c2'))
+ # b) incompatible axis
+ msg = """\
+Incompatible axis for array 'h':
+Axis(['a0', 'a1', 'a2', 'a3', 'a4'], 'a')
+vs
+Axis(['a0', 'a1', 'a2', 'a3'], 'a')"""
+ with must_raise(ValueError, msg=msg):
+ cs['h'] = h.append('a', 0, 'a4')
+
+
+def test_getattr_cs(checkedsession):
+ test_getattr(checkedsession)
+
+
+def test_setattr_cs(checkedsession):
+ cs = checkedsession
+
+ # only change values of an array -> OK
+ cs.h = zeros_like(h)
+
+ # trying to add an undeclared variable -> prints a warning message
+ with must_warn(UserWarning, msg=f"'i' is not declared in '{cs.__class__.__name__}'"):
+ cs.i = ndtest((3, 3))
+
+ # trying to set a variable with an object of different type -> should fail
+ # a) type given explicitly
+ # -> Axis
+ with must_raise(TypeError, msg="instance of Axis expected"):
+ cs.a = 0
+ # -> CheckedArray
+ with must_raise(TypeError, msg="Expected object of type 'Array' or a scalar for the variable 'h' but got "
+ "object of type 'ndarray'"):
+ cs.h = h.data
+ # b) type deduced from the given default value
+ with must_raise(TypeError, msg="instance of Axis expected"):
+ cs.b = ndtest((3, 3))
+
+ # trying to set a CheckedArray variable using a scalar -> OK
+ cs.h = 5
+
+ # trying to set a CheckedArray variable using an array with axes in different order -> OK
+ cs.h = h.transpose()
+ assert cs.h.axes.names == h.axes.names
+
+ # broadcasting (missing axis) is allowed
+ cs.h = ndtest(a3)
+ assert_array_nan_equal(cs.h['b0'], cs.h['b1'])
+
+ # trying to set a CheckedArray variable using an array with wrong axes -> should fail
+ # a) extra axis
+ with must_raise(ValueError, msg="Array 'h' was declared with axes {a, b} but got array with axes {a, b, c} "
+ "(unexpected {c} axis)"):
+ cs.h = ndtest((a3, b2, 'c=c0..c2'))
+ # b) incompatible axis
+ msg = """\
+Incompatible axis for array 'h':
+Axis(['a0', 'a1', 'a2', 'a3', 'a4'], 'a')
+vs
+Axis(['a0', 'a1', 'a2', 'a3'], 'a')"""
+ with must_raise(ValueError, msg=msg):
+ cs.h = h.append('a', 0, 'a4')
+
+
+def test_add_cs(checkedsession):
+ cs = checkedsession
+ test_add(cs)
+
+ u = Axis('u=u0..u2')
+ with must_warn(UserWarning, msg=f"'u' is not declared in '{cs.__class__.__name__}'", check_file=False):
+ cs.add(u)
+
+
+def test_iter_cs(checkedsession):
+ # As of v1.0 of pydantic all fields with annotations (whether annotation-only or with a default value)
+ # will precede all fields without an annotation. Within their respective groups, fields remain in the
+ # order they were defined.
+ # See https://pydantic-docs.helpmanual.io/usage/models/#field-ordering
+ expected = [a, a2, a01, c, e, g, f, h, b, b024, anonymous, ano01, d]
+ assert_seq_equal(checkedsession, expected)
+
+
+def test_filter_cs(checkedsession):
+ # see comment in test_iter_cs() about fields ordering
+ cs = checkedsession
+ cs.ax = 'ax'
+ assert_seq_equal(cs.filter(), [a, a2, a01, c, e, g, f, h, b, b024, anonymous, ano01, d, 'ax'])
+ assert_seq_equal(cs.filter('a*'), [a, a2, a01, anonymous, ano01, 'ax'])
+ assert list(cs.filter('a*', dict)) == []
+ assert list(cs.filter('a*', str)) == ['ax']
+ assert list(cs.filter('a*', Axis)) == [a, a2, anonymous]
+ assert list(cs.filter(kind=Axis)) == [a, a2, b, anonymous]
+ assert list(cs.filter('a01', Group)) == [a01]
+ assert list(cs.filter(kind=Group)) == [a01, b024, ano01]
+ assert_seq_equal(cs.filter(kind=Array), [e, g, f, h])
+ assert list(cs.filter(kind=dict)) == [{}]
+ assert list(cs.filter(kind=(Axis, Group))) == [a, a2, a01, b, b024, anonymous, ano01]
+
+
+def test_names_cs(checkedsession):
+ assert checkedsession.names == ['a', 'a01', 'a2', 'ano01', 'anonymous', 'b', 'b024',
+ 'c', 'd', 'e', 'f', 'g', 'h']
+
+
+def _test_io_cs(tmpdir, meta, engine, ext):
+ filename = f"test_{engine}.{ext}" if 'csv' not in engine else f"test_{engine}{ext}"
+ fpath = tmp_path(tmpdir, filename)
+
+ is_excel_or_csv = 'excel' in engine or 'csv' in engine
+
+ # Save and load
+ # -------------
+
+ # a) - all typed variables have a defined value
+ # - no extra variables are added
+ csession = TestCheckedSession(a=a, a2=a2, a01=a01, d=d, e=e, g=g, f=f, h=h, meta=meta)
+ csession.save(fpath, engine=engine)
+ cs = TestCheckedSession()
+ cs.load(fpath, engine=engine)
+ # --- keys ---
+ assert list(cs.keys()) == list(csession.keys())
+ # --- variables with default values ---
+ assert cs.b.equals(b)
+ assert cs.b024.equals(b024)
+ assert cs.anonymous.equals(anonymous)
+ assert cs.ano01.equals(ano01)
+ assert cs.d == d
+ # --- typed variables ---
+ # Array is support by all formats
+ assert cs.e.equals(e)
+ assert cs.g.equals(g)
+ assert cs.f.equals(f)
+ assert cs.h.equals(h)
+ # Axis and Group are not supported by the Excel and CSV formats
+ if is_excel_or_csv:
+ assert isinstance(cs.a, NotLoaded)
+ assert isinstance(cs.a2, NotLoaded)
+ assert isinstance(cs.a01, NotLoaded)
+ else:
+ assert cs.a.equals(a)
+ assert cs.a2.equals(a2)
+ assert cs.a01.equals(a01)
+ # --- dtype of Axis variables ---
+ if not is_excel_or_csv:
+ for key in cs.filter(kind=Axis).keys():
+ assert cs[key].dtype == csession[key].dtype
+ # --- metadata ---
+ if engine != 'pandas_excel':
+ assert cs.meta == meta
+
+ # b) - not all typed variables have a defined value
+ # - no extra variables are added
+ csession = TestCheckedSession(a=a, d=d, e=e, h=h, meta=meta)
+ if 'csv' in engine:
+ import shutil
+ shutil.rmtree(fpath)
+ csession.save(fpath, engine=engine)
+ cs = TestCheckedSession()
+ cs.load(fpath, engine=engine)
+ # --- keys ---
+ assert list(cs.keys()) == list(csession.keys())
+ # --- variables with default values ---
+ assert cs.b.equals(b)
+ assert cs.b024.equals(b024)
+ assert cs.anonymous.equals(anonymous)
+ assert cs.ano01.equals(ano01)
+ assert cs.d == d
+ # --- typed variables ---
+ # Array is support by all formats
+ assert cs.e.equals(e)
+ assert isinstance(cs.g, NotLoaded)
+ assert isinstance(cs.f, NotLoaded)
+ assert cs.h.equals(h)
+ # Axis and Group are not supported by the Excel and CSV formats
+ if is_excel_or_csv:
+ assert isinstance(cs.a, NotLoaded)
+ assert isinstance(cs.a2, NotLoaded)
+ assert isinstance(cs.a01, NotLoaded)
+ else:
+ assert cs.a.equals(a)
+ assert isinstance(cs.a2, NotLoaded)
+ assert isinstance(cs.a01, NotLoaded)
+
+ # c) - all typed variables have a defined value
+ # - extra variables are added
+ i = ndtest(6)
+ j = ndtest((3, 3))
+ k = ndtest((2, 2))
+ csession = TestCheckedSession(a=a, a2=a2, a01=a01, d=d, e=e, g=g, f=f, h=h, k=k, j=j, i=i, meta=meta)
+ csession.save(fpath, engine=engine)
+ cs = TestCheckedSession()
+ cs.load(fpath, engine=engine)
+ # --- names ---
+ # we do not use keys() since order of undeclared variables
+ # may not be preserved (at least for the HDF format)
+ assert cs.names == csession.names
+ # --- extra variable ---
+ assert cs.i.equals(i)
+ assert cs.j.equals(j)
+ assert cs.k.equals(k)
+
+ # Update a Group + an Axis + an array (overwrite=False)
+ # -----------------------------------------------------
+ csession = TestCheckedSession(a=a, a2=a2, a01=a01, d=d, e=e, g=g, f=f, h=h, meta=meta)
+ csession.save(fpath, engine=engine)
+ a4 = Axis('a=0..3')
+ a4_01 = a3['0,1'] >> 'a01'
+ e2 = ndtest((a4, 'b=b0..b2'))
+ h2 = full_like(h, fill_value=10)
+ TestCheckedSession(a=a4, a01=a4_01, e=e2, h=h2).save(fpath, overwrite=False, engine=engine)
+ cs = TestCheckedSession()
+ cs.load(fpath, engine=engine)
+ # --- variables with default values ---
+ assert cs.b.equals(b)
+ assert cs.b024.equals(b024)
+ assert cs.anonymous.equals(anonymous)
+ assert cs.ano01.equals(ano01)
+ # --- typed variables ---
+ # Array is support by all formats
+ assert cs.e.equals(e2)
+ assert cs.h.equals(h2)
+ if engine == 'pandas_excel':
+ # Session.save() via engine='pandas_excel' always overwrite the output Excel files
+ # arrays 'g' and 'f' have been dropped
+ assert isinstance(cs.g, NotLoaded)
+ assert isinstance(cs.f, NotLoaded)
+ # Axis and Group are not supported by the Excel and CSV formats
+ assert isinstance(cs.a, NotLoaded)
+ assert isinstance(cs.a2, NotLoaded)
+ assert isinstance(cs.a01, NotLoaded)
+ elif is_excel_or_csv:
+ assert cs.g.equals(g)
+ assert cs.f.equals(f)
+ # Axis and Group are not supported by the Excel and CSV formats
+ assert isinstance(cs.a, NotLoaded)
+ assert isinstance(cs.a2, NotLoaded)
+ assert isinstance(cs.a01, NotLoaded)
+ else:
+ assert list(cs.keys()) == list(csession.keys())
+ assert cs.a.equals(a4)
+ assert cs.a2.equals(a2)
+ assert cs.a01.equals(a4_01)
+ assert cs.g.equals(g)
+ assert cs.f.equals(f)
+ if engine != 'pandas_excel':
+ assert cs.meta == meta
+
+ # Load only some objects
+ # ----------------------
+ csession = TestCheckedSession(a=a, a2=a2, a01=a01, d=d, e=e, g=g, f=f, h=h, meta=meta)
+ csession.save(fpath, engine=engine)
+ cs = TestCheckedSession()
+ names_to_load = ['e', 'h'] if is_excel_or_csv else ['a', 'a01', 'a2', 'e', 'h']
+ cs.load(fpath, names=names_to_load, engine=engine)
+ # --- keys ---
+ assert list(cs.keys()) == list(csession.keys())
+ # --- variables with default values ---
+ assert cs.b.equals(b)
+ assert cs.b024.equals(b024)
+ assert cs.anonymous.equals(anonymous)
+ assert cs.ano01.equals(ano01)
+ assert cs.d == d
+ # --- typed variables ---
+ # Array is support by all formats
+ assert cs.e.equals(e)
+ assert isinstance(cs.g, NotLoaded)
+ assert isinstance(cs.f, NotLoaded)
+ assert cs.h.equals(h)
+ # Axis and Group are not supported by the Excel and CSV formats
+ if is_excel_or_csv:
+ assert isinstance(cs.a, NotLoaded)
+ assert isinstance(cs.a2, NotLoaded)
+ assert isinstance(cs.a01, NotLoaded)
+ else:
+ assert cs.a.equals(a)
+ assert cs.a2.equals(a2)
+ assert cs.a01.equals(a01)
+
+ return fpath
+
+
+@needs_pytables
+def test_h5_io_cs(tmpdir, meta):
+ _test_io_cs(tmpdir, meta, engine='pandas_hdf', ext='h5')
+
+
+@needs_openpyxl
+def test_xlsx_pandas_io_cs(tmpdir, meta):
+ _test_io_cs(tmpdir, meta, engine='pandas_excel', ext='xlsx')
+
+
+@needs_xlwings
+def test_xlsx_xlwings_io_cs(tmpdir, meta):
+ _test_io_cs(tmpdir, meta, engine='xlwings_excel', ext='xlsx')
+
+
+def test_csv_io_cs(tmpdir, meta):
+ _test_io_cs(tmpdir, meta, engine='pandas_csv', ext='csv')
+
+
+def test_pickle_io_cs(tmpdir, meta):
+ _test_io_cs(tmpdir, meta, engine='pickle', ext='pkl')
+
+
+def test_pickle_roundtrip_cs(checkedsession, meta):
+ cs = checkedsession
+ cs.meta = meta
+ s = pickle.dumps(cs)
+ res = pickle.loads(s)
+ assert res.equals(cs)
+ assert res.meta == meta
+
+
+def test_element_equals_cs(checkedsession):
+ test_element_equals(checkedsession)
+
+
+def test_eq_cs(checkedsession):
+ test_eq(checkedsession)
+
+
+def test_ne_cs(checkedsession):
+ test_ne(checkedsession)
+
+
+def test_sub_cs(checkedsession):
+ cs = checkedsession
+ session_cls = cs.__class__
+
+ # session - session
+ other = session_cls(a=a, a2=a2, a01=a01, e=e - 1, g=zeros_like(g), f=zeros_like(f), h=ones_like(h))
+ diff = cs - other
+ assert isinstance(diff, session_cls)
+ # --- non-array variables ---
+ assert diff.b is b
+ assert diff.b024 is b024
+ assert diff.a is a
+ assert diff.a2 is a2
+ assert diff.anonymous is anonymous
+ assert diff.a01 is a01
+ assert diff.ano01 is ano01
+ assert diff.c is c
+ assert diff.d is d
+ # --- array variables ---
+ assert_array_nan_equal(diff.e, np.full((2, 3), 1, dtype=np.int32))
+ assert_array_nan_equal(diff.g, g)
+ assert_array_nan_equal(diff.f, f)
+ assert_array_nan_equal(diff.h, h - ones_like(h))
+
+ # session - scalar
+ diff = cs - 2
+ assert isinstance(diff, session_cls)
+ # --- non-array variables ---
+ assert diff.b is b
+ assert diff.b024 is b024
+ assert diff.a is a
+ assert diff.a2 is a2
+ assert diff.anonymous is anonymous
+ assert diff.a01 is a01
+ assert diff.ano01 is ano01
+ assert diff.c is c
+ assert diff.d is d
+ # --- non constant arrays ---
+ assert_array_nan_equal(diff.e, e - 2)
+ assert_array_nan_equal(diff.g, g - 2)
+ assert_array_nan_equal(diff.f, f - 2)
+ assert_array_nan_equal(diff.h, h - 2)
+
+ # session - dict(Array and scalar)
+ other = {'e': ones_like(e), 'h': 1}
+ diff = cs - other
+ assert isinstance(diff, session_cls)
+ # --- non-array variables ---
+ assert diff.b is b
+ assert diff.b024 is b024
+ assert diff.a is a
+ assert diff.a2 is a2
+ assert diff.anonymous is anonymous
+ assert diff.a01 is a01
+ assert diff.ano01 is ano01
+ assert diff.c is c
+ assert diff.d is d
+ # --- non constant arrays ---
+ assert_array_nan_equal(diff.e, e - ones_like(e))
+ assert isnan(diff.g).all()
+ assert isnan(diff.f).all()
+ assert_array_nan_equal(diff.h, h - 1)
+
+ # session - array
+ axes = cs.h.axes
+ cs.e = ndtest(axes)
+ cs.g = ones_like(cs.h)
+ diff = cs - ones(axes)
+ assert isinstance(diff, session_cls)
+ # --- non-array variables ---
+ assert diff.b is b
+ assert diff.b024 is b024
+ assert diff.a is a
+ assert diff.a2 is a2
+ assert diff.anonymous is anonymous
+ assert diff.a01 is a01
+ assert diff.ano01 is ano01
+ assert diff.c is c
+ assert diff.d is d
+ # --- non constant arrays ---
+ assert_array_nan_equal(diff.e, cs.e - ones(axes))
+ assert_array_nan_equal(diff.g, cs.g - ones(axes))
+ assert isnan(diff.f).all()
+ assert_array_nan_equal(diff.h, cs.h - ones(axes))
+
+
+def test_rsub_cs(checkedsession):
+ cs = checkedsession
+ session_cls = cs.__class__
+
+ # scalar - session
+ diff = 2 - cs
+ assert isinstance(diff, session_cls)
+ # --- non-array variables ---
+ assert diff.b is b
+ assert diff.b024 is b024
+ assert diff.a is a
+ assert diff.a2 is a2
+ assert diff.anonymous is anonymous
+ assert diff.a01 is a01
+ assert diff.ano01 is ano01
+ assert diff.c is c
+ assert diff.d is d
+ # --- non constant arrays ---
+ assert_array_nan_equal(diff.e, 2 - e)
+ assert_array_nan_equal(diff.g, 2 - g)
+ assert_array_nan_equal(diff.f, 2 - f)
+ assert_array_nan_equal(diff.h, 2 - h)
+
+ # dict(Array and scalar) - session
+ other = {'e': ones_like(e), 'h': 1}
+ diff = other - cs
+ assert isinstance(diff, session_cls)
+ # --- non-array variables ---
+ assert diff.b is b
+ assert diff.b024 is b024
+ assert diff.a is a
+ assert diff.a2 is a2
+ assert diff.anonymous is anonymous
+ assert diff.a01 is a01
+ assert diff.ano01 is ano01
+ assert diff.c is c
+ assert diff.d is d
+ # --- non constant arrays ---
+ assert_array_nan_equal(diff.e, ones_like(e) - e)
+ assert isnan(diff.g).all()
+ assert isnan(diff.f).all()
+ assert_array_nan_equal(diff.h, 1 - h)
+
+
+def test_neg_cs(checkedsession):
+ cs = checkedsession
+ neg_cs = -cs
+ # --- non-array variables ---
+ assert isnan(neg_cs.b)
+ assert isnan(neg_cs.b024)
+ assert isnan(neg_cs.a)
+ assert isnan(neg_cs.a2)
+ assert isnan(neg_cs.anonymous)
+ assert isnan(neg_cs.a01)
+ assert isnan(neg_cs.ano01)
+ assert isnan(neg_cs.c)
+ assert isnan(neg_cs.d)
+ # --- non constant arrays ---
+ assert_array_nan_equal(neg_cs.e, -e)
+ assert_array_nan_equal(neg_cs.g, -g)
+ assert_array_nan_equal(neg_cs.f, -f)
+ assert_array_nan_equal(neg_cs.h, -h)
+
+
+if __name__ == "__main__":
+ pytest.main()
diff --git a/larray/tests/test_session.py b/larray/tests/test_session.py
index 264f7afdd..486df9421 100644
--- a/larray/tests/test_session.py
+++ b/larray/tests/test_session.py
@@ -12,8 +12,8 @@
from larray.tests.common import (assert_array_nan_equal, inputpath, tmp_path,
needs_xlwings, needs_pytables, needs_openpyxl, must_warn)
from larray.inout.common import _supported_scalars_types
-from larray import (Session, Axis, Array, Group, isnan, zeros_like, ndtest, ones_like, ones, full,
- local_arrays, global_arrays, arrays)
+from larray import (Session, Axis, Array, Group, isnan, zeros_like, ndtest, ones_like,
+ ones, full, full_like, stack, local_arrays, global_arrays, arrays)
# avoid flake8 errors
@@ -37,18 +37,25 @@ def assert_seq_equal(got, expected):
a = Axis('a=a0..a2')
a2 = Axis('a=a0..a4')
+a3 = Axis('a=a0..a3')
anonymous = Axis(4)
a01 = a['a0,a1'] >> 'a01'
ano01 = a['a0,a1']
b = Axis('b=0..4')
+b2 = Axis('b=b0..b4')
b024 = b[[0, 2, 4]] >> 'b024'
c = 'c'
d = {}
e = ndtest([(2, 'a'), (3, 'b')])
_e = ndtest((3, 3))
-f = ndtest((Axis(3), Axis(2)))
+f = ndtest((Axis(3), Axis(2)), dtype=float)
g = ndtest([(2, 'a'), (4, 'b')])
-h = ndtest(('a=a0..a2', 'b=b0..b4'))
+h = ndtest((a3, b2))
+k = ndtest((3, 3))
+
+# ########################### #
+# SESSION #
+# ########################### #
@pytest.fixture()
@@ -58,12 +65,12 @@ def session():
def test_init_session(meta):
- s = Session(b, b024, a, a01, a2=a2, anonymous=anonymous, ano01=ano01, c=c, d=d, e=e, f=f, g=g, h=h)
- assert s.names == ['a', 'a01', 'a2', 'ano01', 'anonymous', 'b', 'b024', 'c', 'd', 'e', 'f', 'g', 'h']
+ s = Session(b, b024, a, a01, a2=a2, anonymous=anonymous, ano01=ano01, c=c, d=d, e=e, g=g, f=f, h=h)
+ assert list(s.keys()) == ['b', 'b024', 'a', 'a01', 'a2', 'anonymous', 'ano01', 'c', 'd', 'e', 'g', 'f', 'h']
# TODO: format auto-detection does not work in this case
# s = Session('test_session_csv')
- # assert s.names == ['e', 'f', 'g']
+ # assert list(s.keys()) == ['e', 'f', 'g']
# metadata
s = Session(b, b024, a, a01, a2=a2, anonymous=anonymous, ano01=ano01, c=c, d=d, e=e, f=f, g=g, h=h, meta=meta)
@@ -73,14 +80,14 @@ def test_init_session(meta):
@needs_xlwings
def test_init_session_xlsx():
s = Session(inputpath('demography_eurostat.xlsx'))
- assert s.names == ['births', 'deaths', 'immigration', 'population',
- 'population_5_countries', 'population_benelux']
+ assert list(s.keys()) == ['population', 'population_benelux', 'population_5_countries',
+ 'births', 'deaths', 'immigration']
@needs_pytables
def test_init_session_hdf():
s = Session(inputpath('test_session.h5'))
- assert s.names == ['e', 'f', 'g']
+ assert list(s.keys()) == ['e', 'f', 'g', 'h', 'a', 'a2', 'anonymous', 'b', 'a01', 'ano01', 'b024']
def test_getitem(session):
@@ -175,13 +182,16 @@ def test_names(session):
assert session.names == ['a', 'a01', 'a2', 'ano01', 'anonymous', 'b', 'b024',
'c', 'd', 'e', 'f', 'g', 'h']
# add them in the "wrong" order
- session.add(i='i')
session.add(j='j')
+ session.add(i='i')
assert session.names == ['a', 'a01', 'a2', 'ano01', 'anonymous', 'b', 'b024',
'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
-def _test_io(fpath, session, meta, engine):
+def _test_io(tmpdir, session, meta, engine, ext):
+ filename = f"test_{engine}.{ext}" if 'csv' not in engine else f"test_{engine}{ext}"
+ fpath = tmp_path(tmpdir, filename)
+
is_excel_or_csv = 'excel' in engine or 'csv' in engine
kind = Array if is_excel_or_csv else (Axis, Group, Array) + _supported_scalars_types
@@ -203,21 +213,22 @@ def _test_io(fpath, session, meta, engine):
assert s.meta == meta
# update a Group + an Axis + an array (overwrite=False)
- a3 = Axis('a=0..3')
- a3_01 = a3['0,1'] >> 'a01'
- e2 = ndtest((a3, 'b=b0..b2'))
- Session(a=a3, a01=a3_01, e=e2).save(fpath, overwrite=False, engine=engine)
+ a4 = Axis('a=0..3')
+ a4_01 = a3['0,1'] >> 'a01'
+ e2 = ndtest((a4, 'b=b0..b2'))
+ h2 = full_like(h, fill_value=10)
+ Session(a=a4, a01=a4_01, e=e2, h=h2).save(fpath, overwrite=False, engine=engine)
s = Session()
s.load(fpath, engine=engine)
if engine == 'pandas_excel':
# Session.save() via engine='pandas_excel' always overwrite the output Excel files
- assert s.names == ['e']
+ assert s.names == ['e', 'h']
elif is_excel_or_csv:
assert s.names == ['e', 'f', 'g', 'h']
else:
assert s.names == session.names
- assert s['a'].equals(a3)
- assert s['a01'].equals(a3_01)
+ assert s['a'].equals(a4)
+ assert s['a01'].equals(a4_01)
assert_array_nan_equal(s['e'], e2)
if engine != 'pandas_excel':
assert s.meta == meta
@@ -225,12 +236,14 @@ def _test_io(fpath, session, meta, engine):
# load only some objects
session.save(fpath, engine=engine)
s = Session()
- names_to_load = ['e', 'f'] if is_excel_or_csv else ['a', 'a01', 'a2', 'anonymous', 'e', 'f']
+ names_to_load = ['e', 'f'] if is_excel_or_csv else ['a', 'a01', 'a2', 'anonymous', 'e', 'f', 's_bool', 's_int']
s.load(fpath, names=names_to_load, engine=engine)
assert s.names == names_to_load
if engine != 'pandas_excel':
assert s.meta == meta
+ return fpath
+
def _add_scalars_to_session(s):
# 's' for scalar
@@ -247,7 +260,6 @@ def _add_scalars_to_session(s):
@needs_pytables
def test_h5_io(tmpdir, session, meta):
session = _add_scalars_to_session(session)
- fpath = tmp_path(tmpdir, 'test_session.h5')
msg = "\nyour performance may suffer as PyTables will pickle object types"
regex = re.compile(msg, flags=re.MULTILINE)
@@ -255,27 +267,24 @@ def test_h5_io(tmpdir, session, meta):
# for some reason the PerformanceWarning is not detected as such, so this does not work:
# with pytest.warns(tables.PerformanceWarning):
with pytest.warns(Warning, match=regex):
- _test_io(fpath, session, meta, engine='pandas_hdf')
+ _test_io(tmpdir, session, meta, engine='pandas_hdf', ext='h5')
@needs_openpyxl
def test_xlsx_pandas_io(tmpdir, session, meta):
- fpath = tmp_path(tmpdir, 'test_session.xlsx')
- _test_io(fpath, session, meta, engine='pandas_excel')
+ _test_io(tmpdir, session, meta, engine='pandas_excel', ext='xlsx')
@needs_xlwings
def test_xlsx_xlwings_io(tmpdir, session, meta):
- fpath = tmp_path(tmpdir, 'test_session.xlsx')
- _test_io(fpath, session, meta, engine='xlwings_excel')
+ _test_io(tmpdir, session, meta, engine='xlwings_excel', ext='xlsx')
def test_csv_io(tmpdir, session, meta):
- fpath = tmp_path(tmpdir, 'test_session_csv')
try:
- _test_io(fpath, session, meta, engine='pandas_csv')
+ fpath = _test_io(tmpdir, session, meta, engine='pandas_csv', ext='csv')
- names = session.filter(kind=Array).names
+ names = Session({k: v for k, v in session.items() if isinstance(v, Array)}).names
# test loading with a pattern
pattern = os.path.join(fpath, '*.csv')
@@ -303,8 +312,16 @@ def test_csv_io(tmpdir, session, meta):
def test_pickle_io(tmpdir, session, meta):
session = _add_scalars_to_session(session)
- fpath = tmp_path(tmpdir, 'test_session.pkl')
- _test_io(fpath, session, meta, engine='pickle')
+ _test_io(tmpdir, session, meta, engine='pickle', ext='pkl')
+
+
+def test_pickle_roundtrip(session, meta):
+ original = session.filter(kind=Array)
+ original.meta = meta
+ s = pickle.dumps(original)
+ res = pickle.loads(s)
+ assert res.equals(original)
+ assert res.meta == meta
def test_to_globals(session):
@@ -337,84 +354,128 @@ def test_to_globals(session):
def test_element_equals(session):
- sess = session.filter(kind=(Axis, Group, Array))
- expected = Session([('b', b), ('b024', b024), ('a', a), ('a2', a2), ('anonymous', anonymous),
- ('a01', a01), ('ano01', ano01), ('e', e), ('g', g), ('f', f), ('h', h)])
- assert all(sess.element_equals(expected))
-
- other = Session([('a', a), ('a2', a2), ('anonymous', anonymous),
- ('a01', a01), ('ano01', ano01), ('e', e), ('f', f), ('h', h)])
- res = sess.element_equals(other)
- assert res.ndim == 1
- assert res.axes.names == ['name']
- assert np.array_equal(res.axes.labels[0], ['b', 'b024', 'a', 'a2', 'anonymous', 'a01', 'ano01',
- 'e', 'g', 'f', 'h'])
- assert list(res) == [False, False, True, True, True, True, True, True, False, True, True]
-
- e2 = e.copy()
- e2.i[1, 1] = 42
- other = Session([('a', a), ('a2', a2), ('anonymous', anonymous),
- ('a01', a01), ('ano01', ano01), ('e', e2), ('f', f), ('h', h)])
- res = sess.element_equals(other)
- assert res.axes.names == ['name']
- assert np.array_equal(res.axes.labels[0], ['b', 'b024', 'a', 'a2', 'anonymous', 'a01', 'ano01',
- 'e', 'g', 'f', 'h'])
- assert list(res) == [False, False, True, True, True, True, True, False, False, True, True]
+ session_cls = session.__class__
+ other_session = session_cls([(key, value) for key, value in session.items()])
+
+ keys = [key for key, value in session.items() if isinstance(value, (Axis, Group, Array))]
+ expected_res = full(Axis(keys, 'name'), fill_value=True, dtype=bool)
+
+ # ====== same sessions ======
+ res = session.element_equals(other_session)
+ assert res.axes == expected_res.axes
+ assert res.equals(expected_res)
+
+ # ====== session with missing/extra items ======
+ # delete some items
+ for deleted_key in ['b', 'b024', 'g']:
+ del other_session[deleted_key]
+ expected_res[deleted_key] = False
+ # add one item
+ other_session['k'] = k
+ expected_res = expected_res.append('name', False, label='k')
+
+ res = session.element_equals(other_session)
+ assert res.axes == expected_res.axes
+ assert res.equals(expected_res)
+
+ # ====== session with a modified array ======
+ h2 = h.copy()
+ h2['a1', 'b1'] = 42
+ other_session['h'] = h2
+ expected_res['h'] = False
+
+ res = session.element_equals(other_session)
+ assert res.axes == expected_res.axes
+ assert res.equals(expected_res)
+
+
+def to_boolean_array_eq(res):
+ return stack([(key, item.all() if isinstance(item, Array) else item)
+ for key, item in res.items()], 'name')
def test_eq(session):
- sess = session.filter(kind=(Axis, Group, Array))
- expected = Session([('b', b), ('b024', b024), ('a', a), ('a2', a2), ('anonymous', anonymous),
- ('a01', a01), ('ano01', ano01), ('e', e), ('g', g), ('f', f), ('h', h)])
- assert all([item.all() if isinstance(item, Array) else item
- for item in (sess == expected).values()])
-
- other = Session([('b', b), ('b024', b024), ('a', a), ('a2', a2), ('anonymous', anonymous),
- ('a01', a01), ('ano01', ano01), ('e', e), ('f', f), ('h', h)])
- res = sess == other
- assert list(res.keys()) == ['b', 'b024', 'a', 'a2', 'anonymous', 'a01', 'ano01',
- 'e', 'g', 'f', 'h']
- assert [item.all() if isinstance(item, Array) else item
- for item in res.values()] == [True, True, True, True, True, True, True, True, False, True, True]
-
- e2 = e.copy()
- e2.i[1, 1] = 42
- other = Session([('b', b), ('b024', b024), ('a', a), ('a2', a2), ('anonymous', anonymous),
- ('a01', a01), ('ano01', ano01), ('e', e2), ('f', f), ('h', h)])
- res = sess == other
- assert [item.all() if isinstance(item, Array) else item
- for item in res.values()] == [True, True, True, True, True, True, True, False, False, True, True]
+ session_cls = session.__class__
+ other_session = session_cls([(key, value) for key, value in session.items()])
+ expected_res = full(Axis(list(session.keys()), 'name'), fill_value=True, dtype=bool)
+
+ # ====== same sessions ======
+ res = session == other_session
+ res = to_boolean_array_eq(res)
+ assert res.axes == expected_res.axes
+ assert res.equals(expected_res)
+
+ # ====== session with missing/extra items ======
+ del other_session['g']
+ expected_res['g'] = False
+ other_session['k'] = k
+ expected_res = expected_res.append('name', False, label='k')
+
+ res = session == other_session
+ res = to_boolean_array_eq(res)
+ assert res.axes == expected_res.axes
+ assert res.equals(expected_res)
+
+ # ====== session with a modified array ======
+ h2 = h.copy()
+ h2['a1', 'b1'] = 42
+ other_session['h'] = h2
+ expected_res['h'] = False
+
+ res = session == other_session
+ assert res['h'].equals(session['h'] == other_session['h'])
+ res = to_boolean_array_eq(res)
+ assert res.axes == expected_res.axes
+ assert res.equals(expected_res)
+
+
+def to_boolean_array_ne(res):
+ return stack([(key, item.any() if isinstance(item, Array) else item)
+ for key, item in res.items()], 'name')
def test_ne(session):
- sess = session.filter(kind=(Axis, Group, Array))
- expected = Session([('b', b), ('b024', b024), ('a', a), ('a2', a2), ('anonymous', anonymous),
- ('a01', a01), ('ano01', ano01), ('e', e), ('g', g), ('f', f), ('h', h)])
- assert ([(~item).all() if isinstance(item, Array) else not item
- for item in (sess != expected).values()])
-
- other = Session([('b', b), ('b024', b024), ('a', a), ('a2', a2), ('anonymous', anonymous),
- ('a01', a01), ('ano01', ano01), ('e', e), ('f', f), ('h', h)])
- res = sess != other
- assert list(res.keys()) == ['b', 'b024', 'a', 'a2', 'anonymous', 'a01', 'ano01',
- 'e', 'g', 'f', 'h']
- assert [(~item).all() if isinstance(item, Array) else not item
- for item in res.values()] == [True, True, True, True, True, True, True, True, False, True, True]
-
- e2 = e.copy()
- e2.i[1, 1] = 42
- other = Session([('b', b), ('b024', b024), ('a', a), ('a2', a2), ('anonymous', anonymous),
- ('a01', a01), ('ano01', ano01), ('e', e2), ('f', f), ('h', h)])
- res = sess != other
- assert [(~item).all() if isinstance(item, Array) else not item
- for item in res.values()] == [True, True, True, True, True, True, True, False, False, True, True]
+ session_cls = session.__class__
+ other_session = session_cls([(key, value) for key, value in session.items()])
+ expected_res = full(Axis(list(session.keys()), 'name'), fill_value=False, dtype=bool)
+
+ # ====== same sessions ======
+ res = session != other_session
+ res = to_boolean_array_ne(res)
+ assert res.axes == expected_res.axes
+ assert res.equals(expected_res)
+
+ # ====== session with missing/extra items ======
+ del other_session['g']
+ expected_res['g'] = True
+ other_session['k'] = k
+ expected_res = expected_res.append('name', True, label='k')
+
+ res = session != other_session
+ res = to_boolean_array_ne(res)
+ assert res.axes == expected_res.axes
+ assert res.equals(expected_res)
+
+ # ====== session with a modified array ======
+ h2 = h.copy()
+ h2['a1', 'b1'] = 42
+ other_session['h'] = h2
+ expected_res['h'] = True
+
+ res = session != other_session
+ assert res['h'].equals(session['h'] != other_session['h'])
+ res = to_boolean_array_ne(res)
+ assert res.axes == expected_res.axes
+ assert res.equals(expected_res)
def test_sub(session):
sess = session
# session - session
- other = Session({'e': e - 1, 'f': ones_like(f)})
+ other = Session({'e': e, 'f': f})
+ other['e'] = e - 1
+ other['f'] = ones_like(f)
diff = sess - other
assert_array_nan_equal(diff['e'], np.full((2, 3), 1, dtype=np.int32))
assert_array_nan_equal(diff['f'], f - ones_like(f))
@@ -444,12 +505,12 @@ def test_sub(session):
# session - array
axes = [a, b]
- sess = Session([('a', a), ('a01', a01), ('c', c), ('e', ndtest(axes)),
- ('f', full(axes, fill_value=3)), ('g', ndtest('c=c0..c2'))])
- diff = sess - ones(axes)
- assert_array_nan_equal(diff['e'], sess['e'] - ones(axes))
- assert_array_nan_equal(diff['f'], sess['f'] - ones(axes))
- assert_array_nan_equal(diff['g'], sess['g'] - ones(axes))
+ other = Session([('a', a), ('a01', a01), ('c', c), ('e', ndtest((a, b))),
+ ('f', full((a, b), fill_value=3)), ('g', ndtest('c=c0..c2'))])
+ diff = other - ones(axes)
+ assert_array_nan_equal(diff['e'], other['e'] - ones(axes))
+ assert_array_nan_equal(diff['f'], other['f'] - ones(axes))
+ assert_array_nan_equal(diff['g'], other['g'] - ones(axes))
assert diff.a is a
assert diff.a01 is a01
assert diff.c is c
@@ -480,7 +541,11 @@ def test_rsub(session):
def test_div(session):
sess = session
- other = Session({'e': e - 1, 'f': f + 1})
+ session_cls = session.__class__
+
+ other = session_cls({'e': e, 'f': f})
+ other['e'] = e - 1
+ other['f'] = f + 1
with must_warn(RuntimeWarning, msg="divide by zero encountered during operation"):
res = sess / other
@@ -527,15 +592,6 @@ def test_rdiv(session):
assert res.c is c
-def test_pickle_roundtrip(session, meta):
- original = session.filter(kind=Array)
- original.meta = meta
- s = pickle.dumps(original)
- res = pickle.loads(s)
- assert res.equals(original)
- assert res.meta == meta
-
-
def test_local_arrays():
h = ndtest(2)
_h = ndtest(3)
@@ -554,12 +610,12 @@ def test_local_arrays():
def test_global_arrays():
# exclude private global arrays
s = global_arrays()
- s_expected = Session([('e', e), ('f', f), ('g', g), ('h', h)])
+ s_expected = Session([('e', e), ('f', f), ('g', g), ('h', h), ('k', k)])
assert s.equals(s_expected)
# all global arrays
s = global_arrays(include_private=True)
- s_expected = Session([('e', e), ('_e', _e), ('f', f), ('g', g), ('h', h)])
+ s_expected = Session([('e', e), ('_e', _e), ('f', f), ('g', g), ('h', h), ('k', k)])
assert s.equals(s_expected)
@@ -569,12 +625,12 @@ def test_arrays():
# exclude private arrays
s = arrays()
- s_expected = Session([('e', e), ('f', f), ('g', g), ('h', h), ('i', i)])
+ s_expected = Session([('e', e), ('f', f), ('g', g), ('h', h), ('i', i), ('k', k)])
assert s.equals(s_expected)
# all arrays
s = arrays(include_private=True)
- s_expected = Session([('_e', _e), ('_i', _i), ('e', e), ('f', f), ('g', g), ('h', h), ('i', i)])
+ s_expected = Session([('_e', _e), ('_i', _i), ('e', e), ('f', f), ('g', g), ('h', h), ('i', i), ('k', k)])
assert s.equals(s_expected)
diff --git a/make_release.py b/make_release.py
index a01d2c92b..1b6575228 100644
--- a/make_release.py
+++ b/make_release.py
@@ -39,9 +39,11 @@ def update_metapackage(local_repository, release_name, public_release=True, **ex
print(f'Updating larrayenv metapackage to version {version}')
# - excluded versions 5.0 and 5.1 of ipykernel because these versions make the console useless after any exception
# https://github.com/larray-project/larray-editor/issues/166
+ # - pydantic: cannot define numpy ndarray / pandas obj / LArray field with default value
+ # since version 1.6
check_call(['conda', 'metapackage', 'larrayenv', version, '--dependencies', f'larray =={version}',
f'larray-editor =={version}', f'larray_eurostat =={version}',
- "qtconsole", "matplotlib", "pyqt", "qtpy", "pytables",
+ "qtconsole", "matplotlib", "pyqt", "qtpy", "pytables", "pydantic <=1.5",
"xlsxwriter", "xlrd", "xlwt", "openpyxl", "xlwings", "ipykernel !=5.0,!=5.1.0",
'--user', 'larray-project',
'--home', 'http://github.com/larray-project/larray',
diff --git a/setup.py b/setup.py
index 64911a30f..6f0cbcb77 100644
--- a/setup.py
+++ b/setup.py
@@ -1,5 +1,3 @@
-from __future__ import print_function
-
import os
from setuptools import setup, find_packages