Skip to content

Commit c6b4a23

Browse files
committed
fix #14 : horizontal labels accept more than 1 dimension
1 parent 4d470ec commit c6b4a23

File tree

3 files changed

+107
-68
lines changed

3 files changed

+107
-68
lines changed

larray_editor/arrayadapter.py

+49-37
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66

77

88
class LArrayDataAdapter(object):
9-
def __init__(self, axes_model, hlabels_model, vlabels_model, data_model,
10-
data=None, changes=None, current_filter=None, bg_gradient=None, bg_value=None):
9+
def __init__(self, axes_model, hlabels_model, vlabels_model, data_model, data=None,
10+
changes=None, current_filter=None, nb_dims_hlabels=1, bg_gradient=None, bg_value=None):
1111
# set models
1212
self.axes_model = axes_model
1313
self.hlabels_model = hlabels_model
1414
self.vlabels_model = vlabels_model
1515
self.data_model = data_model
16+
# set number of dims of hlabels
17+
self.nb_dims_hlabels = nb_dims_hlabels
1618
# set current filter
1719
if current_filter is None:
1820
current_filter = {}
@@ -31,38 +33,43 @@ def set_changes(self, changes=None):
3133
assert isinstance(changes, dict)
3234
self.changes = changes
3335

36+
def update_nb_dims_hlabels(self, nb_dims_hlabels):
37+
self.nb_dims_hlabels = nb_dims_hlabels
38+
self.update_axes_and_labels()
39+
3440
def get_axes_names(self):
3541
return self.filtered_data.axes.display_names
3642

3743
def get_axes(self):
38-
axes = self.filtered_data.axes
39-
if len(axes) == 0:
40-
return None
41-
else:
42-
axes_names = axes.display_names
43-
if len(axes_names) >= 2:
44-
axes_names = axes_names[:-2] + [axes_names[-2] + '\\' + axes_names[-1]]
45-
return [[axis_name] for axis_name in axes_names]
46-
47-
def get_hlabels(self):
48-
axes = self.filtered_data.axes
49-
if len(axes) == 0:
44+
axes_names = self.filtered_data.axes.display_names
45+
if len(axes_names) == 0:
5046
return None
51-
elif len(axes.labels[-1]) == 0:
52-
return [['']]
47+
elif len(axes_names) == 1:
48+
return [axes_names]
5349
else:
54-
return [[label] for label in axes.labels[-1]]
50+
nb_dims_vlabels = len(axes_names) - self.nb_dims_hlabels
51+
# axes corresponding to horizontal labels are set to the last column
52+
res = [['' for c in range(nb_dims_vlabels-1)] + [axis_name] for axis_name in axes_names[nb_dims_vlabels:]]
53+
# axes corresponding to vertical labels are set to the last row
54+
res = res + [[axis_name for axis_name in axes_names[:nb_dims_vlabels]]]
55+
return res
5556

56-
def get_vlabels(self):
57+
def get_labels(self):
5758
axes = self.filtered_data.axes
58-
if len(axes) == 0:
59-
return None
60-
elif len(axes) == 1:
61-
return [['']]
62-
else:
63-
labels = axes.labels[:-1]
64-
prod = Product(labels)
65-
return [_LazyDimLabels(prod, i) for i in range(len(labels))]
59+
nb_dims_vlabels = len(axes) - self.nb_dims_hlabels
60+
def get_labels_product(axes, extra_row=False):
61+
if len(axes) == 0:
62+
return None
63+
else:
64+
# XXX: appends a fake axis instead of using _LazyNone because
65+
# _LazyNone mess up with LabelsArrayModel.get_values (in which slices are used)
66+
if extra_row:
67+
axes.append(la.Axis([' ']))
68+
prod = Product(axes.labels)
69+
return [_LazyDimLabels(prod, i) for i in range(len(axes.labels))]
70+
vlabels = get_labels_product(axes[:nb_dims_vlabels])
71+
hlabels = get_labels_product(axes[nb_dims_vlabels:], nb_dims_vlabels > 0)
72+
return vlabels, hlabels
6673

6774
def get_2D_data(self):
6875
"""Returns Numpy 2D ndarray"""
@@ -110,24 +117,29 @@ def set_data(self, data, bg_gradient=None, bg_value=None, current_filter=None):
110117
self.bg_gradient = bg_gradient
111118
self.update_filtered_data(current_filter)
112119

120+
def update_axes_and_labels(self):
121+
axes = self.get_axes()
122+
vlabels, hlabels = self.get_labels()
123+
self.axes_model.set_data(axes)
124+
self.hlabels_model.set_data(hlabels)
125+
self.vlabels_model.set_data(vlabels)
126+
127+
def update_data_2D(self):
128+
data_2D = self.get_2D_data()
129+
changes_2D = self.get_changes_2D()
130+
bg_value_2D = self.get_bg_value_2D(data_2D.shape)
131+
self.data_model.set_data(data_2D, changes_2D)
132+
self.data_model.set_background(self.bg_gradient, bg_value_2D)
133+
113134
def update_filtered_data(self, current_filter=None):
114135
if current_filter is not None:
115136
assert isinstance(current_filter, dict)
116137
self.current_filter = current_filter
117138
self.filtered_data = self.la_data[self.current_filter]
118139
if np.isscalar(self.filtered_data):
119140
self.filtered_data = la.aslarray(self.filtered_data)
120-
axes = self.get_axes()
121-
hlabels = self.get_hlabels()
122-
vlabels = self.get_vlabels()
123-
data_2D = self.get_2D_data()
124-
changes_2D = self.get_changes_2D()
125-
bg_value_2D = self.get_bg_value_2D(data_2D.shape)
126-
self.axes_model.set_data(axes)
127-
self.hlabels_model.set_data(hlabels)
128-
self.vlabels_model.set_data(vlabels)
129-
self.data_model.set_data(data_2D, changes_2D)
130-
self.data_model.set_background(self.bg_gradient, bg_value_2D)
141+
self.update_axes_and_labels()
142+
self.update_data_2D()
131143

132144
def get_data(self):
133145
return self.la_data

larray_editor/arraymodel.py

+25-13
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ class LabelsArrayModel(AbstractArrayModel):
127127
font : QFont, optional
128128
Font. Default is `Calibri` with size 11.
129129
"""
130-
def __init__(self, parent=None, data=None, readonly=False, font=None):
130+
def __init__(self, parent=None, data=None, readonly=False, font=None, orientation=Qt.Horizontal):
131+
self.orientation = orientation
131132
AbstractArrayModel.__init__(self, parent, data, readonly, font)
132133
self.font.setBold(True)
133134

@@ -138,28 +139,39 @@ def _set_data(self, data, changes=None):
138139
QMessageBox.critical(self.dialog, "Error", "Expected list or tuple.")
139140
data = [[]]
140141
self._data = data
141-
self.total_rows = len(data[0])
142-
self.total_cols = len(data) if self.total_rows > 0 else 0
142+
if self.orientation == Qt.Horizontal:
143+
self.total_rows = len(data) if self.total_cols > 0 else 0
144+
self.total_cols = len(data[0])
145+
else:
146+
self.total_rows = len(data[0])
147+
self.total_cols = len(data) if self.total_rows > 0 else 0
143148
self._compute_rows_cols_loaded()
144149

145150
def flags(self, index):
146151
"""Set editable flag"""
147152
return Qt.ItemIsEnabled
148153

149154
def get_value(self, index):
150-
i = index.row()
151-
j = index.column()
152-
# we need to inverse column and row because of the way vlabels are generated
153-
return str(self._data[j][i])
155+
if self.orientation == Qt.Horizontal:
156+
i, j = index.row(), index.column()
157+
else:
158+
i, j = index.column(), index.row()
159+
return str(self._data[i][j])
154160

155161
# XXX: I wonder if we shouldn't return a 2D Numpy array of strings?
156162
def get_values(self, left=0, top=0, right=None, bottom=None):
157-
if right is None:
158-
right = self.total_rows
159-
if bottom is None:
160-
bottom = self.total_cols
161-
values = [list(line[left:right]) for line in self._data[top:bottom]]
162-
return values
163+
if self.orientation == Qt.Horizontal:
164+
if right is None:
165+
right = self.total_cols
166+
if bottom is None:
167+
bottom = self.total_rows
168+
return [list(line[left:right]) for line in self._data[top:bottom]]
169+
else:
170+
if right is None:
171+
right = self.total_rows
172+
if bottom is None:
173+
bottom = self.total_cols
174+
return [list(line[top:bottom]) for line in self._data[left:right]]
163175

164176
def data(self, index, role=Qt.DisplayRole):
165177
# print('data', index.column(), index.row(), self.rowCount(), self.columnCount(), '\n', self._data)

larray_editor/arraywidget.py

+33-18
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ def __init__(self, parent, data, readonly=False, bg_value=None, bg_gradient=None
517517
self.model_hlabels = LabelsArrayModel(parent=self, readonly=readonly)
518518
self.view_hlabels = LabelsView(parent=self, model=self.model_hlabels, position=(TOP, RIGHT))
519519

520-
self.model_vlabels = LabelsArrayModel(parent=self, readonly=readonly)
520+
self.model_vlabels = LabelsArrayModel(parent=self, readonly=readonly, orientation=Qt.Vertical)
521521
self.view_vlabels = LabelsView(parent=self, model=self.model_vlabels, position=(BOTTOM, LEFT))
522522

523523
self.model_data = DataArrayModel(parent=self, readonly=readonly, minvalue=minvalue, maxvalue=maxvalue)
@@ -617,6 +617,13 @@ def __init__(self, parent, data, readonly=False, bg_value=None, bg_gradient=None
617617
self.bgcolor_checkbox = bgcolor
618618
btn_layout.addWidget(bgcolor)
619619

620+
label = QLabel("Horizontal Dimensions")
621+
btn_layout.addWidget(label)
622+
spin = QSpinBox(self)
623+
spin.valueChanged.connect(self.nb_horizontal_dims_changed)
624+
self.nb_horizontal_dims_spinbox = spin
625+
btn_layout.addWidget(spin)
626+
620627
# Set widget layout
621628
layout = QVBoxLayout()
622629
layout.addLayout(self.filters_layout)
@@ -702,7 +709,8 @@ def dropEvent(self, event):
702709

703710
def set_data(self, data, bg_gradient=None, bg_value=None):
704711
self.data_adapter.set_data(data, bg_gradient=bg_gradient, bg_value=bg_value)
705-
self._update_digits_scientific(self.data_adapter.get_data())
712+
self._update_digits_scientific_dims(self.data_adapter.get_data())
713+
self.nb_horizontal_dims_spinbox.setValue(1)
706714

707715
# update filters
708716
la_data = self.data_adapter.get_data()
@@ -721,11 +729,11 @@ def set_data(self, data, bg_gradient=None, bg_value=None):
721729

722730
# reset default size
723731
self.view_axes.set_default_size()
724-
self.view_ylabels.set_default_size()
725-
self.view_xlabels.set_default_size()
732+
self.view_vlabels.set_default_size()
733+
self.view_hlabels.set_default_size()
726734
self.view_data.set_default_size()
727735

728-
def _update_digits_scientific(self, data):
736+
def _update_digits_scientific_dims(self, data):
729737
"""
730738
data : LArray
731739
"""
@@ -755,6 +763,9 @@ def _update_digits_scientific(self, data):
755763
self.bgcolor_checkbox.setChecked(self.model_data.bgcolor_enabled)
756764
self.bgcolor_checkbox.setEnabled(self.model_data.bgcolor_enabled)
757765

766+
self.nb_horizontal_dims_spinbox.setMinimum(1)
767+
self.nb_horizontal_dims_spinbox.setMaximum(max(1, self.data_adapter.ndim - 1))
768+
758769
def choose_scientific(self, data):
759770
# max_digits = self.get_max_digits()
760771
# default width can fit 8 chars
@@ -887,7 +898,7 @@ def dirty(self):
887898
def accept_changes(self):
888899
"""Accept changes"""
889900
la_data = self.data_adapter.accept_changes()
890-
self._update_digits_scientific(la_data)
901+
self._update_digits_scientific_dims(la_data)
891902

892903
def reject_changes(self):
893904
"""Reject changes"""
@@ -912,10 +923,13 @@ def digits_changed(self, value):
912923
self.digits = value
913924
self.model_data.set_format(self.cell_format)
914925

926+
def nb_horizontal_dims_changed(self, value):
927+
self.data_adapter.update_nb_dims_hlabels(value)
928+
915929
def create_filter_combo(self, axis):
916930
def filter_changed(checked_items):
917931
filtered = self.data_adapter.change_filter(axis, checked_items)
918-
self._update_digits_scientific(filtered)
932+
self._update_digits_scientific_dims(filtered)
919933
combo = FilterComboBox(self)
920934
combo.addItems([str(l) for l in axis.labels])
921935
combo.checkedItemsChanged.connect(filter_changed)
@@ -947,15 +961,15 @@ def _selection_data(self, headers=True, none_selects_all=True):
947961
if not self.data_adapter.ndim:
948962
return raw_data
949963
# FIXME: this is extremely ad-hoc.
950-
# TODO: in the future (pandas-based branch) we should use to_string(data[self._selection_filter()])
964+
# TODO: in the future (multi_index supported) we should use to_string(data[self._selection_filter()])
951965
dim_headers = self.model_axes.get_values()
952-
hlabels = self.model_hlabels.get_values(top=col_min, bottom=col_max)
953-
topheaders = [[dim_header[0] for dim_header in dim_headers] + [label[0] for label in hlabels]]
966+
hlabels = self.model_hlabels.get_values(left=col_min, right=col_max)
967+
topheaders = [dims + labels for dims, labels in zip(dim_headers, hlabels)]
954968
if self.data_adapter.ndim == 1:
955969
return chain(topheaders, [chain([''], row) for row in raw_data])
956970
else:
957971
assert self.data_adapter.ndim > 1
958-
vlabels = self.model_vlabels.get_values(left=row_min, right=row_max)
972+
vlabels = self.model_vlabels.get_values(top=row_min, bottom=row_max)
959973
return chain(topheaders,
960974
[chain([vlabels[j][r] for j in range(len(vlabels))], row)
961975
for r, row in enumerate(raw_data)])
@@ -1039,12 +1053,13 @@ def plot(self):
10391053
row_min, row_max, col_min, col_max = self.view_data._selection_bounds()
10401054
dim_names = self.data_adapter.get_axes_names()
10411055
# labels
1042-
xlabels = [label[0] for label in self.model_hlabels.get_values(top=col_min, bottom=col_max)]
1043-
ylabels = self.model_vlabels.get_values(left=row_min, right=row_max)
1044-
# transpose ylabels
1045-
ylabels = [[str(ylabels[i][j]) for i in range(len(ylabels))] for j in range(len(ylabels[0]))]
1046-
# if there is only one dimension, ylabels is empty
1047-
if not ylabels:
1056+
xlabels = self.model_hlabels.get_values(left=col_min, right=col_max, bottom=self.data_adapter.nb_dims_hlabels)
1057+
xlabels = [[str(xlabels[i][j]) for i in range(len(xlabels))] for j in range(len(xlabels[0]))]
1058+
if self.data_adapter.ndim > 1:
1059+
ylabels = self.model_vlabels.get_values(top=row_min, bottom=row_max)
1060+
# transpose ylabels
1061+
ylabels = [[str(ylabels[i][j]) for i in range(len(ylabels))] for j in range(len(ylabels[0]))]
1062+
else:
10481063
ylabels = [[]]
10491064

10501065
assert data.ndim == 2
@@ -1064,7 +1079,7 @@ def plot(self):
10641079
else:
10651080
# plot each row as a line
10661081
xlabel = dim_names[-1]
1067-
xticklabels = [str(label) for label in xlabels]
1082+
xticklabels = ['\n'.join(row) for row in xlabels]
10681083
xdata = np.arange(col_max - col_min)
10691084
for row in range(len(data)):
10701085
ax.plot(xdata, data[row], label=' '.join(ylabels[row]))

0 commit comments

Comments
 (0)