Skip to content

Commit 92fe417

Browse files
committed
basic port, no new functionality but all tests pass
1 parent feffa7a commit 92fe417

File tree

1 file changed

+132
-37
lines changed

1 file changed

+132
-37
lines changed

larray.py

+132-37
Original file line numberDiff line numberDiff line change
@@ -496,25 +496,79 @@ def __repr__(self):
496496
return "%s[%s]" % (self.axis.name, self.name)
497497

498498

499-
class LArray(np.ndarray):
500-
def __new__(cls, data, axes=None):
501-
obj = np.asarray(data).view(cls)
502-
ndim = obj.ndim
499+
def unaryop(opname):
500+
def func(self):
501+
method = getattr(np.asarray(self), '__%s__' % opname)
502+
return LArray(method(), self.axes)
503+
return func
504+
505+
506+
def binop(opname, reversed=False):
507+
def opmethod(self, other):
508+
if not isinstance(other, LArray) and not np.isscalar(other):
509+
raise TypeError("unsupported operand type(s) for %s: '%s' and '%s'"
510+
% (opname, type(self), type(other)))
511+
if isinstance(other, LArray) and self.axes != other.axes:
512+
raise ValueError('axes not compatible')
513+
axes = self.axes
514+
if reversed:
515+
self, other = other, self
516+
method = getattr(np.asarray(self), '__%s__' % opname)
517+
values = method(np.asarray(other))
518+
return LArray(values, axes)
519+
return opmethod
520+
521+
522+
class LArray(object):
523+
def __init__(self, data, axes=None, name=None):
524+
array = np.asarray(data)
525+
ndim = array.ndim
503526
if axes is not None:
504527
if len(axes) != ndim:
505528
raise ValueError("number of axes (%d) does not match "
506529
"number of dimensions of data (%d)"
507530
% (len(axes), ndim))
508531
shape = tuple(len(axis) for axis in axes)
509-
if shape != obj.shape:
532+
if shape != array.shape:
510533
raise ValueError("length of axes %s does not match "
511-
"data shape %s" % (shape, obj.shape))
534+
"data shape %s" % (shape, array.shape))
512535

513-
if axes is not None and not isinstance(axes, list):
536+
if isinstance(axes, tuple):
514537
axes = list(axes)
515-
obj.axes = axes
516-
return obj
517-
538+
self.axes = axes
539+
if False:
540+
# record/structured array => DataFrame
541+
542+
# to do this (create a dataframe for non-record arrays), we could
543+
# simply create a Series and do s.unstack()
544+
# but I am not sure it makes sense to create a DF for non record
545+
# arrays.
546+
axes_labels = [a.labels for a in axes[:-1]]
547+
axes_names = [a.name for a in axes[:-1]]
548+
axes_names[-1] = axes_names[-1] + '\\' + axes[-1].name
549+
columns = axes[-1].labels.tolist()
550+
full_index = list(product(*axes_labels))
551+
index = pd.MultiIndex.from_tuples(full_index, names=axes_names)
552+
self.data = pd.DataFrame(self.reshape(len(full_index), len(columns)), index, columns)
553+
else:
554+
# homogeneous => Series
555+
axes_names = [a.name for a in axes]
556+
full_index = list(product(*[a.labels for a in axes]))
557+
index = pd.MultiIndex.from_tuples(full_index, names=axes_names)
558+
self.data = pd.Series(array.ravel(), index=index, name=name)
559+
560+
@property
561+
def shape(self):
562+
return tuple(len(axis) for axis in self.axes)
563+
564+
@property
565+
def ndim(self):
566+
return len(self.axes)
567+
568+
@property
569+
def dtype(self):
570+
return self.data.dtype
571+
518572
def as_dataframe(self):
519573
axes_labels = [a.labels.tolist() for a in self.axes[:-1]]
520574
axes_names = [a.name for a in self.axes[:-1]]
@@ -525,31 +579,16 @@ def as_dataframe(self):
525579
df = pd.DataFrame(self.reshape(len(full_index), len(columns)), index, columns)
526580
return df
527581

528-
529-
#noinspection PyAttributeOutsideInit
530-
def __array_finalize__(self, obj):
531-
# We are in the middle of the LabeledArray.__new__ constructor,
532-
# and our special attributes will be set when we return to that
533-
# constructor, so we do not need to set them here.
534-
if obj is None:
535-
return
536-
537-
# obj is our "template" object (on which we have asked a view on).
538-
if isinstance(obj, LArray) and self.shape == obj.shape:
539-
# obj.view(LArray)
540-
# larr[:3]
541-
self.axes = obj.axes
542-
else:
543-
self.axes = None
544-
#self.row_totals = None
545-
#self.col_totals = None
546-
547582
@property
548583
def is_aggregated(self):
549584
return any(axis.is_aggregated for axis in self.axes)
550585

586+
def __setitem__(self, key, value):
587+
#FIXME: allow label keys
588+
self.asarray()[key] = value
589+
551590
def __getitem__(self, key, collapse_slices=False):
552-
data = np.asarray(self)
591+
data = self.asarray()
553592

554593
# convert scalar keys to 1D keys
555594
if not isinstance(key, tuple):
@@ -634,7 +673,8 @@ def convert(axis, values):
634673
axes = [axis.subaxis(axis_key)
635674
for axis, axis_key in zip(self.axes, translated_key)
636675
if not np.isscalar(axis_key)]
637-
return LArray(data[full_key], axes)
676+
res_data = data[full_key]
677+
return LArray(res_data, axes) if axes else res_data
638678

639679
# deprecated since Python 2.0 but we need to define it to catch "simple"
640680
# slices (with integer bounds !) because ndarray is a "builtin" type
@@ -718,7 +758,7 @@ def _axis_aggregate(self, op, axes):
718758
"""
719759
src_data = np.asarray(self)
720760
if not axes:
721-
# scalars don't need to be wrapped in LArray
761+
# scalars do not need to be wrapped in LArray
722762
return op(src_data)
723763

724764
# we need to search for the axis by name, instead of the axis object
@@ -729,6 +769,9 @@ def _axis_aggregate(self, op, axes):
729769
axes_tokill = set(axes_indices)
730770
res_axes = [axis for axis_num, axis in enumerate(self.axes)
731771
if axis_num not in axes_tokill]
772+
if not res_axes:
773+
# scalars do not need to be wrapped in LArray
774+
return res_data
732775
return LArray(res_data, res_axes)
733776

734777
def _get_axis(self, name):
@@ -823,7 +866,7 @@ def _aggregate(self, op, args, kwargs, commutative=False):
823866
return self._axis_aggregate(op, axes=args)
824867

825868
def copy(self):
826-
return LArray(np.ndarray.copy(self), axes=self.axes[:])
869+
return LArray(self.asarray().copy(), axes=self.axes[:])
827870

828871
def info(self):
829872
axes_labels = [' '.join(repr(label) for label in axis.labels)
@@ -843,8 +886,8 @@ def sum(self, *args, **kwargs):
843886
#XXX: sep argument does not seem very useful
844887
#XXX: use pandas function instead?
845888
def to_excel(self, filename, sep=None):
846-
# Why xlsxwriter? Because it is faster than openpyxl and xlwt
847-
# currently does not .xlsx (only .xls).
889+
# Why xlsxwriter? Because it is faster than openpyxl and
890+
# xlwt does not currently support .xlsx (only .xls).
848891
# PyExcelerate seem like a decent alternative too
849892
import xlsxwriter as xl
850893

@@ -884,12 +927,64 @@ def transpose(self, *args):
884927
if axis.name not in axes_names]
885928
res_axes = list(args) + missing_axes
886929
axes_indices = [self._get_axis_idx(axis.name) for axis in res_axes]
887-
src_data = np.asarray(self)
888-
res_data = src_data.transpose(axes_indices)
930+
res_data = self.asarray().transpose(axes_indices)
889931
return LArray(res_data, res_axes)
890932
#XXX: is this necessary?
891933
reorder = transpose
892934

935+
def __array__(self):
936+
return np.asarray(self.data).reshape(self.shape)
937+
asarray = __array__
938+
939+
#FIXME: this make stuff fail
940+
# def __getattr__(self, key):
941+
# return getattr(self.data, key)
942+
943+
__lt__ = binop('lt')
944+
__le__ = binop('le')
945+
__eq__ = binop('eq')
946+
__ne__ = binop('ne')
947+
__gt__ = binop('gt')
948+
__ge__ = binop('ge')
949+
950+
__add__ = binop('add')
951+
__radd__ = binop('add', reversed=True)
952+
__sub__ = binop('sub')
953+
__rsub__ = binop('sub', reversed=True)
954+
__mul__ = binop('mul')
955+
__rmul__ = binop('mul', reversed=True)
956+
957+
__div__ = binop('div')
958+
__rdiv__ = binop('div', reversed=True)
959+
__truediv__ = binop('truediv')
960+
__rtruediv__ = binop('truediv', reversed=True)
961+
__floordiv__ = binop('floordiv')
962+
__rfloordiv__ = binop('floordiv', reversed=True)
963+
964+
__mod__ = binop('mod')
965+
__rmod__ = binop('mod', reversed=True)
966+
__divmod__ = binop('divmod')
967+
__rdivmod__ = binop('divmod', reversed=True)
968+
__pow__ = binop('pow')
969+
__rpow__ = binop('pow', reversed=True)
970+
971+
__lshift__ = binop('lshift')
972+
__rlshift__ = binop('lshift', reversed=True)
973+
__rshift__ = binop('rshift')
974+
__rrshift__ = binop('rshift', reversed=True)
975+
976+
__and__ = binop('and')
977+
__rand__ = binop('and', reversed=True)
978+
__xor__ = binop('xor')
979+
__rxor__ = binop('xor', reversed=True)
980+
__or__ = binop('or')
981+
__ror__ = binop('or', reversed=True)
982+
983+
__neg__ = unaryop('neg')
984+
__pos__ = unaryop('pos')
985+
__abs__ = unaryop('abs')
986+
__invert__ = unaryop('invert')
987+
893988
def ToCsv(self, filename):
894989
res = table2csv(self.as_table(), ',', 'nan')
895990
f = open(filename, "w")

0 commit comments

Comments
 (0)